Restore test_safe_builder_unit.py to upstream version (remove local changes)

This commit is contained in:
Adam Outler
2025-10-24 20:32:50 +00:00
parent 7f74c2d6f3
commit 32f9111f66

View File

@@ -4,15 +4,15 @@ This test file has minimal dependencies to ensure it can run in any environment.
""" """
import sys import sys
import unittest
import re import re
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
# Mock the logger module to avoid dependency issues # Mock the logger module to avoid dependency issues
sys.modules['logger'] = Mock() sys.modules['logger'] = Mock()
# Standalone version of SafeConditionBuilder for testing # Standalone version of SafeConditionBuilder for testing
class SafeConditionBuilder: class TestSafeConditionBuilder:
""" """
Test version of SafeConditionBuilder with mock logger. Test version of SafeConditionBuilder with mock logger.
""" """
@@ -152,182 +152,180 @@ class SafeConditionBuilder:
return "", {} return "", {}
@pytest.fixture class TestSafeConditionBuilderSecurity(unittest.TestCase):
def builder(): """Test cases for the SafeConditionBuilder security functionality."""
"""Fixture to provide a fresh SafeConditionBuilder instance for each test."""
return SafeConditionBuilder() def setUp(self):
"""Set up test fixtures before each test method."""
self.builder = TestSafeConditionBuilder()
def test_initialization(self):
"""Test that SafeConditionBuilder initializes correctly."""
self.assertIsInstance(self.builder, TestSafeConditionBuilder)
self.assertEqual(self.builder.param_counter, 0)
self.assertEqual(self.builder.parameters, {})
def test_sanitize_string(self):
"""Test string sanitization functionality."""
# Test normal string
result = self.builder._sanitize_string("normal string")
self.assertEqual(result, "normal string")
# Test s-quote replacement
result = self.builder._sanitize_string("test{s-quote}value")
self.assertEqual(result, "test'value")
# Test control character removal
result = self.builder._sanitize_string("test\x00\x01string")
self.assertEqual(result, "teststring")
# Test excessive whitespace
result = self.builder._sanitize_string(" test string ")
self.assertEqual(result, "test string")
def test_validate_column_name(self):
"""Test column name validation against whitelist."""
# Valid columns
self.assertTrue(self.builder._validate_column_name('eve_MAC'))
self.assertTrue(self.builder._validate_column_name('devName'))
self.assertTrue(self.builder._validate_column_name('eve_EventType'))
# Invalid columns
self.assertFalse(self.builder._validate_column_name('malicious_column'))
self.assertFalse(self.builder._validate_column_name('drop_table'))
self.assertFalse(self.builder._validate_column_name('user_input'))
def test_validate_operator(self):
"""Test operator validation against whitelist."""
# Valid operators
self.assertTrue(self.builder._validate_operator('='))
self.assertTrue(self.builder._validate_operator('LIKE'))
self.assertTrue(self.builder._validate_operator('IN'))
# Invalid operators
self.assertFalse(self.builder._validate_operator('UNION'))
self.assertFalse(self.builder._validate_operator('DROP'))
self.assertFalse(self.builder._validate_operator('EXEC'))
def test_build_simple_condition_valid(self):
"""Test building valid simple conditions."""
sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
self.assertIn('AND devName = :param_', sql)
self.assertEqual(len(params), 1)
self.assertIn('TestDevice', params.values())
def test_build_simple_condition_invalid_column(self):
"""Test that invalid column names are rejected."""
with self.assertRaises(ValueError) as context:
self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
self.assertIn('Invalid column name', str(context.exception))
def test_build_simple_condition_invalid_operator(self):
"""Test that invalid operators are rejected."""
with self.assertRaises(ValueError) as context:
self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
self.assertIn('Invalid operator', str(context.exception))
def test_sql_injection_attempts(self):
"""Test that various SQL injection attempts are blocked."""
injection_attempts = [
"'; DROP TABLE Devices; --",
"' UNION SELECT * FROM Settings --",
"' OR 1=1 --",
"'; INSERT INTO Events VALUES(1,2,3); --",
"' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
]
for injection in injection_attempts:
with self.subTest(injection=injection):
with self.assertRaises(ValueError):
self.builder.build_safe_condition(f"AND devName = '{injection}'")
def test_legacy_condition_compatibility(self):
"""Test backward compatibility with legacy condition formats."""
# Test simple condition
sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
self.assertIn('devName', sql)
self.assertIn('TestDevice', params.values())
# Test empty condition
sql, params = self.builder.get_safe_condition_legacy("")
self.assertEqual(sql, "")
self.assertEqual(params, {})
# Test invalid condition returns empty
sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
self.assertEqual(sql, "")
self.assertEqual(params, {})
def test_parameter_generation(self):
"""Test that parameters are generated correctly."""
# Test multiple parameters
sql1, params1 = self.builder.build_safe_condition("AND devName = 'Device1'")
sql2, params2 = self.builder.build_safe_condition("AND devName = 'Device2'")
# Each should have unique parameter names
self.assertNotEqual(list(params1.keys())[0], list(params2.keys())[0])
def test_xss_prevention(self):
"""Test that XSS-like payloads in device names are handled safely."""
xss_payloads = [
"<script>alert('xss')</script>",
"javascript:alert(1)",
"<img src=x onerror=alert(1)>",
"'; DROP TABLE users; SELECT '<script>alert(1)</script>' --"
]
for payload in xss_payloads:
with self.subTest(payload=payload):
# Should either process safely or reject
try:
sql, params = self.builder.build_safe_condition(f"AND devName = '{payload}'")
# If processed, should be parameterized
self.assertIn(':', sql)
self.assertIn(payload, params.values())
except ValueError:
# Rejection is also acceptable for safety
pass
def test_unicode_handling(self):
"""Test that Unicode characters are handled properly."""
unicode_strings = [
"Ülrich's Device",
"Café Network",
"测试设备",
"Устройство"
]
for unicode_str in unicode_strings:
with self.subTest(unicode_str=unicode_str):
sql, params = self.builder.build_safe_condition(f"AND devName = '{unicode_str}'")
self.assertIn(unicode_str, params.values())
def test_edge_cases(self):
"""Test edge cases and boundary conditions."""
edge_cases = [
"", # Empty string
" ", # Whitespace only
"AND devName = ''", # Empty value
"AND devName = 'a'", # Single character
"AND devName = '" + "x" * 1000 + "'", # Very long string
]
for case in edge_cases:
with self.subTest(case=case):
try:
sql, params = self.builder.get_safe_condition_legacy(case)
# Should either return valid result or empty safe result
self.assertIsInstance(sql, str)
self.assertIsInstance(params, dict)
except Exception:
self.fail(f"Unexpected exception for edge case: {case}")
def test_initialization(builder): if __name__ == '__main__':
"""Test that SafeConditionBuilder initializes correctly.""" # Run the test suite
assert isinstance(builder, SafeConditionBuilder) unittest.main(verbosity=2)
assert builder.param_counter == 0
assert builder.parameters == {}
def test_sanitize_string(builder):
"""Test string sanitization functionality."""
# Test normal string
result = builder._sanitize_string("normal string")
assert result == "normal string"
# Test s-quote replacement
result = builder._sanitize_string("test{s-quote}value")
assert result == "test'value"
# Test control character removal
result = builder._sanitize_string("test\x00\x01string")
assert result == "teststring"
# Test excessive whitespace
result = builder._sanitize_string(" test string ")
assert result == "test string"
def test_validate_column_name(builder):
"""Test column name validation against whitelist."""
# Valid columns
assert builder._validate_column_name('eve_MAC')
assert builder._validate_column_name('devName')
assert builder._validate_column_name('eve_EventType')
# Invalid columns
assert not builder._validate_column_name('malicious_column')
assert not builder._validate_column_name('drop_table')
assert not builder._validate_column_name('user_input')
def test_validate_operator(builder):
"""Test operator validation against whitelist."""
# Valid operators
assert builder._validate_operator('=')
assert builder._validate_operator('LIKE')
assert builder._validate_operator('IN')
# Invalid operators
assert not builder._validate_operator('UNION')
assert not builder._validate_operator('DROP')
assert not builder._validate_operator('EXEC')
def test_build_simple_condition_valid(builder):
"""Test building valid simple conditions."""
sql, params = builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
assert 'AND devName = :param_' in sql
assert len(params) == 1
assert 'TestDevice' in params.values()
def test_build_simple_condition_invalid_column(builder):
"""Test that invalid column names are rejected."""
with pytest.raises(ValueError) as exc_info:
builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
assert 'Invalid column name' in str(exc_info.value)
def test_build_simple_condition_invalid_operator(builder):
"""Test that invalid operators are rejected."""
with pytest.raises(ValueError) as exc_info:
builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
assert 'Invalid operator' in str(exc_info.value)
def test_sql_injection_attempts(builder):
"""Test that various SQL injection attempts are blocked."""
injection_attempts = [
"'; DROP TABLE Devices; --",
"' UNION SELECT * FROM Settings --",
"' OR 1=1 --",
"'; INSERT INTO Events VALUES(1,2,3); --",
"' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
]
for injection in injection_attempts:
with pytest.raises(ValueError):
builder.build_safe_condition(f"AND devName = '{injection}'")
def test_legacy_condition_compatibility(builder):
"""Test backward compatibility with legacy condition formats."""
# Test simple condition
sql, params = builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
assert 'devName' in sql
assert 'TestDevice' in params.values()
# Test empty condition
sql, params = builder.get_safe_condition_legacy("")
assert sql == ""
assert params == {}
# Test invalid condition returns empty
sql, params = builder.get_safe_condition_legacy("INVALID SQL INJECTION")
assert sql == ""
assert params == {}
def test_parameter_generation(builder):
"""Test that parameters are generated correctly."""
# Test single parameter
sql, params = builder.build_safe_condition("AND devName = 'Device1'")
# Should have 1 parameter
assert len(params) == 1
assert 'param_1' in params
def test_xss_prevention(builder):
"""Test that XSS-like payloads in device names are handled safely."""
xss_payloads = [
"<script>alert('xss')</script>",
"javascript:alert(1)",
"<img src=x onerror=alert(1)>",
"'; DROP TABLE users; SELECT '<script>alert(1)</script>' --"
]
for payload in xss_payloads:
# Should either process safely or reject
try:
sql, params = builder.build_safe_condition(f"AND devName = '{payload}'")
# If processed, should be parameterized
assert ':' in sql
assert payload in params.values()
except ValueError:
# Rejection is also acceptable for safety
pass
def test_unicode_handling(builder):
"""Test that Unicode characters are handled properly."""
unicode_strings = [
"Ülrichs Device",
"Café Network",
"测试设备",
"Устройство"
]
for unicode_str in unicode_strings:
sql, params = builder.build_safe_condition(f"AND devName = '{unicode_str}'")
assert unicode_str in params.values()
def test_edge_cases(builder):
"""Test edge cases and boundary conditions."""
edge_cases = [
"", # Empty string
" ", # Whitespace only
"AND devName = ''", # Empty value
"AND devName = 'a'", # Single character
"AND devName = '" + "x" * 1000 + "'", # Very long string
]
for case in edge_cases:
try:
sql, params = builder.get_safe_condition_legacy(case)
# Should either return valid result or empty safe result
assert isinstance(sql, str)
assert isinstance(params, dict)
except Exception:
pytest.fail(f"Unexpected exception for edge case: {case}")