import uuid from fastapi import HTTPException, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.chat import Chat from app.models.message import Message from app.models.user import User async def create_chat(db: AsyncSession, user: User, title: str | None = None, skill_id: uuid.UUID | None = None) -> Chat: count = await db.scalar( select(func.count()).select_from(Chat).where( Chat.user_id == user.id, Chat.is_archived == False # noqa: E712 ) ) if user.role != "admin" and count >= user.max_chats: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Chat limit reached. Archive or delete existing chats.", ) chat = Chat(user_id=user.id, title=title or "New Chat", skill_id=skill_id) db.add(chat) await db.flush() return chat async def get_user_chats( db: AsyncSession, user_id: uuid.UUID, archived: bool | None = None ) -> list[Chat]: stmt = select(Chat).where(Chat.user_id == user_id) if archived is not None: stmt = stmt.where(Chat.is_archived == archived) stmt = stmt.order_by(Chat.updated_at.desc()) result = await db.execute(stmt) return list(result.scalars().all()) async def get_chat(db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID) -> Chat: result = await db.execute( select(Chat).where(Chat.id == chat_id, Chat.user_id == user_id) ) chat = result.scalar_one_or_none() if not chat: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") return chat async def update_chat( db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, title: str | None = None, is_archived: bool | None = None, skill_id: uuid.UUID | None = None, ) -> Chat: chat = await get_chat(db, chat_id, user_id) if title is not None: chat.title = title if is_archived is not None: chat.is_archived = is_archived if skill_id is not None: chat.skill_id = skill_id await db.flush() return chat async def delete_chat(db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID) -> None: chat = await get_chat(db, chat_id, user_id) await db.delete(chat) async def get_messages( db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, limit: int = 50, before: uuid.UUID | None = None, ) -> list[Message]: # Verify ownership await get_chat(db, chat_id, user_id) stmt = select(Message).where(Message.chat_id == chat_id) if before: before_msg = await db.get(Message, before) if before_msg: stmt = stmt.where(Message.created_at < before_msg.created_at) stmt = stmt.order_by(Message.created_at.asc()).limit(limit) result = await db.execute(stmt) return list(result.scalars().all()) async def save_message( db: AsyncSession, chat_id: uuid.UUID, role: str, content: str, metadata: dict | None = None, ) -> Message: message = Message(chat_id=chat_id, role=role, content=content, metadata_=metadata) db.add(message) await db.flush() return message