mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2025-12-07 09:36:05 -08:00
Tidy up
This commit is contained in:
@@ -118,6 +118,7 @@ def log_request_info():
|
|||||||
data = request.get_data(as_text=True)
|
data = request.get_data(as_text=True)
|
||||||
mylog("none", [f"[HTTP] Body: {data[:1000]}"])
|
mylog("none", [f"[HTTP] Body: {data[:1000]}"])
|
||||||
|
|
||||||
|
|
||||||
@app.errorhandler(404)
|
@app.errorhandler(404)
|
||||||
def not_found(error):
|
def not_found(error):
|
||||||
response = {
|
response = {
|
||||||
@@ -797,6 +798,7 @@ def start_server(graphql_port, app_state):
|
|||||||
# Update the state to indicate the server has started
|
# Update the state to indicate the server has started
|
||||||
app_state = updateState("Process: Idle", None, None, None, 1)
|
app_state = updateState("Process: Idle", None, None, None, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# This block is for running the server directly for testing purposes
|
# This block is for running the server directly for testing purposes
|
||||||
# In production, start_server is called from api.py
|
# In production, start_server is called from api.py
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""MCP bridge routes exposing NetAlertX tool endpoints via JSON-RPC."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import queue
|
import queue
|
||||||
@@ -16,11 +18,13 @@ openapi_spec_cache = None
|
|||||||
|
|
||||||
API_BASE_URL = "http://localhost:20212/api/tools"
|
API_BASE_URL = "http://localhost:20212/api/tools"
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_spec():
|
def get_openapi_spec():
|
||||||
|
"""Fetch and cache the tools OpenAPI specification from the local API server."""
|
||||||
global openapi_spec_cache
|
global openapi_spec_cache
|
||||||
if openapi_spec_cache:
|
if openapi_spec_cache:
|
||||||
return openapi_spec_cache
|
return openapi_spec_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch from local server
|
# Fetch from local server
|
||||||
# We use localhost because this code runs on the server
|
# We use localhost because this code runs on the server
|
||||||
@@ -32,7 +36,9 @@ def get_openapi_spec():
|
|||||||
print(f"Error fetching OpenAPI spec: {e}")
|
print(f"Error fetching OpenAPI spec: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def map_openapi_to_mcp_tools(spec):
|
def map_openapi_to_mcp_tools(spec):
|
||||||
|
"""Convert OpenAPI paths into MCP tool descriptors."""
|
||||||
tools = []
|
tools = []
|
||||||
if not spec or "paths" not in spec:
|
if not spec or "paths" not in spec:
|
||||||
return tools
|
return tools
|
||||||
@@ -49,14 +55,14 @@ def map_openapi_to_mcp_tools(spec):
|
|||||||
"required": []
|
"required": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract parameters from requestBody if present
|
# Extract parameters from requestBody if present
|
||||||
if "requestBody" in details:
|
if "requestBody" in details:
|
||||||
content = details["requestBody"].get("content", {})
|
content = details["requestBody"].get("content", {})
|
||||||
if "application/json" in content:
|
if "application/json" in content:
|
||||||
schema = content["application/json"].get("schema", {})
|
schema = content["application/json"].get("schema", {})
|
||||||
tool["inputSchema"] = schema
|
tool["inputSchema"] = schema
|
||||||
|
|
||||||
# Extract parameters from 'parameters' list (query/path params) - simplistic support
|
# Extract parameters from 'parameters' list (query/path params) - simplistic support
|
||||||
if "parameters" in details:
|
if "parameters" in details:
|
||||||
for param in details["parameters"]:
|
for param in details["parameters"]:
|
||||||
@@ -73,12 +79,14 @@ def map_openapi_to_mcp_tools(spec):
|
|||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def process_mcp_request(data):
|
def process_mcp_request(data):
|
||||||
|
"""Handle incoming MCP JSON-RPC requests and route them to tools."""
|
||||||
method = data.get("method")
|
method = data.get("method")
|
||||||
msg_id = data.get("id")
|
msg_id = data.get("id")
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
if method == "initialize":
|
if method == "initialize":
|
||||||
response = {
|
response = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
@@ -94,11 +102,11 @@ def process_mcp_request(data):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
elif method == "notifications/initialized":
|
elif method == "notifications/initialized":
|
||||||
# No response needed for notification
|
# No response needed for notification
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif method == "tools/list":
|
elif method == "tools/list":
|
||||||
spec = get_openapi_spec()
|
spec = get_openapi_spec()
|
||||||
tools = map_openapi_to_mcp_tools(spec)
|
tools = map_openapi_to_mcp_tools(spec)
|
||||||
@@ -109,17 +117,17 @@ def process_mcp_request(data):
|
|||||||
"tools": tools
|
"tools": tools
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
elif method == "tools/call":
|
elif method == "tools/call":
|
||||||
params = data.get("params", {})
|
params = data.get("params", {})
|
||||||
tool_name = params.get("name")
|
tool_name = params.get("name")
|
||||||
tool_args = params.get("arguments", {})
|
tool_args = params.get("arguments", {})
|
||||||
|
|
||||||
# Find the endpoint for this tool
|
# Find the endpoint for this tool
|
||||||
spec = get_openapi_spec()
|
spec = get_openapi_spec()
|
||||||
target_path = None
|
target_path = None
|
||||||
target_method = None
|
target_method = None
|
||||||
|
|
||||||
if spec and "paths" in spec:
|
if spec and "paths" in spec:
|
||||||
for path, methods in spec["paths"].items():
|
for path, methods in spec["paths"].items():
|
||||||
for m, details in methods.items():
|
for m, details in methods.items():
|
||||||
@@ -129,7 +137,7 @@ def process_mcp_request(data):
|
|||||||
break
|
break
|
||||||
if target_path:
|
if target_path:
|
||||||
break
|
break
|
||||||
|
|
||||||
if target_path:
|
if target_path:
|
||||||
try:
|
try:
|
||||||
# Make the request to the local API
|
# Make the request to the local API
|
||||||
@@ -139,16 +147,16 @@ def process_mcp_request(data):
|
|||||||
}
|
}
|
||||||
if "Authorization" in request.headers:
|
if "Authorization" in request.headers:
|
||||||
headers["Authorization"] = request.headers["Authorization"]
|
headers["Authorization"] = request.headers["Authorization"]
|
||||||
|
|
||||||
url = f"{API_BASE_URL}{target_path}"
|
url = f"{API_BASE_URL}{target_path}"
|
||||||
|
|
||||||
if target_method == "POST":
|
if target_method == "POST":
|
||||||
api_res = requests.post(url, json=tool_args, headers=headers)
|
api_res = requests.post(url, json=tool_args, headers=headers)
|
||||||
elif target_method == "GET":
|
elif target_method == "GET":
|
||||||
api_res = requests.get(url, params=tool_args, headers=headers)
|
api_res = requests.get(url, params=tool_args, headers=headers)
|
||||||
else:
|
else:
|
||||||
api_res = None
|
api_res = None
|
||||||
|
|
||||||
if api_res:
|
if api_res:
|
||||||
content = []
|
content = []
|
||||||
try:
|
try:
|
||||||
@@ -157,12 +165,12 @@ def process_mcp_request(data):
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": json.dumps(json_content, indent=2)
|
"text": json.dumps(json_content, indent=2)
|
||||||
})
|
})
|
||||||
except:
|
except (ValueError, json.JSONDecodeError):
|
||||||
content.append({
|
content.append({
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": api_res.text
|
"text": api_res.text
|
||||||
})
|
})
|
||||||
|
|
||||||
is_error = api_res.status_code >= 400
|
is_error = api_res.status_code >= 400
|
||||||
response = {
|
response = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
@@ -194,27 +202,29 @@ def process_mcp_request(data):
|
|||||||
"id": msg_id,
|
"id": msg_id,
|
||||||
"error": {"code": -32601, "message": f"Tool {tool_name} not found"}
|
"error": {"code": -32601, "message": f"Tool {tool_name} not found"}
|
||||||
}
|
}
|
||||||
|
|
||||||
elif method == "ping":
|
elif method == "ping":
|
||||||
response = {
|
response = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": msg_id,
|
"id": msg_id,
|
||||||
"result": {}
|
"result": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Unknown method
|
# Unknown method
|
||||||
if msg_id: # Only respond if it's a request (has id)
|
if msg_id: # Only respond if it's a request (has id)
|
||||||
response = {
|
response = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": msg_id,
|
"id": msg_id,
|
||||||
"error": {"code": -32601, "message": "Method not found"}
|
"error": {"code": -32601, "message": "Method not found"}
|
||||||
}
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@mcp_bp.route('/sse', methods=['GET', 'POST'])
|
@mcp_bp.route('/sse', methods=['GET', 'POST'])
|
||||||
def handle_sse():
|
def handle_sse():
|
||||||
|
"""Expose an SSE endpoint that streams MCP responses to connected clients."""
|
||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
# Handle verification or keep-alive pings
|
# Handle verification or keep-alive pings
|
||||||
try:
|
try:
|
||||||
@@ -228,25 +238,26 @@ def handle_sse():
|
|||||||
return "", 202
|
return "", 202
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200
|
return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200
|
||||||
|
|
||||||
session_id = uuid.uuid4().hex
|
session_id = uuid.uuid4().hex
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
|
|
||||||
with sessions_lock:
|
with sessions_lock:
|
||||||
sessions[session_id] = q
|
sessions[session_id] = q
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
|
"""Yield SSE messages for queued MCP responses until the client disconnects."""
|
||||||
# Send the endpoint event
|
# Send the endpoint event
|
||||||
# The client should POST to /api/mcp/messages?session_id=<session_id>
|
# The client should POST to /api/mcp/messages?session_id=<session_id>
|
||||||
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
|
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Wait for messages
|
# Wait for messages
|
||||||
message = q.get(timeout=20) # Keep-alive timeout
|
message = q.get(timeout=20) # Keep-alive timeout
|
||||||
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
# Send keep-alive comment
|
# Send keep-alive comment
|
||||||
@@ -258,12 +269,14 @@ def handle_sse():
|
|||||||
|
|
||||||
return Response(stream_with_context(stream()), mimetype='text/event-stream')
|
return Response(stream_with_context(stream()), mimetype='text/event-stream')
|
||||||
|
|
||||||
|
|
||||||
@mcp_bp.route('/messages', methods=['POST'])
|
@mcp_bp.route('/messages', methods=['POST'])
|
||||||
def handle_messages():
|
def handle_messages():
|
||||||
|
"""Receive MCP JSON-RPC messages and enqueue responses for an SSE session."""
|
||||||
session_id = request.args.get('session_id')
|
session_id = request.args.get('session_id')
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return jsonify({"error": "Missing session_id"}), 400
|
return jsonify({"error": "Missing session_id"}), 400
|
||||||
|
|
||||||
with sessions_lock:
|
with sessions_lock:
|
||||||
if session_id not in sessions:
|
if session_id not in sessions:
|
||||||
return jsonify({"error": "Session not found"}), 404
|
return jsonify({"error": "Session not found"}), 404
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from flask import Blueprint, request, jsonify
|
from flask import Blueprint, request, jsonify
|
||||||
@@ -39,25 +37,25 @@ def trigger_scan():
|
|||||||
cmd = []
|
cmd = []
|
||||||
if scan_type == 'arp':
|
if scan_type == 'arp':
|
||||||
# ARP scan usually requires sudo or root, assuming container runs as root or has caps
|
# ARP scan usually requires sudo or root, assuming container runs as root or has caps
|
||||||
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
|
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
|
||||||
if target:
|
if target:
|
||||||
cmd = ["arp-scan", target]
|
cmd = ["arp-scan", target]
|
||||||
elif scan_type == 'nmap_fast':
|
elif scan_type == 'nmap_fast':
|
||||||
cmd = ["nmap", "-F"]
|
cmd = ["nmap", "-F"]
|
||||||
if target:
|
if target:
|
||||||
cmd.append(target)
|
cmd.append(target)
|
||||||
else:
|
else:
|
||||||
# Default to local subnet if possible, or error if not easily determined
|
# Default to local subnet if possible, or error if not easily determined
|
||||||
# For now, let's require target for nmap if not easily deducible,
|
# For now, let's require target for nmap if not easily deducible,
|
||||||
# or try to get it from settings.
|
# or try to get it from settings.
|
||||||
# NetAlertX usually knows its subnet.
|
# NetAlertX usually knows its subnet.
|
||||||
# Let's try to get the scan subnet from settings if not provided.
|
# Let's try to get the scan subnet from settings if not provided.
|
||||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||||
if scan_subnets:
|
if scan_subnets:
|
||||||
# Take the first one for now
|
# Take the first one for now
|
||||||
cmd.append(scan_subnets.split(',')[0].strip())
|
cmd.append(scan_subnets.split(',')[0].strip())
|
||||||
else:
|
else:
|
||||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||||
elif scan_type == 'nmap_deep':
|
elif scan_type == 'nmap_deep':
|
||||||
cmd = ["nmap", "-A", "-T4"]
|
cmd = ["nmap", "-A", "-T4"]
|
||||||
if target:
|
if target:
|
||||||
@@ -65,9 +63,9 @@ def trigger_scan():
|
|||||||
else:
|
else:
|
||||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||||
if scan_subnets:
|
if scan_subnets:
|
||||||
cmd.append(scan_subnets.split(',')[0].strip())
|
cmd.append(scan_subnets.split(',')[0].strip())
|
||||||
else:
|
else:
|
||||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the command
|
# Run the command
|
||||||
@@ -212,7 +210,7 @@ def get_open_ports():
|
|||||||
text=True,
|
text=True,
|
||||||
check=True
|
check=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse output for open ports
|
# Parse output for open ports
|
||||||
open_ports = []
|
open_ports = []
|
||||||
for line in result.stdout.split('\n'):
|
for line in result.stdout.split('\n'):
|
||||||
@@ -250,10 +248,10 @@ def get_network_topology():
|
|||||||
try:
|
try:
|
||||||
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
|
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
links = []
|
links = []
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
nodes.append({
|
nodes.append({
|
||||||
"id": row['devMac'],
|
"id": row['devMac'],
|
||||||
@@ -299,16 +297,16 @@ def get_recent_alerts():
|
|||||||
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
|
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
|
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
|
||||||
FROM Events
|
FROM Events
|
||||||
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
|
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
|
||||||
WHERE eve_DateTime > ?
|
WHERE eve_DateTime > ?
|
||||||
ORDER BY eve_DateTime DESC
|
ORDER BY eve_DateTime DESC
|
||||||
""", (cutoff_str,))
|
""", (cutoff_str,))
|
||||||
|
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
alerts = [dict(row) for row in rows]
|
alerts = [dict(row) for row in rows]
|
||||||
|
|
||||||
return jsonify(alerts)
|
return jsonify(alerts)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
@@ -338,10 +336,10 @@ def set_device_alias():
|
|||||||
try:
|
try:
|
||||||
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
|
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
if cur.rowcount == 0:
|
if cur.rowcount == 0:
|
||||||
return jsonify({"error": "Device not found"}), 404
|
return jsonify({"error": "Device not found"}), 404
|
||||||
|
|
||||||
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
|
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
@@ -379,7 +377,7 @@ def wol_wake_device():
|
|||||||
else:
|
else:
|
||||||
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
|
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": f"Database error: {str(e)}"}), 500
|
return jsonify({"error": f"Database error: {str(e)}"}), 500
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
import subprocess
|
|
||||||
|
|
||||||
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
|
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
|
||||||
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
||||||
@@ -10,20 +9,23 @@ sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
|||||||
from helper import get_setting_value # noqa: E402
|
from helper import get_setting_value # noqa: E402
|
||||||
from api_server.api_server_start import app # noqa: E402
|
from api_server.api_server_start import app # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def api_token():
|
def api_token():
|
||||||
return get_setting_value("API_TOKEN")
|
return get_setting_value("API_TOKEN")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
with app.test_client() as client:
|
with app.test_client() as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
def auth_headers(token):
|
def auth_headers(token):
|
||||||
return {"Authorization": f"Bearer {token}"}
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
# --- get_device_info Tests ---
|
|
||||||
|
|
||||||
|
# --- get_device_info Tests ---
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
|
def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
|
||||||
"""Test get_device_info with partial IP search."""
|
"""Test get_device_info with partial IP search."""
|
||||||
@@ -33,53 +35,55 @@ def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
|
|||||||
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
|
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
|
||||||
]
|
]
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
payload = {"query": ".50"}
|
payload = {"query": ".50"}
|
||||||
response = client.post('/api/tools/get_device_info',
|
response = client.post('/api/tools/get_device_info',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
devices = response.get_json()
|
devices = response.get_json()
|
||||||
assert len(devices) == 1
|
assert len(devices) == 1
|
||||||
assert devices[0]["devLastIP"] == "192.168.1.50"
|
assert devices[0]["devLastIP"] == "192.168.1.50"
|
||||||
|
|
||||||
# Verify SQL query included 3 params (MAC, Name, IP)
|
# Verify SQL query included 3 params (MAC, Name, IP)
|
||||||
args, _ = mock_cursor.execute.call_args
|
args, _ = mock_cursor.execute.call_args
|
||||||
assert args[0].count("?") == 3
|
assert args[0].count("?") == 3
|
||||||
assert len(args[1]) == 3
|
assert len(args[1]) == 3
|
||||||
|
|
||||||
# --- trigger_scan Tests ---
|
|
||||||
|
|
||||||
|
# --- trigger_scan Tests ---
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_trigger_scan_nmap_fast(mock_run, client, api_token):
|
def test_trigger_scan_nmap_fast(mock_run, client, api_token):
|
||||||
"""Test trigger_scan with nmap_fast."""
|
"""Test trigger_scan with nmap_fast."""
|
||||||
mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0)
|
mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0)
|
||||||
|
|
||||||
payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"}
|
payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"}
|
||||||
response = client.post('/api/tools/trigger_scan',
|
response = client.post('/api/tools/trigger_scan',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert "nmap -F 192.168.1.1" in data["command"]
|
assert "nmap -F 192.168.1.1" in data["command"]
|
||||||
mock_run.assert_called_once()
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_trigger_scan_invalid_type(mock_run, client, api_token):
|
def test_trigger_scan_invalid_type(mock_run, client, api_token):
|
||||||
"""Test trigger_scan with invalid scan_type."""
|
"""Test trigger_scan with invalid scan_type."""
|
||||||
payload = {"scan_type": "invalid_type", "target": "192.168.1.1"}
|
payload = {"scan_type": "invalid_type", "target": "192.168.1.1"}
|
||||||
response = client.post('/api/tools/trigger_scan',
|
response = client.post('/api/tools/trigger_scan',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
# --- get_open_ports Tests ---
|
# --- get_open_ports Tests ---
|
||||||
|
|
||||||
|
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_get_open_ports_ip(mock_run, client, api_token):
|
def test_get_open_ports_ip(mock_run, client, api_token):
|
||||||
"""Test get_open_ports with an IP address."""
|
"""Test get_open_ports with an IP address."""
|
||||||
@@ -94,12 +98,12 @@ PORT STATE SERVICE
|
|||||||
Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
|
Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
|
||||||
"""
|
"""
|
||||||
mock_run.return_value = MagicMock(stdout=mock_output, returncode=0)
|
mock_run.return_value = MagicMock(stdout=mock_output, returncode=0)
|
||||||
|
|
||||||
payload = {"target": "192.168.1.1"}
|
payload = {"target": "192.168.1.1"}
|
||||||
response = client.post('/api/tools/get_open_ports',
|
response = client.post('/api/tools/get_open_ports',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
@@ -107,6 +111,7 @@ Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
|
|||||||
assert data["open_ports"][0]["port"] == 22
|
assert data["open_ports"][0]["port"] == 22
|
||||||
assert data["open_ports"][1]["service"] == "http"
|
assert data["open_ports"][1]["service"] == "http"
|
||||||
|
|
||||||
|
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
|
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
|
||||||
@@ -115,24 +120,24 @@ def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
|
|||||||
mock_cursor = MagicMock()
|
mock_cursor = MagicMock()
|
||||||
mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"}
|
mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"}
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
# Mock Nmap output
|
# Mock Nmap output
|
||||||
mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0)
|
mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0)
|
||||||
|
|
||||||
payload = {"target": "AA:BB:CC:DD:EE:FF"}
|
payload = {"target": "AA:BB:CC:DD:EE:FF"}
|
||||||
response = client.post('/api/tools/get_open_ports',
|
response = client.post('/api/tools/get_open_ports',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["target"] == "192.168.1.50" # Should be resolved IP
|
assert data["target"] == "192.168.1.50" # Should be resolved IP
|
||||||
mock_run.assert_called_once()
|
mock_run.assert_called_once()
|
||||||
args, _ = mock_run.call_args
|
args, _ = mock_run.call_args
|
||||||
assert "192.168.1.50" in args[0]
|
assert "192.168.1.50" in args[0]
|
||||||
|
|
||||||
# --- get_network_topology Tests ---
|
|
||||||
|
|
||||||
|
# --- get_network_topology Tests ---
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
def test_get_network_topology(mock_db_conn, client, api_token):
|
def test_get_network_topology(mock_db_conn, client, api_token):
|
||||||
"""Test get_network_topology."""
|
"""Test get_network_topology."""
|
||||||
@@ -142,10 +147,10 @@ def test_get_network_topology(mock_db_conn, client, api_token):
|
|||||||
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
|
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
|
||||||
]
|
]
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
response = client.get('/api/tools/get_network_topology',
|
response = client.get('/api/tools/get_network_topology',
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert len(data["nodes"]) == 2
|
assert len(data["nodes"]) == 2
|
||||||
@@ -153,8 +158,8 @@ def test_get_network_topology(mock_db_conn, client, api_token):
|
|||||||
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
|
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
|
||||||
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
|
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
|
||||||
|
|
||||||
# --- get_recent_alerts Tests ---
|
|
||||||
|
|
||||||
|
# --- get_recent_alerts Tests ---
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
def test_get_recent_alerts(mock_db_conn, client, api_token):
|
def test_get_recent_alerts(mock_db_conn, client, api_token):
|
||||||
"""Test get_recent_alerts."""
|
"""Test get_recent_alerts."""
|
||||||
@@ -163,67 +168,69 @@ def test_get_recent_alerts(mock_db_conn, client, api_token):
|
|||||||
{"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"}
|
{"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"}
|
||||||
]
|
]
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
payload = {"hours": 24}
|
payload = {"hours": 24}
|
||||||
response = client.post('/api/tools/get_recent_alerts',
|
response = client.post('/api/tools/get_recent_alerts',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert len(data) == 1
|
assert len(data) == 1
|
||||||
assert data[0]["eve_EventType"] == "New Device"
|
assert data[0]["eve_EventType"] == "New Device"
|
||||||
|
|
||||||
# --- set_device_alias Tests ---
|
|
||||||
|
|
||||||
|
# --- set_device_alias Tests ---
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
def test_set_device_alias(mock_db_conn, client, api_token):
|
def test_set_device_alias(mock_db_conn, client, api_token):
|
||||||
"""Test set_device_alias."""
|
"""Test set_device_alias."""
|
||||||
mock_cursor = MagicMock()
|
mock_cursor = MagicMock()
|
||||||
mock_cursor.rowcount = 1 # Simulate successful update
|
mock_cursor.rowcount = 1 # Simulate successful update
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
||||||
response = client.post('/api/tools/set_device_alias',
|
response = client.post('/api/tools/set_device_alias',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
|
|
||||||
|
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
def test_set_device_alias_not_found(mock_db_conn, client, api_token):
|
def test_set_device_alias_not_found(mock_db_conn, client, api_token):
|
||||||
"""Test set_device_alias when device is not found."""
|
"""Test set_device_alias when device is not found."""
|
||||||
mock_cursor = MagicMock()
|
mock_cursor = MagicMock()
|
||||||
mock_cursor.rowcount = 0 # Simulate no rows updated
|
mock_cursor.rowcount = 0 # Simulate no rows updated
|
||||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||||
|
|
||||||
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
||||||
response = client.post('/api/tools/set_device_alias',
|
response = client.post('/api/tools/set_device_alias',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
# --- wol_wake_device Tests ---
|
|
||||||
|
|
||||||
|
# --- wol_wake_device Tests ---
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_wol_wake_device(mock_subprocess, client, api_token):
|
def test_wol_wake_device(mock_subprocess, client, api_token):
|
||||||
"""Test wol_wake_device."""
|
"""Test wol_wake_device."""
|
||||||
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
|
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
|
||||||
mock_subprocess.return_value.returncode = 0
|
mock_subprocess.return_value.returncode = 0
|
||||||
|
|
||||||
payload = {"mac": "AA:BB:CC:DD:EE:FF"}
|
payload = {"mac": "AA:BB:CC:DD:EE:FF"}
|
||||||
response = client.post('/api/tools/wol_wake_device',
|
response = client.post('/api/tools/wol_wake_device',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||||
@patch('subprocess.run')
|
@patch('subprocess.run')
|
||||||
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
|
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
|
||||||
@@ -238,38 +245,39 @@ def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token)
|
|||||||
mock_subprocess.return_value.returncode = 0
|
mock_subprocess.return_value.returncode = 0
|
||||||
|
|
||||||
payload = {"ip": "192.168.1.50"}
|
payload = {"ip": "192.168.1.50"}
|
||||||
response = client.post('/api/tools/wol_wake_device',
|
response = client.post('/api/tools/wol_wake_device',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert "AA:BB:CC:DD:EE:FF" in data["message"]
|
assert "AA:BB:CC:DD:EE:FF" in data["message"]
|
||||||
|
|
||||||
# Verify DB lookup
|
# Verify DB lookup
|
||||||
mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",))
|
mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",))
|
||||||
|
|
||||||
# Verify subprocess call
|
# Verify subprocess call
|
||||||
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
def test_wol_wake_device_invalid_mac(client, api_token):
|
def test_wol_wake_device_invalid_mac(client, api_token):
|
||||||
"""Test wol_wake_device with invalid MAC."""
|
"""Test wol_wake_device with invalid MAC."""
|
||||||
payload = {"mac": "invalid-mac"}
|
payload = {"mac": "invalid-mac"}
|
||||||
response = client.post('/api/tools/wol_wake_device',
|
response = client.post('/api/tools/wol_wake_device',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=auth_headers(api_token))
|
headers=auth_headers(api_token))
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
# --- openapi_spec Tests ---
|
|
||||||
|
|
||||||
|
# --- openapi_spec Tests ---
|
||||||
def test_openapi_spec(client):
|
def test_openapi_spec(client):
|
||||||
"""Test openapi_spec endpoint contains new paths."""
|
"""Test openapi_spec endpoint contains new paths."""
|
||||||
response = client.get('/api/tools/openapi.json')
|
response = client.get('/api/tools/openapi.json')
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
spec = response.get_json()
|
spec = response.get_json()
|
||||||
|
|
||||||
# Check for new endpoints
|
# Check for new endpoints
|
||||||
assert "/trigger_scan" in spec["paths"]
|
assert "/trigger_scan" in spec["paths"]
|
||||||
assert "/get_open_ports" in spec["paths"]
|
assert "/get_open_ports" in spec["paths"]
|
||||||
|
|||||||
Reference in New Issue
Block a user