diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index 6c18746d..d81f20b8 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -3,7 +3,7 @@ from flask import Flask, request, jsonify, Response from flask_cors import CORS from .graphql_endpoint import devicesSchema from .device_endpoint import get_device_data, set_device_data, delete_device, delete_device_events, reset_device_props, copy_device, update_device_column -from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices, export_devices, import_csv +from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices, export_devices, import_csv, devices_totals, devices_by_status from .events_endpoint import delete_events, delete_events_30, get_events from .history_endpoint import delete_online_history from .prometheus_endpoint import getMetricStats @@ -150,12 +150,6 @@ def api_delete_unknown_devices(): return jsonify({"error": "Forbidden"}), 403 return delete_unknown_devices() -@app.route("/devices/totals", methods=["GET"]) -def api_get_devices_totals(): - if not is_authorized(): - return jsonify({"error": "Forbidden"}), 403 - return get_devices_totals() - @app.route("/devices/export", methods=["GET"]) @app.route("/devices/export/", methods=["GET"]) @@ -172,6 +166,21 @@ def api_import_csv(): return jsonify({"error": "Forbidden"}), 403 return import_csv(request.files.get("file")) +@app.route("/devices/totals", methods=["GET"]) +def api_devices_totals(): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + return devices_totals() + +@app.route("/devices/by-status", methods=["GET"]) +def api_devices_by_status(): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + + status = request.args.get("status", "") if request.args else None + + return devices_by_status(status) + # -------------------------- # Online history # -------------------------- diff --git a/server/api_server/device_endpoint.py b/server/api_server/device_endpoint.py index c1234743..1064b7a1 100755 --- a/server/api_server/device_endpoint.py +++ b/server/api_server/device_endpoint.py @@ -161,37 +161,39 @@ def set_device_data(mac, data): devSourcePlugin ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ + values = ( mac, - data.get("name", ""), - data.get("owner", ""), - data.get("type", ""), - data.get("vendor", ""), - data.get("icon", ""), - data.get("favorite", 0), - data.get("group", ""), - data.get("location", ""), - data.get("comments", ""), - data.get("networknode", ""), - data.get("networknodeport", ""), - data.get("ssid", ""), - data.get("networksite", ""), - data.get("staticIP", 0), - data.get("scancycle", 0), - data.get("alertevents", 0), - data.get("alertdown", 0), - data.get("relType", "default"), - data.get("reqNics", 0), - data.get("skiprepeated", 0), - data.get("newdevice", 0), - data.get("archived", 0), + data.get("devName", ""), + data.get("devOwner", ""), + data.get("devType", ""), + data.get("devVendor", ""), + data.get("devIcon", ""), + data.get("devFavorite", 0), + data.get("devGroup", ""), + data.get("devLocation", ""), + data.get("devComments", ""), + data.get("devParentMAC", ""), + data.get("devParentPort", ""), + data.get("devSSID", ""), + data.get("devSite", ""), + data.get("devStaticIP", 0), + data.get("devScan", 0), + data.get("devAlertEvents", 0), + data.get("devAlertDown", 0), + data.get("devParentRelType", "default"), + data.get("devReqNicsOnline", 0), + data.get("devSkipRepeated", 0), + data.get("devIsNew", 0), + data.get("devIsArchived", 0), data.get("devLastConnection", datetime.now().strftime("%Y-%m-%d %H:%M:%S")), data.get("devFirstConnection", datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - data.get("ip", ""), + data.get("devLastIP", ""), data.get("devGUID", ""), data.get("devCustomProps", ""), - "DUMMY" + data.get("devSourcePlugin", "DUMMY"), ) + else: sql = """ UPDATE Devices SET @@ -204,31 +206,31 @@ def set_device_data(mac, data): WHERE devMac=? """ values = ( - data.get("name", ""), - data.get("owner", ""), - data.get("type", ""), - data.get("vendor", ""), - data.get("icon", ""), - data.get("favorite", 0), - data.get("group", ""), - data.get("location", ""), - data.get("comments", ""), - data.get("networknode", ""), - data.get("networknodeport", ""), - data.get("ssid", ""), - data.get("networksite", ""), - data.get("staticIP", 0), - data.get("scancycle", 0), - data.get("alertevents", 0), - data.get("alertdown", 0), - data.get("relType", "default"), - data.get("reqNics", 0), - data.get("skiprepeated", 0), - data.get("newdevice", 0), - data.get("archived", 0), - data.get("devCustomProps", ""), - mac - ) + data.get("devName", ""), + data.get("devOwner", ""), + data.get("devType", ""), + data.get("devVendor", ""), + data.get("devIcon", ""), + data.get("devFavorite", 0), + data.get("devGroup", ""), + data.get("devLocation", ""), + data.get("devComments", ""), + data.get("devParentMAC", ""), + data.get("devParentPort", ""), + data.get("devSSID", ""), + data.get("devSite", ""), + data.get("devStaticIP", 0), + data.get("devScan", 0), + data.get("devAlertEvents", 0), + data.get("devAlertDown", 0), + data.get("devParentRelType", "default"), + data.get("devReqNicsOnline", 0), + data.get("devSkipRepeated", 0), + data.get("devIsNew", 0), + data.get("devIsArchived", 0), + data.get("devCustomProps", ""), + mac + ) conn = get_temp_db_connection() cur = conn.cursor() diff --git a/server/api_server/devices_endpoint.py b/server/api_server/devices_endpoint.py index 07e84431..ca86a346 100755 --- a/server/api_server/devices_endpoint.py +++ b/server/api_server/devices_endpoint.py @@ -20,7 +20,7 @@ sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from database import get_temp_db_connection from helper import is_random_mac, format_date, get_setting_value -from db.db_helper import get_table_json +from db.db_helper import get_table_json, get_device_condition_by_status # -------------------------- @@ -192,4 +192,58 @@ def import_csv(file_storage=None): "success": True, "inserted": row_count, "skipped_lines": skipped - }) \ No newline at end of file + }) + +def devices_totals(): + conn = get_temp_db_connection() + sql = conn.cursor() + + # Build a combined query with sub-selects for each status + query = f""" + SELECT + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('my')}) AS devices, + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('connected')}) AS connected, + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('favorites')}) AS favorites, + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('new')}) AS new, + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('down')}) AS down, + (SELECT COUNT(*) FROM Devices {get_device_condition_by_status('archived')}) AS archived + """ + sql.execute(query) + row = sql.fetchone() # returns a tuple like (devices, connected, favorites, new, down, archived) + + conn.close() + + # Return counts as JSON array + return jsonify(list(row)) + + +def devices_by_status(status=None): + """ + Return devices filtered by status. + """ + + conn = get_temp_db_connection() + sql = conn.cursor() + + # Build condition for SQL + condition = get_device_condition_by_status(status) if status else "" + + query = f"SELECT * FROM Devices {condition}" + sql.execute(query) + + table_data = [] + for row in sql.fetchall(): + r = dict(row) # Convert sqlite3.Row to dict for .get() + dev_name = r.get("devName", "") + if r.get("devFavorite") == 1: + dev_name = f' {dev_name}' + + table_data.append({ + "id": r.get("devMac", ""), + "title": dev_name, + "favorite": r.get("devFavorite", 0) + }) + + conn.close() + return jsonify(table_data) + diff --git a/test/test_device_endpoints.py b/test/test_device_endpoints.py index c9504e9e..95787c43 100755 --- a/test/test_device_endpoints.py +++ b/test/test_device_endpoints.py @@ -32,10 +32,10 @@ def auth_headers(token): def test_create_device(client, api_token, test_mac): payload = { "createNew": True, - "name": "Test Device", - "owner": "Unit Test", - "type": "Router", - "vendor": "TestVendor", + "devType": "Test Device", + "devOwner": "Unit Test", + "devType": "Router", + "devVendor": "TestVendor", } resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)) assert resp.status_code == 200 @@ -69,7 +69,7 @@ def test_delete_device(client, api_token, test_mac): def test_copy_device(client, api_token, test_mac): # Step 1: Create the source device - payload = {"createNew": True, "name": "Source Device"} + payload = {"createNew": True} resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)) assert resp.status_code == 200 assert resp.json.get("success") is True diff --git a/test/test_devices_endpoints.py b/test/test_devices_endpoints.py index 0d3443e8..2409bb8e 100755 --- a/test/test_devices_endpoints.py +++ b/test/test_devices_endpoints.py @@ -34,10 +34,10 @@ def auth_headers(token): def create_dummy(client, api_token, test_mac): payload = { "createNew": True, - "name": "Test Device", - "owner": "Unit Test", - "type": "Router", - "vendor": "TestVendor", + "devName": "Test Device", + "devOwner": "Unit Test", + "devType": "Router", + "devVendor": "TestVendor", } resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)) @@ -105,6 +105,8 @@ def test_export_import_cycle_base64(client, api_token, test_mac): assert resp.status_code == 200 csv_data = resp.data.decode("utf-8") + print(csv_data) + # Ensure our dummy device is in the CSV assert test_mac in csv_data assert "Test Device" in csv_data @@ -126,6 +128,51 @@ def test_export_import_cycle_base64(client, api_token, test_mac): assert resp.json.get("inserted") >= 1 assert resp.json.get("skipped_lines") == [] +def test_devices_totals(client, api_token, test_mac): + # 1. Create a dummy device + create_dummy(client, api_token, test_mac) + + # 2. Call the totals endpoint + resp = client.get("/devices/totals", headers=auth_headers(api_token)) + assert resp.status_code == 200 + + # 3. Ensure the response is a JSON list + data = resp.json + assert isinstance(data, list) + assert len(data) == 6 # devices, connected, favorites, new, down, archived + + # 4. Check that at least 1 device exists + assert data[0] >= 1 # 'devices' count includes the dummy device + + +def test_devices_by_status(client, api_token, test_mac): + # 1. Create a dummy device + create_dummy(client, api_token, test_mac) + + # 2. Request devices by a valid status + resp = client.get("/devices/by-status?status=my", headers=auth_headers(api_token)) + assert resp.status_code == 200 + data = resp.json + assert isinstance(data, list) + assert any(d["id"] == test_mac for d in data) + + # 3. Request devices with an invalid/unknown status + resp_invalid = client.get("/devices/by-status?status=invalid_status", headers=auth_headers(api_token)) + assert resp_invalid.status_code == 200 + # Should return empty list for unknown status + assert resp_invalid.json == [] + + # 4. Check favorite formatting if devFavorite = 1 + # Update dummy device to favorite + client.post( + f"/device/{test_mac}", + json={"devFavorite": 1}, + headers=auth_headers(api_token) + ) + resp_fav = client.get("/devices/by-status?status=my", headers=auth_headers(api_token)) + fav_data = next((d for d in resp_fav.json if d["id"] == test_mac), None) + assert fav_data is not None + assert "★" in fav_data["title"] def test_delete_test_devices(client, api_token, test_mac):