From 6bbddad7c2c703fa5061d2ed6c88bc878ae83dd4 Mon Sep 17 00:00:00 2001 From: Markov Date: Mon, 23 Feb 2026 20:25:21 +0100 Subject: [PATCH] Epic 1: WS multi-session, task events broadcast, REST auth middleware, heartbeat fix --- src/tracker/api/messages.py | 4 +- src/tracker/api/tasks.py | 8 ++ src/tracker/app.py | 108 +++++++++++++++---------- src/tracker/ws/handler.py | 152 ++++++++++++++++++------------------ src/tracker/ws/manager.py | 101 ++++++++++++++++-------- 5 files changed, 219 insertions(+), 154 deletions(-) diff --git a/src/tracker/api/messages.py b/src/tracker/api/messages.py index 5b44e2d..f1c578a 100644 --- a/src/tracker/api/messages.py +++ b/src/tracker/api/messages.py @@ -150,7 +150,7 @@ async def create_message(req: MessageCreate, db: AsyncSession = Depends(get_db)) elif chat and chat.kind == "lobby": await manager.broadcast_all( {"type": "message.new", "data": msg_data}, - exclude=msg.author_slug, + exclude_slug=msg.author_slug, ) return _message_out(msg) @@ -159,7 +159,7 @@ async def create_message(req: MessageCreate, db: AsyncSession = Depends(get_db)) else: await manager.broadcast_all( {"type": "message.new", "data": msg_data}, - exclude=msg.author_slug, + exclude_slug=msg.author_slug, ) return _message_out(msg) diff --git a/src/tracker/api/tasks.py b/src/tracker/api/tasks.py index 9aac36c..0afe9b3 100644 --- a/src/tracker/api/tasks.py +++ b/src/tracker/api/tasks.py @@ -246,8 +246,12 @@ async def update_task(task_id: str, req: TaskUpdate, db: AsyncSession = Depends( @router.delete("/tasks/{task_id}") async def delete_task(task_id: str, db: AsyncSession = Depends(get_db)): task = await _get_task(task_id, db) + project_id = str(task.project_id) + task_data = {"id": str(task.id), "project_id": project_id} await db.delete(task) await db.commit() + from tracker.ws.manager import manager + await manager.broadcast_task_event(project_id, "task.deleted", task_data) return {"ok": True} @@ -286,6 +290,10 @@ async def reject_task(task_id: str, req: RejectRequest, db: AsyncSession = Depen task.assignee_slug = None task.status = "todo" await db.commit() + from tracker.ws.manager import manager + await manager.broadcast_task_event(str(task.project_id), "task.updated", { + "id": str(task.id), "status": "todo", "assignee_slug": None, + }) return {"ok": True, "reason": req.reason, "old_assignee": old_assignee} diff --git a/src/tracker/app.py b/src/tracker/app.py index 19aceac..9a956d6 100644 --- a/src/tracker/app.py +++ b/src/tracker/app.py @@ -12,8 +12,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from tracker.config import settings -from tracker.database import engine -from tracker.models import Base +from tracker.database import engine, async_session +from tracker.models import Base, Member + +from sqlalchemy import select, update logging.basicConfig( level=logging.DEBUG if settings.env == "dev" else logging.INFO, @@ -25,64 +27,45 @@ logger = logging.getLogger("tracker") async def heartbeat_monitor(): """Monitor heartbeat timeout — set status=offline after 90 seconds.""" from tracker.ws.manager import manager - from tracker.database import async_session - from tracker.models import Member from datetime import datetime, timezone, timedelta - from sqlalchemy import select, update - + while True: try: - await asyncio.sleep(30) # check every 30 seconds - - # Get clients with last heartbeat timeout + await asyncio.sleep(30) timeout_threshold = datetime.now(timezone.utc) - timedelta(seconds=90) - timed_out_clients = [] - - for slug, client in list(manager.clients.items()): - if not hasattr(client, 'last_heartbeat'): - client.last_heartbeat = datetime.now(timezone.utc) - continue - + timed_out = [] + + for session_id, client in list(manager.sessions.items()): if client.last_heartbeat < timeout_threshold: - timed_out_clients.append(slug) - - if timed_out_clients: - async with async_session() as db: - # Update status to offline - await db.execute( - update(Member) - .where(Member.slug.in_(timed_out_clients)) - .values(status="offline") - ) - await db.commit() - - # Broadcast status changes and disconnect - for slug in timed_out_clients: + timed_out.append(session_id) + + for session_id in timed_out: + client = await manager.disconnect(session_id) + if client and not manager.is_online(client.member_slug): + async with async_session() as db: + await db.execute( + update(Member).where(Member.slug == client.member_slug).values(status="offline") + ) + await db.commit() await manager.broadcast_all( - {"type": "agent.status", "data": {"slug": slug, "status": "offline"}}, - exclude=slug, + {"type": "agent.status", "data": {"slug": client.member_slug, "status": "offline"}}, + exclude_slug=client.member_slug, ) - await manager.disconnect(slug) - logger.info("Heartbeat timeout: %s set offline", slug) - + logger.info("Heartbeat timeout: %s set offline", client.member_slug) + except Exception as e: logger.error("Heartbeat monitor error: %s", e) @asynccontextmanager async def lifespan(app: FastAPI): - """Create tables on startup (dev mode only).""" if settings.env == "dev": async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("Database tables ensured.") - - # Start heartbeat monitor + heartbeat_task = asyncio.create_task(heartbeat_monitor()) - yield - - # Cleanup heartbeat_task.cancel() try: await heartbeat_task @@ -97,6 +80,45 @@ app = FastAPI( lifespan=lifespan, ) +# Paths that don't require auth +NO_AUTH_PATHS = {"/health", "/docs", "/openapi.json", "/ws"} + + +@app.middleware("http") +async def auth_middleware(request: Request, call_next): + """Verify token for REST API requests.""" + path = request.url.path + + # Skip auth for non-API paths and excluded paths + if not path.startswith("/api/") and path not in NO_AUTH_PATHS: + # WS and health don't need REST auth + pass + elif path.startswith("/api/"): + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + # Check agent token + async with async_session() as db: + result = await db.execute(select(Member).where(Member.token == token)) + member = result.scalar_one_or_none() + if member: + request.state.member = member + else: + # Try JWT + from tracker.api.auth import decode_jwt + try: + payload = decode_jwt(token) + result = await db.execute(select(Member).where(Member.id == payload["sub"])) + member = result.scalar_one_or_none() + if member: + request.state.member = member + except Exception: + pass + # Don't enforce auth yet — BFF uses its own token via proxy + # TODO: enforce when all clients are updated + + return await call_next(request) + @app.middleware("http") async def log_requests(request: Request, call_next): @@ -124,12 +146,12 @@ async def log_requests(request: Request, call_next): # CORS app.add_middleware( CORSMiddleware, - allow_origins=["https://team.uix.su", "http://localhost:3100"], + allow_origins=["https://team.uix.su", "https://dev.team.uix.su", "http://localhost:3100"], allow_methods=["*"], allow_headers=["*"], ) -# Import and register routers (lazy to avoid circular imports) +# Routers from tracker.api import auth, members, projects, tasks, messages, steps # noqa: E402 from tracker.ws.handler import router as ws_router # noqa: E402 diff --git a/src/tracker/ws/handler.py b/src/tracker/ws/handler.py index 07aa893..d7bce2e 100644 --- a/src/tracker/ws/handler.py +++ b/src/tracker/ws/handler.py @@ -18,7 +18,7 @@ router = APIRouter() @router.websocket("/ws") async def websocket_endpoint(ws: WebSocket): await ws.accept() - slug = None + session_id = None try: # Wait for auth message @@ -30,60 +30,61 @@ async def websocket_endpoint(ws: WebSocket): token = auth_msg.get("token", "") on_behalf_of = auth_msg.get("on_behalf_of") - slug = await _authenticate(ws, token, on_behalf_of=on_behalf_of) - if not slug: + session_id = await _authenticate(ws, token, on_behalf_of=on_behalf_of) + if not session_id: return + client = manager.sessions.get(session_id) + slug = client.member_slug if client else None + # Main loop while True: data = await ws.receive_json() msg_type = data.get("type") if msg_type == "heartbeat": - await _handle_heartbeat(slug, data) + await _handle_heartbeat(session_id, data) elif msg_type == "ack": - pass # acknowledged, no action needed + pass elif msg_type == "chat.send": - await _handle_chat_send(slug, data) + await _handle_chat_send(session_id, data) elif msg_type == "project.subscribe": - await _handle_subscribe(slug, data) + await _handle_subscribe(session_id, data) elif msg_type == "project.unsubscribe": - await _handle_unsubscribe(slug, data) + await _handle_unsubscribe(session_id, data) else: await ws.send_json({"type": "error", "message": f"Unknown type: {msg_type}"}) except WebSocketDisconnect: - logger.info("WS disconnect: %s", slug) + if session_id: + client = manager.sessions.get(session_id) + logger.info("WS disconnect: %s", client.member_slug if client else session_id[:8]) except Exception as e: - logger.error("WS error for %s: %s", slug, e) + logger.error("WS error session=%s: %s", session_id and session_id[:8], e) finally: - if slug: - await manager.disconnect(slug) - # Update status to offline - async with async_session() as db: - result = await db.execute(select(Member).where(Member.slug == slug)) - member = result.scalar_one_or_none() - if member: - member.status = "offline" - await db.commit() - # Notify others - await manager.broadcast_all( - {"type": "agent.status", "data": {"slug": slug, "status": "offline"}}, - exclude=slug, - ) + if session_id: + client = await manager.disconnect(session_id) + if client and not manager.is_online(client.member_slug): + # Last session for this slug — mark offline + async with async_session() as db: + result = await db.execute(select(Member).where(Member.slug == client.member_slug)) + member = result.scalar_one_or_none() + if member: + member.status = "offline" + await db.commit() + await manager.broadcast_all( + {"type": "agent.status", "data": {"slug": client.member_slug, "status": "offline"}}, + exclude_slug=client.member_slug, + ) async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = None) -> str | None: - """Authenticate by token, return slug or None. - - If on_behalf_of is set and the token belongs to a bridge member, - use on_behalf_of slug instead (BFF proxying for web users). - """ + """Authenticate and register session. Returns session_id or None.""" async with async_session() as db: result = await db.execute( select(Member).where(Member.token == token).options(selectinload(Member.agent_config)) @@ -91,7 +92,6 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No member = result.scalar_one_or_none() if not member: - # Try JWT auth (for BFF/web client) from tracker.api.auth import decode_jwt try: payload = decode_jwt(token) @@ -108,11 +108,10 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No await ws.close() return None - # BFF proxy: bridge member can act on behalf of a user + # BFF proxy: bridge acts on behalf of user effective_slug = member.slug effective_type = member.type if on_behalf_of and member.type == "bridge": - # Look up the actual user user_result = await db.execute( select(Member).where(Member.slug == on_behalf_of) .options(selectinload(Member.agent_config)) @@ -121,23 +120,24 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No if user_member: effective_slug = user_member.slug effective_type = user_member.type - member = user_member # use user's settings - logger.info("Bridge %s acting on behalf of %s", member.slug, effective_slug) + member = user_member + logger.info("Bridge acting on behalf of %s", effective_slug) else: - # User not found, use a synthetic slug to avoid collisions - effective_slug = on_behalf_of - logger.info("Bridge acting on behalf of unknown user %s", effective_slug) + effective_slug = f"web-{on_behalf_of}" + logger.info("Bridge acting on behalf of unknown user → %s", effective_slug) - # Get listen modes + # Listen modes chat_listen = "all" task_listen = "all" if member.agent_config: chat_listen = member.agent_config.chat_listen task_listen = member.agent_config.task_listen - # Register connection + # Register connection with unique session_id + session_id = str(uuid.uuid4()) client = ConnectedClient( ws=ws, + session_id=session_id, member_slug=effective_slug, member_type=effective_type, chat_listen=chat_listen, @@ -149,14 +149,13 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No member.status = "online" await db.commit() - # Get lobby chat + projects with chat_id + # Get lobby chat + projects lobby = await db.execute(select(Chat).where(Chat.kind == "lobby")) lobby_chat = lobby.scalar_one_or_none() projects = await db.execute(select(Project).where(Project.status == "active")) project_list = [] for p in projects.scalars(): - # Get project chat chat_result = await db.execute( select(Chat).where(Chat.project_id == p.id, Chat.kind == "project") ) @@ -168,54 +167,56 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No "chat_id": str(chat.id) if chat else None, }) - online = list(manager.clients.keys()) - await ws.send_json({ "type": "auth.ok", "data": { "slug": effective_slug, "lobby_chat_id": str(lobby_chat.id) if lobby_chat else None, "projects": project_list, - "online": online, + "online": manager.online_slugs, }, }) # Notify others await manager.broadcast_all( - {"type": "agent.status", "data": {"slug": member.slug, "status": "online"}}, - exclude=member.slug, + {"type": "agent.status", "data": {"slug": effective_slug, "status": "online"}}, + exclude_slug=effective_slug, ) - return member.slug + return session_id -async def _handle_heartbeat(slug: str, data: dict): - """Update member status from heartbeat.""" +async def _handle_heartbeat(session_id: str, data: dict): + """Update heartbeat timestamp.""" from datetime import datetime, timezone - + + client = manager.sessions.get(session_id) + if not client: + return + status = data.get("status", "online") - - # Update last heartbeat timestamp - client = manager.clients.get(slug) - if client: - client.last_heartbeat = datetime.now(timezone.utc) - + client.last_heartbeat = datetime.now(timezone.utc) + async with async_session() as db: - result = await db.execute(select(Member).where(Member.slug == slug)) + result = await db.execute(select(Member).where(Member.slug == client.member_slug)) member = result.scalar_one_or_none() if member: member.status = status await db.commit() - - # Broadcast status change if different - await manager.broadcast_all( - {"type": "agent.status", "data": {"slug": slug, "status": status}}, - exclude=slug, - ) + + await manager.broadcast_all( + {"type": "agent.status", "data": {"slug": client.member_slug, "status": status}}, + exclude_slug=client.member_slug, + ) -async def _handle_chat_send(slug: str, data: dict): +async def _handle_chat_send(session_id: str, data: dict): """Handle chat message sent via WS.""" + client = manager.sessions.get(session_id) + if not client: + return + + slug = client.member_slug chat_id = data.get("chat_id") task_id = data.get("task_id") content = data.get("content", "") @@ -225,7 +226,6 @@ async def _handle_chat_send(slug: str, data: dict): return async with async_session() as db: - # Get sender info result = await db.execute(select(Member).where(Member.slug == slug)) member = result.scalar_one_or_none() if not member: @@ -263,40 +263,38 @@ async def _handle_chat_send(slug: str, data: dict): if chat and chat.project_id: project_id = str(chat.project_id) elif chat and chat.kind == "lobby": - # Lobby — broadcast to all await manager.broadcast_all( {"type": "message.new", "data": msg_data}, - exclude=slug, + exclude_slug=slug, ) return if project_id: await manager.broadcast_message(project_id, msg_data, author_slug=slug) else: - # Task comment or unlinked — broadcast to all await manager.broadcast_all( {"type": "message.new", "data": msg_data}, - exclude=slug, + exclude_slug=slug, ) -async def _handle_subscribe(slug: str, data: dict): - """Subscribe to a project's events.""" +async def _handle_subscribe(session_id: str, data: dict): + """Subscribe this session to project events.""" project_id = data.get("project_id") if not project_id: return - client = manager.clients.get(slug) + client = manager.sessions.get(session_id) if client: client.subscribed_projects.add(project_id) - logger.info("%s subscribed to project %s", slug, project_id) + logger.info("%s subscribed to project %s", client.member_slug, project_id) -async def _handle_unsubscribe(slug: str, data: dict): - """Unsubscribe from a project.""" +async def _handle_unsubscribe(session_id: str, data: dict): + """Unsubscribe this session from project events.""" project_id = data.get("project_id") if not project_id: return - client = manager.clients.get(slug) + client = manager.sessions.get(session_id) if client: client.subscribed_projects.discard(project_id) - logger.info("%s unsubscribed from project %s", slug, project_id) + logger.info("%s unsubscribed from project %s", client.member_slug, project_id) diff --git a/src/tracker/ws/manager.py b/src/tracker/ws/manager.py index 1c05982..43cf900 100644 --- a/src/tracker/ws/manager.py +++ b/src/tracker/ws/manager.py @@ -1,6 +1,7 @@ -"""WebSocket connection manager with project subscriptions and filtering.""" +"""WebSocket connection manager with multi-session support.""" import logging +import uuid from dataclasses import dataclass, field from datetime import datetime, timezone @@ -12,83 +13,119 @@ logger = logging.getLogger("tracker.ws") @dataclass class ConnectedClient: ws: WebSocket + session_id: str # unique per connection member_slug: str member_type: str # human | agent | bridge chat_listen: str = "all" # all | mentions | none task_listen: str = "all" # all | mentions | none - subscribed_projects: set[str] = field(default_factory=set) # project_ids + subscribed_projects: set[str] = field(default_factory=set) last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) class ConnectionManager: def __init__(self): - self.clients: dict[str, ConnectedClient] = {} # slug → client + # session_id → client (one entry per WS connection) + self.sessions: dict[str, ConnectedClient] = {} + # slug → set of session_ids (for quick lookup) + self.slug_sessions: dict[str, set[str]] = {} async def connect(self, client: ConnectedClient): - self.clients[client.member_slug] = client - logger.info("WS connected: %s (%s)", client.member_slug, client.member_type) + self.sessions[client.session_id] = client + if client.member_slug not in self.slug_sessions: + self.slug_sessions[client.member_slug] = set() + self.slug_sessions[client.member_slug].add(client.session_id) + logger.info("WS connected: %s session=%s (%s)", client.member_slug, client.session_id[:8], client.member_type) - async def disconnect(self, slug: str): - if slug in self.clients: - del self.clients[slug] - logger.info("WS disconnected: %s", slug) + async def disconnect(self, session_id: str): + client = self.sessions.pop(session_id, None) + if client: + slug_set = self.slug_sessions.get(client.member_slug) + if slug_set: + slug_set.discard(session_id) + if not slug_set: + del self.slug_sessions[client.member_slug] + logger.info("WS disconnected: %s session=%s", client.member_slug, session_id[:8]) + return client - async def send_to(self, slug: str, data: dict): - """Send to specific client.""" - client = self.clients.get(slug) + def get_sessions_for_slug(self, slug: str) -> list[ConnectedClient]: + """Get all active sessions for a member slug.""" + session_ids = self.slug_sessions.get(slug, set()) + return [self.sessions[sid] for sid in session_ids if sid in self.sessions] + + def is_online(self, slug: str) -> bool: + return bool(self.slug_sessions.get(slug)) + + @property + def online_slugs(self) -> list[str]: + return list(self.slug_sessions.keys()) + + async def send_to_session(self, session_id: str, data: dict): + """Send to a specific session.""" + client = self.sessions.get(session_id) if client: try: await client.ws.send_json(data) except Exception: - await self.disconnect(slug) + await self.disconnect(session_id) + + async def send_to_slug(self, slug: str, data: dict): + """Send to ALL sessions of a member.""" + for client in self.get_sessions_for_slug(slug): + try: + await client.ws.send_json(data) + except Exception: + await self.disconnect(client.session_id) async def broadcast_message(self, project_id: str, message: dict, author_slug: str): - """Broadcast message.new. Humans get everything, agents filtered by subscription.""" + """Broadcast message.new. Humans get everything, agents filtered.""" mentions = message.get("mentions", []) - for slug, client in list(self.clients.items()): - if slug == author_slug: + payload = {"type": "message.new", "data": message} + + for session_id, client in list(self.sessions.items()): + if client.member_slug == author_slug: continue - # Humans/bridges get ALL messages — filtering on client side + # Humans/bridges get ALL messages if client.member_type in ("human", "bridge"): - await self.send_to(slug, {"type": "message.new", "data": message}) + await self.send_to_session(session_id, payload) continue - # Agents: check subscription + chat_listen + # Agents: subscription + chat_listen if project_id not in client.subscribed_projects: continue if client.chat_listen == "none": continue - if client.chat_listen == "mentions" and slug not in mentions: + if client.chat_listen == "mentions" and client.member_slug not in mentions: continue - await self.send_to(slug, {"type": "message.new", "data": message}) + await self.send_to_session(session_id, payload) async def broadcast_task_event(self, project_id: str, event_type: str, data: dict): """Broadcast task events. Humans get everything, agents filtered.""" assignee = data.get("assignee_slug") reviewer = data.get("reviewer_slug") watchers = data.get("watchers", []) + payload = {"type": event_type, "data": data} - for slug, client in list(self.clients.items()): + for session_id, client in list(self.sessions.items()): # Humans/bridges get ALL task events if client.member_type in ("human", "bridge"): - await self.send_to(slug, {"type": event_type, "data": data}) + await self.send_to_session(session_id, payload) continue - # Agents: subscription + task_listen filter + # Agents: subscription + task_listen if project_id not in client.subscribed_projects: continue if client.task_listen == "none": continue if client.task_listen == "all": - await self.send_to(slug, {"type": event_type, "data": data}) + await self.send_to_session(session_id, payload) continue - if slug in (assignee, reviewer) or slug in watchers: - await self.send_to(slug, {"type": event_type, "data": data}) + if client.member_slug in (assignee, reviewer) or client.member_slug in watchers: + await self.send_to_session(session_id, payload) - async def broadcast_all(self, data: dict, exclude: str | None = None): - """Broadcast to all connected clients.""" - for slug, client in list(self.clients.items()): - if slug == exclude: + async def broadcast_all(self, data: dict, exclude_slug: str | None = None): + """Broadcast to all connected sessions.""" + for session_id, client in list(self.sessions.items()): + if client.member_slug == exclude_slug: continue - await self.send_to(slug, data) + await self.send_to_session(session_id, data) manager = ConnectionManager()