import uuid from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.memory_entry import MemoryEntry async def create_memory(db: AsyncSession, user_id: uuid.UUID, **kwargs) -> MemoryEntry: entry = MemoryEntry(user_id=user_id, **kwargs) db.add(entry) await db.flush() return entry async def get_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> MemoryEntry: result = await db.execute( select(MemoryEntry).where(MemoryEntry.id == entry_id, MemoryEntry.user_id == user_id) ) entry = result.scalar_one_or_none() if not entry: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Memory entry not found") return entry async def get_user_memories( db: AsyncSession, user_id: uuid.UUID, category: str | None = None, importance: str | None = None, is_active: bool | None = None, ) -> list[MemoryEntry]: stmt = select(MemoryEntry).where(MemoryEntry.user_id == user_id) if category: stmt = stmt.where(MemoryEntry.category == category) if importance: stmt = stmt.where(MemoryEntry.importance == importance) if is_active is not None: stmt = stmt.where(MemoryEntry.is_active == is_active) stmt = stmt.order_by(MemoryEntry.created_at.desc()) result = await db.execute(stmt) return list(result.scalars().all()) ALLOWED_UPDATE_FIELDS = {"category", "title", "content", "importance", "is_active"} async def update_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID, **kwargs) -> MemoryEntry: entry = await get_memory(db, entry_id, user_id) for key, value in kwargs.items(): if key in ALLOWED_UPDATE_FIELDS: setattr(entry, key, value) await db.flush() return entry async def delete_memory(db: AsyncSession, entry_id: uuid.UUID, user_id: uuid.UUID) -> None: entry = await get_memory(db, entry_id, user_id) await db.delete(entry) async def get_critical_memories(db: AsyncSession, user_id: uuid.UUID) -> list[MemoryEntry]: result = await db.execute( select(MemoryEntry).where( MemoryEntry.user_id == user_id, MemoryEntry.is_active == True, # noqa: E712 MemoryEntry.importance.in_(["critical", "high"]), ).order_by(MemoryEntry.importance, MemoryEntry.created_at.desc()) ) return list(result.scalars().all())