Files
NetAlertX/test/test_sql_security.py
jokob-sk a981c9eec1 integration tests cleanup
Signed-off-by: jokob-sk <jokob.sk@gmail.com>
2025-09-21 16:17:20 +10:00

381 lines
15 KiB
Python
Executable File

"""
NetAlertX SQL Security Test Suite
This test suite validates the SQL injection prevention mechanisms
implemented in the SafeConditionBuilder and reporting modules.
Author: Security Enhancement for NetAlertX
License: GNU GPLv3
"""
import sys
import unittest
import sqlite3
import tempfile
import os
from unittest.mock import Mock, patch, MagicMock
# Add the server directory to the path for imports
INSTALL_PATH = "/app"
sys.path.extend([f"{INSTALL_PATH}/server"])
sys.path.append('/home/dell/coding/bash/10x-agentic-setup/netalertx-sql-fix/server')
from db.sql_safe_builder import SafeConditionBuilder, create_safe_condition_builder
from database import DB
from messaging.reporting import get_notifications
class TestSafeConditionBuilder(unittest.TestCase):
"""Test cases for the SafeConditionBuilder class."""
def setUp(self):
"""Set up test fixtures before each test method."""
self.builder = SafeConditionBuilder()
def test_initialization(self):
"""Test that SafeConditionBuilder initializes correctly."""
self.assertIsInstance(self.builder, SafeConditionBuilder)
self.assertEqual(self.builder.param_counter, 0)
self.assertEqual(self.builder.parameters, {})
def test_sanitize_string(self):
"""Test string sanitization functionality."""
# Test normal string
result = self.builder._sanitize_string("normal string")
self.assertEqual(result, "normal string")
# Test s-quote replacement
result = self.builder._sanitize_string("test{s-quote}value")
self.assertEqual(result, "test'value")
# Test control character removal
result = self.builder._sanitize_string("test\x00\x01string")
self.assertEqual(result, "teststring")
# Test excessive whitespace
result = self.builder._sanitize_string(" test string ")
self.assertEqual(result, "test string")
def test_validate_column_name(self):
"""Test column name validation against whitelist."""
# Valid columns
self.assertTrue(self.builder._validate_column_name('eve_MAC'))
self.assertTrue(self.builder._validate_column_name('devName'))
self.assertTrue(self.builder._validate_column_name('eve_EventType'))
# Invalid columns
self.assertFalse(self.builder._validate_column_name('malicious_column'))
self.assertFalse(self.builder._validate_column_name('drop_table'))
self.assertFalse(self.builder._validate_column_name('\'; DROP TABLE users; --'))
def test_validate_operator(self):
"""Test operator validation against whitelist."""
# Valid operators
self.assertTrue(self.builder._validate_operator('='))
self.assertTrue(self.builder._validate_operator('LIKE'))
self.assertTrue(self.builder._validate_operator('IN'))
# Invalid operators
self.assertFalse(self.builder._validate_operator('UNION'))
self.assertFalse(self.builder._validate_operator('; DROP'))
self.assertFalse(self.builder._validate_operator('EXEC'))
def test_build_simple_condition_valid(self):
"""Test building valid simple conditions."""
sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
self.assertIn('AND devName = :param_', sql)
self.assertEqual(len(params), 1)
self.assertIn('TestDevice', params.values())
def test_build_simple_condition_invalid_column(self):
"""Test that invalid column names are rejected."""
with self.assertRaises(ValueError) as context:
self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
self.assertIn('Invalid column name', str(context.exception))
def test_build_simple_condition_invalid_operator(self):
"""Test that invalid operators are rejected."""
with self.assertRaises(ValueError) as context:
self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
self.assertIn('Invalid operator', str(context.exception))
def test_build_in_condition_valid(self):
"""Test building valid IN conditions."""
sql, params = self.builder._build_in_condition('AND', 'eve_EventType', 'IN', "'Connected', 'Disconnected'")
self.assertIn('AND eve_EventType IN', sql)
self.assertEqual(len(params), 2)
self.assertIn('Connected', params.values())
self.assertIn('Disconnected', params.values())
def test_build_null_condition(self):
"""Test building NULL check conditions."""
sql, params = self.builder._build_null_condition('AND', 'devComments', 'IS NULL')
self.assertEqual(sql, 'AND devComments IS NULL')
self.assertEqual(len(params), 0)
def test_sql_injection_attempts(self):
"""Test that various SQL injection attempts are blocked."""
injection_attempts = [
"'; DROP TABLE Devices; --",
"' UNION SELECT * FROM Settings --",
"' OR 1=1 --",
"'; INSERT INTO Events VALUES(1,2,3); --",
"' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
"'; ATTACH DATABASE '/etc/passwd' AS pwn; --"
]
for injection in injection_attempts:
with self.subTest(injection=injection):
with self.assertRaises(ValueError):
self.builder.build_safe_condition(f"AND devName = '{injection}'")
def test_legacy_condition_compatibility(self):
"""Test backward compatibility with legacy condition formats."""
# Test simple condition
sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
self.assertIn('devName', sql)
self.assertIn('TestDevice', params.values())
# Test empty condition
sql, params = self.builder.get_safe_condition_legacy("")
self.assertEqual(sql, "")
self.assertEqual(params, {})
# Test invalid condition returns empty
sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
self.assertEqual(sql, "")
self.assertEqual(params, {})
def test_device_name_filter(self):
"""Test the device name filter helper method."""
sql, params = self.builder.build_device_name_filter("TestDevice")
self.assertIn('AND devName = :device_name_', sql)
self.assertIn('TestDevice', params.values())
def test_event_type_filter(self):
"""Test the event type filter helper method."""
event_types = ['Connected', 'Disconnected']
sql, params = self.builder.build_event_type_filter(event_types)
self.assertIn('AND eve_EventType IN', sql)
self.assertEqual(len(params), 2)
self.assertIn('Connected', params.values())
self.assertIn('Disconnected', params.values())
def test_event_type_filter_whitelist(self):
"""Test that event type filter enforces whitelist."""
# Valid event types
valid_types = ['Connected', 'New Device']
sql, params = self.builder.build_event_type_filter(valid_types)
self.assertEqual(len(params), 2)
# Mix of valid and invalid event types
mixed_types = ['Connected', 'InvalidEventType', 'Device Down']
sql, params = self.builder.build_event_type_filter(mixed_types)
self.assertEqual(len(params), 2) # Only valid types should be included
# All invalid event types
invalid_types = ['InvalidType1', 'InvalidType2']
sql, params = self.builder.build_event_type_filter(invalid_types)
self.assertEqual(sql, "")
self.assertEqual(params, {})
class TestDatabaseParameterSupport(unittest.TestCase):
"""Test that database layer supports parameterized queries."""
def setUp(self):
"""Set up test database."""
self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
self.temp_db.close()
# Create test database
self.conn = sqlite3.connect(self.temp_db.name)
self.conn.execute('''CREATE TABLE test_table (
id INTEGER PRIMARY KEY,
name TEXT,
value TEXT
)''')
self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test1', 'value1')")
self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test2', 'value2')")
self.conn.commit()
def tearDown(self):
"""Clean up test database."""
self.conn.close()
os.unlink(self.temp_db.name)
def test_parameterized_query_execution(self):
"""Test that parameterized queries work correctly."""
cursor = self.conn.cursor()
# Test named parameters
cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': 'test1'})
results = cursor.fetchall()
self.assertEqual(len(results), 1)
self.assertEqual(results[0][1], 'test1')
def test_parameterized_query_prevents_injection(self):
"""Test that parameterized queries prevent SQL injection."""
cursor = self.conn.cursor()
# This should not cause SQL injection
malicious_input = "'; DROP TABLE test_table; --"
cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': malicious_input})
results = cursor.fetchall()
# The table should still exist and be queryable
cursor.execute("SELECT COUNT(*) FROM test_table")
count = cursor.fetchone()[0]
self.assertEqual(count, 2) # Original data should still be there
class TestReportingSecurityIntegration(unittest.TestCase):
"""Integration tests for the secure reporting functionality."""
def setUp(self):
"""Set up test environment for reporting tests."""
self.mock_db = Mock()
self.mock_db.sql = Mock()
self.mock_db.get_table_as_json = Mock()
# Mock successful JSON response
mock_json_obj = Mock()
mock_json_obj.columnNames = ['MAC', 'Datetime', 'IP', 'Event Type', 'Device name', 'Comments']
mock_json_obj.json = {'data': []}
self.mock_db.get_table_as_json.return_value = mock_json_obj
@patch('messaging.reporting.get_setting_value')
def test_new_devices_section_security(self, mock_get_setting):
"""Test that new devices section uses safe SQL building."""
# Mock settings
mock_get_setting.side_effect = lambda key: {
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
'NTFPRCS_new_dev_condition': "AND devName = 'TestDevice'"
}.get(key, '')
# Call the function
result = get_notifications(self.mock_db)
# Verify that get_table_as_json was called with parameters
self.mock_db.get_table_as_json.assert_called()
call_args = self.mock_db.get_table_as_json.call_args
# Should have been called with both query and parameters
self.assertEqual(len(call_args[0]), 1) # Query argument
self.assertEqual(len(call_args[1]), 1) # Parameters keyword argument
@patch('messaging.reporting.get_setting_value')
def test_events_section_security(self, mock_get_setting):
"""Test that events section uses safe SQL building."""
# Mock settings
mock_get_setting.side_effect = lambda key: {
'NTFPRCS_INCLUDED_SECTIONS': ['events'],
'NTFPRCS_event_condition': "AND devName = 'TestDevice'"
}.get(key, '')
# Call the function
result = get_notifications(self.mock_db)
# Verify that get_table_as_json was called with parameters
self.mock_db.get_table_as_json.assert_called()
@patch('messaging.reporting.get_setting_value')
def test_malicious_condition_handling(self, mock_get_setting):
"""Test that malicious conditions are safely handled."""
# Mock settings with malicious input
mock_get_setting.side_effect = lambda key: {
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
'NTFPRCS_new_dev_condition': "'; DROP TABLE Devices; --"
}.get(key, '')
# Call the function - should not raise an exception
result = get_notifications(self.mock_db)
# Should still call get_table_as_json (with safe fallback query)
self.mock_db.get_table_as_json.assert_called()
@patch('messaging.reporting.get_setting_value')
def test_empty_condition_handling(self, mock_get_setting):
"""Test that empty conditions are handled gracefully."""
# Mock settings with empty condition
mock_get_setting.side_effect = lambda key: {
'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
'NTFPRCS_new_dev_condition': ""
}.get(key, '')
# Call the function
result = get_notifications(self.mock_db)
# Should call get_table_as_json
self.mock_db.get_table_as_json.assert_called()
class TestSecurityBenchmarks(unittest.TestCase):
"""Performance and security benchmark tests."""
def setUp(self):
"""Set up benchmark environment."""
self.builder = SafeConditionBuilder()
def test_performance_simple_condition(self):
"""Test performance of simple condition building."""
import time
start_time = time.time()
for _ in range(1000):
sql, params = self.builder.build_safe_condition("AND devName = 'TestDevice'")
end_time = time.time()
execution_time = end_time - start_time
self.assertLess(execution_time, 1.0, "Simple condition building should be fast")
def test_memory_usage_parameter_generation(self):
"""Test memory usage of parameter generation."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Generate many conditions
for i in range(100):
builder = SafeConditionBuilder()
sql, params = builder.build_safe_condition(f"AND devName = 'Device{i}'")
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# Memory increase should be reasonable (less than 10MB)
self.assertLess(memory_increase, 10 * 1024 * 1024, "Memory usage should be reasonable")
def test_pattern_coverage(self):
"""Test coverage of condition patterns."""
patterns_tested = [
"AND devName = 'value'",
"OR eve_EventType LIKE '%test%'",
"AND devComments IS NULL",
"AND eve_EventType IN ('Connected', 'Disconnected')",
]
for pattern in patterns_tested:
with self.subTest(pattern=pattern):
try:
sql, params = self.builder.build_safe_condition(pattern)
self.assertIsInstance(sql, str)
self.assertIsInstance(params, dict)
except ValueError:
# Some patterns might be rejected, which is acceptable
pass
if __name__ == '__main__':
# Run the test suite
unittest.main(verbosity=2)