feat: brain taxonomy — DB-backed folders/tags, sidebar, CRUD API
Backend: - New Folder/Tag/ItemTag models with proper relational tables - Taxonomy CRUD endpoints: list, create, rename, delete, merge tags - Sidebar endpoint with folder/tag counts - AI classification reads live folders/tags from DB, not hardcoded - Default folders/tags seeded on first request per user - folder_id FK on items for relational integrity Frontend: - Left sidebar with Folders/Tags tabs (like Karakeep) - Click folder/tag to filter items - "Manage" mode: add new folders/tags, delete existing - Counts next to each folder/tag - "All items" option to clear filter - Replaces the old signal-strip cards Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
334
services/brain/app/api/taxonomy.py
Normal file
334
services/brain/app/api/taxonomy.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Taxonomy API — folder and tag CRUD, sidebar data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, func, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_user_id, get_db_session
|
||||
from app.models.taxonomy import Folder, Tag, ItemTag, slugify, ensure_user_taxonomy
|
||||
from app.models.item import Item
|
||||
|
||||
router = APIRouter(prefix="/api/taxonomy", tags=["taxonomy"])
|
||||
|
||||
|
||||
# ── Schemas ──
|
||||
|
||||
class FolderIn(BaseModel):
|
||||
name: str
|
||||
|
||||
class FolderOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
slug: str
|
||||
is_active: bool
|
||||
sort_order: int
|
||||
item_count: int = 0
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
class TagIn(BaseModel):
|
||||
name: str
|
||||
|
||||
class TagOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
slug: str
|
||||
is_active: bool
|
||||
sort_order: int
|
||||
item_count: int = 0
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
class SidebarOut(BaseModel):
|
||||
folders: list[FolderOut]
|
||||
tags: list[TagOut]
|
||||
total_items: int
|
||||
|
||||
class MergeTagIn(BaseModel):
|
||||
source_tag_id: str
|
||||
target_tag_id: str
|
||||
|
||||
|
||||
# ── Sidebar data ──
|
||||
|
||||
@router.get("/sidebar", response_model=SidebarOut)
|
||||
async def get_sidebar(
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
await ensure_user_taxonomy(db, user_id)
|
||||
|
||||
# Folders with counts
|
||||
folder_rows = (await db.execute(
|
||||
select(Folder).where(Folder.user_id == user_id).order_by(Folder.sort_order, Folder.name)
|
||||
)).scalars().all()
|
||||
|
||||
folder_counts = {}
|
||||
for row in (await db.execute(
|
||||
select(Item.folder_id, func.count()).where(
|
||||
Item.user_id == user_id, Item.folder_id.isnot(None)
|
||||
).group_by(Item.folder_id)
|
||||
)).all():
|
||||
folder_counts[row[0]] = row[1]
|
||||
|
||||
folders = [FolderOut(
|
||||
id=f.id, name=f.name, slug=f.slug, is_active=f.is_active,
|
||||
sort_order=f.sort_order, item_count=folder_counts.get(f.id, 0)
|
||||
) for f in folder_rows]
|
||||
|
||||
# Tags with counts
|
||||
tag_rows = (await db.execute(
|
||||
select(Tag).where(Tag.user_id == user_id).order_by(Tag.sort_order, Tag.name)
|
||||
)).scalars().all()
|
||||
|
||||
tag_counts = {}
|
||||
for row in (await db.execute(
|
||||
select(ItemTag.tag_id, func.count()).join(Item, Item.id == ItemTag.item_id).where(
|
||||
Item.user_id == user_id
|
||||
).group_by(ItemTag.tag_id)
|
||||
)).all():
|
||||
tag_counts[row[0]] = row[1]
|
||||
|
||||
tags = [TagOut(
|
||||
id=t.id, name=t.name, slug=t.slug, is_active=t.is_active,
|
||||
sort_order=t.sort_order, item_count=tag_counts.get(t.id, 0)
|
||||
) for t in tag_rows]
|
||||
|
||||
# Total items
|
||||
total = (await db.execute(
|
||||
select(func.count()).where(Item.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
|
||||
return SidebarOut(folders=folders, tags=tags, total_items=total)
|
||||
|
||||
|
||||
# ── Folder CRUD ──
|
||||
|
||||
@router.get("/folders", response_model=list[FolderOut])
|
||||
async def list_folders(
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
await ensure_user_taxonomy(db, user_id)
|
||||
rows = (await db.execute(
|
||||
select(Folder).where(Folder.user_id == user_id).order_by(Folder.sort_order, Folder.name)
|
||||
)).scalars().all()
|
||||
return [FolderOut(id=f.id, name=f.name, slug=f.slug, is_active=f.is_active, sort_order=f.sort_order) for f in rows]
|
||||
|
||||
|
||||
@router.post("/folders", response_model=FolderOut, status_code=201)
|
||||
async def create_folder(
|
||||
body: FolderIn,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
slug = slugify(body.name)
|
||||
existing = (await db.execute(
|
||||
select(Folder).where(Folder.user_id == user_id, Folder.slug == slug)
|
||||
)).scalar_one_or_none()
|
||||
if existing:
|
||||
raise HTTPException(400, "Folder already exists")
|
||||
|
||||
max_order = (await db.execute(
|
||||
select(func.max(Folder.sort_order)).where(Folder.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
|
||||
folder = Folder(id=str(uuid.uuid4()), user_id=user_id, name=body.name.strip(), slug=slug, sort_order=max_order + 1)
|
||||
db.add(folder)
|
||||
await db.commit()
|
||||
return FolderOut(id=folder.id, name=folder.name, slug=folder.slug, is_active=folder.is_active, sort_order=folder.sort_order)
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}", response_model=FolderOut)
|
||||
async def update_folder(
|
||||
folder_id: str,
|
||||
body: FolderIn,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
folder = (await db.execute(
|
||||
select(Folder).where(Folder.id == folder_id, Folder.user_id == user_id)
|
||||
)).scalar_one_or_none()
|
||||
if not folder:
|
||||
raise HTTPException(404, "Folder not found")
|
||||
|
||||
folder.name = body.name.strip()
|
||||
folder.slug = slugify(body.name)
|
||||
folder.updated_at = datetime.utcnow()
|
||||
|
||||
# Update denormalized folder name on items
|
||||
await db.execute(
|
||||
update(Item).where(Item.folder_id == folder_id).values(folder=folder.name)
|
||||
)
|
||||
await db.commit()
|
||||
return FolderOut(id=folder.id, name=folder.name, slug=folder.slug, is_active=folder.is_active, sort_order=folder.sort_order)
|
||||
|
||||
|
||||
@router.delete("/folders/{folder_id}")
|
||||
async def delete_folder(
|
||||
folder_id: str,
|
||||
fallback_folder_id: Optional[str] = Query(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
folder = (await db.execute(
|
||||
select(Folder).where(Folder.id == folder_id, Folder.user_id == user_id)
|
||||
)).scalar_one_or_none()
|
||||
if not folder:
|
||||
raise HTTPException(404, "Folder not found")
|
||||
|
||||
# Move items to fallback folder or first available folder
|
||||
if fallback_folder_id:
|
||||
fallback = (await db.execute(select(Folder).where(Folder.id == fallback_folder_id))).scalar_one_or_none()
|
||||
else:
|
||||
fallback = (await db.execute(
|
||||
select(Folder).where(Folder.user_id == user_id, Folder.id != folder_id).order_by(Folder.sort_order).limit(1)
|
||||
)).scalar_one_or_none()
|
||||
|
||||
if fallback:
|
||||
await db.execute(
|
||||
update(Item).where(Item.folder_id == folder_id).values(folder_id=fallback.id, folder=fallback.name)
|
||||
)
|
||||
|
||||
await db.execute(delete(Folder).where(Folder.id == folder_id))
|
||||
await db.commit()
|
||||
return {"status": "deleted", "items_moved_to": fallback.name if fallback else None}
|
||||
|
||||
|
||||
# ── Tag CRUD ──
|
||||
|
||||
@router.get("/tags", response_model=list[TagOut])
|
||||
async def list_tags(
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
await ensure_user_taxonomy(db, user_id)
|
||||
rows = (await db.execute(
|
||||
select(Tag).where(Tag.user_id == user_id).order_by(Tag.sort_order, Tag.name)
|
||||
)).scalars().all()
|
||||
return [TagOut(id=t.id, name=t.name, slug=t.slug, is_active=t.is_active, sort_order=t.sort_order) for t in rows]
|
||||
|
||||
|
||||
@router.post("/tags", response_model=TagOut, status_code=201)
|
||||
async def create_tag(
|
||||
body: TagIn,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
slug = slugify(body.name)
|
||||
existing = (await db.execute(
|
||||
select(Tag).where(Tag.user_id == user_id, Tag.slug == slug)
|
||||
)).scalar_one_or_none()
|
||||
if existing:
|
||||
raise HTTPException(400, "Tag already exists")
|
||||
|
||||
max_order = (await db.execute(
|
||||
select(func.max(Tag.sort_order)).where(Tag.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
|
||||
tag = Tag(id=str(uuid.uuid4()), user_id=user_id, name=body.name.strip(), slug=slug, sort_order=max_order + 1)
|
||||
db.add(tag)
|
||||
await db.commit()
|
||||
return TagOut(id=tag.id, name=tag.name, slug=tag.slug, is_active=tag.is_active, sort_order=tag.sort_order)
|
||||
|
||||
|
||||
@router.patch("/tags/{tag_id}", response_model=TagOut)
|
||||
async def update_tag(
|
||||
tag_id: str,
|
||||
body: TagIn,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
tag = (await db.execute(
|
||||
select(Tag).where(Tag.id == tag_id, Tag.user_id == user_id)
|
||||
)).scalar_one_or_none()
|
||||
if not tag:
|
||||
raise HTTPException(404, "Tag not found")
|
||||
|
||||
old_name = tag.name
|
||||
tag.name = body.name.strip()
|
||||
tag.slug = slugify(body.name)
|
||||
tag.updated_at = datetime.utcnow()
|
||||
|
||||
# Update denormalized tags array on items that had the old name
|
||||
items_with_tag = (await db.execute(
|
||||
select(Item).join(ItemTag, ItemTag.item_id == Item.id).where(ItemTag.tag_id == tag_id)
|
||||
)).scalars().all()
|
||||
for item in items_with_tag:
|
||||
if item.tags and old_name in item.tags:
|
||||
item.tags = [tag.name if t == old_name else t for t in item.tags]
|
||||
|
||||
await db.commit()
|
||||
return TagOut(id=tag.id, name=tag.name, slug=tag.slug, is_active=tag.is_active, sort_order=tag.sort_order)
|
||||
|
||||
|
||||
@router.delete("/tags/{tag_id}")
|
||||
async def delete_tag(
|
||||
tag_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
tag = (await db.execute(
|
||||
select(Tag).where(Tag.id == tag_id, Tag.user_id == user_id)
|
||||
)).scalar_one_or_none()
|
||||
if not tag:
|
||||
raise HTTPException(404, "Tag not found")
|
||||
|
||||
# Remove tag from denormalized arrays
|
||||
items_with_tag = (await db.execute(
|
||||
select(Item).join(ItemTag, ItemTag.item_id == Item.id).where(ItemTag.tag_id == tag_id)
|
||||
)).scalars().all()
|
||||
for item in items_with_tag:
|
||||
if item.tags and tag.name in item.tags:
|
||||
item.tags = [t for t in item.tags if t != tag.name]
|
||||
|
||||
# Delete join table entries and tag
|
||||
await db.execute(delete(ItemTag).where(ItemTag.tag_id == tag_id))
|
||||
await db.execute(delete(Tag).where(Tag.id == tag_id))
|
||||
await db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/tags/merge")
|
||||
async def merge_tags(
|
||||
body: MergeTagIn,
|
||||
user_id: str = Depends(get_user_id),
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
"""Merge source tag into target tag. All items with source get target instead."""
|
||||
source = (await db.execute(select(Tag).where(Tag.id == body.source_tag_id, Tag.user_id == user_id))).scalar_one_or_none()
|
||||
target = (await db.execute(select(Tag).where(Tag.id == body.target_tag_id, Tag.user_id == user_id))).scalar_one_or_none()
|
||||
if not source or not target:
|
||||
raise HTTPException(404, "Tag not found")
|
||||
|
||||
# Move item_tags from source to target (skip duplicates)
|
||||
source_items = (await db.execute(
|
||||
select(ItemTag.item_id).where(ItemTag.tag_id == source.id)
|
||||
)).scalars().all()
|
||||
target_items = set((await db.execute(
|
||||
select(ItemTag.item_id).where(ItemTag.tag_id == target.id)
|
||||
)).scalars().all())
|
||||
|
||||
for item_id in source_items:
|
||||
if item_id not in target_items:
|
||||
db.add(ItemTag(item_id=item_id, tag_id=target.id))
|
||||
|
||||
# Update denormalized tags
|
||||
items = (await db.execute(
|
||||
select(Item).join(ItemTag, ItemTag.item_id == Item.id).where(ItemTag.tag_id == source.id)
|
||||
)).scalars().all()
|
||||
for item in items:
|
||||
if item.tags:
|
||||
new_tags = [target.name if t == source.name else t for t in item.tags]
|
||||
item.tags = list(dict.fromkeys(new_tags)) # dedupe
|
||||
|
||||
# Delete source
|
||||
await db.execute(delete(ItemTag).where(ItemTag.tag_id == source.id))
|
||||
await db.execute(delete(Tag).where(Tag.id == source.id))
|
||||
await db.commit()
|
||||
return {"status": "merged", "source": source.name, "target": target.name}
|
||||
@@ -6,6 +6,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import router
|
||||
from app.api.taxonomy import router as taxonomy_router
|
||||
from app.config import DEBUG
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -23,6 +24,7 @@ app = FastAPI(
|
||||
|
||||
# No CORS — internal service only, accessed via gateway
|
||||
app.include_router(router)
|
||||
app.include_router(taxonomy_router)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -30,6 +32,7 @@ async def startup():
|
||||
from sqlalchemy import text as sa_text
|
||||
from app.database import engine, Base
|
||||
from app.models.item import Item, ItemAsset, AppLink # noqa: import to register models
|
||||
from app.models.taxonomy import Folder, Tag, ItemTag # noqa: register taxonomy tables
|
||||
|
||||
# Enable pgvector extension before creating tables
|
||||
async with engine.begin() as conn:
|
||||
|
||||
@@ -28,8 +28,9 @@ class Item(Base):
|
||||
url = Column(Text, nullable=True)
|
||||
raw_content = Column(Text, nullable=True) # original user input (note body, etc.)
|
||||
extracted_text = Column(Text, nullable=True) # full extracted text from page/doc
|
||||
folder = Column(String(64), nullable=True)
|
||||
tags = Column(ARRAY(String), nullable=True, default=list)
|
||||
folder_id = Column(UUID(as_uuid=False), ForeignKey("folders.id", ondelete="SET NULL"), nullable=True)
|
||||
folder = Column(String(64), nullable=True) # denormalized folder name for fast reads
|
||||
tags = Column(ARRAY(String), nullable=True, default=list) # denormalized tag names for fast reads
|
||||
summary = Column(Text, nullable=True)
|
||||
confidence = Column(Float, nullable=True)
|
||||
metadata_json = Column(JSONB, nullable=True, default=dict)
|
||||
|
||||
103
services/brain/app/models/taxonomy.py
Normal file
103
services/brain/app/models/taxonomy.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Database models for folders and tags — editable taxonomy."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime, ForeignKey, Index, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
def new_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def slugify(text: str) -> str:
|
||||
s = text.lower().strip()
|
||||
s = re.sub(r'[^\w\s-]', '', s)
|
||||
s = re.sub(r'[\s_]+', '-', s)
|
||||
return s.strip('-')
|
||||
|
||||
|
||||
class Folder(Base):
|
||||
__tablename__ = "folders"
|
||||
|
||||
id = Column(UUID(as_uuid=False), primary_key=True, default=new_id)
|
||||
user_id = Column(String(64), nullable=False, index=True)
|
||||
name = Column(String(128), nullable=False)
|
||||
slug = Column(String(128), nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
sort_order = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('user_id', 'slug', name='uq_folder_user_slug'),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
id = Column(UUID(as_uuid=False), primary_key=True, default=new_id)
|
||||
user_id = Column(String(64), nullable=False, index=True)
|
||||
name = Column(String(128), nullable=False)
|
||||
slug = Column(String(128), nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
sort_order = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('user_id', 'slug', name='uq_tag_user_slug'),
|
||||
)
|
||||
|
||||
|
||||
class ItemTag(Base):
|
||||
__tablename__ = "item_tags"
|
||||
|
||||
item_id = Column(UUID(as_uuid=False), ForeignKey("items.id", ondelete="CASCADE"), primary_key=True)
|
||||
tag_id = Column(UUID(as_uuid=False), ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
# Default folders to seed for new users
|
||||
DEFAULT_FOLDERS = ["Home", "Family", "Work", "Travel", "Knowledge", "Faith", "Projects"]
|
||||
|
||||
# Default tags to seed for new users
|
||||
DEFAULT_TAGS = [
|
||||
"reference", "important", "legal", "financial", "insurance",
|
||||
"research", "idea", "guide", "tutorial", "setup", "how-to",
|
||||
"tools", "dev", "server", "selfhosted", "home-assistant",
|
||||
"shopping", "compare", "buy", "product",
|
||||
"family", "kids", "health", "travel", "faith",
|
||||
"video", "read-later", "books",
|
||||
]
|
||||
|
||||
|
||||
async def ensure_user_taxonomy(db, user_id: str):
|
||||
"""Seed default folders and tags for a new user if they have none."""
|
||||
from sqlalchemy import select, func
|
||||
|
||||
folder_count = (await db.execute(
|
||||
select(func.count()).where(Folder.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
|
||||
if folder_count == 0:
|
||||
for i, name in enumerate(DEFAULT_FOLDERS):
|
||||
db.add(Folder(id=new_id(), user_id=user_id, name=name, slug=slugify(name), sort_order=i))
|
||||
await db.flush()
|
||||
|
||||
tag_count = (await db.execute(
|
||||
select(func.count()).where(Tag.user_id == user_id)
|
||||
)).scalar() or 0
|
||||
|
||||
if tag_count == 0:
|
||||
for i, name in enumerate(DEFAULT_TAGS):
|
||||
db.add(Tag(id=new_id(), user_id=user_id, name=name, slug=slugify(name), sort_order=i))
|
||||
await db.flush()
|
||||
|
||||
await db.commit()
|
||||
@@ -5,15 +5,17 @@ import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import OPENAI_API_KEY, OPENAI_MODEL, FOLDERS, TAGS
|
||||
from app.config import OPENAI_API_KEY, OPENAI_MODEL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = f"""You are a classification engine for a personal "second brain" knowledge management system.
|
||||
|
||||
def build_system_prompt(folders: list[str], tags: list[str]) -> str:
|
||||
return f"""You are a classification engine for a personal "second brain" knowledge management system.
|
||||
|
||||
Given an item (URL, note, document, or file), you must return structured JSON with:
|
||||
- folder: exactly 1 from this list: {json.dumps(FOLDERS)}
|
||||
- tags: exactly 2 or 3 from this list: {json.dumps(TAGS)}
|
||||
- folder: exactly 1 from this list: {json.dumps(folders)}
|
||||
- tags: exactly 2 or 3 from this list: {json.dumps(tags)}
|
||||
- title: a concise, normalized title (max 80 chars)
|
||||
- summary: a 1-2 sentence summary of the content (for links/documents only)
|
||||
- corrected_text: for NOTES ONLY — return the original note text with spelling/grammar fixed. Keep the original meaning, tone, and structure. Only fix typos and obvious errors. Return empty string for non-notes.
|
||||
@@ -27,7 +29,9 @@ Rules:
|
||||
- For notes: the summary field should be a very short 5-10 word description, not a rewrite.
|
||||
- Always return valid JSON matching the schema exactly"""
|
||||
|
||||
RESPONSE_SCHEMA = {
|
||||
|
||||
def build_response_schema(folders: list[str], tags: list[str]) -> dict:
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "classification",
|
||||
@@ -35,10 +39,10 @@ RESPONSE_SCHEMA = {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder": {"type": "string", "enum": FOLDERS},
|
||||
"folder": {"type": "string", "enum": folders},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": TAGS},
|
||||
"items": {"type": "string", "enum": tags},
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
},
|
||||
@@ -72,9 +76,15 @@ async def classify_item(
|
||||
url: str | None = None,
|
||||
title: str | None = None,
|
||||
text: str | None = None,
|
||||
folders: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
retries: int = 2,
|
||||
) -> dict:
|
||||
"""Call OpenAI to classify an item. Returns dict with folder, tags, title, summary, confidence."""
|
||||
from app.config import FOLDERS, TAGS
|
||||
folders = folders or FOLDERS
|
||||
tags = tags or TAGS
|
||||
|
||||
if not OPENAI_API_KEY:
|
||||
log.warning("No OPENAI_API_KEY set, returning defaults")
|
||||
return {
|
||||
@@ -86,6 +96,8 @@ async def classify_item(
|
||||
}
|
||||
|
||||
user_msg = build_user_prompt(item_type, url, title, text)
|
||||
system_prompt = build_system_prompt(folders, tags)
|
||||
response_schema = build_response_schema(folders, tags)
|
||||
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
@@ -96,10 +108,10 @@ async def classify_item(
|
||||
json={
|
||||
"model": OPENAI_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_msg},
|
||||
],
|
||||
"response_format": RESPONSE_SCHEMA,
|
||||
"response_format": response_schema,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
)
|
||||
@@ -109,9 +121,9 @@ async def classify_item(
|
||||
result = json.loads(content)
|
||||
|
||||
# Validate folder and tags are in allowed sets
|
||||
if result["folder"] not in FOLDERS:
|
||||
result["folder"] = "Knowledge"
|
||||
result["tags"] = [t for t in result["tags"] if t in TAGS][:3]
|
||||
if result["folder"] not in folders:
|
||||
result["folder"] = folders[0] if folders else "Knowledge"
|
||||
result["tags"] = [t for t in result["tags"] if t in tags][:3]
|
||||
if len(result["tags"]) < 2:
|
||||
result["tags"] = (result["tags"] + ["reference", "read-later"])[:3]
|
||||
|
||||
|
||||
@@ -159,18 +159,53 @@ async def _process_item(item_id: str):
|
||||
if result.get("page_count"):
|
||||
item.metadata_json["page_count"] = result["page_count"]
|
||||
|
||||
# ── Step 2: AI classification ──
|
||||
log.info(f"Classifying item {item.id}")
|
||||
# ── Step 2: Fetch live taxonomy from DB ──
|
||||
from app.models.taxonomy import Folder as FolderModel, Tag as TagModel, ItemTag, ensure_user_taxonomy
|
||||
await ensure_user_taxonomy(db, item.user_id)
|
||||
|
||||
active_folders = (await db.execute(
|
||||
select(FolderModel).where(FolderModel.user_id == item.user_id, FolderModel.is_active == True)
|
||||
.order_by(FolderModel.sort_order)
|
||||
)).scalars().all()
|
||||
active_tags = (await db.execute(
|
||||
select(TagModel).where(TagModel.user_id == item.user_id, TagModel.is_active == True)
|
||||
.order_by(TagModel.sort_order)
|
||||
)).scalars().all()
|
||||
|
||||
folder_names = [f.name for f in active_folders]
|
||||
tag_names = [t.name for t in active_tags]
|
||||
folder_map = {f.name: f for f in active_folders}
|
||||
tag_map = {t.name: t for t in active_tags}
|
||||
|
||||
# ── Step 3: AI classification ──
|
||||
log.info(f"Classifying item {item.id} with {len(folder_names)} folders, {len(tag_names)} tags")
|
||||
classification = await classify_item(
|
||||
item_type=item.type,
|
||||
url=item.url,
|
||||
title=title,
|
||||
text=extracted_text,
|
||||
folders=folder_names,
|
||||
tags=tag_names,
|
||||
)
|
||||
|
||||
item.title = classification.get("title") or title or "Untitled"
|
||||
item.folder = classification.get("folder", "Knowledge")
|
||||
item.tags = classification.get("tags", ["reference", "read-later"])
|
||||
|
||||
# Set folder (relational + denormalized)
|
||||
classified_folder = classification.get("folder", folder_names[0] if folder_names else "Knowledge")
|
||||
item.folder = classified_folder
|
||||
if classified_folder in folder_map:
|
||||
item.folder_id = folder_map[classified_folder].id
|
||||
|
||||
# Set tags (relational + denormalized)
|
||||
classified_tags = classification.get("tags", [])
|
||||
item.tags = classified_tags
|
||||
|
||||
# Clear old item_tags and create new ones
|
||||
from sqlalchemy import delete as sa_delete
|
||||
await db.execute(sa_delete(ItemTag).where(ItemTag.item_id == item.id))
|
||||
for tag_name in classified_tags:
|
||||
if tag_name in tag_map:
|
||||
db.add(ItemTag(item_id=item.id, tag_id=tag_map[tag_name].id))
|
||||
|
||||
# For notes: replace raw_content with spell-corrected version
|
||||
corrected = classification.get("corrected_text", "")
|
||||
|
||||
Reference in New Issue
Block a user