This commit is contained in:
Adam Outler
2025-11-28 21:13:20 +00:00
parent 541b932b6d
commit 5e4ad10fe0
4 changed files with 128 additions and 107 deletions

View File

@@ -118,6 +118,7 @@ def log_request_info():
data = request.get_data(as_text=True) data = request.get_data(as_text=True)
mylog("none", [f"[HTTP] Body: {data[:1000]}"]) mylog("none", [f"[HTTP] Body: {data[:1000]}"])
@app.errorhandler(404) @app.errorhandler(404)
def not_found(error): def not_found(error):
response = { response = {
@@ -797,6 +798,7 @@ def start_server(graphql_port, app_state):
# Update the state to indicate the server has started # Update the state to indicate the server has started
app_state = updateState("Process: Idle", None, None, None, 1) app_state = updateState("Process: Idle", None, None, None, 1)
if __name__ == "__main__": if __name__ == "__main__":
# This block is for running the server directly for testing purposes # This block is for running the server directly for testing purposes
# In production, start_server is called from api.py # In production, start_server is called from api.py

View File

@@ -1,3 +1,5 @@
"""MCP bridge routes exposing NetAlertX tool endpoints via JSON-RPC."""
import json import json
import uuid import uuid
import queue import queue
@@ -16,7 +18,9 @@ openapi_spec_cache = None
API_BASE_URL = "http://localhost:20212/api/tools" API_BASE_URL = "http://localhost:20212/api/tools"
def get_openapi_spec(): def get_openapi_spec():
"""Fetch and cache the tools OpenAPI specification from the local API server."""
global openapi_spec_cache global openapi_spec_cache
if openapi_spec_cache: if openapi_spec_cache:
return openapi_spec_cache return openapi_spec_cache
@@ -32,7 +36,9 @@ def get_openapi_spec():
print(f"Error fetching OpenAPI spec: {e}") print(f"Error fetching OpenAPI spec: {e}")
return None return None
def map_openapi_to_mcp_tools(spec): def map_openapi_to_mcp_tools(spec):
"""Convert OpenAPI paths into MCP tool descriptors."""
tools = [] tools = []
if not spec or "paths" not in spec: if not spec or "paths" not in spec:
return tools return tools
@@ -73,7 +79,9 @@ def map_openapi_to_mcp_tools(spec):
tools.append(tool) tools.append(tool)
return tools return tools
def process_mcp_request(data): def process_mcp_request(data):
"""Handle incoming MCP JSON-RPC requests and route them to tools."""
method = data.get("method") method = data.get("method")
msg_id = data.get("id") msg_id = data.get("id")
@@ -157,7 +165,7 @@ def process_mcp_request(data):
"type": "text", "type": "text",
"text": json.dumps(json_content, indent=2) "text": json.dumps(json_content, indent=2)
}) })
except: except (ValueError, json.JSONDecodeError):
content.append({ content.append({
"type": "text", "type": "text",
"text": api_res.text "text": api_res.text
@@ -204,7 +212,7 @@ def process_mcp_request(data):
else: else:
# Unknown method # Unknown method
if msg_id: # Only respond if it's a request (has id) if msg_id: # Only respond if it's a request (has id)
response = { response = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": msg_id, "id": msg_id,
@@ -213,8 +221,10 @@ def process_mcp_request(data):
return response return response
@mcp_bp.route('/sse', methods=['GET', 'POST']) @mcp_bp.route('/sse', methods=['GET', 'POST'])
def handle_sse(): def handle_sse():
"""Expose an SSE endpoint that streams MCP responses to connected clients."""
if request.method == 'POST': if request.method == 'POST':
# Handle verification or keep-alive pings # Handle verification or keep-alive pings
try: try:
@@ -238,6 +248,7 @@ def handle_sse():
sessions[session_id] = q sessions[session_id] = q
def stream(): def stream():
"""Yield SSE messages for queued MCP responses until the client disconnects."""
# Send the endpoint event # Send the endpoint event
# The client should POST to /api/mcp/messages?session_id=<session_id> # The client should POST to /api/mcp/messages?session_id=<session_id>
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n" yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
@@ -246,7 +257,7 @@ def handle_sse():
while True: while True:
try: try:
# Wait for messages # Wait for messages
message = q.get(timeout=20) # Keep-alive timeout message = q.get(timeout=20) # Keep-alive timeout
yield f"event: message\ndata: {json.dumps(message)}\n\n" yield f"event: message\ndata: {json.dumps(message)}\n\n"
except queue.Empty: except queue.Empty:
# Send keep-alive comment # Send keep-alive comment
@@ -258,8 +269,10 @@ def handle_sse():
return Response(stream_with_context(stream()), mimetype='text/event-stream') return Response(stream_with_context(stream()), mimetype='text/event-stream')
@mcp_bp.route('/messages', methods=['POST']) @mcp_bp.route('/messages', methods=['POST'])
def handle_messages(): def handle_messages():
"""Receive MCP JSON-RPC messages and enqueue responses for an SSE session."""
session_id = request.args.get('session_id') session_id = request.args.get('session_id')
if not session_id: if not session_id:
return jsonify({"error": "Missing session_id"}), 400 return jsonify({"error": "Missing session_id"}), 400

View File

@@ -1,6 +1,4 @@
import subprocess import subprocess
import shutil
import os
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
@@ -39,9 +37,9 @@ def trigger_scan():
cmd = [] cmd = []
if scan_type == 'arp': if scan_type == 'arp':
# ARP scan usually requires sudo or root, assuming container runs as root or has caps # ARP scan usually requires sudo or root, assuming container runs as root or has caps
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
if target: if target:
cmd = ["arp-scan", target] cmd = ["arp-scan", target]
elif scan_type == 'nmap_fast': elif scan_type == 'nmap_fast':
cmd = ["nmap", "-F"] cmd = ["nmap", "-F"]
if target: if target:
@@ -54,10 +52,10 @@ def trigger_scan():
# Let's try to get the scan subnet from settings if not provided. # Let's try to get the scan subnet from settings if not provided.
scan_subnets = get_setting_value("SCAN_SUBNETS") scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets: if scan_subnets:
# Take the first one for now # Take the first one for now
cmd.append(scan_subnets.split(',')[0].strip()) cmd.append(scan_subnets.split(',')[0].strip())
else: else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
elif scan_type == 'nmap_deep': elif scan_type == 'nmap_deep':
cmd = ["nmap", "-A", "-T4"] cmd = ["nmap", "-A", "-T4"]
if target: if target:
@@ -65,9 +63,9 @@ def trigger_scan():
else: else:
scan_subnets = get_setting_value("SCAN_SUBNETS") scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets: if scan_subnets:
cmd.append(scan_subnets.split(',')[0].strip()) cmd.append(scan_subnets.split(',')[0].strip())
else: else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
try: try:
# Run the command # Run the command
@@ -379,7 +377,7 @@ def wol_wake_device():
else: else:
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404 return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
except Exception as e: except Exception as e:
return jsonify({"error": f"Database error: {str(e)}"}), 500 return jsonify({"error": f"Database error: {str(e)}"}), 500
finally: finally:
conn.close() conn.close()

View File

@@ -2,7 +2,6 @@ import sys
import os import os
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import subprocess
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app') INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
@@ -10,20 +9,23 @@ sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402 from helper import get_setting_value # noqa: E402
from api_server.api_server_start import app # noqa: E402 from api_server.api_server_start import app # noqa: E402
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def api_token(): def api_token():
return get_setting_value("API_TOKEN") return get_setting_value("API_TOKEN")
@pytest.fixture @pytest.fixture
def client(): def client():
with app.test_client() as client: with app.test_client() as client:
yield client yield client
def auth_headers(token): def auth_headers(token):
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
# --- get_device_info Tests ---
# --- get_device_info Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_device_info_ip_partial(mock_db_conn, client, api_token): def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
"""Test get_device_info with partial IP search.""" """Test get_device_info with partial IP search."""
@@ -49,8 +51,8 @@ def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
assert args[0].count("?") == 3 assert args[0].count("?") == 3
assert len(args[1]) == 3 assert len(args[1]) == 3
# --- trigger_scan Tests ---
# --- trigger_scan Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_trigger_scan_nmap_fast(mock_run, client, api_token): def test_trigger_scan_nmap_fast(mock_run, client, api_token):
"""Test trigger_scan with nmap_fast.""" """Test trigger_scan with nmap_fast."""
@@ -67,6 +69,7 @@ def test_trigger_scan_nmap_fast(mock_run, client, api_token):
assert "nmap -F 192.168.1.1" in data["command"] assert "nmap -F 192.168.1.1" in data["command"]
mock_run.assert_called_once() mock_run.assert_called_once()
@patch('subprocess.run') @patch('subprocess.run')
def test_trigger_scan_invalid_type(mock_run, client, api_token): def test_trigger_scan_invalid_type(mock_run, client, api_token):
"""Test trigger_scan with invalid scan_type.""" """Test trigger_scan with invalid scan_type."""
@@ -80,6 +83,7 @@ def test_trigger_scan_invalid_type(mock_run, client, api_token):
# --- get_open_ports Tests --- # --- get_open_ports Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_get_open_ports_ip(mock_run, client, api_token): def test_get_open_ports_ip(mock_run, client, api_token):
"""Test get_open_ports with an IP address.""" """Test get_open_ports with an IP address."""
@@ -107,6 +111,7 @@ Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
assert data["open_ports"][0]["port"] == 22 assert data["open_ports"][0]["port"] == 22
assert data["open_ports"][1]["service"] == "http" assert data["open_ports"][1]["service"] == "http"
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run') @patch('subprocess.run')
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token): def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
@@ -126,13 +131,13 @@ def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["target"] == "192.168.1.50" # Should be resolved IP assert data["target"] == "192.168.1.50" # Should be resolved IP
mock_run.assert_called_once() mock_run.assert_called_once()
args, _ = mock_run.call_args args, _ = mock_run.call_args
assert "192.168.1.50" in args[0] assert "192.168.1.50" in args[0]
# --- get_network_topology Tests ---
# --- get_network_topology Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_network_topology(mock_db_conn, client, api_token): def test_get_network_topology(mock_db_conn, client, api_token):
"""Test get_network_topology.""" """Test get_network_topology."""
@@ -153,8 +158,8 @@ def test_get_network_topology(mock_db_conn, client, api_token):
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA" assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB" assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
# --- get_recent_alerts Tests ---
# --- get_recent_alerts Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_recent_alerts(mock_db_conn, client, api_token): def test_get_recent_alerts(mock_db_conn, client, api_token):
"""Test get_recent_alerts.""" """Test get_recent_alerts."""
@@ -174,13 +179,13 @@ def test_get_recent_alerts(mock_db_conn, client, api_token):
assert len(data) == 1 assert len(data) == 1
assert data[0]["eve_EventType"] == "New Device" assert data[0]["eve_EventType"] == "New Device"
# --- set_device_alias Tests ---
# --- set_device_alias Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias(mock_db_conn, client, api_token): def test_set_device_alias(mock_db_conn, client, api_token):
"""Test set_device_alias.""" """Test set_device_alias."""
mock_cursor = MagicMock() mock_cursor = MagicMock()
mock_cursor.rowcount = 1 # Simulate successful update mock_cursor.rowcount = 1 # Simulate successful update
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
@@ -192,11 +197,12 @@ def test_set_device_alias(mock_db_conn, client, api_token):
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias_not_found(mock_db_conn, client, api_token): def test_set_device_alias_not_found(mock_db_conn, client, api_token):
"""Test set_device_alias when device is not found.""" """Test set_device_alias when device is not found."""
mock_cursor = MagicMock() mock_cursor = MagicMock()
mock_cursor.rowcount = 0 # Simulate no rows updated mock_cursor.rowcount = 0 # Simulate no rows updated
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
@@ -206,8 +212,8 @@ def test_set_device_alias_not_found(mock_db_conn, client, api_token):
assert response.status_code == 404 assert response.status_code == 404
# --- wol_wake_device Tests ---
# --- wol_wake_device Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_wol_wake_device(mock_subprocess, client, api_token): def test_wol_wake_device(mock_subprocess, client, api_token):
"""Test wol_wake_device.""" """Test wol_wake_device."""
@@ -224,6 +230,7 @@ def test_wol_wake_device(mock_subprocess, client, api_token):
assert data["success"] is True assert data["success"] is True
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run') @patch('subprocess.run')
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token): def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
@@ -253,6 +260,7 @@ def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token)
# Verify subprocess call # Verify subprocess call
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
def test_wol_wake_device_invalid_mac(client, api_token): def test_wol_wake_device_invalid_mac(client, api_token):
"""Test wol_wake_device with invalid MAC.""" """Test wol_wake_device with invalid MAC."""
payload = {"mac": "invalid-mac"} payload = {"mac": "invalid-mac"}
@@ -262,8 +270,8 @@ def test_wol_wake_device_invalid_mac(client, api_token):
assert response.status_code == 400 assert response.status_code == 400
# --- openapi_spec Tests ---
# --- openapi_spec Tests ---
def test_openapi_spec(client): def test_openapi_spec(client):
"""Test openapi_spec endpoint contains new paths.""" """Test openapi_spec endpoint contains new paths."""
response = client.get('/api/tools/openapi.json') response = client.get('/api/tools/openapi.json')