From bb0c0e1c7433b2714bca6829c71f1dd6b2051351 Mon Sep 17 00:00:00 2001 From: Adam Outler Date: Mon, 19 Jan 2026 00:03:27 +0000 Subject: [PATCH] Coderabbit fixes: - Mac - Flask debug - Threaded flask - propagate token in GET requests - enhance spec docs - normalize MAC x2 - mcp disablement redundant private attribute - run all tests imports --- docs/DEBUG_API_SERVER.md | 13 ++ server/api_server/api_server_start.py | 20 +++- server/api_server/mcp_endpoint.py | 5 + server/api_server/openapi/schema_converter.py | 7 +- server/helper.py | 36 ++++++ server/initialise.py | 9 -- server/models/device_instance.py | 13 +- .../test_device_update_normalization.py | 70 +++++++++++ test/server/test_api_server_start.py | 112 ++++++++++++++++++ test/test_mcp_disablement.py | 2 - test/ui/run_all_tests.py | 37 +++--- test/ui/test_ui_devices.py | 20 +++- test/ui/test_ui_maintenance.py | 37 ++++-- 13 files changed, 326 insertions(+), 55 deletions(-) create mode 100644 test/api_endpoints/test_device_update_normalization.py create mode 100644 test/server/test_api_server_start.py diff --git a/docs/DEBUG_API_SERVER.md b/docs/DEBUG_API_SERVER.md index b8feac8a..4caafff8 100644 --- a/docs/DEBUG_API_SERVER.md +++ b/docs/DEBUG_API_SERVER.md @@ -38,6 +38,19 @@ All application settings can also be initialized via the `APP_CONF_OVERRIDE` doc There are several ways to check if the GraphQL server is running. +## Flask debug mode (environment) + +You can control whether the Flask development debugger is enabled by setting the environment variable `FLASK_DEBUG` (default: `False`). Enabling debug mode will turn on the interactive debugger which may expose a remote code execution (RCE) vector if the server is reachable; **only enable this for local development** and never in production. Valid truthy values are: `1`, `true`, `yes`, `on` (case-insensitive). + +In the running container you can set this variable via Docker Compose or your environment, for example: + +```yaml +environment: + - FLASK_DEBUG=1 +``` + +When enabled, the GraphQL server startup logs will indicate the debug setting. + ### Init Check You can navigate to System Info -> Init Check to see if `isGraphQLServerRunning` is ticked: diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index bea4490a..5e7eac66 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -13,7 +13,7 @@ INSTALL_PATH = os.getenv("NETALERTX_APP", "/app") sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from logger import mylog # noqa: E402 [flake8 lint suppression] -from helper import get_setting_value # noqa: E402 [flake8 lint suppression] +from helper import get_setting_value, get_env_setting_value # noqa: E402 [flake8 lint suppression] from db.db_helper import get_date_from_period # noqa: E402 [flake8 lint suppression] from app_state import updateState # noqa: E402 [flake8 lint suppression] @@ -1693,10 +1693,26 @@ def start_server(graphql_port, app_state): if app_state.graphQLServerStarted == 0: mylog("verbose", [f"[graphql endpoint] Starting on port: {graphql_port}"]) + # First check environment variable override (direct env like FLASK_DEBUG) + env_val = get_env_setting_value("FLASK_DEBUG", None) + if env_val is not None: + flask_debug = bool(env_val) + mylog("verbose", [f"[graphql endpoint] Flask debug mode: {flask_debug} (FLASK_DEBUG env override)"]) + else: + # Fall back to configured setting `FLASK_DEBUG` (from app.conf / overrides) + flask_debug = get_setting_value("FLASK_DEBUG") + # Normalize value to boolean in case it's stored as a string + if isinstance(flask_debug, str): + flask_debug = flask_debug.strip().lower() in ("1", "true", "yes", "on") + else: + flask_debug = bool(flask_debug) + + mylog("verbose", [f"[graphql endpoint] Flask debug mode: {flask_debug} (FLASK_DEBUG setting)"]) + # Start Flask app in a separate thread thread = threading.Thread( target=lambda: app.run( - host="0.0.0.0", port=graphql_port, debug=False, use_reloader=False + host="0.0.0.0", port=graphql_port, threaded=True,debug=flask_debug, use_reloader=False ) ) thread.start() diff --git a/server/api_server/mcp_endpoint.py b/server/api_server/mcp_endpoint.py index e9195155..005ff1ef 100644 --- a/server/api_server/mcp_endpoint.py +++ b/server/api_server/mcp_endpoint.py @@ -642,6 +642,11 @@ def _execute_tool(route: Dict[str, Any], args: Dict[str, Any]) -> Dict[str, Any] headers = {"Content-Type": "application/json"} if "Authorization" in request.headers: headers["Authorization"] = request.headers["Authorization"] + else: + # Propagate query token or fallback to configured API token for internal loopback + token = request.args.get("token") or get_setting_value('API_TOKEN') + if token: + headers["Authorization"] = f"Bearer {token}" filtered_body_args = {k: v for k, v in args.items() if f"{{{k}}}" not in route['path']} diff --git a/server/api_server/openapi/schema_converter.py b/server/api_server/openapi/schema_converter.py index 31a2d12b..c6979527 100644 --- a/server/api_server/openapi/schema_converter.py +++ b/server/api_server/openapi/schema_converter.py @@ -4,7 +4,7 @@ from typing import Dict, Any, Optional, Type, List from pydantic import BaseModel -def pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]: +def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]: """ Convert a Pydantic model to JSON Schema (OpenAPI 3.1 compatible). @@ -13,12 +13,13 @@ def pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]: Args: model: Pydantic BaseModel class + mode: Schema mode - "validation" (for inputs) or "serialization" (for outputs) Returns: JSON Schema dictionary """ # Pydantic v2 uses model_json_schema() - schema = model.model_json_schema(mode="serialization") + schema = model.model_json_schema(mode=mode) # Remove $defs if empty (cleaner output) if "$defs" in schema and not schema["$defs"]: @@ -169,7 +170,7 @@ def build_responses( # Success response (200) if response_model: # Strip validation from response schema to save tokens - schema = strip_validation(pydantic_to_json_schema(response_model)) + schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization")) schema = extract_definitions(schema, definitions) responses["200"] = { "description": "Successful response", diff --git a/server/helper.py b/server/helper.py index 543f2488..6eb21f8a 100755 --- a/server/helper.py +++ b/server/helper.py @@ -361,6 +361,42 @@ def setting_value_to_python_type(set_type, set_value): return value +# ------------------------------------------------------------------------------- +# Environment helper +def get_env_setting_value(key, default=None): + """Return a typed value from environment variable if present. + + - Parses booleans (1/0, true/false, yes/no, on/off). + - Tries to parse ints and JSON literals where sensible. + - Returns `default` when env var is not set. + """ + val = os.environ.get(key) + if val is None: + return default + + v = val.strip() + # Booleans + low = v.lower() + if low in ("1", "true", "yes", "on"): + return True + if low in ("0", "false", "no", "off"): + return False + + # Integer + try: + if re.fullmatch(r"-?\d+", v): + return int(v) + except Exception: + pass + + # JSON-like (list/object/true/false/null/number) + try: + return json.loads(v) + except Exception: + # Fallback to raw string + return v + + # ------------------------------------------------------------------------------- def updateSubnets(scan_subnets): """ diff --git a/server/initialise.py b/server/initialise.py index 5e3ad9e4..1c6f52aa 100755 --- a/server/initialise.py +++ b/server/initialise.py @@ -334,15 +334,6 @@ def importConfigs(pm, db, all_plugins): "[]", "General", ) - conf.FLASK_DEBUG = ccd( - "FLASK_DEBUG", - False, - c_d, - "Flask debug mode - SECURITY WARNING: Enabling enables interactive debugger with RCE risk. Configure via environment only, not exposed in UI.", - '{"dataType": "boolean","elements": []}', - "[]", - "system", - ) conf.VERSION = ccd( "VERSION", "", diff --git a/server/models/device_instance.py b/server/models/device_instance.py index 7e7085e4..430abf69 100755 --- a/server/models/device_instance.py +++ b/server/models/device_instance.py @@ -4,7 +4,7 @@ import re import sqlite3 import csv from io import StringIO -from front.plugins.plugin_helper import is_mac +from front.plugins.plugin_helper import is_mac, normalize_mac from logger import mylog from models.plugin_object_instance import PluginObjectInstance from database import get_temp_db_connection @@ -500,6 +500,9 @@ class DeviceInstance: def setDeviceData(self, mac, data): """Update or create a device.""" + normalized_mac = normalize_mac(mac) + normalized_parent_mac = normalize_mac(data.get("devParentMAC") or "") + conn = None try: if data.get("createNew", False): @@ -517,7 +520,7 @@ class DeviceInstance: """ values = ( - mac, + normalized_mac, data.get("devName") or "", data.get("devOwner") or "", data.get("devType") or "", @@ -527,7 +530,7 @@ class DeviceInstance: data.get("devGroup") or "", data.get("devLocation") or "", data.get("devComments") or "", - data.get("devParentMAC") or "", + normalized_parent_mac, data.get("devParentPort") or "", data.get("devSSID") or "", data.get("devSite") or "", @@ -569,7 +572,7 @@ class DeviceInstance: data.get("devGroup") or "", data.get("devLocation") or "", data.get("devComments") or "", - data.get("devParentMAC") or "", + normalized_parent_mac, data.get("devParentPort") or "", data.get("devSSID") or "", data.get("devSite") or "", @@ -583,7 +586,7 @@ class DeviceInstance: data.get("devIsNew") or 0, data.get("devIsArchived") or 0, data.get("devCustomProps") or "", - mac, + normalized_mac, ) conn = get_temp_db_connection() diff --git a/test/api_endpoints/test_device_update_normalization.py b/test/api_endpoints/test_device_update_normalization.py new file mode 100644 index 00000000..70176d5e --- /dev/null +++ b/test/api_endpoints/test_device_update_normalization.py @@ -0,0 +1,70 @@ + +import pytest +import random +from helper import get_setting_value +from api_server.api_server_start import app +from models.device_instance import DeviceInstance + +@pytest.fixture(scope="session") +def api_token(): + return get_setting_value("API_TOKEN") + +@pytest.fixture +def client(): + with app.test_client() as client: + yield client + +@pytest.fixture +def test_mac_norm(): + # Normalized MAC + return "AA:BB:CC:DD:EE:FF" + +@pytest.fixture +def test_parent_mac_input(): + # Lowercase input MAC + return "aa:bb:cc:dd:ee:00" + +@pytest.fixture +def test_parent_mac_norm(): + # Normalized expected MAC + return "AA:BB:CC:DD:EE:00" + +def auth_headers(token): + return {"Authorization": f"Bearer {token}"} + +def test_update_normalization(client, api_token, test_mac_norm, test_parent_mac_input, test_parent_mac_norm): + # 1. Create a device (using normalized MAC) + create_payload = { + "createNew": True, + "devName": "Normalization Test Device", + "devOwner": "Unit Test", + } + resp = client.post(f"/device/{test_mac_norm}", json=create_payload, headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # 2. Update the device using LOWERCASE MAC in URL + # And set devParentMAC to LOWERCASE + update_payload = { + "devParentMAC": test_parent_mac_input, + "devName": "Updated Device" + } + # Using lowercase MAC in URL: aa:bb:cc:dd:ee:ff + lowercase_mac = test_mac_norm.lower() + + resp = client.post(f"/device/{lowercase_mac}", json=update_payload, headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # 3. Verify in DB that devParentMAC is NORMALIZED + device_handler = DeviceInstance() + device = device_handler.getDeviceData(test_mac_norm) + + assert device is not None + assert device["devName"] == "Updated Device" + # This is the critical check: + assert device["devParentMAC"] == test_parent_mac_norm + assert device["devParentMAC"] != test_parent_mac_input # Should verify it changed from input if input was different case + + # Cleanup + device_handler.deleteDeviceByMAC(test_mac_norm) diff --git a/test/server/test_api_server_start.py b/test/server/test_api_server_start.py new file mode 100644 index 00000000..0259c942 --- /dev/null +++ b/test/server/test_api_server_start.py @@ -0,0 +1,112 @@ +from types import SimpleNamespace + +from server.api_server import api_server_start as api_mod + + +def _make_fake_thread(recorder): + class FakeThread: + def __init__(self, target=None): + self._target = target + + def start(self): + # call target synchronously for test + if self._target: + self._target() + + return FakeThread + + +def test_start_server_passes_debug_true(monkeypatch): + # Arrange + # Use the settings helper to provide the value + monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: True if k == 'FLASK_DEBUG' else None) + + called = {} + + def fake_run(*args, **kwargs): + called['args'] = args + called['kwargs'] = kwargs + + monkeypatch.setattr(api_mod, 'app', api_mod.app) + monkeypatch.setattr(api_mod.app, 'run', fake_run) + + # Replace threading.Thread with a fake that executes target immediately + FakeThread = _make_fake_thread(called) + monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread) + + # Prevent updateState side effects + monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None) + + app_state = SimpleNamespace(graphQLServerStarted=0) + + # Act + api_mod.start_server(12345, app_state) + + # Assert + assert 'kwargs' in called + assert called['kwargs']['debug'] is True + assert called['kwargs']['host'] == '0.0.0.0' + assert called['kwargs']['port'] == 12345 + + +def test_start_server_passes_debug_false(monkeypatch): + # Arrange + monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: False if k == 'FLASK_DEBUG' else None) + + called = {} + + def fake_run(*args, **kwargs): + called['args'] = args + called['kwargs'] = kwargs + + monkeypatch.setattr(api_mod, 'app', api_mod.app) + monkeypatch.setattr(api_mod.app, 'run', fake_run) + + FakeThread = _make_fake_thread(called) + monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread) + + monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None) + + app_state = SimpleNamespace(graphQLServerStarted=0) + + # Act + api_mod.start_server(22222, app_state) + + # Assert + assert 'kwargs' in called + assert called['kwargs']['debug'] is False + assert called['kwargs']['host'] == '0.0.0.0' + assert called['kwargs']['port'] == 22222 + + +def test_env_var_overrides_setting(monkeypatch): + # Arrange + # Ensure env override is present + monkeypatch.setenv('FLASK_DEBUG', '1') + # And the stored setting is False to ensure env takes precedence + monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: False if k == 'FLASK_DEBUG' else None) + + called = {} + + def fake_run(*args, **kwargs): + called['args'] = args + called['kwargs'] = kwargs + + monkeypatch.setattr(api_mod, 'app', api_mod.app) + monkeypatch.setattr(api_mod.app, 'run', fake_run) + + FakeThread = _make_fake_thread(called) + monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread) + + monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None) + + app_state = SimpleNamespace(graphQLServerStarted=0) + + # Act + api_mod.start_server(33333, app_state) + + # Assert + assert 'kwargs' in called + assert called['kwargs']['debug'] is True + assert called['kwargs']['host'] == '0.0.0.0' + assert called['kwargs']['port'] == 33333 diff --git a/test/test_mcp_disablement.py b/test/test_mcp_disablement.py index 37a1b7f3..dcb7400f 100644 --- a/test/test_mcp_disablement.py +++ b/test/test_mcp_disablement.py @@ -9,10 +9,8 @@ from server.api_server import mcp_endpoint @pytest.fixture(autouse=True) def reset_registry(): registry.clear_registry() - registry._disabled_tools.clear() yield registry.clear_registry() - registry._disabled_tools.clear() def test_disable_tool_management(): diff --git a/test/ui/run_all_tests.py b/test/ui/run_all_tests.py index 67368a5e..44ceff51 100644 --- a/test/ui/run_all_tests.py +++ b/test/ui/run_all_tests.py @@ -5,15 +5,8 @@ Runs all page-specific UI tests and provides summary """ import sys -# Import all test modules -from .test_helpers import test_ui_dashboard -from .test_helpers import test_ui_devices -from .test_helpers import test_ui_network -from .test_helpers import test_ui_maintenance -from .test_helpers import test_ui_multi_edit -from .test_helpers import test_ui_notifications -from .test_helpers import test_ui_settings -from .test_helpers import test_ui_plugins +import os +import pytest def main(): @@ -22,22 +15,28 @@ def main(): print("NetAlertX UI Test Suite") print("=" * 70) + # Get directory of this script + base_dir = os.path.dirname(os.path.abspath(__file__)) + test_modules = [ - ("Dashboard", test_ui_dashboard), - ("Devices", test_ui_devices), - ("Network", test_ui_network), - ("Maintenance", test_ui_maintenance), - ("Multi-Edit", test_ui_multi_edit), - ("Notifications", test_ui_notifications), - ("Settings", test_ui_settings), - ("Plugins", test_ui_plugins), + ("Dashboard", "test_ui_dashboard.py"), + ("Devices", "test_ui_devices.py"), + ("Network", "test_ui_network.py"), + ("Maintenance", "test_ui_maintenance.py"), + ("Multi-Edit", "test_ui_multi_edit.py"), + ("Notifications", "test_ui_notifications.py"), + ("Settings", "test_ui_settings.py"), + ("Plugins", "test_ui_plugins.py"), ] results = {} - for name, module in test_modules: + for name, filename in test_modules: try: - result = module.run_tests() + print(f"\nRunning {name} tests...") + file_path = os.path.join(base_dir, filename) + # Run pytest + result = pytest.main([file_path, "-v"]) results[name] = result == 0 except Exception as e: print(f"\n✗ {name} tests failed with exception: {e}") diff --git a/test/ui/test_ui_devices.py b/test/ui/test_ui_devices.py index da4480dd..aef75df8 100644 --- a/test/ui/test_ui_devices.py +++ b/test/ui/test_ui_devices.py @@ -82,13 +82,21 @@ def test_add_device_with_generated_mac_ip(driver, api_token): wait_for_page_load(driver, timeout=10) # --- Click "Add Device" --- - add_buttons = driver.find_elements(By.CSS_SELECTOR, "button#btnAddDevice, button[onclick*='addDevice'], a[href*='deviceDetails.php?mac='], .btn-add-device") - if not add_buttons: + # Wait for the "New Device" link specifically to ensure it's loaded + add_selector = "a[href*='deviceDetails.php?mac=new'], button#btnAddDevice, .btn-add-device" + try: + add_button = wait_for_element_by_css(driver, add_selector, timeout=10) + except Exception: + # Fallback to broader search if specific selector fails add_buttons = driver.find_elements(By.XPATH, "//button[contains(text(),'Add') or contains(text(),'New')] | //a[contains(text(),'Add') or contains(text(),'New')]") - if not add_buttons: - assert True, "Add device button not found, skipping test" - return - add_buttons[0].click() + if add_buttons: + add_button = add_buttons[0] + else: + assert True, "Add device button not found, skipping test" + return + + # Use JavaScript click to bypass any transparent overlays from the chart + driver.execute_script("arguments[0].click();", add_button) # Wait for the device form to appear (use the NEWDEV_devMac field as indicator) wait_for_element_by_css(driver, "#NEWDEV_devMac", timeout=10) diff --git a/test/ui/test_ui_maintenance.py b/test/ui/test_ui_maintenance.py index 8c665eea..104e11dc 100644 --- a/test/ui/test_ui_maintenance.py +++ b/test/ui/test_ui_maintenance.py @@ -6,6 +6,7 @@ Tests CSV export/import, delete operations, database tools from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait +from selenium.webdriver.support import expected_conditions as EC from .test_helpers import BASE_URL, api_get, wait_for_page_load # noqa: E402 @@ -30,7 +31,10 @@ def test_export_csv_button_works(driver): import os import glob - driver.get(f"{BASE_URL}/maintenance.php") + # Use 127.0.0.1 instead of localhost to avoid IPv6 resolution issues in the browser + # which can lead to "Failed to fetch" if the server is only listening on IPv4. + target_url = f"{BASE_URL}/maintenance.php".replace("localhost", "127.0.0.1") + driver.get(target_url) wait_for_page_load(driver, timeout=10) # Clear any existing downloads @@ -38,13 +42,22 @@ def test_export_csv_button_works(driver): for f in glob.glob(f"{download_dir}/*.csv"): os.remove(f) + # Ensure the Backup/Restore tab is active so the button is in a clickable state + try: + tab = WebDriverWait(driver, 5).until( + EC.element_to_be_clickable((By.ID, "tab_BackupRestore_id")) + ) + tab.click() + except Exception: + pass + # Find the export button - export_btns = driver.find_elements(By.ID, "btnExportCSV") + try: + export_btn = WebDriverWait(driver, 10).until( + EC.presence_of_element_located((By.ID, "btnExportCSV")) + ) - if len(export_btns) > 0: - export_btn = export_btns[0] - - # Click it (JavaScript click works even if CSS hides it) + # Click it (JavaScript click works even if CSS hides it or if it's overlapped) driver.execute_script("arguments[0].click();", export_btn) # Wait for download to complete (up to 10 seconds) @@ -70,9 +83,15 @@ def test_export_csv_button_works(driver): # Download via blob/JavaScript - can't verify file in headless mode # Just verify button click didn't cause errors assert "error" not in driver.page_source.lower(), "Button click should not cause errors" - else: - # Button doesn't exist on this page - assert True, "Export button not found on this page" + except Exception as e: + # Check for alerts that might be blocking page_source access + try: + alert = driver.switch_to.alert + alert_text = alert.text + alert.accept() + assert False, f"Alert present: {alert_text}" + except Exception: + raise e def test_import_section_present(driver):