mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2026-03-31 07:12:23 -07:00
feat: implement Server-Sent Events (SSE) for real-time updates and notifications
This commit is contained in:
@@ -58,6 +58,9 @@ from .mcp_endpoint import ( # noqa: E402 [flake8 lint suppression]
|
||||
mcp_messages,
|
||||
openapi_spec
|
||||
)
|
||||
from .sse_endpoint import ( # noqa: E402 [flake8 lint suppression]
|
||||
create_sse_endpoint
|
||||
)
|
||||
# tools and mcp routes have been moved into this module (api_server_start)
|
||||
|
||||
# Flask application
|
||||
@@ -81,7 +84,8 @@ CORS(
|
||||
r"/logs/*": {"origins": "*"},
|
||||
r"/api/tools/*": {"origins": "*"},
|
||||
r"/auth/*": {"origins": "*"},
|
||||
r"/mcp/*": {"origins": "*"}
|
||||
r"/mcp/*": {"origins": "*"},
|
||||
r"/sse/*": {"origins": "*"}
|
||||
},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
@@ -1084,8 +1088,16 @@ def check_auth():
|
||||
# Background Server Start
|
||||
# --------------------------
|
||||
def is_authorized():
|
||||
token = request.headers.get("Authorization")
|
||||
is_authorized = token == f"Bearer {get_setting_value('API_TOKEN')}"
|
||||
expected_token = get_setting_value('API_TOKEN')
|
||||
|
||||
# Check Authorization header first (primary method)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
header_token = auth_header.split()[-1] if auth_header.startswith("Bearer ") else ""
|
||||
|
||||
# Also check query string token (for SSE and other streaming endpoints)
|
||||
query_token = request.args.get("token", "")
|
||||
|
||||
is_authorized = (header_token == expected_token) or (query_token == expected_token)
|
||||
|
||||
if not is_authorized:
|
||||
msg = "[api] Unauthorized access attempt - make sure your GRAPHQL_PORT and API_TOKEN settings are correct."
|
||||
@@ -1095,6 +1107,10 @@ def is_authorized():
|
||||
return is_authorized
|
||||
|
||||
|
||||
# Mount SSE endpoints after is_authorized is defined (avoid circular import)
|
||||
create_sse_endpoint(app, is_authorized)
|
||||
|
||||
|
||||
def start_server(graphql_port, app_state):
|
||||
"""Start the GraphQL server in a background thread."""
|
||||
|
||||
|
||||
48
server/api_server/sse_broadcast.py
Normal file
48
server/api_server/sse_broadcast.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Integration layer to broadcast state changes via SSE
|
||||
Call these functions from the backend whenever state changes occur
|
||||
"""
|
||||
from logger import mylog
|
||||
from .sse_endpoint import broadcast_event
|
||||
|
||||
|
||||
def broadcast_state_update(current_state: str, settings_imported: float = None, **kwargs) -> None:
|
||||
"""
|
||||
Broadcast a state update to all connected SSE clients
|
||||
Call this from app_state.updateState() or equivalent
|
||||
|
||||
Args:
|
||||
current_state: The new application state string
|
||||
settings_imported: Optional timestamp of last settings import
|
||||
**kwargs: Additional state data to broadcast
|
||||
"""
|
||||
try:
|
||||
state_data = {
|
||||
"currentState": current_state,
|
||||
"timestamp": kwargs.get("timestamp"),
|
||||
**({"settingsImported": settings_imported} if settings_imported else {}),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["timestamp"]},
|
||||
}
|
||||
broadcast_event("state_update", state_data)
|
||||
except ImportError:
|
||||
pass # SSE not available, silently skip
|
||||
except Exception as e:
|
||||
mylog("debug", [f"[SSE] Failed to broadcast state update: {e}"])
|
||||
|
||||
|
||||
def broadcast_unread_notifications_count(count: int) -> None:
|
||||
"""
|
||||
Broadcast unread notifications count to all connected SSE clients
|
||||
Call this from messaging.in_app functions when notifications change
|
||||
|
||||
Args:
|
||||
count: Number of unread notifications (must be int)
|
||||
"""
|
||||
try:
|
||||
# Ensure count is an integer
|
||||
count = int(count) if count else 0
|
||||
broadcast_event("unread_notifications_count_update", {"count": count})
|
||||
except ImportError:
|
||||
pass # SSE not available, silently skip
|
||||
except Exception as e:
|
||||
mylog("debug", [f"[SSE] Failed to broadcast unread count update: {e}"])
|
||||
164
server/api_server/sse_endpoint.py
Normal file
164
server/api_server/sse_endpoint.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
SSE (Server-Sent Events) Endpoint
|
||||
Provides real-time state updates to frontend via HTTP streaming
|
||||
Reduces polling overhead from 60+ requests/minute to 1 persistent connection
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from flask import Response, request
|
||||
from logger import mylog
|
||||
|
||||
# Thread-safe event queue
|
||||
_event_queue = deque(maxlen=100) # Keep last 100 events
|
||||
_queue_lock = threading.Lock()
|
||||
_subscribers = set() # Track active subscribers
|
||||
_subscribers_lock = threading.Lock()
|
||||
|
||||
|
||||
class StateChangeEvent:
|
||||
"""Represents a state change event to broadcast"""
|
||||
|
||||
def __init__(self, event_type: str, data: dict, timestamp: float = None):
|
||||
self.event_type = event_type # 'state_update', 'settings_changed', 'device_update', etc
|
||||
self.data = data
|
||||
self.timestamp = timestamp or time.time()
|
||||
self.id = int(self.timestamp * 1000) # Use millisecond timestamp as ID
|
||||
|
||||
def to_sse_format(self) -> str:
|
||||
"""Convert to SSE format with error handling"""
|
||||
try:
|
||||
return f"id: {self.id}\nevent: {self.event_type}\ndata: {json.dumps(self.data)}\n\n"
|
||||
except Exception as e:
|
||||
mylog("none", [f"[SSE] Failed to serialize event: {e}"])
|
||||
return ""
|
||||
|
||||
|
||||
def broadcast_event(event_type: str, data: dict) -> None:
|
||||
"""
|
||||
Broadcast an event to all connected SSE clients
|
||||
Called by backend when state changes occur
|
||||
"""
|
||||
try:
|
||||
event = StateChangeEvent(event_type, data)
|
||||
with _queue_lock:
|
||||
_event_queue.append(event)
|
||||
mylog("debug", [f"[SSE] Broadcasted event: {event_type}"])
|
||||
except Exception as e:
|
||||
mylog("none", [f"[SSE] Failed to broadcast event: {e}"])
|
||||
|
||||
|
||||
def register_subscriber(client_id: str) -> None:
|
||||
"""Track new SSE subscriber"""
|
||||
with _subscribers_lock:
|
||||
_subscribers.add(client_id)
|
||||
mylog("debug", [f"[SSE] Subscriber registered: {client_id} (total: {len(_subscribers)})"])
|
||||
|
||||
|
||||
def unregister_subscriber(client_id: str) -> None:
|
||||
"""Track disconnected SSE subscriber"""
|
||||
with _subscribers_lock:
|
||||
_subscribers.discard(client_id)
|
||||
mylog(
|
||||
"debug",
|
||||
[f"[SSE] Subscriber unregistered: {client_id} (remaining: {len(_subscribers)})"],
|
||||
)
|
||||
|
||||
|
||||
def get_subscriber_count() -> int:
|
||||
"""Get number of active SSE connections"""
|
||||
with _subscribers_lock:
|
||||
return len(_subscribers)
|
||||
|
||||
|
||||
def sse_stream(client_id: str):
|
||||
"""
|
||||
Generator for SSE stream
|
||||
Yields events to client with reconnect guidance
|
||||
"""
|
||||
register_subscriber(client_id)
|
||||
|
||||
# Send initial connection message
|
||||
yield "id: 0\nevent: connected\ndata: {}\nretry: 3000\n\n"
|
||||
|
||||
# Send initial unread notifications count on connect
|
||||
try:
|
||||
from messaging.in_app import get_unread_notifications
|
||||
initial_notifications = get_unread_notifications().json
|
||||
unread_count = len(initial_notifications) if isinstance(initial_notifications, list) else 0
|
||||
broadcast_event("unread_notifications_count_update", {"count": unread_count})
|
||||
except Exception as e:
|
||||
mylog("debug", [f"[SSE] Failed to broadcast initial unread count: {e}"])
|
||||
|
||||
last_event_id = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Check for new events since last_event_id
|
||||
with _queue_lock:
|
||||
new_events = [
|
||||
e for e in _event_queue if e.id > last_event_id
|
||||
]
|
||||
|
||||
if new_events:
|
||||
for event in new_events:
|
||||
sse_data = event.to_sse_format()
|
||||
if sse_data:
|
||||
yield sse_data
|
||||
last_event_id = event.id
|
||||
else:
|
||||
# Send keepalive every 30 seconds to prevent connection timeout
|
||||
time.sleep(1)
|
||||
if int(time.time()) % 30 == 0:
|
||||
yield ": keepalive\n\n"
|
||||
|
||||
except GeneratorExit:
|
||||
unregister_subscriber(client_id)
|
||||
except Exception as e:
|
||||
mylog("none", [f"[SSE] Stream error for {client_id}: {e}"])
|
||||
unregister_subscriber(client_id)
|
||||
|
||||
|
||||
def create_sse_endpoint(app, is_authorized=None) -> None:
|
||||
"""Mount SSE endpoints to Flask app - /sse/state and /sse/stats
|
||||
|
||||
Args:
|
||||
app: Flask app instance
|
||||
is_authorized: Optional function to check authorization (if None, allows all)
|
||||
"""
|
||||
|
||||
@app.route("/sse/state", methods=["GET"])
|
||||
def api_sse_state():
|
||||
"""SSE endpoint for real-time state updates"""
|
||||
if is_authorized and not is_authorized():
|
||||
return {"none": "Unauthorized"}, 401
|
||||
|
||||
client_id = request.args.get("client", f"client-{int(time.time() * 1000)}")
|
||||
mylog("debug", [f"[SSE] Client connected: {client_id}"])
|
||||
|
||||
return Response(
|
||||
sse_stream(client_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
@app.route("/sse/stats", methods=["GET"])
|
||||
def api_sse_stats():
|
||||
"""Get SSE endpoint statistics for debugging"""
|
||||
if is_authorized and not is_authorized():
|
||||
return {"none": "Unauthorized"}, 401
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"connected_clients": get_subscriber_count(),
|
||||
"queued_events": len(_event_queue),
|
||||
"max_queue_size": _event_queue.maxlen,
|
||||
}
|
||||
|
||||
mylog("info", ["[SSE] Endpoints mounted: /sse/state, /sse/stats"])
|
||||
Reference in New Issue
Block a user