Add MCP server

This commit is contained in:
Adam Outler
2025-11-25 02:19:56 +00:00
parent e90fbf17d3
commit 2bf3ff9f00
4 changed files with 1066 additions and 1 deletions

View File

@@ -71,6 +71,7 @@ from messaging.in_app import ( # noqa: E402 [flake8 lint suppression]
delete_notification,
mark_notification_as_read
)
from .tools_routes import tools_bp # noqa: E402 [flake8 lint suppression]
# Flask application
app = Flask(__name__)
@@ -87,7 +88,8 @@ CORS(
r"/dbquery/*": {"origins": "*"},
r"/messaging/*": {"origins": "*"},
r"/events/*": {"origins": "*"},
r"/logs/*": {"origins": "*"}
r"/logs/*": {"origins": "*"},
r"/api/tools/*": {"origins": "*"}
},
supports_credentials=True,
allow_headers=["Authorization", "Content-Type"],
@@ -97,6 +99,17 @@ CORS(
# -------------------------------------------------------------------
# Custom handler for 404 - Route not found
# -------------------------------------------------------------------
@app.before_request
def log_request_info():
"""Log details of every incoming request."""
# Filter out noisy requests if needed, but user asked for drastic logging
mylog("none", [f"[HTTP] {request.method} {request.path} from {request.remote_addr}"])
mylog("none", [f"[HTTP] Headers: {dict(request.headers)}"])
if request.method == "POST":
# Be careful with large bodies, but log first 1000 chars
data = request.get_data(as_text=True)
mylog("none", [f"[HTTP] Body: {data[:1000]}"])
@app.errorhandler(404)
def not_found(error):
response = {
@@ -775,3 +788,10 @@ def start_server(graphql_port, app_state):
# Update the state to indicate the server has started
app_state = updateState("Process: Idle", None, None, None, 1)
# Register Blueprints
app.register_blueprint(tools_bp, url_prefix='/api/tools')
if __name__ == "__main__":
# This block is for running the server directly for testing purposes
# In production, start_server is called from api.py
pass

View File

@@ -0,0 +1,687 @@
import subprocess
import shutil
import os
import re
from datetime import datetime, timedelta
from flask import Blueprint, request, jsonify
import sqlite3
from helper import get_setting_value
from database import get_temp_db_connection
tools_bp = Blueprint('tools', __name__)
def check_auth():
"""Check API_TOKEN authorization."""
token = request.headers.get("Authorization")
expected_token = f"Bearer {get_setting_value('API_TOKEN')}"
return token == expected_token
@tools_bp.route('/trigger_scan', methods=['POST'])
def trigger_scan():
"""
Forces NetAlertX to run a specific scan type immediately.
Arguments: scan_type (Enum: arp, nmap_fast, nmap_deep), target (optional IP/CIDR)
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
scan_type = data.get('scan_type', 'nmap_fast')
target = data.get('target')
# Validate scan_type
if scan_type not in ['arp', 'nmap_fast', 'nmap_deep']:
return jsonify({"error": "Invalid scan_type. Must be 'arp', 'nmap_fast', or 'nmap_deep'"}), 400
# Determine command
cmd = []
if scan_type == 'arp':
# ARP scan usually requires sudo or root, assuming container runs as root or has caps
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
if target:
cmd = ["arp-scan", target]
elif scan_type == 'nmap_fast':
cmd = ["nmap", "-F"]
if target:
cmd.append(target)
else:
# Default to local subnet if possible, or error if not easily determined
# For now, let's require target for nmap if not easily deducible,
# or try to get it from settings.
# NetAlertX usually knows its subnet.
# Let's try to get the scan subnet from settings if not provided.
scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets:
# Take the first one for now
cmd.append(scan_subnets.split(',')[0].strip())
else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
elif scan_type == 'nmap_deep':
cmd = ["nmap", "-A", "-T4"]
if target:
cmd.append(target)
else:
scan_subnets = get_setting_value("SCAN_SUBNETS")
if scan_subnets:
cmd.append(scan_subnets.split(',')[0].strip())
else:
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
try:
# Run the command
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True
)
return jsonify({
"success": True,
"scan_type": scan_type,
"command": " ".join(cmd),
"output": result.stdout.strip().split('\n')
})
except subprocess.CalledProcessError as e:
return jsonify({
"success": False,
"error": "Scan failed",
"details": e.stderr.strip()
}), 500
except Exception as e:
return jsonify({"error": str(e)}), 500
@tools_bp.route('/list_devices', methods=['POST'])
def list_devices():
"""List all devices."""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
cur.execute("SELECT devName, devMac, devLastIP as devIP, devVendor, devFirstConnection, devLastConnection FROM Devices ORDER BY devFirstConnection DESC")
rows = cur.fetchall()
devices = [dict(row) for row in rows]
return jsonify(devices)
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/get_device_info', methods=['POST'])
def get_device_info():
"""Get detailed info for a specific device."""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
if not data or 'query' not in data:
return jsonify({"error": "Missing 'query' parameter"}), 400
query = data['query']
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
# Search by MAC, Name, or partial IP
sql = "SELECT * FROM Devices WHERE devMac LIKE ? OR devName LIKE ? OR devLastIP LIKE ?"
cur.execute(sql, (f"%{query}%", f"%{query}%", f"%{query}%"))
rows = cur.fetchall()
if not rows:
return jsonify({"message": "No devices found"}), 404
devices = [dict(row) for row in rows]
return jsonify(devices)
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/get_latest_device', methods=['POST'])
def get_latest_device():
"""Get full details of the most recently discovered device."""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
# Get the device with the most recent devFirstConnection
cur.execute("SELECT * FROM Devices ORDER BY devFirstConnection DESC LIMIT 1")
row = cur.fetchone()
if not row:
return jsonify({"message": "No devices found"}), 404
# Return as a list to be consistent with other endpoints
return jsonify([dict(row)])
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/get_open_ports', methods=['POST'])
def get_open_ports():
"""
Specific query for the port-scan results of a target.
Arguments: target (IP or MAC)
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
target = data.get('target')
if not target:
return jsonify({"error": "Target is required"}), 400
# If MAC is provided, try to resolve to IP
if re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", target):
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
cur.execute("SELECT devLastIP FROM Devices WHERE devMac = ?", (target,))
row = cur.fetchone()
if row and row['devLastIP']:
target = row['devLastIP']
else:
return jsonify({"error": f"Could not resolve IP for MAC {target}"}), 404
finally:
conn.close()
try:
# Run nmap -F for fast port scan
cmd = ["nmap", "-F", target]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True
)
# Parse output for open ports
open_ports = []
for line in result.stdout.split('\n'):
if '/tcp' in line and 'open' in line:
parts = line.split('/')
port = parts[0].strip()
service = line.split()[2] if len(line.split()) > 2 else "unknown"
open_ports.append({"port": int(port), "service": service})
return jsonify({
"success": True,
"target": target,
"open_ports": open_ports,
"raw_output": result.stdout.strip().split('\n')
})
except subprocess.CalledProcessError as e:
return jsonify({"success": False, "error": "Port scan failed", "details": e.stderr.strip()}), 500
except Exception as e:
return jsonify({"error": str(e)}), 500
@tools_bp.route('/get_network_topology', methods=['GET'])
def get_network_topology():
"""
Returns the "Parent/Child" relationships.
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
rows = cur.fetchall()
nodes = []
links = []
for row in rows:
nodes.append({
"id": row['devMac'],
"name": row['devName'],
"vendor": row['devVendor']
})
if row['devParentMAC']:
links.append({
"source": row['devParentMAC'],
"target": row['devMac'],
"port": row['devParentPort']
})
return jsonify({
"nodes": nodes,
"links": links
})
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/get_recent_alerts', methods=['POST'])
def get_recent_alerts():
"""
Fetches the last N system alerts.
Arguments: hours (lookback period, default 24)
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
hours = data.get('hours', 24)
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
# Calculate cutoff time
cutoff = datetime.now() - timedelta(hours=int(hours))
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
cur.execute("""
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
FROM Events
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
WHERE eve_DateTime > ?
ORDER BY eve_DateTime DESC
""", (cutoff_str,))
rows = cur.fetchall()
alerts = [dict(row) for row in rows]
return jsonify(alerts)
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/set_device_alias', methods=['POST'])
def set_device_alias():
"""
Updates the name (alias) of a device.
Arguments: mac, alias
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
mac = data.get('mac')
alias = data.get('alias')
if not mac or not alias:
return jsonify({"error": "MAC and Alias are required"}), 400
conn = get_temp_db_connection()
cur = conn.cursor()
try:
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
conn.commit()
if cur.rowcount == 0:
return jsonify({"error": "Device not found"}), 404
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
conn.close()
@tools_bp.route('/wol_wake_device', methods=['POST'])
def wol_wake_device():
"""
Sends a Wake-on-LAN magic packet.
Arguments: mac OR ip
"""
if not check_auth():
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
mac = data.get('mac')
ip = data.get('ip')
if not mac and not ip:
return jsonify({"error": "MAC address or IP address is required"}), 400
# Resolve IP to MAC if MAC is missing
if not mac and ip:
conn = get_temp_db_connection()
conn.row_factory = sqlite3.Row
cur = conn.cursor()
try:
# Try to find device by IP (devLastIP)
cur.execute("SELECT devMac FROM Devices WHERE devLastIP = ?", (ip,))
row = cur.fetchone()
if row and row['devMac']:
mac = row['devMac']
else:
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
except Exception as e:
return jsonify({"error": f"Database error: {str(e)}"}), 500
finally:
conn.close()
# Validate MAC
if not re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac):
return jsonify({"success": False, "error": f"Invalid MAC: {mac}"}), 400
try:
# Using wakeonlan command
result = subprocess.run(
["wakeonlan", mac], capture_output=True, text=True, check=True
)
return jsonify(
{
"success": True,
"message": f"WOL packet sent to {mac}",
"output": result.stdout.strip(),
}
)
except subprocess.CalledProcessError as e:
return jsonify(
{
"success": False,
"error": "Failed to send WOL packet",
"details": e.stderr.strip(),
}
), 500
@tools_bp.route('/openapi.json', methods=['GET'])
def openapi_spec():
"""Return OpenAPI specification for tools."""
# No auth required for spec to allow easy import, or require it if preferred.
# Open WebUI usually needs to fetch spec without auth first or handles it.
# We'll allow public access to spec for simplicity of import.
spec = {
"openapi": "3.0.0",
"info": {
"title": "NetAlertX Tools",
"description": "API for NetAlertX device management tools",
"version": "1.1.0"
},
"servers": [
{"url": "/api/tools"}
],
"paths": {
"/list_devices": {
"post": {
"summary": "List all devices (Summary)",
"description": (
"Retrieve a SUMMARY list of all devices, sorted by newest first. "
"IMPORTANT: This only provides basic info (Name, IP, Vendor). "
"For FULL details (like custom props, alerts, etc.), you MUST use 'get_device_info' or 'get_latest_device'."
),
"operationId": "list_devices",
"responses": {
"200": {
"description": "List of devices (Summary)",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"type": "object",
"properties": {
"devName": {"type": "string"},
"devMac": {"type": "string"},
"devIP": {"type": "string"},
"devVendor": {"type": "string"},
"devStatus": {"type": "string"},
"devFirstConnection": {"type": "string"},
"devLastConnection": {"type": "string"}
}
}
}
}
}
}
}
}
},
"/get_device_info": {
"post": {
"summary": "Get device info (Full Details)",
"description": (
"Get COMPREHENSIVE information about a specific device by MAC, Name, or partial IP. "
"Use this to see all available properties, alerts, and metadata not shown in the list."
),
"operationId": "get_device_info",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "MAC address, Device Name, or partial IP to search for"
}
},
"required": ["query"]
}
}
}
},
"responses": {
"200": {
"description": "Device details (Full)",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {"type": "object"}
}
}
}
},
"404": {"description": "Device not found"}
}
}
},
"/get_latest_device": {
"post": {
"summary": "Get latest device (Full Details)",
"description": "Get COMPREHENSIVE information about the most recently discovered device (latest devFirstConnection).",
"operationId": "get_latest_device",
"responses": {
"200": {
"description": "Latest device details (Full)",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {"type": "object"}
}
}
}
},
"404": {"description": "No devices found"}
}
}
},
"/trigger_scan": {
"post": {
"summary": "Trigger Active Scan",
"description": "Forces NetAlertX to run a specific scan type immediately.",
"operationId": "trigger_scan",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"scan_type": {
"type": "string",
"enum": ["arp", "nmap_fast", "nmap_deep"],
"default": "nmap_fast"
},
"target": {
"type": "string",
"description": "IP address or CIDR to scan"
}
}
}
}
}
},
"responses": {
"200": {"description": "Scan started/completed successfully"},
"400": {"description": "Invalid input"}
}
}
},
"/get_open_ports": {
"post": {
"summary": "Get Open Ports",
"description": "Specific query for the port-scan results of a target.",
"operationId": "get_open_ports",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"target": {
"type": "string",
"description": "IP or MAC address"
}
},
"required": ["target"]
}
}
}
},
"responses": {
"200": {"description": "List of open ports"},
"404": {"description": "Target not found"}
}
}
},
"/get_network_topology": {
"get": {
"summary": "Get Network Topology",
"description": "Returns the Parent/Child relationships for network visualization.",
"operationId": "get_network_topology",
"responses": {
"200": {"description": "Graph data (nodes and links)"}
}
}
},
"/get_recent_alerts": {
"post": {
"summary": "Get Recent Alerts",
"description": "Fetches the last N system alerts.",
"operationId": "get_recent_alerts",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"hours": {
"type": "integer",
"default": 24
}
}
}
}
}
},
"responses": {
"200": {"description": "List of alerts"}
}
}
},
"/set_device_alias": {
"post": {
"summary": "Set Device Alias",
"description": "Updates the name (alias) of a device.",
"operationId": "set_device_alias",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"mac": {"type": "string"},
"alias": {"type": "string"}
},
"required": ["mac", "alias"]
}
}
}
},
"responses": {
"200": {"description": "Alias updated"},
"404": {"description": "Device not found"}
}
}
},
"/wol_wake_device": {
"post": {
"summary": "Wake on LAN",
"description": "Sends a Wake-on-LAN magic packet to the target MAC or IP. If IP is provided, it resolves to MAC first.",
"operationId": "wol_wake_device",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"mac": {"type": "string", "description": "Target MAC address"},
"ip": {"type": "string", "description": "Target IP address (resolves to MAC)"}
}
}
}
}
},
"responses": {
"200": {"description": "WOL packet sent"},
"404": {"description": "IP not found"}
}
}
}
},
"components": {
"securitySchemes": {
"bearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT"
}
}
},
"security": [
{"bearerAuth": []}
]
}
return jsonify(spec)

View File

@@ -0,0 +1,279 @@
import sys
import os
import pytest
from unittest.mock import patch, MagicMock
import subprocess
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402
from api_server.api_server_start import app # noqa: E402
@pytest.fixture(scope="session")
def api_token():
return get_setting_value("API_TOKEN")
@pytest.fixture
def client():
with app.test_client() as client:
yield client
def auth_headers(token):
return {"Authorization": f"Bearer {token}"}
# --- get_device_info Tests ---
@patch('api_server.tools_routes.get_temp_db_connection')
def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
"""Test get_device_info with partial IP search."""
mock_cursor = MagicMock()
# Mock return of a device with IP ending in .50
mock_cursor.fetchall.return_value = [
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
]
mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"query": ".50"}
response = client.post('/api/tools/get_device_info',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
devices = response.get_json()
assert len(devices) == 1
assert devices[0]["devLastIP"] == "192.168.1.50"
# Verify SQL query included 3 params (MAC, Name, IP)
args, _ = mock_cursor.execute.call_args
assert args[0].count("?") == 3
assert len(args[1]) == 3
# --- trigger_scan Tests ---
@patch('subprocess.run')
def test_trigger_scan_nmap_fast(mock_run, client, api_token):
"""Test trigger_scan with nmap_fast."""
mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0)
payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"}
response = client.post('/api/tools/trigger_scan',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "nmap -F 192.168.1.1" in data["command"]
mock_run.assert_called_once()
@patch('subprocess.run')
def test_trigger_scan_invalid_type(mock_run, client, api_token):
"""Test trigger_scan with invalid scan_type."""
payload = {"scan_type": "invalid_type", "target": "192.168.1.1"}
response = client.post('/api/tools/trigger_scan',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 400
mock_run.assert_not_called()
# --- get_open_ports Tests ---
@patch('subprocess.run')
def test_get_open_ports_ip(mock_run, client, api_token):
"""Test get_open_ports with an IP address."""
mock_output = """
Starting Nmap 7.80 ( https://nmap.org ) at 2023-10-27 10:00 UTC
Nmap scan report for 192.168.1.1
Host is up (0.0010s latency).
Not shown: 98 closed ports
PORT STATE SERVICE
22/tcp open ssh
80/tcp open http
Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
"""
mock_run.return_value = MagicMock(stdout=mock_output, returncode=0)
payload = {"target": "192.168.1.1"}
response = client.post('/api/tools/get_open_ports',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert len(data["open_ports"]) == 2
assert data["open_ports"][0]["port"] == 22
assert data["open_ports"][1]["service"] == "http"
@patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run')
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
"""Test get_open_ports with a MAC address that resolves to an IP."""
# Mock DB to resolve MAC to IP
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"}
mock_db_conn.return_value.cursor.return_value = mock_cursor
# Mock Nmap output
mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0)
payload = {"target": "AA:BB:CC:DD:EE:FF"}
response = client.post('/api/tools/get_open_ports',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["target"] == "192.168.1.50" # Should be resolved IP
mock_run.assert_called_once()
args, _ = mock_run.call_args
assert "192.168.1.50" in args[0]
# --- get_network_topology Tests ---
@patch('api_server.tools_routes.get_temp_db_connection')
def test_get_network_topology(mock_db_conn, client, api_token):
"""Test get_network_topology."""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [
{"devName": "Router", "devMac": "AA:AA:AA:AA:AA:AA", "devParentMAC": None, "devParentPort": None, "devVendor": "VendorA"},
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
]
mock_db_conn.return_value.cursor.return_value = mock_cursor
response = client.get('/api/tools/get_network_topology',
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert len(data["nodes"]) == 2
assert len(data["links"]) == 1
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
# --- get_recent_alerts Tests ---
@patch('api_server.tools_routes.get_temp_db_connection')
def test_get_recent_alerts(mock_db_conn, client, api_token):
"""Test get_recent_alerts."""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [
{"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"}
]
mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"hours": 24}
response = client.post('/api/tools/get_recent_alerts',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert len(data) == 1
assert data[0]["eve_EventType"] == "New Device"
# --- set_device_alias Tests ---
@patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias(mock_db_conn, client, api_token):
"""Test set_device_alias."""
mock_cursor = MagicMock()
mock_cursor.rowcount = 1 # Simulate successful update
mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
response = client.post('/api/tools/set_device_alias',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
@patch('api_server.tools_routes.get_temp_db_connection')
def test_set_device_alias_not_found(mock_db_conn, client, api_token):
"""Test set_device_alias when device is not found."""
mock_cursor = MagicMock()
mock_cursor.rowcount = 0 # Simulate no rows updated
mock_db_conn.return_value.cursor.return_value = mock_cursor
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
response = client.post('/api/tools/set_device_alias',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 404
# --- wol_wake_device Tests ---
@patch('subprocess.run')
def test_wol_wake_device(mock_subprocess, client, api_token):
"""Test wol_wake_device."""
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
mock_subprocess.return_value.returncode = 0
payload = {"mac": "AA:BB:CC:DD:EE:FF"}
response = client.post('/api/tools/wol_wake_device',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
@patch('api_server.tools_routes.get_temp_db_connection')
@patch('subprocess.run')
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
"""Test wol_wake_device with IP address."""
# Mock DB for IP resolution
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = {"devMac": "AA:BB:CC:DD:EE:FF"}
mock_db_conn.return_value.cursor.return_value = mock_cursor
# Mock subprocess
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
mock_subprocess.return_value.returncode = 0
payload = {"ip": "192.168.1.50"}
response = client.post('/api/tools/wol_wake_device',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert "AA:BB:CC:DD:EE:FF" in data["message"]
# Verify DB lookup
mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",))
# Verify subprocess call
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
def test_wol_wake_device_invalid_mac(client, api_token):
"""Test wol_wake_device with invalid MAC."""
payload = {"mac": "invalid-mac"}
response = client.post('/api/tools/wol_wake_device',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 400
# --- openapi_spec Tests ---
def test_openapi_spec(client):
"""Test openapi_spec endpoint contains new paths."""
response = client.get('/api/tools/openapi.json')
assert response.status_code == 200
spec = response.get_json()
# Check for new endpoints
assert "/trigger_scan" in spec["paths"]
assert "/get_open_ports" in spec["paths"]
assert "/get_network_topology" in spec["paths"]
assert "/get_recent_alerts" in spec["paths"]
assert "/set_device_alias" in spec["paths"]
assert "/wol_wake_device" in spec["paths"]

View File

@@ -0,0 +1,79 @@
import sys
import os
import pytest
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression]
@pytest.fixture(scope="session")
def api_token():
return get_setting_value("API_TOKEN")
@pytest.fixture
def client():
with app.test_client() as client:
yield client
def auth_headers(token):
return {"Authorization": f"Bearer {token}"}
def test_openapi_spec(client):
"""Test OpenAPI spec endpoint."""
response = client.get('/api/tools/openapi.json')
assert response.status_code == 200
spec = response.get_json()
assert "openapi" in spec
assert "info" in spec
assert "paths" in spec
assert "/list_devices" in spec["paths"]
assert "/get_device_info" in spec["paths"]
def test_list_devices(client, api_token):
"""Test list_devices endpoint."""
response = client.post('/api/tools/list_devices', headers=auth_headers(api_token))
assert response.status_code == 200
devices = response.get_json()
assert isinstance(devices, list)
# If there are devices, check structure
if devices:
device = devices[0]
assert "devName" in device
assert "devMac" in device
def test_get_device_info(client, api_token):
"""Test get_device_info endpoint."""
# Test with a query that might not exist
payload = {"query": "nonexistent_device"}
response = client.post('/api/tools/get_device_info',
json=payload,
headers=auth_headers(api_token))
# Should return 404 if no match, or 200 with results
assert response.status_code in [200, 404]
if response.status_code == 200:
devices = response.get_json()
assert isinstance(devices, list)
elif response.status_code == 404:
# Expected for no matches
pass
def test_list_devices_unauthorized(client):
"""Test list_devices without authorization."""
response = client.post('/api/tools/list_devices')
assert response.status_code == 401
def test_get_device_info_unauthorized(client):
"""Test get_device_info without authorization."""
payload = {"query": "test"}
response = client.post('/api/tools/get_device_info', json=payload)
assert response.status_code == 401