mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2026-04-05 01:31:49 -07:00
Fixes for Coderabbit review
This commit is contained in:
@@ -7,6 +7,7 @@ import os
|
|||||||
from flask import Flask, request, jsonify, Response
|
from flask import Flask, request, jsonify, Response
|
||||||
from models.device_instance import DeviceInstance # noqa: E402
|
from models.device_instance import DeviceInstance # noqa: E402
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
from werkzeug.exceptions import HTTPException
|
||||||
|
|
||||||
# Register NetAlertX directories
|
# Register NetAlertX directories
|
||||||
INSTALL_PATH = os.getenv("NETALERTX_APP", "/app")
|
INSTALL_PATH = os.getenv("NETALERTX_APP", "/app")
|
||||||
@@ -105,6 +106,8 @@ app = Flask(__name__)
|
|||||||
@app.errorhandler(Exception)
|
@app.errorhandler(Exception)
|
||||||
def handle_500_error(e):
|
def handle_500_error(e):
|
||||||
"""Global error handler for uncaught exceptions."""
|
"""Global error handler for uncaught exceptions."""
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
return e
|
||||||
mylog("none", [f"[API] Uncaught exception: {e}"])
|
mylog("none", [f"[API] Uncaught exception: {e}"])
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
|
|||||||
@@ -679,6 +679,7 @@ class CreateEventRequest(BaseModel):
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SESSIONS SCHEMAS
|
# SESSIONS SCHEMAS
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
class SessionInfo(BaseModel):
|
||||||
"""Session information."""
|
"""Session information."""
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
extra="allow",
|
extra="allow",
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ def generate_openapi_spec(
|
|||||||
# Apply default disabled tools from setting `MCP_DISABLED_TOOLS`, env var, or hard-coded defaults
|
# Apply default disabled tools from setting `MCP_DISABLED_TOOLS`, env var, or hard-coded defaults
|
||||||
# Format: comma-separated operation IDs, e.g. "dbquery_read,dbquery_write"
|
# Format: comma-separated operation IDs, e.g. "dbquery_read,dbquery_write"
|
||||||
try:
|
try:
|
||||||
disabled_env = ""
|
disabled_env = None
|
||||||
# Prefer setting from app.conf/settings when available
|
# Prefer setting from app.conf/settings when available
|
||||||
try:
|
try:
|
||||||
from helper import get_setting_value
|
from helper import get_setting_value
|
||||||
@@ -88,9 +88,9 @@ def generate_openapi_spec(
|
|||||||
# If helper is unavailable, fall back to environment
|
# If helper is unavailable, fall back to environment
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if disabled_env is None:
|
if not disabled_env:
|
||||||
env_val = os.getenv("MCP_DISABLED_TOOLS")
|
env_val = os.getenv("MCP_DISABLED_TOOLS")
|
||||||
if env_val is not None:
|
if env_val:
|
||||||
disabled_env = env_val.strip()
|
disabled_env = env_val.strip()
|
||||||
|
|
||||||
# If still not set, apply safe hard-coded defaults
|
# If still not set, apply safe hard-coded defaults
|
||||||
|
|||||||
63
test/api_endpoints/test_mcp_disabled_tools.py
Normal file
63
test/api_endpoints/test_mcp_disabled_tools.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# Use cwd as fallback if env var is not set, assuming running from project root
|
||||||
|
INSTALL_PATH = os.getenv('NETALERTX_APP', os.getcwd())
|
||||||
|
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
|
||||||
|
|
||||||
|
from api_server.openapi.spec_generator import generate_openapi_spec
|
||||||
|
from api_server.api_server_start import app
|
||||||
|
|
||||||
|
class TestMCPDisabledTools:
|
||||||
|
|
||||||
|
def test_disabled_tools_via_env_var(self):
|
||||||
|
"""Test that MCP_DISABLED_TOOLS env var disables specific tools."""
|
||||||
|
# Clean registry first to ensure clean state
|
||||||
|
from api_server.openapi.registry import clear_registry
|
||||||
|
clear_registry()
|
||||||
|
|
||||||
|
# Mock get_setting_value to return None (simulating no config setting)
|
||||||
|
# and mock os.getenv to return our target list
|
||||||
|
with patch("helper.get_setting_value", return_value=None), \
|
||||||
|
patch.dict(os.environ, {"MCP_DISABLED_TOOLS": "search_devices_api"}):
|
||||||
|
|
||||||
|
spec = generate_openapi_spec(flask_app=app)
|
||||||
|
|
||||||
|
# Locate the operation
|
||||||
|
# search_devices_api is usually mapped to /devices/search [POST] or similar
|
||||||
|
# We search the spec for the operationId
|
||||||
|
|
||||||
|
found = False
|
||||||
|
for path, methods in spec["paths"].items():
|
||||||
|
for method, op in methods.items():
|
||||||
|
if op["operationId"] == "search_devices_api":
|
||||||
|
assert op.get("x-mcp-disabled") is True
|
||||||
|
found = True
|
||||||
|
|
||||||
|
assert found, "search_devices_api operation not found in spec"
|
||||||
|
|
||||||
|
def test_disabled_tools_default_fallback(self):
|
||||||
|
"""Test fallback to defaults when no setting or env var exists."""
|
||||||
|
from api_server.openapi.registry import clear_registry
|
||||||
|
clear_registry()
|
||||||
|
|
||||||
|
with patch("helper.get_setting_value", return_value=None), \
|
||||||
|
patch.dict(os.environ, {}, clear=True): # Clear env to ensure no MCP_DISABLED_TOOLS
|
||||||
|
|
||||||
|
spec = generate_openapi_spec(flask_app=app)
|
||||||
|
|
||||||
|
# Default is "dbquery_read,dbquery_write"
|
||||||
|
|
||||||
|
# Check dbquery_read
|
||||||
|
found_read = False
|
||||||
|
for path, methods in spec["paths"].items():
|
||||||
|
for method, op in methods.items():
|
||||||
|
if op["operationId"] == "dbquery_read":
|
||||||
|
assert op.get("x-mcp-disabled") is True
|
||||||
|
found_read = True
|
||||||
|
|
||||||
|
assert found_read, "dbquery_read should be disabled by default"
|
||||||
|
|
||||||
Reference in New Issue
Block a user