feat(api): MCP, OpenAPI & Dynamic Introspection

New Features:
- API endpoints now support comprehensive input validation with detailed error responses via Pydantic models.
- OpenAPI specification endpoint (/openapi.json) and interactive Swagger UI documentation (/docs) now available for API discovery.
- Enhanced MCP session lifecycle management with create, retrieve, and delete operations.
- Network diagnostic tools: traceroute, nslookup, NMAP scanning, and network topology viewing exposed via API.
- Device search, filtering by status (including 'offline'), and bulk operations (copy, delete, update).
- Wake-on-LAN functionality for remote device management.
- Added dynamic tool disablement and status reporting.

Bug Fixes:
- Fixed get_tools_status in registry to correctly return boolean values instead of None for enabled tools.
- Improved error handling for invalid API inputs with standardized validation responses.
- Fixed OPTIONS request handling for cross-origin requests.

Refactoring:
- Significant refactoring of api_server_start.py to use decorator-based validation (@validate_request).
This commit is contained in:
Adam Outler
2026-01-18 18:16:18 +00:00
parent cea3369b5e
commit ecea1d1fbd
46 changed files with 5195 additions and 1053 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -46,46 +46,46 @@ class PageQueryOptionsInput(InputObjectType):
# Device ObjectType
class Device(ObjectType):
rowid = Int()
devMac = String()
devName = String()
devOwner = String()
devType = String()
devVendor = String()
devFavorite = Int()
devGroup = String()
devComments = String()
devFirstConnection = String()
devLastConnection = String()
devLastIP = String()
devStaticIP = Int()
devScan = Int()
devLogEvents = Int()
devAlertEvents = Int()
devAlertDown = Int()
devSkipRepeated = Int()
devLastNotification = String()
devPresentLastScan = Int()
devIsNew = Int()
devLocation = String()
devIsArchived = Int()
devParentMAC = String()
devParentPort = String()
devIcon = String()
devGUID = String()
devSite = String()
devSSID = String()
devSyncHubNode = String()
devSourcePlugin = String()
devCustomProps = String()
devStatus = String()
devIsRandomMac = Int()
devParentChildrenCount = Int()
devIpLong = Int()
devFilterStatus = String()
devFQDN = String()
devParentRelType = String()
devReqNicsOnline = Int()
rowid = Int(description="Database row ID")
devMac = String(description="Device MAC address (e.g., 00:11:22:33:44:55)")
devName = String(description="Device display name/alias")
devOwner = String(description="Device owner")
devType = String(description="Device type classification")
devVendor = String(description="Hardware vendor from OUI lookup")
devFavorite = Int(description="Favorite flag (0 or 1)")
devGroup = String(description="Device group")
devComments = String(description="User comments")
devFirstConnection = String(description="Timestamp of first discovery")
devLastConnection = String(description="Timestamp of last connection")
devLastIP = String(description="Last known IP address")
devStaticIP = Int(description="Static IP flag (0 or 1)")
devScan = Int(description="Scan flag (0 or 1)")
devLogEvents = Int(description="Log events flag (0 or 1)")
devAlertEvents = Int(description="Alert events flag (0 or 1)")
devAlertDown = Int(description="Alert on down flag (0 or 1)")
devSkipRepeated = Int(description="Skip repeated alerts flag (0 or 1)")
devLastNotification = String(description="Timestamp of last notification")
devPresentLastScan = Int(description="Present in last scan flag (0 or 1)")
devIsNew = Int(description="Is new device flag (0 or 1)")
devLocation = String(description="Device location")
devIsArchived = Int(description="Is archived flag (0 or 1)")
devParentMAC = String(description="Parent device MAC address")
devParentPort = String(description="Parent device port")
devIcon = String(description="Device icon name")
devGUID = String(description="Unique device GUID")
devSite = String(description="Site name")
devSSID = String(description="SSID connected to")
devSyncHubNode = String(description="Sync hub node name")
devSourcePlugin = String(description="Plugin that discovered the device")
devCustomProps = String(description="Custom properties in JSON format")
devStatus = String(description="Online/Offline status")
devIsRandomMac = Int(description="Calculated: Is MAC address randomized?")
devParentChildrenCount = Int(description="Calculated: Number of children attached to this parent")
devIpLong = Int(description="Calculated: IP address in long format")
devFilterStatus = String(description="Calculated: Status for UI filtering")
devFQDN = String(description="Fully Qualified Domain Name")
devParentRelType = String(description="Relationship type to parent")
devReqNicsOnline = Int(description="Required NICs online flag")
class DeviceResult(ObjectType):
@@ -98,20 +98,20 @@ class DeviceResult(ObjectType):
# Setting ObjectType
class Setting(ObjectType):
setKey = String()
setName = String()
setDescription = String()
setType = String()
setOptions = String()
setGroup = String()
setValue = String()
setEvents = String()
setOverriddenByEnv = Boolean()
setKey = String(description="Unique configuration key")
setName = String(description="Human-readable setting name")
setDescription = String(description="Detailed description of the setting")
setType = String(description="Data type (string, bool, int, etc.)")
setOptions = String(description="JSON string of available options")
setGroup = String(description="UI group for categorization")
setValue = String(description="Current value")
setEvents = String(description="JSON string of events")
setOverriddenByEnv = Boolean(description="Whether the value is currently overridden by an environment variable")
class SettingResult(ObjectType):
settings = List(Setting)
count = Int()
settings = List(Setting, description="List of setting objects")
count = Int(description="Total count of settings")
# --- LANGSTRINGS ---
@@ -123,48 +123,48 @@ _langstrings_cache_mtime = {} # tracks last modified times
# LangString ObjectType
class LangString(ObjectType):
langCode = String()
langStringKey = String()
langStringText = String()
langCode = String(description="Language code (e.g., en_us, de_de)")
langStringKey = String(description="Unique translation key")
langStringText = String(description="Translated text content")
class LangStringResult(ObjectType):
langStrings = List(LangString)
count = Int()
langStrings = List(LangString, description="List of language string objects")
count = Int(description="Total count of strings")
# --- APP EVENTS ---
class AppEvent(ObjectType):
Index = Int()
GUID = String()
AppEventProcessed = Int()
DateTimeCreated = String()
Index = Int(description="Internal index")
GUID = String(description="Unique event GUID")
AppEventProcessed = Int(description="Processing status (0 or 1)")
DateTimeCreated = String(description="Event creation timestamp")
ObjectType = String()
ObjectGUID = String()
ObjectPlugin = String()
ObjectPrimaryID = String()
ObjectSecondaryID = String()
ObjectForeignKey = String()
ObjectIndex = Int()
ObjectType = String(description="Type of the related object (Device, Setting, etc.)")
ObjectGUID = String(description="GUID of the related object")
ObjectPlugin = String(description="Plugin associated with the object")
ObjectPrimaryID = String(description="Primary identifier of the object")
ObjectSecondaryID = String(description="Secondary identifier of the object")
ObjectForeignKey = String(description="Foreign key reference")
ObjectIndex = Int(description="Object index")
ObjectIsNew = Int()
ObjectIsArchived = Int()
ObjectStatusColumn = String()
ObjectStatus = String()
ObjectIsNew = Int(description="Is the object new? (0 or 1)")
ObjectIsArchived = Int(description="Is the object archived? (0 or 1)")
ObjectStatusColumn = String(description="Column used for status")
ObjectStatus = String(description="Object status value")
AppEventType = String()
AppEventType = String(description="Type of application event")
Helper1 = String()
Helper2 = String()
Helper3 = String()
Extra = String()
Helper1 = String(description="Generic helper field 1")
Helper2 = String(description="Generic helper field 2")
Helper3 = String(description="Generic helper field 3")
Extra = String(description="Additional JSON data")
class AppEventResult(ObjectType):
appEvents = List(AppEvent)
count = Int()
appEvents = List(AppEvent, description="List of application events")
count = Int(description="Total count of events")
# ----------------------------------------------------------------------------------------------

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
import re
from typing import Any
import graphene
from .registry import register_tool, _operation_ids
def introspect_graphql_schema(schema: graphene.Schema):
"""
Introspect the GraphQL schema and register endpoints in the OpenAPI registry.
This bridges the 'living code' (GraphQL) to the OpenAPI spec.
"""
# Graphene schema introspection
graphql_schema = schema.graphql_schema
query_type = graphql_schema.query_type
if not query_type:
return
# We register the main /graphql endpoint once
register_tool(
path="/graphql",
method="POST",
operation_id="graphql_query",
summary="GraphQL Endpoint",
description="Execute arbitrary GraphQL queries against the system schema.",
tags=["graphql"]
)
def _flask_to_openapi_path(flask_path: str) -> str:
"""Convert Flask path syntax to OpenAPI path syntax."""
# Handles <converter:variable> -> {variable} and <variable> -> {variable}
return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path)
def introspect_flask_app(app: Any):
"""
Introspect the Flask application to find routes decorated with @validate_request
and register them in the OpenAPI registry.
"""
registered_ops = set()
for rule in app.url_map.iter_rules():
view_func = app.view_functions.get(rule.endpoint)
if not view_func:
continue
# Check for our decorator's metadata
metadata = getattr(view_func, "_openapi_metadata", None)
if not metadata:
# Fallback for wrapped functions
if hasattr(view_func, "__wrapped__"):
metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None)
if metadata:
op_id = metadata["operation_id"]
# Register the tool with real path and method from Flask
for method in rule.methods:
if method in ("OPTIONS", "HEAD"):
continue
# Create a unique key for this path/method/op combination if needed,
# but operationId must be unique globally.
# If the same function is mounted on multiple paths, we append a suffix
path = _flask_to_openapi_path(str(rule))
# Check if this operation (path + method) is already registered
op_key = f"{method}:{path}"
if op_key in registered_ops:
continue
# Determine tags - create a copy to avoid mutating shared metadata
tags = list(metadata.get("tags") or ["rest"])
if path.startswith("/mcp/"):
# Move specific tags to secondary position or just add MCP
if "rest" in tags:
tags.remove("rest")
if "mcp" not in tags:
tags.append("mcp")
# Ensure unique operationId
original_op_id = op_id
unique_op_id = op_id
count = 1
while unique_op_id in _operation_ids:
unique_op_id = f"{op_id}_{count}"
count += 1
register_tool(
path=path,
method=method,
operation_id=unique_op_id,
original_operation_id=original_op_id if unique_op_id != original_op_id else None,
summary=metadata["summary"],
description=metadata["description"],
request_model=metadata.get("request_model"),
response_model=metadata.get("response_model"),
path_params=metadata.get("path_params"),
query_params=metadata.get("query_params"),
tags=tags,
allow_multipart_payload=metadata.get("allow_multipart_payload", False)
)
registered_ops.add(op_key)

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import threading
from copy import deepcopy
from typing import List, Dict, Any, Literal, Optional, Type, Set
from pydantic import BaseModel
# Thread-safe registry
_registry: List[Dict[str, Any]] = []
_registry_lock = threading.Lock()
_operation_ids: Set[str] = set()
_disabled_tools: Set[str] = set()
class DuplicateOperationIdError(Exception):
"""Raised when an operationId is registered more than once."""
pass
def set_tool_disabled(operation_id: str, disabled: bool = True) -> bool:
"""
Enable or disable a tool by operation_id.
Args:
operation_id: The unique operation_id of the tool
disabled: True to disable, False to enable
Returns:
bool: True if operation_id exists, False otherwise
"""
with _registry_lock:
if operation_id not in _operation_ids:
return False
if disabled:
_disabled_tools.add(operation_id)
else:
_disabled_tools.discard(operation_id)
return True
def is_tool_disabled(operation_id: str) -> bool:
"""
Check if a tool is disabled.
Checks both the unique operation_id and the original_operation_id.
"""
with _registry_lock:
if operation_id in _disabled_tools:
return True
# Also check if the original base ID is disabled
for entry in _registry:
if entry["operation_id"] == operation_id:
orig_id = entry.get("original_operation_id")
if orig_id and orig_id in _disabled_tools:
return True
return False
def get_disabled_tools() -> List[str]:
"""Get list of all disabled operation_ids."""
with _registry_lock:
return list(_disabled_tools)
def get_tools_status() -> List[Dict[str, Any]]:
"""
Get a list of all registered tools and their disabled status.
Useful for backend-to-frontend communication.
"""
tools = []
with _registry_lock:
disabled_snapshot = _disabled_tools.copy()
for entry in _registry:
op_id = entry["operation_id"]
orig_id = entry.get("original_operation_id")
is_disabled = bool(op_id in disabled_snapshot or (orig_id and orig_id in disabled_snapshot))
tools.append({
"operation_id": op_id,
"summary": entry["summary"],
"disabled": is_disabled
})
return tools
def register_tool(
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"],
operation_id: str,
summary: str,
description: str,
request_model: Optional[Type[BaseModel]] = None,
response_model: Optional[Type[BaseModel]] = None,
path_params: Optional[List[Dict[str, Any]]] = None,
query_params: Optional[List[Dict[str, Any]]] = None,
tags: Optional[List[str]] = None,
deprecated: bool = False,
original_operation_id: Optional[str] = None,
allow_multipart_payload: bool = False
) -> None:
"""
Register an API endpoint for OpenAPI spec generation.
Args:
path: URL path (e.g., "/devices/{mac}")
method: HTTP method
operation_id: Unique identifier for this operation (MUST be unique across entire spec)
summary: Short summary for the operation
description: Detailed description
request_model: Pydantic model for request body (POST/PUT/PATCH)
response_model: Pydantic model for success response
path_params: List of path parameter definitions
query_params: List of query parameter definitions
tags: OpenAPI tags for grouping
deprecated: Whether this endpoint is deprecated
original_operation_id: The base ID before suffixing (for disablement mapping)
allow_multipart_payload: Whether to allow multipart/form-data payloads
Raises:
DuplicateOperationIdError: If operation_id already exists in registry
"""
with _registry_lock:
if operation_id in _operation_ids:
raise DuplicateOperationIdError(
f"operationId '{operation_id}' is already registered. "
"Each operationId must be unique across the entire API."
)
_operation_ids.add(operation_id)
_registry.append({
"path": path,
"method": method.upper(),
"operation_id": operation_id,
"original_operation_id": original_operation_id,
"summary": summary,
"description": description,
"request_model": request_model,
"response_model": response_model,
"path_params": path_params or [],
"query_params": query_params or [],
"tags": tags or ["default"],
"deprecated": deprecated,
"allow_multipart_payload": allow_multipart_payload
})
def clear_registry() -> None:
"""Clear all registered endpoints (useful for testing)."""
with _registry_lock:
_registry.clear()
_operation_ids.clear()
_disabled_tools.clear()
def get_registry() -> List[Dict[str, Any]]:
"""Get a deep copy of the current registry to prevent external mutation."""
with _registry_lock:
return deepcopy(_registry)

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
from typing import Dict, Any, Optional, Type, List
from pydantic import BaseModel
def pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
"""
Convert a Pydantic model to JSON Schema (OpenAPI 3.1 compatible).
Uses Pydantic's built-in schema generation which produces
JSON Schema Draft 2020-12 compatible output.
Args:
model: Pydantic BaseModel class
Returns:
JSON Schema dictionary
"""
# Pydantic v2 uses model_json_schema()
schema = model.model_json_schema(mode="serialization")
# Remove $defs if empty (cleaner output)
if "$defs" in schema and not schema["$defs"]:
del schema["$defs"]
return schema
def build_parameters(entry: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Build OpenAPI parameters array from path and query params."""
parameters = []
# Path parameters
for param in entry.get("path_params", []):
parameters.append({
"name": param["name"],
"in": "path",
"required": True,
"description": param.get("description", ""),
"schema": param.get("schema", {"type": "string"})
})
# Query parameters
for param in entry.get("query_params", []):
parameters.append({
"name": param["name"],
"in": "query",
"required": param.get("required", False),
"description": param.get("description", ""),
"schema": param.get("schema", {"type": "string"})
})
return parameters
def extract_definitions(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively extract $defs from a schema and move them to the definitions dict.
Also rewrite $ref to point to #/components/schemas/.
"""
if not isinstance(schema, dict):
return schema
# Extract definitions
if "$defs" in schema:
for name, definition in schema["$defs"].items():
# Recursively process the definition itself before adding it
definitions[name] = extract_definitions(definition, definitions)
del schema["$defs"]
# Rewrite references
if "$ref" in schema and schema["$ref"].startswith("#/$defs/"):
ref_name = schema["$ref"].split("/")[-1]
schema["$ref"] = f"#/components/schemas/{ref_name}"
# Recursively process properties
for key, value in schema.items():
if isinstance(value, dict):
schema[key] = extract_definitions(value, definitions)
elif isinstance(value, list):
schema[key] = [extract_definitions(item, definitions) for item in value]
return schema
def build_request_body(
model: Optional[Type[BaseModel]],
definitions: Dict[str, Any],
allow_multipart_payload: bool = False
) -> Optional[Dict[str, Any]]:
"""Build OpenAPI requestBody from Pydantic model."""
if model is None:
return None
schema = pydantic_to_json_schema(model)
schema = extract_definitions(schema, definitions)
content = {
"application/json": {
"schema": schema
}
}
if allow_multipart_payload:
content["multipart/form-data"] = {
"schema": schema
}
return {
"required": True,
"content": content
}
def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively remove validation constraints from a JSON schema.
Keeps structure and descriptions, but removes pattern, minLength, etc.
This saves context tokens for LLMs which don't validate server output.
"""
if not isinstance(schema, dict):
return schema
# Keys to remove
validation_keys = [
"pattern", "minLength", "maxLength", "minimum", "maximum",
"exclusiveMinimum", "exclusiveMaximum", "multipleOf", "minItems",
"maxItems", "uniqueItems", "minProperties", "maxProperties"
]
clean_schema = {k: v for k, v in schema.items() if k not in validation_keys}
# Recursively clean sub-schemas
if "properties" in clean_schema:
clean_schema["properties"] = {
k: strip_validation(v) for k, v in clean_schema["properties"].items()
}
if "items" in clean_schema:
clean_schema["items"] = strip_validation(clean_schema["items"])
if "allOf" in clean_schema:
clean_schema["allOf"] = [strip_validation(x) for x in clean_schema["allOf"]]
if "anyOf" in clean_schema:
clean_schema["anyOf"] = [strip_validation(x) for x in clean_schema["anyOf"]]
if "oneOf" in clean_schema:
clean_schema["oneOf"] = [strip_validation(x) for x in clean_schema["oneOf"]]
if "$defs" in clean_schema:
clean_schema["$defs"] = {
k: strip_validation(v) for k, v in clean_schema["$defs"].items()
}
if "additionalProperties" in clean_schema and isinstance(clean_schema["additionalProperties"], dict):
clean_schema["additionalProperties"] = strip_validation(clean_schema["additionalProperties"])
return clean_schema
def build_responses(
response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any]
) -> Dict[str, Any]:
"""Build OpenAPI responses object."""
responses = {}
# Success response (200)
if response_model:
# Strip validation from response schema to save tokens
schema = strip_validation(pydantic_to_json_schema(response_model))
schema = extract_definitions(schema, definitions)
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": schema
}
}
}
else:
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"message": {"type": "string"}
}
}
}
}
}
# Standard error responses - MINIMIZED context
# Annotate that these errors can occur, but provide no schema/content to save tokens.
# The LLM knows what "Bad Request" or "Not Found" means.
error_codes = {
"400": "Bad Request",
"401": "Unauthorized",
"403": "Forbidden",
"404": "Not Found",
"422": "Validation Error",
"500": "Internal Server Error"
}
for code, desc in error_codes.items():
responses[code] = {
"description": desc
# No "content" schema provided
}
return responses

View File

@@ -0,0 +1,738 @@
#!/usr/bin/env python
"""
NetAlertX API Schema Definitions (Pydantic v2)
This module defines strict Pydantic models for all API request and response payloads.
These schemas serve as the single source of truth for:
1. Runtime validation of incoming requests
2. OpenAPI specification generation
3. MCP tool input schema derivation
Philosophy: "Code First, Spec Second" — these models ARE the contract.
"""
from __future__ import annotations
import re
import ipaddress
from typing import Optional, List, Literal, Any, Dict
from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict, RootModel
# Internal helper imports
from helper import sanitize_string
from plugin_helper import normalize_mac, is_mac
# =============================================================================
# COMMON PATTERNS & VALIDATORS
# =============================================================================
MAC_PATTERN = r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$"
IP_PATTERN = r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
COLUMN_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+$")
# Security whitelists & Literals for documentation
ALLOWED_DEVICE_COLUMNS = Literal[
"devName", "devOwner", "devType", "devVendor",
"devGroup", "devLocation", "devComments", "devFavorite",
"devParentMAC"
]
ALLOWED_NMAP_MODES = Literal[
"quick", "intense", "ping", "comprehensive", "fast", "normal", "detail", "skipdiscovery",
"-sS", "-sT", "-sU", "-sV", "-O"
]
NOTIFICATION_LEVELS = Literal["info", "warning", "error", "alert"]
ALLOWED_TABLES = Literal["Devices", "Events", "Sessions", "Settings", "CurrentScan", "Online_History", "Plugins_Objects"]
ALLOWED_LOG_FILES = Literal[
"app.log", "app_front.log", "IP_changes.log", "stdout.log", "stderr.log",
"app.php_errors.log", "execution_queue.log", "db_is_locked.log"
]
def validate_mac(value: str) -> str:
"""Validate and normalize MAC address format."""
# Allow "Internet" as a special case for the gateway/WAN device
if value.lower() == "internet":
return "Internet"
if not is_mac(value):
raise ValueError(f"Invalid MAC address format: {value}")
return normalize_mac(value)
def validate_ip(value: str) -> str:
"""Validate IP address format (IPv4 or IPv6) using stdlib ipaddress.
Returns the canonical string form of the IP address.
"""
try:
return str(ipaddress.ip_address(value))
except ValueError as err:
raise ValueError(f"Invalid IP address: {value}") from err
def validate_column_identifier(value: str) -> str:
"""Validate a column identifier to prevent SQL injection."""
if not COLUMN_NAME_PATTERN.match(value):
raise ValueError("Invalid column name format")
return value
# =============================================================================
# BASE RESPONSE MODELS
# =============================================================================
class BaseResponse(BaseModel):
"""Standard API response wrapper."""
model_config = ConfigDict(extra="allow")
success: bool = Field(..., description="Whether the operation succeeded")
message: Optional[str] = Field(None, description="Human-readable message")
error: Optional[str] = Field(None, description="Error message if success=False")
class PaginatedResponse(BaseResponse):
"""Response with pagination metadata."""
total: int = Field(0, description="Total number of items")
page: int = Field(1, ge=1, description="Current page number")
per_page: int = Field(50, ge=1, le=500, description="Items per page")
# =============================================================================
# DEVICE SCHEMAS
# =============================================================================
class DeviceSearchRequest(BaseModel):
"""Request payload for searching devices."""
model_config = ConfigDict(str_strip_whitespace=True)
query: str = Field(
...,
min_length=1,
max_length=256,
description="Search term: IP address, MAC address, device name, or vendor",
json_schema_extra={"examples": ["192.168.1.1", "Apple", "00:11:22:33:44:55"]}
)
limit: int = Field(
50,
ge=1,
le=500,
description="Maximum number of results to return"
)
class DeviceInfo(BaseModel):
"""Detailed device information model (Raw record)."""
model_config = ConfigDict(extra="allow")
devMac: str = Field(..., description="Device MAC address")
devName: Optional[str] = Field(None, description="Device display name/alias")
devLastIP: Optional[str] = Field(None, description="Last known IP address")
devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup")
devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type classification")
devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)")
devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)")
devStatus: Optional[str] = Field(None, description="Online/Offline status")
class DeviceSearchResponse(BaseResponse):
"""Response payload for device search."""
devices: List[DeviceInfo] = Field(default_factory=list, description="List of matching devices")
class DeviceListRequest(BaseModel):
"""Request for listing devices by status."""
status: Optional[Literal[
"connected", "down", "favorites", "new", "archived", "all", "my",
"offline"
]] = Field(
None,
description="Filter devices by status (connected, down, favorites, new, archived, all, my, offline)"
)
class DeviceListResponse(RootModel):
"""Response with list of devices."""
root: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
class DeviceListWrapperResponse(BaseResponse):
"""Wrapped response with list of devices."""
devices: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
class GetDeviceRequest(BaseModel):
"""Path parameter for getting a specific device."""
mac: str = Field(
...,
description="Device MAC address",
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
)
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
class GetDeviceResponse(BaseResponse):
"""Wrapped response for getting device details."""
device: Optional[DeviceInfo] = Field(None, description="Device details if found")
class GetDeviceWrapperResponse(BaseResponse):
"""Wrapped response for getting a single device (e.g. latest)."""
device: Optional[DeviceInfo] = Field(None, description="Device details")
class SetDeviceAliasRequest(BaseModel):
"""Request to set a device alias/name."""
alias: str = Field(
...,
min_length=1,
max_length=128,
description="New display name/alias for the device"
)
@field_validator("alias")
@classmethod
def sanitize_alias(cls, v: str) -> str:
return sanitize_string(v)
class DeviceTotalsResponse(RootModel):
"""Response with device statistics."""
root: List[int] = Field(default_factory=list, description="List of counts: [all, online, favorites, new, offline, archived]")
class DeviceExportRequest(BaseModel):
"""Request for exporting devices."""
format: Literal["csv", "json"] = Field(
"csv",
description="Export format: csv or json"
)
class DeviceExportResponse(BaseModel):
"""Raw response for device export in JSON format."""
columns: List[str] = Field(..., description="Column names")
data: List[Dict[str, Any]] = Field(..., description="Device records")
class DeviceImportRequest(BaseModel):
"""Request for importing devices."""
content: Optional[str] = Field(
None,
description="Base64-encoded CSV or JSON content to import"
)
class DeviceImportResponse(BaseResponse):
"""Response for device import operation."""
imported: int = Field(0, description="Number of devices imported")
skipped: int = Field(0, description="Number of devices skipped")
errors: List[str] = Field(default_factory=list, description="List of import errors")
class CopyDeviceRequest(BaseModel):
"""Request to copy device settings."""
macFrom: str = Field(..., description="Source MAC address")
macTo: str = Field(..., description="Destination MAC address")
@field_validator("macFrom", "macTo")
@classmethod
def validate_mac_addresses(cls, v: str) -> str:
return validate_mac(v)
class UpdateDeviceColumnRequest(BaseModel):
"""Request to update a specific device database column."""
columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name")
columnValue: Any = Field(..., description="New value for the column")
class DeviceUpdateRequest(BaseModel):
"""Request to update device fields (create/update)."""
model_config = ConfigDict(extra="allow")
devName: Optional[str] = Field(None, description="Device name")
devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type")
devVendor: Optional[str] = Field(None, description="Device vendor")
devGroup: Optional[str] = Field(None, description="Device group")
devLocation: Optional[str] = Field(None, description="Device location")
devComments: Optional[str] = Field(None, description="Comments")
createNew: bool = Field(False, description="Create new device if not exists")
@field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments")
@classmethod
def sanitize_text_fields(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
return sanitize_string(v)
class DeleteDevicesRequest(BaseModel):
"""Request to delete multiple devices."""
macs: List[str] = Field([], description="List of MACs to delete")
confirm_delete_all: bool = Field(False, description="Explicit flag to delete ALL devices when macs is empty")
@field_validator("macs")
@classmethod
def validate_mac_list(cls, v: List[str]) -> List[str]:
return [validate_mac(mac) for mac in v]
@model_validator(mode="after")
def check_delete_all_safety(self) -> DeleteDevicesRequest:
if not self.macs and not self.confirm_delete_all:
raise ValueError("Must provide at least one MAC or set confirm_delete_all=True")
return self
# =============================================================================
# NETWORK TOOLS SCHEMAS
# =============================================================================
class TriggerScanRequest(BaseModel):
"""Request to trigger a network scan."""
type: str = Field(
"ARPSCAN",
description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)",
json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]}
)
class TriggerScanResponse(BaseResponse):
"""Response for scan trigger."""
scan_type: Optional[str] = Field(None, description="Type of scan that was triggered")
class OpenPortsRequest(BaseModel):
"""Request for getting open ports."""
target: str = Field(
...,
description="Target IP address or MAC address to check ports for",
json_schema_extra={"examples": ["192.168.1.50", "00:11:22:33:44:55"]}
)
@field_validator("target")
@classmethod
def validate_target(cls, v: str) -> str:
"""Validate target is either a valid IP or MAC address."""
# Try IP first
try:
return validate_ip(v)
except ValueError:
pass
# Try MAC
return validate_mac(v)
class OpenPortsResponse(BaseResponse):
"""Response with open ports information."""
target: str = Field(..., description="Target that was scanned")
open_ports: List[Any] = Field(default_factory=list, description="List of open port objects or numbers")
class WakeOnLanRequest(BaseModel):
"""Request to send Wake-on-LAN packet."""
devMac: Optional[str] = Field(
None,
description="Target device MAC address",
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
)
devLastIP: Optional[str] = Field(
None,
alias="ip",
description="Target device IP (MAC will be resolved if not provided)",
json_schema_extra={"examples": ["192.168.1.50"]}
)
# Note: alias="ip" means input JSON can use "ip".
# But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip".
model_config = ConfigDict(populate_by_name=True)
@field_validator("devMac")
@classmethod
def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
return validate_mac(v)
return v
@field_validator("devLastIP")
@classmethod
def validate_ip_if_provided(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
return validate_ip(v)
return v
@model_validator(mode="after")
def require_mac_or_ip(self) -> "WakeOnLanRequest":
"""Ensure at least one of devMac or devLastIP is provided."""
if self.devMac is None and self.devLastIP is None:
raise ValueError("Either 'devMac' or 'devLastIP' (alias 'ip') must be provided")
return self
class WakeOnLanResponse(BaseResponse):
"""Response for Wake-on-LAN operation."""
output: Optional[str] = Field(None, description="Command output")
class TracerouteRequest(BaseModel):
"""Request to perform traceroute."""
devLastIP: str = Field(
...,
description="Target IP address for traceroute",
json_schema_extra={"examples": ["8.8.8.8", "192.168.1.1"]}
)
@field_validator("devLastIP")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class TracerouteResponse(BaseResponse):
"""Response with traceroute results."""
output: List[str] = Field(default_factory=list, description="Traceroute hop output lines")
class NmapScanRequest(BaseModel):
"""Request to perform NMAP scan."""
scan: str = Field(
...,
description="Target IP address for NMAP scan"
)
mode: ALLOWED_NMAP_MODES = Field(
...,
description="NMAP scan mode/arguments (restricted to safe options)"
)
@field_validator("scan")
@classmethod
def validate_scan_target(cls, v: str) -> str:
return validate_ip(v)
class NslookupRequest(BaseModel):
"""Request for DNS lookup."""
devLastIP: str = Field(
...,
description="IP address to perform reverse DNS lookup"
)
@field_validator("devLastIP")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class NslookupResponse(BaseResponse):
"""Response for DNS lookup operation."""
output: List[str] = Field(default_factory=list, description="Nslookup output lines")
class NmapScanResponse(BaseResponse):
"""Response for NMAP scan operation."""
mode: Optional[str] = Field(None, description="NMAP scan mode")
ip: Optional[str] = Field(None, description="Target IP address")
output: List[str] = Field(default_factory=list, description="NMAP scan output lines")
class NetworkTopologyResponse(BaseResponse):
"""Response with network topology data."""
nodes: List[dict] = Field(default_factory=list, description="Network nodes")
links: List[dict] = Field(default_factory=list, description="Network connections")
class InternetInfoResponse(BaseResponse):
"""Response for internet information."""
output: Dict[str, Any] = Field(..., description="Details about the internet connection.")
class NetworkInterfacesResponse(BaseResponse):
"""Response with network interface information."""
interfaces: Dict[str, Any] = Field(..., description="Details about network interfaces.")
# =============================================================================
# EVENTS SCHEMAS
# =============================================================================
class EventInfo(BaseModel):
"""Event/alert information."""
model_config = ConfigDict(extra="allow")
eveRowid: Optional[int] = Field(None, description="Event row ID")
eveMAC: Optional[str] = Field(None, description="Device MAC address")
eveIP: Optional[str] = Field(None, description="Device IP address")
eveDateTime: Optional[str] = Field(None, description="Event timestamp")
eveEventType: Optional[str] = Field(None, description="Type of event")
evePreviousIP: Optional[str] = Field(None, description="Previous IP if changed")
class RecentEventsRequest(BaseModel):
"""Request for recent events."""
hours: int = Field(
24,
ge=1,
le=720,
description="Number of hours to look back for events"
)
limit: int = Field(
100,
ge=1,
le=1000,
description="Maximum number of events to return"
)
class RecentEventsResponse(BaseResponse):
"""Response with recent events."""
hours: int = Field(..., description="The time window in hours")
events: List[EventInfo] = Field(default_factory=list, description="List of recent events")
class LastEventsResponse(BaseResponse):
"""Response with last N events."""
events: List[EventInfo] = Field(default_factory=list, description="List of last events")
class CreateEventRequest(BaseModel):
"""Request to create a device event."""
ip: Optional[str] = Field("0.0.0.0", description="Device IP")
event_type: str = Field("Device Down", description="Event type")
additional_info: Optional[str] = Field("", description="Additional info")
pending_alert: int = Field(1, description="Pending alert flag")
event_time: Optional[str] = Field(None, description="Event timestamp (ISO)")
@field_validator("ip", mode="before")
@classmethod
def validate_ip_field(cls, v: Optional[str]) -> str:
"""Validate and normalize IP address, defaulting to 0.0.0.0."""
if v is None or v == "":
return "0.0.0.0"
return validate_ip(v)
# =============================================================================
# SESSIONS SCHEMAS
# =============================================================================
class SessionInfo(BaseModel):
"""Session information."""
model_config = ConfigDict(extra="allow")
sesRowid: Optional[int] = Field(None, description="Session row ID")
sesMac: Optional[str] = Field(None, description="Device MAC address")
sesDateTimeConnection: Optional[str] = Field(None, description="Connection timestamp")
sesDateTimeDisconnection: Optional[str] = Field(None, description="Disconnection timestamp")
sesIPAddress: Optional[str] = Field(None, description="IP address during session")
class CreateSessionRequest(BaseModel):
"""Request to create a session."""
mac: str = Field(..., description="Device MAC")
ip: str = Field(..., description="Device IP")
start_time: str = Field(..., description="Start time")
end_time: Optional[str] = Field(None, description="End time")
event_type_conn: str = Field("Connected", description="Connection event type")
event_type_disc: str = Field("Disconnected", description="Disconnection event type")
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
@field_validator("ip")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class DeleteSessionRequest(BaseModel):
"""Request to delete sessions for a MAC."""
mac: str = Field(..., description="Device MAC")
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
# =============================================================================
# MESSAGING / IN-APP NOTIFICATIONS SCHEMAS
# =============================================================================
class InAppNotification(BaseModel):
"""In-app notification model."""
model_config = ConfigDict(extra="allow")
id: Optional[int] = Field(None, description="Notification ID")
guid: Optional[str] = Field(None, description="Unique notification GUID")
text: str = Field(..., description="Notification text content")
level: NOTIFICATION_LEVELS = Field("info", description="Notification level")
read: Optional[int] = Field(0, description="Read status (0 or 1)")
created_at: Optional[str] = Field(None, description="Creation timestamp")
class CreateNotificationRequest(BaseModel):
"""Request to create an in-app notification."""
content: str = Field(
...,
min_length=1,
max_length=1024,
description="Notification content"
)
level: NOTIFICATION_LEVELS = Field(
"info",
description="Notification severity level"
)
# =============================================================================
# SYNC SCHEMAS
# =============================================================================
class SyncPushRequest(BaseModel):
"""Request to push data to sync."""
data: dict = Field(..., description="Data to sync")
node_name: str = Field(..., description="Name of the node sending data")
plugin: str = Field(..., description="Plugin identifier")
class SyncPullResponse(BaseResponse):
"""Response with sync data."""
data: Optional[dict] = Field(None, description="Synchronized data")
last_sync: Optional[str] = Field(None, description="Last sync timestamp")
# =============================================================================
# DB QUERY SCHEMAS (Raw SQL)
# =============================================================================
class DbQueryRequest(BaseModel):
"""
Request for raw database query.
WARNING: This is a highly privileged operation.
"""
rawSql: str = Field(
...,
description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)"
)
# Legacy compatibility: removed strict safety check
# TODO: SECURITY CRITICAL - Re-enable strict safety checks.
# The `confirm_dangerous_query` default was relaxed to `True` to maintain backward compatibility
# with the legacy frontend which sends raw SQL directly.
#
# CONTEXT: This explicit safety check was introduced with the new Pydantic validation layer.
# The legacy PHP frontend predates these formal schemas and does not send the
# `confirm_dangerous_query` flag, causing 422 Validation Errors when this check is enforced.
#
# Actionable Advice:
# 1. Implement a parser to strictly whitelist only `SELECT` statements if raw SQL is required.
# 2. Migrate the frontend to use structured endpoints (e.g., `/devices/search`, `/dbquery/read`) instead of raw SQL.
# 3. Once migrated, revert `confirm_dangerous_query` default to `False` and enforce the check.
confirm_dangerous_query: bool = Field(
True,
description="Required to be True to acknowledge the risks of raw SQL execution"
)
class DbQueryUpdateRequest(BaseModel):
"""Request for DB update query."""
columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to update")
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
columns: List[str] = Field(..., description="Columns to update")
values: List[Any] = Field(..., description="New values")
@field_validator("columnName")
@classmethod
def validate_column_name(cls, v: str) -> str:
return validate_column_identifier(v)
@field_validator("columns")
@classmethod
def validate_column_list(cls, values: List[str]) -> List[str]:
return [validate_column_identifier(value) for value in values]
@model_validator(mode="after")
def validate_columns_values(self) -> "DbQueryUpdateRequest":
if len(self.columns) != len(self.values):
raise ValueError("columns and values must have the same length")
return self
class DbQueryDeleteRequest(BaseModel):
"""Request for DB delete query."""
columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to delete")
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
@field_validator("columnName")
@classmethod
def validate_column_name(cls, v: str) -> str:
return validate_column_identifier(v)
class DbQueryResponse(BaseResponse):
"""Response from database query."""
data: Any = Field(None, description="Query result data")
columns: Optional[List[str]] = Field(None, description="Column names if applicable")
# =============================================================================
# LOGS SCHEMAS
# =============================================================================
class CleanLogRequest(BaseModel):
"""Request to clean/truncate a log file."""
logFile: ALLOWED_LOG_FILES = Field(
...,
description="Name of the log file to clean"
)
class LogResource(BaseModel):
"""Log file resource information."""
name: str = Field(..., description="Log file name")
path: str = Field(..., description="Full path to log file")
size_bytes: int = Field(0, description="File size in bytes")
modified: Optional[str] = Field(None, description="Last modification timestamp")
class AddToQueueRequest(BaseModel):
"""Request to add action to execution queue."""
action: str = Field(..., description="Action string (e.g. update_api|devices)")
# =============================================================================
# SETTINGS SCHEMAS
# =============================================================================
class SettingValue(BaseModel):
"""A single setting value."""
key: str = Field(..., description="Setting key name")
value: Any = Field(..., description="Setting value")
class GetSettingResponse(BaseResponse):
"""Response for getting a setting value."""
value: Any = Field(None, description="The setting value")

View File

@@ -0,0 +1,191 @@
#!/usr/bin/env python
"""
NetAlertX OpenAPI Specification Generator
This module provides a registry-based approach to OpenAPI spec generation.
It converts Pydantic models to JSON Schema and assembles a complete OpenAPI 3.1 spec.
Key Features:
- Automatic Pydantic -> JSON Schema conversion
- Centralized endpoint registry
- Unique operationId enforcement
- Complete request/response schema generation
Usage:
from spec_generator import registry, generate_openapi_spec, register_tool
# Register endpoints (typically done at module load)
register_tool(
path="/devices/search",
method="POST",
operation_id="search_devices",
description="Search for devices",
request_model=DeviceSearchRequest,
response_model=DeviceSearchResponse
)
# Generate spec (called by MCP endpoint)
spec = generate_openapi_spec()
"""
from __future__ import annotations
import threading
from typing import Optional, List, Dict, Any
from .registry import (
clear_registry,
_registry,
_registry_lock,
_disabled_tools
)
from .introspection import introspect_flask_app, introspect_graphql_schema
from .schema_converter import (
build_parameters,
build_request_body,
build_responses
)
_rebuild_lock = threading.Lock()
def generate_openapi_spec(
title: str = "NetAlertX API",
version: str = "2.0.0",
description: str = "NetAlertX Network Monitoring API - MCP Compatible",
servers: Optional[List[Dict[str, str]]] = None,
flask_app: Optional[Any] = None
) -> Dict[str, Any]:
"""Assemble a complete OpenAPI specification from the registered endpoints."""
with _rebuild_lock:
# If no app provided and registry is empty, try to use the one from api_server_start
if not flask_app and not _registry:
try:
from ..api_server_start import app as start_app
flask_app = start_app
except (ImportError, AttributeError):
pass
# If we are in "dynamic mode", we rebuild the registry from code
if flask_app:
from ..graphql_endpoint import devicesSchema
clear_registry()
introspect_graphql_schema(devicesSchema)
introspect_flask_app(flask_app)
spec = {
"openapi": "3.1.0",
"info": {
"title": title,
"version": version,
"description": description,
"contact": {
"name": "NetAlertX",
"url": "https://github.com/jokob-sk/NetAlertX"
}
},
"servers": servers or [{"url": "/", "description": "Local server"}],
"security": [
{"BearerAuth": []}
],
"components": {
"securitySchemes": {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"description": "API token from NetAlertX settings (API_TOKEN)"
}
},
"schemas": {}
},
"paths": {},
"tags": []
}
definitions = {}
# Collect unique tags
tag_set = set()
with _registry_lock:
disabled_snapshot = _disabled_tools.copy()
for entry in _registry:
path = entry["path"]
method = entry["method"].lower()
# Initialize path if not exists
if path not in spec["paths"]:
spec["paths"][path] = {}
# Build operation object
operation = {
"operationId": entry["operation_id"],
"summary": entry["summary"],
"description": entry["description"],
"tags": entry["tags"],
"deprecated": entry["deprecated"]
}
# Inject disabled status if applicable
if entry["operation_id"] in disabled_snapshot:
operation["x-mcp-disabled"] = True
# Inject original ID if suffixed (Coderabbit fix)
if entry.get("original_operation_id"):
operation["x-original-operationId"] = entry["original_operation_id"]
# Add parameters (path + query)
parameters = build_parameters(entry)
if parameters:
operation["parameters"] = parameters
# Add request body for POST/PUT/PATCH/DELETE
if method in ("post", "put", "patch", "delete") and entry.get("request_model"):
request_body = build_request_body(
entry["request_model"],
definitions,
allow_multipart_payload=entry.get("allow_multipart_payload", False)
)
if request_body:
operation["requestBody"] = request_body
# Add responses
operation["responses"] = build_responses(
entry.get("response_model"), definitions
)
spec["paths"][path][method] = operation
# Collect tags
for tag in entry["tags"]:
tag_set.add(tag)
spec["components"]["schemas"] = definitions
# Build tags array with descriptions
tag_descriptions = {
"devices": "Device management and queries",
"nettools": "Network diagnostic tools",
"events": "Event and alert management",
"sessions": "Session history tracking",
"messaging": "In-app notifications",
"settings": "Configuration management",
"sync": "Data synchronization",
"logs": "Log file access",
"dbquery": "Direct database queries"
}
spec["tags"] = [
{"name": tag, "description": tag_descriptions.get(tag, f"{tag.title()} operations")}
for tag in sorted(tag_set)
]
return spec
# Initialize registry on module load
# Registry is now populated dynamically via introspection in generate_openapi_spec
def _register_all_endpoints():
"""Dummy function for compatibility with legacy tests."""
pass

View File

@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="description" content="NetAlertX API Documentation" />
<title>NetAlertX API Docs</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui.css" integrity="sha384-+yyzNgM3K92sROwsXxYCxaiLWxWJ0G+v/9A+qIZ2rgefKgkdcmJI+L601cqPD/Ut" crossorigin="anonymous" />
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui-bundle.js" integrity="sha384-qn5tagrAjZi8cSmvZ+k3zk4+eDEEUcP9myuR2J6V+/H6rne++v6ChO7EeHAEzqxQ" crossorigin="anonymous"></script>
<script>
window.onload = () => {
window.ui = SwaggerUIBundle({
url: '/openapi.json',
dom_id: '#swagger-ui',
deepLinking: true,
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
});
};
</script>
</body>
</html>

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
import inspect
import json
from functools import wraps
from typing import Callable, Optional, Type
from flask import request, jsonify
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from logger import mylog
def _handle_validation_error(e: ValidationError, operation_id: str, validation_error_code: int):
"""Internal helper to format Pydantic validation errors."""
mylog("verbose", [f"[Validation] Error for {operation_id}: {e}"])
# Construct a legacy-compatible error message if possible
error_msg = "Validation Error"
if e.errors():
err = e.errors()[0]
if err['type'] == 'missing':
loc = err.get('loc')
field_name = loc[0] if loc and len(loc) > 0 else "unknown field"
error_msg = f"Missing required '{field_name}'"
else:
error_msg = f"Validation Error: {err['msg']}"
return jsonify({
"success": False,
"error": error_msg,
"details": json.loads(e.json())
}), validation_error_code
def validate_request(
operation_id: str,
summary: str,
description: str,
request_model: Optional[Type[BaseModel]] = None,
response_model: Optional[Type[BaseModel]] = None,
tags: Optional[list[str]] = None,
path_params: Optional[list[dict]] = None,
query_params: Optional[list[dict]] = None,
validation_error_code: int = 422,
auth_callable: Optional[Callable[[], bool]] = None,
allow_multipart_payload: bool = False
):
"""
Decorator to register a Flask route with the OpenAPI registry and validate incoming requests.
Features:
- Auto-registers the endpoint with the OpenAPI spec generator.
- Validates JSON body against `request_model` (for POST/PUT).
- Injects the validated Pydantic model as the first argument to the view function.
- Supports auth_callable to check permissions before validation.
- Returns 422 (default) if validation fails.
- allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields.
"""
def decorator(f: Callable) -> Callable:
# Detect if f accepts 'payload' argument (unwrap if needed)
real_f = inspect.unwrap(f)
sig = inspect.signature(real_f)
accepts_payload = 'payload' in sig.parameters
f._openapi_metadata = {
"operation_id": operation_id,
"summary": summary,
"description": description,
"request_model": request_model,
"response_model": response_model,
"tags": tags,
"path_params": path_params,
"query_params": query_params,
"allow_multipart_payload": allow_multipart_payload
}
@wraps(f)
def wrapper(*args, **kwargs):
# 0. Handle OPTIONS explicitly if it reaches here (CORS preflight)
if request.method == "OPTIONS":
return jsonify({"success": True}), 200
# 1. Check Authorization first (Coderabbit fix)
if auth_callable and not auth_callable():
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
validated_instance = None
# 2. Payload Validation
if request_model:
# Helper to detect multipart requests by content-type (not just files)
is_multipart = (
request.content_type and request.content_type.startswith("multipart/")
)
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
# Explicit multipart handling (Coderabbit fix)
# Check both request.files and content-type for form-only multipart bodies
if request.files or is_multipart:
if allow_multipart_payload:
# Attempt validation from form data if allowed
try:
data = request.form.to_dict()
validated_instance = request_model(**data)
except ValidationError as e:
mylog("verbose", [f"[Validation] Multipart validation failed for {operation_id}: {e}"])
# Only continue without validation if handler doesn't expect payload
if accepts_payload:
return _handle_validation_error(e, operation_id, validation_error_code)
# Otherwise, handler will process files manually
else:
# If multipart is not allowed but files are present, we fail fast
# This prevents handlers from receiving unexpected None payloads
mylog("verbose", [f"[Validation] Multipart bypass attempted for {operation_id} but not allowed."])
return jsonify({
"success": False,
"error": "Invalid Content-Type",
"message": "Multipart requests are not allowed for this endpoint"
}), 415
else:
if not request.is_json and request.content_length:
return jsonify({"success": False, "error": "Invalid Content-Type", "message": "Content-Type must be application/json"}), 415
try:
data = request.get_json(silent=False) or {}
validated_instance = request_model(**data)
except ValidationError as e:
return _handle_validation_error(e, operation_id, validation_error_code)
except BadRequest as e:
mylog("verbose", [f"[Validation] Invalid JSON for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid JSON",
"message": "Request body must be valid JSON"
}), 400
except (TypeError, KeyError, AttributeError) as e:
mylog("verbose", [f"[Validation] Malformed request for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid Request",
"message": "Unable to process request body"
}), 400
elif request.method == "GET":
# Attempt to validate from query parameters for GET requests
try:
# request.args is a MultiDict; to_dict() gives first value of each key
# which is usually what we want for Pydantic models.
data = request.args.to_dict()
validated_instance = request_model(**data)
except ValidationError as e:
return _handle_validation_error(e, operation_id, validation_error_code)
except (TypeError, ValueError, KeyError) as e:
mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid query parameters",
"message": "Unable to process query parameters"
}), 400
else:
# Unsupported HTTP method with a request_model - fail explicitly
mylog("verbose", [f"[Validation] Unsupported HTTP method {request.method} for {operation_id} with request_model"])
return jsonify({
"success": False,
"error": "Method Not Allowed",
"message": f"HTTP method {request.method} is not supported for this endpoint"
}), 405
if validated_instance:
if accepts_payload:
kwargs['payload'] = validated_instance
else:
# Fail fast if decorated function doesn't accept payload (Coderabbit fix)
mylog("minimal", [f"[Validation] Endpoint {operation_id} does not accept 'payload' argument!"])
raise TypeError(f"Function {f.__name__} (operationId: {operation_id}) does not accept 'payload' argument.")
return f(*args, **kwargs)
return wrapper
return decorator

View File

@@ -8,7 +8,7 @@ import json
import threading
import time
from collections import deque
from flask import Response, request
from flask import Response, request, jsonify
from logger import mylog
# Thread-safe event queue
@@ -129,11 +129,17 @@ def create_sse_endpoint(app, is_authorized=None) -> None:
is_authorized: Optional function to check authorization (if None, allows all)
"""
@app.route("/sse/state", methods=["GET"])
@app.route("/sse/state", methods=["GET", "OPTIONS"])
def api_sse_state():
"""SSE endpoint for real-time state updates"""
if request.method == "OPTIONS":
response = jsonify({"success": True})
response.headers["Access-Control-Allow-Origin"] = request.headers.get("Origin", "*")
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response, 200
if is_authorized and not is_authorized():
return {"none": "Unauthorized"}, 401
return jsonify({"success": False, "error": "Unauthorized"}), 401
client_id = request.args.get("client", f"client-{int(time.time() * 1000)}")
mylog("debug", [f"[SSE] Client connected: {client_id}"])
@@ -148,11 +154,14 @@ def create_sse_endpoint(app, is_authorized=None) -> None:
},
)
@app.route("/sse/stats", methods=["GET"])
@app.route("/sse/stats", methods=["GET", "OPTIONS"])
def api_sse_stats():
"""Get SSE endpoint statistics for debugging"""
if request.method == "OPTIONS":
return jsonify({"success": True}), 200
if is_authorized and not is_authorized():
return {"none": "Unauthorized"}, 401
return {"success": False, "error": "Unauthorized"}, 401
return {
"success": True,