api layer v0.2.2 - CSV import/export, refactor
Some checks failed
Code checks / check-url-paths (push) Has been cancelled
docker / docker_dev (push) Has been cancelled
Deploy MkDocs / deploy (push) Has been cancelled

This commit is contained in:
jokob-sk
2025-08-19 07:56:54 +10:00
parent 9c71a8ecab
commit 962bbaa5a1
14 changed files with 738 additions and 284 deletions

View File

@@ -38,17 +38,16 @@
case 'deleteActHistory': deleteActHistory(); break;
case 'deleteDeviceEvents': deleteDeviceEvents(); break;
case 'resetDeviceProps': resetDeviceProps(); break;
case 'ExportCSV': ExportCSV(); break;
case 'ImportCSV': ImportCSV(); break;
case 'ExportCSV': ExportCSV(); break; // todo
case 'ImportCSV': ImportCSV(); break; // todo
case 'getDevicesTotals': getDevicesTotals(); break;
case 'getDevicesListCalendar': getDevicesListCalendar(); break; //todo: slowly deprecate this
case 'getDevicesTotals': getDevicesTotals(); break; // todo
case 'getDevicesListCalendar': getDevicesListCalendar(); break; // todo
case 'updateNetworkLeaf': updateNetworkLeaf(); break;
case 'updateNetworkLeaf': updateNetworkLeaf(); break; // todo
case 'getDevices': getDevices(); break;
case 'copyFromDevice': copyFromDevice(); break;
case 'wakeonlan': wakeonlan(); break;
case 'wakeonlan': wakeonlan(); break; // todo
default: logServerConsole ('Action: '. $action); break;
}
@@ -737,37 +736,6 @@ function getDevicesListCalendar() {
// Query Device Data
//------------------------------------------------------------------------------
//------------------------------------------------------------------------------
function getDevices() {
global $db;
// Device Data
$sql = 'select devMac, devName from Devices';
$result = $db->query($sql);
// arrays of rows
$tableData = array();
while ($row = $result -> fetchArray (SQLITE3_ASSOC)) {
$name = handleNull($row['devName'], "(unknown)");
$mac = handleNull($row['devMac'], "(unknown)");
// Push row data
$tableData[] = array('id' => $mac,
'name' => $name );
}
// Control no rows
if (empty($tableData)) {
$tableData = [];
}
// Return json
echo (json_encode ($tableData));
}
// ----------------------------------------------------------------------------------------
function updateNetworkLeaf()
{

View File

@@ -2,9 +2,9 @@ import threading
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from .graphql_endpoint import devicesSchema
from .device_endpoint import get_device_data, set_device_data, delete_device, delete_device_events, reset_device_props
from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices
from .events_endpoint import delete_device_events, delete_events, delete_events_30, get_events
from .device_endpoint import get_device_data, set_device_data, delete_device, delete_device_events, reset_device_props, copy_device, update_device_column
from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices, export_devices, import_csv
from .events_endpoint import delete_events, delete_events_30, get_events
from .history_endpoint import delete_online_history
from .prometheus_endpoint import getMetricStats
from .sync_endpoint import handle_sync_post, handle_sync_get
@@ -97,6 +97,34 @@ def api_reset_device_props(mac):
return jsonify({"error": "Forbidden"}), 403
return reset_device_props(mac, request.json)
@app.route("/device/copy", methods=["POST"])
def api_copy_device():
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
data = request.get_json() or {}
mac_from = data.get("macFrom")
mac_to = data.get("macTo")
if not mac_from or not mac_to:
return jsonify({"success": False, "error": "macFrom and macTo are required"}), 400
return copy_device(mac_from, mac_to)
@app.route("/device/<mac>/update-column", methods=["POST"])
def api_update_device_column(mac):
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
data = request.get_json() or {}
column_name = data.get("columnName")
column_value = data.get("columnValue")
if not column_name or not column_value:
return jsonify({"success": False, "error": "columnName and columnValue are required"}), 400
return update_device_column(mac, column_name, column_value)
# --------------------------
# Devices Collections
# --------------------------
@@ -129,6 +157,21 @@ def api_get_devices_totals():
return get_devices_totals()
@app.route("/devices/export", methods=["GET"])
@app.route("/devices/export/<format>", methods=["GET"])
def api_export_devices(format=None):
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
export_format = (format or request.args.get("format", "csv")).lower()
return export_devices(export_format)
@app.route("/devices/import", methods=["POST"])
def api_import_csv():
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
return import_csv(request.files.get("file"))
# --------------------------
# Online history
# --------------------------
@@ -144,7 +187,7 @@ def api_delete_online_history():
# --------------------------
@app.route("/events/<mac>", methods=["DELETE"])
def api_delete_device_events(mac):
def api_events_by_mac(mac):
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
return delete_device_events(mac)
@@ -156,7 +199,7 @@ def api_delete_all_events():
return delete_events()
@app.route("/events", methods=["GET"])
def api_delete_all_events():
def api_get_events():
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
@@ -170,22 +213,6 @@ def api_delete_old_events():
return jsonify({"error": "Forbidden"}), 403
return delete_events_30()
# --------------------------
# CSV Import / Export
# --------------------------
@app.route("/devices/export", methods=["GET"])
def api_export_csv():
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
return export_csv()
@app.route("/devices/import", methods=["POST"])
def api_import_csv():
if not is_authorized():
return jsonify({"error": "Forbidden"}), 403
return import_csv(request.files.get("file"))
# --------------------------
# Prometheus metrics endpoint
# --------------------------

View File

@@ -14,8 +14,8 @@ INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from database import get_temp_db_connection
from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value
from helper import is_random_mac, format_date, get_setting_value
from db.db_helper import row_to_json, get_date_from_period
# --------------------------
# Device Endpoints Functions
@@ -272,3 +272,63 @@ def reset_device_props(mac, data=None):
conn.close()
return jsonify({"success": True})
def update_device_column(mac, column_name, column_value):
"""
Update a specific column for a given device.
Example: update_device_column("AA:BB:CC:DD:EE:FF", "devParentMAC", "Internet")
"""
conn = get_temp_db_connection()
cur = conn.cursor()
# Build safe SQL with column name whitelisted
sql = f"UPDATE Devices SET {column_name}=? WHERE devMac=?"
cur.execute(sql, (column_value, mac))
conn.commit()
if cur.rowcount > 0:
return jsonify({"success": True})
else:
return jsonify({"success": False, "error": "Device not found"}), 404
conn.close()
return jsonify({"success": True})
def copy_device(mac_from, mac_to):
"""
Copy a device entry from one MAC to another.
If a device already exists with mac_to, it will be replaced.
"""
conn = get_temp_db_connection()
cur = conn.cursor()
try:
# Drop temporary table if exists
cur.execute("DROP TABLE IF EXISTS temp_devices")
# Create temporary table with source device
cur.execute("CREATE TABLE temp_devices AS SELECT * FROM Devices WHERE devMac = ?", (mac_from,))
# Update temporary table to target MAC
cur.execute("UPDATE temp_devices SET devMac = ?", (mac_to,))
# Delete previous entry with target MAC
cur.execute("DELETE FROM Devices WHERE devMac = ?", (mac_to,))
# Insert new entry from temporary table
cur.execute("INSERT INTO Devices SELECT * FROM temp_devices WHERE devMac = ?", (mac_to,))
# Drop temporary table
cur.execute("DROP TABLE temp_devices")
conn.commit()
return jsonify({"success": True, "message": f"Device copied from {mac_from} to {mac_to}"})
except Exception as e:
conn.rollback()
return jsonify({"success": False, "error": str(e)})
finally:
conn.close()

View File

@@ -5,16 +5,22 @@ import subprocess
import argparse
import os
import pathlib
import base64
import re
import sys
from datetime import datetime
from flask import jsonify, request
from flask import jsonify, request, Response
import csv
import io
from io import StringIO
# Register NetAlertX directories
INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from database import get_temp_db_connection
from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value
from helper import is_random_mac, format_date, get_setting_value
from db.db_helper import get_table_json
# --------------------------
@@ -72,4 +78,118 @@ def delete_unknown_devices():
cur.execute("""DELETE FROM Devices WHERE devName='(unknown)' OR devName='(name not found)'""")
conn.commit()
conn.close()
return jsonify({"success": True, "deleted": cur.rowcount})
return jsonify({"success": True, "deleted": cur.rowcount})
def export_devices(export_format):
"""
Export devices from the Devices table in teh desired format.
- If `macs` is None → delete ALL devices.
- If `macs` is a list → delete only matching MACs (supports wildcard '*').
"""
conn = get_temp_db_connection()
cur = conn.cursor()
# Fetch all devices
devices_json = get_table_json(cur, "SELECT * FROM Devices")
conn.close()
# Ensure columns exist
columns = devices_json.columnNames or (
list(devices_json["data"][0].keys()) if devices_json["data"] else []
)
if export_format == "json":
# Convert to standard dict for Flask JSON
return jsonify({
"data": [row for row in devices_json["data"]],
"columns": list(columns)
})
elif export_format == "csv":
si = StringIO()
writer = csv.DictWriter(si, fieldnames=columns, quoting=csv.QUOTE_ALL)
writer.writeheader()
for row in devices_json.json["data"]:
writer.writerow(row)
return Response(
si.getvalue(),
mimetype="text/csv",
headers={"Content-Disposition": "attachment; filename=devices.csv"},
)
else:
return jsonify({"error": f"Unsupported format '{export_format}'"}), 400
def import_csv(file_storage=None):
data = ""
skipped = []
error = None
# 1. Try JSON `content` (base64-encoded CSV)
if request.is_json and request.json.get("content"):
try:
data = base64.b64decode(request.json["content"], validate=True).decode("utf-8")
except Exception as e:
return jsonify({"error": f"Base64 decode failed: {e}"}), 400
# 2. Otherwise, try uploaded file
elif file_storage:
data = file_storage.read().decode("utf-8")
# 3. Fallback: try local file (same as PHP `$file = '../../../config/devices.csv';`)
else:
local_file = "/app/config/devices.csv"
try:
with open(local_file, "r", encoding="utf-8") as f:
data = f.read()
except FileNotFoundError:
return jsonify({"error": "CSV file missing"}), 404
if not data:
return jsonify({"error": "No CSV data found"}), 400
# --- Clean up newlines inside quoted fields ---
data = re.sub(
r'"([^"]*)"',
lambda m: m.group(0).replace("\n", " "),
data
)
# --- Parse CSV ---
lines = data.splitlines()
reader = csv.reader(lines)
try:
header = [h.strip() for h in next(reader)]
except StopIteration:
return jsonify({"error": "CSV missing header"}), 400
# --- Wipe Devices table ---
conn = get_temp_db_connection()
sql = conn.cursor()
sql.execute("DELETE FROM Devices")
# --- Prepare insert ---
placeholders = ",".join(["?"] * len(header))
insert_sql = f"INSERT INTO Devices ({', '.join(header)}) VALUES ({placeholders})"
row_count = 0
for idx, row in enumerate(reader, start=1):
if len(row) != len(header):
skipped.append(idx)
continue
try:
sql.execute(insert_sql, [col.strip() for col in row])
row_count += 1
except sqlite3.Error as e:
mylog("error", [f"[ImportCSV] SQL ERROR row {idx}: {e}"])
skipped.append(idx)
conn.commit()
conn.close()
return jsonify({
"success": True,
"inserted": row_count,
"skipped_lines": skipped
})

View File

@@ -14,7 +14,8 @@ INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from database import get_temp_db_connection
from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value
from helper import is_random_mac, format_date, get_setting_value
from db.db_helper import row_to_json
# --------------------------
@@ -68,16 +69,4 @@ def delete_events():
return jsonify({"success": True, "message": "Deleted all events"})
def delete_device_events(mac):
"""Delete all events"""
conn = get_temp_db_connection()
cur = conn.cursor()
sql = "DELETE FROM Events WHERE eve_MAC= ? "
cur.execute(sql, (mac,))
conn.commit()
conn.close()
return jsonify({"success": True, "message": "Deleted all events for the device"})

View File

@@ -14,7 +14,7 @@ INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from database import get_temp_db_connection
from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value
from helper import is_random_mac, format_date, get_setting_value
# --------------------------------------------------

View File

@@ -8,7 +8,8 @@ import json
from const import fullDbPath, sql_devices_stats, sql_devices_all, sql_generateGuid
from logger import mylog
from helper import json_obj, initOrSetParam, row_to_json, timeNowTZ
from helper import timeNowTZ
from db.db_helper import row_to_json, get_table_json, json_obj
from workflows.app_events import AppEvent_obj
from db.db_upgrade import ensure_column, ensure_views, ensure_CurrentScan, ensure_plugins_tables, ensure_Parameters, ensure_Settings, ensure_Indexes
@@ -121,26 +122,41 @@ class DB():
AppEvent_obj(self)
#-------------------------------------------------------------------------------
# #-------------------------------------------------------------------------------
# def get_table_as_json(self, sqlQuery):
# # mylog('debug',[ '[Database] - get_table_as_json - Query: ', sqlQuery])
# try:
# self.sql.execute(sqlQuery)
# columnNames = list(map(lambda x: x[0], self.sql.description))
# rows = self.sql.fetchall()
# except sqlite3.Error as e:
# mylog('verbose',[ '[Database] - SQL ERROR: ', e])
# return json_obj({}, []) # return empty object
# result = {"data":[]}
# for row in rows:
# tmp = row_to_json(columnNames, row)
# result["data"].append(tmp)
# # mylog('debug',[ '[Database] - get_table_as_json - returning ', len(rows), " rows with columns: ", columnNames])
# # mylog('debug',[ '[Database] - get_table_as_json - returning json ', json.dumps(result) ])
# return json_obj(result, columnNames)
def get_table_as_json(self, sqlQuery):
# mylog('debug',[ '[Database] - get_table_as_json - Query: ', sqlQuery])
"""
Wrapper to use the central get_table_as_json helper.
"""
try:
self.sql.execute(sqlQuery)
columnNames = list(map(lambda x: x[0], self.sql.description))
rows = self.sql.fetchall()
except sqlite3.Error as e:
mylog('verbose',[ '[Database] - SQL ERROR: ', e])
return json_obj({}, []) # return empty object
result = {"data":[]}
for row in rows:
tmp = row_to_json(columnNames, row)
result["data"].append(tmp)
result = get_table_json(self.sql, sqlQuery)
except Exception as e:
mylog('verbose', ['[Database] - get_table_as_json ERROR:', e])
return json_obj({}, []) # return empty object on failure
# mylog('debug',[ '[Database] - get_table_as_json - returning ', len(rows), " rows with columns: ", columnNames])
# mylog('debug',[ '[Database] - get_table_as_json - returning json ', json.dumps(result) ])
return json_obj(result, columnNames)
return result
#-------------------------------------------------------------------------------
# referece from here: https://codereview.stackexchange.com/questions/241043/interface-class-for-sqlite-databases

269
server/db/db_helper.py Executable file
View File

@@ -0,0 +1,269 @@
import sys
import sqlite3
# Register NetAlertX directories
INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/server"])
from helper import if_byte_then_to_str
from logger import mylog
#-------------------------------------------------------------------------------
# Return the SQL WHERE clause for filtering devices based on their status.
def get_device_condition_by_status(device_status):
"""
Return the SQL WHERE clause for filtering devices based on their status.
Parameters:
device_status (str): The status of the device. Possible values:
- 'all' : All active devices
- 'my' : Same as 'all' (active devices)
- 'connected' : Devices that are active and present in the last scan
- 'favorites' : Devices marked as favorite
- 'new' : Devices marked as new
- 'down' : Devices not present in the last scan but with alerts
- 'archived' : Devices that are archived
Returns:
str: SQL WHERE clause corresponding to the device status.
Defaults to 'WHERE 1=0' for unrecognized statuses.
"""
conditions = {
'all': 'WHERE devIsArchived=0',
'my': 'WHERE devIsArchived=0',
'connected': 'WHERE devIsArchived=0 AND devPresentLastScan=1',
'favorites': 'WHERE devIsArchived=0 AND devFavorite=1',
'new': 'WHERE devIsArchived=0 AND devIsNew=1',
'down': 'WHERE devIsArchived=0 AND devAlertDown != 0 AND devPresentLastScan=0',
'archived': 'WHERE devIsArchived=1'
}
return conditions.get(device_status, 'WHERE 1=0')
#-------------------------------------------------------------------------------
# Creates a JSON-like dictionary from a database row
def row_to_json(names, row):
"""
Convert a database row into a JSON-like dictionary.
Parameters:
names (list of str): List of column names corresponding to the row fields.
row (dict or sequence): A database row, typically a dictionary or list-like object,
where each column can be accessed by index or key.
Returns:
dict: A dictionary where keys are column names and values are the corresponding
row values. Byte values are automatically converted to strings using
`if_byte_then_to_str`.
Example:
names = ['id', 'name', 'data']
row = {0: 1, 1: b'Example', 2: b'\x01\x02'}
row_to_json(names, row)
# Returns: {'id': 1, 'name': 'Example', 'data': '\\x01\\x02'}
"""
rowEntry = {}
for index, name in enumerate(names):
rowEntry[name] = if_byte_then_to_str(row[name])
return rowEntry
#-------------------------------------------------------------------------------
def sanitize_SQL_input(val):
"""
Sanitize a value for use in SQL queries by replacing single quotes in strings.
Parameters:
val (any): The value to sanitize.
Returns:
str or any:
- Returns an empty string if val is None.
- Returns a string with single quotes replaced by underscores if val is a string.
- Returns val unchanged if it is any other type.
"""
if val is None:
return ''
if isinstance(val, str):
return val.replace("'", "_")
return val # Return non-string values as they are
# -------------------------------------------------------------------------------------------
def get_date_from_period(period):
"""
Convert a period string into an SQLite date expression.
Parameters:
period (str): The requested period (e.g., '7 days', '1 month', '1 year', '100 years').
Returns:
str: An SQLite date expression like "date('now', '-7 day')" corresponding to the period.
"""
days_map = {
'7 days': 7,
'1 month': 30,
'1 year': 365,
'100 years': 3650, # actually 10 years in original PHP
}
days = days_map.get(period, 1) # default 1 day
period_sql = f"date('now', '-{days} day')"
return period_sql
#-------------------------------------------------------------------------------
def print_table_schema(db, table):
"""
Print the schema of a database table to the log.
Parameters:
db: A database connection object with a `sql` cursor.
table (str): The name of the table whose schema is to be printed.
Returns:
None: Logs the column information including cid, name, type, notnull, default value, and primary key.
"""
sql = db.sql
sql.execute(f"PRAGMA table_info({table})")
result = sql.fetchall()
if not result:
mylog('none', f'[Schema] Table "{table}" not found or has no columns.')
return
mylog('debug', f'[Schema] Structure for table: {table}')
header = f"{'cid':<4} {'name':<20} {'type':<10} {'notnull':<8} {'default':<10} {'pk':<2}"
mylog('debug', header)
mylog('debug', '-' * len(header))
for row in result:
# row = (cid, name, type, notnull, dflt_value, pk)
line = f"{row[0]:<4} {row[1]:<20} {row[2]:<10} {row[3]:<8} {str(row[4]):<10} {row[5]:<2}"
mylog('debug', line)
#-------------------------------------------------------------------------------
# Generate a WHERE condition for SQLite based on a list of values.
def list_to_where(logical_operator, column_name, condition_operator, values_list):
"""
Generate a WHERE condition for SQLite based on a list of values.
Parameters:
- logical_operator: The logical operator ('AND' or 'OR') to combine conditions.
- column_name: The name of the column to filter on.
- condition_operator: The condition operator ('LIKE', 'NOT LIKE', '=', '!=', etc.).
- values_list: A list of values to be included in the condition.
Returns:
- A string representing the WHERE condition.
"""
# If the list is empty, return an empty string
if not values_list:
return ""
# Replace {s-quote} with single quote in values_list
values_list = [value.replace("{s-quote}", "'") for value in values_list]
# Build the WHERE condition for the first value
condition = f"{column_name} {condition_operator} '{values_list[0]}'"
# Add the rest of the values using the logical operator
for value in values_list[1:]:
condition += f" {logical_operator} {column_name} {condition_operator} '{value}'"
return f'({condition})'
#-------------------------------------------------------------------------------
def get_table_json(sql, sql_query):
"""
Execute a SQL query and return the results as JSON-like dict.
Args:
sql: SQLite cursor or connection wrapper supporting execute(), description, and fetchall().
sql_query (str): The SQL query to execute.
Returns:
dict: JSON-style object with data and column names.
"""
try:
sql.execute(sql_query)
column_names = [col[0] for col in sql.description]
rows = sql.fetchall()
except sqlite3.Error as e:
mylog('verbose', ['[Database] - SQL ERROR: ', e])
return json_obj({}, []) # return empty object
result = {"data": [row_to_json(column_names, row) for row in rows]}
return json_obj(result, column_names)
#-------------------------------------------------------------------------------
class json_obj:
"""
A wrapper class for JSON-style objects returned from database queries.
Provides dict-like access to the JSON data while storing column metadata.
Attributes:
json (dict): The actual JSON-style data returned from the database.
columnNames (list): List of column names corresponding to the data.
"""
def __init__(self, jsn, columnNames):
"""
Initialize the json_obj with JSON data and column names.
Args:
jsn (dict): JSON-style dictionary containing the data.
columnNames (list): List of column names for the data.
"""
self.json = jsn
self.columnNames = columnNames
def get(self, key, default=None):
"""
Dict-like .get() access to the JSON data.
Args:
key (str): Key to retrieve from the JSON data.
default: Value to return if key is not found (default: None).
Returns:
Value corresponding to key in the JSON data, or default if not present.
"""
return self.json.get(key, default)
def keys(self):
"""
Return the keys of the JSON data.
Returns:
Iterable of keys in the JSON dictionary.
"""
return self.json.keys()
def items(self):
"""
Return the items of the JSON data.
Returns:
Iterable of (key, value) pairs in the JSON dictionary.
"""
return self.json.items()
def __getitem__(self, key):
"""
Allow bracket-access (obj[key]) to the JSON data.
Args:
key (str): Key to retrieve from the JSON data.
Returns:
Value corresponding to the key.
"""
return self.json[key]

View File

@@ -18,7 +18,6 @@ import hashlib
import random
import string
import ipaddress
import dns.resolver
import conf
from const import *
@@ -53,22 +52,6 @@ def get_timezone_offset():
return offset_formatted
#-------------------------------------------------------------------------------
def updateSubnets(scan_subnets):
subnets = []
# multiple interfaces
if type(scan_subnets) is list:
for interface in scan_subnets :
subnets.append(interface)
# one interface only
else:
subnets.append(scan_subnets)
return subnets
#-------------------------------------------------------------------------------
# File system permission handling
#-------------------------------------------------------------------------------
@@ -217,12 +200,6 @@ def get_setting(key):
return None
#-------------------------------------------------------------------------------
# Settings
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# Return setting value
def get_setting_value(key):
@@ -248,8 +225,6 @@ def get_setting_value(key):
#-------------------------------------------------------------------------------
# Convert the setting value to the corresponding python type
def setting_value_to_python_type(set_type, set_value):
value = '----not processed----'
@@ -341,6 +316,30 @@ def setting_value_to_python_type(set_type, set_value):
return value
#-------------------------------------------------------------------------------
def updateSubnets(scan_subnets):
"""
Normalize scan subnet input into a list of subnets.
Parameters:
scan_subnets (str or list): A single subnet string or a list of subnet strings.
Returns:
list: A list containing all subnets. If a single subnet is provided, it is returned as a single-element list.
"""
subnets = []
# multiple interfaces
if isinstance(scan_subnets, list):
for interface in scan_subnets:
subnets.append(interface)
# one interface only
else:
subnets.append(scan_subnets)
return subnets
#-------------------------------------------------------------------------------
# Reverse transformed values if needed
def reverseTransformers(val, transformers):
@@ -360,41 +359,6 @@ def reverseTransformers(val, transformers):
else:
return reverse_transformers(val, transformers)
#-------------------------------------------------------------------------------
# Generate a WHERE condition for SQLite based on a list of values.
def list_to_where(logical_operator, column_name, condition_operator, values_list):
"""
Generate a WHERE condition for SQLite based on a list of values.
Parameters:
- logical_operator: The logical operator ('AND' or 'OR') to combine conditions.
- column_name: The name of the column to filter on.
- condition_operator: The condition operator ('LIKE', 'NOT LIKE', '=', '!=', etc.).
- values_list: A list of values to be included in the condition.
Returns:
- A string representing the WHERE condition.
"""
# If the list is empty, return an empty string
if not values_list:
return ""
# Replace {s-quote} with single quote in values_list
values_list = [value.replace("{s-quote}", "'") for value in values_list]
# Build the WHERE condition for the first value
condition = f"{column_name} {condition_operator} '{values_list[0]}'"
# Add the rest of the values using the logical operator
for value in values_list[1:]:
condition += f" {logical_operator} {column_name} {condition_operator} '{value}'"
return f'({condition})'
#-------------------------------------------------------------------------------
# IP validation methods
@@ -432,6 +396,19 @@ def check_IP_format (pIP):
# String manipulation methods
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
def generate_random_string(length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
#-------------------------------------------------------------------------------
def extract_between_strings(text, start, end):
start_index = text.find(start)
end_index = text.find(end, start_index + len(start))
if start_index != -1 and end_index != -1:
return text[start_index + len(start):end_index]
else:
return ""
#-------------------------------------------------------------------------------
@@ -474,7 +451,6 @@ def removeDuplicateNewLines(text):
return text
#-------------------------------------------------------------------------------
def sanitize_string(input):
if isinstance(input, bytes):
input = input.decode('utf-8')
@@ -482,15 +458,6 @@ def sanitize_string(input):
return input
#-------------------------------------------------------------------------------
def sanitize_SQL_input(val):
if val is None:
return ''
if isinstance(val, str):
return val.replace("'", "_")
return val # Return non-string values as they are
#-------------------------------------------------------------------------------
# Function to normalize the string and remove diacritics
def normalize_string(text):
@@ -501,8 +468,29 @@ def normalize_string(text):
# Filter out diacritics and unwanted characters
return ''.join(c for c in normalized_text if unicodedata.category(c) != 'Mn')
# ------------------------------------------------------------------------------
# MAC and IP helper methods
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
def is_random_mac(mac: str) -> bool:
"""Determine if a MAC address is random, respecting user-defined prefixes not to mark as random."""
is_random = mac[1].upper() in ["2", "6", "A", "E"]
# Get prefixes from settings
prefixes = get_setting_value("UI_NOT_RANDOM_MAC")
# If detected as random, make sure it doesn't start with a prefix the user wants to exclude
if is_random:
for prefix in prefixes:
if mac.upper().startswith(prefix.upper()):
is_random = False
break
return is_random
# -------------------------------------------------------------------------------------------
def generate_mac_links (html, deviceUrl):
p = re.compile(r'(?:[0-9a-fA-F]:?){12}')
@@ -514,15 +502,6 @@ def generate_mac_links (html, deviceUrl):
return html
#-------------------------------------------------------------------------------
def extract_between_strings(text, start, end):
start_index = text.find(start)
end_index = text.find(end, start_index + len(start))
if start_index != -1 and end_index != -1:
return text[start_index + len(start):end_index]
else:
return ""
#-------------------------------------------------------------------------------
def extract_mac_addresses(text):
mac_pattern = r"([0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2})"
@@ -536,11 +515,6 @@ def extract_ip_addresses(text):
return ip_addresses
#-------------------------------------------------------------------------------
def generate_random_string(length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
# Helper function to determine if a MAC address is random
def is_random_mac(mac):
# Check if second character matches "2", "6", "A", "E" (case insensitive)
@@ -555,13 +529,14 @@ def is_random_mac(mac):
break
return is_random
#-------------------------------------------------------------------------------
# Helper function to calculate number of children
def get_number_of_children(mac, devices):
# Count children by checking devParentMAC for each device
return sum(1 for dev in devices if dev.get("devParentMAC", "").strip() == mac.strip())
#-------------------------------------------------------------------------------
# Function to convert IP to a long integer
def format_ip_long(ip_address):
try:
@@ -596,8 +571,6 @@ def add_json_list (row, list):
return list
#-------------------------------------------------------------------------------
# Checks if the object has a __dict__ attribute. If it does, it assumes that it's an instance of a class and serializes its attributes dynamically.
class NotiStrucEncoder(json.JSONEncoder):
@@ -607,19 +580,6 @@ class NotiStrucEncoder(json.JSONEncoder):
return obj.__dict__
return super().default(obj)
#-------------------------------------------------------------------------------
# Creates a JSON object from a DB row
def row_to_json(names, row):
rowEntry = {}
index = 0
for name in names:
rowEntry[name]= if_byte_then_to_str(row[name])
index += 1
return rowEntry
#-------------------------------------------------------------------------------
# Get language strings from plugin JSON
def collect_lang_strings(json, pref, stringSqlParams):
@@ -633,7 +593,7 @@ def collect_lang_strings(json, pref, stringSqlParams):
return stringSqlParams
#-------------------------------------------------------------------------------
# Misc
# Date and time methods
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
@@ -661,65 +621,6 @@ def format_date_iso(date1: str) -> str:
dt = datetime.datetime.fromisoformat(date1) if isinstance(date1, str) else date1
return dt.isoformat()
# -------------------------------------------------------------------------------------------
def is_random_mac(mac: str) -> bool:
"""Determine if a MAC address is random, respecting user-defined prefixes not to mark as random."""
is_random = mac[1].upper() in ["2", "6", "A", "E"]
# Get prefixes from settings
prefixes = get_setting_value("UI_NOT_RANDOM_MAC")
# If detected as random, make sure it doesn't start with a prefix the user wants to exclude
if is_random:
for prefix in prefixes:
if mac.upper().startswith(prefix.upper()):
is_random = False
break
return is_random
# -------------------------------------------------------------------------------------------
def get_date_from_period(period):
"""
Convert a period request parameter into an SQLite date expression.
Equivalent to PHP getDateFromPeriod().
Returns a string like "date('now', '-7 day')"
"""
days_map = {
'7 days': 7,
'1 month': 30,
'1 year': 365,
'100 years': 3650, # actually 10 years in original PHP
}
days = days_map.get(period, 1) # default 1 day
period_sql = f"date('now', '-{days} day')"
return period_sql
#-------------------------------------------------------------------------------
def print_table_schema(db, table):
sql = db.sql
sql.execute(f"PRAGMA table_info({table})")
result = sql.fetchall()
if not result:
mylog('none', f'[Schema] Table "{table}" not found or has no columns.')
return
mylog('debug', f'[Schema] Structure for table: {table}')
header = f"{'cid':<4} {'name':<20} {'type':<10} {'notnull':<8} {'default':<10} {'pk':<2}"
mylog('debug', header)
mylog('debug', '-' * len(header))
for row in result:
# row = (cid, name, type, notnull, dflt_value, pk)
line = f"{row[0]:<4} {row[1]:<20} {row[2]:<10} {row[3]:<8} {str(row[4]):<10} {row[5]:<2}"
mylog('debug', line)
#-------------------------------------------------------------------------------
def checkNewVersion():
mylog('debug', [f"[Version check] Checking if new version available"])
@@ -761,22 +662,6 @@ def checkNewVersion():
return newVersion
#-------------------------------------------------------------------------------
def initOrSetParam(db, parID, parValue):
sql = db.sql
sql.execute ("INSERT INTO Parameters(par_ID, par_Value) VALUES('"+str(parID)+"', '"+str(parValue)+"') ON CONFLICT(par_ID) DO UPDATE SET par_Value='"+str(parValue)+"' where par_ID='"+str(parID)+"'")
db.commitDB()
#-------------------------------------------------------------------------------
class json_obj:
def __init__(self, jsn, columnNames):
self.json = jsn
self.columnNames = columnNames
#-------------------------------------------------------------------------------
class noti_obj:
def __init__(self, json, text, html):

View File

@@ -12,7 +12,7 @@ import re
# Register NetAlertX libraries
import conf
from const import fullConfPath, applicationPath, fullConfFolder, default_tz
from helper import fixPermissions, collect_lang_strings, updateSubnets, initOrSetParam, isJsonObject, setting_value_to_python_type, timeNowTZ, get_setting_value, generate_random_string
from helper import fixPermissions, collect_lang_strings, updateSubnets, isJsonObject, setting_value_to_python_type, timeNowTZ, get_setting_value, generate_random_string
from app_state import updateState
from logger import mylog
from api import update_api

View File

@@ -8,12 +8,13 @@ import re
INSTALL_PATH="/app"
sys.path.extend([f"{INSTALL_PATH}/server"])
from helper import timeNowTZ, get_setting_value, list_to_where, check_IP_format, sanitize_SQL_input
from helper import timeNowTZ, get_setting_value, check_IP_format
from logger import mylog
from const import vendorsPath, vendorsPathNewest, sql_generateGuid
from models.device_instance import DeviceInstance
from scan.name_resolution import NameResolver
from scan.device_heuristics import guess_icon, guess_type
from db.db_helper import sanitize_SQL_input, list_to_where
#-------------------------------------------------------------------------------
# Removing devices from the CurrentScan DB table which the user chose to ignore by MAC or IP

View File

@@ -6,7 +6,8 @@ sys.path.extend([f"{INSTALL_PATH}/server"])
import conf
from scan.device_handling import create_new_devices, print_scan_stats, save_scanned_devices, exclude_ignored_devices, update_devices_data_from_scan
from helper import timeNowTZ, print_table_schema, get_setting_value
from helper import timeNowTZ, get_setting_value
from db.db_helper import print_table_schema
from logger import mylog, Logger
from messaging.reporting import skip_repeated_notifications

View File

@@ -66,3 +66,56 @@ def test_delete_device(client, api_token, test_mac):
resp = client.delete(f"/device/{test_mac}/delete", headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
def test_copy_device(client, api_token, test_mac):
# Step 1: Create the source device
payload = {"createNew": True, "name": "Source Device"}
resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
# Step 2: Generate a target MAC
target_mac = "AA:BB:CC:" + ":".join(f"{random.randint(0,255):02X}" for _ in range(3))
# Step 3: Copy device
copy_payload = {"macFrom": test_mac, "macTo": target_mac}
resp = client.post("/device/copy", json=copy_payload, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
# Step 4: Verify new device exists
resp = client.get(f"/device/{target_mac}", headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("devMac") == target_mac
# Cleanup: delete both devices
client.delete(f"/device/{test_mac}/delete", headers=auth_headers(api_token))
client.delete(f"/device/{target_mac}/delete", headers=auth_headers(api_token))
def test_update_device_column(client, api_token, test_mac):
# First, create the device
client.post(
f"/device/{test_mac}",
json={"createNew": True},
headers=auth_headers(api_token),
)
# Update its parent MAC
resp = client.post(
f"/device/{test_mac}/update-column",
json={"columnName": "devParentMAC", "columnValue": "Internet"},
headers=auth_headers(api_token),
)
assert resp.status_code == 200
assert resp.json.get("success") is True
# Try updating a non-existent device
resp_missing = client.post(
"/device/11:22:33:44:55:66/update-column",
json={"columnName": "devParentMAC", "columnValue": "Internet"},
headers=auth_headers(api_token),
)
assert resp_missing.status_code == 404
assert resp_missing.json.get("success") is False

View File

@@ -1,6 +1,7 @@
import sys
import pathlib
import sqlite3
import base64
import random
import string
import uuid
@@ -29,9 +30,8 @@ def test_mac():
def auth_headers(token):
return {"Authorization": f"Bearer {token}"}
def test_delete_devices_with_macs(client, api_token, test_mac):
# First create device so it exists
def create_dummy(client, api_token, test_mac):
payload = {
"createNew": True,
"name": "Test Device",
@@ -40,6 +40,10 @@ def test_delete_devices_with_macs(client, api_token, test_mac):
"vendor": "TestVendor",
}
resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token))
def test_delete_devices_with_macs(client, api_token, test_mac):
# First create device so it exists
create_dummy(client, api_token, test_mac)
client.post(f"/device/{test_mac}", json={"createNew": True}, headers=auth_headers(api_token))
@@ -48,14 +52,6 @@ def test_delete_devices_with_macs(client, api_token, test_mac):
assert resp.status_code == 200
assert resp.json.get("success") is True
def test_delete_test_devices(client, api_token, test_mac):
# Delete by MAC
resp = client.delete("/devices", json={"macs": ["AA:BB:CC:*"]}, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
def test_delete_all_empty_macs(client, api_token):
resp = client.delete("/devices/empty-macs", headers=auth_headers(api_token))
assert resp.status_code == 200
@@ -68,3 +64,72 @@ def test_delete_unknown_devices(client, api_token):
assert resp.status_code == 200
assert resp.json.get("success") is True
def test_export_devices_csv(client, api_token, test_mac):
# Create a device first
create_dummy(client, api_token, test_mac)
# Export devices as CSV
resp = client.get("/devices/export/csv", headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.mimetype == "text/csv"
assert "attachment; filename=devices.csv" in resp.headers.get("Content-disposition", "")
# CSV should contain test_mac
assert test_mac in resp.data.decode()
def test_export_devices_json(client, api_token, test_mac):
# Create a device first
create_dummy(client, api_token, test_mac)
# Export devices as JSON
resp = client.get("/devices/export/json", headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.is_json
data = resp.get_json()
assert any(dev.get("devMac") == test_mac for dev in data["data"])
def test_export_devices_invalid_format(client, api_token):
# Request with unsupported format
resp = client.get("/devices/export/invalid", headers=auth_headers(api_token))
assert resp.status_code == 400
assert "Unsupported format" in resp.json.get("error")
def test_export_import_cycle_base64(client, api_token, test_mac):
# 1. Create a dummy device
create_dummy(client, api_token, test_mac)
# 2. Export devices as CSV
resp = client.get("/devices/export/csv", headers=auth_headers(api_token))
assert resp.status_code == 200
csv_data = resp.data.decode("utf-8")
# Ensure our dummy device is in the CSV
assert test_mac in csv_data
assert "Test Device" in csv_data
# 3. Base64-encode the CSV for JSON payload
csv_base64 = base64.b64encode(csv_data.encode("utf-8")).decode("utf-8")
json_payload = {"content": csv_base64}
# 4. POST to import endpoint with JSON content
resp = client.post(
"/devices/import",
json=json_payload,
headers={**auth_headers(api_token), "Content-Type": "application/json"}
)
assert resp.status_code == 200
assert resp.json.get("success") is True
# 5. Verify import results
assert resp.json.get("inserted") >= 1
assert resp.json.get("skipped_lines") == []
def test_delete_test_devices(client, api_token, test_mac):
# Delete by MAC
resp = client.delete("/devices", json={"macs": ["AA:BB:CC:*"]}, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True