feat: JWT auth in WebSocket handler
Some checks failed
Deploy Tracker / deploy (push) Failing after 2s
Some checks failed
Deploy Tracker / deploy (push) Failing after 2s
This commit is contained in:
parent
dbd20ff550
commit
e39c26d321
@ -16,23 +16,29 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket):
|
||||
async def websocket_endpoint(ws: WebSocket, token: str = ""):
|
||||
await ws.accept()
|
||||
session_id = None
|
||||
|
||||
try:
|
||||
# Wait for auth message
|
||||
auth_msg = await ws.receive_json()
|
||||
if auth_msg.get("type") != "auth":
|
||||
await ws.send_json({"type": "auth.error", "message": "First message must be auth"})
|
||||
await ws.close()
|
||||
return
|
||||
# Try query param token first (for direct JWT auth)
|
||||
if token:
|
||||
session_id = await _authenticate(ws, token)
|
||||
if not session_id:
|
||||
return
|
||||
else:
|
||||
# Wait for auth message (backward compatibility with agents)
|
||||
auth_msg = await ws.receive_json()
|
||||
if auth_msg.get("type") != "auth":
|
||||
await ws.send_json({"type": "auth.error", "message": "First message must be auth"})
|
||||
await ws.close()
|
||||
return
|
||||
|
||||
token = auth_msg.get("token", "")
|
||||
on_behalf_of = auth_msg.get("on_behalf_of")
|
||||
session_id = await _authenticate(ws, token, on_behalf_of=on_behalf_of)
|
||||
if not session_id:
|
||||
return
|
||||
token = auth_msg.get("token", "")
|
||||
on_behalf_of = auth_msg.get("on_behalf_of")
|
||||
session_id = await _authenticate(ws, token, on_behalf_of=on_behalf_of)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
client = manager.sessions.get(session_id)
|
||||
slug = client.member_slug if client else None
|
||||
@ -86,22 +92,28 @@ async def websocket_endpoint(ws: WebSocket):
|
||||
async def _authenticate(ws: WebSocket, token: str, on_behalf_of: str | None = None) -> str | None:
|
||||
"""Authenticate and register session. Returns session_id or None."""
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(Member).where(Member.token == token).options(selectinload(Member.agent_config))
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
member = None
|
||||
|
||||
if not member:
|
||||
# Check if it's an agent token (starts with 'tb-')
|
||||
if token.startswith("tb-"):
|
||||
result = await db.execute(
|
||||
select(Member).where(Member.token == token).options(selectinload(Member.agent_config))
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
else:
|
||||
# Try JWT decode
|
||||
from tracker.api.auth import decode_jwt
|
||||
try:
|
||||
payload = decode_jwt(token)
|
||||
member_id = payload["sub"]
|
||||
result = await db.execute(
|
||||
select(Member).where(Member.id == payload["sub"])
|
||||
select(Member).where(Member.id == member_id)
|
||||
.options(selectinload(Member.agent_config))
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("JWT auth successful for member_id=%s", member_id)
|
||||
except Exception as e:
|
||||
logger.warning("JWT decode failed: %s", e)
|
||||
|
||||
if not member:
|
||||
await ws.send_json({"type": "auth.error", "message": "Invalid token"})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user