170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
"""Provider-agnostic update checker service."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from functools import total_ordering
|
|
from typing import Any, Optional
|
|
|
|
from .release_provider import ReleaseProvider
|
|
from .websocket_manager import ws_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PRE_PATTERN = re.compile(
|
|
r"^(\d+\.\d+\.\d+)[-.]?(alpha|beta|rc)[.-]?(\d+)$", re.IGNORECASE
|
|
)
|
|
_PRE_ORDER = {"alpha": 0, "beta": 1, "rc": 2}
|
|
|
|
|
|
@total_ordering
|
|
class _Version:
|
|
"""Lightweight PEP 440-ish version for comparison without packaging dep.
|
|
|
|
Supports: X.Y.Z and X.Y.Z-{alpha,beta,rc}.N
|
|
Pre-releases sort before the corresponding stable release.
|
|
"""
|
|
|
|
__slots__ = ("_release", "_pre")
|
|
|
|
def __init__(self, release: tuple[int, ...], pre: Optional[tuple[int, int]]) -> None:
|
|
self._release = release
|
|
self._pre = pre
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, _Version):
|
|
return NotImplemented
|
|
return self._release == other._release and self._pre == other._pre
|
|
|
|
def __lt__(self, other: object) -> bool:
|
|
if not isinstance(other, _Version):
|
|
return NotImplemented
|
|
if self._release != other._release:
|
|
return self._release < other._release
|
|
# No pre-release (stable) is greater than any pre-release
|
|
if self._pre is None and other._pre is None:
|
|
return False
|
|
if self._pre is not None and other._pre is None:
|
|
return True
|
|
if self._pre is None and other._pre is not None:
|
|
return False
|
|
return self._pre < other._pre # type: ignore[operator]
|
|
|
|
def __repr__(self) -> str:
|
|
v = ".".join(str(p) for p in self._release)
|
|
if self._pre is not None:
|
|
labels = {0: "alpha", 1: "beta", 2: "rc"}
|
|
v += f"-{labels[self._pre[0]]}.{self._pre[1]}"
|
|
return f"_Version('{v}')"
|
|
|
|
|
|
def _parse_version(raw: str) -> _Version:
|
|
"""Parse a version tag for comparison.
|
|
|
|
Examples:
|
|
v0.3.0-alpha.1 → (0,3,0) pre=(0,1) (sorts below 0.3.0)
|
|
v0.3.0-rc.3 → (0,3,0) pre=(2,3)
|
|
v1.0.0 → (1,0,0) pre=None
|
|
"""
|
|
cleaned = raw.lstrip("v").strip()
|
|
m = _PRE_PATTERN.match(cleaned)
|
|
if m:
|
|
base = tuple(int(x) for x in m.group(1).split("."))
|
|
pre_label = m.group(2).lower()
|
|
pre_num = int(m.group(3))
|
|
return _Version(base, (_PRE_ORDER[pre_label], pre_num))
|
|
release = tuple(int(x) for x in cleaned.split("."))
|
|
return _Version(release, None)
|
|
|
|
|
|
class UpdateChecker:
|
|
"""Periodically checks for new releases using a ReleaseProvider."""
|
|
|
|
def __init__(self, provider: ReleaseProvider, current_version: str) -> None:
|
|
self._provider = provider
|
|
self._current_version = current_version
|
|
self._current_parsed = _parse_version(current_version)
|
|
self._task: Optional[asyncio.Task] = None
|
|
self._cached_update: Optional[dict[str, Any]] = None
|
|
|
|
@property
|
|
def cached_update(self) -> Optional[dict[str, Any]]:
|
|
"""Return the cached update info, or None if up-to-date."""
|
|
return self._cached_update
|
|
|
|
async def check_for_update(self) -> Optional[dict[str, Any]]:
|
|
"""Check for a newer release.
|
|
|
|
Returns:
|
|
Dict with current/latest/url if an update exists, None otherwise.
|
|
"""
|
|
release = await self._provider.get_latest_release()
|
|
if release is None:
|
|
return None
|
|
|
|
latest_parsed = _parse_version(release.version)
|
|
if latest_parsed <= self._current_parsed:
|
|
return None
|
|
|
|
return {
|
|
"current": self._current_version,
|
|
"latest": release.version,
|
|
"url": release.url,
|
|
}
|
|
|
|
async def start(self, interval: int) -> None:
|
|
"""Start periodic update checking.
|
|
|
|
Checks immediately on start, then every `interval` seconds.
|
|
"""
|
|
if self._task is not None:
|
|
return
|
|
|
|
self._task = asyncio.create_task(self._check_loop(interval))
|
|
logger.info("Update checker started (interval: %ds)", interval)
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop periodic update checking."""
|
|
if self._task is not None:
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._task = None
|
|
logger.info("Update checker stopped")
|
|
|
|
async def _check_loop(self, interval: int) -> None:
|
|
"""Background loop that checks for updates periodically."""
|
|
# Initial check with a small delay to let the server finish starting
|
|
await asyncio.sleep(5)
|
|
|
|
while True:
|
|
try:
|
|
update = await self.check_for_update()
|
|
|
|
if update and update != self._cached_update:
|
|
self._cached_update = update
|
|
logger.info(
|
|
"New version available: %s → %s (%s)",
|
|
update["current"],
|
|
update["latest"],
|
|
update["url"],
|
|
)
|
|
await ws_manager.broadcast(
|
|
{"type": "update_available", "data": update}
|
|
)
|
|
elif update is None and self._cached_update is not None:
|
|
# Version was updated (or release removed) — clear cache
|
|
self._cached_update = None
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.warning("Update check failed: %s", e)
|
|
|
|
try:
|
|
await asyncio.sleep(interval)
|
|
except asyncio.CancelledError:
|
|
break
|