Epic 1: WS multi-session, task events broadcast, REST auth middleware, heartbeat fix
Some checks failed
Deploy Tracker / deploy (push) Failing after 1s
Some checks failed
Deploy Tracker / deploy (push) Failing after 1s
This commit is contained in:
parent
a1be497cfe
commit
6bbddad7c2
@ -150,7 +150,7 @@ async def create_message(req: MessageCreate, db: AsyncSession = Depends(get_db))
|
|||||||
elif chat and chat.kind == "lobby":
|
elif chat and chat.kind == "lobby":
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "message.new", "data": msg_data},
|
{"type": "message.new", "data": msg_data},
|
||||||
exclude=msg.author_slug,
|
exclude_slug=msg.author_slug,
|
||||||
)
|
)
|
||||||
return _message_out(msg)
|
return _message_out(msg)
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ async def create_message(req: MessageCreate, db: AsyncSession = Depends(get_db))
|
|||||||
else:
|
else:
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "message.new", "data": msg_data},
|
{"type": "message.new", "data": msg_data},
|
||||||
exclude=msg.author_slug,
|
exclude_slug=msg.author_slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _message_out(msg)
|
return _message_out(msg)
|
||||||
|
|||||||
@ -246,8 +246,12 @@ async def update_task(task_id: str, req: TaskUpdate, db: AsyncSession = Depends(
|
|||||||
@router.delete("/tasks/{task_id}")
|
@router.delete("/tasks/{task_id}")
|
||||||
async def delete_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_task(task_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
task = await _get_task(task_id, db)
|
task = await _get_task(task_id, db)
|
||||||
|
project_id = str(task.project_id)
|
||||||
|
task_data = {"id": str(task.id), "project_id": project_id}
|
||||||
await db.delete(task)
|
await db.delete(task)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
from tracker.ws.manager import manager
|
||||||
|
await manager.broadcast_task_event(project_id, "task.deleted", task_data)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@ -286,6 +290,10 @@ async def reject_task(task_id: str, req: RejectRequest, db: AsyncSession = Depen
|
|||||||
task.assignee_slug = None
|
task.assignee_slug = None
|
||||||
task.status = "todo"
|
task.status = "todo"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
from tracker.ws.manager import manager
|
||||||
|
await manager.broadcast_task_event(str(task.project_id), "task.updated", {
|
||||||
|
"id": str(task.id), "status": "todo", "assignee_slug": None,
|
||||||
|
})
|
||||||
return {"ok": True, "reason": req.reason, "old_assignee": old_assignee}
|
return {"ok": True, "reason": req.reason, "old_assignee": old_assignee}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from tracker.config import settings
|
from tracker.config import settings
|
||||||
from tracker.database import engine
|
from tracker.database import engine, async_session
|
||||||
from tracker.models import Base
|
from tracker.models import Base, Member
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG if settings.env == "dev" else logging.INFO,
|
level=logging.DEBUG if settings.env == "dev" else logging.INFO,
|
||||||
@ -25,45 +27,31 @@ logger = logging.getLogger("tracker")
|
|||||||
async def heartbeat_monitor():
|
async def heartbeat_monitor():
|
||||||
"""Monitor heartbeat timeout — set status=offline after 90 seconds."""
|
"""Monitor heartbeat timeout — set status=offline after 90 seconds."""
|
||||||
from tracker.ws.manager import manager
|
from tracker.ws.manager import manager
|
||||||
from tracker.database import async_session
|
|
||||||
from tracker.models import Member
|
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from sqlalchemy import select, update
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(30) # check every 30 seconds
|
await asyncio.sleep(30)
|
||||||
|
|
||||||
# Get clients with last heartbeat timeout
|
|
||||||
timeout_threshold = datetime.now(timezone.utc) - timedelta(seconds=90)
|
timeout_threshold = datetime.now(timezone.utc) - timedelta(seconds=90)
|
||||||
timed_out_clients = []
|
timed_out = []
|
||||||
|
|
||||||
for slug, client in list(manager.clients.items()):
|
|
||||||
if not hasattr(client, 'last_heartbeat'):
|
|
||||||
client.last_heartbeat = datetime.now(timezone.utc)
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for session_id, client in list(manager.sessions.items()):
|
||||||
if client.last_heartbeat < timeout_threshold:
|
if client.last_heartbeat < timeout_threshold:
|
||||||
timed_out_clients.append(slug)
|
timed_out.append(session_id)
|
||||||
|
|
||||||
if timed_out_clients:
|
for session_id in timed_out:
|
||||||
|
client = await manager.disconnect(session_id)
|
||||||
|
if client and not manager.is_online(client.member_slug):
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
# Update status to offline
|
|
||||||
await db.execute(
|
await db.execute(
|
||||||
update(Member)
|
update(Member).where(Member.slug == client.member_slug).values(status="offline")
|
||||||
.where(Member.slug.in_(timed_out_clients))
|
|
||||||
.values(status="offline")
|
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Broadcast status changes and disconnect
|
|
||||||
for slug in timed_out_clients:
|
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "agent.status", "data": {"slug": slug, "status": "offline"}},
|
{"type": "agent.status", "data": {"slug": client.member_slug, "status": "offline"}},
|
||||||
exclude=slug,
|
exclude_slug=client.member_slug,
|
||||||
)
|
)
|
||||||
await manager.disconnect(slug)
|
logger.info("Heartbeat timeout: %s set offline", client.member_slug)
|
||||||
logger.info("Heartbeat timeout: %s set offline", slug)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Heartbeat monitor error: %s", e)
|
logger.error("Heartbeat monitor error: %s", e)
|
||||||
@ -71,18 +59,13 @@ async def heartbeat_monitor():
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Create tables on startup (dev mode only)."""
|
|
||||||
if settings.env == "dev":
|
if settings.env == "dev":
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
logger.info("Database tables ensured.")
|
logger.info("Database tables ensured.")
|
||||||
|
|
||||||
# Start heartbeat monitor
|
|
||||||
heartbeat_task = asyncio.create_task(heartbeat_monitor())
|
heartbeat_task = asyncio.create_task(heartbeat_monitor())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
heartbeat_task.cancel()
|
heartbeat_task.cancel()
|
||||||
try:
|
try:
|
||||||
await heartbeat_task
|
await heartbeat_task
|
||||||
@ -97,6 +80,45 @@ app = FastAPI(
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Paths that don't require auth
|
||||||
|
NO_AUTH_PATHS = {"/health", "/docs", "/openapi.json", "/ws"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def auth_middleware(request: Request, call_next):
|
||||||
|
"""Verify token for REST API requests."""
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip auth for non-API paths and excluded paths
|
||||||
|
if not path.startswith("/api/") and path not in NO_AUTH_PATHS:
|
||||||
|
# WS and health don't need REST auth
|
||||||
|
pass
|
||||||
|
elif path.startswith("/api/"):
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
token = auth_header[7:]
|
||||||
|
# Check agent token
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(select(Member).where(Member.token == token))
|
||||||
|
member = result.scalar_one_or_none()
|
||||||
|
if member:
|
||||||
|
request.state.member = member
|
||||||
|
else:
|
||||||
|
# Try JWT
|
||||||
|
from tracker.api.auth import decode_jwt
|
||||||
|
try:
|
||||||
|
payload = decode_jwt(token)
|
||||||
|
result = await db.execute(select(Member).where(Member.id == payload["sub"]))
|
||||||
|
member = result.scalar_one_or_none()
|
||||||
|
if member:
|
||||||
|
request.state.member = member
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Don't enforce auth yet — BFF uses its own token via proxy
|
||||||
|
# TODO: enforce when all clients are updated
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request: Request, call_next):
|
async def log_requests(request: Request, call_next):
|
||||||
@ -124,12 +146,12 @@ async def log_requests(request: Request, call_next):
|
|||||||
# CORS
|
# CORS
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["https://team.uix.su", "http://localhost:3100"],
|
allow_origins=["https://team.uix.su", "https://dev.team.uix.su", "http://localhost:3100"],
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import and register routers (lazy to avoid circular imports)
|
# Routers
|
||||||
from tracker.api import auth, members, projects, tasks, messages, steps # noqa: E402
|
from tracker.api import auth, members, projects, tasks, messages, steps # noqa: E402
|
||||||
from tracker.ws.handler import router as ws_router # noqa: E402
|
from tracker.ws.handler import router as ws_router # noqa: E402
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ router = APIRouter()
|
|||||||
@router.websocket("/ws")
|
@router.websocket("/ws")
|
||||||
async def websocket_endpoint(ws: WebSocket):
|
async def websocket_endpoint(ws: WebSocket):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
slug = None
|
session_id = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for auth message
|
# Wait for auth message
|
||||||
@ -30,60 +30,61 @@ async def websocket_endpoint(ws: WebSocket):
|
|||||||
|
|
||||||
token = auth_msg.get("token", "")
|
token = auth_msg.get("token", "")
|
||||||
on_behalf_of = auth_msg.get("on_behalf_of")
|
on_behalf_of = auth_msg.get("on_behalf_of")
|
||||||
slug = await _authenticate(ws, token, on_behalf_of=on_behalf_of)
|
session_id = await _authenticate(ws, token, on_behalf_of=on_behalf_of)
|
||||||
if not slug:
|
if not session_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
client = manager.sessions.get(session_id)
|
||||||
|
slug = client.member_slug if client else None
|
||||||
|
|
||||||
# Main loop
|
# Main loop
|
||||||
while True:
|
while True:
|
||||||
data = await ws.receive_json()
|
data = await ws.receive_json()
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
|
|
||||||
if msg_type == "heartbeat":
|
if msg_type == "heartbeat":
|
||||||
await _handle_heartbeat(slug, data)
|
await _handle_heartbeat(session_id, data)
|
||||||
|
|
||||||
elif msg_type == "ack":
|
elif msg_type == "ack":
|
||||||
pass # acknowledged, no action needed
|
pass
|
||||||
|
|
||||||
elif msg_type == "chat.send":
|
elif msg_type == "chat.send":
|
||||||
await _handle_chat_send(slug, data)
|
await _handle_chat_send(session_id, data)
|
||||||
|
|
||||||
elif msg_type == "project.subscribe":
|
elif msg_type == "project.subscribe":
|
||||||
await _handle_subscribe(slug, data)
|
await _handle_subscribe(session_id, data)
|
||||||
|
|
||||||
elif msg_type == "project.unsubscribe":
|
elif msg_type == "project.unsubscribe":
|
||||||
await _handle_unsubscribe(slug, data)
|
await _handle_unsubscribe(session_id, data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await ws.send_json({"type": "error", "message": f"Unknown type: {msg_type}"})
|
await ws.send_json({"type": "error", "message": f"Unknown type: {msg_type}"})
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WS disconnect: %s", slug)
|
if session_id:
|
||||||
|
client = manager.sessions.get(session_id)
|
||||||
|
logger.info("WS disconnect: %s", client.member_slug if client else session_id[:8])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("WS error for %s: %s", slug, e)
|
logger.error("WS error session=%s: %s", session_id and session_id[:8], e)
|
||||||
finally:
|
finally:
|
||||||
if slug:
|
if session_id:
|
||||||
await manager.disconnect(slug)
|
client = await manager.disconnect(session_id)
|
||||||
# Update status to offline
|
if client and not manager.is_online(client.member_slug):
|
||||||
|
# Last session for this slug — mark offline
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
result = await db.execute(select(Member).where(Member.slug == slug))
|
result = await db.execute(select(Member).where(Member.slug == client.member_slug))
|
||||||
member = result.scalar_one_or_none()
|
member = result.scalar_one_or_none()
|
||||||
if member:
|
if member:
|
||||||
member.status = "offline"
|
member.status = "offline"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
# Notify others
|
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "agent.status", "data": {"slug": slug, "status": "offline"}},
|
{"type": "agent.status", "data": {"slug": client.member_slug, "status": "offline"}},
|
||||||
exclude=slug,
|
exclude_slug=client.member_slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = None) -> str | None:
|
async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = None) -> str | None:
|
||||||
"""Authenticate by token, return slug or None.
|
"""Authenticate and register session. Returns session_id or None."""
|
||||||
|
|
||||||
If on_behalf_of is set and the token belongs to a bridge member,
|
|
||||||
use on_behalf_of slug instead (BFF proxying for web users).
|
|
||||||
"""
|
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Member).where(Member.token == token).options(selectinload(Member.agent_config))
|
select(Member).where(Member.token == token).options(selectinload(Member.agent_config))
|
||||||
@ -91,7 +92,6 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No
|
|||||||
member = result.scalar_one_or_none()
|
member = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not member:
|
if not member:
|
||||||
# Try JWT auth (for BFF/web client)
|
|
||||||
from tracker.api.auth import decode_jwt
|
from tracker.api.auth import decode_jwt
|
||||||
try:
|
try:
|
||||||
payload = decode_jwt(token)
|
payload = decode_jwt(token)
|
||||||
@ -108,11 +108,10 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No
|
|||||||
await ws.close()
|
await ws.close()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# BFF proxy: bridge member can act on behalf of a user
|
# BFF proxy: bridge acts on behalf of user
|
||||||
effective_slug = member.slug
|
effective_slug = member.slug
|
||||||
effective_type = member.type
|
effective_type = member.type
|
||||||
if on_behalf_of and member.type == "bridge":
|
if on_behalf_of and member.type == "bridge":
|
||||||
# Look up the actual user
|
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(Member).where(Member.slug == on_behalf_of)
|
select(Member).where(Member.slug == on_behalf_of)
|
||||||
.options(selectinload(Member.agent_config))
|
.options(selectinload(Member.agent_config))
|
||||||
@ -121,23 +120,24 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No
|
|||||||
if user_member:
|
if user_member:
|
||||||
effective_slug = user_member.slug
|
effective_slug = user_member.slug
|
||||||
effective_type = user_member.type
|
effective_type = user_member.type
|
||||||
member = user_member # use user's settings
|
member = user_member
|
||||||
logger.info("Bridge %s acting on behalf of %s", member.slug, effective_slug)
|
logger.info("Bridge acting on behalf of %s", effective_slug)
|
||||||
else:
|
else:
|
||||||
# User not found, use a synthetic slug to avoid collisions
|
effective_slug = f"web-{on_behalf_of}"
|
||||||
effective_slug = on_behalf_of
|
logger.info("Bridge acting on behalf of unknown user → %s", effective_slug)
|
||||||
logger.info("Bridge acting on behalf of unknown user %s", effective_slug)
|
|
||||||
|
|
||||||
# Get listen modes
|
# Listen modes
|
||||||
chat_listen = "all"
|
chat_listen = "all"
|
||||||
task_listen = "all"
|
task_listen = "all"
|
||||||
if member.agent_config:
|
if member.agent_config:
|
||||||
chat_listen = member.agent_config.chat_listen
|
chat_listen = member.agent_config.chat_listen
|
||||||
task_listen = member.agent_config.task_listen
|
task_listen = member.agent_config.task_listen
|
||||||
|
|
||||||
# Register connection
|
# Register connection with unique session_id
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
client = ConnectedClient(
|
client = ConnectedClient(
|
||||||
ws=ws,
|
ws=ws,
|
||||||
|
session_id=session_id,
|
||||||
member_slug=effective_slug,
|
member_slug=effective_slug,
|
||||||
member_type=effective_type,
|
member_type=effective_type,
|
||||||
chat_listen=chat_listen,
|
chat_listen=chat_listen,
|
||||||
@ -149,14 +149,13 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No
|
|||||||
member.status = "online"
|
member.status = "online"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Get lobby chat + projects with chat_id
|
# Get lobby chat + projects
|
||||||
lobby = await db.execute(select(Chat).where(Chat.kind == "lobby"))
|
lobby = await db.execute(select(Chat).where(Chat.kind == "lobby"))
|
||||||
lobby_chat = lobby.scalar_one_or_none()
|
lobby_chat = lobby.scalar_one_or_none()
|
||||||
|
|
||||||
projects = await db.execute(select(Project).where(Project.status == "active"))
|
projects = await db.execute(select(Project).where(Project.status == "active"))
|
||||||
project_list = []
|
project_list = []
|
||||||
for p in projects.scalars():
|
for p in projects.scalars():
|
||||||
# Get project chat
|
|
||||||
chat_result = await db.execute(
|
chat_result = await db.execute(
|
||||||
select(Chat).where(Chat.project_id == p.id, Chat.kind == "project")
|
select(Chat).where(Chat.project_id == p.id, Chat.kind == "project")
|
||||||
)
|
)
|
||||||
@ -168,54 +167,56 @@ async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = No
|
|||||||
"chat_id": str(chat.id) if chat else None,
|
"chat_id": str(chat.id) if chat else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
online = list(manager.clients.keys())
|
|
||||||
|
|
||||||
await ws.send_json({
|
await ws.send_json({
|
||||||
"type": "auth.ok",
|
"type": "auth.ok",
|
||||||
"data": {
|
"data": {
|
||||||
"slug": effective_slug,
|
"slug": effective_slug,
|
||||||
"lobby_chat_id": str(lobby_chat.id) if lobby_chat else None,
|
"lobby_chat_id": str(lobby_chat.id) if lobby_chat else None,
|
||||||
"projects": project_list,
|
"projects": project_list,
|
||||||
"online": online,
|
"online": manager.online_slugs,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
# Notify others
|
# Notify others
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "agent.status", "data": {"slug": member.slug, "status": "online"}},
|
{"type": "agent.status", "data": {"slug": effective_slug, "status": "online"}},
|
||||||
exclude=member.slug,
|
exclude_slug=effective_slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
return member.slug
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
async def _handle_heartbeat(slug: str, data: dict):
|
async def _handle_heartbeat(session_id: str, data: dict):
|
||||||
"""Update member status from heartbeat."""
|
"""Update heartbeat timestamp."""
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
status = data.get("status", "online")
|
client = manager.sessions.get(session_id)
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
# Update last heartbeat timestamp
|
status = data.get("status", "online")
|
||||||
client = manager.clients.get(slug)
|
|
||||||
if client:
|
|
||||||
client.last_heartbeat = datetime.now(timezone.utc)
|
client.last_heartbeat = datetime.now(timezone.utc)
|
||||||
|
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
result = await db.execute(select(Member).where(Member.slug == slug))
|
result = await db.execute(select(Member).where(Member.slug == client.member_slug))
|
||||||
member = result.scalar_one_or_none()
|
member = result.scalar_one_or_none()
|
||||||
if member:
|
if member:
|
||||||
member.status = status
|
member.status = status
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Broadcast status change if different
|
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "agent.status", "data": {"slug": slug, "status": status}},
|
{"type": "agent.status", "data": {"slug": client.member_slug, "status": status}},
|
||||||
exclude=slug,
|
exclude_slug=client.member_slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_chat_send(slug: str, data: dict):
|
async def _handle_chat_send(session_id: str, data: dict):
|
||||||
"""Handle chat message sent via WS."""
|
"""Handle chat message sent via WS."""
|
||||||
|
client = manager.sessions.get(session_id)
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
|
slug = client.member_slug
|
||||||
chat_id = data.get("chat_id")
|
chat_id = data.get("chat_id")
|
||||||
task_id = data.get("task_id")
|
task_id = data.get("task_id")
|
||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
@ -225,7 +226,6 @@ async def _handle_chat_send(slug: str, data: dict):
|
|||||||
return
|
return
|
||||||
|
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
# Get sender info
|
|
||||||
result = await db.execute(select(Member).where(Member.slug == slug))
|
result = await db.execute(select(Member).where(Member.slug == slug))
|
||||||
member = result.scalar_one_or_none()
|
member = result.scalar_one_or_none()
|
||||||
if not member:
|
if not member:
|
||||||
@ -263,40 +263,38 @@ async def _handle_chat_send(slug: str, data: dict):
|
|||||||
if chat and chat.project_id:
|
if chat and chat.project_id:
|
||||||
project_id = str(chat.project_id)
|
project_id = str(chat.project_id)
|
||||||
elif chat and chat.kind == "lobby":
|
elif chat and chat.kind == "lobby":
|
||||||
# Lobby — broadcast to all
|
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "message.new", "data": msg_data},
|
{"type": "message.new", "data": msg_data},
|
||||||
exclude=slug,
|
exclude_slug=slug,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if project_id:
|
if project_id:
|
||||||
await manager.broadcast_message(project_id, msg_data, author_slug=slug)
|
await manager.broadcast_message(project_id, msg_data, author_slug=slug)
|
||||||
else:
|
else:
|
||||||
# Task comment or unlinked — broadcast to all
|
|
||||||
await manager.broadcast_all(
|
await manager.broadcast_all(
|
||||||
{"type": "message.new", "data": msg_data},
|
{"type": "message.new", "data": msg_data},
|
||||||
exclude=slug,
|
exclude_slug=slug,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_subscribe(slug: str, data: dict):
|
async def _handle_subscribe(session_id: str, data: dict):
|
||||||
"""Subscribe to a project's events."""
|
"""Subscribe this session to project events."""
|
||||||
project_id = data.get("project_id")
|
project_id = data.get("project_id")
|
||||||
if not project_id:
|
if not project_id:
|
||||||
return
|
return
|
||||||
client = manager.clients.get(slug)
|
client = manager.sessions.get(session_id)
|
||||||
if client:
|
if client:
|
||||||
client.subscribed_projects.add(project_id)
|
client.subscribed_projects.add(project_id)
|
||||||
logger.info("%s subscribed to project %s", slug, project_id)
|
logger.info("%s subscribed to project %s", client.member_slug, project_id)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_unsubscribe(slug: str, data: dict):
|
async def _handle_unsubscribe(session_id: str, data: dict):
|
||||||
"""Unsubscribe from a project."""
|
"""Unsubscribe this session from project events."""
|
||||||
project_id = data.get("project_id")
|
project_id = data.get("project_id")
|
||||||
if not project_id:
|
if not project_id:
|
||||||
return
|
return
|
||||||
client = manager.clients.get(slug)
|
client = manager.sessions.get(session_id)
|
||||||
if client:
|
if client:
|
||||||
client.subscribed_projects.discard(project_id)
|
client.subscribed_projects.discard(project_id)
|
||||||
logger.info("%s unsubscribed from project %s", slug, project_id)
|
logger.info("%s unsubscribed from project %s", client.member_slug, project_id)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""WebSocket connection manager with project subscriptions and filtering."""
|
"""WebSocket connection manager with multi-session support."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
@ -12,83 +13,119 @@ logger = logging.getLogger("tracker.ws")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConnectedClient:
|
class ConnectedClient:
|
||||||
ws: WebSocket
|
ws: WebSocket
|
||||||
|
session_id: str # unique per connection
|
||||||
member_slug: str
|
member_slug: str
|
||||||
member_type: str # human | agent | bridge
|
member_type: str # human | agent | bridge
|
||||||
chat_listen: str = "all" # all | mentions | none
|
chat_listen: str = "all" # all | mentions | none
|
||||||
task_listen: str = "all" # all | mentions | none
|
task_listen: str = "all" # all | mentions | none
|
||||||
subscribed_projects: set[str] = field(default_factory=set) # project_ids
|
subscribed_projects: set[str] = field(default_factory=set)
|
||||||
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
class ConnectionManager:
|
class ConnectionManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.clients: dict[str, ConnectedClient] = {} # slug → client
|
# session_id → client (one entry per WS connection)
|
||||||
|
self.sessions: dict[str, ConnectedClient] = {}
|
||||||
|
# slug → set of session_ids (for quick lookup)
|
||||||
|
self.slug_sessions: dict[str, set[str]] = {}
|
||||||
|
|
||||||
async def connect(self, client: ConnectedClient):
|
async def connect(self, client: ConnectedClient):
|
||||||
self.clients[client.member_slug] = client
|
self.sessions[client.session_id] = client
|
||||||
logger.info("WS connected: %s (%s)", client.member_slug, client.member_type)
|
if client.member_slug not in self.slug_sessions:
|
||||||
|
self.slug_sessions[client.member_slug] = set()
|
||||||
|
self.slug_sessions[client.member_slug].add(client.session_id)
|
||||||
|
logger.info("WS connected: %s session=%s (%s)", client.member_slug, client.session_id[:8], client.member_type)
|
||||||
|
|
||||||
async def disconnect(self, slug: str):
|
async def disconnect(self, session_id: str):
|
||||||
if slug in self.clients:
|
client = self.sessions.pop(session_id, None)
|
||||||
del self.clients[slug]
|
if client:
|
||||||
logger.info("WS disconnected: %s", slug)
|
slug_set = self.slug_sessions.get(client.member_slug)
|
||||||
|
if slug_set:
|
||||||
|
slug_set.discard(session_id)
|
||||||
|
if not slug_set:
|
||||||
|
del self.slug_sessions[client.member_slug]
|
||||||
|
logger.info("WS disconnected: %s session=%s", client.member_slug, session_id[:8])
|
||||||
|
return client
|
||||||
|
|
||||||
async def send_to(self, slug: str, data: dict):
|
def get_sessions_for_slug(self, slug: str) -> list[ConnectedClient]:
|
||||||
"""Send to specific client."""
|
"""Get all active sessions for a member slug."""
|
||||||
client = self.clients.get(slug)
|
session_ids = self.slug_sessions.get(slug, set())
|
||||||
|
return [self.sessions[sid] for sid in session_ids if sid in self.sessions]
|
||||||
|
|
||||||
|
def is_online(self, slug: str) -> bool:
|
||||||
|
return bool(self.slug_sessions.get(slug))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def online_slugs(self) -> list[str]:
|
||||||
|
return list(self.slug_sessions.keys())
|
||||||
|
|
||||||
|
async def send_to_session(self, session_id: str, data: dict):
|
||||||
|
"""Send to a specific session."""
|
||||||
|
client = self.sessions.get(session_id)
|
||||||
if client:
|
if client:
|
||||||
try:
|
try:
|
||||||
await client.ws.send_json(data)
|
await client.ws.send_json(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self.disconnect(slug)
|
await self.disconnect(session_id)
|
||||||
|
|
||||||
|
async def send_to_slug(self, slug: str, data: dict):
|
||||||
|
"""Send to ALL sessions of a member."""
|
||||||
|
for client in self.get_sessions_for_slug(slug):
|
||||||
|
try:
|
||||||
|
await client.ws.send_json(data)
|
||||||
|
except Exception:
|
||||||
|
await self.disconnect(client.session_id)
|
||||||
|
|
||||||
async def broadcast_message(self, project_id: str, message: dict, author_slug: str):
|
async def broadcast_message(self, project_id: str, message: dict, author_slug: str):
|
||||||
"""Broadcast message.new. Humans get everything, agents filtered by subscription."""
|
"""Broadcast message.new. Humans get everything, agents filtered."""
|
||||||
mentions = message.get("mentions", [])
|
mentions = message.get("mentions", [])
|
||||||
for slug, client in list(self.clients.items()):
|
payload = {"type": "message.new", "data": message}
|
||||||
if slug == author_slug:
|
|
||||||
|
for session_id, client in list(self.sessions.items()):
|
||||||
|
if client.member_slug == author_slug:
|
||||||
continue
|
continue
|
||||||
# Humans/bridges get ALL messages — filtering on client side
|
# Humans/bridges get ALL messages
|
||||||
if client.member_type in ("human", "bridge"):
|
if client.member_type in ("human", "bridge"):
|
||||||
await self.send_to(slug, {"type": "message.new", "data": message})
|
await self.send_to_session(session_id, payload)
|
||||||
continue
|
continue
|
||||||
# Agents: check subscription + chat_listen
|
# Agents: subscription + chat_listen
|
||||||
if project_id not in client.subscribed_projects:
|
if project_id not in client.subscribed_projects:
|
||||||
continue
|
continue
|
||||||
if client.chat_listen == "none":
|
if client.chat_listen == "none":
|
||||||
continue
|
continue
|
||||||
if client.chat_listen == "mentions" and slug not in mentions:
|
if client.chat_listen == "mentions" and client.member_slug not in mentions:
|
||||||
continue
|
continue
|
||||||
await self.send_to(slug, {"type": "message.new", "data": message})
|
await self.send_to_session(session_id, payload)
|
||||||
|
|
||||||
async def broadcast_task_event(self, project_id: str, event_type: str, data: dict):
|
async def broadcast_task_event(self, project_id: str, event_type: str, data: dict):
|
||||||
"""Broadcast task events. Humans get everything, agents filtered."""
|
"""Broadcast task events. Humans get everything, agents filtered."""
|
||||||
assignee = data.get("assignee_slug")
|
assignee = data.get("assignee_slug")
|
||||||
reviewer = data.get("reviewer_slug")
|
reviewer = data.get("reviewer_slug")
|
||||||
watchers = data.get("watchers", [])
|
watchers = data.get("watchers", [])
|
||||||
|
payload = {"type": event_type, "data": data}
|
||||||
|
|
||||||
for slug, client in list(self.clients.items()):
|
for session_id, client in list(self.sessions.items()):
|
||||||
# Humans/bridges get ALL task events
|
# Humans/bridges get ALL task events
|
||||||
if client.member_type in ("human", "bridge"):
|
if client.member_type in ("human", "bridge"):
|
||||||
await self.send_to(slug, {"type": event_type, "data": data})
|
await self.send_to_session(session_id, payload)
|
||||||
continue
|
continue
|
||||||
# Agents: subscription + task_listen filter
|
# Agents: subscription + task_listen
|
||||||
if project_id not in client.subscribed_projects:
|
if project_id not in client.subscribed_projects:
|
||||||
continue
|
continue
|
||||||
if client.task_listen == "none":
|
if client.task_listen == "none":
|
||||||
continue
|
continue
|
||||||
if client.task_listen == "all":
|
if client.task_listen == "all":
|
||||||
await self.send_to(slug, {"type": event_type, "data": data})
|
await self.send_to_session(session_id, payload)
|
||||||
continue
|
continue
|
||||||
if slug in (assignee, reviewer) or slug in watchers:
|
if client.member_slug in (assignee, reviewer) or client.member_slug in watchers:
|
||||||
await self.send_to(slug, {"type": event_type, "data": data})
|
await self.send_to_session(session_id, payload)
|
||||||
|
|
||||||
async def broadcast_all(self, data: dict, exclude: str | None = None):
|
async def broadcast_all(self, data: dict, exclude_slug: str | None = None):
|
||||||
"""Broadcast to all connected clients."""
|
"""Broadcast to all connected sessions."""
|
||||||
for slug, client in list(self.clients.items()):
|
for session_id, client in list(self.sessions.items()):
|
||||||
if slug == exclude:
|
if client.member_slug == exclude_slug:
|
||||||
continue
|
continue
|
||||||
await self.send_to(slug, data)
|
await self.send_to_session(session_id, data)
|
||||||
|
|
||||||
|
|
||||||
manager = ConnectionManager()
|
manager = ConnectionManager()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user