from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from sqlmodel import Session from app.database import get_session from app.websocket import manager from app.auth import decode_access_token from app.models import User, Channel from sqlmodel import select import json router = APIRouter() @router.websocket("/ws/{channel_id}") async def websocket_endpoint( websocket: WebSocket, channel_id: int, token: str = Query(...), ): """WebSocket endpoint for real-time channel messages and direct messages""" # Authenticate user via token username = decode_access_token(token) if not username: await websocket.close(code=1008, reason="Invalid authentication") return # Create a session for database operations from app.database import engine with Session(engine) as session: # Verify user exists statement = select(User).where(User.username == username) user = session.exec(statement).first() if not user: await websocket.close(code=1008, reason="User not found") return # Negative channel_id means direct messages (user_id) if channel_id < 0: # Direct message connection - verify it's the user's own connection if -channel_id != user.id: await websocket.close(code=1008, reason="Access denied") return else: # Regular channel - verify channel exists and user has access channel = session.get(Channel, channel_id) if not channel: await websocket.close(code=1008, reason="Channel not found") return user_dept_ids = [dept.id for dept in user.departments] if channel.department_id not in user_dept_ids: await websocket.close(code=1008, reason="Access denied") return # Connect to channel await manager.connect(websocket, channel_id) try: # Send welcome message await manager.send_personal_message( json.dumps({ "type": "system", "message": f"Connected to channel {channel_id}" }), websocket ) # Listen for messages while True: data = await websocket.receive_text() # Echo back or process the message # In production, you'd save to DB and broadcast try: message_data = json.loads(data) # Broadcast to all clients in the channel await manager.broadcast_to_channel( { "type": "message", "content": message_data.get("content", ""), "sender": username, "channel_id": channel_id }, channel_id ) except json.JSONDecodeError: await manager.send_personal_message( json.dumps({"type": "error", "message": "Invalid JSON"}), websocket ) except WebSocketDisconnect: manager.disconnect(websocket, channel_id) except Exception as e: manager.disconnect(websocket, channel_id) print(f"WebSocket error: {e}")