Phase 4: Documents & Memory — upload, FTS, AI tools, context injection

Backend:
- Document + MemoryEntry models with Alembic migration (GIN FTS index)
- File upload endpoint with path traversal protection (sanitized filenames)
- Background document text extraction (PyMuPDF)
- Full-text search on extracted_text via PostgreSQL tsvector/tsquery
- Memory CRUD with enum-validated categories/importance, field allow-list
- AI tools: save_memory, search_documents, get_memory (Claude function calling)
- Tool execution loop in stream_ai_response (multi-turn tool use)
- Context assembly: injects critical memory + relevant doc excerpts
- File storage abstraction (local filesystem, S3-swappable)
- Secure file deletion (DB flush before disk delete)

Frontend:
- Document upload dialog (drag-and-drop + file picker)
- Document list with status badges, search, download (via authenticated blob)
- Document viewer with extracted text preview
- Memory list grouped by category with importance color coding
- Memory editor with category/importance dropdowns
- Documents + Memory pages with full CRUD
- Enabled sidebar navigation for both sections

Review fixes applied:
- Sanitized upload filenames (path traversal prevention)
- Download via axios blob (not bare <a href>, preserves auth)
- Route ordering: /search before /{id}/reindex
- Memory update allows is_active=False + field allow-list
- MemoryEditor form resets on mode switch
- Literal enum validation on category/importance schemas
- DB flush before file deletion for data integrity

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-19 13:46:59 +03:00
parent 03afb7a075
commit 8b8fe916f0
37 changed files with 1921 additions and 26 deletions

View File

@@ -12,9 +12,100 @@ from app.models.message import Message
from app.models.skill import Skill
from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context, get_personal_context
from app.services.chat_service import get_chat, save_message
from app.services.memory_service import get_critical_memories, create_memory, get_user_memories
from app.services.document_service import search_documents
client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
# --- AI Tool Definitions ---
AI_TOOLS = [
{
"name": "save_memory",
"description": "Save important health information to the user's memory. Use this when the user shares critical health data like conditions, medications, allergies, or important health facts.",
"input_schema": {
"type": "object",
"properties": {
"category": {
"type": "string",
"enum": ["condition", "medication", "allergy", "vital", "document_summary", "other"],
"description": "Category of the memory entry",
},
"title": {"type": "string", "description": "Short title for the memory entry"},
"content": {"type": "string", "description": "Detailed content of the memory entry"},
"importance": {
"type": "string",
"enum": ["critical", "high", "medium", "low"],
"description": "Importance level",
},
},
"required": ["category", "title", "content", "importance"],
},
},
{
"name": "search_documents",
"description": "Search the user's uploaded health documents for relevant information. Use this when you need to find specific health records, lab results, or consultation notes.",
"input_schema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query to find relevant documents"},
},
"required": ["query"],
},
},
{
"name": "get_memory",
"description": "Retrieve the user's stored health memories filtered by category. Use this to recall previously saved health information.",
"input_schema": {
"type": "object",
"properties": {
"category": {
"type": "string",
"enum": ["condition", "medication", "allergy", "vital", "document_summary", "other"],
"description": "Optional category filter. Omit to get all memories.",
},
},
"required": [],
},
},
]
async def _execute_tool(
db: AsyncSession, user_id: uuid.UUID, tool_name: str, tool_input: dict
) -> str:
"""Execute an AI tool and return the result as a string."""
if tool_name == "save_memory":
entry = await create_memory(
db, user_id,
category=tool_input["category"],
title=tool_input["title"],
content=tool_input["content"],
importance=tool_input["importance"],
)
await db.commit()
return json.dumps({"status": "saved", "id": str(entry.id), "title": entry.title})
elif tool_name == "search_documents":
docs = await search_documents(db, user_id, tool_input["query"], limit=5)
results = []
for doc in docs:
excerpt = (doc.extracted_text or "")[:1000]
results.append({
"filename": doc.original_filename,
"doc_type": doc.doc_type,
"excerpt": excerpt,
})
return json.dumps({"results": results, "count": len(results)})
elif tool_name == "get_memory":
category = tool_input.get("category")
entries = await get_user_memories(db, user_id, category=category, is_active=True)
items = [{"category": e.category, "title": e.title, "content": e.content, "importance": e.importance} for e in entries]
return json.dumps({"entries": items, "count": len(items)})
return json.dumps({"error": f"Unknown tool: {tool_name}"})
async def assemble_context(
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
@@ -39,9 +130,25 @@ async def assemble_context(
if skill and skill.is_active:
system_parts.append(f"---\nSpecialist Role ({skill.name}):\n{skill.system_prompt}")
# 4. Critical memory entries
memories = await get_critical_memories(db, user_id)
if memories:
memory_lines = [f"- [{m.category}] {m.title}: {m.content}" for m in memories]
system_parts.append(f"---\nUser Health Profile:\n" + "\n".join(memory_lines))
# 5. Relevant document excerpts (based on user message keywords)
if user_message.strip():
docs = await search_documents(db, user_id, user_message, limit=3)
if docs:
doc_lines = []
for d in docs:
excerpt = (d.extracted_text or "")[:1500]
doc_lines.append(f"[{d.original_filename} ({d.doc_type})]\n{excerpt}")
system_parts.append(f"---\nRelevant Document Excerpts:\n" + "\n\n".join(doc_lines))
system_prompt = "\n\n".join(system_parts)
# 4. Conversation history
# 6. Conversation history
result = await db.execute(
select(Message)
.where(Message.chat_id == chat_id, Message.role.in_(["user", "assistant"]))
@@ -50,7 +157,7 @@ async def assemble_context(
history = result.scalars().all()
messages = [{"role": msg.role, "content": msg.content} for msg in history]
# 5. Current user message
# 7. Current user message
messages.append({"role": "user", "content": user_message})
return system_prompt, messages
@@ -63,7 +170,7 @@ def _sse_event(event: str, data: dict) -> str:
async def stream_ai_response(
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
) -> AsyncGenerator[str, None]:
"""Stream AI response as SSE events."""
"""Stream AI response as SSE events, with tool use support."""
# Verify ownership
chat = await get_chat(db, chat_id, user_id)
@@ -75,28 +182,53 @@ async def stream_ai_response(
# Assemble context
system_prompt, messages = await assemble_context(db, chat_id, user_id, user_message)
# Stream from Claude
full_content = ""
assistant_msg_id = str(uuid.uuid4())
yield _sse_event("message_start", {"message_id": assistant_msg_id})
async with client.messages.stream(
model=settings.CLAUDE_MODEL,
max_tokens=4096,
system=system_prompt,
messages=messages,
) as stream:
async for text in stream.text_stream:
full_content += text
yield _sse_event("content_delta", {"delta": text})
# Tool use loop
full_content = ""
max_tool_rounds = 5
for _ in range(max_tool_rounds):
response = await client.messages.create(
model=settings.CLAUDE_MODEL,
max_tokens=4096,
system=system_prompt,
messages=messages,
tools=AI_TOOLS,
)
# Process content blocks
tool_use_blocks = []
for block in response.content:
if block.type == "text":
full_content += block.text
yield _sse_event("content_delta", {"delta": block.text})
elif block.type == "tool_use":
tool_use_blocks.append(block)
yield _sse_event("tool_use", {"tool": block.name, "input": block.input})
# If no tool use, we're done
if response.stop_reason != "tool_use" or not tool_use_blocks:
break
# Execute tools and continue conversation
messages.append({"role": "assistant", "content": response.content})
tool_results = []
for tool_block in tool_use_blocks:
result = await _execute_tool(db, user_id, tool_block.name, tool_block.input)
tool_results.append({
"type": "tool_result",
"tool_use_id": tool_block.id,
"content": result,
})
yield _sse_event("tool_result", {"tool": tool_block.name, "result": result})
messages.append({"role": "user", "content": tool_results})
# Get final message for metadata
final_message = await stream.get_final_message()
metadata = {
"model": final_message.model,
"input_tokens": final_message.usage.input_tokens,
"output_tokens": final_message.usage.output_tokens,
"model": response.model,
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
}
# Save assistant message
@@ -109,7 +241,6 @@ async def stream_ai_response(
)
assistant_count = len(result.scalars().all())
if assistant_count == 1 and chat.title == "New Chat":
# Auto-generate title from first few words
title = full_content[:50].split("\n")[0].strip()
if len(title) > 40:
title = title[:40] + "..."

View File

@@ -0,0 +1,97 @@
import uuid
from fastapi import HTTPException, status
from sqlalchemy import func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.document import Document
from app.utils.file_storage import delete_file
async def create_document(
db: AsyncSession,
user_id: uuid.UUID,
filename: str,
original_filename: str,
storage_path: str,
mime_type: str,
file_size: int,
doc_type: str = "other",
) -> Document:
doc = Document(
user_id=user_id,
filename=filename,
original_filename=original_filename,
storage_path=storage_path,
mime_type=mime_type,
file_size=file_size,
doc_type=doc_type,
processing_status="pending",
)
db.add(doc)
await db.flush()
return doc
async def get_document(db: AsyncSession, doc_id: uuid.UUID, user_id: uuid.UUID) -> Document:
result = await db.execute(
select(Document).where(Document.id == doc_id, Document.user_id == user_id)
)
doc = result.scalar_one_or_none()
if not doc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
return doc
async def get_user_documents(
db: AsyncSession,
user_id: uuid.UUID,
doc_type: str | None = None,
processing_status: str | None = None,
) -> list[Document]:
stmt = select(Document).where(Document.user_id == user_id)
if doc_type:
stmt = stmt.where(Document.doc_type == doc_type)
if processing_status:
stmt = stmt.where(Document.processing_status == processing_status)
stmt = stmt.order_by(Document.created_at.desc())
result = await db.execute(stmt)
return list(result.scalars().all())
async def delete_document(db: AsyncSession, doc_id: uuid.UUID, user_id: uuid.UUID) -> None:
doc = await get_document(db, doc_id, user_id)
storage_path = doc.storage_path
await db.delete(doc)
await db.flush()
delete_file(storage_path)
async def search_documents(db: AsyncSession, user_id: uuid.UUID, query: str, limit: int = 5) -> list[Document]:
stmt = (
select(Document)
.where(
Document.user_id == user_id,
Document.processing_status == "completed",
text("to_tsvector('english', coalesce(extracted_text, '')) @@ plainto_tsquery('english', :query)"),
)
.params(query=query)
.order_by(
text("ts_rank(to_tsvector('english', coalesce(extracted_text, '')), plainto_tsquery('english', :query)) DESC")
)
.params(query=query)
.limit(limit)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def update_document_text(
db: AsyncSession, doc_id: uuid.UUID, extracted_text: str, status_val: str = "completed"
) -> None:
result = await db.execute(select(Document).where(Document.id == doc_id))
doc = result.scalar_one_or_none()
if doc:
doc.extracted_text = extracted_text
doc.processing_status = status_val
await db.flush()

View File

@@ -0,0 +1,71 @@
import uuid
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory_entry import MemoryEntry
async def create_memory(db: AsyncSession, user_id: uuid.UUID, **kwargs) -> MemoryEntry:
entry = MemoryEntry(user_id=user_id, **kwargs)
db.add(entry)
await db.flush()
return entry
async def get_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> MemoryEntry:
result = await db.execute(
select(MemoryEntry).where(MemoryEntry.id == entry_id, MemoryEntry.user_id == user_id)
)
entry = result.scalar_one_or_none()
if not entry:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Memory entry not found")
return entry
async def get_user_memories(
db: AsyncSession,
user_id: uuid.UUID,
category: str | None = None,
importance: str | None = None,
is_active: bool | None = None,
) -> list[MemoryEntry]:
stmt = select(MemoryEntry).where(MemoryEntry.user_id == user_id)
if category:
stmt = stmt.where(MemoryEntry.category == category)
if importance:
stmt = stmt.where(MemoryEntry.importance == importance)
if is_active is not None:
stmt = stmt.where(MemoryEntry.is_active == is_active)
stmt = stmt.order_by(MemoryEntry.created_at.desc())
result = await db.execute(stmt)
return list(result.scalars().all())
ALLOWED_UPDATE_FIELDS = {"category", "title", "content", "importance", "is_active"}
async def update_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID, **kwargs) -> MemoryEntry:
entry = await get_memory(db, entry_id, user_id)
for key, value in kwargs.items():
if key in ALLOWED_UPDATE_FIELDS:
setattr(entry, key, value)
await db.flush()
return entry
async def delete_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> None:
entry = await get_memory(db, entry_id, user_id)
await db.delete(entry)
async def get_critical_memories(db: AsyncSession, user_id: uuid.UUID) -> list[MemoryEntry]:
result = await db.execute(
select(MemoryEntry).where(
MemoryEntry.user_id == user_id,
MemoryEntry.is_active == True, # noqa: E712
MemoryEntry.importance.in_(["critical", "high"]),
).order_by(MemoryEntry.importance, MemoryEntry.created_at.desc())
)
return list(result.scalars().all())