import json import uuid from collections.abc import AsyncGenerator from anthropic import AsyncAnthropic from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.models.chat import Chat from app.models.message import Message from app.models.skill import Skill from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context, get_personal_context from app.services.chat_service import get_chat, save_message from app.services.memory_service import get_critical_memories, create_memory, get_user_memories from app.services.document_service import search_documents from app.services.notification_service import create_notification from app.services.ws_manager import manager client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY) # --- AI Tool Definitions --- AI_TOOLS = [ { "name": "save_memory", "description": "Save important information to the user's memory. Use this when the user shares critical personal data, facts, preferences, or key details they want to remember across conversations.", "input_schema": { "type": "object", "properties": { "category": { "type": "string", "enum": ["health", "finance", "personal", "work", "document_summary", "other"], "description": "Category of the memory entry", }, "title": {"type": "string", "description": "Short title for the memory entry"}, "content": {"type": "string", "description": "Detailed content of the memory entry"}, "importance": { "type": "string", "enum": ["critical", "high", "medium", "low"], "description": "Importance level", }, }, "required": ["category", "title", "content", "importance"], }, }, { "name": "search_documents", "description": "Search the user's uploaded documents for relevant information. Use this when you need to find specific records, files, or notes the user has uploaded.", "input_schema": { "type": "object", "properties": { "query": {"type": "string", "description": "Search query to find relevant documents"}, }, "required": ["query"], }, }, { "name": "get_memory", "description": "Retrieve the user's stored memories filtered by category. Use this to recall previously saved information.", "input_schema": { "type": "object", "properties": { "category": { "type": "string", "enum": ["health", "finance", "personal", "work", "document_summary", "other"], "description": "Optional category filter. Omit to get all memories.", }, }, "required": [], }, }, { "name": "schedule_notification", "description": "Schedule a notification or reminder for the user. Can be immediate or scheduled for a future time.", "input_schema": { "type": "object", "properties": { "title": {"type": "string", "description": "Notification title"}, "body": {"type": "string", "description": "Notification body text"}, "scheduled_at": { "type": "string", "description": "ISO 8601 datetime for scheduled delivery. Omit for immediate.", }, "type": { "type": "string", "enum": ["reminder", "alert", "info"], "description": "Notification type", "default": "reminder", }, }, "required": ["title", "body"], }, }, { "name": "generate_pdf", "description": "Generate a PDF report compilation from the user's data. Use this when the user asks for a document or summary of their stored information.", "input_schema": { "type": "object", "properties": { "title": {"type": "string", "description": "Title for the PDF report"}, }, "required": ["title"], }, }, ] async def _execute_tool( db: AsyncSession, user_id: uuid.UUID, tool_name: str, tool_input: dict ) -> str: """Execute an AI tool and return the result as a string.""" if tool_name == "save_memory": entry = await create_memory( db, user_id, category=tool_input["category"], title=tool_input["title"], content=tool_input["content"], importance=tool_input["importance"], ) await db.commit() return json.dumps({"status": "saved", "id": str(entry.id), "title": entry.title}) elif tool_name == "search_documents": docs = await search_documents(db, user_id, tool_input["query"], limit=5) results = [] for doc in docs: excerpt = (doc.extracted_text or "")[:1000] results.append({ "filename": doc.original_filename, "doc_type": doc.doc_type, "excerpt": excerpt, }) return json.dumps({"results": results, "count": len(results)}) elif tool_name == "get_memory": category = tool_input.get("category") entries = await get_user_memories(db, user_id, category=category, is_active=True) items = [{"category": e.category, "title": e.title, "content": e.content, "importance": e.importance} for e in entries] return json.dumps({"entries": items, "count": len(items)}) elif tool_name == "schedule_notification": from datetime import datetime, timezone as tz scheduled_at = None if tool_input.get("scheduled_at"): scheduled_at = datetime.fromisoformat(tool_input["scheduled_at"]) if scheduled_at.tzinfo is None: scheduled_at = scheduled_at.replace(tzinfo=tz.utc) if scheduled_at <= datetime.now(tz.utc): scheduled_at = None # Past date → send immediately notif = await create_notification( db, user_id, title=tool_input["title"], body=tool_input["body"], type=tool_input.get("type", "reminder"), scheduled_at=scheduled_at, ) await db.commit() # Push immediately if not scheduled if not scheduled_at: from app.workers.notification_sender import _serialize_notification await manager.send_to_user(user_id, { "type": "new_notification", "notification": _serialize_notification(notif), }) return json.dumps({ "status": "scheduled" if scheduled_at else "sent", "id": str(notif.id), "title": notif.title, }) elif tool_name == "generate_pdf": from app.services.pdf_service import generate_pdf_report pdf = await generate_pdf_report(db, user_id, title=tool_input["title"]) await db.commit() return json.dumps({ "status": "generated", "id": str(pdf.id), "title": pdf.title, "download_url": f"/api/v1/pdf/{pdf.id}/download", }) return json.dumps({"error": f"Unknown tool: {tool_name}"}) async def assemble_context( db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str ) -> tuple[str, list[dict]]: """Assemble system prompt and messages for Claude API.""" system_parts = [] # 1. Primary context ctx = await get_primary_context(db) system_parts.append(ctx.content if ctx and ctx.content.strip() else DEFAULT_SYSTEM_PROMPT) # 2. Personal context personal_ctx = await get_personal_context(db, user_id) if personal_ctx and personal_ctx.content.strip(): system_parts.append(f"---\nUser Context:\n{personal_ctx.content}") # 3. Active skill system prompt chat = await get_chat(db, chat_id, user_id) if chat.skill_id: result = await db.execute(select(Skill).where(Skill.id == chat.skill_id)) skill = result.scalar_one_or_none() if skill and skill.is_active: system_parts.append(f"---\nSpecialist Role ({skill.name}):\n{skill.system_prompt}") # 4. Critical memory entries memories = await get_critical_memories(db, user_id) if memories: memory_lines = [f"- [{m.category}] {m.title}: {m.content}" for m in memories] system_parts.append(f"---\nUser Profile (Key Information):\n" + "\n".join(memory_lines)) # 5. Relevant document excerpts (based on user message keywords) if user_message.strip(): docs = await search_documents(db, user_id, user_message, limit=3) if docs: doc_lines = [] for d in docs: excerpt = (d.extracted_text or "")[:1500] doc_lines.append(f"[{d.original_filename} ({d.doc_type})]\n{excerpt}") system_parts.append(f"---\nRelevant Document Excerpts:\n" + "\n\n".join(doc_lines)) system_prompt = "\n\n".join(system_parts) # 6. Conversation history result = await db.execute( select(Message) .where(Message.chat_id == chat_id, Message.role.in_(["user", "assistant"])) .order_by(Message.created_at.asc()) ) history = result.scalars().all() messages = [{"role": msg.role, "content": msg.content} for msg in history] # 7. Current user message messages.append({"role": "user", "content": user_message}) return system_prompt, messages def _sse_event(event: str, data: dict) -> str: return f"event: {event}\ndata: {json.dumps(data)}\n\n" async def stream_ai_response( db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str ) -> AsyncGenerator[str, None]: """Stream AI response as SSE events, with tool use support.""" # Verify ownership chat = await get_chat(db, chat_id, user_id) # Save user message await save_message(db, chat_id, "user", user_message) await db.commit() try: # Assemble context system_prompt, messages = await assemble_context(db, chat_id, user_id, user_message) assistant_msg_id = str(uuid.uuid4()) yield _sse_event("message_start", {"message_id": assistant_msg_id}) # Tool use loop full_content = "" max_tool_rounds = 5 for _ in range(max_tool_rounds): response = await client.messages.create( model=settings.CLAUDE_MODEL, max_tokens=4096, system=system_prompt, messages=messages, tools=AI_TOOLS, ) # Process content blocks tool_use_blocks = [] for block in response.content: if block.type == "text": full_content += block.text yield _sse_event("content_delta", {"delta": block.text}) elif block.type == "tool_use": tool_use_blocks.append(block) yield _sse_event("tool_use", {"tool": block.name, "input": block.input}) # If no tool use, we're done if response.stop_reason != "tool_use" or not tool_use_blocks: break # Execute tools and continue conversation messages.append({"role": "assistant", "content": response.content}) tool_results = [] for tool_block in tool_use_blocks: result = await _execute_tool(db, user_id, tool_block.name, tool_block.input) tool_results.append({ "type": "tool_result", "tool_use_id": tool_block.id, "content": result, }) yield _sse_event("tool_result", {"tool": tool_block.name, "result": result}) messages.append({"role": "user", "content": tool_results}) metadata = { "model": response.model, "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, } # Save assistant message saved_msg = await save_message(db, chat_id, "assistant", full_content, metadata) await db.commit() # Update chat title if first exchange result = await db.execute( select(Message).where(Message.chat_id == chat_id, Message.role == "assistant") ) assistant_count = len(result.scalars().all()) if assistant_count == 1 and chat.title == "New Chat": title = full_content[:50].split("\n")[0].strip() if len(title) > 40: title = title[:40] + "..." chat.title = title await db.commit() yield _sse_event("message_end", { "message_id": str(saved_msg.id), "metadata": metadata, }) except Exception as e: yield _sse_event("error", {"detail": str(e)})