Improve OpenAPI specs

This commit is contained in:
Adam Outler
2026-01-29 23:06:05 +00:00
parent f54ba4817e
commit ed4e0388cc
10 changed files with 533 additions and 133 deletions

View File

@@ -100,6 +100,18 @@ from .sse_endpoint import ( # noqa: E402 [flake8 lint suppression]
app = Flask(__name__) app = Flask(__name__)
@app.errorhandler(500)
@app.errorhandler(Exception)
def handle_500_error(e):
"""Global error handler for uncaught exceptions."""
mylog("none", [f"[API] Uncaught exception: {e}"])
return jsonify({
"success": False,
"error": "Internal Server Error",
"message": "Something went wrong on the server"
}), 500
# Parse CORS origins from environment or use safe defaults # Parse CORS origins from environment or use safe defaults
_cors_origins_env = os.environ.get("CORS_ORIGINS", "") _cors_origins_env = os.environ.get("CORS_ORIGINS", "")
_cors_origins = [ _cors_origins = [
@@ -599,7 +611,7 @@ def api_device_open_ports(payload=None):
@validate_request( @validate_request(
operation_id="get_all_devices", operation_id="get_all_devices",
summary="Get All Devices", summary="Get All Devices",
description="Retrieve a list of all devices in the system.", description="Retrieve a list of all devices in the system. Returns all records. No pagination supported.",
response_model=DeviceListWrapperResponse, response_model=DeviceListWrapperResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized
@@ -662,7 +674,7 @@ def api_delete_unknown_devices(payload=None):
@app.route("/devices/export", methods=["GET"]) @app.route("/devices/export", methods=["GET"])
@app.route("/devices/export/<format>", methods=["GET"]) @app.route("/devices/export/<format>", methods=["GET"])
@validate_request( @validate_request(
operation_id="export_devices", operation_id="export_devices_all",
summary="Export Devices", summary="Export Devices",
description="Export all devices in CSV or JSON format.", description="Export all devices in CSV or JSON format.",
query_params=[{ query_params=[{
@@ -679,7 +691,8 @@ def api_delete_unknown_devices(payload=None):
}], }],
response_model=DeviceExportResponse, response_model=DeviceExportResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized,
response_content_types=["application/json", "text/csv"]
) )
def api_export_devices(format=None, payload=None): def api_export_devices(format=None, payload=None):
export_format = (format or request.args.get("format", "csv")).lower() export_format = (format or request.args.get("format", "csv")).lower()
@@ -747,7 +760,7 @@ def api_devices_totals(payload=None):
@app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST']) @app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST'])
@app.route("/devices/by-status", methods=["GET", "POST"]) @app.route("/devices/by-status", methods=["GET", "POST"])
@validate_request( @validate_request(
operation_id="list_devices_by_status", operation_id="list_devices_by_status_api",
summary="List Devices by Status", summary="List Devices by Status",
description="List devices filtered by their online/offline status.", description="List devices filtered by their online/offline status.",
request_model=DeviceListRequest, request_model=DeviceListRequest,
@@ -763,7 +776,30 @@ def api_devices_totals(payload=None):
"connected", "down", "favorites", "new", "archived", "all", "my", "connected", "down", "favorites", "new", "archived", "all", "my",
"offline" "offline"
]} ]}
}] }],
links={
"GetOpenPorts": {
"operationId": "get_open_ports",
"parameters": {
"target": "$response.body#/0/devLastIP"
},
"description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the list."
},
"WakeOnLan": {
"operationId": "wake_on_lan",
"parameters": {
"devMac": "$response.body#/0/devMac"
},
"description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the list."
},
"UpdateDevice": {
"operationId": "update_device",
"parameters": {
"mac": "$response.body#/0/devMac"
},
"description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the list."
}
}
) )
def api_devices_by_status(payload: DeviceListRequest = None): def api_devices_by_status(payload: DeviceListRequest = None):
status = payload.status if payload else request.args.get("status") status = payload.status if payload else request.args.get("status")
@@ -774,13 +810,43 @@ def api_devices_by_status(payload: DeviceListRequest = None):
@app.route('/mcp/sse/devices/search', methods=['POST']) @app.route('/mcp/sse/devices/search', methods=['POST'])
@app.route('/devices/search', methods=['POST']) @app.route('/devices/search', methods=['POST'])
@validate_request( @validate_request(
operation_id="search_devices", operation_id="search_devices_api",
summary="Search Devices", summary="Search Devices",
description="Search for devices based on various criteria like name, IP, MAC, or vendor.", description="Search for devices based on various criteria like name, IP, MAC, or vendor.",
request_model=DeviceSearchRequest, request_model=DeviceSearchRequest,
response_model=DeviceSearchResponse, response_model=DeviceSearchResponse,
tags=["devices"], tags=["devices"],
auth_callable=is_authorized auth_callable=is_authorized,
links={
"GetOpenPorts": {
"operationId": "get_open_ports",
"parameters": {
"target": "$response.body#/devices/0/devLastIP"
},
"description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the search results."
},
"WakeOnLan": {
"operationId": "wake_on_lan",
"parameters": {
"devMac": "$response.body#/devices/0/devMac"
},
"description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the search results."
},
"NmapScan": {
"operationId": "run_nmap_scan",
"parameters": {
"scan": "$response.body#/devices/0/devLastIP"
},
"description": "The `scan` parameter for `run_nmap_scan` requires an IP or range. Use the `devLastIP` from the first device in the search results."
},
"UpdateDevice": {
"operationId": "update_device",
"parameters": {
"mac": "$response.body#/devices/0/devMac"
},
"description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the search results."
}
}
) )
def api_devices_search(payload=None): def api_devices_search(payload=None):
"""Device search: accepts 'query' in JSON and maps to device info/search.""" """Device search: accepts 'query' in JSON and maps to device info/search."""
@@ -1011,7 +1077,7 @@ def api_network_interfaces(payload=None):
@app.route('/mcp/sse/nettools/trigger-scan', methods=['POST']) @app.route('/mcp/sse/nettools/trigger-scan', methods=['POST'])
@app.route("/nettools/trigger-scan", methods=["GET"]) @app.route("/nettools/trigger-scan", methods=["GET", "POST"])
@validate_request( @validate_request(
operation_id="trigger_network_scan", operation_id="trigger_network_scan",
summary="Trigger Network Scan", summary="Trigger Network Scan",
@@ -1300,13 +1366,25 @@ def api_create_event(mac, payload=None):
@app.route("/events/<mac>", methods=["DELETE"]) @app.route("/events/<mac>", methods=["DELETE"])
@validate_request( @validate_request(
operation_id="delete_events_by_mac", operation_id="delete_events",
summary="Delete Events by MAC", summary="Delete Events",
description="Delete all events for a specific device MAC address.", description="Delete events by device MAC address or older than a specified number of days.",
path_params=[{ path_params=[{
"name": "mac", "name": "mac",
"description": "Device MAC address", "description": "Device MAC address or number of days",
"schema": {"type": "string"} "schema": {
"oneOf": [
{
"type": "integer",
"description": "Number of days (e.g., 30) to delete events older than this value."
},
{
"type": "string",
"pattern": "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$",
"description": "Device MAC address to delete all events for a specific device."
}
]
}
}], }],
response_model=BaseResponse, response_model=BaseResponse,
tags=["events"], tags=["events"],
@@ -1315,6 +1393,7 @@ def api_create_event(mac, payload=None):
def api_events_by_mac(mac, payload=None): def api_events_by_mac(mac, payload=None):
"""Delete events for a specific device MAC; string converter keeps this distinct from /events/<int:days>.""" """Delete events for a specific device MAC; string converter keeps this distinct from /events/<int:days>."""
device_handler = DeviceInstance() device_handler = DeviceInstance()
result = device_handler.deleteDeviceEvents(mac) result = device_handler.deleteDeviceEvents(mac)
return jsonify(result) return jsonify(result)
@@ -1338,7 +1417,7 @@ def api_delete_all_events(payload=None):
@validate_request( @validate_request(
operation_id="get_all_events", operation_id="get_all_events",
summary="Get Events", summary="Get Events",
description="Retrieve a list of events, optionally filtered by MAC.", description="Retrieve a list of events, optionally filtered by MAC. Returns all matching records. No pagination supported.",
query_params=[{ query_params=[{
"name": "mac", "name": "mac",
"description": "Filter by Device MAC", "description": "Filter by Device MAC",
@@ -1372,7 +1451,8 @@ def api_get_events(payload=None):
}], }],
response_model=BaseResponse, response_model=BaseResponse,
tags=["events"], tags=["events"],
auth_callable=is_authorized auth_callable=is_authorized,
exclude_from_spec=True
) )
def api_delete_old_events(days: int, payload=None): def api_delete_old_events(days: int, payload=None):
""" """
@@ -1406,7 +1486,7 @@ def api_get_events_totals(payload=None):
@app.route('/mcp/sse/events/recent', methods=['GET', 'POST']) @app.route('/mcp/sse/events/recent', methods=['GET', 'POST'])
@app.route('/events/recent', methods=['GET']) @app.route('/events/recent', methods=['GET', 'POST'])
@validate_request( @validate_request(
operation_id="get_recent_events", operation_id="get_recent_events",
summary="Get Recent Events", summary="Get Recent Events",
@@ -1426,7 +1506,7 @@ def api_events_default_24h(payload=None):
@app.route('/mcp/sse/events/last', methods=['GET', 'POST']) @app.route('/mcp/sse/events/last', methods=['GET', 'POST'])
@app.route('/events/last', methods=['GET']) @app.route('/events/last', methods=['GET', 'POST'])
@validate_request( @validate_request(
operation_id="get_last_events", operation_id="get_last_events",
summary="Get Last Events", summary="Get Last Events",
@@ -1763,7 +1843,7 @@ def sync_endpoint_post(payload=None):
@validate_request( @validate_request(
operation_id="check_auth", operation_id="check_auth",
summary="Check Authentication", summary="Check Authentication",
description="Check if the current API token is valid.", description="Check if the current API token is valid. Note: tokens must be generated externally via the UI or CLI.",
response_model=BaseResponse, response_model=BaseResponse,
tags=["auth"], tags=["auth"],
auth_callable=is_authorized auth_callable=is_authorized

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Any from typing import Any, Dict, Optional
import graphene import graphene
from .registry import register_tool, _operation_ids from .registry import register_tool, _operation_ids
from .schemas import GraphQLRequest
from .schema_converter import pydantic_to_json_schema, resolve_schema_refs
def introspect_graphql_schema(schema: graphene.Schema): def introspect_graphql_schema(schema: graphene.Schema):
@@ -26,6 +28,7 @@ def introspect_graphql_schema(schema: graphene.Schema):
operation_id="graphql_query", operation_id="graphql_query",
summary="GraphQL Endpoint", summary="GraphQL Endpoint",
description="Execute arbitrary GraphQL queries against the system schema.", description="Execute arbitrary GraphQL queries against the system schema.",
request_model=GraphQLRequest,
tags=["graphql"] tags=["graphql"]
) )
@@ -36,6 +39,20 @@ def _flask_to_openapi_path(flask_path: str) -> str:
return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path) return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path)
def _get_openapi_metadata(func: Any) -> Optional[Dict[str, Any]]:
"""Recursively find _openapi_metadata in wrapped functions."""
# Check current function
metadata = getattr(func, "_openapi_metadata", None)
if metadata:
return metadata
# Check __wrapped__ (standard for @wraps)
if hasattr(func, "__wrapped__"):
return _get_openapi_metadata(func.__wrapped__)
return None
def introspect_flask_app(app: Any): def introspect_flask_app(app: Any):
""" """
Introspect the Flask application to find routes decorated with @validate_request Introspect the Flask application to find routes decorated with @validate_request
@@ -47,14 +64,13 @@ def introspect_flask_app(app: Any):
if not view_func: if not view_func:
continue continue
# Check for our decorator's metadata # Check for our decorator's metadata recursively
metadata = getattr(view_func, "_openapi_metadata", None) metadata = _get_openapi_metadata(view_func)
if not metadata:
# Fallback for wrapped functions
if hasattr(view_func, "__wrapped__"):
metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None)
if metadata: if metadata:
if metadata.get("exclude_from_spec"):
continue
op_id = metadata["operation_id"] op_id = metadata["operation_id"]
# Register the tool with real path and method from Flask # Register the tool with real path and method from Flask
@@ -75,11 +91,8 @@ def introspect_flask_app(app: Any):
# Determine tags - create a copy to avoid mutating shared metadata # Determine tags - create a copy to avoid mutating shared metadata
tags = list(metadata.get("tags") or ["rest"]) tags = list(metadata.get("tags") or ["rest"])
if path.startswith("/mcp/"): if path.startswith("/mcp/"):
# Move specific tags to secondary position or just add MCP # For MCP endpoints, we want them exclusively in the 'mcp' tag section
if "rest" in tags: tags = ["mcp"]
tags.remove("rest")
if "mcp" not in tags:
tags.append("mcp")
# Ensure unique operationId # Ensure unique operationId
original_op_id = op_id original_op_id = op_id
@@ -89,6 +102,38 @@ def introspect_flask_app(app: Any):
unique_op_id = f"{op_id}_{count}" unique_op_id = f"{op_id}_{count}"
count += 1 count += 1
# Filter path_params to only include those that are actually in the path
path_params = metadata.get("path_params")
if path_params:
path_params = [
p for p in path_params
if f"{{{p['name']}}}" in path
]
# Auto-generate query_params from request_model for GET requests
query_params = metadata.get("query_params")
if method == 'GET' and not query_params and metadata.get("request_model"):
try:
schema = pydantic_to_json_schema(metadata["request_model"])
properties = schema.get("properties", {})
query_params = []
for name, prop in properties.items():
is_required = name in schema.get("required", [])
# Create param definition, preserving enum/schema
param_def = {
"name": name,
"in": "query",
"required": is_required,
"description": prop.get("description", ""),
"schema": prop
}
# Remove description from schema to avoid duplication
if "description" in param_def["schema"]:
del param_def["schema"]["description"]
query_params.append(param_def)
except Exception:
pass # Fallback to empty if schema generation fails
register_tool( register_tool(
path=path, path=path,
method=method, method=method,
@@ -98,9 +143,11 @@ def introspect_flask_app(app: Any):
description=metadata["description"], description=metadata["description"],
request_model=metadata.get("request_model"), request_model=metadata.get("request_model"),
response_model=metadata.get("response_model"), response_model=metadata.get("response_model"),
path_params=metadata.get("path_params"), path_params=path_params,
query_params=metadata.get("query_params"), query_params=query_params,
tags=tags, tags=tags,
allow_multipart_payload=metadata.get("allow_multipart_payload", False) allow_multipart_payload=metadata.get("allow_multipart_payload", False),
response_content_types=metadata.get("response_content_types"),
links=metadata.get("links")
) )
registered_ops.add(op_key) registered_ops.add(op_key)

View File

@@ -96,7 +96,9 @@ def register_tool(
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
deprecated: bool = False, deprecated: bool = False,
original_operation_id: Optional[str] = None, original_operation_id: Optional[str] = None,
allow_multipart_payload: bool = False allow_multipart_payload: bool = False,
response_content_types: Optional[List[str]] = None,
links: Optional[Dict[str, Any]] = None
) -> None: ) -> None:
""" """
Register an API endpoint for OpenAPI spec generation. Register an API endpoint for OpenAPI spec generation.
@@ -115,6 +117,8 @@ def register_tool(
deprecated: Whether this endpoint is deprecated deprecated: Whether this endpoint is deprecated
original_operation_id: The base ID before suffixing (for disablement mapping) original_operation_id: The base ID before suffixing (for disablement mapping)
allow_multipart_payload: Whether to allow multipart/form-data payloads allow_multipart_payload: Whether to allow multipart/form-data payloads
response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"])
links: Dictionary of OpenAPI links to include in the response definition.
Raises: Raises:
DuplicateOperationIdError: If operation_id already exists in registry DuplicateOperationIdError: If operation_id already exists in registry
@@ -140,7 +144,9 @@ def register_tool(
"query_params": query_params or [], "query_params": query_params or [],
"tags": tags or ["default"], "tags": tags or ["default"],
"deprecated": deprecated, "deprecated": deprecated,
"allow_multipart_payload": allow_multipart_payload "allow_multipart_payload": allow_multipart_payload,
"response_content_types": response_content_types or ["application/json"],
"links": links
}) })

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Dict, Any, Optional, Type, List from typing import Dict, Any, Optional, Type, List
from pydantic import BaseModel from pydantic import BaseModel
from .schemas import ErrorResponse, BaseResponse
def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]: def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]:
@@ -161,57 +162,124 @@ def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]:
return clean_schema return clean_schema
def resolve_schema_refs(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively resolve $ref in schema by inlining the definition.
Useful for standalone schema parts like query parameters where global definitions aren't available.
"""
if not isinstance(schema, dict):
return schema
if "$ref" in schema:
ref = schema["$ref"]
# Handle #/$defs/Name syntax
if ref.startswith("#/$defs/"):
def_name = ref.split("/")[-1]
if def_name in definitions:
# Inline the definition (and resolve its refs recursively)
inlined = resolve_schema_refs(definitions[def_name], definitions)
# Merge any extra keys from the original schema (e.g. description override)
# Schema keys take precedence over definition keys
return {**inlined, **{k: v for k, v in schema.items() if k != "$ref"}}
# Recursively resolve properties
resolved = {}
for k, v in schema.items():
if k == "items":
resolved[k] = resolve_schema_refs(v, definitions)
elif k == "properties":
resolved[k] = {pk: resolve_schema_refs(pv, definitions) for pk, pv in v.items()}
elif k in ("allOf", "anyOf", "oneOf"):
resolved[k] = [resolve_schema_refs(i, definitions) for i in v]
else:
resolved[k] = v
return resolved
def build_responses( def build_responses(
response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any] response_model: Optional[Type[BaseModel]],
definitions: Dict[str, Any],
response_content_types: Optional[List[str]] = None,
links: Optional[Dict[str, Any]] = None,
method: str = "post"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Build OpenAPI responses object.""" """Build OpenAPI responses object."""
responses = {} responses = {}
# Success response (200) # Use a fresh list for response content types to avoid a shared mutable default.
if response_model: if response_content_types is None:
# Strip validation from response schema to save tokens response_content_types = ["application/json"]
schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization"))
schema = extract_definitions(schema, definitions)
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": schema
}
}
}
else: else:
responses["200"] = { # Copy provided list to ensure each call gets its own list
"description": "Successful response", response_content_types = list(response_content_types)
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"message": {"type": "string"}
}
}
}
}
}
# Standard error responses - MINIMIZED context # Success response (200)
# Annotate that these errors can occur, but provide no schema/content to save tokens. effective_model = response_model or BaseResponse
# The LLM knows what "Bad Request" or "Not Found" means. schema = strip_validation(pydantic_to_json_schema(effective_model, mode="serialization"))
error_codes = { schema = extract_definitions(schema, definitions)
"400": "Bad Request",
"401": "Unauthorized", content = {}
"403": "Forbidden", for ct in response_content_types:
"404": "Not Found", if ct == "application/json":
"422": "Validation Error", content[ct] = {"schema": schema}
"500": "Internal Server Error" else:
# For non-JSON types like CSV, we don't necessarily use the JSON schema
content[ct] = {"schema": {"type": "string", "format": "binary"}}
response_obj = {
"description": "Successful response",
"content": content
}
if links:
response_obj["links"] = links
responses["200"] = response_obj
# Standard error responses
error_configs = {
"400": ("Invalid JSON", "Request body must be valid JSON"),
"401": ("Unauthorized", None),
"403": ("Forbidden", "ERROR: Not authorized"),
"404": ("API route not found", "The requested URL /example/path was not found on the server."),
"422": ("Validation Error", None),
"500": ("Internal Server Error", "Something went wrong on the server")
} }
for code, desc in error_codes.items(): for code, (error_val, message_val) in error_configs.items():
# Generate a fresh schema for each error to customize examples
error_schema_raw = strip_validation(pydantic_to_json_schema(ErrorResponse, mode="serialization"))
error_schema = extract_definitions(error_schema_raw, definitions)
# Inject status-specific example
if "examples" in error_schema and len(error_schema["examples"]) > 0:
example = {
"success": False,
"error": error_val
}
if message_val:
example["message"] = message_val
if code == "422":
example["error"] = "Validation Error: Input should be a valid string"
example["details"] = [
{
"input": "invalid_value",
"loc": ["field_name"],
"msg": "Input should be a valid string",
"type": "string_type",
"url": "https://errors.pydantic.dev/2.12/v/string_type"
}
]
error_schema["examples"] = [example]
responses[code] = { responses[code] = {
"description": desc "description": error_val,
# No "content" schema provided "content": {
"application/json": {
"schema": error_schema
}
}
} }
return responses return responses

View File

@@ -52,6 +52,16 @@ ALLOWED_LOG_FILES = Literal[
"app.php_errors.log", "execution_queue.log", "db_is_locked.log" "app.php_errors.log", "execution_queue.log", "db_is_locked.log"
] ]
ALLOWED_SCAN_TYPES = Literal["ARPSCAN", "NMAPDEV", "NMAP", "INTRNT", "AVAHISCAN", "NBTSCAN"]
ALLOWED_SESSION_CONNECTION_TYPES = Literal["Connected", "Reconnected", "New Device", "Down Reconnected"]
ALLOWED_SESSION_DISCONNECTION_TYPES = Literal["Disconnected", "Device Down", "Timeout"]
ALLOWED_EVENT_TYPES = Literal[
"Device Down", "Device Up", "New Device", "Connected", "Disconnected",
"IP Changed", "Down Reconnected"
]
def validate_mac(value: str) -> str: def validate_mac(value: str) -> str:
"""Validate and normalize MAC address format.""" """Validate and normalize MAC address format."""
@@ -89,14 +99,42 @@ def validate_column_identifier(value: str) -> str:
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
"""Standard API response wrapper.""" """
model_config = ConfigDict(extra="allow") Standard API response wrapper.
Note: The API often returns 200 OK for most operations; clients MUST parse the 'success'
boolean field to determine if the operation was actually successful.
"""
model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"success": True
}]
}
)
success: bool = Field(..., description="Whether the operation succeeded") success: bool = Field(..., description="Whether the operation succeeded")
message: Optional[str] = Field(None, description="Human-readable message") message: Optional[str] = Field(None, description="Human-readable message")
error: Optional[str] = Field(None, description="Error message if success=False") error: Optional[str] = Field(None, description="Error message if success=False")
class ErrorResponse(BaseResponse):
"""Standard error response model with details."""
model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"success": False,
"error": "Error message"
}]
}
)
success: bool = Field(False, description="Always False for errors")
details: Optional[Any] = Field(None, description="Detailed error information (e.g., validation errors)")
code: Optional[str] = Field(None, description="Internal error code")
class PaginatedResponse(BaseResponse): class PaginatedResponse(BaseResponse):
"""Response with pagination metadata.""" """Response with pagination metadata."""
total: int = Field(0, description="Total number of items") total: int = Field(0, description="Total number of items")
@@ -130,7 +168,19 @@ class DeviceSearchRequest(BaseModel):
class DeviceInfo(BaseModel): class DeviceInfo(BaseModel):
"""Detailed device information model (Raw record).""" """Detailed device information model (Raw record)."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"devMac": "00:11:22:33:44:55",
"devName": "My iPhone",
"devLastIP": "192.168.1.10",
"devVendor": "Apple",
"devStatus": "online",
"devFavorite": 0
}]
}
)
devMac: str = Field(..., description="Device MAC address") devMac: str = Field(..., description="Device MAC address")
devName: Optional[str] = Field(None, description="Device display name/alias") devName: Optional[str] = Field(None, description="Device display name/alias")
@@ -138,13 +188,27 @@ class DeviceInfo(BaseModel):
devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address") devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address")
devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address") devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address")
devVlan: Optional[str] = Field(None, description="VLAN identifier") devVlan: Optional[str] = Field(None, description="VLAN identifier")
devForceStatus: Optional[str] = Field(None, description="Force device status (online/offline/dont_force)") devForceStatus: Optional[Literal["online", "offline", "dont_force"]] = Field(
"dont_force",
description="Force device status (online/offline/dont_force)"
)
devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup") devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup")
devOwner: Optional[str] = Field(None, description="Device owner") devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type classification") devType: Optional[str] = Field(None, description="Device type classification")
devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)") devFavorite: Optional[int] = Field(
devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)") 0,
devStatus: Optional[str] = Field(None, description="Online/Offline status") description="Favorite flag (0=False, 1=True). Legacy boolean representation.",
json_schema_extra={"enum": [0, 1]}
)
devPresentLastScan: Optional[int] = Field(
None,
description="Present in last scan (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
devStatus: Optional[Literal["online", "offline"]] = Field(
None,
description="Online/Offline status"
)
devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)") devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)")
devNameSource: Optional[str] = Field(None, description="Source of devName") devNameSource: Optional[str] = Field(None, description="Source of devName")
devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN") devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN")
@@ -270,7 +334,18 @@ class CopyDeviceRequest(BaseModel):
class UpdateDeviceColumnRequest(BaseModel): class UpdateDeviceColumnRequest(BaseModel):
"""Request to update a specific device database column.""" """Request to update a specific device database column."""
columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name") columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name")
columnValue: Any = Field(..., description="New value for the column") columnValue: Union[str, int, bool, None] = Field(
...,
description="New value for the column",
json_schema_extra={
"oneOf": [
{"type": "string"},
{"type": "integer"},
{"type": "boolean"},
{"type": "null"}
]
}
)
class LockDeviceFieldRequest(BaseModel): class LockDeviceFieldRequest(BaseModel):
@@ -301,7 +376,13 @@ class DeviceUpdateRequest(BaseModel):
devName: Optional[str] = Field(None, description="Device name") devName: Optional[str] = Field(None, description="Device name")
devOwner: Optional[str] = Field(None, description="Device owner") devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type") devType: Optional[str] = Field(
None,
description="Device type",
json_schema_extra={
"examples": ["Phone", "Laptop", "Desktop", "Router", "IoT", "Camera", "Server", "TV"]
}
)
devVendor: Optional[str] = Field(None, description="Device vendor") devVendor: Optional[str] = Field(None, description="Device vendor")
devGroup: Optional[str] = Field(None, description="Device group") devGroup: Optional[str] = Field(None, description="Device group")
devLocation: Optional[str] = Field(None, description="Device location") devLocation: Optional[str] = Field(None, description="Device location")
@@ -340,10 +421,9 @@ class DeleteDevicesRequest(BaseModel):
class TriggerScanRequest(BaseModel): class TriggerScanRequest(BaseModel):
"""Request to trigger a network scan.""" """Request to trigger a network scan."""
type: str = Field( type: ALLOWED_SCAN_TYPES = Field(
"ARPSCAN", "ARPSCAN",
description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)", description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)"
json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]}
) )
@@ -420,7 +500,11 @@ class WakeOnLanRequest(BaseModel):
class WakeOnLanResponse(BaseResponse): class WakeOnLanResponse(BaseResponse):
"""Response for Wake-on-LAN operation.""" """Response for Wake-on-LAN operation."""
output: Optional[str] = Field(None, description="Command output") output: Optional[str] = Field(
None,
description="Command output",
json_schema_extra={"examples": ["Sent magic packet to AA:BB:CC:DD:EE:FF"]}
)
class TracerouteRequest(BaseModel): class TracerouteRequest(BaseModel):
@@ -507,7 +591,17 @@ class NetworkInterfacesResponse(BaseResponse):
class EventInfo(BaseModel): class EventInfo(BaseModel):
"""Event/alert information.""" """Event/alert information."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"eveMAC": "00:11:22:33:44:55",
"eveIP": "192.168.1.10",
"eveDateTime": "2024-01-29 10:00:00",
"eveEventType": "Device Down"
}]
}
)
eveRowid: Optional[int] = Field(None, description="Event row ID") eveRowid: Optional[int] = Field(None, description="Event row ID")
eveMAC: Optional[str] = Field(None, description="Device MAC address") eveMAC: Optional[str] = Field(None, description="Device MAC address")
@@ -547,9 +641,19 @@ class LastEventsResponse(BaseResponse):
class CreateEventRequest(BaseModel): class CreateEventRequest(BaseModel):
"""Request to create a device event.""" """Request to create a device event."""
ip: Optional[str] = Field("0.0.0.0", description="Device IP") ip: Optional[str] = Field("0.0.0.0", description="Device IP")
event_type: str = Field("Device Down", description="Event type") event_type: str = Field(
"Device Down",
description="Event type",
json_schema_extra={
"examples": ["Device Down", "Device Up", "New Device", "Connected", "Disconnected"]
}
)
additional_info: Optional[str] = Field("", description="Additional info") additional_info: Optional[str] = Field("", description="Additional info")
pending_alert: int = Field(1, description="Pending alert flag") pending_alert: int = Field(
1,
description="Pending alert flag (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
event_time: Optional[str] = Field(None, description="Event timestamp (ISO)") event_time: Optional[str] = Field(None, description="Event timestamp (ISO)")
@field_validator("ip", mode="before") @field_validator("ip", mode="before")
@@ -564,11 +668,18 @@ class CreateEventRequest(BaseModel):
# ============================================================================= # =============================================================================
# SESSIONS SCHEMAS # SESSIONS SCHEMAS
# ============================================================================= # =============================================================================
class SessionInfo(BaseModel):
"""Session information.""" """Session information."""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(
extra="allow",
json_schema_extra={
"examples": [{
"sesMac": "00:11:22:33:44:55",
"sesDateTimeConnection": "2024-01-29 08:00:00",
"sesDateTimeDisconnection": "2024-01-29 09:00:00",
"sesIPAddress": "192.168.1.10"
}]
}
)
sesRowid: Optional[int] = Field(None, description="Session row ID") sesRowid: Optional[int] = Field(None, description="Session row ID")
sesMac: Optional[str] = Field(None, description="Device MAC address") sesMac: Optional[str] = Field(None, description="Device MAC address")
@@ -583,8 +694,20 @@ class CreateSessionRequest(BaseModel):
ip: str = Field(..., description="Device IP") ip: str = Field(..., description="Device IP")
start_time: str = Field(..., description="Start time") start_time: str = Field(..., description="Start time")
end_time: Optional[str] = Field(None, description="End time") end_time: Optional[str] = Field(None, description="End time")
event_type_conn: str = Field("Connected", description="Connection event type") event_type_conn: str = Field(
event_type_disc: str = Field("Disconnected", description="Disconnection event type") "Connected",
description="Connection event type",
json_schema_extra={
"examples": ["Connected", "Reconnected", "New Device", "Down Reconnected"]
}
)
event_type_disc: str = Field(
"Disconnected",
description="Disconnection event type",
json_schema_extra={
"examples": ["Disconnected", "Device Down", "Timeout"]
}
)
@field_validator("mac") @field_validator("mac")
@classmethod @classmethod
@@ -620,7 +743,11 @@ class InAppNotification(BaseModel):
guid: Optional[str] = Field(None, description="Unique notification GUID") guid: Optional[str] = Field(None, description="Unique notification GUID")
text: str = Field(..., description="Notification text content") text: str = Field(..., description="Notification text content")
level: NOTIFICATION_LEVELS = Field("info", description="Notification level") level: NOTIFICATION_LEVELS = Field("info", description="Notification level")
read: Optional[int] = Field(0, description="Read status (0 or 1)") read: Optional[int] = Field(
0,
description="Read status (0 or 1)",
json_schema_extra={"enum": [0, 1]}
)
created_at: Optional[str] = Field(None, description="Creation timestamp") created_at: Optional[str] = Field(None, description="Creation timestamp")
@@ -665,10 +792,12 @@ class DbQueryRequest(BaseModel):
""" """
Request for raw database query. Request for raw database query.
WARNING: This is a highly privileged operation. WARNING: This is a highly privileged operation.
Can be used to read settings by querying the 'Settings' table.
""" """
rawSql: str = Field( rawSql: str = Field(
..., ...,
description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)" description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)",
json_schema_extra={"examples": ["U0VMRUNUICogRlJPTSBTZXR0aW5ncw=="]}
) )
# Legacy compatibility: removed strict safety check # Legacy compatibility: removed strict safety check
# TODO: SECURITY CRITICAL - Re-enable strict safety checks. # TODO: SECURITY CRITICAL - Re-enable strict safety checks.
@@ -690,9 +819,23 @@ class DbQueryRequest(BaseModel):
class DbQueryUpdateRequest(BaseModel): class DbQueryUpdateRequest(BaseModel):
"""Request for DB update query.""" """
Request for DB update query.
Can be used to update settings by targeting the 'Settings' table.
"""
columnName: str = Field(..., description="Column to filter by") columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to update") id: List[Union[str, int]] = Field(
...,
description="List of IDs to update (strings for MACs, integers for row IDs)",
json_schema_extra={
"items": {
"oneOf": [
{"type": "string", "description": "A string identifier (e.g., MAC address)"},
{"type": "integer", "description": "A numeric row ID"}
]
}
}
)
dbtable: ALLOWED_TABLES = Field(..., description="Table name") dbtable: ALLOWED_TABLES = Field(..., description="Table name")
columns: List[str] = Field(..., description="Columns to update") columns: List[str] = Field(..., description="Columns to update")
values: List[Any] = Field(..., description="New values") values: List[Any] = Field(..., description="New values")
@@ -715,9 +858,23 @@ class DbQueryUpdateRequest(BaseModel):
class DbQueryDeleteRequest(BaseModel): class DbQueryDeleteRequest(BaseModel):
"""Request for DB delete query.""" """
Request for DB delete query.
Can be used to delete settings by targeting the 'Settings' table.
"""
columnName: str = Field(..., description="Column to filter by") columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to delete") id: List[Union[str, int]] = Field(
...,
description="List of IDs to delete (strings for MACs, integers for row IDs)",
json_schema_extra={
"items": {
"oneOf": [
{"type": "string", "description": "A string identifier (e.g., MAC address)"},
{"type": "integer", "description": "A numeric row ID"}
]
}
}
)
dbtable: ALLOWED_TABLES = Field(..., description="Table name") dbtable: ALLOWED_TABLES = Field(..., description="Table name")
@field_validator("columnName") @field_validator("columnName")
@@ -772,3 +929,14 @@ class SettingValue(BaseModel):
class GetSettingResponse(BaseResponse): class GetSettingResponse(BaseResponse):
"""Response for getting a setting value.""" """Response for getting a setting value."""
value: Any = Field(None, description="The setting value") value: Any = Field(None, description="The setting value")
# =============================================================================
# GRAPHQL SCHEMAS
# =============================================================================
class GraphQLRequest(BaseModel):
"""Request payload for GraphQL queries."""
query: str = Field(..., description="GraphQL query string", json_schema_extra={"examples": ["{ devices { devMac devName } }"]})
variables: Optional[Dict[str, Any]] = Field(None, description="Variables for the GraphQL query")

View File

@@ -52,7 +52,7 @@ _rebuild_lock = threading.Lock()
def generate_openapi_spec( def generate_openapi_spec(
title: str = "NetAlertX API", title: str = "NetAlertX API",
version: str = "2.0.0", version: str = "2.0.0",
description: str = "NetAlertX Network Monitoring API - MCP Compatible", description: str = "NetAlertX Network Monitoring API - Official Documentation - MCP Compatible",
servers: Optional[List[Dict[str, str]]] = None, servers: Optional[List[Dict[str, str]]] = None,
flask_app: Optional[Any] = None flask_app: Optional[Any] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -80,12 +80,21 @@ def generate_openapi_spec(
"title": title, "title": title,
"version": version, "version": version,
"description": description, "description": description,
"termsOfService": "https://github.com/netalertx/NetAlertX/blob/main/LICENSE.txt",
"contact": { "contact": {
"name": "NetAlertX", "name": "Open Source Project - NetAlertX - Github",
"url": "https://github.com/jokob-sk/NetAlertX" "url": "https://github.com/netalertx/NetAlertX"
},
"license": {
"name": "Licensed under GPLv3",
"url": "https://www.gnu.org/licenses/gpl-3.0.html"
} }
}, },
"servers": servers or [{"url": "/", "description": "Local server"}], "externalDocs": {
"description": "NetAlertX Official Documentation",
"url": "https://docs.netalertx.com/"
},
"servers": servers or [{"url": "/", "description": "This NetAlertX instance"}],
"security": [ "security": [
{"BearerAuth": []} {"BearerAuth": []}
], ],
@@ -152,7 +161,11 @@ def generate_openapi_spec(
# Add responses # Add responses
operation["responses"] = build_responses( operation["responses"] = build_responses(
entry.get("response_model"), definitions entry.get("response_model"),
definitions,
response_content_types=entry.get("response_content_types", ["application/json"]),
links=entry.get("links"),
method=method
) )
spec["paths"][path][method] = operation spec["paths"][path][method] = operation

View File

@@ -44,7 +44,11 @@ def validate_request(
query_params: Optional[list[dict]] = None, query_params: Optional[list[dict]] = None,
validation_error_code: int = 422, validation_error_code: int = 422,
auth_callable: Optional[Callable[[], bool]] = None, auth_callable: Optional[Callable[[], bool]] = None,
allow_multipart_payload: bool = False allow_multipart_payload: bool = False,
exclude_from_spec: bool = False,
response_content_types: Optional[list[str]] = None,
links: Optional[dict] = None,
error_responses: Optional[dict] = None
): ):
""" """
Decorator to register a Flask route with the OpenAPI registry and validate incoming requests. Decorator to register a Flask route with the OpenAPI registry and validate incoming requests.
@@ -56,6 +60,10 @@ def validate_request(
- Supports auth_callable to check permissions before validation. - Supports auth_callable to check permissions before validation.
- Returns 422 (default) if validation fails. - Returns 422 (default) if validation fails.
- allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields. - allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields.
- exclude_from_spec: If True, this endpoint will be omitted from the generated OpenAPI specification.
- response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]).
- links: Dictionary of OpenAPI links to include in the response definition.
- error_responses: Dictionary of custom error examples (e.g. {"404": "Device not found"}).
""" """
def decorator(f: Callable) -> Callable: def decorator(f: Callable) -> Callable:
@@ -73,7 +81,11 @@ def validate_request(
"tags": tags, "tags": tags,
"path_params": path_params, "path_params": path_params,
"query_params": query_params, "query_params": query_params,
"allow_multipart_payload": allow_multipart_payload "allow_multipart_payload": allow_multipart_payload,
"exclude_from_spec": exclude_from_spec,
"response_content_types": response_content_types,
"links": links,
"error_responses": error_responses
} }
@wraps(f) @wraps(f)
@@ -150,6 +162,7 @@ def validate_request(
data = request.args.to_dict() data = request.args.to_dict()
validated_instance = request_model(**data) validated_instance = request_model(**data)
except ValidationError as e: except ValidationError as e:
# Use configured validation error code (default 422)
return _handle_validation_error(e, operation_id, validation_error_code) return _handle_validation_error(e, operation_id, validation_error_code)
except (TypeError, ValueError, KeyError) as e: except (TypeError, ValueError, KeyError) as e:
mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"]) mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"])

View File

@@ -50,28 +50,33 @@ def write_notification(content, level="alert", timestamp=None):
} }
# If file exists, load existing data, otherwise initialize as empty list # If file exists, load existing data, otherwise initialize as empty list
if os.path.exists(NOTIFICATION_API_FILE): try:
with open(NOTIFICATION_API_FILE, "r") as file: if os.path.exists(NOTIFICATION_API_FILE):
# Check if the file object is of type _io.TextIOWrapper with open(NOTIFICATION_API_FILE, "r") as file:
if isinstance(file, _io.TextIOWrapper): file_contents = file.read().strip()
file_contents = file.read() # Read file contents if file_contents:
if file_contents == "": notifications = json.loads(file_contents)
file_contents = "[]" # If file is empty, initialize as empty list if not isinstance(notifications, list):
mylog("error", "[Notification] Invalid format: not a list, resetting")
# mylog('debug', ['[Notification] User Notifications file: ', file_contents]) notifications = []
notifications = json.loads(file_contents) # Parse JSON data else:
else: notifications = []
mylog("none", "[Notification] File is not of type _io.TextIOWrapper") else:
notifications = [] notifications = []
else: except Exception as e:
mylog("error", [f"[Notification] Error reading notifications file: {e}"])
notifications = [] notifications = []
# Append new notification # Append new notification
notifications.append(notification) notifications.append(notification)
# Write updated data back to file # Write updated data back to file
with open(NOTIFICATION_API_FILE, "w") as file: try:
json.dump(notifications, file, indent=4) with open(NOTIFICATION_API_FILE, "w") as file:
json.dump(notifications, file, indent=4)
except Exception as e:
mylog("error", [f"[Notification] Error writing to notifications file: {e}"])
# Don't re-raise, just log. This prevents the API from crashing 500.
# Broadcast unread count update # Broadcast unread count update
try: try:

View File

@@ -197,7 +197,7 @@ class TestOpenAPISpecGenerator:
f"Path param '{param_name}' not defined: {method.upper()} {path}" f"Path param '{param_name}' not defined: {method.upper()} {path}"
def test_standard_error_responses(self): def test_standard_error_responses(self):
"""Operations should have minimal standard error responses (400, 403, 404, etc) without schema bloat.""" """Operations should have standard error responses (400, 403, 404, etc)."""
spec = generate_openapi_spec() spec = generate_openapi_spec()
expected_minimal_codes = ["400", "401", "403", "404", "500", "422"] expected_minimal_codes = ["400", "401", "403", "404", "500", "422"]
@@ -207,9 +207,9 @@ class TestOpenAPISpecGenerator:
continue continue
responses = details.get("responses", {}) responses = details.get("responses", {})
for code in expected_minimal_codes: for code in expected_minimal_codes:
assert code in responses, f"Missing minimal {code} response in: {method.upper()} {path}." assert code in responses, f"Missing {code} response in: {method.upper()} {path}."
# Verify no "content" or schema is present (minimalism) # Content should now be present (BaseResponse/Error schema)
assert "content" not in responses[code], f"Response {code} in {method.upper()} {path} should not have content/schema." assert "content" in responses[code], f"Response {code} in {method.upper()} {path} should have content/schema."
class TestMCPToolMapping: class TestMCPToolMapping:
@@ -239,9 +239,9 @@ class TestMCPToolMapping:
spec = generate_openapi_spec() spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec) tools = map_openapi_to_mcp_tools(spec)
search_tool = next((t for t in tools if t["name"] == "search_devices"), None) search_tool = next((t for t in tools if t["name"] == "search_devices_api"), None)
assert search_tool is not None assert search_tool is not None
assert "query" in search_tool["inputSchema"].get("required", []) assert "query" in search_tool["inputSchema"]["required"]
def test_tool_descriptions_present(self): def test_tool_descriptions_present(self):
"""All tools should have non-empty descriptions.""" """All tools should have non-empty descriptions."""