Phase 3: Skills & Context — skill system, personal context, context layering
Backend: - Skill model + migration (with FK on chats.skill_id) - Personal + general skill CRUD services with access isolation - Admin skill CRUD endpoints (POST/GET/PATCH/DELETE /admin/skills) - User skill CRUD endpoints (POST/GET/PATCH/DELETE /skills/) - Personal context GET/PUT at /users/me/context - Extended context assembly: primary + personal context + skill prompt - Chat creation/update now accepts skill_id with validation Frontend: - Skill selector dropdown in chat header (grouped: general + personal) - Reusable skill editor form component - Admin skills management page (/admin/skills) - Personal skills page (/skills) - Personal context editor page (/profile/context) - Updated sidebar: Skills, My Context nav items + admin skills link - English + Russian translations for all skill/context UI Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
"""Create skills table and add FK on chats.skill_id
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"skills",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=True, index=True),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("system_prompt", sa.Text, nullable=False),
|
||||
sa.Column("icon", sa.String(50), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean, nullable=False, server_default=sa.text("true")),
|
||||
sa.Column("sort_order", sa.Integer, nullable=False, server_default=sa.text("0")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"fk_chats_skill_id",
|
||||
"chats",
|
||||
"skills",
|
||||
["skill_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chats_skill_id", "chats", type_="foreignkey")
|
||||
op.drop_table("skills")
|
||||
@@ -1,17 +1,26 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import require_admin
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.chat import ContextFileResponse, UpdateContextRequest
|
||||
from app.services import context_service
|
||||
from app.schemas.skill import (
|
||||
CreateSkillRequest,
|
||||
SkillListResponse,
|
||||
SkillResponse,
|
||||
UpdateSkillRequest,
|
||||
)
|
||||
from app.services import context_service, skill_service
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
# --- Context ---
|
||||
|
||||
@router.get("/context", response_model=ContextFileResponse | None)
|
||||
async def get_primary_context(
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
@@ -31,3 +40,44 @@ async def update_primary_context(
|
||||
):
|
||||
ctx = await context_service.upsert_primary_context(db, data.content, admin.id)
|
||||
return ContextFileResponse.model_validate(ctx)
|
||||
|
||||
|
||||
# --- Skills ---
|
||||
|
||||
@router.get("/skills", response_model=SkillListResponse)
|
||||
async def list_general_skills(
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skills = await skill_service.get_general_skills(db)
|
||||
return SkillListResponse(skills=[SkillResponse.model_validate(s) for s in skills])
|
||||
|
||||
|
||||
@router.post("/skills", response_model=SkillResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_general_skill(
|
||||
data: CreateSkillRequest,
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skill = await skill_service.create_general_skill(db, **data.model_dump())
|
||||
return SkillResponse.model_validate(skill)
|
||||
|
||||
|
||||
@router.patch("/skills/{skill_id}", response_model=SkillResponse)
|
||||
async def update_general_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: UpdateSkillRequest,
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skill = await skill_service.update_general_skill(db, skill_id, **data.model_dump(exclude_unset=True))
|
||||
return SkillResponse.model_validate(skill)
|
||||
|
||||
|
||||
@router.delete("/skills/{skill_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_general_skill(
|
||||
skill_id: uuid.UUID,
|
||||
_admin: Annotated[User, Depends(require_admin)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
await skill_service.delete_general_skill(db, skill_id)
|
||||
|
||||
@@ -17,7 +17,7 @@ from app.schemas.chat import (
|
||||
SendMessageRequest,
|
||||
UpdateChatRequest,
|
||||
)
|
||||
from app.services import chat_service
|
||||
from app.services import chat_service, skill_service
|
||||
from app.services.ai_service import stream_ai_response
|
||||
|
||||
router = APIRouter(prefix="/chats", tags=["chats"])
|
||||
@@ -29,7 +29,9 @@ async def create_chat(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
chat = await chat_service.create_chat(db, user, data.title)
|
||||
if data.skill_id:
|
||||
await skill_service.validate_skill_accessible(db, data.skill_id, user.id)
|
||||
chat = await chat_service.create_chat(db, user, data.title, data.skill_id)
|
||||
return ChatResponse.model_validate(chat)
|
||||
|
||||
|
||||
@@ -60,7 +62,9 @@ async def update_chat(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
chat = await chat_service.update_chat(db, chat_id, user.id, data.title, data.is_archived)
|
||||
if data.skill_id:
|
||||
await skill_service.validate_skill_accessible(db, data.skill_id, user.id)
|
||||
chat = await chat_service.update_chat(db, chat_id, user.id, data.title, data.is_archived, data.skill_id)
|
||||
return ChatResponse.model_validate(chat)
|
||||
|
||||
|
||||
|
||||
@@ -3,12 +3,16 @@ from fastapi import APIRouter
|
||||
from app.api.v1.auth import router as auth_router
|
||||
from app.api.v1.chats import router as chats_router
|
||||
from app.api.v1.admin import router as admin_router
|
||||
from app.api.v1.skills import router as skills_router
|
||||
from app.api.v1.users import router as users_router
|
||||
|
||||
api_v1_router = APIRouter(prefix="/api/v1")
|
||||
|
||||
api_v1_router.include_router(auth_router)
|
||||
api_v1_router.include_router(chats_router)
|
||||
api_v1_router.include_router(admin_router)
|
||||
api_v1_router.include_router(skills_router)
|
||||
api_v1_router.include_router(users_router)
|
||||
|
||||
|
||||
@api_v1_router.get("/health")
|
||||
|
||||
70
backend/app/api/v1/skills.py
Normal file
70
backend/app/api/v1/skills.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.skill import (
|
||||
CreateSkillRequest,
|
||||
SkillListResponse,
|
||||
SkillResponse,
|
||||
UpdateSkillRequest,
|
||||
)
|
||||
from app.services import skill_service
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["skills"])
|
||||
|
||||
|
||||
@router.get("/", response_model=SkillListResponse)
|
||||
async def list_skills(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
include_general: bool = Query(default=True),
|
||||
):
|
||||
skills = await skill_service.get_accessible_skills(db, user.id, include_general)
|
||||
return SkillListResponse(skills=[SkillResponse.model_validate(s) for s in skills])
|
||||
|
||||
|
||||
@router.post("/", response_model=SkillResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_skill(
|
||||
data: CreateSkillRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skill = await skill_service.create_personal_skill(db, user.id, **data.model_dump())
|
||||
return SkillResponse.model_validate(skill)
|
||||
|
||||
|
||||
@router.get("/{skill_id}", response_model=SkillResponse)
|
||||
async def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skill = await skill_service.get_skill(db, skill_id, user.id)
|
||||
return SkillResponse.model_validate(skill)
|
||||
|
||||
|
||||
@router.patch("/{skill_id}", response_model=SkillResponse)
|
||||
async def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: UpdateSkillRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
skill = await skill_service.update_personal_skill(
|
||||
db, skill_id, user.id, **data.model_dump(exclude_unset=True)
|
||||
)
|
||||
return SkillResponse.model_validate(skill)
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
await skill_service.delete_personal_skill(db, skill_id, user.id)
|
||||
33
backend/app/api/v1/users.py
Normal file
33
backend/app/api/v1/users.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.chat import ContextFileResponse, UpdateContextRequest
|
||||
from app.services import context_service
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("/me/context", response_model=ContextFileResponse | None)
|
||||
async def get_personal_context(
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
ctx = await context_service.get_personal_context(db, user.id)
|
||||
if not ctx:
|
||||
return None
|
||||
return ContextFileResponse.model_validate(ctx)
|
||||
|
||||
|
||||
@router.put("/me/context", response_model=ContextFileResponse)
|
||||
async def update_personal_context(
|
||||
data: UpdateContextRequest,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
ctx = await context_service.upsert_personal_context(db, user.id, data.content)
|
||||
return ContextFileResponse.model_validate(ctx)
|
||||
@@ -3,5 +3,6 @@ from app.models.session import Session
|
||||
from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.models.context_file import ContextFile
|
||||
from app.models.skill import Skill
|
||||
|
||||
__all__ = ["User", "Session", "Chat", "Message", "ContextFile"]
|
||||
__all__ = ["User", "Session", "Chat", "Message", "ContextFile", "Skill"]
|
||||
|
||||
@@ -15,11 +15,14 @@ class Chat(Base):
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False, default="New Chat")
|
||||
skill_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True)
|
||||
skill_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("skills.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
is_archived: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="chats") # noqa: F821
|
||||
skill: Mapped["Skill | None"] = relationship() # noqa: F821
|
||||
messages: Mapped[list["Message"]] = relationship(back_populates="chat", cascade="all, delete-orphan") # noqa: F821
|
||||
|
||||
23
backend/app/models/skill.py
Normal file
23
backend/app/models/skill.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Skill(Base):
|
||||
__tablename__ = "skills"
|
||||
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
system_prompt: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
icon: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
user: Mapped["User | None"] = relationship(back_populates="skills") # noqa: F821
|
||||
@@ -26,3 +26,4 @@ class User(Base):
|
||||
|
||||
sessions: Mapped[list["Session"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
chats: Mapped[list["Chat"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
skills: Mapped[list["Skill"]] = relationship(back_populates="user", cascade="all, delete-orphan") # noqa: F821
|
||||
|
||||
@@ -6,11 +6,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
title: str | None = None
|
||||
skill_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class UpdateChatRequest(BaseModel):
|
||||
title: str | None = None
|
||||
is_archived: bool | None = None
|
||||
skill_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
|
||||
40
backend/app/schemas/skill.py
Normal file
40
backend/app/schemas/skill.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateSkillRequest(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
system_prompt: str = Field(min_length=1)
|
||||
icon: str | None = None
|
||||
is_active: bool = True
|
||||
sort_order: int = 0
|
||||
|
||||
|
||||
class UpdateSkillRequest(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
description: str | None = None
|
||||
system_prompt: str | None = Field(default=None, min_length=1)
|
||||
icon: str | None = None
|
||||
is_active: bool | None = None
|
||||
sort_order: int | None = None
|
||||
|
||||
|
||||
class SkillResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID | None
|
||||
name: str
|
||||
description: str | None
|
||||
system_prompt: str
|
||||
icon: str | None
|
||||
is_active: bool
|
||||
sort_order: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SkillListResponse(BaseModel):
|
||||
skills: list[SkillResponse]
|
||||
@@ -9,31 +9,48 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.config import settings
|
||||
from app.models.chat import Chat
|
||||
from app.models.message import Message
|
||||
from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context
|
||||
from app.models.skill import Skill
|
||||
from app.services.context_service import DEFAULT_SYSTEM_PROMPT, get_primary_context, get_personal_context
|
||||
from app.services.chat_service import get_chat, save_message
|
||||
|
||||
client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
|
||||
|
||||
async def assemble_context(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_message: str
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID, user_message: str
|
||||
) -> tuple[str, list[dict]]:
|
||||
"""Assemble system prompt and messages for Claude API."""
|
||||
system_parts = []
|
||||
|
||||
# 1. Primary context
|
||||
ctx = await get_primary_context(db)
|
||||
system_prompt = ctx.content if ctx and ctx.content.strip() else DEFAULT_SYSTEM_PROMPT
|
||||
system_parts.append(ctx.content if ctx and ctx.content.strip() else DEFAULT_SYSTEM_PROMPT)
|
||||
|
||||
# 2. Conversation history
|
||||
# 2. Personal context
|
||||
personal_ctx = await get_personal_context(db, user_id)
|
||||
if personal_ctx and personal_ctx.content.strip():
|
||||
system_parts.append(f"---\nUser Context:\n{personal_ctx.content}")
|
||||
|
||||
# 3. Active skill system prompt
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
if chat.skill_id:
|
||||
result = await db.execute(select(Skill).where(Skill.id == chat.skill_id))
|
||||
skill = result.scalar_one_or_none()
|
||||
if skill and skill.is_active:
|
||||
system_parts.append(f"---\nSpecialist Role ({skill.name}):\n{skill.system_prompt}")
|
||||
|
||||
system_prompt = "\n\n".join(system_parts)
|
||||
|
||||
# 4. Conversation history
|
||||
result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.chat_id == chat_id, Message.role.in_(["user", "assistant"]))
|
||||
.order_by(Message.created_at.asc())
|
||||
)
|
||||
history = result.scalars().all()
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in history]
|
||||
|
||||
# 3. Current user message
|
||||
# 5. Current user message
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
return system_prompt, messages
|
||||
@@ -56,7 +73,7 @@ async def stream_ai_response(
|
||||
|
||||
try:
|
||||
# Assemble context
|
||||
system_prompt, messages = await assemble_context(db, chat_id, user_message)
|
||||
system_prompt, messages = await assemble_context(db, chat_id, user_id, user_message)
|
||||
|
||||
# Stream from Claude
|
||||
full_content = ""
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.models.message import Message
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
async def create_chat(db: AsyncSession, user: User, title: str | None = None) -> Chat:
|
||||
async def create_chat(db: AsyncSession, user: User, title: str | None = None, skill_id: uuid.UUID | None = None) -> Chat:
|
||||
count = await db.scalar(
|
||||
select(func.count()).select_from(Chat).where(
|
||||
Chat.user_id == user.id, Chat.is_archived == False # noqa: E712
|
||||
@@ -21,7 +21,7 @@ async def create_chat(db: AsyncSession, user: User, title: str | None = None) ->
|
||||
detail="Chat limit reached. Archive or delete existing chats.",
|
||||
)
|
||||
|
||||
chat = Chat(user_id=user.id, title=title or "New Chat")
|
||||
chat = Chat(user_id=user.id, title=title or "New Chat", skill_id=skill_id)
|
||||
db.add(chat)
|
||||
await db.flush()
|
||||
return chat
|
||||
@@ -51,12 +51,15 @@ async def get_chat(db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID) ->
|
||||
async def update_chat(
|
||||
db: AsyncSession, chat_id: uuid.UUID, user_id: uuid.UUID,
|
||||
title: str | None = None, is_archived: bool | None = None,
|
||||
skill_id: uuid.UUID | None = None,
|
||||
) -> Chat:
|
||||
chat = await get_chat(db, chat_id, user_id)
|
||||
if title is not None:
|
||||
chat.title = title
|
||||
if is_archived is not None:
|
||||
chat.is_archived = is_archived
|
||||
if skill_id is not None:
|
||||
chat.skill_id = skill_id
|
||||
await db.flush()
|
||||
return chat
|
||||
|
||||
|
||||
@@ -23,6 +23,34 @@ async def get_primary_context(db: AsyncSession) -> ContextFile | None:
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_personal_context(db: AsyncSession, user_id: uuid.UUID) -> ContextFile | None:
|
||||
result = await db.execute(
|
||||
select(ContextFile).where(ContextFile.type == "personal", ContextFile.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def upsert_personal_context(
|
||||
db: AsyncSession, user_id: uuid.UUID, content: str
|
||||
) -> ContextFile:
|
||||
ctx = await get_personal_context(db, user_id)
|
||||
if ctx:
|
||||
ctx.content = content
|
||||
ctx.version = ctx.version + 1
|
||||
ctx.updated_by = user_id
|
||||
else:
|
||||
ctx = ContextFile(
|
||||
type="personal",
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
version=1,
|
||||
updated_by=user_id,
|
||||
)
|
||||
db.add(ctx)
|
||||
await db.flush()
|
||||
return ctx
|
||||
|
||||
|
||||
async def upsert_primary_context(
|
||||
db: AsyncSession, content: str, admin_user_id: uuid.UUID
|
||||
) -> ContextFile:
|
||||
|
||||
97
backend/app/services/skill_service.py
Normal file
97
backend/app/services/skill_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.skill import Skill
|
||||
|
||||
|
||||
async def get_accessible_skills(
|
||||
db: AsyncSession, user_id: uuid.UUID, include_general: bool = True
|
||||
) -> list[Skill]:
|
||||
conditions = [Skill.user_id == user_id]
|
||||
if include_general:
|
||||
conditions.append(Skill.user_id.is_(None))
|
||||
stmt = select(Skill).where(or_(*conditions), Skill.is_active == True).order_by(Skill.sort_order) # noqa: E712
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_skill(db: AsyncSession, skill_id: uuid.UUID, user_id: uuid.UUID | None = None) -> Skill:
|
||||
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
||||
skill = result.scalar_one_or_none()
|
||||
if not skill:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Skill not found")
|
||||
# Access check: must be general or owned by user
|
||||
if user_id and skill.user_id is not None and skill.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Skill not found")
|
||||
return skill
|
||||
|
||||
|
||||
async def validate_skill_accessible(db: AsyncSession, skill_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
"""Validate skill exists and is accessible by user (general or owned). Raises 404 if not."""
|
||||
await get_skill(db, skill_id, user_id)
|
||||
|
||||
|
||||
# --- Personal skills ---
|
||||
|
||||
async def create_personal_skill(db: AsyncSession, user_id: uuid.UUID, **kwargs) -> Skill:
|
||||
skill = Skill(user_id=user_id, **kwargs)
|
||||
db.add(skill)
|
||||
await db.flush()
|
||||
return skill
|
||||
|
||||
|
||||
async def update_personal_skill(db: AsyncSession, skill_id: uuid.UUID, user_id: uuid.UUID, **kwargs) -> Skill:
|
||||
skill = await get_skill(db, skill_id, user_id)
|
||||
if skill.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot edit general skills")
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
setattr(skill, key, value)
|
||||
await db.flush()
|
||||
return skill
|
||||
|
||||
|
||||
async def delete_personal_skill(db: AsyncSession, skill_id: uuid.UUID, user_id: uuid.UUID) -> None:
|
||||
skill = await get_skill(db, skill_id, user_id)
|
||||
if skill.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot delete general skills")
|
||||
await db.delete(skill)
|
||||
|
||||
|
||||
# --- General (admin) skills ---
|
||||
|
||||
async def get_general_skills(db: AsyncSession) -> list[Skill]:
|
||||
result = await db.execute(
|
||||
select(Skill).where(Skill.user_id.is_(None)).order_by(Skill.sort_order)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def create_general_skill(db: AsyncSession, **kwargs) -> Skill:
|
||||
skill = Skill(user_id=None, **kwargs)
|
||||
db.add(skill)
|
||||
await db.flush()
|
||||
return skill
|
||||
|
||||
|
||||
async def update_general_skill(db: AsyncSession, skill_id: uuid.UUID, **kwargs) -> Skill:
|
||||
result = await db.execute(select(Skill).where(Skill.id == skill_id, Skill.user_id.is_(None)))
|
||||
skill = result.scalar_one_or_none()
|
||||
if not skill:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="General skill not found")
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
setattr(skill, key, value)
|
||||
await db.flush()
|
||||
return skill
|
||||
|
||||
|
||||
async def delete_general_skill(db: AsyncSession, skill_id: uuid.UUID) -> None:
|
||||
result = await db.execute(select(Skill).where(Skill.id == skill_id, Skill.user_id.is_(None)))
|
||||
skill = result.scalar_one_or_none()
|
||||
if not skill:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="General skill not found")
|
||||
await db.delete(skill)
|
||||
142
backend/tests/test_skills.py
Normal file
142
backend/tests/test_skills.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def user_headers(client: AsyncClient):
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "skilluser@example.com",
|
||||
"username": "skilluser",
|
||||
"password": "testpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def other_user_headers(client: AsyncClient):
|
||||
resp = await client.post("/api/v1/auth/register", json={
|
||||
"email": "skillother@example.com",
|
||||
"username": "skillother",
|
||||
"password": "testpass123",
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
return {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||
|
||||
|
||||
# --- Personal Skills ---
|
||||
|
||||
async def test_create_personal_skill(client: AsyncClient, user_headers: dict):
|
||||
resp = await client.post("/api/v1/skills/", json={
|
||||
"name": "Nutritionist",
|
||||
"description": "Diet and nutrition advice",
|
||||
"system_prompt": "You are a nutritionist.",
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "Nutritionist"
|
||||
assert data["user_id"] is not None
|
||||
|
||||
|
||||
async def test_list_personal_skills(client: AsyncClient, user_headers: dict):
|
||||
await client.post("/api/v1/skills/", json={
|
||||
"name": "Test Skill",
|
||||
"system_prompt": "Test prompt",
|
||||
}, headers=user_headers)
|
||||
|
||||
resp = await client.get("/api/v1/skills/", params={"include_general": False}, headers=user_headers)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["skills"]) >= 1
|
||||
|
||||
|
||||
async def test_update_personal_skill(client: AsyncClient, user_headers: dict):
|
||||
resp = await client.post("/api/v1/skills/", json={
|
||||
"name": "Old Name",
|
||||
"system_prompt": "Prompt",
|
||||
}, headers=user_headers)
|
||||
skill_id = resp.json()["id"]
|
||||
|
||||
resp = await client.patch(f"/api/v1/skills/{skill_id}", json={
|
||||
"name": "New Name",
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "New Name"
|
||||
|
||||
|
||||
async def test_delete_personal_skill(client: AsyncClient, user_headers: dict):
|
||||
resp = await client.post("/api/v1/skills/", json={
|
||||
"name": "To Delete",
|
||||
"system_prompt": "Prompt",
|
||||
}, headers=user_headers)
|
||||
skill_id = resp.json()["id"]
|
||||
|
||||
resp = await client.delete(f"/api/v1/skills/{skill_id}", headers=user_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
|
||||
async def test_cannot_access_other_users_skill(client: AsyncClient, user_headers: dict, other_user_headers: dict):
|
||||
resp = await client.post("/api/v1/skills/", json={
|
||||
"name": "Private Skill",
|
||||
"system_prompt": "Prompt",
|
||||
}, headers=user_headers)
|
||||
skill_id = resp.json()["id"]
|
||||
|
||||
# Other user can't see it
|
||||
resp = await client.get(f"/api/v1/skills/{skill_id}", headers=other_user_headers)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# --- Admin Skills ---
|
||||
|
||||
async def test_non_admin_cannot_manage_general_skills(client: AsyncClient, user_headers: dict):
|
||||
resp = await client.get("/api/v1/admin/skills", headers=user_headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
resp = await client.post("/api/v1/admin/skills", json={
|
||||
"name": "General",
|
||||
"system_prompt": "Prompt",
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
# --- Personal Context ---
|
||||
|
||||
async def test_personal_context_crud(client: AsyncClient, user_headers: dict):
|
||||
# Initially null
|
||||
resp = await client.get("/api/v1/users/me/context", headers=user_headers)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Create
|
||||
resp = await client.put("/api/v1/users/me/context", json={
|
||||
"content": "I have diabetes type 2",
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == "I have diabetes type 2"
|
||||
assert data["version"] == 1
|
||||
|
||||
# Update
|
||||
resp = await client.put("/api/v1/users/me/context", json={
|
||||
"content": "I have diabetes type 2 and hypertension",
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["version"] == 2
|
||||
|
||||
|
||||
# --- Chat with Skill ---
|
||||
|
||||
async def test_create_chat_with_skill(client: AsyncClient, user_headers: dict):
|
||||
# Create a skill first
|
||||
resp = await client.post("/api/v1/skills/", json={
|
||||
"name": "Cardiologist",
|
||||
"system_prompt": "You are a cardiologist.",
|
||||
}, headers=user_headers)
|
||||
skill_id = resp.json()["id"]
|
||||
|
||||
# Create chat with skill
|
||||
resp = await client.post("/api/v1/chats/", json={
|
||||
"title": "Heart Consultation",
|
||||
"skill_id": skill_id,
|
||||
}, headers=user_headers)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["skill_id"] == skill_id
|
||||
Reference in New Issue
Block a user