mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2025-12-07 09:36:05 -08:00
Tidy up
This commit is contained in:
@@ -118,6 +118,7 @@ def log_request_info():
|
||||
data = request.get_data(as_text=True)
|
||||
mylog("none", [f"[HTTP] Body: {data[:1000]}"])
|
||||
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(error):
|
||||
response = {
|
||||
@@ -797,6 +798,7 @@ def start_server(graphql_port, app_state):
|
||||
# Update the state to indicate the server has started
|
||||
app_state = updateState("Process: Idle", None, None, None, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This block is for running the server directly for testing purposes
|
||||
# In production, start_server is called from api.py
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""MCP bridge routes exposing NetAlertX tool endpoints via JSON-RPC."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import queue
|
||||
@@ -16,11 +18,13 @@ openapi_spec_cache = None
|
||||
|
||||
API_BASE_URL = "http://localhost:20212/api/tools"
|
||||
|
||||
|
||||
def get_openapi_spec():
|
||||
"""Fetch and cache the tools OpenAPI specification from the local API server."""
|
||||
global openapi_spec_cache
|
||||
if openapi_spec_cache:
|
||||
return openapi_spec_cache
|
||||
|
||||
|
||||
try:
|
||||
# Fetch from local server
|
||||
# We use localhost because this code runs on the server
|
||||
@@ -32,7 +36,9 @@ def get_openapi_spec():
|
||||
print(f"Error fetching OpenAPI spec: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def map_openapi_to_mcp_tools(spec):
|
||||
"""Convert OpenAPI paths into MCP tool descriptors."""
|
||||
tools = []
|
||||
if not spec or "paths" not in spec:
|
||||
return tools
|
||||
@@ -49,14 +55,14 @@ def map_openapi_to_mcp_tools(spec):
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Extract parameters from requestBody if present
|
||||
if "requestBody" in details:
|
||||
content = details["requestBody"].get("content", {})
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
tool["inputSchema"] = schema
|
||||
|
||||
|
||||
# Extract parameters from 'parameters' list (query/path params) - simplistic support
|
||||
if "parameters" in details:
|
||||
for param in details["parameters"]:
|
||||
@@ -73,12 +79,14 @@ def map_openapi_to_mcp_tools(spec):
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
|
||||
def process_mcp_request(data):
|
||||
"""Handle incoming MCP JSON-RPC requests and route them to tools."""
|
||||
method = data.get("method")
|
||||
msg_id = data.get("id")
|
||||
|
||||
|
||||
response = None
|
||||
|
||||
|
||||
if method == "initialize":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -94,11 +102,11 @@ def process_mcp_request(data):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
elif method == "notifications/initialized":
|
||||
# No response needed for notification
|
||||
pass
|
||||
|
||||
|
||||
elif method == "tools/list":
|
||||
spec = get_openapi_spec()
|
||||
tools = map_openapi_to_mcp_tools(spec)
|
||||
@@ -109,17 +117,17 @@ def process_mcp_request(data):
|
||||
"tools": tools
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
elif method == "tools/call":
|
||||
params = data.get("params", {})
|
||||
tool_name = params.get("name")
|
||||
tool_args = params.get("arguments", {})
|
||||
|
||||
|
||||
# Find the endpoint for this tool
|
||||
spec = get_openapi_spec()
|
||||
target_path = None
|
||||
target_method = None
|
||||
|
||||
|
||||
if spec and "paths" in spec:
|
||||
for path, methods in spec["paths"].items():
|
||||
for m, details in methods.items():
|
||||
@@ -129,7 +137,7 @@ def process_mcp_request(data):
|
||||
break
|
||||
if target_path:
|
||||
break
|
||||
|
||||
|
||||
if target_path:
|
||||
try:
|
||||
# Make the request to the local API
|
||||
@@ -139,16 +147,16 @@ def process_mcp_request(data):
|
||||
}
|
||||
if "Authorization" in request.headers:
|
||||
headers["Authorization"] = request.headers["Authorization"]
|
||||
|
||||
|
||||
url = f"{API_BASE_URL}{target_path}"
|
||||
|
||||
|
||||
if target_method == "POST":
|
||||
api_res = requests.post(url, json=tool_args, headers=headers)
|
||||
elif target_method == "GET":
|
||||
api_res = requests.get(url, params=tool_args, headers=headers)
|
||||
else:
|
||||
api_res = None
|
||||
|
||||
|
||||
if api_res:
|
||||
content = []
|
||||
try:
|
||||
@@ -157,12 +165,12 @@ def process_mcp_request(data):
|
||||
"type": "text",
|
||||
"text": json.dumps(json_content, indent=2)
|
||||
})
|
||||
except:
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": api_res.text
|
||||
})
|
||||
|
||||
|
||||
is_error = api_res.status_code >= 400
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -194,27 +202,29 @@ def process_mcp_request(data):
|
||||
"id": msg_id,
|
||||
"error": {"code": -32601, "message": f"Tool {tool_name} not found"}
|
||||
}
|
||||
|
||||
|
||||
elif method == "ping":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"result": {}
|
||||
}
|
||||
|
||||
|
||||
else:
|
||||
# Unknown method
|
||||
if msg_id: # Only respond if it's a request (has id)
|
||||
if msg_id: # Only respond if it's a request (has id)
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": msg_id,
|
||||
"error": {"code": -32601, "message": "Method not found"}
|
||||
}
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@mcp_bp.route('/sse', methods=['GET', 'POST'])
|
||||
def handle_sse():
|
||||
"""Expose an SSE endpoint that streams MCP responses to connected clients."""
|
||||
if request.method == 'POST':
|
||||
# Handle verification or keep-alive pings
|
||||
try:
|
||||
@@ -228,25 +238,26 @@ def handle_sse():
|
||||
return "", 202
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
q = queue.Queue()
|
||||
|
||||
|
||||
with sessions_lock:
|
||||
sessions[session_id] = q
|
||||
|
||||
def stream():
|
||||
"""Yield SSE messages for queued MCP responses until the client disconnects."""
|
||||
# Send the endpoint event
|
||||
# The client should POST to /api/mcp/messages?session_id=<session_id>
|
||||
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
|
||||
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages
|
||||
message = q.get(timeout=20) # Keep-alive timeout
|
||||
message = q.get(timeout=20) # Keep-alive timeout
|
||||
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
||||
except queue.Empty:
|
||||
# Send keep-alive comment
|
||||
@@ -258,12 +269,14 @@ def handle_sse():
|
||||
|
||||
return Response(stream_with_context(stream()), mimetype='text/event-stream')
|
||||
|
||||
|
||||
@mcp_bp.route('/messages', methods=['POST'])
|
||||
def handle_messages():
|
||||
"""Receive MCP JSON-RPC messages and enqueue responses for an SSE session."""
|
||||
session_id = request.args.get('session_id')
|
||||
if not session_id:
|
||||
return jsonify({"error": "Missing session_id"}), 400
|
||||
|
||||
|
||||
with sessions_lock:
|
||||
if session_id not in sessions:
|
||||
return jsonify({"error": "Session not found"}), 404
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import subprocess
|
||||
import shutil
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, request, jsonify
|
||||
@@ -39,25 +37,25 @@ def trigger_scan():
|
||||
cmd = []
|
||||
if scan_type == 'arp':
|
||||
# ARP scan usually requires sudo or root, assuming container runs as root or has caps
|
||||
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
|
||||
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
|
||||
if target:
|
||||
cmd = ["arp-scan", target]
|
||||
cmd = ["arp-scan", target]
|
||||
elif scan_type == 'nmap_fast':
|
||||
cmd = ["nmap", "-F"]
|
||||
if target:
|
||||
cmd.append(target)
|
||||
else:
|
||||
# Default to local subnet if possible, or error if not easily determined
|
||||
# For now, let's require target for nmap if not easily deducible,
|
||||
# or try to get it from settings.
|
||||
# For now, let's require target for nmap if not easily deducible,
|
||||
# or try to get it from settings.
|
||||
# NetAlertX usually knows its subnet.
|
||||
# Let's try to get the scan subnet from settings if not provided.
|
||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||
if scan_subnets:
|
||||
# Take the first one for now
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
# Take the first one for now
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
else:
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
elif scan_type == 'nmap_deep':
|
||||
cmd = ["nmap", "-A", "-T4"]
|
||||
if target:
|
||||
@@ -65,9 +63,9 @@ def trigger_scan():
|
||||
else:
|
||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||
if scan_subnets:
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
else:
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
|
||||
try:
|
||||
# Run the command
|
||||
@@ -212,7 +210,7 @@ def get_open_ports():
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
|
||||
# Parse output for open ports
|
||||
open_ports = []
|
||||
for line in result.stdout.split('\n'):
|
||||
@@ -250,10 +248,10 @@ def get_network_topology():
|
||||
try:
|
||||
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
|
||||
rows = cur.fetchall()
|
||||
|
||||
|
||||
nodes = []
|
||||
links = []
|
||||
|
||||
|
||||
for row in rows:
|
||||
nodes.append({
|
||||
"id": row['devMac'],
|
||||
@@ -299,16 +297,16 @@ def get_recent_alerts():
|
||||
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
cur.execute("""
|
||||
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
|
||||
FROM Events
|
||||
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
|
||||
FROM Events
|
||||
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
|
||||
WHERE eve_DateTime > ?
|
||||
WHERE eve_DateTime > ?
|
||||
ORDER BY eve_DateTime DESC
|
||||
""", (cutoff_str,))
|
||||
|
||||
|
||||
rows = cur.fetchall()
|
||||
alerts = [dict(row) for row in rows]
|
||||
|
||||
|
||||
return jsonify(alerts)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
@@ -338,10 +336,10 @@ def set_device_alias():
|
||||
try:
|
||||
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
|
||||
conn.commit()
|
||||
|
||||
|
||||
if cur.rowcount == 0:
|
||||
return jsonify({"error": "Device not found"}), 404
|
||||
|
||||
|
||||
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
@@ -379,7 +377,7 @@ def wol_wake_device():
|
||||
else:
|
||||
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"Database error: {str(e)}"}), 500
|
||||
return jsonify({"error": f"Database error: {str(e)}"}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user