import uuid from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import decode_access_token from app.database import async_session_factory from app.models.user import User from app.services.ws_manager import manager from app.services.notification_service import get_unread_count router = APIRouter(tags=["websocket"]) async def _authenticate_ws(token: str) -> uuid.UUID | None: payload = decode_access_token(token) user_id = payload.get("sub") if not user_id: return None try: uid = uuid.UUID(user_id) except ValueError: return None async with async_session_factory() as db: result = await db.execute(select(User).where(User.id == uid, User.is_active == True)) # noqa: E712 user = result.scalar_one_or_none() if not user: return None return uid @router.websocket("/ws/notifications") async def ws_notifications(websocket: WebSocket, token: str = Query(...)): user_id = await _authenticate_ws(token) if not user_id: await websocket.close(code=4001, reason="Unauthorized") return await manager.connect(user_id, websocket) try: # Send initial unread count async with async_session_factory() as db: count = await get_unread_count(db, user_id) await websocket.send_json({"type": "unread_count", "count": count}) # Keep alive - wait for disconnect while True: await websocket.receive_text() except WebSocketDisconnect: pass finally: manager.disconnect(user_id, websocket)