mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2026-03-31 07:12:23 -07:00
feat(api): MCP, OpenAPI & Dynamic Introspection
New Features: - API endpoints now support comprehensive input validation with detailed error responses via Pydantic models. - OpenAPI specification endpoint (/openapi.json) and interactive Swagger UI documentation (/docs) now available for API discovery. - Enhanced MCP session lifecycle management with create, retrieve, and delete operations. - Network diagnostic tools: traceroute, nslookup, NMAP scanning, and network topology viewing exposed via API. - Device search, filtering by status (including 'offline'), and bulk operations (copy, delete, update). - Wake-on-LAN functionality for remote device management. - Added dynamic tool disablement and status reporting. Bug Fixes: - Fixed get_tools_status in registry to correctly return boolean values instead of None for enabled tools. - Improved error handling for invalid API inputs with standardized validation responses. - Fixed OPTIONS request handling for cross-origin requests. Refactoring: - Significant refactoring of api_server_start.py to use decorator-based validation (@validate_request).
This commit is contained in:
0
server/api_server/openapi/__init__.py
Normal file
0
server/api_server/openapi/__init__.py
Normal file
106
server/api_server/openapi/introspection.py
Normal file
106
server/api_server/openapi/introspection.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
import graphene
|
||||
|
||||
from .registry import register_tool, _operation_ids
|
||||
|
||||
|
||||
def introspect_graphql_schema(schema: graphene.Schema):
|
||||
"""
|
||||
Introspect the GraphQL schema and register endpoints in the OpenAPI registry.
|
||||
This bridges the 'living code' (GraphQL) to the OpenAPI spec.
|
||||
"""
|
||||
# Graphene schema introspection
|
||||
graphql_schema = schema.graphql_schema
|
||||
query_type = graphql_schema.query_type
|
||||
|
||||
if not query_type:
|
||||
return
|
||||
|
||||
# We register the main /graphql endpoint once
|
||||
register_tool(
|
||||
path="/graphql",
|
||||
method="POST",
|
||||
operation_id="graphql_query",
|
||||
summary="GraphQL Endpoint",
|
||||
description="Execute arbitrary GraphQL queries against the system schema.",
|
||||
tags=["graphql"]
|
||||
)
|
||||
|
||||
|
||||
def _flask_to_openapi_path(flask_path: str) -> str:
|
||||
"""Convert Flask path syntax to OpenAPI path syntax."""
|
||||
# Handles <converter:variable> -> {variable} and <variable> -> {variable}
|
||||
return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path)
|
||||
|
||||
|
||||
def introspect_flask_app(app: Any):
|
||||
"""
|
||||
Introspect the Flask application to find routes decorated with @validate_request
|
||||
and register them in the OpenAPI registry.
|
||||
"""
|
||||
registered_ops = set()
|
||||
for rule in app.url_map.iter_rules():
|
||||
view_func = app.view_functions.get(rule.endpoint)
|
||||
if not view_func:
|
||||
continue
|
||||
|
||||
# Check for our decorator's metadata
|
||||
metadata = getattr(view_func, "_openapi_metadata", None)
|
||||
if not metadata:
|
||||
# Fallback for wrapped functions
|
||||
if hasattr(view_func, "__wrapped__"):
|
||||
metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None)
|
||||
|
||||
if metadata:
|
||||
op_id = metadata["operation_id"]
|
||||
|
||||
# Register the tool with real path and method from Flask
|
||||
for method in rule.methods:
|
||||
if method in ("OPTIONS", "HEAD"):
|
||||
continue
|
||||
|
||||
# Create a unique key for this path/method/op combination if needed,
|
||||
# but operationId must be unique globally.
|
||||
# If the same function is mounted on multiple paths, we append a suffix
|
||||
path = _flask_to_openapi_path(str(rule))
|
||||
|
||||
# Check if this operation (path + method) is already registered
|
||||
op_key = f"{method}:{path}"
|
||||
if op_key in registered_ops:
|
||||
continue
|
||||
|
||||
# Determine tags - create a copy to avoid mutating shared metadata
|
||||
tags = list(metadata.get("tags") or ["rest"])
|
||||
if path.startswith("/mcp/"):
|
||||
# Move specific tags to secondary position or just add MCP
|
||||
if "rest" in tags:
|
||||
tags.remove("rest")
|
||||
if "mcp" not in tags:
|
||||
tags.append("mcp")
|
||||
|
||||
# Ensure unique operationId
|
||||
original_op_id = op_id
|
||||
unique_op_id = op_id
|
||||
count = 1
|
||||
while unique_op_id in _operation_ids:
|
||||
unique_op_id = f"{op_id}_{count}"
|
||||
count += 1
|
||||
|
||||
register_tool(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=unique_op_id,
|
||||
original_operation_id=original_op_id if unique_op_id != original_op_id else None,
|
||||
summary=metadata["summary"],
|
||||
description=metadata["description"],
|
||||
request_model=metadata.get("request_model"),
|
||||
response_model=metadata.get("response_model"),
|
||||
path_params=metadata.get("path_params"),
|
||||
query_params=metadata.get("query_params"),
|
||||
tags=tags,
|
||||
allow_multipart_payload=metadata.get("allow_multipart_payload", False)
|
||||
)
|
||||
registered_ops.add(op_key)
|
||||
158
server/api_server/openapi/registry.py
Normal file
158
server/api_server/openapi/registry.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import List, Dict, Any, Literal, Optional, Type, Set
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Thread-safe registry
|
||||
_registry: List[Dict[str, Any]] = []
|
||||
_registry_lock = threading.Lock()
|
||||
_operation_ids: Set[str] = set()
|
||||
_disabled_tools: Set[str] = set()
|
||||
|
||||
|
||||
class DuplicateOperationIdError(Exception):
|
||||
"""Raised when an operationId is registered more than once."""
|
||||
pass
|
||||
|
||||
|
||||
def set_tool_disabled(operation_id: str, disabled: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable a tool by operation_id.
|
||||
|
||||
Args:
|
||||
operation_id: The unique operation_id of the tool
|
||||
disabled: True to disable, False to enable
|
||||
|
||||
Returns:
|
||||
bool: True if operation_id exists, False otherwise
|
||||
"""
|
||||
with _registry_lock:
|
||||
if operation_id not in _operation_ids:
|
||||
return False
|
||||
|
||||
if disabled:
|
||||
_disabled_tools.add(operation_id)
|
||||
else:
|
||||
_disabled_tools.discard(operation_id)
|
||||
return True
|
||||
|
||||
|
||||
def is_tool_disabled(operation_id: str) -> bool:
|
||||
"""
|
||||
Check if a tool is disabled.
|
||||
Checks both the unique operation_id and the original_operation_id.
|
||||
"""
|
||||
with _registry_lock:
|
||||
if operation_id in _disabled_tools:
|
||||
return True
|
||||
|
||||
# Also check if the original base ID is disabled
|
||||
for entry in _registry:
|
||||
if entry["operation_id"] == operation_id:
|
||||
orig_id = entry.get("original_operation_id")
|
||||
if orig_id and orig_id in _disabled_tools:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_disabled_tools() -> List[str]:
|
||||
"""Get list of all disabled operation_ids."""
|
||||
with _registry_lock:
|
||||
return list(_disabled_tools)
|
||||
|
||||
|
||||
def get_tools_status() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a list of all registered tools and their disabled status.
|
||||
Useful for backend-to-frontend communication.
|
||||
"""
|
||||
tools = []
|
||||
with _registry_lock:
|
||||
disabled_snapshot = _disabled_tools.copy()
|
||||
for entry in _registry:
|
||||
op_id = entry["operation_id"]
|
||||
orig_id = entry.get("original_operation_id")
|
||||
is_disabled = bool(op_id in disabled_snapshot or (orig_id and orig_id in disabled_snapshot))
|
||||
tools.append({
|
||||
"operation_id": op_id,
|
||||
"summary": entry["summary"],
|
||||
"disabled": is_disabled
|
||||
})
|
||||
return tools
|
||||
|
||||
|
||||
def register_tool(
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
request_model: Optional[Type[BaseModel]] = None,
|
||||
response_model: Optional[Type[BaseModel]] = None,
|
||||
path_params: Optional[List[Dict[str, Any]]] = None,
|
||||
query_params: Optional[List[Dict[str, Any]]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
deprecated: bool = False,
|
||||
original_operation_id: Optional[str] = None,
|
||||
allow_multipart_payload: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Register an API endpoint for OpenAPI spec generation.
|
||||
|
||||
Args:
|
||||
path: URL path (e.g., "/devices/{mac}")
|
||||
method: HTTP method
|
||||
operation_id: Unique identifier for this operation (MUST be unique across entire spec)
|
||||
summary: Short summary for the operation
|
||||
description: Detailed description
|
||||
request_model: Pydantic model for request body (POST/PUT/PATCH)
|
||||
response_model: Pydantic model for success response
|
||||
path_params: List of path parameter definitions
|
||||
query_params: List of query parameter definitions
|
||||
tags: OpenAPI tags for grouping
|
||||
deprecated: Whether this endpoint is deprecated
|
||||
original_operation_id: The base ID before suffixing (for disablement mapping)
|
||||
allow_multipart_payload: Whether to allow multipart/form-data payloads
|
||||
|
||||
Raises:
|
||||
DuplicateOperationIdError: If operation_id already exists in registry
|
||||
"""
|
||||
with _registry_lock:
|
||||
if operation_id in _operation_ids:
|
||||
raise DuplicateOperationIdError(
|
||||
f"operationId '{operation_id}' is already registered. "
|
||||
"Each operationId must be unique across the entire API."
|
||||
)
|
||||
_operation_ids.add(operation_id)
|
||||
|
||||
_registry.append({
|
||||
"path": path,
|
||||
"method": method.upper(),
|
||||
"operation_id": operation_id,
|
||||
"original_operation_id": original_operation_id,
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"request_model": request_model,
|
||||
"response_model": response_model,
|
||||
"path_params": path_params or [],
|
||||
"query_params": query_params or [],
|
||||
"tags": tags or ["default"],
|
||||
"deprecated": deprecated,
|
||||
"allow_multipart_payload": allow_multipart_payload
|
||||
})
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
"""Clear all registered endpoints (useful for testing)."""
|
||||
with _registry_lock:
|
||||
_registry.clear()
|
||||
_operation_ids.clear()
|
||||
_disabled_tools.clear()
|
||||
|
||||
|
||||
def get_registry() -> List[Dict[str, Any]]:
|
||||
"""Get a deep copy of the current registry to prevent external mutation."""
|
||||
with _registry_lock:
|
||||
return deepcopy(_registry)
|
||||
216
server/api_server/openapi/schema_converter.py
Normal file
216
server/api_server/openapi/schema_converter.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Any, Optional, Type, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to JSON Schema (OpenAPI 3.1 compatible).
|
||||
|
||||
Uses Pydantic's built-in schema generation which produces
|
||||
JSON Schema Draft 2020-12 compatible output.
|
||||
|
||||
Args:
|
||||
model: Pydantic BaseModel class
|
||||
|
||||
Returns:
|
||||
JSON Schema dictionary
|
||||
"""
|
||||
# Pydantic v2 uses model_json_schema()
|
||||
schema = model.model_json_schema(mode="serialization")
|
||||
|
||||
# Remove $defs if empty (cleaner output)
|
||||
if "$defs" in schema and not schema["$defs"]:
|
||||
del schema["$defs"]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def build_parameters(entry: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Build OpenAPI parameters array from path and query params."""
|
||||
parameters = []
|
||||
|
||||
# Path parameters
|
||||
for param in entry.get("path_params", []):
|
||||
parameters.append({
|
||||
"name": param["name"],
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"description": param.get("description", ""),
|
||||
"schema": param.get("schema", {"type": "string"})
|
||||
})
|
||||
|
||||
# Query parameters
|
||||
for param in entry.get("query_params", []):
|
||||
parameters.append({
|
||||
"name": param["name"],
|
||||
"in": "query",
|
||||
"required": param.get("required", False),
|
||||
"description": param.get("description", ""),
|
||||
"schema": param.get("schema", {"type": "string"})
|
||||
})
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def extract_definitions(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively extract $defs from a schema and move them to the definitions dict.
|
||||
Also rewrite $ref to point to #/components/schemas/.
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
# Extract definitions
|
||||
if "$defs" in schema:
|
||||
for name, definition in schema["$defs"].items():
|
||||
# Recursively process the definition itself before adding it
|
||||
definitions[name] = extract_definitions(definition, definitions)
|
||||
del schema["$defs"]
|
||||
|
||||
# Rewrite references
|
||||
if "$ref" in schema and schema["$ref"].startswith("#/$defs/"):
|
||||
ref_name = schema["$ref"].split("/")[-1]
|
||||
schema["$ref"] = f"#/components/schemas/{ref_name}"
|
||||
|
||||
# Recursively process properties
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, dict):
|
||||
schema[key] = extract_definitions(value, definitions)
|
||||
elif isinstance(value, list):
|
||||
schema[key] = [extract_definitions(item, definitions) for item in value]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def build_request_body(
|
||||
model: Optional[Type[BaseModel]],
|
||||
definitions: Dict[str, Any],
|
||||
allow_multipart_payload: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Build OpenAPI requestBody from Pydantic model."""
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
schema = pydantic_to_json_schema(model)
|
||||
schema = extract_definitions(schema, definitions)
|
||||
|
||||
content = {
|
||||
"application/json": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
|
||||
if allow_multipart_payload:
|
||||
content["multipart/form-data"] = {
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
return {
|
||||
"required": True,
|
||||
"content": content
|
||||
}
|
||||
|
||||
|
||||
def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively remove validation constraints from a JSON schema.
|
||||
Keeps structure and descriptions, but removes pattern, minLength, etc.
|
||||
This saves context tokens for LLMs which don't validate server output.
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
# Keys to remove
|
||||
validation_keys = [
|
||||
"pattern", "minLength", "maxLength", "minimum", "maximum",
|
||||
"exclusiveMinimum", "exclusiveMaximum", "multipleOf", "minItems",
|
||||
"maxItems", "uniqueItems", "minProperties", "maxProperties"
|
||||
]
|
||||
|
||||
clean_schema = {k: v for k, v in schema.items() if k not in validation_keys}
|
||||
|
||||
# Recursively clean sub-schemas
|
||||
if "properties" in clean_schema:
|
||||
clean_schema["properties"] = {
|
||||
k: strip_validation(v) for k, v in clean_schema["properties"].items()
|
||||
}
|
||||
|
||||
if "items" in clean_schema:
|
||||
clean_schema["items"] = strip_validation(clean_schema["items"])
|
||||
|
||||
if "allOf" in clean_schema:
|
||||
clean_schema["allOf"] = [strip_validation(x) for x in clean_schema["allOf"]]
|
||||
|
||||
if "anyOf" in clean_schema:
|
||||
clean_schema["anyOf"] = [strip_validation(x) for x in clean_schema["anyOf"]]
|
||||
|
||||
if "oneOf" in clean_schema:
|
||||
clean_schema["oneOf"] = [strip_validation(x) for x in clean_schema["oneOf"]]
|
||||
|
||||
if "$defs" in clean_schema:
|
||||
clean_schema["$defs"] = {
|
||||
k: strip_validation(v) for k, v in clean_schema["$defs"].items()
|
||||
}
|
||||
|
||||
if "additionalProperties" in clean_schema and isinstance(clean_schema["additionalProperties"], dict):
|
||||
clean_schema["additionalProperties"] = strip_validation(clean_schema["additionalProperties"])
|
||||
|
||||
return clean_schema
|
||||
|
||||
|
||||
def build_responses(
|
||||
response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Build OpenAPI responses object."""
|
||||
responses = {}
|
||||
|
||||
# Success response (200)
|
||||
if response_model:
|
||||
# Strip validation from response schema to save tokens
|
||||
schema = strip_validation(pydantic_to_json_schema(response_model))
|
||||
schema = extract_definitions(schema, definitions)
|
||||
responses["200"] = {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
responses["200"] = {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean"},
|
||||
"message": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Standard error responses - MINIMIZED context
|
||||
# Annotate that these errors can occur, but provide no schema/content to save tokens.
|
||||
# The LLM knows what "Bad Request" or "Not Found" means.
|
||||
error_codes = {
|
||||
"400": "Bad Request",
|
||||
"401": "Unauthorized",
|
||||
"403": "Forbidden",
|
||||
"404": "Not Found",
|
||||
"422": "Validation Error",
|
||||
"500": "Internal Server Error"
|
||||
}
|
||||
|
||||
for code, desc in error_codes.items():
|
||||
responses[code] = {
|
||||
"description": desc
|
||||
# No "content" schema provided
|
||||
}
|
||||
|
||||
return responses
|
||||
738
server/api_server/openapi/schemas.py
Normal file
738
server/api_server/openapi/schemas.py
Normal file
@@ -0,0 +1,738 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
NetAlertX API Schema Definitions (Pydantic v2)
|
||||
|
||||
This module defines strict Pydantic models for all API request and response payloads.
|
||||
These schemas serve as the single source of truth for:
|
||||
1. Runtime validation of incoming requests
|
||||
2. OpenAPI specification generation
|
||||
3. MCP tool input schema derivation
|
||||
|
||||
Philosophy: "Code First, Spec Second" — these models ARE the contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import ipaddress
|
||||
from typing import Optional, List, Literal, Any, Dict
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict, RootModel
|
||||
|
||||
# Internal helper imports
|
||||
from helper import sanitize_string
|
||||
from plugin_helper import normalize_mac, is_mac
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# COMMON PATTERNS & VALIDATORS
|
||||
# =============================================================================
|
||||
|
||||
MAC_PATTERN = r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$"
|
||||
IP_PATTERN = r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
||||
COLUMN_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+$")
|
||||
|
||||
# Security whitelists & Literals for documentation
|
||||
ALLOWED_DEVICE_COLUMNS = Literal[
|
||||
"devName", "devOwner", "devType", "devVendor",
|
||||
"devGroup", "devLocation", "devComments", "devFavorite",
|
||||
"devParentMAC"
|
||||
]
|
||||
|
||||
ALLOWED_NMAP_MODES = Literal[
|
||||
"quick", "intense", "ping", "comprehensive", "fast", "normal", "detail", "skipdiscovery",
|
||||
"-sS", "-sT", "-sU", "-sV", "-O"
|
||||
]
|
||||
|
||||
NOTIFICATION_LEVELS = Literal["info", "warning", "error", "alert"]
|
||||
|
||||
ALLOWED_TABLES = Literal["Devices", "Events", "Sessions", "Settings", "CurrentScan", "Online_History", "Plugins_Objects"]
|
||||
|
||||
ALLOWED_LOG_FILES = Literal[
|
||||
"app.log", "app_front.log", "IP_changes.log", "stdout.log", "stderr.log",
|
||||
"app.php_errors.log", "execution_queue.log", "db_is_locked.log"
|
||||
]
|
||||
|
||||
|
||||
def validate_mac(value: str) -> str:
|
||||
"""Validate and normalize MAC address format."""
|
||||
# Allow "Internet" as a special case for the gateway/WAN device
|
||||
if value.lower() == "internet":
|
||||
return "Internet"
|
||||
|
||||
if not is_mac(value):
|
||||
raise ValueError(f"Invalid MAC address format: {value}")
|
||||
|
||||
return normalize_mac(value)
|
||||
|
||||
|
||||
def validate_ip(value: str) -> str:
|
||||
"""Validate IP address format (IPv4 or IPv6) using stdlib ipaddress.
|
||||
|
||||
Returns the canonical string form of the IP address.
|
||||
"""
|
||||
try:
|
||||
return str(ipaddress.ip_address(value))
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Invalid IP address: {value}") from err
|
||||
|
||||
|
||||
def validate_column_identifier(value: str) -> str:
|
||||
"""Validate a column identifier to prevent SQL injection."""
|
||||
if not COLUMN_NAME_PATTERN.match(value):
|
||||
raise ValueError("Invalid column name format")
|
||||
return value
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BASE RESPONSE MODELS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Standard API response wrapper."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
success: bool = Field(..., description="Whether the operation succeeded")
|
||||
message: Optional[str] = Field(None, description="Human-readable message")
|
||||
error: Optional[str] = Field(None, description="Error message if success=False")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseResponse):
|
||||
"""Response with pagination metadata."""
|
||||
total: int = Field(0, description="Total number of items")
|
||||
page: int = Field(1, ge=1, description="Current page number")
|
||||
per_page: int = Field(50, ge=1, le=500, description="Items per page")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DEVICE SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DeviceSearchRequest(BaseModel):
|
||||
"""Request payload for searching devices."""
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
description="Search term: IP address, MAC address, device name, or vendor",
|
||||
json_schema_extra={"examples": ["192.168.1.1", "Apple", "00:11:22:33:44:55"]}
|
||||
)
|
||||
limit: int = Field(
|
||||
50,
|
||||
ge=1,
|
||||
le=500,
|
||||
description="Maximum number of results to return"
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Detailed device information model (Raw record)."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
devMac: str = Field(..., description="Device MAC address")
|
||||
devName: Optional[str] = Field(None, description="Device display name/alias")
|
||||
devLastIP: Optional[str] = Field(None, description="Last known IP address")
|
||||
devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup")
|
||||
devOwner: Optional[str] = Field(None, description="Device owner")
|
||||
devType: Optional[str] = Field(None, description="Device type classification")
|
||||
devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)")
|
||||
devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)")
|
||||
devStatus: Optional[str] = Field(None, description="Online/Offline status")
|
||||
|
||||
|
||||
class DeviceSearchResponse(BaseResponse):
|
||||
"""Response payload for device search."""
|
||||
devices: List[DeviceInfo] = Field(default_factory=list, description="List of matching devices")
|
||||
|
||||
|
||||
class DeviceListRequest(BaseModel):
|
||||
"""Request for listing devices by status."""
|
||||
status: Optional[Literal[
|
||||
"connected", "down", "favorites", "new", "archived", "all", "my",
|
||||
"offline"
|
||||
]] = Field(
|
||||
None,
|
||||
description="Filter devices by status (connected, down, favorites, new, archived, all, my, offline)"
|
||||
)
|
||||
|
||||
|
||||
class DeviceListResponse(RootModel):
|
||||
"""Response with list of devices."""
|
||||
root: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
|
||||
|
||||
|
||||
class DeviceListWrapperResponse(BaseResponse):
|
||||
"""Wrapped response with list of devices."""
|
||||
devices: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
|
||||
|
||||
|
||||
class GetDeviceRequest(BaseModel):
|
||||
"""Path parameter for getting a specific device."""
|
||||
mac: str = Field(
|
||||
...,
|
||||
description="Device MAC address",
|
||||
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
|
||||
)
|
||||
|
||||
@field_validator("mac")
|
||||
@classmethod
|
||||
def validate_mac_address(cls, v: str) -> str:
|
||||
return validate_mac(v)
|
||||
|
||||
|
||||
class GetDeviceResponse(BaseResponse):
|
||||
"""Wrapped response for getting device details."""
|
||||
device: Optional[DeviceInfo] = Field(None, description="Device details if found")
|
||||
|
||||
|
||||
class GetDeviceWrapperResponse(BaseResponse):
|
||||
"""Wrapped response for getting a single device (e.g. latest)."""
|
||||
device: Optional[DeviceInfo] = Field(None, description="Device details")
|
||||
|
||||
|
||||
class SetDeviceAliasRequest(BaseModel):
|
||||
"""Request to set a device alias/name."""
|
||||
alias: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=128,
|
||||
description="New display name/alias for the device"
|
||||
)
|
||||
|
||||
@field_validator("alias")
|
||||
@classmethod
|
||||
def sanitize_alias(cls, v: str) -> str:
|
||||
return sanitize_string(v)
|
||||
|
||||
|
||||
class DeviceTotalsResponse(RootModel):
|
||||
"""Response with device statistics."""
|
||||
root: List[int] = Field(default_factory=list, description="List of counts: [all, online, favorites, new, offline, archived]")
|
||||
|
||||
|
||||
class DeviceExportRequest(BaseModel):
|
||||
"""Request for exporting devices."""
|
||||
format: Literal["csv", "json"] = Field(
|
||||
"csv",
|
||||
description="Export format: csv or json"
|
||||
)
|
||||
|
||||
|
||||
class DeviceExportResponse(BaseModel):
|
||||
"""Raw response for device export in JSON format."""
|
||||
columns: List[str] = Field(..., description="Column names")
|
||||
data: List[Dict[str, Any]] = Field(..., description="Device records")
|
||||
|
||||
|
||||
class DeviceImportRequest(BaseModel):
|
||||
"""Request for importing devices."""
|
||||
content: Optional[str] = Field(
|
||||
None,
|
||||
description="Base64-encoded CSV or JSON content to import"
|
||||
)
|
||||
|
||||
|
||||
class DeviceImportResponse(BaseResponse):
|
||||
"""Response for device import operation."""
|
||||
imported: int = Field(0, description="Number of devices imported")
|
||||
skipped: int = Field(0, description="Number of devices skipped")
|
||||
errors: List[str] = Field(default_factory=list, description="List of import errors")
|
||||
|
||||
|
||||
class CopyDeviceRequest(BaseModel):
|
||||
"""Request to copy device settings."""
|
||||
macFrom: str = Field(..., description="Source MAC address")
|
||||
macTo: str = Field(..., description="Destination MAC address")
|
||||
|
||||
@field_validator("macFrom", "macTo")
|
||||
@classmethod
|
||||
def validate_mac_addresses(cls, v: str) -> str:
|
||||
return validate_mac(v)
|
||||
|
||||
|
||||
class UpdateDeviceColumnRequest(BaseModel):
|
||||
"""Request to update a specific device database column."""
|
||||
columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name")
|
||||
columnValue: Any = Field(..., description="New value for the column")
|
||||
|
||||
|
||||
class DeviceUpdateRequest(BaseModel):
|
||||
"""Request to update device fields (create/update)."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
devName: Optional[str] = Field(None, description="Device name")
|
||||
devOwner: Optional[str] = Field(None, description="Device owner")
|
||||
devType: Optional[str] = Field(None, description="Device type")
|
||||
devVendor: Optional[str] = Field(None, description="Device vendor")
|
||||
devGroup: Optional[str] = Field(None, description="Device group")
|
||||
devLocation: Optional[str] = Field(None, description="Device location")
|
||||
devComments: Optional[str] = Field(None, description="Comments")
|
||||
createNew: bool = Field(False, description="Create new device if not exists")
|
||||
|
||||
@field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments")
|
||||
@classmethod
|
||||
def sanitize_text_fields(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
return sanitize_string(v)
|
||||
|
||||
|
||||
class DeleteDevicesRequest(BaseModel):
|
||||
"""Request to delete multiple devices."""
|
||||
macs: List[str] = Field([], description="List of MACs to delete")
|
||||
confirm_delete_all: bool = Field(False, description="Explicit flag to delete ALL devices when macs is empty")
|
||||
|
||||
@field_validator("macs")
|
||||
@classmethod
|
||||
def validate_mac_list(cls, v: List[str]) -> List[str]:
|
||||
return [validate_mac(mac) for mac in v]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_delete_all_safety(self) -> DeleteDevicesRequest:
|
||||
if not self.macs and not self.confirm_delete_all:
|
||||
raise ValueError("Must provide at least one MAC or set confirm_delete_all=True")
|
||||
return self
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NETWORK TOOLS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TriggerScanRequest(BaseModel):
|
||||
"""Request to trigger a network scan."""
|
||||
type: str = Field(
|
||||
"ARPSCAN",
|
||||
description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)",
|
||||
json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]}
|
||||
)
|
||||
|
||||
|
||||
class TriggerScanResponse(BaseResponse):
|
||||
"""Response for scan trigger."""
|
||||
scan_type: Optional[str] = Field(None, description="Type of scan that was triggered")
|
||||
|
||||
|
||||
class OpenPortsRequest(BaseModel):
|
||||
"""Request for getting open ports."""
|
||||
target: str = Field(
|
||||
...,
|
||||
description="Target IP address or MAC address to check ports for",
|
||||
json_schema_extra={"examples": ["192.168.1.50", "00:11:22:33:44:55"]}
|
||||
)
|
||||
|
||||
@field_validator("target")
|
||||
@classmethod
|
||||
def validate_target(cls, v: str) -> str:
|
||||
"""Validate target is either a valid IP or MAC address."""
|
||||
# Try IP first
|
||||
try:
|
||||
return validate_ip(v)
|
||||
except ValueError:
|
||||
pass
|
||||
# Try MAC
|
||||
return validate_mac(v)
|
||||
|
||||
|
||||
class OpenPortsResponse(BaseResponse):
|
||||
"""Response with open ports information."""
|
||||
target: str = Field(..., description="Target that was scanned")
|
||||
open_ports: List[Any] = Field(default_factory=list, description="List of open port objects or numbers")
|
||||
|
||||
|
||||
class WakeOnLanRequest(BaseModel):
|
||||
"""Request to send Wake-on-LAN packet."""
|
||||
devMac: Optional[str] = Field(
|
||||
None,
|
||||
description="Target device MAC address",
|
||||
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
|
||||
)
|
||||
devLastIP: Optional[str] = Field(
|
||||
None,
|
||||
alias="ip",
|
||||
description="Target device IP (MAC will be resolved if not provided)",
|
||||
json_schema_extra={"examples": ["192.168.1.50"]}
|
||||
)
|
||||
# Note: alias="ip" means input JSON can use "ip".
|
||||
# But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip".
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@field_validator("devMac")
|
||||
@classmethod
|
||||
def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
return validate_mac(v)
|
||||
return v
|
||||
|
||||
@field_validator("devLastIP")
|
||||
@classmethod
|
||||
def validate_ip_if_provided(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
return validate_ip(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_mac_or_ip(self) -> "WakeOnLanRequest":
|
||||
"""Ensure at least one of devMac or devLastIP is provided."""
|
||||
if self.devMac is None and self.devLastIP is None:
|
||||
raise ValueError("Either 'devMac' or 'devLastIP' (alias 'ip') must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class WakeOnLanResponse(BaseResponse):
|
||||
"""Response for Wake-on-LAN operation."""
|
||||
output: Optional[str] = Field(None, description="Command output")
|
||||
|
||||
|
||||
class TracerouteRequest(BaseModel):
|
||||
"""Request to perform traceroute."""
|
||||
devLastIP: str = Field(
|
||||
...,
|
||||
description="Target IP address for traceroute",
|
||||
json_schema_extra={"examples": ["8.8.8.8", "192.168.1.1"]}
|
||||
)
|
||||
|
||||
@field_validator("devLastIP")
|
||||
@classmethod
|
||||
def validate_ip_address(cls, v: str) -> str:
|
||||
return validate_ip(v)
|
||||
|
||||
|
||||
class TracerouteResponse(BaseResponse):
|
||||
"""Response with traceroute results."""
|
||||
output: List[str] = Field(default_factory=list, description="Traceroute hop output lines")
|
||||
|
||||
|
||||
class NmapScanRequest(BaseModel):
|
||||
"""Request to perform NMAP scan."""
|
||||
scan: str = Field(
|
||||
...,
|
||||
description="Target IP address for NMAP scan"
|
||||
)
|
||||
mode: ALLOWED_NMAP_MODES = Field(
|
||||
...,
|
||||
description="NMAP scan mode/arguments (restricted to safe options)"
|
||||
)
|
||||
|
||||
@field_validator("scan")
|
||||
@classmethod
|
||||
def validate_scan_target(cls, v: str) -> str:
|
||||
return validate_ip(v)
|
||||
|
||||
|
||||
class NslookupRequest(BaseModel):
|
||||
"""Request for DNS lookup."""
|
||||
devLastIP: str = Field(
|
||||
...,
|
||||
description="IP address to perform reverse DNS lookup"
|
||||
)
|
||||
|
||||
@field_validator("devLastIP")
|
||||
@classmethod
|
||||
def validate_ip_address(cls, v: str) -> str:
|
||||
return validate_ip(v)
|
||||
|
||||
|
||||
class NslookupResponse(BaseResponse):
|
||||
"""Response for DNS lookup operation."""
|
||||
output: List[str] = Field(default_factory=list, description="Nslookup output lines")
|
||||
|
||||
|
||||
class NmapScanResponse(BaseResponse):
|
||||
"""Response for NMAP scan operation."""
|
||||
mode: Optional[str] = Field(None, description="NMAP scan mode")
|
||||
ip: Optional[str] = Field(None, description="Target IP address")
|
||||
output: List[str] = Field(default_factory=list, description="NMAP scan output lines")
|
||||
|
||||
|
||||
class NetworkTopologyResponse(BaseResponse):
|
||||
"""Response with network topology data."""
|
||||
nodes: List[dict] = Field(default_factory=list, description="Network nodes")
|
||||
links: List[dict] = Field(default_factory=list, description="Network connections")
|
||||
|
||||
|
||||
class InternetInfoResponse(BaseResponse):
|
||||
"""Response for internet information."""
|
||||
output: Dict[str, Any] = Field(..., description="Details about the internet connection.")
|
||||
|
||||
|
||||
class NetworkInterfacesResponse(BaseResponse):
|
||||
"""Response with network interface information."""
|
||||
interfaces: Dict[str, Any] = Field(..., description="Details about network interfaces.")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# EVENTS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class EventInfo(BaseModel):
|
||||
"""Event/alert information."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
eveRowid: Optional[int] = Field(None, description="Event row ID")
|
||||
eveMAC: Optional[str] = Field(None, description="Device MAC address")
|
||||
eveIP: Optional[str] = Field(None, description="Device IP address")
|
||||
eveDateTime: Optional[str] = Field(None, description="Event timestamp")
|
||||
eveEventType: Optional[str] = Field(None, description="Type of event")
|
||||
evePreviousIP: Optional[str] = Field(None, description="Previous IP if changed")
|
||||
|
||||
|
||||
class RecentEventsRequest(BaseModel):
|
||||
"""Request for recent events."""
|
||||
hours: int = Field(
|
||||
24,
|
||||
ge=1,
|
||||
le=720,
|
||||
description="Number of hours to look back for events"
|
||||
)
|
||||
limit: int = Field(
|
||||
100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Maximum number of events to return"
|
||||
)
|
||||
|
||||
|
||||
class RecentEventsResponse(BaseResponse):
|
||||
"""Response with recent events."""
|
||||
hours: int = Field(..., description="The time window in hours")
|
||||
events: List[EventInfo] = Field(default_factory=list, description="List of recent events")
|
||||
|
||||
|
||||
class LastEventsResponse(BaseResponse):
|
||||
"""Response with last N events."""
|
||||
events: List[EventInfo] = Field(default_factory=list, description="List of last events")
|
||||
|
||||
|
||||
class CreateEventRequest(BaseModel):
|
||||
"""Request to create a device event."""
|
||||
ip: Optional[str] = Field("0.0.0.0", description="Device IP")
|
||||
event_type: str = Field("Device Down", description="Event type")
|
||||
additional_info: Optional[str] = Field("", description="Additional info")
|
||||
pending_alert: int = Field(1, description="Pending alert flag")
|
||||
event_time: Optional[str] = Field(None, description="Event timestamp (ISO)")
|
||||
|
||||
@field_validator("ip", mode="before")
|
||||
@classmethod
|
||||
def validate_ip_field(cls, v: Optional[str]) -> str:
|
||||
"""Validate and normalize IP address, defaulting to 0.0.0.0."""
|
||||
if v is None or v == "":
|
||||
return "0.0.0.0"
|
||||
return validate_ip(v)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSIONS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SessionInfo(BaseModel):
|
||||
"""Session information."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
sesRowid: Optional[int] = Field(None, description="Session row ID")
|
||||
sesMac: Optional[str] = Field(None, description="Device MAC address")
|
||||
sesDateTimeConnection: Optional[str] = Field(None, description="Connection timestamp")
|
||||
sesDateTimeDisconnection: Optional[str] = Field(None, description="Disconnection timestamp")
|
||||
sesIPAddress: Optional[str] = Field(None, description="IP address during session")
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request to create a session."""
|
||||
mac: str = Field(..., description="Device MAC")
|
||||
ip: str = Field(..., description="Device IP")
|
||||
start_time: str = Field(..., description="Start time")
|
||||
end_time: Optional[str] = Field(None, description="End time")
|
||||
event_type_conn: str = Field("Connected", description="Connection event type")
|
||||
event_type_disc: str = Field("Disconnected", description="Disconnection event type")
|
||||
|
||||
@field_validator("mac")
|
||||
@classmethod
|
||||
def validate_mac_address(cls, v: str) -> str:
|
||||
return validate_mac(v)
|
||||
|
||||
@field_validator("ip")
|
||||
@classmethod
|
||||
def validate_ip_address(cls, v: str) -> str:
|
||||
return validate_ip(v)
|
||||
|
||||
|
||||
class DeleteSessionRequest(BaseModel):
|
||||
"""Request to delete sessions for a MAC."""
|
||||
mac: str = Field(..., description="Device MAC")
|
||||
|
||||
@field_validator("mac")
|
||||
@classmethod
|
||||
def validate_mac_address(cls, v: str) -> str:
|
||||
return validate_mac(v)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MESSAGING / IN-APP NOTIFICATIONS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class InAppNotification(BaseModel):
|
||||
"""In-app notification model."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
id: Optional[int] = Field(None, description="Notification ID")
|
||||
guid: Optional[str] = Field(None, description="Unique notification GUID")
|
||||
text: str = Field(..., description="Notification text content")
|
||||
level: NOTIFICATION_LEVELS = Field("info", description="Notification level")
|
||||
read: Optional[int] = Field(0, description="Read status (0 or 1)")
|
||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||
|
||||
|
||||
class CreateNotificationRequest(BaseModel):
|
||||
"""Request to create an in-app notification."""
|
||||
content: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=1024,
|
||||
description="Notification content"
|
||||
)
|
||||
level: NOTIFICATION_LEVELS = Field(
|
||||
"info",
|
||||
description="Notification severity level"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SYNC SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SyncPushRequest(BaseModel):
|
||||
"""Request to push data to sync."""
|
||||
data: dict = Field(..., description="Data to sync")
|
||||
node_name: str = Field(..., description="Name of the node sending data")
|
||||
plugin: str = Field(..., description="Plugin identifier")
|
||||
|
||||
|
||||
class SyncPullResponse(BaseResponse):
|
||||
"""Response with sync data."""
|
||||
data: Optional[dict] = Field(None, description="Synchronized data")
|
||||
last_sync: Optional[str] = Field(None, description="Last sync timestamp")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DB QUERY SCHEMAS (Raw SQL)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DbQueryRequest(BaseModel):
|
||||
"""
|
||||
Request for raw database query.
|
||||
WARNING: This is a highly privileged operation.
|
||||
"""
|
||||
rawSql: str = Field(
|
||||
...,
|
||||
description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)"
|
||||
)
|
||||
# Legacy compatibility: removed strict safety check
|
||||
# TODO: SECURITY CRITICAL - Re-enable strict safety checks.
|
||||
# The `confirm_dangerous_query` default was relaxed to `True` to maintain backward compatibility
|
||||
# with the legacy frontend which sends raw SQL directly.
|
||||
#
|
||||
# CONTEXT: This explicit safety check was introduced with the new Pydantic validation layer.
|
||||
# The legacy PHP frontend predates these formal schemas and does not send the
|
||||
# `confirm_dangerous_query` flag, causing 422 Validation Errors when this check is enforced.
|
||||
#
|
||||
# Actionable Advice:
|
||||
# 1. Implement a parser to strictly whitelist only `SELECT` statements if raw SQL is required.
|
||||
# 2. Migrate the frontend to use structured endpoints (e.g., `/devices/search`, `/dbquery/read`) instead of raw SQL.
|
||||
# 3. Once migrated, revert `confirm_dangerous_query` default to `False` and enforce the check.
|
||||
confirm_dangerous_query: bool = Field(
|
||||
True,
|
||||
description="Required to be True to acknowledge the risks of raw SQL execution"
|
||||
)
|
||||
|
||||
|
||||
class DbQueryUpdateRequest(BaseModel):
|
||||
"""Request for DB update query."""
|
||||
columnName: str = Field(..., description="Column to filter by")
|
||||
id: List[Any] = Field(..., description="List of IDs to update")
|
||||
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
|
||||
columns: List[str] = Field(..., description="Columns to update")
|
||||
values: List[Any] = Field(..., description="New values")
|
||||
|
||||
@field_validator("columnName")
|
||||
@classmethod
|
||||
def validate_column_name(cls, v: str) -> str:
|
||||
return validate_column_identifier(v)
|
||||
|
||||
@field_validator("columns")
|
||||
@classmethod
|
||||
def validate_column_list(cls, values: List[str]) -> List[str]:
|
||||
return [validate_column_identifier(value) for value in values]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_columns_values(self) -> "DbQueryUpdateRequest":
|
||||
if len(self.columns) != len(self.values):
|
||||
raise ValueError("columns and values must have the same length")
|
||||
return self
|
||||
|
||||
|
||||
class DbQueryDeleteRequest(BaseModel):
|
||||
"""Request for DB delete query."""
|
||||
columnName: str = Field(..., description="Column to filter by")
|
||||
id: List[Any] = Field(..., description="List of IDs to delete")
|
||||
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
|
||||
|
||||
@field_validator("columnName")
|
||||
@classmethod
|
||||
def validate_column_name(cls, v: str) -> str:
|
||||
return validate_column_identifier(v)
|
||||
|
||||
|
||||
class DbQueryResponse(BaseResponse):
|
||||
"""Response from database query."""
|
||||
data: Any = Field(None, description="Query result data")
|
||||
columns: Optional[List[str]] = Field(None, description="Column names if applicable")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOGS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CleanLogRequest(BaseModel):
|
||||
"""Request to clean/truncate a log file."""
|
||||
logFile: ALLOWED_LOG_FILES = Field(
|
||||
...,
|
||||
description="Name of the log file to clean"
|
||||
)
|
||||
|
||||
|
||||
class LogResource(BaseModel):
|
||||
"""Log file resource information."""
|
||||
name: str = Field(..., description="Log file name")
|
||||
path: str = Field(..., description="Full path to log file")
|
||||
size_bytes: int = Field(0, description="File size in bytes")
|
||||
modified: Optional[str] = Field(None, description="Last modification timestamp")
|
||||
|
||||
|
||||
class AddToQueueRequest(BaseModel):
|
||||
"""Request to add action to execution queue."""
|
||||
action: str = Field(..., description="Action string (e.g. update_api|devices)")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SETTINGS SCHEMAS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SettingValue(BaseModel):
|
||||
"""A single setting value."""
|
||||
key: str = Field(..., description="Setting key name")
|
||||
value: Any = Field(..., description="Setting value")
|
||||
|
||||
|
||||
class GetSettingResponse(BaseResponse):
|
||||
"""Response for getting a setting value."""
|
||||
value: Any = Field(None, description="The setting value")
|
||||
191
server/api_server/openapi/spec_generator.py
Normal file
191
server/api_server/openapi/spec_generator.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
NetAlertX OpenAPI Specification Generator
|
||||
|
||||
This module provides a registry-based approach to OpenAPI spec generation.
|
||||
It converts Pydantic models to JSON Schema and assembles a complete OpenAPI 3.1 spec.
|
||||
|
||||
Key Features:
|
||||
- Automatic Pydantic -> JSON Schema conversion
|
||||
- Centralized endpoint registry
|
||||
- Unique operationId enforcement
|
||||
- Complete request/response schema generation
|
||||
|
||||
Usage:
|
||||
from spec_generator import registry, generate_openapi_spec, register_tool
|
||||
|
||||
# Register endpoints (typically done at module load)
|
||||
register_tool(
|
||||
path="/devices/search",
|
||||
method="POST",
|
||||
operation_id="search_devices",
|
||||
description="Search for devices",
|
||||
request_model=DeviceSearchRequest,
|
||||
response_model=DeviceSearchResponse
|
||||
)
|
||||
|
||||
# Generate spec (called by MCP endpoint)
|
||||
spec = generate_openapi_spec()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from .registry import (
|
||||
clear_registry,
|
||||
_registry,
|
||||
_registry_lock,
|
||||
_disabled_tools
|
||||
)
|
||||
from .introspection import introspect_flask_app, introspect_graphql_schema
|
||||
from .schema_converter import (
|
||||
build_parameters,
|
||||
build_request_body,
|
||||
build_responses
|
||||
)
|
||||
|
||||
_rebuild_lock = threading.Lock()
|
||||
|
||||
|
||||
def generate_openapi_spec(
|
||||
title: str = "NetAlertX API",
|
||||
version: str = "2.0.0",
|
||||
description: str = "NetAlertX Network Monitoring API - MCP Compatible",
|
||||
servers: Optional[List[Dict[str, str]]] = None,
|
||||
flask_app: Optional[Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Assemble a complete OpenAPI specification from the registered endpoints."""
|
||||
|
||||
with _rebuild_lock:
|
||||
# If no app provided and registry is empty, try to use the one from api_server_start
|
||||
if not flask_app and not _registry:
|
||||
try:
|
||||
from ..api_server_start import app as start_app
|
||||
flask_app = start_app
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# If we are in "dynamic mode", we rebuild the registry from code
|
||||
if flask_app:
|
||||
from ..graphql_endpoint import devicesSchema
|
||||
clear_registry()
|
||||
introspect_graphql_schema(devicesSchema)
|
||||
introspect_flask_app(flask_app)
|
||||
|
||||
spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {
|
||||
"title": title,
|
||||
"version": version,
|
||||
"description": description,
|
||||
"contact": {
|
||||
"name": "NetAlertX",
|
||||
"url": "https://github.com/jokob-sk/NetAlertX"
|
||||
}
|
||||
},
|
||||
"servers": servers or [{"url": "/", "description": "Local server"}],
|
||||
"security": [
|
||||
{"BearerAuth": []}
|
||||
],
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"BearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"description": "API token from NetAlertX settings (API_TOKEN)"
|
||||
}
|
||||
},
|
||||
"schemas": {}
|
||||
},
|
||||
"paths": {},
|
||||
"tags": []
|
||||
}
|
||||
|
||||
definitions = {}
|
||||
|
||||
# Collect unique tags
|
||||
tag_set = set()
|
||||
|
||||
with _registry_lock:
|
||||
disabled_snapshot = _disabled_tools.copy()
|
||||
for entry in _registry:
|
||||
path = entry["path"]
|
||||
method = entry["method"].lower()
|
||||
|
||||
# Initialize path if not exists
|
||||
if path not in spec["paths"]:
|
||||
spec["paths"][path] = {}
|
||||
|
||||
# Build operation object
|
||||
operation = {
|
||||
"operationId": entry["operation_id"],
|
||||
"summary": entry["summary"],
|
||||
"description": entry["description"],
|
||||
"tags": entry["tags"],
|
||||
"deprecated": entry["deprecated"]
|
||||
}
|
||||
|
||||
# Inject disabled status if applicable
|
||||
if entry["operation_id"] in disabled_snapshot:
|
||||
operation["x-mcp-disabled"] = True
|
||||
|
||||
# Inject original ID if suffixed (Coderabbit fix)
|
||||
if entry.get("original_operation_id"):
|
||||
operation["x-original-operationId"] = entry["original_operation_id"]
|
||||
|
||||
# Add parameters (path + query)
|
||||
parameters = build_parameters(entry)
|
||||
if parameters:
|
||||
operation["parameters"] = parameters
|
||||
|
||||
# Add request body for POST/PUT/PATCH/DELETE
|
||||
if method in ("post", "put", "patch", "delete") and entry.get("request_model"):
|
||||
request_body = build_request_body(
|
||||
entry["request_model"],
|
||||
definitions,
|
||||
allow_multipart_payload=entry.get("allow_multipart_payload", False)
|
||||
)
|
||||
if request_body:
|
||||
operation["requestBody"] = request_body
|
||||
|
||||
# Add responses
|
||||
operation["responses"] = build_responses(
|
||||
entry.get("response_model"), definitions
|
||||
)
|
||||
|
||||
spec["paths"][path][method] = operation
|
||||
|
||||
# Collect tags
|
||||
for tag in entry["tags"]:
|
||||
tag_set.add(tag)
|
||||
|
||||
spec["components"]["schemas"] = definitions
|
||||
|
||||
# Build tags array with descriptions
|
||||
tag_descriptions = {
|
||||
"devices": "Device management and queries",
|
||||
"nettools": "Network diagnostic tools",
|
||||
"events": "Event and alert management",
|
||||
"sessions": "Session history tracking",
|
||||
"messaging": "In-app notifications",
|
||||
"settings": "Configuration management",
|
||||
"sync": "Data synchronization",
|
||||
"logs": "Log file access",
|
||||
"dbquery": "Direct database queries"
|
||||
}
|
||||
|
||||
spec["tags"] = [
|
||||
{"name": tag, "description": tag_descriptions.get(tag, f"{tag.title()} operations")}
|
||||
for tag in sorted(tag_set)
|
||||
]
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
# Initialize registry on module load
|
||||
# Registry is now populated dynamically via introspection in generate_openapi_spec
|
||||
def _register_all_endpoints():
|
||||
"""Dummy function for compatibility with legacy tests."""
|
||||
pass
|
||||
31
server/api_server/openapi/swagger.html
Normal file
31
server/api_server/openapi/swagger.html
Normal file
@@ -0,0 +1,31 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="description" content="NetAlertX API Documentation" />
|
||||
<title>NetAlertX API Docs</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui.css" integrity="sha384-+yyzNgM3K92sROwsXxYCxaiLWxWJ0G+v/9A+qIZ2rgefKgkdcmJI+L601cqPD/Ut" crossorigin="anonymous" />
|
||||
<style>
|
||||
body { margin: 0; padding: 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui-bundle.js" integrity="sha384-qn5tagrAjZi8cSmvZ+k3zk4+eDEEUcP9myuR2J6V+/H6rne++v6ChO7EeHAEzqxQ" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
window.onload = () => {
|
||||
window.ui = SwaggerUIBundle({
|
||||
url: '/openapi.json',
|
||||
dom_id: '#swagger-ui',
|
||||
deepLinking: true,
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout",
|
||||
});
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
181
server/api_server/openapi/validation.py
Normal file
181
server/api_server/openapi/validation.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional, Type
|
||||
from flask import request, jsonify
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from logger import mylog
|
||||
|
||||
|
||||
def _handle_validation_error(e: ValidationError, operation_id: str, validation_error_code: int):
|
||||
"""Internal helper to format Pydantic validation errors."""
|
||||
mylog("verbose", [f"[Validation] Error for {operation_id}: {e}"])
|
||||
|
||||
# Construct a legacy-compatible error message if possible
|
||||
error_msg = "Validation Error"
|
||||
if e.errors():
|
||||
err = e.errors()[0]
|
||||
if err['type'] == 'missing':
|
||||
loc = err.get('loc')
|
||||
field_name = loc[0] if loc and len(loc) > 0 else "unknown field"
|
||||
error_msg = f"Missing required '{field_name}'"
|
||||
else:
|
||||
error_msg = f"Validation Error: {err['msg']}"
|
||||
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"details": json.loads(e.json())
|
||||
}), validation_error_code
|
||||
|
||||
|
||||
def validate_request(
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
request_model: Optional[Type[BaseModel]] = None,
|
||||
response_model: Optional[Type[BaseModel]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
path_params: Optional[list[dict]] = None,
|
||||
query_params: Optional[list[dict]] = None,
|
||||
validation_error_code: int = 422,
|
||||
auth_callable: Optional[Callable[[], bool]] = None,
|
||||
allow_multipart_payload: bool = False
|
||||
):
|
||||
"""
|
||||
Decorator to register a Flask route with the OpenAPI registry and validate incoming requests.
|
||||
|
||||
Features:
|
||||
- Auto-registers the endpoint with the OpenAPI spec generator.
|
||||
- Validates JSON body against `request_model` (for POST/PUT).
|
||||
- Injects the validated Pydantic model as the first argument to the view function.
|
||||
- Supports auth_callable to check permissions before validation.
|
||||
- Returns 422 (default) if validation fails.
|
||||
- allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields.
|
||||
"""
|
||||
|
||||
def decorator(f: Callable) -> Callable:
|
||||
# Detect if f accepts 'payload' argument (unwrap if needed)
|
||||
real_f = inspect.unwrap(f)
|
||||
sig = inspect.signature(real_f)
|
||||
accepts_payload = 'payload' in sig.parameters
|
||||
|
||||
f._openapi_metadata = {
|
||||
"operation_id": operation_id,
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"request_model": request_model,
|
||||
"response_model": response_model,
|
||||
"tags": tags,
|
||||
"path_params": path_params,
|
||||
"query_params": query_params,
|
||||
"allow_multipart_payload": allow_multipart_payload
|
||||
}
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 0. Handle OPTIONS explicitly if it reaches here (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
return jsonify({"success": True}), 200
|
||||
|
||||
# 1. Check Authorization first (Coderabbit fix)
|
||||
if auth_callable and not auth_callable():
|
||||
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
|
||||
|
||||
validated_instance = None
|
||||
|
||||
# 2. Payload Validation
|
||||
if request_model:
|
||||
# Helper to detect multipart requests by content-type (not just files)
|
||||
is_multipart = (
|
||||
request.content_type and request.content_type.startswith("multipart/")
|
||||
)
|
||||
|
||||
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
|
||||
# Explicit multipart handling (Coderabbit fix)
|
||||
# Check both request.files and content-type for form-only multipart bodies
|
||||
if request.files or is_multipart:
|
||||
if allow_multipart_payload:
|
||||
# Attempt validation from form data if allowed
|
||||
try:
|
||||
data = request.form.to_dict()
|
||||
validated_instance = request_model(**data)
|
||||
except ValidationError as e:
|
||||
mylog("verbose", [f"[Validation] Multipart validation failed for {operation_id}: {e}"])
|
||||
# Only continue without validation if handler doesn't expect payload
|
||||
if accepts_payload:
|
||||
return _handle_validation_error(e, operation_id, validation_error_code)
|
||||
# Otherwise, handler will process files manually
|
||||
else:
|
||||
# If multipart is not allowed but files are present, we fail fast
|
||||
# This prevents handlers from receiving unexpected None payloads
|
||||
mylog("verbose", [f"[Validation] Multipart bypass attempted for {operation_id} but not allowed."])
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Invalid Content-Type",
|
||||
"message": "Multipart requests are not allowed for this endpoint"
|
||||
}), 415
|
||||
else:
|
||||
if not request.is_json and request.content_length:
|
||||
return jsonify({"success": False, "error": "Invalid Content-Type", "message": "Content-Type must be application/json"}), 415
|
||||
|
||||
try:
|
||||
data = request.get_json(silent=False) or {}
|
||||
validated_instance = request_model(**data)
|
||||
except ValidationError as e:
|
||||
return _handle_validation_error(e, operation_id, validation_error_code)
|
||||
except BadRequest as e:
|
||||
mylog("verbose", [f"[Validation] Invalid JSON for {operation_id}: {e}"])
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Invalid JSON",
|
||||
"message": "Request body must be valid JSON"
|
||||
}), 400
|
||||
except (TypeError, KeyError, AttributeError) as e:
|
||||
mylog("verbose", [f"[Validation] Malformed request for {operation_id}: {e}"])
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Invalid Request",
|
||||
"message": "Unable to process request body"
|
||||
}), 400
|
||||
elif request.method == "GET":
|
||||
# Attempt to validate from query parameters for GET requests
|
||||
try:
|
||||
# request.args is a MultiDict; to_dict() gives first value of each key
|
||||
# which is usually what we want for Pydantic models.
|
||||
data = request.args.to_dict()
|
||||
validated_instance = request_model(**data)
|
||||
except ValidationError as e:
|
||||
return _handle_validation_error(e, operation_id, validation_error_code)
|
||||
except (TypeError, ValueError, KeyError) as e:
|
||||
mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"])
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Invalid query parameters",
|
||||
"message": "Unable to process query parameters"
|
||||
}), 400
|
||||
else:
|
||||
# Unsupported HTTP method with a request_model - fail explicitly
|
||||
mylog("verbose", [f"[Validation] Unsupported HTTP method {request.method} for {operation_id} with request_model"])
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "Method Not Allowed",
|
||||
"message": f"HTTP method {request.method} is not supported for this endpoint"
|
||||
}), 405
|
||||
|
||||
if validated_instance:
|
||||
if accepts_payload:
|
||||
kwargs['payload'] = validated_instance
|
||||
else:
|
||||
# Fail fast if decorated function doesn't accept payload (Coderabbit fix)
|
||||
mylog("minimal", [f"[Validation] Endpoint {operation_id} does not accept 'payload' argument!"])
|
||||
raise TypeError(f"Function {f.__name__} (operationId: {operation_id}) does not accept 'payload' argument.")
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
Reference in New Issue
Block a user