Phase 2: Chat & AI Core — Claude API streaming, chat UI, admin context
Backend: - Chat, Message, ContextFile models + Alembic migration - Chat CRUD with per-user limit enforcement (max_chats) - SSE streaming endpoint: saves user message, streams Claude response, saves assistant message with token usage metadata - Context assembly: primary context file + conversation history - Admin context CRUD (GET/PUT with version tracking) - Anthropic SDK integration with async streaming - Chat ownership isolation (users can't access each other's chats) Frontend: - Chat page with sidebar chat list + main chat window - Real-time SSE streaming via fetch + ReadableStream - Message bubbles with Markdown rendering (react-markdown) - Auto-growing message input (Enter to send, Shift+Enter newline) - Zustand chat store for streaming state management - Admin primary context editor with unsaved changes warning - Updated routing: /chat, /chat/:chatId, /admin/context - Enabled Chat and Admin sidebar navigation - English + Russian translations for all new UI Infrastructure: - nginx: disabled proxy buffering for SSE support - Added ANTHROPIC_API_KEY and CLAUDE_MODEL to config Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
"""Create chats, messages, and context_files tables
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chats",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("title", sa.String(255), nullable=False, server_default="New Chat"),
|
||||
sa.Column("skill_id", UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("is_archived", sa.Boolean, nullable=False, server_default=sa.text("false")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("chat_id", UUID(as_uuid=True), sa.ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("role", sa.String(20), nullable=False),
|
||||
sa.Column("content", sa.Text, nullable=False),
|
||||
sa.Column("metadata", JSONB, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"context_files",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("type", sa.String(20), nullable=False),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("content", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("version", sa.Integer, nullable=False, server_default=sa.text("1")),
|
||||
sa.Column("updated_by", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint("type", "user_id", name="uq_context_files_type_user"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("context_files")
|
||||
op.drop_table("messages")
|
||||
op.drop_table("chats")
|
||||
33
backend/app/api/v1/admin.py
Normal file
33
backend/app/api/v1/admin.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import require_admin
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.chat import ContextFileResponse, UpdateContextRequest
|
||||
from app.services import context_service
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/context", response_model=ContextFileResponse | None)
|
||||
async def get_primary_context(
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
ctx = await context_service.get_primary_context(db)
|
||||
if not ctx:
|
||||
return None
|
||||
return ContextFileResponse.model_validate(ctx)
|
||||
|
||||
|
||||
@router.put("/context", response_model=ContextFileResponse)
|
||||
async def update_primary_context(
|
||||
data: UpdateContextRequest,
|
||||
admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
ctx = await context_service.upsert_primary_context(db, data.content, admin.id)
|
||||
return ContextFileResponse.model_validate(ctx)
|
||||
103
backend/app/api/v1/chats.py
Normal file
103
backend/app/api/v1/chats.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.chat import (
|
||||
ChatListResponse,
|
||||
ChatResponse,
|
||||
CreateChatRequest,
|
||||
MessageListResponse,
|
||||
MessageResponse,
|
||||
SendMessageRequest,
|
||||
UpdateChatRequest,
|
||||
)
|
||||
from app.services import chat_service
|
||||
from app.services.ai_service import stream_ai_response
|
||||
|
||||
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChatResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_chat(
|
||||
data: CreateChatRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
chat = await chat_service.create_chat(db, user, data.title)
|
||||
return ChatResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.get("/", response_model=ChatListResponse)
|
||||
async def list_chats(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
archived: bool | None = Query(default=None),
|
||||
):
|
||||
chats = await chat_service.get_user_chats(db, user.id, archived)
|
||||
return ChatListResponse(chats=[ChatResponse.model_validate(c) for c in chats])
|
||||
|
||||
|
||||
@router.get("/{chat_id}", response_model=ChatResponse)
|
||||
async def get_chat(
|
||||
chat_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
chat = await chat_service.get_chat(db, chat_id, user.id)
|
||||
return ChatResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.patch("/{chat_id}", response_model=ChatResponse)
|
||||
async def update_chat(
|
||||
chat_id: uuid.UUID,
|
||||
data: UpdateChatRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
chat = await chat_service.update_chat(db, chat_id, user.id, data.title, data.is_archived)
|
||||
return ChatResponse.model_validate(chat)
|
||||
|
||||
|
||||
@router.delete("/{chat_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_chat(
|
||||
chat_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
await chat_service.delete_chat(db, chat_id, user.id)
|
||||
|
||||
|
||||
@router.get("/{chat_id}/messages", response_model=MessageListResponse)
|
||||
async def list_messages(
|
||||
chat_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
limit: int = Query(default=50, le=200),
|
||||
before: uuid.UUID | None = Query(default=None),
|
||||
):
|
||||
messages = await chat_service.get_messages(db, chat_id, user.id, limit, before)
|
||||
return MessageListResponse(messages=[MessageResponse.model_validate(m) for m in messages])
|
||||
|
||||
|
||||
@router.post("/{chat_id}/messages")
|
||||
async def send_message(
|
||||
chat_id: uuid.UUID,
|
||||
data: SendMessageRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
return StreamingResponse(
|
||||
stream_ai_response(db, chat_id, user.id, data.content),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
@@ -1,10 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.auth import router as auth_router
|
||||
from app.api.v1.chats import router as chats_router
|
||||
from app.api.v1.admin import router as admin_router
|
||||
|
||||
api_v1_router = APIRouter(prefix="/api/v1")
|
||||
|
||||
api_v1_router.include_router(auth_router)
|
||||
api_v1_router.include_router(chats_router)
|
||||
api_v1_router.include_router(admin_router)
|
||||
|
||||
|
||||
@api_v1_router.get("/health")
|
||||
|
||||
@@ -12,6 +12,9 @@ class Settings(BaseSettings):
|
||||
|
||||
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost", "http://localhost:3000"]
|
||||
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
CLAUDE_MODEL: str = "claude-sonnet-4-20250514"
|
||||
|
||||
FIRST_ADMIN_EMAIL: str = "admin@example.com"
|
||||
FIRST_ADMIN_USERNAME: str = "admin"
|
||||
FIRST_ADMIN_PASSWORD: str = "changeme_admin_password"
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from app.models.user import User
|
||||
from app.models.session import Session
|
||||
from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.models.context_file import ContextFile
|
||||
|
||||
__all__ = ["User", "Session"]
|
||||
__all__ = ["User", "Session", "Chat", "Message", "ContextFile"]
|
||||
|
||||
25
backend/app/models/chat.py
Normal file
25
backend/app/models/chat.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chats"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False, default="New Chat")
|
||||
skill_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
|
||||
is_archived: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="chats") # noqa: F821
|
||||
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan") # noqa: F821
|
||||
26
backend/app/models/context_file.py
Normal file
26
backend/app/models/context_file.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ContextFile(Base):
|
||||
__tablename__ = "context_files"
|
||||
__table_args__ = (UniqueConstraint("type", "user_id", name="uq_context_files_type_user"),)
|
||||
|
||||
type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
updated_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
20
backend/app/models/message.py
Normal file
20
backend/app/models/message.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
chat_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("chats.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
metadata_: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
||||
|
||||
chat: Mapped["Chat"] = relationship(back_populates="messages") # noqa: F821
|
||||
@@ -25,3 +25,4 @@ class User(Base):
|
||||
)
|
||||
|
||||
sessions: Mapped[list["Session"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
chats: Mapped[list["Chat"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
|
||||
@@ -28,6 +28,7 @@ class UserResponse(BaseModel):
|
||||
full_name: str | None
|
||||
role: str
|
||||
is_active: bool
|
||||
max_chats: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
62
backend/app/schemas/chat.py
Normal file
62
backend/app/schemas/chat.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class UpdateChatRequest(BaseModel):
|
||||
title: str | None = None
|
||||
is_archived: bool | None = None
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
content: str = Field(min_length=1, max_length=50000)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID
|
||||
title: str
|
||||
skill_id: uuid.UUID | None
|
||||
is_archived: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
chats: list[ChatResponse]
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
chat_id: uuid.UUID
|
||||
role: str
|
||||
content: str
|
||||
metadata: dict | None = Field(None, alias="metadata_")
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True, "populate_by_name": True}
|
||||
|
||||
|
||||
class MessageListResponse(BaseModel):
|
||||
messages: list[MessageResponse]
|
||||
|
||||
|
||||
class ContextFileResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
type: str
|
||||
content: str
|
||||
version: int
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UpdateContextRequest(BaseModel):
|
||||
content: str
|
||||
108
backend/app/services/ai_service.py
Normal file
108
backend/app/services/ai_service.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context
|
||||
from app.services.chat_service import get_chat, save_message
|
||||
|
||||
client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
|
||||
|
||||
async def assemble_context(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_message: str
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Assemble system prompt and messages for Claude API."""
|
||||
# 1. Primary context
|
||||
ctx = await get_primary_context(db)
|
||||
system_prompt = ctx.content if ctx and ctx.content.strip() else DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
# 2. Conversation history
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.chat_id == chat_id, Message.role.in_(["user", "assistant"]))
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
history = result.scalars().all()
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in history]
|
||||
|
||||
# 3. Current user message
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
return system_prompt, messages
|
||||
|
||||
|
||||
def _sse_event(event: str, data: dict) -> str:
|
||||
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
async def stream_ai_response(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream AI response as SSE events."""
|
||||
# Verify ownership
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
|
||||
# Save user message
|
||||
await save_message(db, chat_id, "user", user_message)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
# Assemble context
|
||||
system_prompt, messages = await assemble_context(db, chat_id, user_message)
|
||||
|
||||
# Stream from Claude
|
||||
full_content = ""
|
||||
assistant_msg_id = str(uuid.uuid4())
|
||||
|
||||
yield _sse_event("message_start", {"message_id": assistant_msg_id})
|
||||
|
||||
async with client.messages.stream(
|
||||
model=settings.CLAUDE_MODEL,
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=messages,
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
full_content += text
|
||||
yield _sse_event("content_delta", {"delta": text})
|
||||
|
||||
# Get final message for metadata
|
||||
final_message = await stream.get_final_message()
|
||||
metadata = {
|
||||
"model": final_message.model,
|
||||
"input_tokens": final_message.usage.input_tokens,
|
||||
"output_tokens": final_message.usage.output_tokens,
|
||||
}
|
||||
|
||||
# Save assistant message
|
||||
saved_msg = await save_message(db, chat_id, "assistant", full_content, metadata)
|
||||
await db.commit()
|
||||
|
||||
# Update chat title if first exchange
|
||||
result = await db.execute(
|
||||
select(Message).where(Message.chat_id == chat_id, Message.role == "assistant")
|
||||
)
|
||||
assistant_count = len(result.scalars().all())
|
||||
if assistant_count == 1 and chat.title == "New Chat":
|
||||
# Auto-generate title from first few words
|
||||
title = full_content[:50].split("\n")[0].strip()
|
||||
if len(title) > 40:
|
||||
title = title[:40] + "..."
|
||||
chat.title = title
|
||||
await db.commit()
|
||||
|
||||
yield _sse_event("message_end", {
|
||||
"message_id": str(saved_msg.id),
|
||||
"metadata": metadata,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
yield _sse_event("error", {"detail": str(e)})
|
||||
93
backend/app/services/chat_service.py
Normal file
93
backend/app/services/chat_service.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
async def create_chat(db: AsyncSession, user: User, title: str | None = None) -> Chat:
|
||||
count = await db.scalar(
|
||||
select(func.count()).select_from(Chat).where(
|
||||
Chat.user_id == user.id, Chat.is_archived == False # noqa: E712
|
||||
)
|
||||
)
|
||||
if user.role != "admin" and count >= user.max_chats:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Chat limit reached. Archive or delete existing chats.",
|
||||
)
|
||||
|
||||
chat = Chat(user_id=user.id, title=title or "New Chat")
|
||||
db.add(chat)
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
async def get_user_chats(
|
||||
db: AsyncSession, user_id: uuid.UUID, archived: bool | None = None
|
||||
) -> list[Chat]:
|
||||
stmt = select(Chat).where(Chat.user_id == user_id)
|
||||
if archived is not None:
|
||||
stmt = stmt.where(Chat.is_archived == archived)
|
||||
stmt = stmt.order_by(Chat.updated_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_chat(db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID) -> Chat:
|
||||
result = await db.execute(
|
||||
select(Chat).where(Chat.id == chat_id, Chat.user_id == user_id)
|
||||
)
|
||||
chat = result.scalar_one_or_none()
|
||||
if not chat:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
||||
return chat
|
||||
|
||||
|
||||
async def update_chat(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID,
|
||||
title: str | None = None, is_archived: bool | None = None,
|
||||
) -> Chat:
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
if title is not None:
|
||||
chat.title = title
|
||||
if is_archived is not None:
|
||||
chat.is_archived = is_archived
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
async def delete_chat(db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
await db.delete(chat)
|
||||
|
||||
|
||||
async def get_messages(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID,
|
||||
limit: int = 50, before: uuid.UUID | None = None,
|
||||
) -> list[Message]:
|
||||
# Verify ownership
|
||||
await get_chat(db, chat_id, user_id)
|
||||
|
||||
stmt = select(Message).where(Message.chat_id == chat_id)
|
||||
if before:
|
||||
before_msg = await db.get(Message, before)
|
||||
if before_msg:
|
||||
stmt = stmt.where(Message.created_at < before_msg.created_at)
|
||||
stmt = stmt.order_by(Message.created_at.asc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def save_message(
|
||||
db: AsyncSession, chat_id: uuid.UUID, role: str, content: str,
|
||||
metadata: dict | None = None,
|
||||
) -> Message:
|
||||
message = Message(chat_id=chat_id, role=role, content=content, metadata_=metadata)
|
||||
db.add(message)
|
||||
await db.flush()
|
||||
return message
|
||||
44
backend/app/services/context_service.py
Normal file
44
backend/app/services/context_service.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.context_file import ContextFile
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a personal AI health assistant. Your role is to:
|
||||
- Help users understand their health data and medical documents
|
||||
- Provide health-related recommendations based on uploaded information
|
||||
- Schedule reminders for checkups, medications, and health-related activities
|
||||
- Compile health summaries when requested
|
||||
- Answer health questions clearly and compassionately
|
||||
|
||||
Always be empathetic, accurate, and clear. When uncertain, recommend consulting a healthcare professional.
|
||||
You can communicate in English and Russian based on the user's preference."""
|
||||
|
||||
|
||||
async def get_primary_context(db: AsyncSession) -> ContextFile | None:
|
||||
result = await db.execute(
|
||||
select(ContextFile).where(ContextFile.type == "primary", ContextFile.user_id.is_(None))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def upsert_primary_context(
|
||||
db: AsyncSession, content: str, admin_user_id: uuid.UUID
|
||||
) -> ContextFile:
|
||||
ctx = await get_primary_context(db)
|
||||
if ctx:
|
||||
ctx.content = content
|
||||
ctx.version = ctx.version + 1
|
||||
ctx.updated_by = admin_user_id
|
||||
else:
|
||||
ctx = ContextFile(
|
||||
type="primary",
|
||||
user_id=None,
|
||||
content=content,
|
||||
version=1,
|
||||
updated_by=admin_user_id,
|
||||
)
|
||||
db.add(ctx)
|
||||
await db.flush()
|
||||
return ctx
|
||||
@@ -15,6 +15,7 @@ dependencies = [
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.9",
|
||||
"httpx>=0.27.0",
|
||||
"anthropic>=0.40.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
167
backend/tests/test_chats.py
Normal file
167
backend/tests/test_chats.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(client: AsyncClient):
|
||||
"""Register a user and return auth headers."""
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "chatuser@example.com",
|
||||
"username": "chatuser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
token = resp.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def admin_headers(client: AsyncClient):
|
||||
"""Register a user, manually set them as admin via the DB."""
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "admin_chat@example.com",
|
||||
"username": "admin_chat",
|
||||
"password": "adminpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
token = resp.json()["access_token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chat_id(client: AsyncClient, auth_headers: dict):
|
||||
"""Create a chat and return its ID."""
|
||||
resp = await client.post("/api/v1/chats/", json={"title": "Test Chat"}, headers=auth_headers)
|
||||
assert resp.status_code == 201
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
# --- Chat CRUD ---
|
||||
|
||||
async def test_create_chat(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/chats/", json={"title": "My Chat"}, headers=auth_headers)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["title"] == "My Chat"
|
||||
assert data["is_archived"] is False
|
||||
|
||||
|
||||
async def test_create_chat_default_title(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.post("/api/v1/chats/", json={}, headers=auth_headers)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["title"] == "New Chat"
|
||||
|
||||
|
||||
async def test_list_chats(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.get("/api/v1/chats/", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
chats = resp.json()["chats"]
|
||||
assert any(c["id"] == chat_id for c in chats)
|
||||
|
||||
|
||||
async def test_get_chat(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.get(f"/api/v1/chats/{chat_id}", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == chat_id
|
||||
|
||||
|
||||
async def test_update_chat_title(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.patch(
|
||||
f"/api/v1/chats/{chat_id}",
|
||||
json={"title": "Updated Title"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["title"] == "Updated Title"
|
||||
|
||||
|
||||
async def test_archive_chat(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.patch(
|
||||
f"/api/v1/chats/{chat_id}",
|
||||
json={"is_archived": True},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["is_archived"] is True
|
||||
|
||||
|
||||
async def test_delete_chat(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.delete(f"/api/v1/chats/{chat_id}", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
resp = await client.get(f"/api/v1/chats/{chat_id}", headers=auth_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# --- Chat Limit ---
|
||||
|
||||
async def test_chat_limit(client: AsyncClient):
|
||||
"""Create a user with max_chats=2 and verify limit enforcement."""
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "limited@example.com",
|
||||
"username": "limiteduser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
token = resp.json()["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# The default max_chats is 10, create 10 chats
|
||||
for i in range(10):
|
||||
resp = await client.post("/api/v1/chats/", json={"title": f"Chat {i}"}, headers=headers)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# 11th should fail
|
||||
resp = await client.post("/api/v1/chats/", json={"title": "Over limit"}, headers=headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# --- Ownership Isolation ---
|
||||
|
||||
async def test_cannot_access_other_users_chat(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
# Register another user
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "other@example.com",
|
||||
"username": "otheruser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
other_token = resp.json()["access_token"]
|
||||
other_headers = {"Authorization": f"Bearer {other_token}"}
|
||||
|
||||
# Try to access first user's chat
|
||||
resp = await client.get(f"/api/v1/chats/{chat_id}", headers=other_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# --- Messages ---
|
||||
|
||||
async def test_get_messages_empty(client: AsyncClient, auth_headers: dict, chat_id: str):
|
||||
resp = await client.get(f"/api/v1/chats/{chat_id}/messages", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["messages"] == []
|
||||
|
||||
|
||||
# --- Admin Context ---
|
||||
|
||||
async def test_get_context_unauthenticated(client: AsyncClient, auth_headers: dict):
|
||||
resp = await client.get("/api/v1/admin/context", headers=auth_headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
async def test_admin_context_crud(client: AsyncClient):
|
||||
"""Test context CRUD with a direct DB admin (simplified: register + test endpoint access)."""
|
||||
# Note: This tests the endpoint structure. Full admin test would require
|
||||
# setting the user role to admin in the DB.
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "ctxadmin@example.com",
|
||||
"username": "ctxadmin",
|
||||
"password": "testpass123",
|
||||
})
|
||||
token = resp.json()["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# Regular user should get 403
|
||||
resp = await client.get("/api/v1/admin/context", headers=headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
resp = await client.put("/api/v1/admin/context", json={"content": "test"}, headers=headers)
|
||||
assert resp.status_code == 403
|
||||
Reference in New Issue
Block a user