feat: Enhance authoritative field handling with new locking mechanisms and update tests

This commit is contained in:
Jokob @NetAlertX
2026-01-21 04:46:07 +00:00
parent 97e684dba4
commit 54d01f0a65
5 changed files with 309 additions and 79 deletions

View File

@@ -17,6 +17,7 @@ sys.path.extend([f"{INSTALL_PATH}/server"])
from logger import mylog # noqa: E402 [flake8 lint suppression]
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
from db.db_helper import row_to_json # noqa: E402 [flake8 lint suppression]
# Map of field to its source tracking field
@@ -149,8 +150,6 @@ def enforce_source_on_user_update(devMac, updates_dict, conn):
conn: Database connection object.
"""
cur = conn.cursor()
# Check if field has a corresponding source and should be updated
cur = conn.cursor()
try:
@@ -160,10 +159,9 @@ def enforce_source_on_user_update(devMac, updates_dict, conn):
device_columns = set()
updates_to_apply = {}
for field_name, new_value in updates_dict.items():
for field_name in updates_dict.keys():
if field_name in FIELD_SOURCE_MAP:
source_field = FIELD_SOURCE_MAP[field_name]
# User is updating this field, so mark it as USER
if not device_columns or source_field in device_columns:
updates_to_apply[source_field] = "USER"
@@ -179,17 +177,62 @@ def enforce_source_on_user_update(devMac, updates_dict, conn):
try:
cur.execute(sql, values)
conn.commit()
mylog(
"debug",
[f"[enforce_source_on_user_update] Updated sources for {devMac}: {updates_to_apply}"],
)
except Exception as e:
mylog("none", [f"[enforce_source_on_user_update] ERROR: {e}"])
conn.rollback()
raise
def get_locked_field_overrides(devMac, updates_dict, conn):
"""
For user updates, restore values for any fields whose *Source is LOCKED.
Args:
devMac: The MAC address of the device being updated.
updates_dict: Dict of field -> value being updated.
conn: Database connection object.
Returns:
tuple(set, dict): (locked_fields, overrides)
locked_fields: set of field names that are locked
overrides: dict of field -> existing value to preserve
"""
tracked_fields = [field for field in updates_dict.keys() if field in FIELD_SOURCE_MAP]
if not tracked_fields:
return set(), {}
select_columns = tracked_fields + [FIELD_SOURCE_MAP[field] for field in tracked_fields]
select_clause = ", ".join(select_columns)
cur = conn.cursor()
try:
cur.execute(
f"SELECT {select_clause} FROM Devices WHERE devMac=?",
(devMac,),
)
row = cur.fetchone()
except Exception:
row = None
if not row:
return set(), {}
row_data = row_to_json(list(row.keys()), row)
locked_fields = set()
overrides = {}
for field in tracked_fields:
source_field = FIELD_SOURCE_MAP[field]
if row_data.get(source_field) == "LOCKED":
locked_fields.add(field)
overrides[field] = row_data.get(field) or ""
return locked_fields, overrides
def lock_field(devMac, field_name, conn):
"""
Lock a field so it won't be overwritten by plugins.

View File

@@ -9,7 +9,13 @@ from logger import mylog
from models.plugin_object_instance import PluginObjectInstance
from database import get_temp_db_connection
from db.db_helper import get_table_json, get_device_condition_by_status, row_to_json, get_date_from_period
from db.authoritative_handler import enforce_source_on_user_update, lock_field, unlock_field, FIELD_SOURCE_MAP
from db.authoritative_handler import (
enforce_source_on_user_update,
get_locked_field_overrides,
lock_field,
unlock_field,
FIELD_SOURCE_MAP,
)
from helper import is_random_mac, get_setting_value
from utils.datetime_utils import timeNowDB, format_date
@@ -504,6 +510,67 @@ class DeviceInstance:
normalized_mac = normalize_mac(mac)
normalized_parent_mac = normalize_mac(data.get("devParentMAC") or "")
fields_updated_by_set_device_data = {
"devName",
"devOwner",
"devType",
"devVendor",
"devIcon",
"devFavorite",
"devGroup",
"devLocation",
"devComments",
"devParentMAC",
"devParentPort",
"devSSID",
"devSite",
"devStaticIP",
"devScan",
"devAlertEvents",
"devAlertDown",
"devParentRelType",
"devReqNicsOnline",
"devSkipRepeated",
"devIsNew",
"devIsArchived",
"devCustomProps",
}
# Only mark USER for tracked fields that this method actually updates.
tracked_update_fields = set(FIELD_SOURCE_MAP.keys()) & fields_updated_by_set_device_data
tracked_update_fields.discard("devMac")
locked_fields = set()
pre_update_tracked_values = {}
if not data.get("createNew", False):
conn_preview = get_temp_db_connection()
try:
locked_fields, overrides = get_locked_field_overrides(
normalized_mac,
data,
conn_preview,
)
if overrides:
data.update(overrides)
# Capture pre-update values for tracked fields so we can mark USER only
# when the user actually changes the value.
tracked_fields_in_payload = [
k for k in data.keys() if k in tracked_update_fields
]
if tracked_fields_in_payload:
select_clause = ", ".join(tracked_fields_in_payload)
cur_preview = conn_preview.cursor()
cur_preview.execute(
f"SELECT {select_clause} FROM Devices WHERE devMac=?",
(normalized_mac,),
)
row = cur_preview.fetchone()
if row:
pre_update_tracked_values = row_to_json(list(row.keys()), row)
finally:
conn_preview.close()
conn = None
try:
if data.get("createNew", False):
@@ -619,7 +686,30 @@ class DeviceInstance:
# Enforce source tracking on user updates
# User-updated fields should have their *Source set to "USER"
user_updated_fields = {k: v for k, v in data.items() if k in FIELD_SOURCE_MAP}
def _normalize_tracked_value(value):
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value)
user_updated_fields = {}
if not data.get("createNew", False):
for field_name in tracked_update_fields:
if field_name in locked_fields:
continue
if field_name not in data:
continue
if field_name == "devParentMAC":
new_value = normalized_parent_mac
else:
new_value = data.get(field_name)
old_value = pre_update_tracked_values.get(field_name)
if _normalize_tracked_value(old_value) != _normalize_tracked_value(new_value):
user_updated_fields[field_name] = new_value
if user_updated_fields and not data.get("createNew", False):
try:
enforce_source_on_user_update(normalized_mac, user_updated_fields, conn)