feat: security hardening — SSRF guard, template sandbox timeout, webhook log prune, auth & backup polish
- Add outbound URL validation (SSRF) for webhook/Discord/Slack/ntfy/Matrix dispatch - Template renderer: input/output caps and thread-based render timeout - Webhook log filter: strip Authorization/signature/token-like headers; atomic prune - Auth/JWT/backup/config tightening; misc frontend UX fixes
This commit is contained in:
@@ -31,6 +31,23 @@ def _backup_dir():
|
||||
return app_config.data_dir / "backups"
|
||||
|
||||
|
||||
def _resolve_backup_file(filename: str):
|
||||
"""Validate filename and resolve to a path strictly inside the backup dir."""
|
||||
if not filename.startswith("backup-") or not filename.endswith(".json"):
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
if "/" in filename or "\\" in filename or ".." in filename or "\x00" in filename:
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
base = _backup_dir().resolve()
|
||||
candidate = (base / filename).resolve()
|
||||
try:
|
||||
candidate.relative_to(base)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
if not candidate.is_file():
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
return candidate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -194,9 +211,7 @@ async def download_backup_file(
|
||||
user: User = Depends(require_admin),
|
||||
):
|
||||
"""Download a specific backup file."""
|
||||
filepath = _backup_dir() / filename
|
||||
if not filepath.is_file() or not filename.startswith("backup-"):
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
filepath = _resolve_backup_file(filename)
|
||||
|
||||
try:
|
||||
content = json.loads(filepath.read_text(encoding="utf-8"))
|
||||
@@ -215,9 +230,6 @@ async def delete_backup_file(
|
||||
user: User = Depends(require_admin),
|
||||
):
|
||||
"""Delete a specific backup file."""
|
||||
filepath = _backup_dir() / filename
|
||||
if not filepath.is_file() or not filename.startswith("backup-"):
|
||||
raise HTTPException(status_code=404, detail="Backup file not found")
|
||||
|
||||
filepath = _resolve_backup_file(filename)
|
||||
filepath.unlink()
|
||||
return {"deleted": filename}
|
||||
|
||||
@@ -350,12 +350,29 @@ def _verify_generic_webhook_auth(
|
||||
return False
|
||||
|
||||
|
||||
_SENSITIVE_HEADER_SUBSTR = (
|
||||
"token", "auth", "key", "secret", "signature", "password", "credential",
|
||||
"cookie", "x-api", "x-hub-signature",
|
||||
)
|
||||
|
||||
|
||||
def _is_sensitive_header(name: str) -> bool:
|
||||
n = name.lower()
|
||||
return any(s in n for s in _SENSITIVE_HEADER_SUBSTR)
|
||||
|
||||
|
||||
def _filter_headers(raw_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Keep only safe headers for logging (no Authorization)."""
|
||||
"""Keep only safe headers for logging (strip Authorization, signatures, tokens).
|
||||
|
||||
Allowlist base set of known-safe headers, accept X-* only if they do not
|
||||
match any sensitive substring (token/auth/key/secret/signature/...).
|
||||
"""
|
||||
safe: dict[str, str] = {}
|
||||
for k, v in raw_headers.items():
|
||||
kl = k.lower()
|
||||
if kl in ("content-type", "user-agent") or kl.startswith("x-"):
|
||||
if _is_sensitive_header(kl):
|
||||
continue
|
||||
if kl in ("content-type", "user-agent", "content-length", "accept") or kl.startswith("x-"):
|
||||
safe[k] = v
|
||||
return safe
|
||||
|
||||
@@ -384,26 +401,26 @@ async def _save_webhook_log(
|
||||
error_message=error_message,
|
||||
))
|
||||
await session.flush()
|
||||
count_result = await session.exec(
|
||||
select(func.count(WebhookPayloadLog.id))
|
||||
# Atomic prune: DELETE anything for this provider outside the newest
|
||||
# max_count rows. Avoids the COUNT -> SELECT -> DELETE race.
|
||||
keep_subq = (
|
||||
select(WebhookPayloadLog.id)
|
||||
.where(WebhookPayloadLog.provider_id == provider_id)
|
||||
.order_by(WebhookPayloadLog.created_at.desc(), WebhookPayloadLog.id.desc())
|
||||
.limit(max_count)
|
||||
.subquery()
|
||||
)
|
||||
await session.execute(
|
||||
sa_delete(WebhookPayloadLog)
|
||||
.where(WebhookPayloadLog.provider_id == provider_id)
|
||||
.where(~WebhookPayloadLog.id.in_(select(keep_subq.c.id)))
|
||||
)
|
||||
total = count_result.one()
|
||||
if total > max_count:
|
||||
oldest = await session.exec(
|
||||
select(WebhookPayloadLog.id)
|
||||
.where(WebhookPayloadLog.provider_id == provider_id)
|
||||
.order_by(WebhookPayloadLog.created_at.asc())
|
||||
.limit(total - max_count)
|
||||
)
|
||||
ids_to_delete = list(oldest.all())
|
||||
if ids_to_delete:
|
||||
await session.execute(
|
||||
sa_delete(WebhookPayloadLog)
|
||||
.where(WebhookPayloadLog.id.in_(ids_to_delete))
|
||||
)
|
||||
except Exception:
|
||||
_LOGGER.warning("Failed to save webhook payload log for provider %d", provider_id, exc_info=True)
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/webhook/{token}")
|
||||
@@ -436,6 +453,8 @@ async def generic_webhook(token: str, request: Request):
|
||||
# Parse JSON payload
|
||||
try:
|
||||
payload = await request.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Payload must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
if store_payloads:
|
||||
async with AsyncSession(get_engine()) as log_session:
|
||||
|
||||
@@ -22,12 +22,15 @@ async def get_current_user(
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
||||
user_id = int(payload["sub"])
|
||||
token_version = int(payload.get("ver", 1))
|
||||
except (jwt.PyJWTError, KeyError, ValueError) as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token") from exc
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
if token_version != user.token_version:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token revoked")
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -9,15 +9,26 @@ from ..config import settings
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(user_id: int, role: str) -> str:
|
||||
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
payload = {"sub": str(user_id), "role": role, "type": "access", "exp": expire}
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"type": "access",
|
||||
"ver": token_version,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int) -> str:
|
||||
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
|
||||
payload = {"sub": str(user_id), "type": "refresh", "exp": expire}
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"type": "refresh",
|
||||
"ver": token_version,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
|
||||
@@ -69,8 +69,8 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
|
||||
await session.refresh(user)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id, user.role),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
access_token=create_access_token(user.id, user.role, user.token_version),
|
||||
refresh_token=create_refresh_token(user.id, user.token_version),
|
||||
)
|
||||
|
||||
|
||||
@@ -83,29 +83,33 @@ async def login(request: Request, body: LoginRequest, session: AsyncSession = De
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id, user.role),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
access_token=create_access_token(user.id, user.role, user.token_version),
|
||||
refresh_token=create_refresh_token(user.id, user.token_version),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(body: RefreshRequest, session: AsyncSession = Depends(get_session)):
|
||||
@limiter.limit("10/minute")
|
||||
async def refresh(request: Request, body: RefreshRequest, session: AsyncSession = Depends(get_session)):
|
||||
import jwt as pyjwt
|
||||
try:
|
||||
payload = decode_token(body.refresh_token)
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(status_code=401, detail="Invalid token type")
|
||||
user_id = int(payload["sub"])
|
||||
token_version = int(payload.get("ver", 1))
|
||||
except (pyjwt.PyJWTError, KeyError, ValueError) as exc:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if token_version != user.token_version:
|
||||
raise HTTPException(status_code=401, detail="Refresh token revoked")
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id, user.role),
|
||||
refresh_token=create_refresh_token(user.id),
|
||||
access_token=create_access_token(user.id, user.role, user.token_version),
|
||||
refresh_token=create_refresh_token(user.id, user.token_version),
|
||||
)
|
||||
|
||||
|
||||
@@ -130,6 +134,7 @@ async def change_password(
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
||||
user.hashed_password = _hash_password(body.new_password)
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return {"success": True}
|
||||
|
||||
@@ -14,10 +14,19 @@ class Settings(BaseSettings):
|
||||
secret_key: str = "change-me-in-production"
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.secret_key == "change-me-in-production" and not self.debug:
|
||||
if self.secret_key == "change-me-in-production":
|
||||
raise ValueError(
|
||||
"SECURITY: Cannot start with default secret_key in production. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY environment variable."
|
||||
"SECURITY: Refusing to start with the default secret_key. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||
"before starting the server (debug mode included)."
|
||||
)
|
||||
if len(self.secret_key) < 32:
|
||||
raise ValueError(
|
||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||
)
|
||||
if "*" in self.cors_allowed_origins.split(","):
|
||||
raise ValueError(
|
||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||
)
|
||||
|
||||
access_token_expire_minutes: int = 60
|
||||
|
||||
@@ -18,8 +18,23 @@ logger = logging.getLogger(__name__)
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IDENT_RE = __import__("re").compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _assert_ident(ident: str, kind: str = "identifier") -> str:
|
||||
"""Guard against SQL injection in dynamically interpolated identifiers.
|
||||
|
||||
All table/column names flow through here before being embedded into f-strings,
|
||||
so attacker-controlled values cannot break out even if they reach this layer.
|
||||
"""
|
||||
if not isinstance(ident, str) or not _IDENT_RE.match(ident):
|
||||
raise ValueError(f"Unsafe {kind}: {ident!r}")
|
||||
return ident
|
||||
|
||||
|
||||
async def _has_column(conn, table: str, column: str) -> bool:
|
||||
"""Check if a column exists in a SQLite table."""
|
||||
_assert_ident(table, "table")
|
||||
cols = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
@@ -1187,3 +1202,15 @@ async def migrate_notification_slot_locale(engine: AsyncEngine) -> None:
|
||||
"Merged system notification template configs for %s (EN=%d, RU=%d) into %d",
|
||||
provider_type, en_id, ru_id, en_id,
|
||||
)
|
||||
|
||||
|
||||
async def migrate_user_token_version(engine: AsyncEngine) -> None:
|
||||
"""Add token_version column to user for JWT revocation on password change."""
|
||||
async with engine.begin() as conn:
|
||||
if not await _has_table(conn, "user"):
|
||||
return
|
||||
if not await _has_column(conn, "user", "token_version"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
|
||||
)
|
||||
logger.info("Added token_version column to user table")
|
||||
|
||||
@@ -19,6 +19,7 @@ class User(SQLModel, table=True):
|
||||
username: str = Field(index=True, unique=True)
|
||||
hashed_password: str
|
||||
role: str = Field(default="user")
|
||||
token_version: int = Field(default=1)
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
# Run data migrations (idempotent)
|
||||
from .database.engine import get_engine
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
|
||||
engine = get_engine()
|
||||
await migrate_schema(engine)
|
||||
await migrate_tracker_targets(engine)
|
||||
@@ -63,6 +63,7 @@ async def lifespan(app: FastAPI):
|
||||
await migrate_receivers_from_config(engine)
|
||||
await migrate_command_slot_locale(engine)
|
||||
await migrate_notification_slot_locale(engine)
|
||||
await migrate_user_token_version(engine)
|
||||
from .database.seeds import seed_all
|
||||
await seed_all()
|
||||
# Configure webhook secret from DB setting (falls back to env var)
|
||||
|
||||
@@ -34,6 +34,44 @@ _LOGGER = logging.getLogger(__name__)
|
||||
# Fields to skip when serializing TrackingConfig into the generic `fields` dict
|
||||
_TRACKING_SKIP = frozenset(("id", "user_id", "provider_type", "name", "icon", "created_at"))
|
||||
|
||||
# Import-time config hardening limits
|
||||
_MAX_CONFIG_DEPTH = 6
|
||||
_MAX_CONFIG_KEYS = 200
|
||||
_MAX_STRING_LEN = 8192
|
||||
|
||||
|
||||
def _sanitize_config(value: Any, depth: int = 0) -> Any:
|
||||
"""Clamp imported config values to safe shapes before persistence.
|
||||
|
||||
Rejects anything that is not a JSON-compatible primitive/container, truncates
|
||||
over-long strings, and caps dict/list sizes. Returns a defensively-copied
|
||||
structure; the caller should never see attacker-controlled references.
|
||||
"""
|
||||
if depth > _MAX_CONFIG_DEPTH:
|
||||
raise ValueError("Config nesting exceeds maximum depth")
|
||||
if value is None or isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value[:_MAX_STRING_LEN]
|
||||
if isinstance(value, list):
|
||||
if len(value) > _MAX_CONFIG_KEYS:
|
||||
raise ValueError("Config list exceeds maximum length")
|
||||
return [_sanitize_config(v, depth + 1) for v in value]
|
||||
if isinstance(value, dict):
|
||||
if len(value) > _MAX_CONFIG_KEYS:
|
||||
raise ValueError("Config dict exceeds maximum key count")
|
||||
cleaned: dict[str, Any] = {}
|
||||
for k, v in value.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("Config keys must be strings")
|
||||
if len(k) > 128:
|
||||
raise ValueError(f"Config key too long: {k[:40]}...")
|
||||
cleaned[k] = _sanitize_config(v, depth + 1)
|
||||
return cleaned
|
||||
raise ValueError(f"Unsupported config value type: {type(value).__name__}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export
|
||||
@@ -530,9 +568,14 @@ async def import_backup(
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
try:
|
||||
safe_cfg = _sanitize_config(p.config or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped provider '{p.name}': {exc}")
|
||||
continue
|
||||
new_p = ServiceProvider(
|
||||
user_id=user_id, type=p.type, name=name,
|
||||
icon=p.icon, config=p.config,
|
||||
icon=p.icon, config=safe_cfg,
|
||||
)
|
||||
session.add(new_p)
|
||||
await session.flush()
|
||||
@@ -635,17 +678,27 @@ async def import_backup(
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
try:
|
||||
safe_tgt_cfg = _sanitize_config(tgt.config or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped target '{tgt.name}': {exc}")
|
||||
continue
|
||||
new_tgt = NotificationTarget(
|
||||
user_id=user_id, type=tgt.type, name=name,
|
||||
icon=tgt.icon, config=tgt.config,
|
||||
icon=tgt.icon, config=safe_tgt_cfg,
|
||||
chat_action=tgt.chat_action,
|
||||
)
|
||||
session.add(new_tgt)
|
||||
await session.flush()
|
||||
id_map["targets"][tgt.id] = new_tgt.id
|
||||
for r in tgt.receivers:
|
||||
try:
|
||||
safe_r_cfg = _sanitize_config(r.config or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped receiver in '{tgt.name}': {exc}")
|
||||
continue
|
||||
session.add(TargetReceiver(
|
||||
target_id=new_tgt.id, name=r.name, config=r.config,
|
||||
target_id=new_tgt.id, name=r.name, config=safe_r_cfg,
|
||||
receiver_key=r.receiver_key, locale=r.locale,
|
||||
enabled=r.enabled,
|
||||
))
|
||||
|
||||
@@ -249,6 +249,22 @@ async def load_link_data(
|
||||
event_key = s.slot_name.removeprefix("message_") if s.slot_name.startswith("message_") else s.slot_name
|
||||
slots_by_config.setdefault(s.config_id, {}).setdefault(event_key, {})[s.locale] = s.template
|
||||
|
||||
# Pre-resolve broadcast children in one query to avoid N+1 per-child fetches
|
||||
broadcast_child_ids: set[int] = set()
|
||||
for tt in active_links:
|
||||
target = target_map.get(tt.target_id)
|
||||
if target and target.type == "broadcast":
|
||||
disabled_ids = set(target.config.get("disabled_child_ids", []))
|
||||
for cid in target.config.get("child_target_ids", []):
|
||||
if cid not in disabled_ids:
|
||||
broadcast_child_ids.add(cid)
|
||||
child_target_map: dict[int, NotificationTarget] = {}
|
||||
if broadcast_child_ids:
|
||||
child_rows = await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.id.in_(broadcast_child_ids))
|
||||
)
|
||||
child_target_map = {t.id: t for t in child_rows.all()}
|
||||
|
||||
link_data: list[dict[str, Any]] = []
|
||||
for tt in active_links:
|
||||
target = target_map.get(tt.target_id)
|
||||
@@ -262,14 +278,13 @@ async def load_link_data(
|
||||
template_config = tmpl_map.get(tmpl_id) if tmpl_id else None
|
||||
template_slots = slots_by_config.get(template_config.id) if template_config else None
|
||||
|
||||
# Broadcast target: expand into child targets
|
||||
# Broadcast target: expand into child targets (pre-loaded above)
|
||||
if target.type == "broadcast":
|
||||
child_ids = target.config.get("child_target_ids", [])
|
||||
disabled_ids = set(target.config.get("disabled_child_ids", []))
|
||||
for child_id in child_ids:
|
||||
for child_id in target.config.get("child_target_ids", []):
|
||||
if child_id in disabled_ids:
|
||||
continue
|
||||
child_target = await session.get(NotificationTarget, child_id)
|
||||
child_target = child_target_map.get(child_id)
|
||||
if not child_target or child_target.type == "broadcast":
|
||||
continue
|
||||
resolved = await _resolve_target(session, child_target)
|
||||
|
||||
Reference in New Issue
Block a user