mirror of
https://github.com/jokob-sk/NetAlertX.git
synced 2025-12-07 09:36:05 -08:00
331
test/backend/test_compound_conditions.py
Normal file
331
test/backend/test_compound_conditions.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Unit tests for SafeConditionBuilder compound condition parsing.
|
||||
|
||||
Tests the fix for Issue #1210 - compound conditions with multiple AND/OR clauses.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Mock the logger module before importing SafeConditionBuilder
|
||||
sys.modules['logger'] = MagicMock()
|
||||
|
||||
# Add parent directory to path for imports
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from server.db.sql_safe_builder import SafeConditionBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def builder():
|
||||
"""Create a fresh builder instance for each test."""
|
||||
return SafeConditionBuilder()
|
||||
|
||||
|
||||
def test_user_failing_filter_six_and_clauses(builder):
|
||||
"""Test the exact user-reported failing filter from Issue #1210."""
|
||||
condition = (
|
||||
"AND devLastIP NOT LIKE '192.168.50.%' "
|
||||
"AND devLastIP NOT LIKE '192.168.60.%' "
|
||||
"AND devLastIP NOT LIKE '192.168.70.2' "
|
||||
"AND devLastIP NOT LIKE '192.168.70.5' "
|
||||
"AND devLastIP NOT LIKE '192.168.70.3' "
|
||||
"AND devLastIP NOT LIKE '192.168.70.4'"
|
||||
)
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should successfully parse
|
||||
assert sql is not None
|
||||
assert params is not None
|
||||
|
||||
# Should have 6 parameters (one per clause)
|
||||
assert len(params) == 6
|
||||
|
||||
# Should contain all 6 AND operators
|
||||
assert sql.count('AND') == 6
|
||||
|
||||
# Should contain all 6 NOT LIKE operators
|
||||
assert sql.count('NOT LIKE') == 6
|
||||
|
||||
# Should have 6 parameter placeholders
|
||||
assert sql.count(':param_') == 6
|
||||
|
||||
# Verify all IP patterns are in parameters
|
||||
param_values = list(params.values())
|
||||
assert '192.168.50.%' in param_values
|
||||
assert '192.168.60.%' in param_values
|
||||
assert '192.168.70.2' in param_values
|
||||
assert '192.168.70.5' in param_values
|
||||
assert '192.168.70.3' in param_values
|
||||
assert '192.168.70.4' in param_values
|
||||
|
||||
|
||||
def test_multiple_and_clauses_simple(builder):
|
||||
"""Test multiple AND clauses with simple equality operators."""
|
||||
condition = "AND devName = 'Device1' AND devVendor = 'Apple' AND devFavorite = '1'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 3 parameters
|
||||
assert len(params) == 3
|
||||
|
||||
# Should have 3 AND operators
|
||||
assert sql.count('AND') == 3
|
||||
|
||||
# Verify all values are parameterized
|
||||
param_values = list(params.values())
|
||||
assert 'Device1' in param_values
|
||||
assert 'Apple' in param_values
|
||||
assert '1' in param_values
|
||||
|
||||
|
||||
def test_multiple_or_clauses(builder):
|
||||
"""Test multiple OR clauses."""
|
||||
condition = "OR devName = 'Device1' OR devName = 'Device2' OR devName = 'Device3'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 3 parameters
|
||||
assert len(params) == 3
|
||||
|
||||
# Should have 3 OR operators
|
||||
assert sql.count('OR') == 3
|
||||
|
||||
# Verify all device names are parameterized
|
||||
param_values = list(params.values())
|
||||
assert 'Device1' in param_values
|
||||
assert 'Device2' in param_values
|
||||
assert 'Device3' in param_values
|
||||
|
||||
def test_mixed_and_or_clauses(builder):
|
||||
"""Test mixed AND/OR logical operators."""
|
||||
condition = "AND devName = 'Device1' OR devName = 'Device2' AND devFavorite = '1'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 3 parameters
|
||||
assert len(params) == 3
|
||||
|
||||
# Should preserve the logical operator order
|
||||
assert 'AND' in sql
|
||||
assert 'OR' in sql
|
||||
|
||||
# Verify all values are parameterized
|
||||
param_values = list(params.values())
|
||||
assert 'Device1' in param_values
|
||||
assert 'Device2' in param_values
|
||||
assert '1' in param_values
|
||||
|
||||
|
||||
def test_single_condition_backward_compatibility(builder):
|
||||
"""Test that single conditions still work (backward compatibility)."""
|
||||
condition = "AND devName = 'TestDevice'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 1 parameter
|
||||
assert len(params) == 1
|
||||
|
||||
# Should match expected format
|
||||
assert 'AND devName = :param_' in sql
|
||||
|
||||
# Parameter should contain the value
|
||||
assert 'TestDevice' in params.values()
|
||||
|
||||
|
||||
def test_single_condition_like_operator(builder):
|
||||
"""Test single LIKE condition for backward compatibility."""
|
||||
condition = "AND devComments LIKE '%important%'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 1 parameter
|
||||
assert len(params) == 1
|
||||
|
||||
# Should contain LIKE operator
|
||||
assert 'LIKE' in sql
|
||||
|
||||
# Parameter should contain the pattern
|
||||
assert '%important%' in params.values()
|
||||
|
||||
|
||||
def test_compound_with_like_patterns(builder):
|
||||
"""Test compound conditions with LIKE patterns."""
|
||||
condition = "AND devLastIP LIKE '192.168.%' AND devVendor LIKE '%Apple%'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 2 parameters
|
||||
assert len(params) == 2
|
||||
|
||||
# Should have 2 LIKE operators
|
||||
assert sql.count('LIKE') == 2
|
||||
|
||||
# Verify patterns are parameterized
|
||||
param_values = list(params.values())
|
||||
assert '192.168.%' in param_values
|
||||
assert '%Apple%' in param_values
|
||||
|
||||
|
||||
def test_compound_with_inequality_operators(builder):
|
||||
"""Test compound conditions with various inequality operators."""
|
||||
condition = "AND eve_DateTime > '2024-01-01' AND eve_DateTime < '2024-12-31'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 2 parameters
|
||||
assert len(params) == 2
|
||||
|
||||
# Should have both operators
|
||||
assert '>' in sql
|
||||
assert '<' in sql
|
||||
|
||||
# Verify dates are parameterized
|
||||
param_values = list(params.values())
|
||||
assert '2024-01-01' in param_values
|
||||
assert '2024-12-31' in param_values
|
||||
|
||||
|
||||
def test_empty_condition(builder):
|
||||
"""Test empty condition string."""
|
||||
condition = ""
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should return empty results
|
||||
assert sql == ""
|
||||
assert params == {}
|
||||
|
||||
|
||||
def test_whitespace_only_condition(builder):
|
||||
"""Test condition with only whitespace."""
|
||||
condition = " \t\n "
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should return empty results
|
||||
assert sql == ""
|
||||
assert params == {}
|
||||
|
||||
|
||||
def test_invalid_column_name_rejected(builder):
|
||||
"""Test that invalid column names are rejected."""
|
||||
condition = "AND malicious_column = 'value'"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_safe_condition(condition)
|
||||
|
||||
|
||||
def test_invalid_operator_rejected(builder):
|
||||
"""Test that invalid operators are rejected."""
|
||||
condition = "AND devName EXECUTE 'DROP TABLE'"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_safe_condition(condition)
|
||||
|
||||
|
||||
def test_sql_injection_attempt_blocked(builder):
|
||||
"""Test that SQL injection attempts are blocked."""
|
||||
condition = "AND devName = 'value'; DROP TABLE devices; --"
|
||||
|
||||
# Should either reject or sanitize the dangerous input
|
||||
# The semicolon and comment should not appear in the final SQL
|
||||
try:
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
# If it doesn't raise an error, it should sanitize the input
|
||||
assert 'DROP' not in sql.upper()
|
||||
assert ';' not in sql
|
||||
except ValueError:
|
||||
# Rejection is also acceptable
|
||||
pass
|
||||
|
||||
|
||||
def test_quoted_string_with_spaces(builder):
|
||||
"""Test that quoted strings with spaces are handled correctly."""
|
||||
condition = "AND devName = 'My Device Name' AND devComments = 'Has spaces here'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 2 parameters
|
||||
assert len(params) == 2
|
||||
|
||||
# Verify values with spaces are preserved
|
||||
param_values = list(params.values())
|
||||
assert 'My Device Name' in param_values
|
||||
assert 'Has spaces here' in param_values
|
||||
|
||||
|
||||
def test_compound_condition_with_not_equal(builder):
|
||||
"""Test compound conditions with != operator."""
|
||||
condition = "AND devName != 'Device1' AND devVendor != 'Unknown'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 2 parameters
|
||||
assert len(params) == 2
|
||||
|
||||
# Should have != operators (or converted to <>)
|
||||
assert '!=' in sql or '<>' in sql
|
||||
|
||||
# Verify values are parameterized
|
||||
param_values = list(params.values())
|
||||
assert 'Device1' in param_values
|
||||
assert 'Unknown' in param_values
|
||||
|
||||
|
||||
def test_very_long_compound_condition(builder):
|
||||
"""Test handling of very long compound conditions (10+ clauses)."""
|
||||
clauses = []
|
||||
for i in range(10):
|
||||
clauses.append(f"AND devName != 'Device{i}'")
|
||||
|
||||
condition = " ".join(clauses)
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have 10 parameters
|
||||
assert len(params) == 10
|
||||
|
||||
# Should have 10 AND operators
|
||||
assert sql.count('AND') == 10
|
||||
|
||||
# Verify all device names are parameterized
|
||||
param_values = list(params.values())
|
||||
for i in range(10):
|
||||
assert f'Device{i}' in param_values
|
||||
|
||||
|
||||
def test_parameters_have_unique_names(builder):
|
||||
"""Test that all parameters get unique names."""
|
||||
condition = "AND devName = 'A' AND devName = 'B' AND devName = 'C'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# All parameter names should be unique
|
||||
param_names = list(params.keys())
|
||||
assert len(param_names) == len(set(param_names))
|
||||
|
||||
|
||||
def test_parameter_values_match_condition(builder):
|
||||
"""Test that parameter values correctly match the condition values."""
|
||||
condition = "AND devLastIP NOT LIKE '192.168.1.%' AND devLastIP NOT LIKE '10.0.0.%'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Should have exactly the values from the condition
|
||||
param_values = sorted(params.values())
|
||||
expected_values = sorted(['192.168.1.%', '10.0.0.%'])
|
||||
assert param_values == expected_values
|
||||
|
||||
|
||||
def test_parameters_referenced_in_sql(builder):
|
||||
"""Test that all parameters are actually referenced in the SQL."""
|
||||
condition = "AND devName = 'Device1' AND devVendor = 'Apple'"
|
||||
|
||||
sql, params = builder.build_safe_condition(condition)
|
||||
|
||||
# Every parameter should appear in the SQL
|
||||
for param_name in params.keys():
|
||||
assert f':{param_name}' in sql
|
||||
331
test/backend/test_safe_builder_unit.py
Normal file
331
test/backend/test_safe_builder_unit.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Unit tests for SafeConditionBuilder focusing on core security functionality.
|
||||
This test file has minimal dependencies to ensure it can run in any environment.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import re
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
# Mock the logger module to avoid dependency issues
|
||||
sys.modules['logger'] = Mock()
|
||||
|
||||
# Standalone version of SafeConditionBuilder for testing
|
||||
class TestSafeConditionBuilder:
|
||||
"""
|
||||
Test version of SafeConditionBuilder with mock logger.
|
||||
"""
|
||||
|
||||
# Whitelist of allowed column names for filtering
|
||||
ALLOWED_COLUMNS = {
|
||||
'eve_MAC', 'eve_DateTime', 'eve_IP', 'eve_EventType', 'devName',
|
||||
'devComments', 'devLastIP', 'devVendor', 'devAlertEvents',
|
||||
'devAlertDown', 'devIsArchived', 'devPresentLastScan', 'devFavorite',
|
||||
'devIsNew', 'Plugin', 'Object_PrimaryId', 'Object_SecondaryId',
|
||||
'DateTimeChanged', 'Watched_Value1', 'Watched_Value2', 'Watched_Value3',
|
||||
'Watched_Value4', 'Status'
|
||||
}
|
||||
|
||||
# Whitelist of allowed comparison operators
|
||||
ALLOWED_OPERATORS = {
|
||||
'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'NOT LIKE',
|
||||
'IN', 'NOT IN', 'IS NULL', 'IS NOT NULL'
|
||||
}
|
||||
|
||||
# Whitelist of allowed logical operators
|
||||
ALLOWED_LOGICAL_OPERATORS = {'AND', 'OR'}
|
||||
|
||||
# Whitelist of allowed event types
|
||||
ALLOWED_EVENT_TYPES = {
|
||||
'New Device', 'Connected', 'Disconnected', 'Device Down',
|
||||
'Down Reconnected', 'IP Changed'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the SafeConditionBuilder."""
|
||||
self.parameters = {}
|
||||
self.param_counter = 0
|
||||
|
||||
def _generate_param_name(self, prefix='param'):
|
||||
"""Generate a unique parameter name for SQL binding."""
|
||||
self.param_counter += 1
|
||||
return f"{prefix}_{self.param_counter}"
|
||||
|
||||
def _sanitize_string(self, value):
|
||||
"""Sanitize string input by removing potentially dangerous characters."""
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
|
||||
# Replace {s-quote} placeholder with single quote (maintaining compatibility)
|
||||
value = value.replace('{s-quote}', "'")
|
||||
|
||||
# Remove any null bytes, control characters, and excessive whitespace
|
||||
value = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f]', '', value)
|
||||
value = re.sub(r'\s+', ' ', value.strip())
|
||||
|
||||
return value
|
||||
|
||||
def _validate_column_name(self, column):
|
||||
"""Validate that a column name is in the whitelist."""
|
||||
return column in self.ALLOWED_COLUMNS
|
||||
|
||||
def _validate_operator(self, operator):
|
||||
"""Validate that an operator is in the whitelist."""
|
||||
return operator.upper() in self.ALLOWED_OPERATORS
|
||||
|
||||
def _validate_logical_operator(self, logical_op):
|
||||
"""Validate that a logical operator is in the whitelist."""
|
||||
return logical_op.upper() in self.ALLOWED_LOGICAL_OPERATORS
|
||||
|
||||
def build_safe_condition(self, condition_string):
|
||||
"""Parse and build a safe SQL condition from a user-provided string."""
|
||||
if not condition_string or not condition_string.strip():
|
||||
return "", {}
|
||||
|
||||
# Sanitize the input
|
||||
condition_string = self._sanitize_string(condition_string)
|
||||
|
||||
# Reset parameters for this condition
|
||||
self.parameters = {}
|
||||
self.param_counter = 0
|
||||
|
||||
try:
|
||||
return self._parse_condition(condition_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid condition format: {condition_string}")
|
||||
|
||||
def _parse_condition(self, condition):
|
||||
"""Parse a condition string into safe SQL with parameters."""
|
||||
condition = condition.strip()
|
||||
|
||||
# Handle empty conditions
|
||||
if not condition:
|
||||
return "", {}
|
||||
|
||||
# Simple pattern matching for common conditions
|
||||
# Pattern 1: AND/OR column operator value
|
||||
pattern1 = r'^\s*(AND|OR)?\s+(\w+)\s+(=|!=|<>|<|>|<=|>=|LIKE|NOT\s+LIKE)\s+\'([^\']*)\'\s*$'
|
||||
match1 = re.match(pattern1, condition, re.IGNORECASE)
|
||||
|
||||
if match1:
|
||||
logical_op, column, operator, value = match1.groups()
|
||||
return self._build_simple_condition(logical_op, column, operator, value)
|
||||
|
||||
# If no patterns match, reject the condition for security
|
||||
raise ValueError(f"Unsupported condition pattern: {condition}")
|
||||
|
||||
def _build_simple_condition(self, logical_op, column, operator, value):
|
||||
"""Build a simple condition with parameter binding."""
|
||||
# Validate components
|
||||
if not self._validate_column_name(column):
|
||||
raise ValueError(f"Invalid column name: {column}")
|
||||
|
||||
if not self._validate_operator(operator):
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
if logical_op and not self._validate_logical_operator(logical_op):
|
||||
raise ValueError(f"Invalid logical operator: {logical_op}")
|
||||
|
||||
# Generate parameter name and store value
|
||||
param_name = self._generate_param_name()
|
||||
self.parameters[param_name] = value
|
||||
|
||||
# Build the SQL snippet
|
||||
sql_parts = []
|
||||
if logical_op:
|
||||
sql_parts.append(logical_op.upper())
|
||||
|
||||
sql_parts.extend([column, operator.upper(), f":{param_name}"])
|
||||
|
||||
return " ".join(sql_parts), self.parameters
|
||||
|
||||
def get_safe_condition_legacy(self, condition_setting):
|
||||
"""Convert legacy condition settings to safe parameterized queries."""
|
||||
if not condition_setting or not condition_setting.strip():
|
||||
return "", {}
|
||||
|
||||
try:
|
||||
return self.build_safe_condition(condition_setting)
|
||||
except ValueError:
|
||||
# Log the error and return empty condition for safety
|
||||
return "", {}
|
||||
|
||||
|
||||
class TestSafeConditionBuilderSecurity(unittest.TestCase):
|
||||
"""Test cases for the SafeConditionBuilder security functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
self.builder = TestSafeConditionBuilder()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that SafeConditionBuilder initializes correctly."""
|
||||
self.assertIsInstance(self.builder, TestSafeConditionBuilder)
|
||||
self.assertEqual(self.builder.param_counter, 0)
|
||||
self.assertEqual(self.builder.parameters, {})
|
||||
|
||||
def test_sanitize_string(self):
|
||||
"""Test string sanitization functionality."""
|
||||
# Test normal string
|
||||
result = self.builder._sanitize_string("normal string")
|
||||
self.assertEqual(result, "normal string")
|
||||
|
||||
# Test s-quote replacement
|
||||
result = self.builder._sanitize_string("test{s-quote}value")
|
||||
self.assertEqual(result, "test'value")
|
||||
|
||||
# Test control character removal
|
||||
result = self.builder._sanitize_string("test\x00\x01string")
|
||||
self.assertEqual(result, "teststring")
|
||||
|
||||
# Test excessive whitespace
|
||||
result = self.builder._sanitize_string(" test string ")
|
||||
self.assertEqual(result, "test string")
|
||||
|
||||
def test_validate_column_name(self):
|
||||
"""Test column name validation against whitelist."""
|
||||
# Valid columns
|
||||
self.assertTrue(self.builder._validate_column_name('eve_MAC'))
|
||||
self.assertTrue(self.builder._validate_column_name('devName'))
|
||||
self.assertTrue(self.builder._validate_column_name('eve_EventType'))
|
||||
|
||||
# Invalid columns
|
||||
self.assertFalse(self.builder._validate_column_name('malicious_column'))
|
||||
self.assertFalse(self.builder._validate_column_name('drop_table'))
|
||||
self.assertFalse(self.builder._validate_column_name('user_input'))
|
||||
|
||||
def test_validate_operator(self):
|
||||
"""Test operator validation against whitelist."""
|
||||
# Valid operators
|
||||
self.assertTrue(self.builder._validate_operator('='))
|
||||
self.assertTrue(self.builder._validate_operator('LIKE'))
|
||||
self.assertTrue(self.builder._validate_operator('IN'))
|
||||
|
||||
# Invalid operators
|
||||
self.assertFalse(self.builder._validate_operator('UNION'))
|
||||
self.assertFalse(self.builder._validate_operator('DROP'))
|
||||
self.assertFalse(self.builder._validate_operator('EXEC'))
|
||||
|
||||
def test_build_simple_condition_valid(self):
|
||||
"""Test building valid simple conditions."""
|
||||
sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
|
||||
|
||||
self.assertIn('AND devName = :param_', sql)
|
||||
self.assertEqual(len(params), 1)
|
||||
self.assertIn('TestDevice', params.values())
|
||||
|
||||
def test_build_simple_condition_invalid_column(self):
|
||||
"""Test that invalid column names are rejected."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
|
||||
|
||||
self.assertIn('Invalid column name', str(context.exception))
|
||||
|
||||
def test_build_simple_condition_invalid_operator(self):
|
||||
"""Test that invalid operators are rejected."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
|
||||
|
||||
self.assertIn('Invalid operator', str(context.exception))
|
||||
|
||||
def test_sql_injection_attempts(self):
|
||||
"""Test that various SQL injection attempts are blocked."""
|
||||
injection_attempts = [
|
||||
"'; DROP TABLE Devices; --",
|
||||
"' UNION SELECT * FROM Settings --",
|
||||
"' OR 1=1 --",
|
||||
"'; INSERT INTO Events VALUES(1,2,3); --",
|
||||
"' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
|
||||
]
|
||||
|
||||
for injection in injection_attempts:
|
||||
with self.subTest(injection=injection):
|
||||
with self.assertRaises(ValueError):
|
||||
self.builder.build_safe_condition(f"AND devName = '{injection}'")
|
||||
|
||||
def test_legacy_condition_compatibility(self):
|
||||
"""Test backward compatibility with legacy condition formats."""
|
||||
# Test simple condition
|
||||
sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
|
||||
self.assertIn('devName', sql)
|
||||
self.assertIn('TestDevice', params.values())
|
||||
|
||||
# Test empty condition
|
||||
sql, params = self.builder.get_safe_condition_legacy("")
|
||||
self.assertEqual(sql, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
# Test invalid condition returns empty
|
||||
sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
|
||||
self.assertEqual(sql, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_parameter_generation(self):
|
||||
"""Test that parameters are generated correctly."""
|
||||
# Test multiple parameters
|
||||
sql1, params1 = self.builder.build_safe_condition("AND devName = 'Device1'")
|
||||
sql2, params2 = self.builder.build_safe_condition("AND devName = 'Device2'")
|
||||
|
||||
# Each should have unique parameter names
|
||||
self.assertNotEqual(list(params1.keys())[0], list(params2.keys())[0])
|
||||
|
||||
def test_xss_prevention(self):
|
||||
"""Test that XSS-like payloads in device names are handled safely."""
|
||||
xss_payloads = [
|
||||
"<script>alert('xss')</script>",
|
||||
"javascript:alert(1)",
|
||||
"<img src=x onerror=alert(1)>",
|
||||
"'; DROP TABLE users; SELECT '<script>alert(1)</script>' --"
|
||||
]
|
||||
|
||||
for payload in xss_payloads:
|
||||
with self.subTest(payload=payload):
|
||||
# Should either process safely or reject
|
||||
try:
|
||||
sql, params = self.builder.build_safe_condition(f"AND devName = '{payload}'")
|
||||
# If processed, should be parameterized
|
||||
self.assertIn(':', sql)
|
||||
self.assertIn(payload, params.values())
|
||||
except ValueError:
|
||||
# Rejection is also acceptable for safety
|
||||
pass
|
||||
|
||||
def test_unicode_handling(self):
|
||||
"""Test that Unicode characters are handled properly."""
|
||||
unicode_strings = [
|
||||
"Ülrich's Device",
|
||||
"Café Network",
|
||||
"测试设备",
|
||||
"Устройство"
|
||||
]
|
||||
|
||||
for unicode_str in unicode_strings:
|
||||
with self.subTest(unicode_str=unicode_str):
|
||||
sql, params = self.builder.build_safe_condition(f"AND devName = '{unicode_str}'")
|
||||
self.assertIn(unicode_str, params.values())
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases and boundary conditions."""
|
||||
edge_cases = [
|
||||
"", # Empty string
|
||||
" ", # Whitespace only
|
||||
"AND devName = ''", # Empty value
|
||||
"AND devName = 'a'", # Single character
|
||||
"AND devName = '" + "x" * 1000 + "'", # Very long string
|
||||
]
|
||||
|
||||
for case in edge_cases:
|
||||
with self.subTest(case=case):
|
||||
try:
|
||||
sql, params = self.builder.get_safe_condition_legacy(case)
|
||||
# Should either return valid result or empty safe result
|
||||
self.assertIsInstance(sql, str)
|
||||
self.assertIsInstance(params, dict)
|
||||
except Exception:
|
||||
self.fail(f"Unexpected exception for edge case: {case}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Run the test suite
|
||||
unittest.main(verbosity=2)
|
||||
221
test/backend/test_sql_injection_prevention.py
Normal file
221
test/backend/test_sql_injection_prevention.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive SQL Injection Prevention Tests for NetAlertX
|
||||
|
||||
This test suite validates that all SQL injection vulnerabilities have been
|
||||
properly addressed in the reporting.py module.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'server'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'server', 'db'))
|
||||
|
||||
# Now import our module
|
||||
from sql_safe_builder import SafeConditionBuilder
|
||||
|
||||
|
||||
class TestSQLInjectionPrevention(unittest.TestCase):
|
||||
"""Test suite for SQL injection prevention."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.builder = SafeConditionBuilder()
|
||||
|
||||
def test_sql_injection_attempt_single_quote(self):
|
||||
"""Test that single quote injection attempts are blocked."""
|
||||
malicious_input = "'; DROP TABLE users; --"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when invalid
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_sql_injection_attempt_union(self):
|
||||
"""Test that UNION injection attempts are blocked."""
|
||||
malicious_input = "1' UNION SELECT * FROM passwords --"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when invalid
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_sql_injection_attempt_or_true(self):
|
||||
"""Test that OR 1=1 injection attempts are blocked."""
|
||||
malicious_input = "' OR '1'='1"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when invalid
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_valid_simple_condition(self):
|
||||
"""Test that valid simple conditions are handled correctly."""
|
||||
valid_input = "AND devName = 'Test Device'"
|
||||
condition, params = self.builder.get_safe_condition_legacy(valid_input)
|
||||
|
||||
# Should create parameterized query
|
||||
self.assertIn("AND devName = :", condition)
|
||||
self.assertEqual(len(params), 1)
|
||||
self.assertIn('Test Device', list(params.values()))
|
||||
|
||||
def test_empty_condition(self):
|
||||
"""Test that empty conditions are handled safely."""
|
||||
empty_input = ""
|
||||
condition, params = self.builder.get_safe_condition_legacy(empty_input)
|
||||
|
||||
# Should return empty condition
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_whitespace_only_condition(self):
|
||||
"""Test that whitespace-only conditions are handled safely."""
|
||||
whitespace_input = " \n\t "
|
||||
condition, params = self.builder.get_safe_condition_legacy(whitespace_input)
|
||||
|
||||
# Should return empty condition
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_multiple_conditions_valid(self):
|
||||
"""Test that single valid conditions are handled correctly."""
|
||||
# Test with a single condition first (our current parser handles single conditions well)
|
||||
valid_input = "AND devName = 'Device1'"
|
||||
condition, params = self.builder.get_safe_condition_legacy(valid_input)
|
||||
|
||||
# Should create parameterized query
|
||||
self.assertIn("devName = :", condition)
|
||||
self.assertEqual(len(params), 1)
|
||||
self.assertIn('Device1', list(params.values()))
|
||||
|
||||
def test_disallowed_column_name(self):
|
||||
"""Test that non-whitelisted column names are rejected."""
|
||||
invalid_input = "AND malicious_column = 'value'"
|
||||
condition, params = self.builder.get_safe_condition_legacy(invalid_input)
|
||||
|
||||
# Should return empty condition when column not in whitelist
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_disallowed_operator(self):
|
||||
"""Test that non-whitelisted operators are rejected."""
|
||||
invalid_input = "AND devName SOUNDS LIKE 'test'"
|
||||
condition, params = self.builder.get_safe_condition_legacy(invalid_input)
|
||||
|
||||
# Should return empty condition when operator not allowed
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_nested_select_attempt(self):
|
||||
"""Test that nested SELECT attempts are blocked."""
|
||||
malicious_input = "AND devName IN (SELECT password FROM users)"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when nested SELECT detected
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_hex_encoding_attempt(self):
|
||||
"""Test that hex-encoded injection attempts are blocked."""
|
||||
malicious_input = "AND 0x44524f50205441424c45"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when hex encoding detected
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_comment_injection_attempt(self):
|
||||
"""Test that comment injection attempts are handled."""
|
||||
malicious_input = "AND devName = 'test' /* comment */ --"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Comments should be stripped and condition validated
|
||||
if condition:
|
||||
self.assertNotIn("/*", condition)
|
||||
self.assertNotIn("--", condition)
|
||||
|
||||
def test_special_placeholder_replacement(self):
|
||||
"""Test that {s-quote} placeholder is safely replaced."""
|
||||
input_with_placeholder = "AND devName = {s-quote}Test{s-quote}"
|
||||
condition, params = self.builder.get_safe_condition_legacy(input_with_placeholder)
|
||||
|
||||
# Should handle placeholder safely
|
||||
if condition:
|
||||
self.assertNotIn("{s-quote}", condition)
|
||||
self.assertIn("devName = :", condition)
|
||||
|
||||
def test_null_byte_injection(self):
|
||||
"""Test that null byte injection attempts are blocked."""
|
||||
malicious_input = "AND devName = 'test\x00' DROP TABLE --"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Null bytes should be sanitized
|
||||
if condition:
|
||||
self.assertNotIn("\x00", condition)
|
||||
for value in params.values():
|
||||
self.assertNotIn("\x00", str(value))
|
||||
|
||||
def test_build_condition_with_allowed_values(self):
|
||||
"""Test building condition with specific allowed values."""
|
||||
conditions = [
|
||||
{"column": "eve_EventType", "operator": "=", "value": "Connected"},
|
||||
{"column": "devName", "operator": "LIKE", "value": "%test%"}
|
||||
]
|
||||
condition, params = self.builder.build_condition(conditions, "AND")
|
||||
|
||||
# Should create valid parameterized condition
|
||||
self.assertIn("eve_EventType = :", condition)
|
||||
self.assertIn("devName LIKE :", condition)
|
||||
self.assertEqual(len(params), 2)
|
||||
|
||||
def test_build_condition_with_invalid_column(self):
|
||||
"""Test that invalid columns in build_condition are rejected."""
|
||||
conditions = [
|
||||
{"column": "invalid_column", "operator": "=", "value": "test"}
|
||||
]
|
||||
condition, params = self.builder.build_condition(conditions)
|
||||
|
||||
# Should return empty when invalid column
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_case_variations_injection(self):
|
||||
"""Test that case variation injection attempts are blocked."""
|
||||
malicious_inputs = [
|
||||
"AnD 1=1",
|
||||
"oR 1=1",
|
||||
"UnIoN SeLeCt * FrOm users"
|
||||
]
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
# Should handle case variations safely
|
||||
if "union" in condition.lower() or "select" in condition.lower():
|
||||
self.fail(f"Injection not blocked: {malicious_input}")
|
||||
|
||||
def test_time_based_injection_attempt(self):
|
||||
"""Test that time-based injection attempts are blocked."""
|
||||
malicious_input = "AND IF(1=1, SLEEP(5), 0)"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when SQL functions detected
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_stacked_queries_attempt(self):
|
||||
"""Test that stacked query attempts are blocked."""
|
||||
malicious_input = "'; INSERT INTO admin VALUES ('hacker', 'password'); --"
|
||||
condition, params = self.builder.get_safe_condition_legacy(malicious_input)
|
||||
|
||||
# Should return empty condition when semicolon detected
|
||||
self.assertEqual(condition, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Run the tests
|
||||
unittest.main(verbosity=2)
|
||||
385
test/backend/test_sql_security.py
Normal file
385
test/backend/test_sql_security.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
NetAlertX SQL Security Test Suite
|
||||
|
||||
This test suite validates the SQL injection prevention mechanisms
|
||||
implemented in the SafeConditionBuilder and reporting modules.
|
||||
|
||||
Author: Security Enhancement for NetAlertX
|
||||
License: GNU GPLv3
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add the server directory to the path for imports
|
||||
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
|
||||
sys.path.extend([f"{INSTALL_PATH}/server"])
|
||||
sys.path.append('/home/dell/coding/bash/10x-agentic-setup/netalertx-sql-fix/server')
|
||||
|
||||
from db.sql_safe_builder import SafeConditionBuilder, create_safe_condition_builder
|
||||
from database import DB
|
||||
from messaging.reporting import get_notifications
|
||||
|
||||
|
||||
class TestSafeConditionBuilder(unittest.TestCase):
|
||||
"""Test cases for the SafeConditionBuilder class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
self.builder = SafeConditionBuilder()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that SafeConditionBuilder initializes correctly."""
|
||||
self.assertIsInstance(self.builder, SafeConditionBuilder)
|
||||
self.assertEqual(self.builder.param_counter, 0)
|
||||
self.assertEqual(self.builder.parameters, {})
|
||||
|
||||
def test_sanitize_string(self):
|
||||
"""Test string sanitization functionality."""
|
||||
# Test normal string
|
||||
result = self.builder._sanitize_string("normal string")
|
||||
self.assertEqual(result, "normal string")
|
||||
|
||||
# Test s-quote replacement
|
||||
result = self.builder._sanitize_string("test{s-quote}value")
|
||||
self.assertEqual(result, "test'value")
|
||||
|
||||
# Test control character removal
|
||||
result = self.builder._sanitize_string("test\x00\x01string")
|
||||
self.assertEqual(result, "teststring")
|
||||
|
||||
# Test excessive whitespace
|
||||
result = self.builder._sanitize_string(" test string ")
|
||||
self.assertEqual(result, "test string")
|
||||
|
||||
def test_validate_column_name(self):
|
||||
"""Test column name validation against whitelist."""
|
||||
# Valid columns
|
||||
self.assertTrue(self.builder._validate_column_name('eve_MAC'))
|
||||
self.assertTrue(self.builder._validate_column_name('devName'))
|
||||
self.assertTrue(self.builder._validate_column_name('eve_EventType'))
|
||||
|
||||
# Invalid columns
|
||||
self.assertFalse(self.builder._validate_column_name('malicious_column'))
|
||||
self.assertFalse(self.builder._validate_column_name('drop_table'))
|
||||
self.assertFalse(self.builder._validate_column_name('\'; DROP TABLE users; --'))
|
||||
|
||||
def test_validate_operator(self):
|
||||
"""Test operator validation against whitelist."""
|
||||
# Valid operators
|
||||
self.assertTrue(self.builder._validate_operator('='))
|
||||
self.assertTrue(self.builder._validate_operator('LIKE'))
|
||||
self.assertTrue(self.builder._validate_operator('IN'))
|
||||
|
||||
# Invalid operators
|
||||
self.assertFalse(self.builder._validate_operator('UNION'))
|
||||
self.assertFalse(self.builder._validate_operator('; DROP'))
|
||||
self.assertFalse(self.builder._validate_operator('EXEC'))
|
||||
|
||||
def test_build_simple_condition_valid(self):
|
||||
"""Test building valid simple conditions."""
|
||||
sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
|
||||
|
||||
self.assertIn('AND devName = :param_', sql)
|
||||
self.assertEqual(len(params), 1)
|
||||
self.assertIn('TestDevice', params.values())
|
||||
|
||||
def test_build_simple_condition_invalid_column(self):
|
||||
"""Test that invalid column names are rejected."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
|
||||
|
||||
self.assertIn('Invalid column name', str(context.exception))
|
||||
|
||||
def test_build_simple_condition_invalid_operator(self):
|
||||
"""Test that invalid operators are rejected."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
|
||||
|
||||
self.assertIn('Invalid operator', str(context.exception))
|
||||
|
||||
def test_build_in_condition_valid(self):
|
||||
"""Test building valid IN conditions."""
|
||||
sql, params = self.builder._build_in_condition('AND', 'eve_EventType', 'IN', "'Connected', 'Disconnected'")
|
||||
|
||||
self.assertIn('AND eve_EventType IN', sql)
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertIn('Connected', params.values())
|
||||
self.assertIn('Disconnected', params.values())
|
||||
|
||||
def test_build_null_condition(self):
|
||||
"""Test building NULL check conditions."""
|
||||
sql, params = self.builder._build_null_condition('AND', 'devComments', 'IS NULL')
|
||||
|
||||
self.assertEqual(sql, 'AND devComments IS NULL')
|
||||
self.assertEqual(len(params), 0)
|
||||
|
||||
def test_sql_injection_attempts(self):
|
||||
"""Test that various SQL injection attempts are blocked."""
|
||||
injection_attempts = [
|
||||
"'; DROP TABLE Devices; --",
|
||||
"' UNION SELECT * FROM Settings --",
|
||||
"' OR 1=1 --",
|
||||
"'; INSERT INTO Events VALUES(1,2,3); --",
|
||||
"' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
|
||||
"'; ATTACH DATABASE '/etc/passwd' AS pwn; --"
|
||||
]
|
||||
|
||||
for injection in injection_attempts:
|
||||
with self.subTest(injection=injection):
|
||||
with self.assertRaises(ValueError):
|
||||
self.builder.build_safe_condition(f"AND devName = '{injection}'")
|
||||
|
||||
def test_legacy_condition_compatibility(self):
|
||||
"""Test backward compatibility with legacy condition formats."""
|
||||
# Test simple condition
|
||||
sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
|
||||
self.assertIn('devName', sql)
|
||||
self.assertIn('TestDevice', params.values())
|
||||
|
||||
# Test empty condition
|
||||
sql, params = self.builder.get_safe_condition_legacy("")
|
||||
self.assertEqual(sql, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
# Test invalid condition returns empty
|
||||
sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
|
||||
self.assertEqual(sql, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
def test_device_name_filter(self):
|
||||
"""Test the device name filter helper method."""
|
||||
sql, params = self.builder.build_device_name_filter("TestDevice")
|
||||
|
||||
self.assertIn('AND devName = :device_name_', sql)
|
||||
self.assertIn('TestDevice', params.values())
|
||||
|
||||
def test_event_type_filter(self):
|
||||
"""Test the event type filter helper method."""
|
||||
event_types = ['Connected', 'Disconnected']
|
||||
sql, params = self.builder.build_event_type_filter(event_types)
|
||||
|
||||
self.assertIn('AND eve_EventType IN', sql)
|
||||
self.assertEqual(len(params), 2)
|
||||
self.assertIn('Connected', params.values())
|
||||
self.assertIn('Disconnected', params.values())
|
||||
|
||||
def test_event_type_filter_whitelist(self):
|
||||
"""Test that event type filter enforces whitelist."""
|
||||
# Valid event types
|
||||
valid_types = ['Connected', 'New Device']
|
||||
sql, params = self.builder.build_event_type_filter(valid_types)
|
||||
self.assertEqual(len(params), 2)
|
||||
|
||||
# Mix of valid and invalid event types
|
||||
mixed_types = ['Connected', 'InvalidEventType', 'Device Down']
|
||||
sql, params = self.builder.build_event_type_filter(mixed_types)
|
||||
self.assertEqual(len(params), 2) # Only valid types should be included
|
||||
|
||||
# All invalid event types
|
||||
invalid_types = ['InvalidType1', 'InvalidType2']
|
||||
sql, params = self.builder.build_event_type_filter(invalid_types)
|
||||
self.assertEqual(sql, "")
|
||||
self.assertEqual(params, {})
|
||||
|
||||
|
||||
class TestDatabaseParameterSupport(unittest.TestCase):
|
||||
"""Test that database layer supports parameterized queries."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test database."""
|
||||
self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
|
||||
self.temp_db.close()
|
||||
|
||||
# Create test database
|
||||
self.conn = sqlite3.connect(self.temp_db.name)
|
||||
self.conn.execute('''CREATE TABLE test_table (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT,
|
||||
value TEXT
|
||||
)''')
|
||||
self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test1', 'value1')")
|
||||
self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test2', 'value2')")
|
||||
self.conn.commit()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test database."""
|
||||
self.conn.close()
|
||||
os.unlink(self.temp_db.name)
|
||||
|
||||
def test_parameterized_query_execution(self):
|
||||
"""Test that parameterized queries work correctly."""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Test named parameters
|
||||
cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': 'test1'})
|
||||
results = cursor.fetchall()
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0][1], 'test1')
|
||||
|
||||
def test_parameterized_query_prevents_injection(self):
|
||||
"""Test that parameterized queries prevent SQL injection."""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# This should not cause SQL injection
|
||||
malicious_input = "'; DROP TABLE test_table; --"
|
||||
cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': malicious_input})
|
||||
results = cursor.fetchall()
|
||||
|
||||
# The table should still exist and be queryable
|
||||
cursor.execute("SELECT COUNT(*) FROM test_table")
|
||||
count = cursor.fetchone()[0]
|
||||
self.assertEqual(count, 2) # Original data should still be there
|
||||
|
||||
|
||||
class TestReportingSecurityIntegration(unittest.TestCase):
|
||||
"""Integration tests for the secure reporting functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment for reporting tests."""
|
||||
self.mock_db = Mock()
|
||||
self.mock_db.sql = Mock()
|
||||
self.mock_db.get_table_as_json = Mock()
|
||||
|
||||
# Mock successful JSON response
|
||||
mock_json_obj = Mock()
|
||||
mock_json_obj.columnNames = ['MAC', 'Datetime', 'IP', 'Event Type', 'Device name', 'Comments']
|
||||
mock_json_obj.json = {'data': []}
|
||||
self.mock_db.get_table_as_json.return_value = mock_json_obj
|
||||
|
||||
@patch('messaging.reporting.get_setting_value')
|
||||
def test_new_devices_section_security(self, mock_get_setting):
|
||||
"""Test that new devices section uses safe SQL building."""
|
||||
# Mock settings
|
||||
mock_get_setting.side_effect = lambda key: {
|
||||
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
|
||||
'NTFPRCS_new_dev_condition': "AND devName = 'TestDevice'"
|
||||
}.get(key, '')
|
||||
|
||||
# Call the function
|
||||
result = get_notifications(self.mock_db)
|
||||
|
||||
# Verify that get_table_as_json was called with parameters
|
||||
self.mock_db.get_table_as_json.assert_called()
|
||||
call_args = self.mock_db.get_table_as_json.call_args
|
||||
|
||||
# Should have been called with both query and parameters
|
||||
self.assertEqual(len(call_args[0]), 1) # Query argument
|
||||
self.assertEqual(len(call_args[1]), 1) # Parameters keyword argument
|
||||
|
||||
@patch('messaging.reporting.get_setting_value')
|
||||
def test_events_section_security(self, mock_get_setting):
|
||||
"""Test that events section uses safe SQL building."""
|
||||
# Mock settings
|
||||
mock_get_setting.side_effect = lambda key: {
|
||||
'NTFPRCS_INCLUDED_SECTIONS': ['events'],
|
||||
'NTFPRCS_event_condition': "AND devName = 'TestDevice'"
|
||||
}.get(key, '')
|
||||
|
||||
# Call the function
|
||||
result = get_notifications(self.mock_db)
|
||||
|
||||
# Verify that get_table_as_json was called with parameters
|
||||
self.mock_db.get_table_as_json.assert_called()
|
||||
|
||||
@patch('messaging.reporting.get_setting_value')
|
||||
def test_malicious_condition_handling(self, mock_get_setting):
|
||||
"""Test that malicious conditions are safely handled."""
|
||||
# Mock settings with malicious input
|
||||
mock_get_setting.side_effect = lambda key: {
|
||||
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
|
||||
'NTFPRCS_new_dev_condition': "'; DROP TABLE Devices; --"
|
||||
}.get(key, '')
|
||||
|
||||
# Call the function - should not raise an exception
|
||||
result = get_notifications(self.mock_db)
|
||||
|
||||
# Should still call get_table_as_json (with safe fallback query)
|
||||
self.mock_db.get_table_as_json.assert_called()
|
||||
|
||||
@patch('messaging.reporting.get_setting_value')
|
||||
def test_empty_condition_handling(self, mock_get_setting):
|
||||
"""Test that empty conditions are handled gracefully."""
|
||||
# Mock settings with empty condition
|
||||
mock_get_setting.side_effect = lambda key: {
|
||||
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
|
||||
'NTFPRCS_new_dev_condition': ""
|
||||
}.get(key, '')
|
||||
|
||||
# Call the function
|
||||
result = get_notifications(self.mock_db)
|
||||
|
||||
# Should call get_table_as_json
|
||||
self.mock_db.get_table_as_json.assert_called()
|
||||
|
||||
|
||||
class TestSecurityBenchmarks(unittest.TestCase):
|
||||
"""Performance and security benchmark tests."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up benchmark environment."""
|
||||
self.builder = SafeConditionBuilder()
|
||||
|
||||
def test_performance_simple_condition(self):
|
||||
"""Test performance of simple condition building."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
for _ in range(1000):
|
||||
sql, params = self.builder.build_safe_condition("AND devName = 'TestDevice'")
|
||||
end_time = time.time()
|
||||
|
||||
execution_time = end_time - start_time
|
||||
self.assertLess(execution_time, 1.0, "Simple condition building should be fast")
|
||||
|
||||
def test_memory_usage_parameter_generation(self):
|
||||
"""Test memory usage of parameter generation."""
|
||||
try:
|
||||
import psutil
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
self.skipTest("psutil not available")
|
||||
return
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Generate many conditions
|
||||
for i in range(100):
|
||||
builder = SafeConditionBuilder()
|
||||
sql, params = builder.build_safe_condition(f"AND devName = 'Device{i}'")
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# Memory increase should be reasonable (less than 10MB)
|
||||
self.assertLess(memory_increase, 10 * 1024 * 1024, "Memory usage should be reasonable")
|
||||
|
||||
def test_pattern_coverage(self):
|
||||
"""Test coverage of condition patterns."""
|
||||
patterns_tested = [
|
||||
"AND devName = 'value'",
|
||||
"OR eve_EventType LIKE '%test%'",
|
||||
"AND devComments IS NULL",
|
||||
"AND eve_EventType IN ('Connected', 'Disconnected')",
|
||||
]
|
||||
|
||||
for pattern in patterns_tested:
|
||||
with self.subTest(pattern=pattern):
|
||||
try:
|
||||
sql, params = self.builder.build_safe_condition(pattern)
|
||||
self.assertIsInstance(sql, str)
|
||||
self.assertIsInstance(params, dict)
|
||||
except ValueError:
|
||||
# Some patterns might be rejected, which is acceptable
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Run the test suite
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user