Phase 4: Documents & Memory — upload, FTS, AI tools, context injection
Backend:
- Document + MemoryEntry models with Alembic migration (GIN FTS index)
- File upload endpoint with path traversal protection (sanitized filenames)
- Background document text extraction (PyMuPDF)
- Full-text search on extracted_text via PostgreSQL tsvector/tsquery
- Memory CRUD with enum-validated categories/importance, field allow-list
- AI tools: save_memory, search_documents, get_memory (Claude function calling)
- Tool execution loop in stream_ai_response (multi-turn tool use)
- Context assembly: injects critical memory + relevant doc excerpts
- File storage abstraction (local filesystem, S3-swappable)
- Secure file deletion (DB flush before disk delete)
Frontend:
- Document upload dialog (drag-and-drop + file picker)
- Document list with status badges, search, download (via authenticated blob)
- Document viewer with extracted text preview
- Memory list grouped by category with importance color coding
- Memory editor with category/importance dropdowns
- Documents + Memory pages with full CRUD
- Enabled sidebar navigation for both sections
Review fixes applied:
- Sanitized upload filenames (path traversal prevention)
- Download via axios blob (not bare <a href>, preserves auth)
- Route ordering: /search before /{id}/reindex
- Memory update allows is_active=False + field allow-list
- MemoryEditor form resets on mode switch
- Literal enum validation on category/importance schemas
- DB flush before file deletion for data integrity
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
"""Create documents and memory_entries tables
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"documents",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("filename", sa.String(255), nullable=False),
|
||||
sa.Column("original_filename", sa.String(255), nullable=False),
|
||||
sa.Column("storage_path", sa.Text, nullable=False),
|
||||
sa.Column("mime_type", sa.String(100), nullable=False),
|
||||
sa.Column("file_size", sa.BigInteger, nullable=False),
|
||||
sa.Column("doc_type", sa.String(50), nullable=False, server_default="other"),
|
||||
sa.Column("extracted_text", sa.Text, nullable=True),
|
||||
sa.Column("processing_status", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("metadata", JSONB, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX ix_documents_fts ON documents USING gin(to_tsvector('english', coalesce(extracted_text, '')))"
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"memory_entries",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("category", sa.String(50), nullable=False),
|
||||
sa.Column("title", sa.String(255), nullable=False),
|
||||
sa.Column("content", sa.Text, nullable=False),
|
||||
sa.Column("source_document_id", UUID(as_uuid=True), sa.ForeignKey("documents.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("importance", sa.String(20), nullable=False, server_default="medium"),
|
||||
sa.Column("is_active", sa.Boolean, nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("memory_entries")
|
||||
op.execute("DROP INDEX IF EXISTS ix_documents_fts")
|
||||
op.drop_table("documents")
|
||||
126
backend/app/api/v1/documents.py
Normal file
126
backend/app/api/v1/documents.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile, File, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.document import DocumentListResponse, DocumentResponse, DocumentSearchRequest
|
||||
from app.services import document_service
|
||||
from app.utils.file_storage import save_upload, get_file_path
|
||||
from app.workers.document_processor import process_document
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
ALLOWED_MIME_TYPES = [
|
||||
"application/pdf",
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/tiff",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
|
||||
@router.post("/", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_document(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
file: UploadFile = File(...),
|
||||
doc_type: str = Query(default="other"),
|
||||
):
|
||||
if file.content_type not in ALLOWED_MIME_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > settings.MAX_UPLOAD_SIZE_MB * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail=f"File too large. Max {settings.MAX_UPLOAD_SIZE_MB}MB")
|
||||
|
||||
doc_id = uuid.uuid4()
|
||||
safe_name = PurePosixPath(file.filename or "upload").name
|
||||
filename = f"{doc_id}_{safe_name}"
|
||||
storage_path = await save_upload(user.id, doc_id, filename, content)
|
||||
|
||||
doc = await document_service.create_document(
|
||||
db, user.id, filename, safe_name,
|
||||
storage_path, file.content_type or "application/octet-stream",
|
||||
len(content), doc_type,
|
||||
)
|
||||
|
||||
# Trigger background processing
|
||||
asyncio.create_task(process_document(doc.id, storage_path, file.content_type or ""))
|
||||
|
||||
return DocumentResponse.model_validate(doc)
|
||||
|
||||
|
||||
@router.get("/", response_model=DocumentListResponse)
|
||||
async def list_documents(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
doc_type: str | None = Query(default=None),
|
||||
processing_status: str | None = Query(default=None),
|
||||
):
|
||||
docs = await document_service.get_user_documents(db, user.id, doc_type, processing_status)
|
||||
return DocumentListResponse(documents=[DocumentResponse.model_validate(d) for d in docs])
|
||||
|
||||
|
||||
@router.get("/{doc_id}", response_model=DocumentResponse)
|
||||
async def get_document(
|
||||
doc_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
doc = await document_service.get_document(db, doc_id, user.id)
|
||||
return DocumentResponse.model_validate(doc)
|
||||
|
||||
|
||||
@router.get("/{doc_id}/download")
|
||||
async def download_document(
|
||||
doc_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
doc = await document_service.get_document(db, doc_id, user.id)
|
||||
file_path = get_file_path(doc.storage_path)
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=doc.original_filename,
|
||||
media_type=doc.mime_type,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{doc_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_document(
|
||||
doc_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
await document_service.delete_document(db, doc_id, user.id)
|
||||
|
||||
|
||||
@router.post("/search", response_model=DocumentListResponse)
|
||||
async def search_documents(
|
||||
data: DocumentSearchRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
docs = await document_service.search_documents(db, user.id, data.query)
|
||||
return DocumentListResponse(documents=[DocumentResponse.model_validate(d) for d in docs])
|
||||
|
||||
|
||||
@router.post("/{doc_id}/reindex", response_model=DocumentResponse)
|
||||
async def reindex_document(
|
||||
doc_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
doc = await document_service.get_document(db, doc_id, user.id)
|
||||
asyncio.create_task(process_document(doc.id, doc.storage_path, doc.mime_type))
|
||||
return DocumentResponse.model_validate(doc)
|
||||
70
backend/app/api/v1/memory.py
Normal file
70
backend/app/api/v1/memory.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.memory import (
|
||||
CreateMemoryRequest,
|
||||
MemoryEntryListResponse,
|
||||
MemoryEntryResponse,
|
||||
UpdateMemoryRequest,
|
||||
)
|
||||
from app.services import memory_service
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
|
||||
@router.post("/", response_model=MemoryEntryResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_memory(
|
||||
data: CreateMemoryRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
entry = await memory_service.create_memory(db, user.id, **data.model_dump())
|
||||
return MemoryEntryResponse.model_validate(entry)
|
||||
|
||||
|
||||
@router.get("/", response_model=MemoryEntryListResponse)
|
||||
async def list_memories(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
category: str | None = Query(default=None),
|
||||
importance: str | None = Query(default=None),
|
||||
is_active: bool | None = Query(default=None),
|
||||
):
|
||||
entries = await memory_service.get_user_memories(db, user.id, category, importance, is_active)
|
||||
return MemoryEntryListResponse(entries=[MemoryEntryResponse.model_validate(e) for e in entries])
|
||||
|
||||
|
||||
@router.get("/{entry_id}", response_model=MemoryEntryResponse)
|
||||
async def get_memory(
|
||||
entry_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
entry = await memory_service.get_memory(db, entry_id, user.id)
|
||||
return MemoryEntryResponse.model_validate(entry)
|
||||
|
||||
|
||||
@router.patch("/{entry_id}", response_model=MemoryEntryResponse)
|
||||
async def update_memory(
|
||||
entry_id: uuid.UUID,
|
||||
data: UpdateMemoryRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
entry = await memory_service.update_memory(db, entry_id, user.id, **data.model_dump(exclude_unset=True))
|
||||
return MemoryEntryResponse.model_validate(entry)
|
||||
|
||||
|
||||
@router.delete("/{entry_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_memory(
|
||||
entry_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
await memory_service.delete_memory(db, entry_id, user.id)
|
||||
@@ -5,6 +5,8 @@ from app.api.v1.chats import router as chats_router
|
||||
from app.api.v1.admin import router as admin_router
|
||||
from app.api.v1.skills import router as skills_router
|
||||
from app.api.v1.users import router as users_router
|
||||
from app.api.v1.documents import router as documents_router
|
||||
from app.api.v1.memory import router as memory_router
|
||||
|
||||
api_v1_router = APIRouter(prefix="/api/v1")
|
||||
|
||||
@@ -13,6 +15,8 @@ api_v1_router.include_router(chats_router)
|
||||
api_v1_router.include_router(admin_router)
|
||||
api_v1_router.include_router(skills_router)
|
||||
api_v1_router.include_router(users_router)
|
||||
api_v1_router.include_router(documents_router)
|
||||
api_v1_router.include_router(memory_router)
|
||||
|
||||
|
||||
@api_v1_router.get("/health")
|
||||
|
||||
@@ -15,6 +15,9 @@ class Settings(BaseSettings):
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
CLAUDE_MODEL: str = "claude-sonnet-4-20250514"
|
||||
|
||||
UPLOAD_DIR: str = "/data/uploads"
|
||||
MAX_UPLOAD_SIZE_MB: int = 20
|
||||
|
||||
FIRST_ADMIN_EMAIL: str = "admin@example.com"
|
||||
FIRST_ADMIN_USERNAME: str = "admin"
|
||||
FIRST_ADMIN_PASSWORD: str = "changeme_admin_password"
|
||||
|
||||
@@ -4,5 +4,7 @@ from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.models.context_file import ContextFile
|
||||
from app.models.skill import Skill
|
||||
from app.models.document import Document
|
||||
from app.models.memory_entry import MemoryEntry
|
||||
|
||||
__all__ = ["User", "Session", "Chat", "Message", "ContextFile", "Skill"]
|
||||
__all__ = ["User", "Session", "Chat", "Message", "ContextFile", "Skill", "Document", "MemoryEntry"]
|
||||
|
||||
33
backend/app/models/document.py
Normal file
33
backend/app/models/document.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, Index, String, Text, func, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_documents_fts",
|
||||
text("to_tsvector('english', coalesce(extracted_text, ''))"),
|
||||
postgresql_using="gin",
|
||||
),
|
||||
)
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
original_filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
file_size: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
doc_type: Mapped[str] = mapped_column(String(50), nullable=False, default="other")
|
||||
extracted_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
processing_status: Mapped[str] = mapped_column(String(20), nullable=False, default="pending")
|
||||
metadata_: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="documents") # noqa: F821
|
||||
26
backend/app/models/memory_entry.py
Normal file
26
backend/app/models/memory_entry.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class MemoryEntry(Base):
|
||||
__tablename__ = "memory_entries"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
source_document_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("documents.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
importance: Mapped[str] = mapped_column(String(20), nullable=False, default="medium")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="memory_entries") # noqa: F821
|
||||
source_document: Mapped["Document | None"] = relationship() # noqa: F821
|
||||
@@ -27,3 +27,5 @@ class User(Base):
|
||||
sessions: Mapped[list["Session"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
chats: Mapped[list["Chat"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
skills: Mapped[list["Skill"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
documents: Mapped[list["Document"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
memory_entries: Mapped[list["MemoryEntry"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
|
||||
32
backend/app/schemas/document.py
Normal file
32
backend/app/schemas/document.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID
|
||||
filename: str
|
||||
original_filename: str
|
||||
mime_type: str
|
||||
file_size: int
|
||||
doc_type: str
|
||||
processing_status: str
|
||||
extracted_text: str | None = None
|
||||
metadata: dict | None = Field(None, alias="metadata_")
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True, "populate_by_name": True}
|
||||
|
||||
|
||||
class DocumentListResponse(BaseModel):
|
||||
documents: list[DocumentResponse]
|
||||
|
||||
|
||||
class UpdateDocumentRequest(BaseModel):
|
||||
doc_type: str | None = None
|
||||
|
||||
|
||||
class DocumentSearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=500)
|
||||
43
backend/app/schemas/memory.py
Normal file
43
backend/app/schemas/memory.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
CategoryType = Literal["condition", "medication", "allergy", "vital", "document_summary", "other"]
|
||||
ImportanceType = Literal["critical", "high", "medium", "low"]
|
||||
|
||||
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
category: CategoryType
|
||||
title: str = Field(min_length=1, max_length=255)
|
||||
content: str = Field(min_length=1)
|
||||
source_document_id: uuid.UUID | None = None
|
||||
importance: ImportanceType = "medium"
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class UpdateMemoryRequest(BaseModel):
|
||||
category: CategoryType | None = None
|
||||
title: str | None = None
|
||||
content: str | None = None
|
||||
importance: ImportanceType | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class MemoryEntryResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID
|
||||
category: str
|
||||
title: str
|
||||
content: str
|
||||
source_document_id: uuid.UUID | None
|
||||
importance: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class MemoryEntryListResponse(BaseModel):
|
||||
entries: list[MemoryEntryResponse]
|
||||
@@ -12,9 +12,100 @@ from app.models.message import Message
|
||||
from app.models.skill import Skill
|
||||
from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context, get_personal_context
|
||||
from app.services.chat_service import get_chat, save_message
|
||||
from app.services.memory_service import get_critical_memories, create_memory, get_user_memories
|
||||
from app.services.document_service import search_documents
|
||||
|
||||
client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
|
||||
# --- AI Tool Definitions ---
|
||||
|
||||
AI_TOOLS = [
|
||||
{
|
||||
"name": "save_memory",
|
||||
"description": "Save important health information to the user's memory. Use this when the user shares critical health data like conditions, medications, allergies, or important health facts.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["condition", "medication", "allergy", "vital", "document_summary", "other"],
|
||||
"description": "Category of the memory entry",
|
||||
},
|
||||
"title": {"type": "string", "description": "Short title for the memory entry"},
|
||||
"content": {"type": "string", "description": "Detailed content of the memory entry"},
|
||||
"importance": {
|
||||
"type": "string",
|
||||
"enum": ["critical", "high", "medium", "low"],
|
||||
"description": "Importance level",
|
||||
},
|
||||
},
|
||||
"required": ["category", "title", "content", "importance"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "search_documents",
|
||||
"description": "Search the user's uploaded health documents for relevant information. Use this when you need to find specific health records, lab results, or consultation notes.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query to find relevant documents"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_memory",
|
||||
"description": "Retrieve the user's stored health memories filtered by category. Use this to recall previously saved health information.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["condition", "medication", "allergy", "vital", "document_summary", "other"],
|
||||
"description": "Optional category filter. Omit to get all memories.",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def _execute_tool(
|
||||
db: AsyncSession, user_id: uuid.UUID, tool_name: str, tool_input: dict
|
||||
) -> str:
|
||||
"""Execute an AI tool and return the result as a string."""
|
||||
if tool_name == "save_memory":
|
||||
entry = await create_memory(
|
||||
db, user_id,
|
||||
category=tool_input["category"],
|
||||
title=tool_input["title"],
|
||||
content=tool_input["content"],
|
||||
importance=tool_input["importance"],
|
||||
)
|
||||
await db.commit()
|
||||
return json.dumps({"status": "saved", "id": str(entry.id), "title": entry.title})
|
||||
|
||||
elif tool_name == "search_documents":
|
||||
docs = await search_documents(db, user_id, tool_input["query"], limit=5)
|
||||
results = []
|
||||
for doc in docs:
|
||||
excerpt = (doc.extracted_text or "")[:1000]
|
||||
results.append({
|
||||
"filename": doc.original_filename,
|
||||
"doc_type": doc.doc_type,
|
||||
"excerpt": excerpt,
|
||||
})
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif tool_name == "get_memory":
|
||||
category = tool_input.get("category")
|
||||
entries = await get_user_memories(db, user_id, category=category, is_active=True)
|
||||
items = [{"category": e.category, "title": e.title, "content": e.content, "importance": e.importance} for e in entries]
|
||||
return json.dumps({"entries": items, "count": len(items)})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
|
||||
async def assemble_context(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
|
||||
@@ -39,9 +130,25 @@ async def assemble_context(
|
||||
if skill and skill.is_active:
|
||||
system_parts.append(f"---\nSpecialist Role ({skill.name}):\n{skill.system_prompt}")
|
||||
|
||||
# 4. Critical memory entries
|
||||
memories = await get_critical_memories(db, user_id)
|
||||
if memories:
|
||||
memory_lines = [f"- [{m.category}] {m.title}: {m.content}" for m in memories]
|
||||
system_parts.append(f"---\nUser Health Profile:\n" + "\n".join(memory_lines))
|
||||
|
||||
# 5. Relevant document excerpts (based on user message keywords)
|
||||
if user_message.strip():
|
||||
docs = await search_documents(db, user_id, user_message, limit=3)
|
||||
if docs:
|
||||
doc_lines = []
|
||||
for d in docs:
|
||||
excerpt = (d.extracted_text or "")[:1500]
|
||||
doc_lines.append(f"[{d.original_filename} ({d.doc_type})]\n{excerpt}")
|
||||
system_parts.append(f"---\nRelevant Document Excerpts:\n" + "\n\n".join(doc_lines))
|
||||
|
||||
system_prompt = "\n\n".join(system_parts)
|
||||
|
||||
# 4. Conversation history
|
||||
# 6. Conversation history
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.chat_id == chat_id, Message.role.in_(["user", "assistant"]))
|
||||
@@ -50,7 +157,7 @@ async def assemble_context(
|
||||
history = result.scalars().all()
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in history]
|
||||
|
||||
# 5. Current user message
|
||||
# 7. Current user message
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
return system_prompt, messages
|
||||
@@ -63,7 +170,7 @@ def _sse_event(event: str, data: dict) -> str:
|
||||
async def stream_ai_response(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream AI response as SSE events."""
|
||||
"""Stream AI response as SSE events, with tool use support."""
|
||||
# Verify ownership
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
|
||||
@@ -75,28 +182,53 @@ async def stream_ai_response(
|
||||
# Assemble context
|
||||
system_prompt, messages = await assemble_context(db, chat_id, user_id, user_message)
|
||||
|
||||
# Stream from Claude
|
||||
full_content = ""
|
||||
assistant_msg_id = str(uuid.uuid4())
|
||||
|
||||
yield _sse_event("message_start", {"message_id": assistant_msg_id})
|
||||
|
||||
async with client.messages.stream(
|
||||
model=settings.CLAUDE_MODEL,
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=messages,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
full_content += text
|
||||
yield _sse_event("content_delta", {"delta": text})
|
||||
# Tool use loop
|
||||
full_content = ""
|
||||
max_tool_rounds = 5
|
||||
|
||||
for _ in range(max_tool_rounds):
|
||||
response = await client.messages.create(
|
||||
model=settings.CLAUDE_MODEL,
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=messages,
|
||||
tools=AI_TOOLS,
|
||||
)
|
||||
|
||||
# Process content blocks
|
||||
tool_use_blocks = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
full_content += block.text
|
||||
yield _sse_event("content_delta", {"delta": block.text})
|
||||
elif block.type == "tool_use":
|
||||
tool_use_blocks.append(block)
|
||||
yield _sse_event("tool_use", {"tool": block.name, "input": block.input})
|
||||
|
||||
# If no tool use, we're done
|
||||
if response.stop_reason != "tool_use" or not tool_use_blocks:
|
||||
break
|
||||
|
||||
# Execute tools and continue conversation
|
||||
messages.append({"role": "assistant", "content": response.content})
|
||||
tool_results = []
|
||||
for tool_block in tool_use_blocks:
|
||||
result = await _execute_tool(db, user_id, tool_block.name, tool_block.input)
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_block.id,
|
||||
"content": result,
|
||||
})
|
||||
yield _sse_event("tool_result", {"tool": tool_block.name, "result": result})
|
||||
messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
# Get final message for metadata
|
||||
final_message = await stream.get_final_message()
|
||||
metadata = {
|
||||
"model": final_message.model,
|
||||
"input_tokens": final_message.usage.input_tokens,
|
||||
"output_tokens": final_message.usage.output_tokens,
|
||||
"model": response.model,
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
}
|
||||
|
||||
# Save assistant message
|
||||
@@ -109,7 +241,6 @@ async def stream_ai_response(
|
||||
)
|
||||
assistant_count = len(result.scalars().all())
|
||||
if assistant_count == 1 and chat.title == "New Chat":
|
||||
# Auto-generate title from first few words
|
||||
title = full_content[:50].split("\n")[0].strip()
|
||||
if len(title) > 40:
|
||||
title = title[:40] + "..."
|
||||
|
||||
97
backend/app/services/document_service.py
Normal file
97
backend/app/services/document_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.document import Document
|
||||
from app.utils.file_storage import delete_file
|
||||
|
||||
|
||||
async def create_document(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
filename: str,
|
||||
original_filename: str,
|
||||
storage_path: str,
|
||||
mime_type: str,
|
||||
file_size: int,
|
||||
doc_type: str = "other",
|
||||
) -> Document:
|
||||
doc = Document(
|
||||
user_id=user_id,
|
||||
filename=filename,
|
||||
original_filename=original_filename,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
file_size=file_size,
|
||||
doc_type=doc_type,
|
||||
processing_status="pending",
|
||||
)
|
||||
db.add(doc)
|
||||
await db.flush()
|
||||
return doc
|
||||
|
||||
|
||||
async def get_document(db: AsyncSession, doc_id: uuid.UUID, user_id: uuid.UUID) -> Document:
|
||||
result = await db.execute(
|
||||
select(Document).where(Document.id == doc_id, Document.user_id == user_id)
|
||||
)
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
||||
return doc
|
||||
|
||||
|
||||
async def get_user_documents(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
doc_type: str | None = None,
|
||||
processing_status: str | None = None,
|
||||
) -> list[Document]:
|
||||
stmt = select(Document).where(Document.user_id == user_id)
|
||||
if doc_type:
|
||||
stmt = stmt.where(Document.doc_type == doc_type)
|
||||
if processing_status:
|
||||
stmt = stmt.where(Document.processing_status == processing_status)
|
||||
stmt = stmt.order_by(Document.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def delete_document(db: AsyncSession, doc_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
doc = await get_document(db, doc_id, user_id)
|
||||
storage_path = doc.storage_path
|
||||
await db.delete(doc)
|
||||
await db.flush()
|
||||
delete_file(storage_path)
|
||||
|
||||
|
||||
async def search_documents(db: AsyncSession, user_id: uuid.UUID, query: str, limit: int = 5) -> list[Document]:
|
||||
stmt = (
|
||||
select(Document)
|
||||
.where(
|
||||
Document.user_id == user_id,
|
||||
Document.processing_status == "completed",
|
||||
text("to_tsvector('english', coalesce(extracted_text, '')) @@ plainto_tsquery('english', :query)"),
|
||||
)
|
||||
.params(query=query)
|
||||
.order_by(
|
||||
text("ts_rank(to_tsvector('english', coalesce(extracted_text, '')), plainto_tsquery('english', :query)) DESC")
|
||||
)
|
||||
.params(query=query)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def update_document_text(
|
||||
db: AsyncSession, doc_id: uuid.UUID, extracted_text: str, status_val: str = "completed"
|
||||
) -> None:
|
||||
result = await db.execute(select(Document).where(Document.id == doc_id))
|
||||
doc = result.scalar_one_or_none()
|
||||
if doc:
|
||||
doc.extracted_text = extracted_text
|
||||
doc.processing_status = status_val
|
||||
await db.flush()
|
||||
71
backend/app/services/memory_service.py
Normal file
71
backend/app/services/memory_service.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory_entry import MemoryEntry
|
||||
|
||||
|
||||
async def create_memory(db: AsyncSession, user_id: uuid.UUID, **kwargs) -> MemoryEntry:
|
||||
entry = MemoryEntry(user_id=user_id, **kwargs)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
return entry
|
||||
|
||||
|
||||
async def get_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> MemoryEntry:
|
||||
result = await db.execute(
|
||||
select(MemoryEntry).where(MemoryEntry.id == entry_id, MemoryEntry.user_id == user_id)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Memory entry not found")
|
||||
return entry
|
||||
|
||||
|
||||
async def get_user_memories(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
category: str | None = None,
|
||||
importance: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
) -> list[MemoryEntry]:
|
||||
stmt = select(MemoryEntry).where(MemoryEntry.user_id == user_id)
|
||||
if category:
|
||||
stmt = stmt.where(MemoryEntry.category == category)
|
||||
if importance:
|
||||
stmt = stmt.where(MemoryEntry.importance == importance)
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(MemoryEntry.is_active == is_active)
|
||||
stmt = stmt.order_by(MemoryEntry.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
ALLOWED_UPDATE_FIELDS = {"category", "title", "content", "importance", "is_active"}
|
||||
|
||||
|
||||
async def update_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID, **kwargs) -> MemoryEntry:
|
||||
entry = await get_memory(db, entry_id, user_id)
|
||||
for key, value in kwargs.items():
|
||||
if key in ALLOWED_UPDATE_FIELDS:
|
||||
setattr(entry, key, value)
|
||||
await db.flush()
|
||||
return entry
|
||||
|
||||
|
||||
async def delete_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
entry = await get_memory(db, entry_id, user_id)
|
||||
await db.delete(entry)
|
||||
|
||||
|
||||
async def get_critical_memories(db: AsyncSession, user_id: uuid.UUID) -> list[MemoryEntry]:
|
||||
result = await db.execute(
|
||||
select(MemoryEntry).where(
|
||||
MemoryEntry.user_id == user_id,
|
||||
MemoryEntry.is_active == True, # noqa: E712
|
||||
MemoryEntry.importance.in_(["critical", "high"]),
|
||||
).order_by(MemoryEntry.importance, MemoryEntry.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
0
backend/app/utils/__init__.py
Normal file
0
backend/app/utils/__init__.py
Normal file
34
backend/app/utils/file_storage.py
Normal file
34
backend/app/utils/file_storage.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _get_upload_dir(user_id: uuid.UUID, doc_id: uuid.UUID) -> Path:
|
||||
path = Path(settings.UPLOAD_DIR) / str(user_id) / str(doc_id)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
async def save_upload(user_id: uuid.UUID, doc_id: uuid.UUID, filename: str, content: bytes) -> str:
|
||||
directory = _get_upload_dir(user_id, doc_id)
|
||||
file_path = directory / filename
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def get_file_path(storage_path: str) -> Path:
|
||||
return Path(storage_path)
|
||||
|
||||
|
||||
def delete_file(storage_path: str) -> None:
|
||||
path = Path(storage_path)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
# Clean up empty parent dirs
|
||||
parent = path.parent
|
||||
if parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
19
backend/app/utils/text_extraction.py
Normal file
19
backend/app/utils/text_extraction.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_text_from_pdf(file_path: str) -> str:
|
||||
import fitz # PyMuPDF
|
||||
|
||||
text_parts = []
|
||||
with fitz.open(file_path) as doc:
|
||||
for page in doc:
|
||||
text_parts.append(page.get_text())
|
||||
return "\n".join(text_parts).strip()
|
||||
|
||||
|
||||
def extract_text(file_path: str, mime_type: str) -> str:
|
||||
if mime_type == "application/pdf":
|
||||
return extract_text_from_pdf(file_path)
|
||||
# For images, we'd use pytesseract but skip for now as it requires system deps
|
||||
# For other types, return empty
|
||||
return ""
|
||||
0
backend/app/workers/__init__.py
Normal file
0
backend/app/workers/__init__.py
Normal file
32
backend/app/workers/document_processor.py
Normal file
32
backend/app/workers/document_processor.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import uuid
|
||||
|
||||
from app.database import async_session_factory
|
||||
from app.services.document_service import update_document_text
|
||||
from app.utils.text_extraction import extract_text
|
||||
|
||||
|
||||
async def process_document(doc_id: uuid.UUID, storage_path: str, mime_type: str) -> None:
|
||||
"""Background task: extract text from uploaded document."""
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
# Update status to processing
|
||||
from sqlalchemy import select
|
||||
from app.models.document import Document
|
||||
|
||||
result = await db.execute(select(Document).where(Document.id == doc_id))
|
||||
doc = result.scalar_one_or_none()
|
||||
if not doc:
|
||||
return
|
||||
doc.processing_status = "processing"
|
||||
await db.commit()
|
||||
|
||||
# Extract text
|
||||
text = extract_text(storage_path, mime_type)
|
||||
|
||||
# Update with extracted text
|
||||
await update_document_text(db, doc_id, text, "completed" if text else "failed")
|
||||
await db.commit()
|
||||
|
||||
except Exception:
|
||||
await update_document_text(db, doc_id, "", "failed")
|
||||
await db.commit()
|
||||
@@ -16,6 +16,8 @@ dependencies = [
|
||||
"python-multipart>=0.0.9",
|
||||
"httpx>=0.27.0",
|
||||
"anthropic>=0.40.0",
|
||||
"pymupdf>=1.24.0",
|
||||
"aiofiles>=24.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
97
backend/tests/test_documents.py
Normal file
97
backend/tests/test_documents.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import io
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient):
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "docuser@example.com",
|
||||
"username": "docuser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
async def test_upload_document(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/?doc_type=lab_result",
|
||||
headers=auth_headers,
|
||||
files={"file": ("test.pdf", b"%PDF-1.4 test content", "application/pdf")},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["original_filename"] == "test.pdf"
|
||||
assert data["doc_type"] == "lab_result"
|
||||
assert data["processing_status"] == "pending"
|
||||
|
||||
|
||||
async def test_upload_invalid_type(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/",
|
||||
headers=auth_headers,
|
||||
files={"file": ("test.exe", b"MZ...", "application/x-msdownload")},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
async def test_list_documents(client: AsyncClient, auth_headers: dict):
|
||||
# Upload first
|
||||
await client.post(
|
||||
"/api/v1/documents/",
|
||||
headers=auth_headers,
|
||||
files={"file": ("list_test.pdf", b"%PDF-1.4 content", "application/pdf")},
|
||||
)
|
||||
|
||||
resp = await client.get("/api/v1/documents/", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["documents"]) >= 1
|
||||
|
||||
|
||||
async def test_get_document(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/",
|
||||
headers=auth_headers,
|
||||
files={"file": ("get_test.pdf", b"%PDF-1.4 content", "application/pdf")},
|
||||
)
|
||||
doc_id = resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/v1/documents/{doc_id}", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == doc_id
|
||||
|
||||
|
||||
async def test_delete_document(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/",
|
||||
headers=auth_headers,
|
||||
files={"file": ("del_test.pdf", b"%PDF-1.4 content", "application/pdf")},
|
||||
)
|
||||
doc_id = resp.json()["id"]
|
||||
|
||||
resp = await client.delete(f"/api/v1/documents/{doc_id}", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
resp = await client.get(f"/api/v1/documents/{doc_id}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_document_ownership_isolation(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/",
|
||||
headers=auth_headers,
|
||||
files={"file": ("private.pdf", b"%PDF-1.4 content", "application/pdf")},
|
||||
)
|
||||
doc_id = resp.json()["id"]
|
||||
|
||||
# Register another user
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "docother@example.com",
|
||||
"username": "docother",
|
||||
"password": "testpass123",
|
||||
})
|
||||
other_headers = {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
resp = await client.get(f"/api/v1/documents/{doc_id}", headers=other_headers)
|
||||
assert resp.status_code == 404
|
||||
109
backend/tests/test_memory.py
Normal file
109
backend/tests/test_memory.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient):
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "memuser@example.com",
|
||||
"username": "memuser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
async def test_create_memory(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/memory/", json={
|
||||
"category": "condition",
|
||||
"title": "Diabetes Type 2",
|
||||
"content": "Diagnosed in 2024, managed with metformin",
|
||||
"importance": "critical",
|
||||
}, headers=auth_headers)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["category"] == "condition"
|
||||
assert data["title"] == "Diabetes Type 2"
|
||||
assert data["importance"] == "critical"
|
||||
assert data["is_active"] is True
|
||||
|
||||
|
||||
async def test_list_memories(client: AsyncClient, auth_headers: dict):
|
||||
await client.post("/api/v1/memory/", json={
|
||||
"category": "allergy",
|
||||
"title": "Penicillin",
|
||||
"content": "Severe allergic reaction",
|
||||
"importance": "critical",
|
||||
}, headers=auth_headers)
|
||||
|
||||
resp = await client.get("/api/v1/memory/", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["entries"]) >= 1
|
||||
|
||||
|
||||
async def test_filter_by_category(client: AsyncClient, auth_headers: dict):
|
||||
await client.post("/api/v1/memory/", json={
|
||||
"category": "medication",
|
||||
"title": "Metformin",
|
||||
"content": "500mg twice daily",
|
||||
"importance": "high",
|
||||
}, headers=auth_headers)
|
||||
|
||||
resp = await client.get("/api/v1/memory/", params={"category": "medication"}, headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
entries = resp.json()["entries"]
|
||||
assert all(e["category"] == "medication" for e in entries)
|
||||
|
||||
|
||||
async def test_update_memory(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/memory/", json={
|
||||
"category": "vital",
|
||||
"title": "Blood Pressure",
|
||||
"content": "130/85",
|
||||
"importance": "medium",
|
||||
}, headers=auth_headers)
|
||||
entry_id = resp.json()["id"]
|
||||
|
||||
resp = await client.patch(f"/api/v1/memory/{entry_id}", json={
|
||||
"content": "125/80 (improved)",
|
||||
"importance": "low",
|
||||
}, headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["content"] == "125/80 (improved)"
|
||||
assert resp.json()["importance"] == "low"
|
||||
|
||||
|
||||
async def test_delete_memory(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/memory/", json={
|
||||
"category": "other",
|
||||
"title": "To Delete",
|
||||
"content": "Test",
|
||||
"importance": "low",
|
||||
}, headers=auth_headers)
|
||||
entry_id = resp.json()["id"]
|
||||
|
||||
resp = await client.delete(f"/api/v1/memory/{entry_id}", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
resp = await client.get(f"/api/v1/memory/{entry_id}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_memory_ownership_isolation(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/memory/", json={
|
||||
"category": "condition",
|
||||
"title": "Private Info",
|
||||
"content": "Sensitive",
|
||||
"importance": "critical",
|
||||
}, headers=auth_headers)
|
||||
entry_id = resp.json()["id"]
|
||||
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "memother@example.com",
|
||||
"username": "memother",
|
||||
"password": "testpass123",
|
||||
})
|
||||
other_headers = {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
resp = await client.get(f"/api/v1/memory/{entry_id}", headers=other_headers)
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user