Merge pull request #1476 from adamoutler/openapi-mcp-improvements

Standardize and Optimize OpenAPI & MCP for AI Agents
This commit is contained in:
Jokob @NetAlertX
2026-01-31 14:35:24 +11:00
committed by GitHub
12 changed files with 746 additions and 143 deletions

View File

@@ -7,6 +7,7 @@ import os
from flask import Flask, request, jsonify, Response from flask import Flask, request, jsonify, Response
from models.device_instance import DeviceInstance # noqa: E402 from models.device_instance import DeviceInstance # noqa: E402
from flask_cors import CORS from flask_cors import CORS
from werkzeug.exceptions import HTTPException
# Register NetAlertX directories # Register NetAlertX directories
INSTALL_PATH = os.getenv("NETALERTX_APP", "/app") INSTALL_PATH = os.getenv("NETALERTX_APP", "/app")
@@ -59,7 +60,8 @@ from .mcp_endpoint import (
mcp_sse, mcp_sse,
mcp_messages, mcp_messages,
openapi_spec, openapi_spec,
) # noqa: E402 [flake8 lint suppression] get_openapi_spec,
)
# validation and schemas for MCP v2 # validation and schemas for MCP v2
from .openapi.validation import validate_request # noqa: E402 [flake8 lint suppression] from .openapi.validation import validate_request # noqa: E402 [flake8 lint suppression]
from .openapi.schemas import ( # noqa: E402 [flake8 lint suppression] from .openapi.schemas import ( # noqa: E402 [flake8 lint suppression]
@@ -100,6 +102,20 @@ from .sse_endpoint import ( # noqa: E402 [flake8 lint suppression]
app = Flask(__name__) app = Flask(__name__)
@app.errorhandler(500)
@app.errorhandler(Exception)
def handle_500_error(e):
"""Global error handler for uncaught exceptions."""
if isinstance(e, HTTPException):
return e
mylog("none", [f"[API] Uncaught exception: {e}"])
return jsonify({
"success": False,
"error": "Internal Server Error",
"message": "Something went wrong on the server"
}), 500
# Parse CORS origins from environment or use safe defaults # Parse CORS origins from environment or use safe defaults
_cors_origins_env = os.environ.get("CORS_ORIGINS", "") _cors_origins_env = os.environ.get("CORS_ORIGINS", "")
_cors_origins = [ _cors_origins = [
@@ -599,7 +615,7 @@ def api_device_open_ports(payload=None):
@validate_request( @validate_request(
operation_id="get_all_devices", operation_id="get_all_devices",
summary="Get All Devices", summary="Get All Devices",
description="Retrieve a list of all devices in the system.", description="Retrieve a list of all devices in the system. Returns all records. No pagination supported.",
response_model=DeviceListWrapperResponse, response_model=DeviceListWrapperResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized
@@ -662,7 +678,7 @@ def api_delete_unknown_devices(payload=None):
@app.route("/devices/export", methods=["GET"]) @app.route("/devices/export", methods=["GET"])
@app.route("/devices/export/<format>", methods=["GET"]) @app.route("/devices/export/<format>", methods=["GET"])
@validate_request( @validate_request(
operation_id="export_devices", operation_id="export_devices_all",
summary="Export Devices", summary="Export Devices",
description="Export all devices in CSV or JSON format.", description="Export all devices in CSV or JSON format.",
query_params=[{ query_params=[{
@@ -679,7 +695,8 @@ def api_delete_unknown_devices(payload=None):
}], }],
response_model=DeviceExportResponse, response_model=DeviceExportResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized,
response_content_types=["application/json", "text/csv"]
) )
def api_export_devices(format=None, payload=None): def api_export_devices(format=None, payload=None):
export_format = (format or request.args.get("format", "csv")).lower() export_format = (format or request.args.get("format", "csv")).lower()
@@ -747,7 +764,7 @@ def api_devices_totals(payload=None):
@app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST']) @app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST'])
@app.route("/devices/by-status", methods=["GET", "POST"]) @app.route("/devices/by-status", methods=["GET", "POST"])
@validate_request( @validate_request(
operation_id="list_devices_by_status", operation_id="list_devices_by_status_api",
summary="List Devices by Status", summary="List Devices by Status",
description="List devices filtered by their online/offline status.", description="List devices filtered by their online/offline status.",
request_model=DeviceListRequest, request_model=DeviceListRequest,
@@ -763,7 +780,30 @@ def api_devices_totals(payload=None):
"connected", "down", "favorites", "new", "archived", "all", "my", "connected", "down", "favorites", "new", "archived", "all", "my",
"offline" "offline"
]} ]}
}] }],
links={
"GetOpenPorts": {
"operationId": "get_open_ports",
"parameters": {
"target": "$response.body#/0/devLastIP"
},
"description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the list."
},
"WakeOnLan": {
"operationId": "wake_on_lan",
"parameters": {
"devMac": "$response.body#/0/devMac"
},
"description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the list."
},
"UpdateDevice": {
"operationId": "update_device",
"parameters": {
"mac": "$response.body#/0/devMac"
},
"description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the list."
}
}
) )
def api_devices_by_status(payload: DeviceListRequest = None): def api_devices_by_status(payload: DeviceListRequest = None):
status = payload.status if payload else request.args.get("status") status = payload.status if payload else request.args.get("status")
@@ -774,13 +814,43 @@ def api_devices_by_status(payload: DeviceListRequest = None):
@app.route('/mcp/sse/devices/search', methods=['POST']) @app.route('/mcp/sse/devices/search', methods=['POST'])
@app.route('/devices/search', methods=['POST']) @app.route('/devices/search', methods=['POST'])
@validate_request( @validate_request(
operation_id="search_devices", operation_id="search_devices_api",
summary="Search Devices", summary="Search Devices",
description="Search for devices based on various criteria like name, IP, MAC, or vendor.", description="Search for devices based on various criteria like name, IP, MAC, or vendor.",
request_model=DeviceSearchRequest, request_model=DeviceSearchRequest,
response_model=DeviceSearchResponse, response_model=DeviceSearchResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized,
links={
"GetOpenPorts": {
"operationId": "get_open_ports",
"parameters": {
"target": "$response.body#/devices/0/devLastIP"
},
"description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the search results."
},
"WakeOnLan": {
"operationId": "wake_on_lan",
"parameters": {
"devMac": "$response.body#/devices/0/devMac"
},
"description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the search results."
},
"NmapScan": {
"operationId": "run_nmap_scan",
"parameters": {
"scan": "$response.body#/devices/0/devLastIP"
},
"description": "The `scan` parameter for `run_nmap_scan` requires an IP or range. Use the `devLastIP` from the first device in the search results."
},
"UpdateDevice": {
"operationId": "update_device",
"parameters": {
"mac": "$response.body#/devices/0/devMac"
},
"description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the search results."
}
}
) )
def api_devices_search(payload=None): def api_devices_search(payload=None):
"""Device search: accepts 'query' in JSON and maps to device info/search.""" """Device search: accepts 'query' in JSON and maps to device info/search."""
@@ -884,9 +954,13 @@ def api_devices_network_topology(payload=None):
auth_callable=is_authorized auth_callable=is_authorized
) )
def api_wakeonlan(payload=None): def api_wakeonlan(payload=None):
data = request.get_json(silent=True) or {} if payload:
mac = data.get("devMac") mac = payload.mac
ip = data.get("devLastIP") or data.get('ip') ip = payload.devLastIP
else:
data = request.get_json(silent=True) or {}
mac = data.get("mac") or data.get("devMac")
ip = data.get("devLastIP") or data.get('ip')
if not mac and ip: if not mac and ip:
@@ -1011,7 +1085,7 @@ def api_network_interfaces(payload=None):
@app.route('/mcp/sse/nettools/trigger-scan', methods=['POST']) @app.route('/mcp/sse/nettools/trigger-scan', methods=['POST'])
@app.route("/nettools/trigger-scan", methods=["GET"]) @app.route("/nettools/trigger-scan", methods=["GET", "POST"])
@validate_request( @validate_request(
operation_id="trigger_network_scan", operation_id="trigger_network_scan",
summary="Trigger Network Scan", summary="Trigger Network Scan",
@@ -1300,13 +1374,25 @@ def api_create_event(mac, payload=None):
@app.route("/events/<mac>", methods=["DELETE"]) @app.route("/events/<mac>", methods=["DELETE"])
@validate_request( @validate_request(
operation_id="delete_events_by_mac", operation_id="delete_events",
summary="Delete Events by MAC", summary="Delete Events",
description="Delete all events for a specific device MAC address.", description="Delete events by device MAC address or older than a specified number of days.",
path_params=[{ path_params=[{
"name": "mac", "name": "mac",
"description": "Device MAC address", "description": "Device MAC address or number of days",
"schema": {"type": "string"} "schema": {
"oneOf": [
{
"type": "integer",
"description": "Number of days (e.g., 30) to delete events older than this value."
},
{
"type": "string",
"pattern": "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$",
"description": "Device MAC address to delete all events for a specific device."
}
]
}
}], }],
response_model=BaseResponse, response_model=BaseResponse,
tags=["events"], tags=["events"],
@@ -1315,6 +1401,7 @@ def api_create_event(mac, payload=None):
def api_events_by_mac(mac, payload=None): def api_events_by_mac(mac, payload=None):
"""Delete events for a specific device MAC; string converter keeps this distinct from /events/<int:days>.""" """Delete events for a specific device MAC; string converter keeps this distinct from /events/<int:days>."""
device_handler = DeviceInstance() device_handler = DeviceInstance()
result = device_handler.deleteDeviceEvents(mac) result = device_handler.deleteDeviceEvents(mac)
return jsonify(result) return jsonify(result)
@@ -1338,7 +1425,7 @@ def api_delete_all_events(payload=None):
@validate_request( @validate_request(
operation_id="get_all_events", operation_id="get_all_events",
summary="Get Events", summary="Get Events",
description="Retrieve a list of events, optionally filtered by MAC.", description="Retrieve a list of events, optionally filtered by MAC. Returns all matching records. No pagination supported.",
query_params=[{ query_params=[{
"name": "mac", "name": "mac",
"description": "Filter by Device MAC", "description": "Filter by Device MAC",
@@ -1372,7 +1459,8 @@ def api_get_events(payload=None):
}], }],
response_model=BaseResponse, response_model=BaseResponse,
tags=["events"], tags=["events"],
auth_callable=is_authorized auth_callable=is_authorized,
exclude_from_spec=True
) )
def api_delete_old_events(days: int, payload=None): def api_delete_old_events(days: int, payload=None):
""" """
@@ -1406,7 +1494,7 @@ def api_get_events_totals(payload=None):
@app.route('/mcp/sse/events/recent', methods=['GET', 'POST']) @app.route('/mcp/sse/events/recent', methods=['GET', 'POST'])
@app.route('/events/recent', methods=['GET']) @app.route('/events/recent', methods=['GET', 'POST'])
@validate_request( @validate_request(
operation_id="get_recent_events", operation_id="get_recent_events",
summary="Get Recent Events", summary="Get Recent Events",
@@ -1426,7 +1514,7 @@ def api_events_default_24h(payload=None):
@app.route('/mcp/sse/events/last', methods=['GET', 'POST']) @app.route('/mcp/sse/events/last', methods=['GET', 'POST'])
@app.route('/events/last', methods=['GET']) @app.route('/events/last', methods=['GET', 'POST'])
@validate_request( @validate_request(
operation_id="get_last_events", operation_id="get_last_events",
summary="Get Last Events", summary="Get Last Events",
@@ -1763,7 +1851,7 @@ def sync_endpoint_post(payload=None):
@validate_request( @validate_request(
operation_id="check_auth", operation_id="check_auth",
summary="Check Authentication", summary="Check Authentication",
description="Check if the current API token is valid.", description="Check if the current API token is valid. Note: tokens must be generated externally via the UI or CLI.",
response_model=BaseResponse, response_model=BaseResponse,
tags=["auth"], tags=["auth"],
auth_callable=is_authorized auth_callable=is_authorized
@@ -1778,6 +1866,14 @@ def check_auth(payload=None):
# Mount SSE endpoints after is_authorized is defined (avoid circular import) # Mount SSE endpoints after is_authorized is defined (avoid circular import)
create_sse_endpoint(app, is_authorized) create_sse_endpoint(app, is_authorized)
# Apply environment-driven MCP disablement by regenerating the OpenAPI spec.
# This populates the registry and applies any operation IDs listed in MCP_DISABLED_TOOLS.
try:
get_openapi_spec(force_refresh=True, flask_app=app)
mylog("verbose", [f"[MCP] Applied MCP_DISABLED_TOOLS: {os.environ.get('MCP_DISABLED_TOOLS', '')}"])
except Exception as e:
mylog("none", [f"[MCP] Error applying MCP_DISABLED_TOOLS: {e}"])
def start_server(graphql_port, app_state): def start_server(graphql_port, app_state):
"""Start the GraphQL server in a background thread.""" """Start the GraphQL server in a background thread."""

View File

@@ -309,6 +309,7 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
This function transforms OpenAPI operations into MCP-compatible tool schemas, This function transforms OpenAPI operations into MCP-compatible tool schemas,
ensuring proper inputSchema derivation from request bodies and parameters. ensuring proper inputSchema derivation from request bodies and parameters.
It deduplicates tools by their original operationId, preferring /mcp/ routes.
Args: Args:
spec: OpenAPI specification dictionary spec: OpenAPI specification dictionary
@@ -316,10 +317,10 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
Returns: Returns:
List of MCP tool definitions with name, description, and inputSchema List of MCP tool definitions with name, description, and inputSchema
""" """
tools = [] tools_map = {}
if not spec or "paths" not in spec: if not spec or "paths" not in spec:
return tools return []
for path, methods in spec["paths"].items(): for path, methods in spec["paths"].items():
for method, details in methods.items(): for method, details in methods.items():
@@ -327,6 +328,9 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
continue continue
operation_id = details["operationId"] operation_id = details["operationId"]
# Deduplicate using the original operationId (before suffixing)
# or the unique operationId as fallback.
original_op_id = details.get("x-original-operationId", operation_id)
# Build inputSchema from requestBody and parameters # Build inputSchema from requestBody and parameters
input_schema = { input_schema = {
@@ -382,31 +386,82 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]:
tool = { tool = {
"name": operation_id, "name": operation_id,
"description": details.get("description", details.get("summary", "")), "description": details.get("description", details.get("summary", "")),
"inputSchema": input_schema "inputSchema": input_schema,
"_original_op_id": original_op_id,
"_is_mcp": path.startswith("/mcp/"),
"_is_post": method.upper() == "POST"
} }
tools.append(tool) # Preference logic for deduplication:
# 1. Prefer /mcp/ routes over standard ones.
# 2. Prefer POST methods over GET for the same logic (usually more robust body validation).
existing = tools_map.get(original_op_id)
if not existing:
tools_map[original_op_id] = tool
else:
# Upgrade if current is MCP and existing is not
mcp_upgrade = tool["_is_mcp"] and not existing["_is_mcp"]
# Upgrade if same route type but current is POST and existing is GET
method_upgrade = (tool["_is_mcp"] == existing["_is_mcp"]) and tool["_is_post"] and not existing["_is_post"]
if mcp_upgrade or method_upgrade:
tools_map[original_op_id] = tool
return tools # Final cleanup: remove internal preference flags and ensure tools have the original names
# unless we explicitly want the suffixed ones.
# The user said "Eliminate Duplicate Tool Names", so we should use original_op_id as the tool name.
final_tools = []
_tool_name_to_operation_id: Dict[str, str] = {}
for tool in tools_map.values():
actual_operation_id = tool["name"] # Save before overwriting
tool["name"] = tool["_original_op_id"]
_tool_name_to_operation_id[tool["name"]] = actual_operation_id
del tool["_original_op_id"]
del tool["_is_mcp"]
del tool["_is_post"]
final_tools.append(tool)
return final_tools
def find_route_for_tool(tool_name: str) -> Optional[Dict[str, Any]]: def find_route_for_tool(tool_name: str) -> Optional[Dict[str, Any]]:
""" """
Find the registered route for a given tool name (operationId). Find the registered route for a given tool name (operationId).
Handles exact matches and deduplicated original IDs.
Args: Args:
tool_name: The operationId to look up tool_name: The operationId or original_operation_id to look up
Returns: Returns:
Route dictionary with path, method, and models, or None if not found Route dictionary with path, method, and models, or None if not found
""" """
registry = get_registry() registry = get_registry()
candidates = []
for entry in registry: for entry in registry:
# Exact match (priority) - if the client passed the specific suffixed ID
if entry["operation_id"] == tool_name: if entry["operation_id"] == tool_name:
return entry return entry
if entry.get("original_operation_id") == tool_name:
candidates.append(entry)
return None if not candidates:
return None
# Apply same preference logic as map_openapi_to_mcp_tools to ensure we pick the
# same route definition that generated the tool schema.
# Priority 1: MCP routes (they have specialized paths/behavior)
mcp_candidates = [c for c in candidates if c["path"].startswith("/mcp/")]
pool = mcp_candidates if mcp_candidates else candidates
# Priority 2: POST methods (usually preferred for tools)
post_candidates = [c for c in pool if c["method"].upper() == "POST"]
if post_candidates:
return post_candidates[0]
# Fallback: return the first from the best pool available
return pool[0]
# ============================================================================= # =============================================================================

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Any from typing import Any, Dict, Optional
import graphene import graphene
from .registry import register_tool, _operation_ids from .registry import register_tool, _operation_ids
from .schemas import GraphQLRequest
from .schema_converter import pydantic_to_json_schema, resolve_schema_refs
def introspect_graphql_schema(schema: graphene.Schema): def introspect_graphql_schema(schema: graphene.Schema):
@@ -26,6 +28,7 @@ def introspect_graphql_schema(schema: graphene.Schema):
operation_id="graphql_query", operation_id="graphql_query",
summary="GraphQL Endpoint", summary="GraphQL Endpoint",
description="Execute arbitrary GraphQL queries against the system schema.", description="Execute arbitrary GraphQL queries against the system schema.",
request_model=GraphQLRequest,
tags=["graphql"] tags=["graphql"]
) )
@@ -36,6 +39,20 @@ def _flask_to_openapi_path(flask_path: str) -> str:
return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path) return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path)
def _get_openapi_metadata(func: Any) -> Optional[Dict[str, Any]]:
"""Recursively find _openapi_metadata in wrapped functions."""
# Check current function
metadata = getattr(func, "_openapi_metadata", None)
if metadata:
return metadata
# Check __wrapped__ (standard for @wraps)
if hasattr(func, "__wrapped__"):
return _get_openapi_metadata(func.__wrapped__)
return None
def introspect_flask_app(app: Any): def introspect_flask_app(app: Any):
""" """
Introspect the Flask application to find routes decorated with @validate_request Introspect the Flask application to find routes decorated with @validate_request
@@ -47,14 +64,13 @@ def introspect_flask_app(app: Any):
if not view_func: if not view_func:
continue continue
# Check for our decorator's metadata # Check for our decorator's metadata recursively
metadata = getattr(view_func, "_openapi_metadata", None) metadata = _get_openapi_metadata(view_func)
if not metadata:
# Fallback for wrapped functions
if hasattr(view_func, "__wrapped__"):
metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None)
if metadata: if metadata:
if metadata.get("exclude_from_spec"):
continue
op_id = metadata["operation_id"] op_id = metadata["operation_id"]
# Register the tool with real path and method from Flask # Register the tool with real path and method from Flask
@@ -75,20 +91,72 @@ def introspect_flask_app(app: Any):
# Determine tags - create a copy to avoid mutating shared metadata # Determine tags - create a copy to avoid mutating shared metadata
tags = list(metadata.get("tags") or ["rest"]) tags = list(metadata.get("tags") or ["rest"])
if path.startswith("/mcp/"): if path.startswith("/mcp/"):
# Move specific tags to secondary position or just add MCP # For MCP endpoints, we want them exclusively in the 'mcp' tag section
if "rest" in tags: tags = ["mcp"]
tags.remove("rest")
if "mcp" not in tags:
tags.append("mcp")
# Ensure unique operationId # Ensure unique operationId
original_op_id = op_id original_op_id = op_id
unique_op_id = op_id unique_op_id = op_id
# Semantic naming strategy for duplicates
if unique_op_id in _operation_ids:
# Construct a semantic suffix to replace numeric ones
# Priority: /mcp/ prefix and HTTP method
suffix = ""
if path.startswith("/mcp/"):
suffix = "_mcp"
if method.upper() == "POST":
suffix += "_post"
elif method.upper() == "GET":
suffix += "_get"
if suffix:
candidate = f"{op_id}{suffix}"
if candidate not in _operation_ids:
unique_op_id = candidate
# Fallback to numeric suffixes if semantic naming didn't ensure uniqueness
count = 1 count = 1
while unique_op_id in _operation_ids: while unique_op_id in _operation_ids:
unique_op_id = f"{op_id}_{count}" unique_op_id = f"{op_id}_{count}"
count += 1 count += 1
# Filter path_params to only include those that are actually in the path
path_params = metadata.get("path_params")
if path_params:
path_params = [
p for p in path_params
if f"{{{p['name']}}}" in path
]
# Auto-generate query_params from request_model for GET requests
query_params = metadata.get("query_params")
if method == 'GET' and not query_params and metadata.get("request_model"):
try:
schema = pydantic_to_json_schema(metadata["request_model"])
defs = schema.get("$defs", {})
properties = schema.get("properties", {})
query_params = []
for name, prop in properties.items():
is_required = name in schema.get("required", [])
# Resolve references to inlined definitions (preserving Enums)
resolved_prop = resolve_schema_refs(prop, defs)
# Create param definition
param_def = {
"name": name,
"in": "query",
"required": is_required,
"description": prop.get("description", ""),
"schema": resolved_prop
}
# Remove description from schema to avoid duplication
if "description" in param_def["schema"]:
del param_def["schema"]["description"]
query_params.append(param_def)
except Exception:
pass # Fallback to empty if schema generation fails
register_tool( register_tool(
path=path, path=path,
method=method, method=method,
@@ -98,9 +166,11 @@ def introspect_flask_app(app: Any):
description=metadata["description"], description=metadata["description"],
request_model=metadata.get("request_model"), request_model=metadata.get("request_model"),
response_model=metadata.get("response_model"), response_model=metadata.get("response_model"),
path_params=metadata.get("path_params"), path_params=path_params,
query_params=metadata.get("query_params"), query_params=query_params,
tags=tags, tags=tags,
allow_multipart_payload=metadata.get("allow_multipart_payload", False) allow_multipart_payload=metadata.get("allow_multipart_payload", False),
response_content_types=metadata.get("response_content_types"),
links=metadata.get("links")
) )
registered_ops.add(op_key) registered_ops.add(op_key)

View File

@@ -96,7 +96,9 @@ def register_tool(
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
deprecated: bool = False, deprecated: bool = False,
original_operation_id: Optional[str] = None, original_operation_id: Optional[str] = None,
allow_multipart_payload: bool = False allow_multipart_payload: bool = False,
response_content_types: Optional[List[str]] = None,
links: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
""" """
Register an API endpoint for OpenAPI spec generation. Register an API endpoint for OpenAPI spec generation.
@@ -115,6 +117,8 @@ def register_tool(
deprecated: Whether this endpoint is deprecated deprecated: Whether this endpoint is deprecated
original_operation_id: The base ID before suffixing (for disablement mapping) original_operation_id: The base ID before suffixing (for disablement mapping)
allow_multipart_payload: Whether to allow multipart/form-data payloads allow_multipart_payload: Whether to allow multipart/form-data payloads
response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"])
links: Dictionary of OpenAPI links to include in the response definition.
Raises: Raises:
DuplicateOperationIdError: If operation_id already exists in registry DuplicateOperationIdError: If operation_id already exists in registry
@@ -140,7 +144,9 @@ def register_tool(
"query_params": query_params or [], "query_params": query_params or [],
"tags": tags or ["default"], "tags": tags or ["default"],
"deprecated": deprecated, "deprecated": deprecated,
"allow_multipart_payload": allow_multipart_payload "allow_multipart_payload": allow_multipart_payload,
"response_content_types": response_content_types or ["application/json"],
"links": links
}) })

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Dict, Any, Optional, Type, List from typing import Dict, Any, Optional, Type, List
from pydantic import BaseModel from pydantic import BaseModel
from .schemas import ErrorResponse, BaseResponse
def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]: def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]:
@@ -161,57 +162,124 @@ def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]:
return clean_schema return clean_schema
def resolve_schema_refs(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively resolve $ref in schema by inlining the definition.
Useful for standalone schema parts like query parameters where global definitions aren't available.
"""
if not isinstance(schema, dict):
return schema
if "$ref" in schema:
ref = schema["$ref"]
# Handle #/$defs/Name syntax
if ref.startswith("#/$defs/"):
def_name = ref.split("/")[-1]
if def_name in definitions:
# Inline the definition (and resolve its refs recursively)
inlined = resolve_schema_refs(definitions[def_name], definitions)
# Merge any extra keys from the original schema (e.g. description override)
# Schema keys take precedence over definition keys
return {**inlined, **{k: v for k, v in schema.items() if k != "$ref"}}
# Recursively resolve properties
resolved = {}
for k, v in schema.items():
if k == "items":
resolved[k] = resolve_schema_refs(v, definitions)
elif k == "properties":
resolved[k] = {pk: resolve_schema_refs(pv, definitions) for pk, pv in v.items()}
elif k in ("allOf", "anyOf", "oneOf"):
resolved[k] = [resolve_schema_refs(i, definitions) for i in v]
else:
resolved[k] = v
return resolved
def build_responses( def build_responses(
response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any] response_model: Optional[Type[BaseModel]],
definitions: Dict[str, Any],
response_content_types: Optional[List[str]] = None,
links: Optional[Dict[str, Any]] = None,
method: str = "post"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Build OpenAPI responses object.""" """Build OpenAPI responses object."""
responses = {} responses = {}
# Success response (200) # Use a fresh list for response content types to avoid a shared mutable default.
if response_model: if response_content_types is None:
# Strip validation from response schema to save tokens response_content_types = ["application/json"]
schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization"))
schema = extract_definitions(schema, definitions)
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": schema
}
}
}
else: else:
responses["200"] = { # Copy provided list to ensure each call gets its own list
"description": "Successful response", response_content_types = list(response_content_types)
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"message": {"type": "string"}
}
}
}
}
}
# Standard error responses - MINIMIZED context # Success response (200)
# Annotate that these errors can occur, but provide no schema/content to save tokens. effective_model = response_model or BaseResponse
# The LLM knows what "Bad Request" or "Not Found" means. schema = strip_validation(pydantic_to_json_schema(effective_model, mode="serialization"))
error_codes = { schema = extract_definitions(schema, definitions)
"400": "Bad Request",
"401": "Unauthorized", content = {}
"403": "Forbidden", for ct in response_content_types:
"404": "Not Found", if ct == "application/json":
"422": "Validation Error", content[ct] = {"schema": schema}
"500": "Internal Server Error" else:
# For non-JSON types like CSV, we don't necessarily use the JSON schema
content[ct] = {"schema": {"type": "string", "format": "binary"}}
response_obj = {
"description": "Successful response",
"content": content
}
if links:
response_obj["links"] = links
responses["200"] = response_obj
# Standard error responses
error_configs = {
"400": ("Invalid JSON", "Request body must be valid JSON"),
"401": ("Unauthorized", None),
"403": ("Forbidden", "ERROR: Not authorized"),
"404": ("API route not found", "The requested URL /example/path was not found on the server."),
"422": ("Validation Error", None),
"500": ("Internal Server Error", "Something went wrong on the server")
} }
for code, desc in error_codes.items(): for code, (error_val, message_val) in error_configs.items():
# Generate a fresh schema for each error to customize examples
error_schema_raw = strip_validation(pydantic_to_json_schema(ErrorResponse, mode="serialization"))
error_schema = extract_definitions(error_schema_raw, definitions)
# Inject status-specific example
if "examples" in error_schema and len(error_schema["examples"]) > 0:
example = {
"success": False,
"error": error_val
}
if message_val:
example["message"] = message_val
if code == "422":
example["error"] = "Validation Error: Input should be a valid string"
example["details"] = [
{
"input": "invalid_value",
"loc": ["field_name"],
"msg": "Input should be a valid string",
"type": "string_type",
"url": "https://errors.pydantic.dev/2.12/v/string_type"
}
]
error_schema["examples"] = [example]
responses[code] = { responses[code] = {
"description": desc "description": error_val,
# No "content" schema provided "content": {
"application/json": {
"schema": error_schema
}
}
} }
return responses return responses

View File

@@ -52,6 +52,16 @@ ALLOWED_LOG_FILES = Literal[
"app.php_errors.log", "execution_queue.log", "db_is_locked.log" "app.php_errors.log", "execution_queue.log", "db_is_locked.log"
] ]
ALLOWED_SCAN_TYPES = Literal["ARPSCAN", "NMAPDEV", "NMAP", "INTRNT", "AVAHISCAN", "NBTSCAN"]
ALLOWED_SESSION_CONNECTION_TYPES = Literal["Connected", "Reconnected", "New Device", "Down Reconnected"]
ALLOWED_SESSION_DISCONNECTION_TYPES = Literal["Disconnected", "Device Down", "Timeout"]
ALLOWED_EVENT_TYPES = Literal[
"Device Down", "New Device", "Connected", "Disconnected",
"IP Changed", "Down Reconnected", "<missing event>"
]
def validate_mac(value: str) -> str: def validate_mac(value: str) -> str:
"""Validate and normalize MAC address format.""" """Validate and normalize MAC address format."""
@@ -89,14 +99,42 @@ def validate_column_identifier(value: str) -> str:
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
"""Standard API response wrapper.""" """
model_config = ConfigDict(extra="allow") Standard API response wrapper.
Note: The API often returns 200 OK for most operations; clients MUST parse the 'success'
boolean field to determine if the operation was actually successful.
"""
model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"success": True
}]
}
)
success: bool = Field(..., description="Whether the operation succeeded") success: bool = Field(..., description="Whether the operation succeeded")
message: Optional[str] = Field(None, description="Human-readable message") message: Optional[str] = Field(None, description="Human-readable message")
error: Optional[str] = Field(None, description="Error message if success=False") error: Optional[str] = Field(None, description="Error message if success=False")
class ErrorResponse(BaseResponse):
"""Standard error response model with details."""
model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"success": False,
"error": "Error message"
}]
}
)
success: bool = Field(False, description="Always False for errors")
details: Optional[Any] = Field(None, description="Detailed error information (e.g., validation errors)")
code: Optional[str] = Field(None, description="Internal error code")
class PaginatedResponse(BaseResponse): class PaginatedResponse(BaseResponse):
"""Response with pagination metadata.""" """Response with pagination metadata."""
total: int = Field(0, description="Total number of items") total: int = Field(0, description="Total number of items")
@@ -130,7 +168,19 @@ class DeviceSearchRequest(BaseModel):
class DeviceInfo(BaseModel): class DeviceInfo(BaseModel):
"""Detailed device information model (Raw record).""" """Detailed device information model (Raw record)."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"devMac": "00:11:22:33:44:55",
"devName": "My iPhone",
"devLastIP": "192.168.1.10",
"devVendor": "Apple",
"devStatus": "online",
"devFavorite": 0
}]
}
)
devMac: str = Field(..., description="Device MAC address") devMac: str = Field(..., description="Device MAC address")
devName: Optional[str] = Field(None, description="Device display name/alias") devName: Optional[str] = Field(None, description="Device display name/alias")
@@ -138,13 +188,27 @@ class DeviceInfo(BaseModel):
devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address") devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address")
devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address") devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address")
devVlan: Optional[str] = Field(None, description="VLAN identifier") devVlan: Optional[str] = Field(None, description="VLAN identifier")
devForceStatus: Optional[str] = Field(None, description="Force device status (online/offline/dont_force)") devForceStatus: Optional[Literal["online", "offline", "dont_force"]] = Field(
"dont_force",
description="Force device status (online/offline/dont_force)"
)
devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup") devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup")
devOwner: Optional[str] = Field(None, description="Device owner") devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type classification") devType: Optional[str] = Field(None, description="Device type classification")
devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)") devFavorite: Optional[int] = Field(
devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)") 0,
devStatus: Optional[str] = Field(None, description="Online/Offline status") description="Favorite flag (0=False, 1=True). Legacy boolean representation.",
json_schema_extra={"enum": [0, 1]}
)
devPresentLastScan: Optional[int] = Field(
None,
description="Present in last scan (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
devStatus: Optional[Literal["online", "offline"]] = Field(
None,
description="Online/Offline status"
)
devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)") devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)")
devNameSource: Optional[str] = Field(None, description="Source of devName") devNameSource: Optional[str] = Field(None, description="Source of devName")
devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN") devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN")
@@ -169,7 +233,17 @@ class DeviceListRequest(BaseModel):
"offline" "offline"
]] = Field( ]] = Field(
None, None,
description="Filter devices by status (connected, down, favorites, new, archived, all, my, offline)" description=(
"Filter devices by status:\n"
"- connected: Active devices present in the last scan\n"
"- down: Devices with active 'Device Down' alert\n"
"- favorites: Devices marked as favorite\n"
"- new: Devices flagged as new\n"
"- archived: Devices moved to archive\n"
"- all: All active (non-archived) devices\n"
"- my: All active devices (alias for 'all')\n"
"- offline: Devices not present in the last scan"
)
) )
@@ -270,12 +344,23 @@ class CopyDeviceRequest(BaseModel):
class UpdateDeviceColumnRequest(BaseModel): class UpdateDeviceColumnRequest(BaseModel):
"""Request to update a specific device database column.""" """Request to update a specific device database column."""
columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name") columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name")
columnValue: Any = Field(..., description="New value for the column") columnValue: Union[str, int, bool, None] = Field(
...,
description="New value for the column. Must match the column's expected data type (e.g., string for devName, integer for devFavorite).",
json_schema_extra={
"oneOf": [
{"type": "string"},
{"type": "integer"},
{"type": "boolean"},
{"type": "null"}
]
}
)
class LockDeviceFieldRequest(BaseModel): class LockDeviceFieldRequest(BaseModel):
"""Request to lock/unlock a device field.""" """Request to lock/unlock a device field."""
fieldName: Optional[str] = Field(None, description="Field name to lock/unlock (devMac, devName, devLastIP, etc.)") fieldName: str = Field(..., description="Field name to lock/unlock (e.g., devName, devVendor). Required.")
lock: bool = Field(True, description="True to lock the field, False to unlock") lock: bool = Field(True, description="True to lock the field, False to unlock")
@@ -301,12 +386,18 @@ class DeviceUpdateRequest(BaseModel):
devName: Optional[str] = Field(None, description="Device name") devName: Optional[str] = Field(None, description="Device name")
devOwner: Optional[str] = Field(None, description="Device owner") devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type") devType: Optional[str] = Field(
None,
description="Device type",
json_schema_extra={
"examples": ["Phone", "Laptop", "Desktop", "Router", "IoT", "Camera", "Server", "TV"]
}
)
devVendor: Optional[str] = Field(None, description="Device vendor") devVendor: Optional[str] = Field(None, description="Device vendor")
devGroup: Optional[str] = Field(None, description="Device group") devGroup: Optional[str] = Field(None, description="Device group")
devLocation: Optional[str] = Field(None, description="Device location") devLocation: Optional[str] = Field(None, description="Device location")
devComments: Optional[str] = Field(None, description="Comments") devComments: Optional[str] = Field(None, description="Comments")
createNew: bool = Field(False, description="Create new device if not exists") createNew: bool = Field(False, description="If True, creates a new device. Recommended to provide at least devName and devVendor. If False, updates existing device.")
@field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments") @field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments")
@classmethod @classmethod
@@ -340,10 +431,9 @@ class DeleteDevicesRequest(BaseModel):
class TriggerScanRequest(BaseModel): class TriggerScanRequest(BaseModel):
"""Request to trigger a network scan.""" """Request to trigger a network scan."""
type: str = Field( type: ALLOWED_SCAN_TYPES = Field(
"ARPSCAN", "ARPSCAN",
description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)", description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)"
json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]}
) )
@@ -381,8 +471,9 @@ class OpenPortsResponse(BaseResponse):
class WakeOnLanRequest(BaseModel): class WakeOnLanRequest(BaseModel):
"""Request to send Wake-on-LAN packet.""" """Request to send Wake-on-LAN packet."""
devMac: Optional[str] = Field( mac: Optional[str] = Field(
None, None,
alias="devMac",
description="Target device MAC address", description="Target device MAC address",
json_schema_extra={"examples": ["00:11:22:33:44:55"]} json_schema_extra={"examples": ["00:11:22:33:44:55"]}
) )
@@ -396,7 +487,7 @@ class WakeOnLanRequest(BaseModel):
# But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip". # But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip".
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@field_validator("devMac") @field_validator("mac")
@classmethod @classmethod
def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]: def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]:
if v is not None: if v is not None:
@@ -412,15 +503,19 @@ class WakeOnLanRequest(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def require_mac_or_ip(self) -> "WakeOnLanRequest": def require_mac_or_ip(self) -> "WakeOnLanRequest":
"""Ensure at least one of devMac or devLastIP is provided.""" """Ensure at least one of mac or devLastIP is provided."""
if self.devMac is None and self.devLastIP is None: if self.mac is None and self.devLastIP is None:
raise ValueError("Either 'devMac' or 'devLastIP' (alias 'ip') must be provided") raise ValueError("Either devMac (aka mac) or devLastIP (aka ip) must be provided")
return self return self
class WakeOnLanResponse(BaseResponse): class WakeOnLanResponse(BaseResponse):
"""Response for Wake-on-LAN operation.""" """Response for Wake-on-LAN operation."""
output: Optional[str] = Field(None, description="Command output") output: Optional[str] = Field(
None,
description="Command output",
json_schema_extra={"examples": ["Sent magic packet to AA:BB:CC:DD:EE:FF"]}
)
class TracerouteRequest(BaseModel): class TracerouteRequest(BaseModel):
@@ -446,7 +541,7 @@ class NmapScanRequest(BaseModel):
"""Request to perform NMAP scan.""" """Request to perform NMAP scan."""
scan: str = Field( scan: str = Field(
..., ...,
description="Target IP address for NMAP scan" description="Target IP address for NMAP scan (Single IP only, no CIDR/ranges/hostnames)."
) )
mode: ALLOWED_NMAP_MODES = Field( mode: ALLOWED_NMAP_MODES = Field(
..., ...,
@@ -507,7 +602,17 @@ class NetworkInterfacesResponse(BaseResponse):
class EventInfo(BaseModel): class EventInfo(BaseModel):
"""Event/alert information.""" """Event/alert information."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"eveMAC": "00:11:22:33:44:55",
"eveIP": "192.168.1.10",
"eveDateTime": "2024-01-29 10:00:00",
"eveEventType": "Device Down"
}]
}
)
eveRowid: Optional[int] = Field(None, description="Event row ID") eveRowid: Optional[int] = Field(None, description="Event row ID")
eveMAC: Optional[str] = Field(None, description="Device MAC address") eveMAC: Optional[str] = Field(None, description="Device MAC address")
@@ -547,9 +652,19 @@ class LastEventsResponse(BaseResponse):
class CreateEventRequest(BaseModel): class CreateEventRequest(BaseModel):
"""Request to create a device event.""" """Request to create a device event."""
ip: Optional[str] = Field("0.0.0.0", description="Device IP") ip: Optional[str] = Field("0.0.0.0", description="Device IP")
event_type: str = Field("Device Down", description="Event type") event_type: str = Field(
"Device Down",
description="Event type",
json_schema_extra={
"examples": ["Device Down", "New Device", "Connected", "Disconnected", "IP Changed", "Down Reconnected", "<missing event>"]
}
)
additional_info: Optional[str] = Field("", description="Additional info") additional_info: Optional[str] = Field("", description="Additional info")
pending_alert: int = Field(1, description="Pending alert flag") pending_alert: int = Field(
1,
description="Pending alert flag (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
event_time: Optional[str] = Field(None, description="Event timestamp (ISO)") event_time: Optional[str] = Field(None, description="Event timestamp (ISO)")
@field_validator("ip", mode="before") @field_validator("ip", mode="before")
@@ -564,11 +679,19 @@ class CreateEventRequest(BaseModel):
# ============================================================================= # =============================================================================
# SESSIONS SCHEMAS # SESSIONS SCHEMAS
# ============================================================================= # =============================================================================
class SessionInfo(BaseModel): class SessionInfo(BaseModel):
"""Session information.""" """Session information."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"sesMac": "00:11:22:33:44:55",
"sesDateTimeConnection": "2024-01-29 08:00:00",
"sesDateTimeDisconnection": "2024-01-29 09:00:00",
"sesIPAddress": "192.168.1.10"
}]
}
)
sesRowid: Optional[int] = Field(None, description="Session row ID") sesRowid: Optional[int] = Field(None, description="Session row ID")
sesMac: Optional[str] = Field(None, description="Device MAC address") sesMac: Optional[str] = Field(None, description="Device MAC address")
@@ -583,8 +706,20 @@ class CreateSessionRequest(BaseModel):
ip: str = Field(..., description="Device IP") ip: str = Field(..., description="Device IP")
start_time: str = Field(..., description="Start time") start_time: str = Field(..., description="Start time")
end_time: Optional[str] = Field(None, description="End time") end_time: Optional[str] = Field(None, description="End time")
event_type_conn: str = Field("Connected", description="Connection event type") event_type_conn: str = Field(
event_type_disc: str = Field("Disconnected", description="Disconnection event type") "Connected",
description="Connection event type",
json_schema_extra={
"examples": ["Connected", "Reconnected", "New Device", "Down Reconnected"]
}
)
event_type_disc: str = Field(
"Disconnected",
description="Disconnection event type",
json_schema_extra={
"examples": ["Disconnected", "Device Down", "Timeout"]
}
)
@field_validator("mac") @field_validator("mac")
@classmethod @classmethod
@@ -620,7 +755,11 @@ class InAppNotification(BaseModel):
guid: Optional[str] = Field(None, description="Unique notification GUID") guid: Optional[str] = Field(None, description="Unique notification GUID")
text: str = Field(..., description="Notification text content") text: str = Field(..., description="Notification text content")
level: NOTIFICATION_LEVELS = Field("info", description="Notification level") level: NOTIFICATION_LEVELS = Field("info", description="Notification level")
read: Optional[int] = Field(0, description="Read status (0 or 1)") read: Optional[int] = Field(
0,
description="Read status (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
@@ -665,10 +804,12 @@ class DbQueryRequest(BaseModel):
""" """
Request for raw database query. Request for raw database query.
WARNING: This is a highly privileged operation. WARNING: This is a highly privileged operation.
Can be used to read settings by querying the 'Settings' table.
""" """
rawSql: str = Field( rawSql: str = Field(
..., ...,
description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)" description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)",
json_schema_extra={"examples": ["U0VMRUNUICogRlJPTSBTZXR0aW5ncw=="]}
) )
# Legacy compatibility: removed strict safety check # Legacy compatibility: removed strict safety check
# TODO: SECURITY CRITICAL - Re-enable strict safety checks. # TODO: SECURITY CRITICAL - Re-enable strict safety checks.
@@ -690,9 +831,23 @@ class DbQueryRequest(BaseModel):
class DbQueryUpdateRequest(BaseModel): class DbQueryUpdateRequest(BaseModel):
"""Request for DB update query.""" """
Request for DB update query.
Can be used to update settings by targeting the 'Settings' table.
"""
columnName: str = Field(..., description="Column to filter by") columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to update") id: List[Union[str, int]] = Field(
...,
description="List of IDs to update. Use MAC address strings for 'Devices' table, and integer RowIDs for all other tables.",
json_schema_extra={
"items": {
"oneOf": [
{"type": "string", "description": "A string identifier (e.g., MAC address)"},
{"type": "integer", "description": "A numeric row ID"}
]
}
}
)
dbtable: ALLOWED_TABLES = Field(..., description="Table name") dbtable: ALLOWED_TABLES = Field(..., description="Table name")
columns: List[str] = Field(..., description="Columns to update") columns: List[str] = Field(..., description="Columns to update")
values: List[Any] = Field(..., description="New values") values: List[Any] = Field(..., description="New values")
@@ -715,9 +870,23 @@ class DbQueryUpdateRequest(BaseModel):
class DbQueryDeleteRequest(BaseModel): class DbQueryDeleteRequest(BaseModel):
"""Request for DB delete query.""" """
Request for DB delete query.
Can be used to delete settings by targeting the 'Settings' table.
"""
columnName: str = Field(..., description="Column to filter by") columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to delete") id: List[Union[str, int]] = Field(
...,
description="List of IDs to delete. Use MAC address strings for 'Devices' table, and integer RowIDs for all other tables.",
json_schema_extra={
"items": {
"oneOf": [
{"type": "string", "description": "A string identifier (e.g., MAC address)"},
{"type": "integer", "description": "A numeric row ID"}
]
}
}
)
dbtable: ALLOWED_TABLES = Field(..., description="Table name") dbtable: ALLOWED_TABLES = Field(..., description="Table name")
@field_validator("columnName") @field_validator("columnName")
@@ -772,3 +941,14 @@ class SettingValue(BaseModel):
class GetSettingResponse(BaseResponse): class GetSettingResponse(BaseResponse):
"""Response for getting a setting value.""" """Response for getting a setting value."""
value: Any = Field(None, description="The setting value") value: Any = Field(None, description="The setting value")
# =============================================================================
# GRAPHQL SCHEMAS
# =============================================================================
class GraphQLRequest(BaseModel):
"""Request payload for GraphQL queries."""
query: str = Field(..., description="GraphQL query string", json_schema_extra={"examples": ["{ devices { devMac devName } }"]})
variables: Optional[Dict[str, Any]] = Field(None, description="Variables for the GraphQL query")

View File

@@ -29,7 +29,7 @@ Usage:
""" """
from __future__ import annotations from __future__ import annotations
import os
import threading import threading
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
@@ -52,7 +52,7 @@ _rebuild_lock = threading.Lock()
def generate_openapi_spec( def generate_openapi_spec(
title: str = "NetAlertX API", title: str = "NetAlertX API",
version: str = "2.0.0", version: str = "2.0.0",
description: str = "NetAlertX Network Monitoring API - MCP Compatible", description: str = "NetAlertX Network Monitoring API - Official Documentation - MCP Compatible",
servers: Optional[List[Dict[str, str]]] = None, servers: Optional[List[Dict[str, str]]] = None,
flask_app: Optional[Any] = None flask_app: Optional[Any] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -74,18 +74,58 @@ def generate_openapi_spec(
introspect_graphql_schema(devicesSchema) introspect_graphql_schema(devicesSchema)
introspect_flask_app(flask_app) introspect_flask_app(flask_app)
# Apply default disabled tools from setting `MCP_DISABLED_TOOLS`, env var, or hard-coded defaults
# Format: comma-separated operation IDs, e.g. "dbquery_read,dbquery_write"
try:
disabled_env = None
# Prefer setting from app.conf/settings when available
try:
from helper import get_setting_value
setting_val = get_setting_value("MCP_DISABLED_TOOLS")
if setting_val:
disabled_env = str(setting_val).strip()
except Exception:
# If helper is unavailable, fall back to environment
pass
if not disabled_env:
env_val = os.getenv("MCP_DISABLED_TOOLS")
if env_val:
disabled_env = env_val.strip()
# If still not set, apply safe hard-coded defaults
if not disabled_env:
disabled_env = "dbquery_read,dbquery_write"
if disabled_env:
from .registry import set_tool_disabled
for op in [p.strip() for p in disabled_env.split(",") if p.strip()]:
set_tool_disabled(op, True)
except Exception:
# Never fail spec generation due to disablement application issues
pass
spec = { spec = {
"openapi": "3.1.0", "openapi": "3.1.0",
"info": { "info": {
"title": title, "title": title,
"version": version, "version": version,
"description": description, "description": description,
"termsOfService": "https://github.com/netalertx/NetAlertX/blob/main/LICENSE.txt",
"contact": { "contact": {
"name": "NetAlertX", "name": "Open Source Project - NetAlertX - Github",
"url": "https://github.com/jokob-sk/NetAlertX" "url": "https://github.com/netalertx/NetAlertX"
},
"license": {
"name": "Licensed under GPLv3",
"url": "https://www.gnu.org/licenses/gpl-3.0.html"
} }
}, },
"servers": servers or [{"url": "/", "description": "Local server"}], "externalDocs": {
"description": "NetAlertX Official Documentation",
"url": "https://docs.netalertx.com/"
},
"servers": servers or [{"url": "/", "description": "This NetAlertX instance"}],
"security": [ "security": [
{"BearerAuth": []} {"BearerAuth": []}
], ],
@@ -152,7 +192,11 @@ def generate_openapi_spec(
# Add responses # Add responses
operation["responses"] = build_responses( operation["responses"] = build_responses(
entry.get("response_model"), definitions entry.get("response_model"),
definitions,
response_content_types=entry.get("response_content_types", ["application/json"]),
links=entry.get("links"),
method=method
) )
spec["paths"][path][method] = operation spec["paths"][path][method] = operation

View File

@@ -44,7 +44,11 @@ def validate_request(
query_params: Optional[list[dict]] = None, query_params: Optional[list[dict]] = None,
validation_error_code: int = 422, validation_error_code: int = 422,
auth_callable: Optional[Callable[[], bool]] = None, auth_callable: Optional[Callable[[], bool]] = None,
allow_multipart_payload: bool = False allow_multipart_payload: bool = False,
exclude_from_spec: bool = False,
response_content_types: Optional[list[str]] = None,
links: Optional[dict] = None,
error_responses: Optional[dict] = None
): ):
""" """
Decorator to register a Flask route with the OpenAPI registry and validate incoming requests. Decorator to register a Flask route with the OpenAPI registry and validate incoming requests.
@@ -56,6 +60,10 @@ def validate_request(
- Supports auth_callable to check permissions before validation. - Supports auth_callable to check permissions before validation.
- Returns 422 (default) if validation fails. - Returns 422 (default) if validation fails.
- allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields. - allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields.
- exclude_from_spec: If True, this endpoint will be omitted from the generated OpenAPI specification.
- response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]).
- links: Dictionary of OpenAPI links to include in the response definition.
- error_responses: Dictionary of custom error examples (e.g. {"404": "Device not found"}).
""" """
def decorator(f: Callable) -> Callable: def decorator(f: Callable) -> Callable:
@@ -73,7 +81,11 @@ def validate_request(
"tags": tags, "tags": tags,
"path_params": path_params, "path_params": path_params,
"query_params": query_params, "query_params": query_params,
"allow_multipart_payload": allow_multipart_payload "allow_multipart_payload": allow_multipart_payload,
"exclude_from_spec": exclude_from_spec,
"response_content_types": response_content_types,
"links": links,
"error_responses": error_responses
} }
@wraps(f) @wraps(f)
@@ -150,6 +162,7 @@ def validate_request(
data = request.args.to_dict() data = request.args.to_dict()
validated_instance = request_model(**data) validated_instance = request_model(**data)
except ValidationError as e: except ValidationError as e:
# Use configured validation error code (default 422)
return _handle_validation_error(e, operation_id, validation_error_code) return _handle_validation_error(e, operation_id, validation_error_code)
except (TypeError, ValueError, KeyError) as e: except (TypeError, ValueError, KeyError) as e:
mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"]) mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"])

View File

@@ -0,0 +1,63 @@
import os
import sys
import pytest
from unittest.mock import patch, MagicMock
# Use cwd as fallback if env var is not set, assuming running from project root
INSTALL_PATH = os.getenv('NETALERTX_APP', os.getcwd())
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from api_server.openapi.spec_generator import generate_openapi_spec
from api_server.api_server_start import app
class TestMCPDisabledTools:
def test_disabled_tools_via_env_var(self):
"""Test that MCP_DISABLED_TOOLS env var disables specific tools."""
# Clean registry first to ensure clean state
from api_server.openapi.registry import clear_registry
clear_registry()
# Mock get_setting_value to return None (simulating no config setting)
# and mock os.getenv to return our target list
with patch("helper.get_setting_value", return_value=None), \
patch.dict(os.environ, {"MCP_DISABLED_TOOLS": "search_devices_api"}):
spec = generate_openapi_spec(flask_app=app)
# Locate the operation
# search_devices_api is usually mapped to /devices/search [POST] or similar
# We search the spec for the operationId
found = False
for path, methods in spec["paths"].items():
for method, op in methods.items():
if op["operationId"] == "search_devices_api":
assert op.get("x-mcp-disabled") is True
found = True
assert found, "search_devices_api operation not found in spec"
def test_disabled_tools_default_fallback(self):
"""Test fallback to defaults when no setting or env var exists."""
from api_server.openapi.registry import clear_registry
clear_registry()
with patch("helper.get_setting_value", return_value=None), \
patch.dict(os.environ, {}, clear=True): # Clear env to ensure no MCP_DISABLED_TOOLS
spec = generate_openapi_spec(flask_app=app)
# Default is "dbquery_read,dbquery_write"
# Check dbquery_read
found_read = False
for path, methods in spec["paths"].items():
for method, op in methods.items():
if op["operationId"] == "dbquery_read":
assert op.get("x-mcp-disabled") is True
found_read = True
assert found_read, "dbquery_read should be disabled by default"

View File

@@ -66,7 +66,7 @@ class TestPydanticSchemas:
"""WakeOnLanRequest should validate MAC format.""" """WakeOnLanRequest should validate MAC format."""
# Valid MAC # Valid MAC
req = WakeOnLanRequest(devMac="00:11:22:33:44:55") req = WakeOnLanRequest(devMac="00:11:22:33:44:55")
assert req.devMac == "00:11:22:33:44:55" assert req.mac == "00:11:22:33:44:55"
# Invalid MAC # Invalid MAC
# with pytest.raises(ValidationError): # with pytest.raises(ValidationError):
@@ -76,7 +76,7 @@ class TestPydanticSchemas:
"""WakeOnLanRequest should accept either MAC or IP.""" """WakeOnLanRequest should accept either MAC or IP."""
req_mac = WakeOnLanRequest(devMac="00:11:22:33:44:55") req_mac = WakeOnLanRequest(devMac="00:11:22:33:44:55")
req_ip = WakeOnLanRequest(devLastIP="192.168.1.50") req_ip = WakeOnLanRequest(devLastIP="192.168.1.50")
assert req_mac.devMac is not None assert req_mac.mac is not None
assert req_ip.devLastIP == "192.168.1.50" assert req_ip.devLastIP == "192.168.1.50"
def test_traceroute_request_ip_validation(self): def test_traceroute_request_ip_validation(self):
@@ -197,7 +197,7 @@ class TestOpenAPISpecGenerator:
f"Path param '{param_name}' not defined: {method.upper()} {path}" f"Path param '{param_name}' not defined: {method.upper()} {path}"
def test_standard_error_responses(self): def test_standard_error_responses(self):
"""Operations should have minimal standard error responses (400, 403, 404, etc) without schema bloat.""" """Operations should have standard error responses (400, 403, 404, etc)."""
spec = generate_openapi_spec() spec = generate_openapi_spec()
expected_minimal_codes = ["400", "401", "403", "404", "500", "422"] expected_minimal_codes = ["400", "401", "403", "404", "500", "422"]
@@ -207,21 +207,28 @@ class TestOpenAPISpecGenerator:
continue continue
responses = details.get("responses", {}) responses = details.get("responses", {})
for code in expected_minimal_codes: for code in expected_minimal_codes:
assert code in responses, f"Missing minimal {code} response in: {method.upper()} {path}." assert code in responses, f"Missing {code} response in: {method.upper()} {path}."
# Verify no "content" or schema is present (minimalism) # Content should now be present (BaseResponse/Error schema)
assert "content" not in responses[code], f"Response {code} in {method.upper()} {path} should not have content/schema." assert "content" in responses[code], f"Response {code} in {method.upper()} {path} should have content/schema."
class TestMCPToolMapping: class TestMCPToolMapping:
"""Test MCP tool generation from OpenAPI spec.""" """Test MCP tool generation from OpenAPI spec."""
def test_tools_match_registry_count(self): def test_tools_match_registry_count(self):
"""Number of MCP tools should match registered endpoints.""" """Number of MCP tools should match unique original operation IDs in registry."""
spec = generate_openapi_spec() spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec) tools = map_openapi_to_mcp_tools(spec)
registry = get_registry() registry = get_registry()
assert len(tools) == len(registry) # Count unique operation IDs (accounting for our deduplication logic)
unique_ops = set()
for entry in registry:
# We used x-original-operationId for deduplication logic, or operation_id if not present
op_id = entry.get("original_operation_id") or entry["operation_id"]
unique_ops.add(op_id)
assert len(tools) == len(unique_ops)
def test_tools_have_input_schema(self): def test_tools_have_input_schema(self):
"""All MCP tools should have inputSchema.""" """All MCP tools should have inputSchema."""
@@ -239,9 +246,9 @@ class TestMCPToolMapping:
spec = generate_openapi_spec() spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec) tools = map_openapi_to_mcp_tools(spec)
search_tool = next((t for t in tools if t["name"] == "search_devices"), None) search_tool = next((t for t in tools if t["name"] == "search_devices_api"), None)
assert search_tool is not None assert search_tool is not None
assert "query" in search_tool["inputSchema"].get("required", []) assert "query" in search_tool["inputSchema"]["required"]
def test_tool_descriptions_present(self): def test_tool_descriptions_present(self):
"""All tools should have non-empty descriptions.""" """All tools should have non-empty descriptions."""

View File

@@ -99,8 +99,9 @@ class TestDeviceFieldLock:
json=payload, json=payload,
headers=auth_headers headers=auth_headers
) )
assert resp.status_code == 400 assert resp.status_code == 422
assert "fieldName is required" in resp.json.get("error", "") # Pydantic error message format for missing fields
assert "Missing required 'fieldName'" in resp.json.get("error", "")
def test_lock_field_invalid_field_name(self, client, test_mac, auth_headers): def test_lock_field_invalid_field_name(self, client, test_mac, auth_headers):
"""Lock endpoint rejects untracked fields.""" """Lock endpoint rejects untracked fields."""