This commit is contained in:
Adam Outler
2025-11-28 21:13:20 +00:00
parent 541b932b6d
commit 5e4ad10fe0
4 changed files with 128 additions and 107 deletions

View File

@@ -118,6 +118,7 @@ def log_request_info():
data = request.get_data(as_text=True) data = request.get_data(as_text=True)
mylog("none", [f"[HTTP] Body: {data[:1000]}"]) mylog("none", [f"[HTTP] Body: {data[:1000]}"])
@app.errorhandler(404) @app.errorhandler(404)
def not_found(error): def not_found(error):
response = { response = {
@@ -797,6 +798,7 @@ def start_server(graphql_port, app_state):
# Update the state to indicate the server has started # Update the state to indicate the server has started
app_state = updateState("Process: Idle", None, None, None, 1) app_state = updateState("Process: Idle", None, None, None, 1)
if __name__ == "__main__": if __name__ == "__main__":
# This block is for running the server directly for testing purposes # This block is for running the server directly for testing purposes
# In production, start_server is called from api.py # In production, start_server is called from api.py

View File

@@ -1,3 +1,5 @@
"""MCP bridge routes exposing NetAlertX tool endpoints via JSON-RPC."""
import json import json
import uuid import uuid
import queue import queue
@@ -16,11 +18,13 @@ openapi_spec_cache = None
API_BASE_URL = "http://localhost:20212/api/tools" API_BASE_URL = "http://localhost:20212/api/tools"
def get_openapi_spec(): def get_openapi_spec():
"""Fetch and cache the tools OpenAPI specification from the local API server."""
global openapi_spec_cache global openapi_spec_cache
if openapi_spec_cache: if openapi_spec_cache:
return openapi_spec_cache return openapi_spec_cache
try: try:
# Fetch from local server # Fetch from local server
# We use localhost because this code runs on the server # We use localhost because this code runs on the server
@@ -32,7 +36,9 @@ def get_openapi_spec():
print(f"Error fetching OpenAPI spec: {e}") print(f"Error fetching OpenAPI spec: {e}")
return None return None
def map_openapi_to_mcp_tools(spec): def map_openapi_to_mcp_tools(spec):
"""Convert OpenAPI paths into MCP tool descriptors."""
tools = [] tools = []
if not spec or "paths" not in spec: if not spec or "paths" not in spec:
return tools return tools
@@ -49,14 +55,14 @@ def map_openapi_to_mcp_tools(spec):
"required": [] "required": []
} }
} }
# Extract parameters from requestBody if present # Extract parameters from requestBody if present
if "requestBody" in details: if "requestBody" in details:
content = details["requestBody"].get("content", {}) content = details["requestBody"].get("content", {})
if "application/json" in content: if "application/json" in content:
schema = content["application/json"].get("schema", {}) schema = content["application/json"].get("schema", {})
tool["inputSchema"] = schema tool["inputSchema"] = schema
# Extract parameters from 'parameters' list (query/path params) - simplistic support # Extract parameters from 'parameters' list (query/path params) - simplistic support
if "parameters" in details: if "parameters" in details:
for param in details["parameters"]: for param in details["parameters"]:
@@ -73,12 +79,14 @@ def map_openapi_to_mcp_tools(spec):
tools.append(tool) tools.append(tool)
return tools return tools
def process_mcp_request(data): def process_mcp_request(data):
"""Handle incoming MCP JSON-RPC requests and route them to tools."""
method = data.get("method") method = data.get("method")
msg_id = data.get("id") msg_id = data.get("id")
response = None response = None
if method == "initialize": if method == "initialize":
response = { response = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
@@ -94,11 +102,11 @@ def process_mcp_request(data):
} }
} }
} }
elif method == "notifications/initialized": elif method == "notifications/initialized":
# No response needed for notification # No response needed for notification
pass pass
elif method == "tools/list": elif method == "tools/list":
spec = get_openapi_spec() spec = get_openapi_spec()
tools = map_openapi_to_mcp_tools(spec) tools = map_openapi_to_mcp_tools(spec)
@@ -109,17 +117,17 @@ def process_mcp_request(data):
"tools": tools "tools": tools
} }
} }
elif method == "tools/call": elif method == "tools/call":
params = data.get("params", {}) params = data.get("params", {})
tool_name = params.get("name") tool_name = params.get("name")
tool_args = params.get("arguments", {}) tool_args = params.get("arguments", {})
# Find the endpoint for this tool # Find the endpoint for this tool
spec = get_openapi_spec() spec = get_openapi_spec()
target_path = None target_path = None
target_method = None target_method = None
if spec and "paths" in spec: if spec and "paths" in spec:
for path, methods in spec["paths"].items(): for path, methods in spec["paths"].items():
for m, details in methods.items(): for m, details in methods.items():
@@ -129,7 +137,7 @@ def process_mcp_request(data):
break break
if target_path: if target_path:
break break
if target_path: if target_path:
try: try:
# Make the request to the local API # Make the request to the local API
@@ -139,16 +147,16 @@ def process_mcp_request(data):
} }
if "Authorization" in request.headers: if "Authorization" in request.headers:
headers["Authorization"] = request.headers["Authorization"] headers["Authorization"] = request.headers["Authorization"]
url = f"{API_BASE_URL}{target_path}" url = f"{API_BASE_URL}{target_path}"
if target_method == "POST": if target_method == "POST":
api_res = requests.post(url, json=tool_args, headers=headers) api_res = requests.post(url, json=tool_args, headers=headers)
elif target_method == "GET": elif target_method == "GET":
api_res = requests.get(url, params=tool_args, headers=headers) api_res = requests.get(url, params=tool_args, headers=headers)
else: else:
api_res = None api_res = None
if api_res: if api_res:
content = [] content = []
try: try:
@@ -157,12 +165,12 @@ def process_mcp_request(data):
"type": "text", "type": "text",
"text": json.dumps(json_content, indent=2) "text": json.dumps(json_content, indent=2)
}) })
except: except (ValueError, json.JSONDecodeError):
content.append({ content.append({
"type": "text", "type": "text",
"text": api_res.text "text": api_res.text
}) })
is_error = api_res.status_code >= 400 is_error = api_res.status_code >= 400
response = { response = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
@@ -194,27 +202,29 @@ def process_mcp_request(data):
"id": msg_id, "id": msg_id,
"error": {"code": -32601, "message": f"Tool {tool_name} not found"} "error": {"code": -32601, "message": f"Tool {tool_name} not found"}
} }
elif method == "ping": elif method == "ping":
response = { response = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": msg_id, "id": msg_id,
"result": {} "result": {}
} }
else: else:
# Unknown method # Unknown method
if msg_id: # Only respond if it's a request (has id) if msg_id: # Only respond if it's a request (has id)
response = { response = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": msg_id, "id": msg_id,
"error": {"code": -32601, "message": "Method not found"} "error": {"code": -32601, "message": "Method not found"}
} }
return response return response
@mcp_bp.route('/sse', methods=['GET', 'POST']) @mcp_bp.route('/sse', methods=['GET', 'POST'])
def handle_sse(): def handle_sse():
"""Expose an SSE endpoint that streams MCP responses to connected clients."""
if request.method == 'POST': if request.method == 'POST':
# Handle verification or keep-alive pings # Handle verification or keep-alive pings
try: try:
@@ -228,25 +238,26 @@ def handle_sse():
return "", 202 return "", 202
except Exception: except Exception:
pass pass
return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200 return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200
session_id = uuid.uuid4().hex session_id = uuid.uuid4().hex
q = queue.Queue() q = queue.Queue()
with sessions_lock: with sessions_lock:
sessions[session_id] = q sessions[session_id] = q
def stream(): def stream():
"""Yield SSE messages for queued MCP responses until the client disconnects."""
# Send the endpoint event # Send the endpoint event
# The client should POST to /api/mcp/messages?session_id=<session_id> # The client should POST to /api/mcp/messages?session_id=<session_id>
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n" yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
try: try:
while True: while True:
try: try:
# Wait for messages # Wait for messages
message = q.get(timeout=20) # Keep-alive timeout message = q.get(timeout=20) # Keep-alive timeout
yield f"event: message\ndata: {json.dumps(message)}\n\n" yield f"event: message\ndata: {json.dumps(message)}\n\n"
except queue.Empty: except queue.Empty:
# Send keep-alive comment # Send keep-alive comment
@@ -258,12 +269,14 @@ def handle_sse():
return Response(stream_with_context(stream()), mimetype='text/event-stream') return Response(stream_with_context(stream()), mimetype='text/event-stream')
@mcp_bp.route('/messages', methods=['POST']) @mcp_bp.route('/messages', methods=['POST'])
def handle_messages(): def handle_messages():
"""Receive MCP JSON-RPC messages and enqueue responses for an SSE session."""
session_id = request.args.get('session_id') session_id = request.args.get('session_id')
if not session_id: if not session_id:
return jsonify({"error": "Missing session_id"}), 400 return jsonify({"error": "Missing session_id"}), 400
with sessions_lock: with sessions_lock:
if session_id not in sessions: if session_id not in sessions:
return jsonify({"error": "Session not found"}), 404 return jsonify({"error": "Session not found"}), 404

View File

@@ -1,6 +1,4 @@
import subprocess import subprocess
import shutil
import os
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
@@ -39,25 +37,25 @@ def trigger_scan():
cmd = [] cmd = []
if scan_type == 'arp': if scan_type == 'arp':
# ARP scan usually requires sudo or root, assuming container runs as root or has caps # ARP scan usually requires sudo or root, assuming container runs as root or has caps
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
if target: if target:
cmd = ["arp-scan", target] cmd = ["arp-scan", target]
elif scan_type == 'nmap_fast': elif scan_type == 'nmap_fast':
cmd = ["nmap", "-F"] cmd = ["nmap", "-F"]
if target: if target:
cmd.append(target) cmd.append(target)
else: else:
# Default to local subnet if possible, or error if not easily determined # Default to local subnet if possible, or error if not easily determined
# For now, let's require target for nmap if not easily deducible, # For now, let's require target for nmap if not easily deducible,
# or try to get it from settings. # or try to get it from settings.
# NetAlertX usually knows its subnet. # NetAlertX usually knows its subnet.
# Let's try to get the scan subnet from settings if not provided. # Let's try to get the scan subnet from settings if not provided.
scan_subnets = get_setting_value("SCAN_SUBNETS") scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets: if scan_subnets:
# Take the first one for now # Take the first one for now
cmd.append(scan_subnets.split(',')[0].strip()) cmd.append(scan_subnets.split(',')[0].strip())
else: else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
elif scan_type == 'nmap_deep': elif scan_type == 'nmap_deep':
cmd = ["nmap", "-A", "-T4"] cmd = ["nmap", "-A", "-T4"]
if target: if target:
@@ -65,9 +63,9 @@ def trigger_scan():
else: else:
scan_subnets = get_setting_value("SCAN_SUBNETS") scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets: if scan_subnets:
cmd.append(scan_subnets.split(',')[0].strip()) cmd.append(scan_subnets.split(',')[0].strip())
else: else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
try: try:
# Run the command # Run the command
@@ -212,7 +210,7 @@ def get_open_ports():
text=True, text=True,
check=True check=True
) )
# Parse output for open ports # Parse output for open ports
open_ports = [] open_ports = []
for line in result.stdout.split('\n'): for line in result.stdout.split('\n'):
@@ -250,10 +248,10 @@ def get_network_topology():
try: try:
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices") cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
rows = cur.fetchall() rows = cur.fetchall()
nodes = [] nodes = []
links = [] links = []
for row in rows: for row in rows:
nodes.append({ nodes.append({
"id": row['devMac'], "id": row['devMac'],
@@ -299,16 +297,16 @@ def get_recent_alerts():
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S') cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
cur.execute(""" cur.execute("""
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
FROM Events FROM Events
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
WHERE eve_DateTime > ? WHERE eve_DateTime > ?
ORDER BY eve_DateTime DESC ORDER BY eve_DateTime DESC
""", (cutoff_str,)) """, (cutoff_str,))
rows = cur.fetchall() rows = cur.fetchall()
alerts = [dict(row) for row in rows] alerts = [dict(row) for row in rows]
return jsonify(alerts) return jsonify(alerts)
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@@ -338,10 +336,10 @@ def set_device_alias():
try: try:
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac)) cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
conn.commit() conn.commit()
if cur.rowcount == 0: if cur.rowcount == 0:
return jsonify({"error": "Device not found"}), 404 return jsonify({"error": "Device not found"}), 404
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"}) return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@@ -379,7 +377,7 @@ def wol_wake_device():
else: else:
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404 return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
except Exception as e: except Exception as e:
return jsonify({"error": f"Database error: {str(e)}"}), 500 return jsonify({"error": f"Database error: {str(e)}"}), 500
finally: finally:
conn.close() conn.close()

View File

@@ -2,7 +2,6 @@ import sys
import os import os
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import subprocess
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app') INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
@@ -10,20 +9,23 @@ sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402 from helper import get_setting_value # noqa: E402
from api_server.api_server_start import app # noqa: E402 from api_server.api_server_start import app # noqa: E402
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def api_token(): def api_token():
return get_setting_value("API_TOKEN") return get_setting_value("API_TOKEN")
@pytest.fixture @pytest.fixture
def client(): def client():
with app.test_client() as client: with app.test_client() as client:
yield client yield client
def auth_headers(token): def auth_headers(token):
return {"Authorization": f"Bearer {token}"} return {"Authorization": f"Bearer {token}"}
# --- get_device_info Tests ---
# --- get_device_info Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_device_info_ip_partial(mock_db_conn, client, api_token): def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
"""Test get_device_info with partial IP search.""" """Test get_device_info with partial IP search."""
@@ -33,53 +35,55 @@ def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"} {"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
] ]
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"query": ".50"} payload = {"query": ".50"}
response = client.post('/api/tools/get_device_info', response = client.post('/api/tools/get_device_info',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
devices = response.get_json() devices = response.get_json()
assert len(devices) == 1 assert len(devices) == 1
assert devices[0]["devLastIP"] == "192.168.1.50" assert devices[0]["devLastIP"] == "192.168.1.50"
# Verify SQL query included 3 params (MAC, Name, IP) # Verify SQL query included 3 params (MAC, Name, IP)
args, _ = mock_cursor.execute.call_args args, _ = mock_cursor.execute.call_args
assert args[0].count("?") == 3 assert args[0].count("?") == 3
assert len(args[1]) == 3 assert len(args[1]) == 3
# --- trigger_scan Tests ---
# --- trigger_scan Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_trigger_scan_nmap_fast(mock_run, client, api_token): def test_trigger_scan_nmap_fast(mock_run, client, api_token):
"""Test trigger_scan with nmap_fast.""" """Test trigger_scan with nmap_fast."""
mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0) mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0)
payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"} payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"}
response = client.post('/api/tools/trigger_scan', response = client.post('/api/tools/trigger_scan',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
assert "nmap -F 192.168.1.1" in data["command"] assert "nmap -F 192.168.1.1" in data["command"]
mock_run.assert_called_once() mock_run.assert_called_once()
@patch('subprocess.run') @patch('subprocess.run')
def test_trigger_scan_invalid_type(mock_run, client, api_token): def test_trigger_scan_invalid_type(mock_run, client, api_token):
"""Test trigger_scan with invalid scan_type.""" """Test trigger_scan with invalid scan_type."""
payload = {"scan_type": "invalid_type", "target": "192.168.1.1"} payload = {"scan_type": "invalid_type", "target": "192.168.1.1"}
response = client.post('/api/tools/trigger_scan', response = client.post('/api/tools/trigger_scan',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 400 assert response.status_code == 400
mock_run.assert_not_called() mock_run.assert_not_called()
# --- get_open_ports Tests --- # --- get_open_ports Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_get_open_ports_ip(mock_run, client, api_token): def test_get_open_ports_ip(mock_run, client, api_token):
"""Test get_open_ports with an IP address.""" """Test get_open_ports with an IP address."""
@@ -94,12 +98,12 @@ PORT STATE SERVICE
Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
""" """
mock_run.return_value = MagicMock(stdout=mock_output, returncode=0) mock_run.return_value = MagicMock(stdout=mock_output, returncode=0)
payload = {"target": "192.168.1.1"} payload = {"target": "192.168.1.1"}
response = client.post('/api/tools/get_open_ports', response = client.post('/api/tools/get_open_ports',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
@@ -107,6 +111,7 @@ Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
assert data["open_ports"][0]["port"] == 22 assert data["open_ports"][0]["port"] == 22
assert data["open_ports"][1]["service"] == "http" assert data["open_ports"][1]["service"] == "http"
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run') @patch('subprocess.run')
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token): def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
@@ -115,24 +120,24 @@ def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
mock_cursor = MagicMock() mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"} mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"}
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
# Mock Nmap output # Mock Nmap output
mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0) mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0)
payload = {"target": "AA:BB:CC:DD:EE:FF"} payload = {"target": "AA:BB:CC:DD:EE:FF"}
response = client.post('/api/tools/get_open_ports', response = client.post('/api/tools/get_open_ports',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["target"] == "192.168.1.50" # Should be resolved IP assert data["target"] == "192.168.1.50" # Should be resolved IP
mock_run.assert_called_once() mock_run.assert_called_once()
args, _ = mock_run.call_args args, _ = mock_run.call_args
assert "192.168.1.50" in args[0] assert "192.168.1.50" in args[0]
# --- get_network_topology Tests ---
# --- get_network_topology Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_network_topology(mock_db_conn, client, api_token): def test_get_network_topology(mock_db_conn, client, api_token):
"""Test get_network_topology.""" """Test get_network_topology."""
@@ -142,10 +147,10 @@ def test_get_network_topology(mock_db_conn, client, api_token):
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"} {"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
] ]
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
response = client.get('/api/tools/get_network_topology', response = client.get('/api/tools/get_network_topology',
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert len(data["nodes"]) == 2 assert len(data["nodes"]) == 2
@@ -153,8 +158,8 @@ def test_get_network_topology(mock_db_conn, client, api_token):
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA" assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB" assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
# --- get_recent_alerts Tests ---
# --- get_recent_alerts Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_get_recent_alerts(mock_db_conn, client, api_token): def test_get_recent_alerts(mock_db_conn, client, api_token):
"""Test get_recent_alerts.""" """Test get_recent_alerts."""
@@ -163,67 +168,69 @@ def test_get_recent_alerts(mock_db_conn, client, api_token):
{"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"} {"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"}
] ]
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"hours": 24} payload = {"hours": 24}
response = client.post('/api/tools/get_recent_alerts', response = client.post('/api/tools/get_recent_alerts',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert len(data) == 1 assert len(data) == 1
assert data[0]["eve_EventType"] == "New Device" assert data[0]["eve_EventType"] == "New Device"
# --- set_device_alias Tests ---
# --- set_device_alias Tests ---
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias(mock_db_conn, client, api_token): def test_set_device_alias(mock_db_conn, client, api_token):
"""Test set_device_alias.""" """Test set_device_alias."""
mock_cursor = MagicMock() mock_cursor = MagicMock()
mock_cursor.rowcount = 1 # Simulate successful update mock_cursor.rowcount = 1 # Simulate successful update
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
response = client.post('/api/tools/set_device_alias', response = client.post('/api/tools/set_device_alias',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias_not_found(mock_db_conn, client, api_token): def test_set_device_alias_not_found(mock_db_conn, client, api_token):
"""Test set_device_alias when device is not found.""" """Test set_device_alias when device is not found."""
mock_cursor = MagicMock() mock_cursor = MagicMock()
mock_cursor.rowcount = 0 # Simulate no rows updated mock_cursor.rowcount = 0 # Simulate no rows updated
mock_db_conn.return_value.cursor.return_value = mock_cursor mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
response = client.post('/api/tools/set_device_alias', response = client.post('/api/tools/set_device_alias',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 404 assert response.status_code == 404
# --- wol_wake_device Tests ---
# --- wol_wake_device Tests ---
@patch('subprocess.run') @patch('subprocess.run')
def test_wol_wake_device(mock_subprocess, client, api_token): def test_wol_wake_device(mock_subprocess, client, api_token):
"""Test wol_wake_device.""" """Test wol_wake_device."""
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF" mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
mock_subprocess.return_value.returncode = 0 mock_subprocess.return_value.returncode = 0
payload = {"mac": "AA:BB:CC:DD:EE:FF"} payload = {"mac": "AA:BB:CC:DD:EE:FF"}
response = client.post('/api/tools/wol_wake_device', response = client.post('/api/tools/wol_wake_device',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
@patch('api_server.tools_routes.get_temp_db_connection') @patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run') @patch('subprocess.run')
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token): def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
@@ -238,38 +245,39 @@ def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token)
mock_subprocess.return_value.returncode = 0 mock_subprocess.return_value.returncode = 0
payload = {"ip": "192.168.1.50"} payload = {"ip": "192.168.1.50"}
response = client.post('/api/tools/wol_wake_device', response = client.post('/api/tools/wol_wake_device',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data["success"] is True assert data["success"] is True
assert "AA:BB:CC:DD:EE:FF" in data["message"] assert "AA:BB:CC:DD:EE:FF" in data["message"]
# Verify DB lookup # Verify DB lookup
mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",)) mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",))
# Verify subprocess call # Verify subprocess call
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
def test_wol_wake_device_invalid_mac(client, api_token): def test_wol_wake_device_invalid_mac(client, api_token):
"""Test wol_wake_device with invalid MAC.""" """Test wol_wake_device with invalid MAC."""
payload = {"mac": "invalid-mac"} payload = {"mac": "invalid-mac"}
response = client.post('/api/tools/wol_wake_device', response = client.post('/api/tools/wol_wake_device',
json=payload, json=payload,
headers=auth_headers(api_token)) headers=auth_headers(api_token))
assert response.status_code == 400 assert response.status_code == 400
# --- openapi_spec Tests ---
# --- openapi_spec Tests ---
def test_openapi_spec(client): def test_openapi_spec(client):
"""Test openapi_spec endpoint contains new paths.""" """Test openapi_spec endpoint contains new paths."""
response = client.get('/api/tools/openapi.json') response = client.get('/api/tools/openapi.json')
assert response.status_code == 200 assert response.status_code == 200
spec = response.get_json() spec = response.get_json()
# Check for new endpoints # Check for new endpoints
assert "/trigger_scan" in spec["paths"] assert "/trigger_scan" in spec["paths"]
assert "/get_open_ports" in spec["paths"] assert "/get_open_ports" in spec["paths"]