from fastapi import WebSocket, WebSocketDisconnect from typing import Dict, List, Optional import json import time from datetime import datetime, timedelta class ConnectionManager: def __init__(self): # Maps channel_id to list of WebSocket connections self.active_connections: Dict[int, List[WebSocket]] = {} # Maps user_id to their connection info self.user_connections: Dict[int, Dict] = {} async def connect(self, websocket: WebSocket, channel_id: int, user_id: int): """Accept a new WebSocket connection for a channel""" await websocket.accept() if channel_id not in self.active_connections: self.active_connections[channel_id] = [] self.active_connections[channel_id].append(websocket) # Track user connection self.user_connections[user_id] = { 'websocket': websocket, 'channel_id': channel_id, 'last_activity': time.time(), 'connected_at': time.time() } def disconnect(self, websocket: WebSocket, channel_id: int, user_id: int): """Remove a WebSocket connection""" if channel_id in self.active_connections: if websocket in self.active_connections[channel_id]: self.active_connections[channel_id].remove(websocket) # Clean up empty channel lists if not self.active_connections[channel_id]: del self.active_connections[channel_id] # Remove user connection if user_id in self.user_connections: del self.user_connections[user_id] def update_activity(self, user_id: int): """Update last activity time for a user""" if user_id in self.user_connections: self.user_connections[user_id]['last_activity'] = time.time() def get_user_status(self, user_id: int) -> str: """Get user online status""" if user_id not in self.user_connections: return 'offline' # User is online as long as they have an active connection return 'online' def get_all_user_statuses(self) -> Dict[int, str]: """Get status for all users""" statuses = {} for user_id in self.user_connections: statuses[user_id] = self.get_user_status(user_id) return statuses async def send_personal_message(self, message: str, websocket: WebSocket): """Send a message to a specific WebSocket""" await websocket.send_text(message) async def broadcast_to_channel(self, message: dict, channel_id: int): """Broadcast a message to all connections in a channel""" if channel_id in self.active_connections: message_str = json.dumps(message) disconnected = [] for connection in self.active_connections[channel_id]: try: await connection.send_text(message_str) except Exception: # Mark for removal if send fails disconnected.append(connection) # Also broadcast to channel 0 (global listeners) for messages if message.get("type") in ["message", "direct_message"] and 0 in self.active_connections: for connection in self.active_connections[0]: try: await connection.send_text(message_str) except Exception: pass async def broadcast_user_status_update(self, user_id: int, status: str): """Broadcast user status update to all connected clients""" message = { "type": "user_status_update", "user_id": user_id, "status": status, "timestamp": time.time() } # Broadcast to all channels (presence connections are on channel 0) for channel_id in self.active_connections: message_str = json.dumps(message) disconnected = [] for connection in self.active_connections[channel_id]: try: await connection.send_text(message_str) except Exception: disconnected.append(connection) # Clean up disconnected clients for connection in disconnected: user_id_to_remove = None for uid, conn_info in self.user_connections.items(): if conn_info['websocket'] == connection: user_id_to_remove = uid break if user_id_to_remove: self.disconnect(connection, channel_id, user_id_to_remove) # Global connection manager instance manager = ConnectionManager()