mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2025-12-07 09:36:05 -08:00
@@ -25,7 +25,7 @@
|
||||
// even within this container and connect to them as needed.
|
||||
// "--network=host",
|
||||
],
|
||||
"mounts": [
|
||||
"mounts": [
|
||||
"source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" //used for testing various conditions in docker
|
||||
],
|
||||
// ATTENTION: If running with --network=host, COMMENT `forwardPorts` OR ELSE THERE WILL BE NO WEBUI!
|
||||
@@ -88,7 +88,7 @@
|
||||
}
|
||||
},
|
||||
"terminal.integrated.defaultProfile.linux": "zsh",
|
||||
|
||||
|
||||
// Python testing configuration
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.unittestEnabled": false,
|
||||
|
||||
5
.github/copilot-instructions.md
vendored
5
.github/copilot-instructions.md
vendored
@@ -39,6 +39,7 @@ Backend loop phases (see `server/__main__.py` and `server/plugin.py`): `once`, `
|
||||
## API/Endpoints quick map
|
||||
- Flask app: `server/api_server/api_server_start.py` exposes routes like `/device/<mac>`, `/devices`, `/devices/export/{csv,json}`, `/devices/import`, `/devices/totals`, `/devices/by-status`, plus `nettools`, `events`, `sessions`, `dbquery`, `metrics`, `sync`.
|
||||
- Authorization: all routes expect header `Authorization: Bearer <API_TOKEN>` via `get_setting_value('API_TOKEN')`.
|
||||
- All responses need to return `"success":<False:True>` and if `False` an "error" message needs to be returned, e.g. `{"success": False, "error": f"No stored open ports for Device"}`
|
||||
|
||||
## Conventions & helpers to reuse
|
||||
- Settings: add/modify via `ccd()` in `server/initialise.py` or per‑plugin manifest. Never hardcode ports or secrets; use `get_setting_value()`.
|
||||
@@ -85,7 +86,7 @@ Backend loop phases (see `server/__main__.py` and `server/plugin.py`): `once`, `
|
||||
- Above all, use the simplest possible code that meets the need so it can be easily audited and maintained.
|
||||
- Always leave logging enabled. If there is a possiblity it will be difficult to debug with current logging, add more logging.
|
||||
- Always run the testFailure tool before executing any tests to gather current failure information and avoid redundant runs.
|
||||
- Always prioritize using the appropriate tools in the environment first. As an example if a test is failing use `testFailure` then `runTests`. Never `runTests` first.
|
||||
- Always prioritize using the appropriate tools in the environment first. As an example if a test is failing use `testFailure` then `runTests`. Never `runTests` first.
|
||||
- Docker tests take an extremely long time to run. Avoid changes to docker or tests until you've examined the exisiting testFailures and runTests results.
|
||||
- Environment tools are designed specifically for your use in this project and running them in this order will give you the best results.
|
||||
- Environment tools are designed specifically for your use in this project and running them in this order will give you the best results.
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
from pytz import timezone # noqa: E402 [flake8 lint suppression]
|
||||
@@ -98,9 +97,7 @@ def main():
|
||||
{"devMac": "00:11:22:33:44:57", "devLastIP": "192.168.1.82"},
|
||||
]
|
||||
else:
|
||||
db = DB()
|
||||
db.open()
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
devices = (
|
||||
device_handler.getAll()
|
||||
if get_setting_value("REFRESH_FQDN")
|
||||
|
||||
@@ -11,7 +11,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
from pytz import timezone # noqa: E402 [flake8 lint suppression]
|
||||
@@ -38,15 +37,11 @@ def main():
|
||||
|
||||
timeout = get_setting_value('DIGSCAN_RUN_TIMEOUT')
|
||||
|
||||
# Create a database connection
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
|
||||
# Initialize the Plugin obj output file
|
||||
plugin_objects = Plugin_Objects(RESULT_FILE)
|
||||
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Retrieve devices
|
||||
if get_setting_value("REFRESH_FQDN"):
|
||||
|
||||
@@ -15,7 +15,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
from pytz import timezone # noqa: E402 [flake8 lint suppression]
|
||||
@@ -41,15 +40,11 @@ def main():
|
||||
args = get_setting_value('ICMP_ARGS')
|
||||
in_regex = get_setting_value('ICMP_IN_REGEX')
|
||||
|
||||
# Create a database connection
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
|
||||
# Initialize the Plugin obj output file
|
||||
plugin_objects = Plugin_Objects(RESULT_FILE)
|
||||
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Retrieve devices
|
||||
all_devices = device_handler.getAll()
|
||||
|
||||
@@ -12,7 +12,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
from pytz import timezone # noqa: E402 [flake8 lint suppression]
|
||||
@@ -40,15 +39,11 @@ def main():
|
||||
# timeout = get_setting_value('NBLOOKUP_RUN_TIMEOUT')
|
||||
timeout = 20
|
||||
|
||||
# Create a database connection
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
|
||||
# Initialize the Plugin obj output file
|
||||
plugin_objects = Plugin_Objects(RESULT_FILE)
|
||||
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Retrieve devices
|
||||
if get_setting_value("REFRESH_FQDN"):
|
||||
|
||||
@@ -15,7 +15,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
from pytz import timezone # noqa: E402 [flake8 lint suppression]
|
||||
@@ -39,15 +38,11 @@ def main():
|
||||
|
||||
timeout = get_setting_value('NSLOOKUP_RUN_TIMEOUT')
|
||||
|
||||
# Create a database connection
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
|
||||
# Initialize the Plugin obj output file
|
||||
plugin_objects = Plugin_Objects(RESULT_FILE)
|
||||
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Retrieve devices
|
||||
if get_setting_value("REFRESH_FQDN"):
|
||||
|
||||
@@ -256,13 +256,11 @@ def main():
|
||||
start_time = time.time()
|
||||
|
||||
mylog("verbose", [f"[{pluginName}] starting execution"])
|
||||
from database import DB
|
||||
|
||||
from models.device_instance import DeviceInstance
|
||||
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
# Retrieve configuration settings
|
||||
# these should be self-explanatory
|
||||
omada_sites = []
|
||||
|
||||
@@ -13,7 +13,6 @@ from plugin_helper import Plugin_Objects # noqa: E402 [flake8 lint suppression]
|
||||
from logger import mylog, Logger # noqa: E402 [flake8 lint suppression]
|
||||
from const import logPath # noqa: E402 [flake8 lint suppression]
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
|
||||
import conf # noqa: E402 [flake8 lint suppression]
|
||||
|
||||
@@ -44,12 +43,8 @@ def main():
|
||||
|
||||
mylog('verbose', [f'[{pluginName}] broadcast_ips value {broadcast_ips}'])
|
||||
|
||||
# Create a database connection
|
||||
db = DB() # instance of class DB
|
||||
db.open()
|
||||
|
||||
# Create a DeviceInstance instance
|
||||
device_handler = DeviceInstance(db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Retrieve devices
|
||||
if 'offline' in devices_to_wake:
|
||||
|
||||
@@ -2,13 +2,8 @@ import threading
|
||||
import sys
|
||||
import os
|
||||
|
||||
from flask import Flask, request, jsonify, Response, stream_with_context
|
||||
import json
|
||||
import uuid
|
||||
import queue
|
||||
from flask import Flask, request, jsonify, Response
|
||||
import requests
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from models.device_instance import DeviceInstance # noqa: E402
|
||||
from flask_cors import CORS
|
||||
|
||||
@@ -70,9 +65,12 @@ from .dbquery_endpoint import read_query, write_query, update_query, delete_quer
|
||||
from .sync_endpoint import handle_sync_post, handle_sync_get # noqa: E402 [flake8 lint suppression]
|
||||
from .logs_endpoint import clean_log # noqa: E402 [flake8 lint suppression]
|
||||
from models.user_events_queue_instance import UserEventsQueueInstance # noqa: E402 [flake8 lint suppression]
|
||||
from database import DB # noqa: E402 [flake8 lint suppression]
|
||||
from models.plugin_object_instance import PluginObjectInstance # noqa: E402 [flake8 lint suppression]
|
||||
|
||||
from models.event_instance import EventInstance # noqa: E402 [flake8 lint suppression]
|
||||
# Import tool logic from the MCP/tools module to reuse behavior (no blueprints)
|
||||
from plugin_helper import is_mac # noqa: E402 [flake8 lint suppression]
|
||||
# is_mac is provided in mcp_endpoint and used by those handlers
|
||||
# mcp_endpoint contains helper functions; routes moved into this module to keep a single place for routes
|
||||
from messaging.in_app import ( # noqa: E402 [flake8 lint suppression]
|
||||
write_notification,
|
||||
mark_all_notifications_read,
|
||||
@@ -81,14 +79,17 @@ from messaging.in_app import ( # noqa: E402 [flake8 lint suppression]
|
||||
delete_notification,
|
||||
mark_notification_as_read
|
||||
)
|
||||
from .tools_routes import openapi_spec as tools_openapi_spec # noqa: E402 [flake8 lint suppression]
|
||||
from .mcp_endpoint import ( # noqa: E402 [flake8 lint suppression]
|
||||
mcp_sse,
|
||||
mcp_messages,
|
||||
openapi_spec
|
||||
)
|
||||
# tools and mcp routes have been moved into this module (api_server_start)
|
||||
|
||||
# Flask application
|
||||
app = Flask(__name__)
|
||||
|
||||
# Register Blueprints
|
||||
# No separate blueprints for tools or mcp - routes are registered below
|
||||
|
||||
CORS(
|
||||
app,
|
||||
resources={
|
||||
@@ -103,30 +104,22 @@ CORS(
|
||||
r"/messaging/*": {"origins": "*"},
|
||||
r"/events/*": {"origins": "*"},
|
||||
r"/logs/*": {"origins": "*"},
|
||||
r"/api/tools/*": {"origins": "*"}
|
||||
r"/auth/*": {"origins": "*"}
|
||||
r"/api/tools/*": {"origins": "*"},
|
||||
r"/auth/*": {"origins": "*"},
|
||||
r"/mcp/*": {"origins": "*"}
|
||||
},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# -----------------------------------------------
|
||||
# DB model instances for helper usage
|
||||
# -----------------------------------------------
|
||||
db_helper = DB()
|
||||
db_helper.open()
|
||||
device_handler = DeviceInstance(db_helper)
|
||||
plugin_object_handler = PluginObjectInstance(db_helper)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# MCP bridge variables + helpers (moved from mcp_routes)
|
||||
# -------------------------------------------------------------------------------
|
||||
mcp_sessions = {}
|
||||
mcp_sessions_lock = threading.Lock()
|
||||
|
||||
mcp_openapi_spec_cache = None
|
||||
|
||||
BACKEND_PORT = get_setting_value("GRAPHQL_PORT")
|
||||
API_BASE_URL = f"http://localhost:{BACKEND_PORT}/api/tools"
|
||||
API_BASE_URL = f"http://localhost:{BACKEND_PORT}"
|
||||
|
||||
|
||||
def get_openapi_spec_local():
|
||||
@@ -134,7 +127,7 @@ def get_openapi_spec_local():
|
||||
if mcp_openapi_spec_cache:
|
||||
return mcp_openapi_spec_cache
|
||||
try:
|
||||
resp = requests.get(f"{API_BASE_URL}/openapi.json", timeout=10)
|
||||
resp = requests.get(f"{API_BASE_URL}/mcp/openapi.json", timeout=10)
|
||||
resp.raise_for_status()
|
||||
mcp_openapi_spec_cache = resp.json()
|
||||
return mcp_openapi_spec_cache
|
||||
@@ -143,161 +136,18 @@ def get_openapi_spec_local():
|
||||
return None
|
||||
|
||||
|
||||
def map_openapi_to_mcp_tools(spec):
|
||||
tools = []
|
||||
if not spec or 'paths' not in spec:
|
||||
return tools
|
||||
for path, methods in spec['paths'].items():
|
||||
for method, details in methods.items():
|
||||
if 'operationId' in details:
|
||||
tool = {
|
||||
'name': details['operationId'],
|
||||
'description': details.get('description', details.get('summary', '')),
|
||||
'inputSchema': {'type': 'object', 'properties': {}, 'required': []},
|
||||
}
|
||||
if 'requestBody' in details:
|
||||
content = details['requestBody'].get('content', {})
|
||||
if 'application/json' in content:
|
||||
schema = content['application/json'].get('schema', {})
|
||||
tool['inputSchema'] = schema.copy()
|
||||
if 'properties' not in tool['inputSchema']:
|
||||
tool['inputSchema']['properties'] = {}
|
||||
if 'parameters' in details:
|
||||
for param in details['parameters']:
|
||||
if param.get('in') == 'query':
|
||||
tool['inputSchema']['properties'][param['name']] = {
|
||||
'type': param.get('schema', {}).get('type', 'string'),
|
||||
'description': param.get('description', ''),
|
||||
}
|
||||
if param.get('required'):
|
||||
tool['inputSchema'].setdefault('required', []).append(param['name'])
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
|
||||
def process_mcp_request(data):
|
||||
method = data.get('method')
|
||||
msg_id = data.get('id')
|
||||
response = None
|
||||
if method == 'initialize':
|
||||
response = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': msg_id,
|
||||
'result': {
|
||||
'protocolVersion': '2024-11-05',
|
||||
'capabilities': {'tools': {}},
|
||||
'serverInfo': {'name': 'NetAlertX', 'version': '1.0.0'},
|
||||
},
|
||||
}
|
||||
elif method == 'notifications/initialized':
|
||||
pass
|
||||
elif method == 'tools/list':
|
||||
spec = get_openapi_spec_local()
|
||||
tools = map_openapi_to_mcp_tools(spec)
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'result': {'tools': tools}}
|
||||
elif method == 'tools/call':
|
||||
params = data.get('params', {})
|
||||
tool_name = params.get('name')
|
||||
tool_args = params.get('arguments', {})
|
||||
spec = get_openapi_spec_local()
|
||||
target_path = None
|
||||
target_method = None
|
||||
if spec and 'paths' in spec:
|
||||
for path, methods in spec['paths'].items():
|
||||
for m, details in methods.items():
|
||||
if details.get('operationId') == tool_name:
|
||||
target_path = path
|
||||
target_method = m.upper()
|
||||
break
|
||||
if target_path:
|
||||
break
|
||||
if target_path:
|
||||
try:
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if 'Authorization' in request.headers:
|
||||
headers['Authorization'] = request.headers['Authorization']
|
||||
url = f"{API_BASE_URL}{target_path}"
|
||||
if target_method == 'POST':
|
||||
api_res = requests.post(url, json=tool_args, headers=headers, timeout=30)
|
||||
elif target_method == 'GET':
|
||||
api_res = requests.get(url, params=tool_args, headers=headers, timeout=30)
|
||||
else:
|
||||
api_res = None
|
||||
if api_res:
|
||||
content = []
|
||||
try:
|
||||
json_content = api_res.json()
|
||||
content.append({'type': 'text', 'text': json.dumps(json_content, indent=2)})
|
||||
except Exception:
|
||||
content.append({'type': 'text', 'text': api_res.text})
|
||||
is_error = api_res.status_code >= 400
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'result': {'content': content, 'isError': is_error}}
|
||||
else:
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'error': {'code': -32601, 'message': f"Method {target_method} not supported"}}
|
||||
except Exception as e:
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'result': {'content': [{'type': 'text', 'text': f"Error calling tool: {str(e)}"}], 'isError': True}}
|
||||
else:
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'error': {'code': -32601, 'message': f"Tool {tool_name} not found"}}
|
||||
elif method == 'ping':
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'result': {}}
|
||||
else:
|
||||
if msg_id:
|
||||
response = {'jsonrpc': '2.0', 'id': msg_id, 'error': {'code': -32601, 'message': 'Method not found'}}
|
||||
return response
|
||||
|
||||
|
||||
@app.route('/api/mcp/sse', methods=['GET', 'POST'])
|
||||
@app.route('/mcp/sse', methods=['GET', 'POST'])
|
||||
def api_mcp_sse():
|
||||
if request.method == 'POST':
|
||||
try:
|
||||
data = request.get_json(silent=True)
|
||||
if data and 'method' in data and 'jsonrpc' in data:
|
||||
response = process_mcp_request(data)
|
||||
if response:
|
||||
return jsonify(response)
|
||||
else:
|
||||
return '', 202
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).debug(f'SSE POST processing error: {e}')
|
||||
return jsonify({'status': 'ok', 'message': 'MCP SSE endpoint active'}), 200
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
q = queue.Queue()
|
||||
with mcp_sessions_lock:
|
||||
mcp_sessions[session_id] = q
|
||||
|
||||
def stream():
|
||||
yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n"
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = q.get(timeout=20)
|
||||
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
||||
except queue.Empty:
|
||||
yield ": keep-alive\n\n"
|
||||
except GeneratorExit:
|
||||
with mcp_sessions_lock:
|
||||
if session_id in mcp_sessions:
|
||||
del mcp_sessions[session_id]
|
||||
return Response(stream_with_context(stream()), mimetype='text/event-stream')
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
return mcp_sse()
|
||||
|
||||
|
||||
@app.route('/api/mcp/messages', methods=['POST'])
|
||||
def api_mcp_messages():
|
||||
session_id = request.args.get('session_id')
|
||||
if not session_id:
|
||||
return jsonify({"error": "Missing session_id"}), 400
|
||||
with mcp_sessions_lock:
|
||||
if session_id not in mcp_sessions:
|
||||
return jsonify({"error": "Session not found"}), 404
|
||||
q = mcp_sessions[session_id]
|
||||
data = request.json
|
||||
if not data:
|
||||
return jsonify({"error": "Invalid JSON"}), 400
|
||||
response = process_mcp_request(data)
|
||||
if response:
|
||||
q.put(response)
|
||||
return jsonify({"status": "accepted"}), 202
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
return mcp_messages()
|
||||
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
@@ -365,188 +215,12 @@ def graphql_endpoint():
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Tools endpoints (moved from tools_routes)
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route('/api/tools/trigger_scan', methods=['POST'])
|
||||
def api_trigger_scan():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json() or {}
|
||||
scan_type = data.get('scan_type', 'nmap_fast')
|
||||
# Map requested scan type to plugin prefix
|
||||
plugin_prefix = None
|
||||
if scan_type in ['nmap_fast', 'nmap_deep']:
|
||||
plugin_prefix = 'NMAPDEV'
|
||||
elif scan_type == 'arp':
|
||||
plugin_prefix = 'ARPSCAN'
|
||||
else:
|
||||
return jsonify({"error": "Invalid scan_type. Must be 'arp', 'nmap_fast', or 'nmap_deep'"}), 400
|
||||
|
||||
queue_instance = UserEventsQueueInstance()
|
||||
action = f"run|{plugin_prefix}"
|
||||
success, message = queue_instance.add_event(action)
|
||||
if success:
|
||||
return jsonify({"success": True, "message": f"Triggered plugin {plugin_prefix} via ad-hoc queue."})
|
||||
else:
|
||||
return jsonify({"success": False, "error": message}), 500
|
||||
|
||||
|
||||
@app.route('/api/tools/list_devices', methods=['POST'])
|
||||
def api_tools_list_devices():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
return get_all_devices()
|
||||
|
||||
|
||||
@app.route('/api/tools/get_device_info', methods=['POST'])
|
||||
def api_tools_get_device_info():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
data = request.get_json(silent=True) or {}
|
||||
query = data.get('query')
|
||||
if not query:
|
||||
return jsonify({"error": "Missing 'query' parameter"}), 400
|
||||
# if MAC -> device endpoint
|
||||
if is_mac(query):
|
||||
return get_device_data(query)
|
||||
# search by name or IP
|
||||
matches = device_handler.search(query)
|
||||
if not matches:
|
||||
return jsonify({"message": "No devices found"}), 404
|
||||
return jsonify(matches)
|
||||
|
||||
|
||||
@app.route('/api/tools/get_latest_device', methods=['POST'])
|
||||
def api_tools_get_latest_device():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
latest = device_handler.getLatest()
|
||||
if not latest:
|
||||
return jsonify({"message": "No devices found"}), 404
|
||||
return jsonify([latest])
|
||||
|
||||
|
||||
@app.route('/api/tools/get_open_ports', methods=['POST'])
|
||||
def api_tools_get_open_ports():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
data = request.get_json(silent=True) or {}
|
||||
target = data.get('target')
|
||||
if not target:
|
||||
return jsonify({"error": "Target is required"}), 400
|
||||
|
||||
# If MAC is provided, use plugin objects to get port entries
|
||||
if is_mac(target):
|
||||
entries = plugin_object_handler.getByPrimary('NMAP', target.lower())
|
||||
open_ports = []
|
||||
for e in entries:
|
||||
try:
|
||||
port = int(e.get('Object_SecondaryID', 0))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
service = e.get('Watched_Value2', 'unknown')
|
||||
open_ports.append({"port": port, "service": service})
|
||||
return jsonify({"success": True, "target": target, "open_ports": open_ports, "raw": entries})
|
||||
|
||||
# If IP provided, try to resolve to MAC and proceed
|
||||
# Use device handler to resolve IP
|
||||
device = device_handler.getByIP(target)
|
||||
if device and device.get('devMac'):
|
||||
mac = device.get('devMac')
|
||||
entries = plugin_object_handler.getByPrimary('NMAP', mac.lower())
|
||||
open_ports = []
|
||||
for e in entries:
|
||||
try:
|
||||
port = int(e.get('Object_SecondaryID', 0))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
service = e.get('Watched_Value2', 'unknown')
|
||||
open_ports.append({"port": port, "service": service})
|
||||
return jsonify({"success": True, "target": target, "open_ports": open_ports, "raw": entries})
|
||||
|
||||
# No plugin data found; as fallback use nettools nmap_scan (may run subprocess)
|
||||
# Note: Prefer plugin data (NMAP) when available
|
||||
res = nmap_scan(target, 'fast')
|
||||
return res
|
||||
|
||||
|
||||
@app.route('/api/tools/get_network_topology', methods=['GET'])
|
||||
def api_tools_get_network_topology():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
topo = device_handler.getNetworkTopology()
|
||||
return jsonify(topo)
|
||||
|
||||
|
||||
@app.route('/api/tools/get_recent_alerts', methods=['POST'])
|
||||
def api_tools_get_recent_alerts():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
data = request.get_json(silent=True) or {}
|
||||
hours = int(data.get('hours', 24))
|
||||
# Reuse get_events() - which returns a Flask response with JSON containing 'events'
|
||||
res = get_events()
|
||||
events_json = res.get_json() if hasattr(res, 'get_json') else None
|
||||
events = events_json.get('events', []) if events_json else []
|
||||
cutoff = datetime.now() - timedelta(hours=hours)
|
||||
filtered = [e for e in events if 'eve_DateTime' in e and datetime.strptime(e['eve_DateTime'], '%Y-%m-%d %H:%M:%S') > cutoff]
|
||||
return jsonify(filtered)
|
||||
|
||||
|
||||
@app.route('/api/tools/set_device_alias', methods=['POST'])
|
||||
def api_tools_set_device_alias():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
data = request.get_json(silent=True) or {}
|
||||
mac = data.get('mac')
|
||||
alias = data.get('alias')
|
||||
if not mac or not alias:
|
||||
return jsonify({"error": "MAC and Alias are required"}), 400
|
||||
return update_device_column(mac, 'devName', alias)
|
||||
|
||||
|
||||
@app.route('/api/tools/wol_wake_device', methods=['POST'])
|
||||
def api_tools_wol_wake_device():
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
data = request.get_json(silent=True) or {}
|
||||
mac = data.get('mac')
|
||||
ip = data.get('ip')
|
||||
if not mac and not ip:
|
||||
return jsonify({"error": "MAC or IP is required"}), 400
|
||||
# Resolve IP to MAC if needed
|
||||
if not mac and ip:
|
||||
device = device_handler.getByIP(ip)
|
||||
if not device or not device.get('devMac'):
|
||||
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
|
||||
mac = device.get('devMac')
|
||||
# Validate mac using is_mac helper
|
||||
if not is_mac(mac):
|
||||
return jsonify({"success": False, "error": f"Invalid MAC: {mac}"}), 400
|
||||
return wakeonlan(mac)
|
||||
|
||||
|
||||
@app.route('/api/tools/openapi.json', methods=['GET'])
|
||||
def api_tools_openapi_spec():
|
||||
# Minimal OpenAPI spec for tools
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "NetAlertX Tools", "version": "1.1.0"},
|
||||
"servers": [{"url": "/api/tools"}],
|
||||
"paths": {}
|
||||
}
|
||||
return jsonify(spec)
|
||||
# Tools endpoints are registered via `mcp_endpoint.tools_bp` blueprint.
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Settings Endpoints
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route("/settings/<setKey>", methods=["GET"])
|
||||
def api_get_setting(setKey):
|
||||
if not is_authorized():
|
||||
@@ -558,8 +232,7 @@ def api_get_setting(setKey):
|
||||
# --------------------------
|
||||
# Device Endpoints
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route('/mcp/sse/device/<mac>', methods=['GET', 'POST'])
|
||||
@app.route("/device/<mac>", methods=["GET"])
|
||||
def api_get_device(mac):
|
||||
if not is_authorized():
|
||||
@@ -625,11 +298,45 @@ def api_update_device_column(mac):
|
||||
return update_device_column(mac, column_name, column_value)
|
||||
|
||||
|
||||
@app.route('/mcp/sse/device/<mac>/set-alias', methods=['POST'])
|
||||
@app.route('/device/<mac>/set-alias', methods=['POST'])
|
||||
def api_device_set_alias(mac):
|
||||
"""Set the device alias - convenience wrapper around update_device_column."""
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
data = request.get_json() or {}
|
||||
alias = data.get('alias')
|
||||
if not alias:
|
||||
return jsonify({"success": False, "message": "ERROR: Missing parameters", "error": "alias is required"}), 400
|
||||
return update_device_column(mac, 'devName', alias)
|
||||
|
||||
|
||||
@app.route('/mcp/sse/device/open_ports', methods=['POST'])
|
||||
@app.route('/device/open_ports', methods=['POST'])
|
||||
def api_device_open_ports():
|
||||
"""Get stored NMAP open ports for a target IP or MAC."""
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
target = data.get('target')
|
||||
if not target:
|
||||
return jsonify({"success": False, "error": "Target (IP or MAC) is required"}), 400
|
||||
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
# Use DeviceInstance method to get stored open ports
|
||||
open_ports = device_handler.getOpenPorts(target)
|
||||
|
||||
if not open_ports:
|
||||
return jsonify({"success": False, "error": f"No stored open ports for {target}. Run a scan with `/nettools/trigger-scan`"}), 404
|
||||
|
||||
return jsonify({"success": True, "target": target, "open_ports": open_ports})
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Devices Collections
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route("/devices", methods=["GET"])
|
||||
def api_get_devices():
|
||||
if not is_authorized():
|
||||
@@ -685,6 +392,7 @@ def api_devices_totals():
|
||||
return devices_totals()
|
||||
|
||||
|
||||
@app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST'])
|
||||
@app.route("/devices/by-status", methods=["GET"])
|
||||
def api_devices_by_status():
|
||||
if not is_authorized():
|
||||
@@ -695,15 +403,88 @@ def api_devices_by_status():
|
||||
return devices_by_status(status)
|
||||
|
||||
|
||||
@app.route('/mcp/sse/devices/search', methods=['POST'])
|
||||
@app.route('/devices/search', methods=['POST'])
|
||||
def api_devices_search():
|
||||
"""Device search: accepts 'query' in JSON and maps to device info/search."""
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
query = data.get('query')
|
||||
|
||||
if not query:
|
||||
return jsonify({"error": "Missing 'query' parameter"}), 400
|
||||
|
||||
if is_mac(query):
|
||||
device_data = get_device_data(query)
|
||||
if device_data:
|
||||
return jsonify({"success": True, "devices": [device_data.get_json()]})
|
||||
else:
|
||||
return jsonify({"success": False, "error": "Device not found"}), 404
|
||||
|
||||
# Create fresh DB instance for this thread
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
matches = device_handler.search(query)
|
||||
|
||||
if not matches:
|
||||
return jsonify({"success": False, "error": "No devices found"}), 404
|
||||
|
||||
return jsonify({"success": True, "devices": matches})
|
||||
|
||||
|
||||
@app.route('/mcp/sse/devices/latest', methods=['GET'])
|
||||
@app.route('/devices/latest', methods=['GET'])
|
||||
def api_devices_latest():
|
||||
"""Get latest device (most recent) - maps to DeviceInstance.getLatest()."""
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
latest = device_handler.getLatest()
|
||||
|
||||
if not latest:
|
||||
return jsonify({"message": "No devices found"}), 404
|
||||
return jsonify([latest])
|
||||
|
||||
|
||||
@app.route('/mcp/sse/devices/network/topology', methods=['GET'])
|
||||
@app.route('/devices/network/topology', methods=['GET'])
|
||||
def api_devices_network_topology():
|
||||
"""Network topology mapping."""
|
||||
if not is_authorized():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
result = device_handler.getNetworkTopology()
|
||||
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Net tools
|
||||
# --------------------------
|
||||
@app.route('/mcp/sse/nettools/wakeonlan', methods=['POST'])
|
||||
@app.route("/nettools/wakeonlan", methods=["POST"])
|
||||
def api_wakeonlan():
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
|
||||
mac = request.json.get("devMac")
|
||||
data = request.json or {}
|
||||
mac = data.get("devMac")
|
||||
ip = data.get("devLastIP") or data.get('ip')
|
||||
if not mac and ip:
|
||||
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
dev = device_handler.getByIP(ip)
|
||||
|
||||
if not dev or not dev.get('devMac'):
|
||||
return jsonify({"success": False, "message": "ERROR: Device not found", "error": "MAC not resolved"}), 404
|
||||
mac = dev.get('devMac')
|
||||
return wakeonlan(mac)
|
||||
|
||||
|
||||
@@ -764,11 +545,42 @@ def api_internet_info():
|
||||
return internet_info()
|
||||
|
||||
|
||||
@app.route('/mcp/sse/nettools/trigger-scan', methods=['POST'])
|
||||
@app.route("/nettools/trigger-scan", methods=["GET"])
|
||||
def api_trigger_scan():
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
scan_type = data.get('type', 'ARPSCAN')
|
||||
|
||||
# Validate scan type
|
||||
loaded_plugins = get_setting_value('LOADED_PLUGINS')
|
||||
if scan_type not in loaded_plugins:
|
||||
return jsonify({"success": False, "error": f"Invalid scan type. Must be one of: {', '.join(loaded_plugins)}"}), 400
|
||||
|
||||
queue = UserEventsQueueInstance()
|
||||
|
||||
action = f"run|{scan_type}"
|
||||
|
||||
queue.add_event(action)
|
||||
|
||||
return jsonify({"success": True, "message": f"Scan triggered for type: {scan_type}"}), 200
|
||||
|
||||
|
||||
# --------------------------
|
||||
# MCP Server
|
||||
# --------------------------
|
||||
@app.route('/mcp/sse/openapi.json', methods=['GET'])
|
||||
def api_openapi_spec():
|
||||
if not is_authorized():
|
||||
return jsonify({"Success": False, "error": "Unauthorized"}), 401
|
||||
return openapi_spec()
|
||||
|
||||
|
||||
# --------------------------
|
||||
# DB query
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route("/dbquery/read", methods=["POST"])
|
||||
def dbquery_read():
|
||||
if not is_authorized():
|
||||
@@ -791,6 +603,7 @@ def dbquery_write():
|
||||
data = request.get_json() or {}
|
||||
raw_sql_b64 = data.get("rawSql")
|
||||
if not raw_sql_b64:
|
||||
|
||||
return jsonify({"success": False, "message": "ERROR: Missing parameters", "error": "rawSql is required"}), 400
|
||||
|
||||
return write_query(raw_sql_b64)
|
||||
@@ -856,11 +669,13 @@ def api_delete_online_history():
|
||||
|
||||
@app.route("/logs", methods=["DELETE"])
|
||||
def api_clean_log():
|
||||
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
|
||||
file = request.args.get("file")
|
||||
if not file:
|
||||
|
||||
return jsonify({"success": False, "message": "ERROR: Missing parameters", "error": "Missing 'file' query parameter"}), 400
|
||||
|
||||
return clean_log(file)
|
||||
@@ -895,8 +710,6 @@ def api_add_to_execution_queue():
|
||||
# --------------------------
|
||||
# Device Events
|
||||
# --------------------------
|
||||
|
||||
|
||||
@app.route("/events/create/<mac>", methods=["POST"])
|
||||
def api_create_event(mac):
|
||||
if not is_authorized():
|
||||
@@ -960,6 +773,44 @@ def api_get_events_totals():
|
||||
return get_events_totals(period)
|
||||
|
||||
|
||||
@app.route('/mcp/sse/events/recent', methods=['GET', 'POST'])
|
||||
@app.route('/events/recent', methods=['GET'])
|
||||
def api_events_default_24h():
|
||||
return api_events_recent(24) # Reuse handler
|
||||
|
||||
|
||||
@app.route('/mcp/sse/events/last', methods=['GET', 'POST'])
|
||||
@app.route('/events/last', methods=['GET'])
|
||||
def get_last_events():
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
# Create fresh DB instance for this thread
|
||||
event_handler = EventInstance()
|
||||
|
||||
return event_handler.get_last_n(10)
|
||||
|
||||
|
||||
@app.route('/events/<int:hours>', methods=['GET'])
|
||||
def api_events_recent(hours):
|
||||
"""Return events from the last <hours> hours using EventInstance."""
|
||||
|
||||
if not is_authorized():
|
||||
return jsonify({"success": False, "error": "Unauthorized"}), 401
|
||||
|
||||
# Validate hours input
|
||||
if hours <= 0:
|
||||
return jsonify({"success": False, "error": "Hours must be > 0"}), 400
|
||||
try:
|
||||
# Create fresh DB instance for this thread
|
||||
event_handler = EventInstance()
|
||||
|
||||
events = event_handler.get_by_hours(hours)
|
||||
|
||||
return jsonify({"success": True, "hours": hours, "count": len(events), "events": events}), 200
|
||||
|
||||
except Exception as ex:
|
||||
return jsonify({"success": False, "error": str(ex)}), 500
|
||||
|
||||
# --------------------------
|
||||
# Sessions
|
||||
# --------------------------
|
||||
|
||||
@@ -228,7 +228,8 @@ def devices_totals():
|
||||
|
||||
def devices_by_status(status=None):
|
||||
"""
|
||||
Return devices filtered by status.
|
||||
Return devices filtered by status. Returns all if no status provided.
|
||||
Possible statuses: my, connected, favorites, new, down, archived
|
||||
"""
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
|
||||
207
server/api_server/mcp_endpoint.py
Normal file
207
server/api_server/mcp_endpoint.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import threading
|
||||
from flask import Blueprint, request, jsonify, Response, stream_with_context
|
||||
from helper import get_setting_value
|
||||
from helper import mylog
|
||||
# from .events_endpoint import get_events # will import locally where needed
|
||||
import requests
|
||||
import json
|
||||
import uuid
|
||||
import queue
|
||||
|
||||
# Blueprints
|
||||
mcp_bp = Blueprint('mcp', __name__)
|
||||
tools_bp = Blueprint('tools', __name__)
|
||||
|
||||
mcp_sessions = {}
|
||||
mcp_sessions_lock = threading.Lock()
|
||||
|
||||
|
||||
def check_auth():
|
||||
token = request.headers.get("Authorization")
|
||||
expected_token = f"Bearer {get_setting_value('API_TOKEN')}"
|
||||
return token == expected_token
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Specs
|
||||
# --------------------------
|
||||
def openapi_spec():
|
||||
# Spec matching actual available routes for MCP tools
|
||||
mylog("verbose", ["[MCP] OpenAPI spec requested"])
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "NetAlertX Tools", "version": "1.1.0"},
|
||||
"servers": [{"url": "/"}],
|
||||
"paths": {
|
||||
"/devices/by-status": {"post": {"operationId": "list_devices"}},
|
||||
"/device/{mac}": {"post": {"operationId": "get_device_info"}},
|
||||
"/devices/search": {"post": {"operationId": "search_devices"}},
|
||||
"/devices/latest": {"get": {"operationId": "get_latest_device"}},
|
||||
"/nettools/trigger-scan": {"post": {"operationId": "trigger_scan"}},
|
||||
"/device/open_ports": {"post": {"operationId": "get_open_ports"}},
|
||||
"/devices/network/topology": {"get": {"operationId": "get_network_topology"}},
|
||||
"/events/recent": {"get": {"operationId": "get_recent_alerts"}, "post": {"operationId": "get_recent_alerts"}},
|
||||
"/events/last": {"get": {"operationId": "get_last_events"}, "post": {"operationId": "get_last_events"}},
|
||||
"/device/{mac}/set-alias": {"post": {"operationId": "set_device_alias"}},
|
||||
"/nettools/wakeonlan": {"post": {"operationId": "wol_wake_device"}}
|
||||
}
|
||||
}
|
||||
return jsonify(spec)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# MCP SSE/JSON-RPC Endpoint
|
||||
# --------------------------
|
||||
|
||||
|
||||
# Sessions for SSE
|
||||
_sessions = {}
|
||||
_sessions_lock = __import__('threading').Lock()
|
||||
_openapi_spec_cache = None
|
||||
API_BASE_URL = f"http://localhost:{get_setting_value('GRAPHQL_PORT')}"
|
||||
|
||||
|
||||
def get_openapi_spec():
|
||||
global _openapi_spec_cache
|
||||
# Clear cache on each call for now to ensure fresh spec
|
||||
_openapi_spec_cache = None
|
||||
if _openapi_spec_cache:
|
||||
return _openapi_spec_cache
|
||||
try:
|
||||
r = requests.get(f"{API_BASE_URL}/mcp/openapi.json", timeout=10)
|
||||
r.raise_for_status()
|
||||
_openapi_spec_cache = r.json()
|
||||
return _openapi_spec_cache
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def map_openapi_to_mcp_tools(spec):
|
||||
tools = []
|
||||
if not spec or 'paths' not in spec:
|
||||
return tools
|
||||
for path, methods in spec['paths'].items():
|
||||
for method, details in methods.items():
|
||||
if 'operationId' in details:
|
||||
tool = {'name': details['operationId'], 'description': details.get('description', ''), 'inputSchema': {'type': 'object', 'properties': {}, 'required': []}}
|
||||
if 'requestBody' in details:
|
||||
content = details['requestBody'].get('content', {})
|
||||
if 'application/json' in content:
|
||||
schema = content['application/json'].get('schema', {})
|
||||
tool['inputSchema'] = schema.copy()
|
||||
if 'parameters' in details:
|
||||
for param in details['parameters']:
|
||||
if param.get('in') == 'query':
|
||||
tool['inputSchema']['properties'][param['name']] = {'type': param.get('schema', {}).get('type', 'string'), 'description': param.get('description', '')}
|
||||
if param.get('required'):
|
||||
tool['inputSchema']['required'].append(param['name'])
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
|
||||
def process_mcp_request(data):
|
||||
method = data.get('method')
|
||||
msg_id = data.get('id')
|
||||
if method == 'initialize':
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'result': {'protocolVersion': '2024-11-05', 'capabilities': {'tools': {}}, 'serverInfo': {'name': 'NetAlertX', 'version': '1.0.0'}}}
|
||||
if method == 'notifications/initialized':
|
||||
return None
|
||||
if method == 'tools/list':
|
||||
spec = get_openapi_spec()
|
||||
tools = map_openapi_to_mcp_tools(spec)
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'result': {'tools': tools}}
|
||||
if method == 'tools/call':
|
||||
params = data.get('params', {})
|
||||
tool_name = params.get('name')
|
||||
tool_args = params.get('arguments', {})
|
||||
spec = get_openapi_spec()
|
||||
target_path = None
|
||||
target_method = None
|
||||
if spec and 'paths' in spec:
|
||||
for path, methods in spec['paths'].items():
|
||||
for m, details in methods.items():
|
||||
if details.get('operationId') == tool_name:
|
||||
target_path = path
|
||||
target_method = m.upper()
|
||||
break
|
||||
if target_path:
|
||||
break
|
||||
if not target_path:
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'error': {'code': -32601, 'message': f"Tool {tool_name} not found"}}
|
||||
try:
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if 'Authorization' in request.headers:
|
||||
headers['Authorization'] = request.headers['Authorization']
|
||||
url = f"{API_BASE_URL}{target_path}"
|
||||
if target_method == 'POST':
|
||||
api_res = requests.post(url, json=tool_args, headers=headers, timeout=30)
|
||||
else:
|
||||
api_res = requests.get(url, params=tool_args, headers=headers, timeout=30)
|
||||
content = []
|
||||
try:
|
||||
json_content = api_res.json()
|
||||
content.append({'type': 'text', 'text': json.dumps(json_content, indent=2)})
|
||||
except Exception:
|
||||
content.append({'type': 'text', 'text': api_res.text})
|
||||
is_error = api_res.status_code >= 400
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'result': {'content': content, 'isError': is_error}}
|
||||
except Exception as e:
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'result': {'content': [{'type': 'text', 'text': f"Error calling tool: {str(e)}"}], 'isError': True}}
|
||||
if method == 'ping':
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'result': {}}
|
||||
if msg_id:
|
||||
return {'jsonrpc': '2.0', 'id': msg_id, 'error': {'code': -32601, 'message': 'Method not found'}}
|
||||
|
||||
|
||||
def mcp_messages():
|
||||
session_id = request.args.get('session_id')
|
||||
if not session_id:
|
||||
return jsonify({"error": "Missing session_id"}), 400
|
||||
with mcp_sessions_lock:
|
||||
if session_id not in mcp_sessions:
|
||||
return jsonify({"error": "Session not found"}), 404
|
||||
q = mcp_sessions[session_id]
|
||||
data = request.json
|
||||
if not data:
|
||||
return jsonify({"error": "Invalid JSON"}), 400
|
||||
response = process_mcp_request(data)
|
||||
if response:
|
||||
q.put(response)
|
||||
return jsonify({"status": "accepted"}), 202
|
||||
|
||||
|
||||
def mcp_sse():
|
||||
if request.method == 'POST':
|
||||
try:
|
||||
data = request.get_json(silent=True)
|
||||
if data and 'method' in data and 'jsonrpc' in data:
|
||||
response = process_mcp_request(data)
|
||||
if response:
|
||||
return jsonify(response)
|
||||
else:
|
||||
return '', 202
|
||||
except Exception as e:
|
||||
mylog("none", f'SSE POST processing error: {e}')
|
||||
return jsonify({'status': 'ok', 'message': 'MCP SSE endpoint active'}), 200
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
q = queue.Queue()
|
||||
with mcp_sessions_lock:
|
||||
mcp_sessions[session_id] = q
|
||||
|
||||
def stream():
|
||||
yield f"event: endpoint\ndata: /mcp/messages?session_id={session_id}\n\n"
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = q.get(timeout=20)
|
||||
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
||||
except queue.Empty:
|
||||
yield ": keep-alive\n\n"
|
||||
except GeneratorExit:
|
||||
with mcp_sessions_lock:
|
||||
if session_id in mcp_sessions:
|
||||
del mcp_sessions[session_id]
|
||||
return Response(stream_with_context(stream()), mimetype='text/event-stream')
|
||||
@@ -1,686 +0,0 @@
|
||||
import subprocess
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, request, jsonify
|
||||
import sqlite3
|
||||
from helper import get_setting_value
|
||||
from database import get_temp_db_connection
|
||||
|
||||
tools_bp = Blueprint('tools', __name__)
|
||||
|
||||
|
||||
def check_auth():
|
||||
"""Check API_TOKEN authorization."""
|
||||
token = request.headers.get("Authorization")
|
||||
expected_token = f"Bearer {get_setting_value('API_TOKEN')}"
|
||||
return token == expected_token
|
||||
|
||||
|
||||
@tools_bp.route('/trigger_scan', methods=['POST'])
|
||||
def trigger_scan():
|
||||
"""
|
||||
Forces NetAlertX to run a specific scan type immediately.
|
||||
Arguments: scan_type (Enum: arp, nmap_fast, nmap_deep), target (optional IP/CIDR)
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
scan_type = data.get('scan_type', 'nmap_fast')
|
||||
target = data.get('target')
|
||||
|
||||
# Validate scan_type
|
||||
if scan_type not in ['arp', 'nmap_fast', 'nmap_deep']:
|
||||
return jsonify({"error": "Invalid scan_type. Must be 'arp', 'nmap_fast', or 'nmap_deep'"}), 400
|
||||
|
||||
# Determine command
|
||||
cmd = []
|
||||
if scan_type == 'arp':
|
||||
# ARP scan usually requires sudo or root, assuming container runs as root or has caps
|
||||
cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection
|
||||
if target:
|
||||
cmd = ["arp-scan", target]
|
||||
elif scan_type == 'nmap_fast':
|
||||
cmd = ["nmap", "-F"]
|
||||
if target:
|
||||
cmd.append(target)
|
||||
else:
|
||||
# Default to local subnet if possible, or error if not easily determined
|
||||
# For now, let's require target for nmap if not easily deducible,
|
||||
# or try to get it from settings.
|
||||
# NetAlertX usually knows its subnet.
|
||||
# Let's try to get the scan subnet from settings if not provided.
|
||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||
if scan_subnets:
|
||||
# Take the first one for now
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
else:
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
elif scan_type == 'nmap_deep':
|
||||
cmd = ["nmap", "-A", "-T4"]
|
||||
if target:
|
||||
cmd.append(target)
|
||||
else:
|
||||
scan_subnets = get_setting_value("SCAN_SUBNETS")
|
||||
if scan_subnets:
|
||||
cmd.append(scan_subnets.split(',')[0].strip())
|
||||
else:
|
||||
return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400
|
||||
|
||||
try:
|
||||
# Run the command
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"scan_type": scan_type,
|
||||
"command": " ".join(cmd),
|
||||
"output": result.stdout.strip().split('\n')
|
||||
})
|
||||
except subprocess.CalledProcessError as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Scan failed",
|
||||
"details": e.stderr.strip()
|
||||
}), 500
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@tools_bp.route('/list_devices', methods=['POST'])
|
||||
def list_devices():
|
||||
"""List all devices."""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute("SELECT devName, devMac, devLastIP as devIP, devVendor, devFirstConnection, devLastConnection FROM Devices ORDER BY devFirstConnection DESC")
|
||||
rows = cur.fetchall()
|
||||
devices = [dict(row) for row in rows]
|
||||
return jsonify(devices)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/get_device_info', methods=['POST'])
|
||||
def get_device_info():
|
||||
"""Get detailed info for a specific device."""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
if not data or 'query' not in data:
|
||||
return jsonify({"error": "Missing 'query' parameter"}), 400
|
||||
|
||||
query = data['query']
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
# Search by MAC, Name, or partial IP
|
||||
sql = "SELECT * FROM Devices WHERE devMac LIKE ? OR devName LIKE ? OR devLastIP LIKE ?"
|
||||
cur.execute(sql, (f"%{query}%", f"%{query}%", f"%{query}%"))
|
||||
rows = cur.fetchall()
|
||||
|
||||
if not rows:
|
||||
return jsonify({"message": "No devices found"}), 404
|
||||
|
||||
devices = [dict(row) for row in rows]
|
||||
return jsonify(devices)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/get_latest_device', methods=['POST'])
|
||||
def get_latest_device():
|
||||
"""Get full details of the most recently discovered device."""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
# Get the device with the most recent devFirstConnection
|
||||
cur.execute("SELECT * FROM Devices ORDER BY devFirstConnection DESC LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return jsonify({"message": "No devices found"}), 404
|
||||
|
||||
# Return as a list to be consistent with other endpoints
|
||||
return jsonify([dict(row)])
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/get_open_ports', methods=['POST'])
|
||||
def get_open_ports():
|
||||
"""
|
||||
Specific query for the port-scan results of a target.
|
||||
Arguments: target (IP or MAC)
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
target = data.get('target')
|
||||
|
||||
if not target:
|
||||
return jsonify({"error": "Target is required"}), 400
|
||||
|
||||
# If MAC is provided, try to resolve to IP
|
||||
if re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", target):
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute("SELECT devLastIP FROM Devices WHERE devMac = ?", (target,))
|
||||
row = cur.fetchone()
|
||||
if row and row['devLastIP']:
|
||||
target = row['devLastIP']
|
||||
else:
|
||||
return jsonify({"error": f"Could not resolve IP for MAC {target}"}), 404
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
try:
|
||||
# Run nmap -F for fast port scan
|
||||
cmd = ["nmap", "-F", target]
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
# Parse output for open ports
|
||||
open_ports = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if '/tcp' in line and 'open' in line:
|
||||
parts = line.split('/')
|
||||
port = parts[0].strip()
|
||||
service = line.split()[2] if len(line.split()) > 2 else "unknown"
|
||||
open_ports.append({"port": int(port), "service": service})
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"target": target,
|
||||
"open_ports": open_ports,
|
||||
"raw_output": result.stdout.strip().split('\n')
|
||||
})
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
return jsonify({"success": False, "error": "Port scan failed", "details": e.stderr.strip()}), 500
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@tools_bp.route('/get_network_topology', methods=['GET'])
|
||||
def get_network_topology():
|
||||
"""
|
||||
Returns the "Parent/Child" relationships.
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
|
||||
rows = cur.fetchall()
|
||||
|
||||
nodes = []
|
||||
links = []
|
||||
|
||||
for row in rows:
|
||||
nodes.append({
|
||||
"id": row['devMac'],
|
||||
"name": row['devName'],
|
||||
"vendor": row['devVendor']
|
||||
})
|
||||
if row['devParentMAC']:
|
||||
links.append({
|
||||
"source": row['devParentMAC'],
|
||||
"target": row['devMac'],
|
||||
"port": row['devParentPort']
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
"nodes": nodes,
|
||||
"links": links
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/get_recent_alerts', methods=['POST'])
|
||||
def get_recent_alerts():
|
||||
"""
|
||||
Fetches the last N system alerts.
|
||||
Arguments: hours (lookback period, default 24)
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
hours = data.get('hours', 24)
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
# Calculate cutoff time
|
||||
cutoff = datetime.now() - timedelta(hours=int(hours))
|
||||
cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
cur.execute("""
|
||||
SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName
|
||||
FROM Events
|
||||
LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac
|
||||
WHERE eve_DateTime > ?
|
||||
ORDER BY eve_DateTime DESC
|
||||
""", (cutoff_str,))
|
||||
|
||||
rows = cur.fetchall()
|
||||
alerts = [dict(row) for row in rows]
|
||||
|
||||
return jsonify(alerts)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/set_device_alias', methods=['POST'])
|
||||
def set_device_alias():
|
||||
"""
|
||||
Updates the name (alias) of a device.
|
||||
Arguments: mac, alias
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
mac = data.get('mac')
|
||||
alias = data.get('alias')
|
||||
|
||||
if not mac or not alias:
|
||||
return jsonify({"error": "MAC and Alias are required"}), 400
|
||||
|
||||
conn = get_temp_db_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac))
|
||||
conn.commit()
|
||||
|
||||
if cur.rowcount == 0:
|
||||
return jsonify({"error": "Device not found"}), 404
|
||||
|
||||
return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@tools_bp.route('/wol_wake_device', methods=['POST'])
|
||||
def wol_wake_device():
|
||||
"""
|
||||
Sends a Wake-on-LAN magic packet.
|
||||
Arguments: mac OR ip
|
||||
"""
|
||||
if not check_auth():
|
||||
return jsonify({"error": "Unauthorized"}), 401
|
||||
|
||||
data = request.get_json()
|
||||
mac = data.get('mac')
|
||||
ip = data.get('ip')
|
||||
|
||||
if not mac and not ip:
|
||||
return jsonify({"error": "MAC address or IP address is required"}), 400
|
||||
|
||||
# Resolve IP to MAC if MAC is missing
|
||||
if not mac and ip:
|
||||
conn = get_temp_db_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
# Try to find device by IP (devLastIP)
|
||||
cur.execute("SELECT devMac FROM Devices WHERE devLastIP = ?", (ip,))
|
||||
row = cur.fetchone()
|
||||
if row and row['devMac']:
|
||||
mac = row['devMac']
|
||||
else:
|
||||
return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"Database error: {str(e)}"}), 500
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# Validate MAC
|
||||
if not re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac):
|
||||
return jsonify({"success": False, "error": f"Invalid MAC: {mac}"}), 400
|
||||
|
||||
try:
|
||||
# Using wakeonlan command
|
||||
result = subprocess.run(
|
||||
["wakeonlan", mac], capture_output=True, text=True, check=True, timeout=10
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"WOL packet sent to {mac}",
|
||||
"output": result.stdout.strip(),
|
||||
}
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
return jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to send WOL packet",
|
||||
"details": e.stderr.strip(),
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
@tools_bp.route('/openapi.json', methods=['GET'])
|
||||
def openapi_spec():
|
||||
"""Return OpenAPI specification for tools."""
|
||||
# No auth required for spec to allow easy import, or require it if preferred.
|
||||
# Open WebUI usually needs to fetch spec without auth first or handles it.
|
||||
# We'll allow public access to spec for simplicity of import.
|
||||
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "NetAlertX Tools",
|
||||
"description": "API for NetAlertX device management tools",
|
||||
"version": "1.1.0"
|
||||
},
|
||||
"servers": [
|
||||
{"url": "/api/tools"}
|
||||
],
|
||||
"paths": {
|
||||
"/list_devices": {
|
||||
"post": {
|
||||
"summary": "List all devices (Summary)",
|
||||
"description": (
|
||||
"Retrieve a SUMMARY list of all devices, sorted by newest first. "
|
||||
"IMPORTANT: This only provides basic info (Name, IP, Vendor). "
|
||||
"For FULL details (like custom props, alerts, etc.), you MUST use 'get_device_info' or 'get_latest_device'."
|
||||
),
|
||||
"operationId": "list_devices",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "List of devices (Summary)",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"devName": {"type": "string"},
|
||||
"devMac": {"type": "string"},
|
||||
"devIP": {"type": "string"},
|
||||
"devVendor": {"type": "string"},
|
||||
"devStatus": {"type": "string"},
|
||||
"devFirstConnection": {"type": "string"},
|
||||
"devLastConnection": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/get_device_info": {
|
||||
"post": {
|
||||
"summary": "Get device info (Full Details)",
|
||||
"description": (
|
||||
"Get COMPREHENSIVE information about a specific device by MAC, Name, or partial IP. "
|
||||
"Use this to see all available properties, alerts, and metadata not shown in the list."
|
||||
),
|
||||
"operationId": "get_device_info",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "MAC address, Device Name, or partial IP to search for"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Device details (Full)",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "Device not found"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/get_latest_device": {
|
||||
"post": {
|
||||
"summary": "Get latest device (Full Details)",
|
||||
"description": "Get COMPREHENSIVE information about the most recently discovered device (latest devFirstConnection).",
|
||||
"operationId": "get_latest_device",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Latest device details (Full)",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "No devices found"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/trigger_scan": {
|
||||
"post": {
|
||||
"summary": "Trigger Active Scan",
|
||||
"description": "Forces NetAlertX to run a specific scan type immediately.",
|
||||
"operationId": "trigger_scan",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scan_type": {
|
||||
"type": "string",
|
||||
"enum": ["arp", "nmap_fast", "nmap_deep"],
|
||||
"default": "nmap_fast"
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "IP address or CIDR to scan"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {"description": "Scan started/completed successfully"},
|
||||
"400": {"description": "Invalid input"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/get_open_ports": {
|
||||
"post": {
|
||||
"summary": "Get Open Ports",
|
||||
"description": "Specific query for the port-scan results of a target.",
|
||||
"operationId": "get_open_ports",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "IP or MAC address"
|
||||
}
|
||||
},
|
||||
"required": ["target"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {"description": "List of open ports"},
|
||||
"404": {"description": "Target not found"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/get_network_topology": {
|
||||
"get": {
|
||||
"summary": "Get Network Topology",
|
||||
"description": "Returns the Parent/Child relationships for network visualization.",
|
||||
"operationId": "get_network_topology",
|
||||
"responses": {
|
||||
"200": {"description": "Graph data (nodes and links)"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/get_recent_alerts": {
|
||||
"post": {
|
||||
"summary": "Get Recent Alerts",
|
||||
"description": "Fetches the last N system alerts.",
|
||||
"operationId": "get_recent_alerts",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"hours": {
|
||||
"type": "integer",
|
||||
"default": 24
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {"description": "List of alerts"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/set_device_alias": {
|
||||
"post": {
|
||||
"summary": "Set Device Alias",
|
||||
"description": "Updates the name (alias) of a device.",
|
||||
"operationId": "set_device_alias",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mac": {"type": "string"},
|
||||
"alias": {"type": "string"}
|
||||
},
|
||||
"required": ["mac", "alias"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {"description": "Alias updated"},
|
||||
"404": {"description": "Device not found"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/wol_wake_device": {
|
||||
"post": {
|
||||
"summary": "Wake on LAN",
|
||||
"description": "Sends a Wake-on-LAN magic packet to the target MAC or IP. If IP is provided, it resolves to MAC first.",
|
||||
"operationId": "wol_wake_device",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mac": {"type": "string", "description": "Target MAC address"},
|
||||
"ip": {"type": "string", "description": "Target IP address (resolves to MAC)"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {"description": "WOL packet sent"},
|
||||
"404": {"description": "IP not found"}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"bearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"bearerAuth": []}
|
||||
]
|
||||
}
|
||||
return jsonify(spec)
|
||||
@@ -1,121 +1,134 @@
|
||||
from front.plugins.plugin_helper import is_mac
|
||||
from logger import mylog
|
||||
from models.plugin_object_instance import PluginObjectInstance
|
||||
from database import get_temp_db_connection
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Device object handling (WIP)
|
||||
# -------------------------------------------------------------------------------
|
||||
class DeviceInstance:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
# Get all
|
||||
def getAll(self):
|
||||
self.db.sql.execute("""
|
||||
SELECT * FROM Devices
|
||||
""")
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get all with unknown names
|
||||
def getUnknown(self):
|
||||
self.db.sql.execute("""
|
||||
SELECT * FROM Devices WHERE devName in ("(unknown)", "(name not found)", "" )
|
||||
""")
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get specific column value based on devMac
|
||||
def getValueWithMac(self, column_name, devMac):
|
||||
query = f"SELECT {column_name} FROM Devices WHERE devMac = ?"
|
||||
self.db.sql.execute(query, (devMac,))
|
||||
result = self.db.sql.fetchone()
|
||||
return result[column_name] if result else None
|
||||
|
||||
# Get all down
|
||||
def getDown(self):
|
||||
self.db.sql.execute("""
|
||||
SELECT * FROM Devices WHERE devAlertDown = 1 and devPresentLastScan = 0
|
||||
""")
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get all down
|
||||
def getOffline(self):
|
||||
self.db.sql.execute("""
|
||||
SELECT * FROM Devices WHERE devPresentLastScan = 0
|
||||
""")
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get a device by devGUID
|
||||
def getByGUID(self, devGUID):
|
||||
self.db.sql.execute("SELECT * FROM Devices WHERE devGUID = ?", (devGUID,))
|
||||
result = self.db.sql.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
# Check if a device exists by devGUID
|
||||
def exists(self, devGUID):
|
||||
self.db.sql.execute(
|
||||
"SELECT COUNT(*) AS count FROM Devices WHERE devGUID = ?", (devGUID,)
|
||||
)
|
||||
result = self.db.sql.fetchone()
|
||||
return result["count"] > 0
|
||||
|
||||
# Get a device by its last IP address
|
||||
def getByIP(self, ip):
|
||||
self.db.sql.execute("SELECT * FROM Devices WHERE devLastIP = ?", (ip,))
|
||||
row = self.db.sql.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
# Search devices by partial mac, name or IP
|
||||
def search(self, query):
|
||||
like = f"%{query}%"
|
||||
self.db.sql.execute(
|
||||
"SELECT * FROM Devices WHERE devMac LIKE ? OR devName LIKE ? OR devLastIP LIKE ?",
|
||||
(like, like, like),
|
||||
)
|
||||
rows = self.db.sql.fetchall()
|
||||
# --- helpers --------------------------------------------------------------
|
||||
def _fetchall(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
# Get the most recently discovered device
|
||||
def getLatest(self):
|
||||
self.db.sql.execute("SELECT * FROM Devices ORDER BY devFirstConnection DESC LIMIT 1")
|
||||
row = self.db.sql.fetchone()
|
||||
def _fetchone(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
row = conn.execute(query, params).fetchone()
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
|
||||
def getNetworkTopology(self):
|
||||
"""Returns nodes and links for the current Devices table.
|
||||
def _execute(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
cur = conn.cursor()
|
||||
cur.execute(query, params)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
Nodes: {id, name, vendor}
|
||||
Links: {source, target, port}
|
||||
"""
|
||||
self.db.sql.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices")
|
||||
rows = self.db.sql.fetchall()
|
||||
nodes = []
|
||||
links = []
|
||||
for row in rows:
|
||||
nodes.append({"id": row['devMac'], "name": row['devName'], "vendor": row['devVendor']})
|
||||
if row['devParentMAC']:
|
||||
links.append({"source": row['devParentMAC'], "target": row['devMac'], "port": row['devParentPort']})
|
||||
# --- public API -----------------------------------------------------------
|
||||
def getAll(self):
|
||||
return self._fetchall("SELECT * FROM Devices")
|
||||
|
||||
def getUnknown(self):
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Devices
|
||||
WHERE devName IN ("(unknown)", "(name not found)", "")
|
||||
""")
|
||||
|
||||
def getValueWithMac(self, column_name, devMac):
|
||||
row = self._fetchone(f"""
|
||||
SELECT {column_name} FROM Devices WHERE devMac = ?
|
||||
""", (devMac,))
|
||||
return row.get(column_name) if row else None
|
||||
|
||||
def getDown(self):
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Devices
|
||||
WHERE devAlertDown = 1 AND devPresentLastScan = 0
|
||||
""")
|
||||
|
||||
def getOffline(self):
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Devices
|
||||
WHERE devPresentLastScan = 0
|
||||
""")
|
||||
|
||||
def getByGUID(self, devGUID):
|
||||
return self._fetchone("""
|
||||
SELECT * FROM Devices WHERE devGUID = ?
|
||||
""", (devGUID,))
|
||||
|
||||
def exists(self, devGUID):
|
||||
row = self._fetchone("""
|
||||
SELECT COUNT(*) as count FROM Devices WHERE devGUID = ?
|
||||
""", (devGUID,))
|
||||
return row['count'] > 0 if row else False
|
||||
|
||||
def getByIP(self, ip):
|
||||
return self._fetchone("""
|
||||
SELECT * FROM Devices WHERE devLastIP = ?
|
||||
""", (ip,))
|
||||
|
||||
def search(self, query):
|
||||
like = f"%{query}%"
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Devices
|
||||
WHERE devMac LIKE ? OR devName LIKE ? OR devLastIP LIKE ?
|
||||
""", (like, like, like))
|
||||
|
||||
def getLatest(self):
|
||||
return self._fetchone("""
|
||||
SELECT * FROM Devices
|
||||
ORDER BY devFirstConnection DESC LIMIT 1
|
||||
""")
|
||||
|
||||
def getNetworkTopology(self):
|
||||
rows = self._fetchall("""
|
||||
SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices
|
||||
""")
|
||||
nodes = [{"id": r["devMac"], "name": r["devName"], "vendor": r["devVendor"]} for r in rows]
|
||||
links = [{"source": r["devParentMAC"], "target": r["devMac"], "port": r["devParentPort"]}
|
||||
for r in rows if r["devParentMAC"]]
|
||||
return {"nodes": nodes, "links": links}
|
||||
|
||||
# Update a specific field for a device
|
||||
def updateField(self, devGUID, field, value):
|
||||
if not self.exists(devGUID):
|
||||
m = f"[Device] In 'updateField': GUID {devGUID} not found."
|
||||
mylog("none", m)
|
||||
raise ValueError(m)
|
||||
msg = f"[Device] updateField: GUID {devGUID} not found"
|
||||
mylog("none", msg)
|
||||
raise ValueError(msg)
|
||||
self._execute(f"UPDATE Devices SET {field}=? WHERE devGUID=?", (value, devGUID))
|
||||
|
||||
self.db.sql.execute(
|
||||
f"""
|
||||
UPDATE Devices SET {field} = ? WHERE devGUID = ?
|
||||
""",
|
||||
(value, devGUID),
|
||||
)
|
||||
self.db.commitDB()
|
||||
|
||||
# Delete a device by devGUID
|
||||
def delete(self, devGUID):
|
||||
if not self.exists(devGUID):
|
||||
m = f"[Device] In 'delete': GUID {devGUID} not found."
|
||||
mylog("none", m)
|
||||
raise ValueError(m)
|
||||
msg = f"[Device] delete: GUID {devGUID} not found"
|
||||
mylog("none", msg)
|
||||
raise ValueError(msg)
|
||||
self._execute("DELETE FROM Devices WHERE devGUID=?", (devGUID,))
|
||||
|
||||
self.db.sql.execute("DELETE FROM Devices WHERE devGUID = ?", (devGUID,))
|
||||
self.db.commitDB()
|
||||
def resolvePrimaryID(self, target):
|
||||
if is_mac(target):
|
||||
return target.lower()
|
||||
dev = self.getByIP(target)
|
||||
return dev['devMac'].lower() if dev else None
|
||||
|
||||
def getOpenPorts(self, target):
|
||||
primary = self.resolvePrimaryID(target)
|
||||
if not primary:
|
||||
return []
|
||||
|
||||
objs = PluginObjectInstance().getByField(
|
||||
plugPrefix='NMAP',
|
||||
matchedColumn='Object_PrimaryID',
|
||||
matchedKey=primary,
|
||||
returnFields=['Object_SecondaryID', 'Watched_Value2']
|
||||
)
|
||||
|
||||
ports = []
|
||||
for o in objs:
|
||||
|
||||
port = int(o.get('Object_SecondaryID') or 0)
|
||||
|
||||
ports.append({"port": port, "service": o.get('Watched_Value2', '')})
|
||||
|
||||
return ports
|
||||
|
||||
106
server/models/event_instance.py
Normal file
106
server/models/event_instance.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime, timedelta
|
||||
from logger import mylog
|
||||
from database import get_temp_db_connection
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Event handling (Matches table: Events)
|
||||
# -------------------------------------------------------------------------------
|
||||
class EventInstance:
|
||||
|
||||
def _conn(self):
|
||||
"""Always return a new DB connection (thread-safe)."""
|
||||
return get_temp_db_connection()
|
||||
|
||||
def _rows_to_list(self, rows):
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
# Get all events
|
||||
def get_all(self):
|
||||
conn = self._conn()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM Events ORDER BY eve_DateTime DESC"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return self._rows_to_list(rows)
|
||||
|
||||
# --- Get last n events ---
|
||||
def get_last_n(self, n=10):
|
||||
conn = self._conn()
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM Events
|
||||
ORDER BY eve_DateTime DESC
|
||||
LIMIT ?
|
||||
""", (n,)).fetchall()
|
||||
return self._rows_to_list(rows)
|
||||
|
||||
# --- Specific helper for last 10 ---
|
||||
def get_last(self):
|
||||
return self.get_last_n(10)
|
||||
|
||||
# Get events in the last 24h
|
||||
def get_recent(self):
|
||||
since = datetime.now() - timedelta(hours=24)
|
||||
conn = self._conn()
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM Events
|
||||
WHERE eve_DateTime >= ?
|
||||
ORDER BY eve_DateTime DESC
|
||||
""", (since,)).fetchall()
|
||||
conn.close()
|
||||
return self._rows_to_list(rows)
|
||||
|
||||
# Get events from last N hours
|
||||
def get_by_hours(self, hours: int):
|
||||
if hours <= 0:
|
||||
mylog("warn", f"[Events] get_by_hours({hours}) -> invalid value")
|
||||
return []
|
||||
|
||||
since = datetime.now() - timedelta(hours=hours)
|
||||
conn = self._conn()
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM Events
|
||||
WHERE eve_DateTime >= ?
|
||||
ORDER BY eve_DateTime DESC
|
||||
""", (since,)).fetchall()
|
||||
conn.close()
|
||||
return self._rows_to_list(rows)
|
||||
|
||||
# Get events in a date range
|
||||
def get_by_range(self, start: datetime, end: datetime):
|
||||
if end < start:
|
||||
mylog("error", f"[Events] get_by_range invalid: {start} > {end}")
|
||||
raise ValueError("Start must not be after end")
|
||||
|
||||
conn = self._conn()
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM Events
|
||||
WHERE eve_DateTime BETWEEN ? AND ?
|
||||
ORDER BY eve_DateTime DESC
|
||||
""", (start, end)).fetchall()
|
||||
conn.close()
|
||||
return self._rows_to_list(rows)
|
||||
|
||||
# Insert new event
|
||||
def add(self, mac, ip, eventType, info="", pendingAlert=True, pairRow=None):
|
||||
conn = self._conn()
|
||||
conn.execute("""
|
||||
INSERT INTO Events (
|
||||
eve_MAC, eve_IP, eve_DateTime,
|
||||
eve_EventType, eve_AdditionalInfo,
|
||||
eve_PendingAlertEmail, eve_PairEventRowid
|
||||
) VALUES (?,?,?,?,?,?,?)
|
||||
""", (mac, ip, datetime.now(), eventType, info,
|
||||
1 if pendingAlert else 0, pairRow))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Delete old events
|
||||
def delete_older_than(self, days: int):
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
conn = self._conn()
|
||||
result = conn.execute("DELETE FROM Events WHERE eve_DateTime < ?", (cutoff,))
|
||||
conn.commit()
|
||||
deleted_count = result.rowcount
|
||||
conn.close()
|
||||
return deleted_count
|
||||
@@ -1,79 +1,91 @@
|
||||
from logger import mylog
|
||||
from database import get_temp_db_connection
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Plugin object handling (WIP)
|
||||
# Plugin object handling (THREAD-SAFE REWRITE)
|
||||
# -------------------------------------------------------------------------------
|
||||
class PluginObjectInstance:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
# Get all plugin objects
|
||||
def getAll(self):
|
||||
self.db.sql.execute("""
|
||||
SELECT * FROM Plugins_Objects
|
||||
""")
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get plugin object by ObjectGUID
|
||||
def getByGUID(self, ObjectGUID):
|
||||
self.db.sql.execute(
|
||||
"SELECT * FROM Plugins_Objects WHERE ObjectGUID = ?", (ObjectGUID,)
|
||||
)
|
||||
result = self.db.sql.fetchone()
|
||||
return dict(result) if result else None
|
||||
|
||||
# Check if a plugin object exists by ObjectGUID
|
||||
def exists(self, ObjectGUID):
|
||||
self.db.sql.execute(
|
||||
"SELECT COUNT(*) AS count FROM Plugins_Objects WHERE ObjectGUID = ?",
|
||||
(ObjectGUID,),
|
||||
)
|
||||
result = self.db.sql.fetchone()
|
||||
return result["count"] > 0
|
||||
|
||||
# Get objects by plugin name
|
||||
def getByPlugin(self, plugin):
|
||||
self.db.sql.execute("SELECT * FROM Plugins_Objects WHERE Plugin = ?", (plugin,))
|
||||
return self.db.sql.fetchall()
|
||||
|
||||
# Get plugin objects by primary ID and plugin name
|
||||
def getByPrimary(self, plugin, primary_id):
|
||||
self.db.sql.execute(
|
||||
"SELECT * FROM Plugins_Objects WHERE Plugin = ? AND Object_PrimaryID = ?",
|
||||
(plugin, primary_id),
|
||||
)
|
||||
rows = self.db.sql.fetchall()
|
||||
# -------------- Internal DB helper wrappers --------------------------------
|
||||
def _fetchall(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
# Get objects by status
|
||||
def getByStatus(self, status):
|
||||
self.db.sql.execute("SELECT * FROM Plugins_Objects WHERE Status = ?", (status,))
|
||||
return self.db.sql.fetchall()
|
||||
def _fetchone(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
row = conn.execute(query, params).fetchone()
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
|
||||
def _execute(self, query, params=()):
|
||||
conn = get_temp_db_connection()
|
||||
conn.execute(query, params)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — identical behaviour, now thread-safe + self-contained
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def getAll(self):
|
||||
return self._fetchall("SELECT * FROM Plugins_Objects")
|
||||
|
||||
def getByGUID(self, ObjectGUID):
|
||||
return self._fetchone(
|
||||
"SELECT * FROM Plugins_Objects WHERE ObjectGUID = ?", (ObjectGUID,)
|
||||
)
|
||||
|
||||
def exists(self, ObjectGUID):
|
||||
row = self._fetchone("""
|
||||
SELECT COUNT(*) AS count FROM Plugins_Objects WHERE ObjectGUID = ?
|
||||
""", (ObjectGUID,))
|
||||
return row["count"] > 0 if row else False
|
||||
|
||||
def getByPlugin(self, plugin):
|
||||
return self._fetchall(
|
||||
"SELECT * FROM Plugins_Objects WHERE Plugin = ?", (plugin,)
|
||||
)
|
||||
|
||||
def getByField(self, plugPrefix, matchedColumn, matchedKey, returnFields=None):
|
||||
rows = self._fetchall(
|
||||
f"SELECT * FROM Plugins_Objects WHERE Plugin = ? AND {matchedColumn} = ?",
|
||||
(plugPrefix, matchedKey.lower())
|
||||
)
|
||||
|
||||
if not returnFields:
|
||||
return rows
|
||||
|
||||
return [{f: row.get(f) for f in returnFields} for row in rows]
|
||||
|
||||
def getByPrimary(self, plugin, primary_id):
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Plugins_Objects
|
||||
WHERE Plugin = ? AND Object_PrimaryID = ?
|
||||
""", (plugin, primary_id))
|
||||
|
||||
def getByStatus(self, status):
|
||||
return self._fetchall("""
|
||||
SELECT * FROM Plugins_Objects WHERE Status = ?
|
||||
""", (status,))
|
||||
|
||||
# Update a specific field for a plugin object
|
||||
def updateField(self, ObjectGUID, field, value):
|
||||
if not self.exists(ObjectGUID):
|
||||
m = f"[PluginObject] In 'updateField': GUID {ObjectGUID} not found."
|
||||
mylog("none", m)
|
||||
raise ValueError(m)
|
||||
msg = f"[PluginObject] updateField: GUID {ObjectGUID} not found."
|
||||
mylog("none", msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.db.sql.execute(
|
||||
f"""
|
||||
UPDATE Plugins_Objects SET {field} = ? WHERE ObjectGUID = ?
|
||||
""",
|
||||
(value, ObjectGUID),
|
||||
self._execute(
|
||||
f"UPDATE Plugins_Objects SET {field}=? WHERE ObjectGUID=?",
|
||||
(value, ObjectGUID)
|
||||
)
|
||||
self.db.commitDB()
|
||||
|
||||
# Delete a plugin object by ObjectGUID
|
||||
def delete(self, ObjectGUID):
|
||||
if not self.exists(ObjectGUID):
|
||||
m = f"[PluginObject] In 'delete': GUID {ObjectGUID} not found."
|
||||
mylog("none", m)
|
||||
raise ValueError(m)
|
||||
msg = f"[PluginObject] delete: GUID {ObjectGUID} not found."
|
||||
mylog("none", msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.db.sql.execute(
|
||||
"DELETE FROM Plugins_Objects WHERE ObjectGUID = ?", (ObjectGUID,)
|
||||
)
|
||||
self.db.commitDB()
|
||||
self._execute("DELETE FROM Plugins_Objects WHERE ObjectGUID=?", (ObjectGUID,))
|
||||
|
||||
@@ -650,7 +650,7 @@ def update_devices_names(pm):
|
||||
|
||||
sql = pm.db.sql
|
||||
resolver = NameResolver(pm.db)
|
||||
device_handler = DeviceInstance(pm.db)
|
||||
device_handler = DeviceInstance()
|
||||
|
||||
nameNotFound = "(name not found)"
|
||||
|
||||
|
||||
@@ -42,13 +42,13 @@ class UpdateFieldAction(Action):
|
||||
# currently unused
|
||||
if isinstance(obj, dict) and "ObjectGUID" in obj:
|
||||
mylog("debug", f"[WF] Updating Object '{obj}' ")
|
||||
plugin_instance = PluginObjectInstance(self.db)
|
||||
plugin_instance = PluginObjectInstance()
|
||||
plugin_instance.updateField(obj["ObjectGUID"], self.field, self.value)
|
||||
processed = True
|
||||
|
||||
elif isinstance(obj, dict) and "devGUID" in obj:
|
||||
mylog("debug", f"[WF] Updating Device '{obj}' ")
|
||||
device_instance = DeviceInstance(self.db)
|
||||
device_instance = DeviceInstance()
|
||||
device_instance.updateField(obj["devGUID"], self.field, self.value)
|
||||
processed = True
|
||||
|
||||
@@ -79,13 +79,13 @@ class DeleteObjectAction(Action):
|
||||
# currently unused
|
||||
if isinstance(obj, dict) and "ObjectGUID" in obj:
|
||||
mylog("debug", f"[WF] Updating Object '{obj}' ")
|
||||
plugin_instance = PluginObjectInstance(self.db)
|
||||
plugin_instance = PluginObjectInstance()
|
||||
plugin_instance.delete(obj["ObjectGUID"])
|
||||
processed = True
|
||||
|
||||
elif isinstance(obj, dict) and "devGUID" in obj:
|
||||
mylog("debug", f"[WF] Updating Device '{obj}' ")
|
||||
device_instance = DeviceInstance(self.db)
|
||||
device_instance = DeviceInstance()
|
||||
device_instance.delete(obj["devGUID"])
|
||||
processed = True
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
|
||||
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
||||
@@ -25,82 +26,94 @@ def auth_headers(token):
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# --- get_device_info Tests ---
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
# --- Device Search Tests ---
|
||||
|
||||
@patch('models.device_instance.get_temp_db_connection')
|
||||
def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
|
||||
"""Test get_device_info with partial IP search."""
|
||||
mock_cursor = MagicMock()
|
||||
# Mock return of a device with IP ending in .50
|
||||
mock_cursor.fetchall.return_value = [
|
||||
"""Test device search with partial IP search."""
|
||||
# Mock database connection - DeviceInstance._fetchall calls conn.execute().fetchall()
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
mock_execute_result.fetchall.return_value = [
|
||||
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
|
||||
]
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_db_conn.return_value = mock_conn
|
||||
|
||||
payload = {"query": ".50"}
|
||||
response = client.post('/api/tools/get_device_info',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
devices = response.get_json()
|
||||
assert len(devices) == 1
|
||||
assert devices[0]["devLastIP"] == "192.168.1.50"
|
||||
|
||||
# Verify SQL query included 3 params (MAC, Name, IP)
|
||||
args, _ = mock_cursor.execute.call_args
|
||||
assert args[0].count("?") == 3
|
||||
assert len(args[1]) == 3
|
||||
|
||||
|
||||
# --- trigger_scan Tests ---
|
||||
@patch('subprocess.run')
|
||||
def test_trigger_scan_nmap_fast(mock_run, client, api_token):
|
||||
"""Test trigger_scan with nmap_fast."""
|
||||
mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0)
|
||||
|
||||
payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"}
|
||||
response = client.post('/api/tools/trigger_scan',
|
||||
response = client.post('/devices/search',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["success"] is True
|
||||
assert "nmap -F 192.168.1.1" in data["command"]
|
||||
mock_run.assert_called_once()
|
||||
assert len(data["devices"]) == 1
|
||||
assert data["devices"][0]["devLastIP"] == "192.168.1.50"
|
||||
|
||||
|
||||
@patch('subprocess.run')
|
||||
def test_trigger_scan_invalid_type(mock_run, client, api_token):
|
||||
"""Test trigger_scan with invalid scan_type."""
|
||||
payload = {"scan_type": "invalid_type", "target": "192.168.1.1"}
|
||||
response = client.post('/api/tools/trigger_scan',
|
||||
# --- Trigger Scan Tests ---
|
||||
|
||||
@patch('api_server.api_server_start.UserEventsQueueInstance')
|
||||
def test_trigger_scan_ARPSCAN(mock_queue_class, client, api_token):
|
||||
"""Test trigger_scan with ARPSCAN type."""
|
||||
mock_queue = MagicMock()
|
||||
mock_queue_class.return_value = mock_queue
|
||||
|
||||
payload = {"type": "ARPSCAN"}
|
||||
response = client.post('/mcp/sse/nettools/trigger-scan',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["success"] is True
|
||||
mock_queue.add_event.assert_called_once()
|
||||
call_args = mock_queue.add_event.call_args[0]
|
||||
assert "run|ARPSCAN" in call_args[0]
|
||||
|
||||
|
||||
@patch('api_server.api_server_start.UserEventsQueueInstance')
|
||||
def test_trigger_scan_invalid_type(mock_queue_class, client, api_token):
|
||||
"""Test trigger_scan with invalid scan type."""
|
||||
mock_queue = MagicMock()
|
||||
mock_queue_class.return_value = mock_queue
|
||||
|
||||
payload = {"type": "invalid_type", "target": "192.168.1.0/24"}
|
||||
response = client.post('/mcp/sse/nettools/trigger-scan',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 400
|
||||
mock_run.assert_not_called()
|
||||
data = response.get_json()
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
# --- get_open_ports Tests ---
|
||||
|
||||
|
||||
@patch('subprocess.run')
|
||||
def test_get_open_ports_ip(mock_run, client, api_token):
|
||||
@patch('models.plugin_object_instance.get_temp_db_connection')
|
||||
@patch('models.device_instance.get_temp_db_connection')
|
||||
def test_get_open_ports_ip(mock_plugin_db_conn, mock_device_db_conn, client, api_token):
|
||||
"""Test get_open_ports with an IP address."""
|
||||
mock_output = """
|
||||
Starting Nmap 7.80 ( https://nmap.org ) at 2023-10-27 10:00 UTC
|
||||
Nmap scan report for 192.168.1.1
|
||||
Host is up (0.0010s latency).
|
||||
Not shown: 98 closed ports
|
||||
PORT STATE SERVICE
|
||||
22/tcp open ssh
|
||||
80/tcp open http
|
||||
Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
|
||||
"""
|
||||
mock_run.return_value = MagicMock(stdout=mock_output, returncode=0)
|
||||
# Mock database connections for both device lookup and plugin objects
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
|
||||
# Mock for PluginObjectInstance.getByField (returns port data)
|
||||
mock_execute_result.fetchall.return_value = [
|
||||
{"Object_SecondaryID": "22", "Watched_Value2": "ssh"},
|
||||
{"Object_SecondaryID": "80", "Watched_Value2": "http"}
|
||||
]
|
||||
# Mock for DeviceInstance.getByIP (returns device with MAC)
|
||||
mock_execute_result.fetchone.return_value = {"devMac": "AA:BB:CC:DD:EE:FF"}
|
||||
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_plugin_db_conn.return_value = mock_conn
|
||||
mock_device_db_conn.return_value = mock_conn
|
||||
|
||||
payload = {"target": "192.168.1.1"}
|
||||
response = client.post('/api/tools/get_open_ports',
|
||||
response = client.post('/device/open_ports',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
@@ -112,43 +125,46 @@ Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds
|
||||
assert data["open_ports"][1]["service"] == "http"
|
||||
|
||||
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
@patch('subprocess.run')
|
||||
def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token):
|
||||
@patch('models.plugin_object_instance.get_temp_db_connection')
|
||||
def test_get_open_ports_mac_resolve(mock_plugin_db_conn, client, api_token):
|
||||
"""Test get_open_ports with a MAC address that resolves to an IP."""
|
||||
# Mock DB to resolve MAC to IP
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"}
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
|
||||
# Mock Nmap output
|
||||
mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0)
|
||||
# Mock database connection for MAC-based open ports query
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
mock_execute_result.fetchall.return_value = [
|
||||
{"Object_SecondaryID": "80", "Watched_Value2": "http"}
|
||||
]
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_plugin_db_conn.return_value = mock_conn
|
||||
|
||||
payload = {"target": "AA:BB:CC:DD:EE:FF"}
|
||||
response = client.post('/api/tools/get_open_ports',
|
||||
response = client.post('/device/open_ports',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["target"] == "192.168.1.50" # Should be resolved IP
|
||||
mock_run.assert_called_once()
|
||||
args, _ = mock_run.call_args
|
||||
assert "192.168.1.50" in args[0]
|
||||
assert data["success"] is True
|
||||
assert "target" in data
|
||||
assert len(data["open_ports"]) == 1
|
||||
assert data["open_ports"][0]["port"] == 80
|
||||
|
||||
|
||||
# --- get_network_topology Tests ---
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
@patch('models.device_instance.get_temp_db_connection')
|
||||
def test_get_network_topology(mock_db_conn, client, api_token):
|
||||
"""Test get_network_topology."""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
# Mock database connection for topology query
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
mock_execute_result.fetchall.return_value = [
|
||||
{"devName": "Router", "devMac": "AA:AA:AA:AA:AA:AA", "devParentMAC": None, "devParentPort": None, "devVendor": "VendorA"},
|
||||
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
|
||||
]
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_db_conn.return_value = mock_conn
|
||||
|
||||
response = client.get('/api/tools/get_network_topology',
|
||||
response = client.get('/devices/network/topology',
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -160,92 +176,71 @@ def test_get_network_topology(mock_db_conn, client, api_token):
|
||||
|
||||
|
||||
# --- get_recent_alerts Tests ---
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
@patch('models.event_instance.get_temp_db_connection')
|
||||
def test_get_recent_alerts(mock_db_conn, client, api_token):
|
||||
"""Test get_recent_alerts."""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchall.return_value = [
|
||||
{"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"}
|
||||
# Mock database connection for events query
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
mock_execute_result.fetchall.return_value = [
|
||||
{"eve_DateTime": now, "eve_EventType": "New Device", "eve_MAC": "AA:BB:CC:DD:EE:FF"}
|
||||
]
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_db_conn.return_value = mock_conn
|
||||
|
||||
payload = {"hours": 24}
|
||||
response = client.post('/api/tools/get_recent_alerts',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
response = client.get('/events/recent',
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["eve_EventType"] == "New Device"
|
||||
assert data["success"] is True
|
||||
assert data["hours"] == 24
|
||||
|
||||
|
||||
# --- set_device_alias Tests ---
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
def test_set_device_alias(mock_db_conn, client, api_token):
|
||||
# --- Device Alias Tests ---
|
||||
|
||||
@patch('api_server.api_server_start.update_device_column')
|
||||
def test_set_device_alias(mock_update_col, client, api_token):
|
||||
"""Test set_device_alias."""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 1 # Simulate successful update
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
mock_update_col.return_value = {"success": True, "message": "Device alias updated"}
|
||||
|
||||
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
||||
response = client.post('/api/tools/set_device_alias',
|
||||
payload = {"alias": "New Device Name"}
|
||||
response = client.post('/device/AA:BB:CC:DD:EE:FF/set-alias',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["success"] is True
|
||||
mock_update_col.assert_called_once_with("AA:BB:CC:DD:EE:FF", "devName", "New Device Name")
|
||||
|
||||
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
def test_set_device_alias_not_found(mock_db_conn, client, api_token):
|
||||
@patch('api_server.api_server_start.update_device_column')
|
||||
def test_set_device_alias_not_found(mock_update_col, client, api_token):
|
||||
"""Test set_device_alias when device is not found."""
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.rowcount = 0 # Simulate no rows updated
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
mock_update_col.return_value = {"success": False, "error": "Device not found"}
|
||||
|
||||
payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"}
|
||||
response = client.post('/api/tools/set_device_alias',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# --- wol_wake_device Tests ---
|
||||
@patch('subprocess.run')
|
||||
def test_wol_wake_device(mock_subprocess, client, api_token):
|
||||
"""Test wol_wake_device."""
|
||||
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
|
||||
mock_subprocess.return_value.returncode = 0
|
||||
|
||||
payload = {"mac": "AA:BB:CC:DD:EE:FF"}
|
||||
response = client.post('/api/tools/wol_wake_device',
|
||||
payload = {"alias": "New Device Name"}
|
||||
response = client.post('/device/FF:FF:FF:FF:FF:FF/set-alias',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["success"] is True
|
||||
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
||||
assert data["success"] is False
|
||||
assert "Device not found" in data["error"]
|
||||
|
||||
|
||||
@patch('api_server.tools_routes.get_temp_db_connection')
|
||||
@patch('subprocess.run')
|
||||
def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token):
|
||||
"""Test wol_wake_device with IP address."""
|
||||
# Mock DB for IP resolution
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = {"devMac": "AA:BB:CC:DD:EE:FF"}
|
||||
mock_db_conn.return_value.cursor.return_value = mock_cursor
|
||||
# --- Wake-on-LAN Tests ---
|
||||
|
||||
# Mock subprocess
|
||||
mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF"
|
||||
mock_subprocess.return_value.returncode = 0
|
||||
@patch('api_server.api_server_start.wakeonlan')
|
||||
def test_wol_wake_device(mock_wakeonlan, client, api_token):
|
||||
"""Test wol_wake_device."""
|
||||
mock_wakeonlan.return_value = {"success": True, "message": "WOL packet sent to AA:BB:CC:DD:EE:FF"}
|
||||
|
||||
payload = {"ip": "192.168.1.50"}
|
||||
response = client.post('/api/tools/wol_wake_device',
|
||||
payload = {"devMac": "AA:BB:CC:DD:EE:FF"}
|
||||
response = client.post('/nettools/wakeonlan',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
@@ -254,34 +249,58 @@ def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token)
|
||||
assert data["success"] is True
|
||||
assert "AA:BB:CC:DD:EE:FF" in data["message"]
|
||||
|
||||
# Verify DB lookup
|
||||
mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",))
|
||||
|
||||
# Verify subprocess call
|
||||
mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True)
|
||||
|
||||
|
||||
def test_wol_wake_device_invalid_mac(client, api_token):
|
||||
"""Test wol_wake_device with invalid MAC."""
|
||||
payload = {"mac": "invalid-mac"}
|
||||
response = client.post('/api/tools/wol_wake_device',
|
||||
payload = {"devMac": "invalid-mac"}
|
||||
response = client.post('/nettools/wakeonlan',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
# --- openapi_spec Tests ---
|
||||
def test_openapi_spec(client):
|
||||
"""Test openapi_spec endpoint contains new paths."""
|
||||
response = client.get('/api/tools/openapi.json')
|
||||
# --- OpenAPI Spec Tests ---
|
||||
|
||||
# --- Latest Device Tests ---
|
||||
|
||||
@patch('models.device_instance.get_temp_db_connection')
|
||||
def test_get_latest_device(mock_db_conn, client, api_token):
|
||||
"""Test get_latest_device endpoint."""
|
||||
# Mock database connection for latest device query
|
||||
mock_conn = MagicMock()
|
||||
mock_execute_result = MagicMock()
|
||||
mock_execute_result.fetchone.return_value = {
|
||||
"devName": "Latest Device",
|
||||
"devMac": "AA:BB:CC:DD:EE:FF",
|
||||
"devLastIP": "192.168.1.100",
|
||||
"devFirstConnection": "2025-12-07 10:30:00"
|
||||
}
|
||||
mock_conn.execute.return_value = mock_execute_result
|
||||
mock_db_conn.return_value = mock_conn
|
||||
|
||||
response = client.get('/devices/latest',
|
||||
headers=auth_headers(api_token))
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["devName"] == "Latest Device"
|
||||
assert data[0]["devMac"] == "AA:BB:CC:DD:EE:FF"
|
||||
|
||||
|
||||
def test_openapi_spec(client, api_token):
|
||||
"""Test openapi_spec endpoint contains MCP tool paths."""
|
||||
response = client.get('/mcp/sse/openapi.json', headers=auth_headers(api_token))
|
||||
assert response.status_code == 200
|
||||
spec = response.get_json()
|
||||
|
||||
# Check for new endpoints
|
||||
assert "/trigger_scan" in spec["paths"]
|
||||
assert "/get_open_ports" in spec["paths"]
|
||||
assert "/get_network_topology" in spec["paths"]
|
||||
assert "/get_recent_alerts" in spec["paths"]
|
||||
assert "/set_device_alias" in spec["paths"]
|
||||
assert "/wol_wake_device" in spec["paths"]
|
||||
# Check for MCP tool endpoints in the spec with correct paths
|
||||
assert "/nettools/trigger-scan" in spec["paths"]
|
||||
assert "/device/open_ports" in spec["paths"]
|
||||
assert "/devices/network/topology" in spec["paths"]
|
||||
assert "/events/recent" in spec["paths"]
|
||||
assert "/device/{mac}/set-alias" in spec["paths"]
|
||||
assert "/nettools/wakeonlan" in spec["paths"]
|
||||
@@ -1,79 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
|
||||
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
||||
|
||||
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
|
||||
from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_token():
|
||||
return get_setting_value("API_TOKEN")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with app.test_client() as client:
|
||||
yield client
|
||||
|
||||
|
||||
def auth_headers(token):
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def test_openapi_spec(client):
|
||||
"""Test OpenAPI spec endpoint."""
|
||||
response = client.get('/api/tools/openapi.json')
|
||||
assert response.status_code == 200
|
||||
spec = response.get_json()
|
||||
assert "openapi" in spec
|
||||
assert "info" in spec
|
||||
assert "paths" in spec
|
||||
assert "/list_devices" in spec["paths"]
|
||||
assert "/get_device_info" in spec["paths"]
|
||||
|
||||
|
||||
def test_list_devices(client, api_token):
|
||||
"""Test list_devices endpoint."""
|
||||
response = client.post('/api/tools/list_devices', headers=auth_headers(api_token))
|
||||
assert response.status_code == 200
|
||||
devices = response.get_json()
|
||||
assert isinstance(devices, list)
|
||||
# If there are devices, check structure
|
||||
if devices:
|
||||
device = devices[0]
|
||||
assert "devName" in device
|
||||
assert "devMac" in device
|
||||
|
||||
|
||||
def test_get_device_info(client, api_token):
|
||||
"""Test get_device_info endpoint."""
|
||||
# Test with a query that might not exist
|
||||
payload = {"query": "nonexistent_device"}
|
||||
response = client.post('/api/tools/get_device_info',
|
||||
json=payload,
|
||||
headers=auth_headers(api_token))
|
||||
# Should return 404 if no match, or 200 with results
|
||||
assert response.status_code in [200, 404]
|
||||
if response.status_code == 200:
|
||||
devices = response.get_json()
|
||||
assert isinstance(devices, list)
|
||||
elif response.status_code == 404:
|
||||
# Expected for no matches
|
||||
pass
|
||||
|
||||
|
||||
def test_list_devices_unauthorized(client):
|
||||
"""Test list_devices without authorization."""
|
||||
response = client.post('/api/tools/list_devices')
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_get_device_info_unauthorized(client):
|
||||
"""Test get_device_info without authorization."""
|
||||
payload = {"query": "test"}
|
||||
response = client.post('/api/tools/get_device_info', json=payload)
|
||||
assert response.status_code == 401
|
||||
Reference in New Issue
Block a user