BuddAI/tests/test_buddai.py
JamesTheGiblet 27601aa2ba Add comprehensive unit tests for BuddAI functionality
- Introduced 16 additional coverage tests in `test_additional_coverage.py` to enhance overall test coverage.
- Added 15 extended feature tests in `test_extended_features.py` to validate new functionalities.
- Implemented 27 final coverage tests in `test_final_coverage.py` to achieve a total of 100 tests.
- Created 2 fallback logic tests in `test_fallback_logic.py` to ensure proper fallback behavior based on confidence scores.
- Each test suite covers various aspects of the BuddAI system, including command handling, database interactions, and hardware detection.
2026-01-07 19:48:24 +00:00

626 lines
No EOL
25 KiB
Python

#!/usr/bin/env python3
"""
BuddAI v3.2 Test Suite
Comprehensive testing for all features
Author: James Gilbert
License: MIT
"""
import sys
import importlib.util
import unittest
from unittest.mock import MagicMock, patch
import sqlite3
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
import os
import io
import zipfile
import http.client
# Dynamic import of buddai_v3.2.py
REPO_ROOT = Path(__file__).parent.parent
MODULE_PATH = REPO_ROOT / "buddai_executive.py"
spec = importlib.util.spec_from_file_location("buddai_executive", MODULE_PATH)
buddai_module = importlib.util.module_from_spec(spec)
sys.modules["buddai_executive"] = buddai_module
spec.loader.exec_module(buddai_module)
BuddAI = buddai_module.BuddAI
class TestBuddAICore(unittest.TestCase):
# Test 1: Database Initialization
def test_database_init(self):
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
# Create tables
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
started_at TIMESTAMP,
ended_at TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
role TEXT,
content TEXT,
timestamp TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS repo_index (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT,
repo_name TEXT,
function_name TEXT,
content TEXT,
last_modified TIMESTAMP
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS style_preferences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
category TEXT,
preference TEXT,
confidence FLOAT,
extracted_at TIMESTAMP
)
""")
conn.commit()
# Verify tables exist
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
required_tables = ['sessions', 'messages', 'repo_index', 'style_preferences']
all_exist = all(table in tables for table in required_tables)
conn.close()
self.assertTrue(all_exist, f"Missing tables: {[t for t in required_tables if t not in tables]}")
# Test 2: SQL Injection Prevention
def test_sql_injection_prevention(self):
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE repo_index (
id INTEGER PRIMARY KEY,
function_name TEXT,
content TEXT
)
""")
# Insert test data
cursor.execute("INSERT INTO repo_index (function_name, content) VALUES (?, ?)",
("testFunc", "test content"))
conn.commit()
# Test malicious input
malicious_input = "'; DROP TABLE repo_index; --"
# SECURE: Parameterized query
cursor.execute("SELECT * FROM repo_index WHERE function_name LIKE ?",
(f"%{malicious_input}%",))
results = cursor.fetchall()
# Verify table still exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='repo_index'")
table_exists = cursor.fetchone() is not None
conn.close()
self.assertTrue(table_exists, "Table was dropped - vulnerable to injection!")
# Test 3: Auto-Learning Pattern Extraction
def test_auto_learning(self):
sample_code = """
#include <Arduino.h>
#define MOTOR_PIN 5
const int TIMEOUT_MS = 5000;
void setup() {
Serial.begin(115200);
ledcSetup(0, 500, 8);
}
"""
import re
patterns = {
'serial_baud': re.search(r'Serial\.begin\((\d+)\)', sample_code),
'pin_style': 'define' if '#define' in sample_code else 'const',
'timeout_value': re.search(r'TIMEOUT.*?(\d+)', sample_code),
'pwm_freq': re.search(r'ledcSetup\([^,]+,\s*(\d+)', sample_code),
}
extracted = {}
for key, value in patterns.items():
if value:
extracted[key] = value.group(1) if hasattr(value, 'group') else str(value)
expected = {
'serial_baud': '115200',
'pin_style': 'define',
'timeout_value': '5000',
'pwm_freq': '500'
}
for key, expected_val in expected.items():
self.assertEqual(extracted.get(key), expected_val, f"Failed to extract {key}")
# Test 4: Module Detection
def test_module_detection(self):
MODULE_PATTERNS = {
"ble": ["bluetooth", "ble", "wireless"],
"servo": ["servo", "flipper", "weapon"],
"motor": ["motor", "drive", "movement", "l298n"],
"safety": ["safety", "timeout", "failsafe"],
}
test_cases = [
("Generate code with BLE and servo control", ["ble", "servo"]),
("Add motor driver with safety timeout", ["motor", "safety"]),
("Build complete robot with bluetooth, motors, and weapon", ["ble", "motor", "servo"]),
]
def extract_modules(message):
message_lower = message.lower()
detected = []
for module, keywords in MODULE_PATTERNS.items():
if any(kw in message_lower for kw in keywords):
detected.append(module)
return detected
for message, expected_modules in test_cases:
detected = extract_modules(message)
self.assertEqual(set(detected), set(expected_modules), f"Failed for '{message}'")
# Test 5: Complexity Detection
def test_complexity_detection(self):
COMPLEX_TRIGGERS = ["complete", "entire", "full", "build entire"]
MODULE_PATTERNS = {
"ble": ["bluetooth", "ble"],
"servo": ["servo"],
"motor": ["motor"],
}
def is_complex(message):
message_lower = message.lower()
trigger_count = sum(1 for trigger in COMPLEX_TRIGGERS if trigger in message_lower)
module_count = sum(1 for module, keywords in MODULE_PATTERNS.items()
if any(kw in message_lower for kw in keywords))
return trigger_count >= 2 or module_count >= 3
test_cases = [
("Generate a motor driver class", False),
("Build complete robot with BLE, servo, and motors", True),
("Create entire system with full integration", True),
("What pins should I use?", False),
]
for message, expected_complex in test_cases:
detected = is_complex(message)
self.assertEqual(detected, expected_complex, f"Failed for '{message}'")
# Test 6: LRU Cache Performance
def test_lru_cache(self):
from functools import lru_cache
import time
call_count = 0
@lru_cache(maxsize=128)
def cached_function(keywords):
nonlocal call_count
call_count += 1
time.sleep(0.01)
return f"Result for {keywords}"
cached_function(("motor", "servo"))
first_count = call_count
cached_function(("motor", "servo"))
second_count = call_count
cached_function(("ble", "battery"))
third_count = call_count
self.assertEqual(first_count, 1)
self.assertEqual(second_count, 1)
self.assertEqual(third_count, 2)
# Test 7: Session Export
def test_session_export(self):
with tempfile.TemporaryDirectory() as tmpdir:
export_path = Path(tmpdir) / "test_export.md"
session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
messages = [
("user", "Generate motor code", "2025-12-28 10:00:00"),
("assistant", "```cpp\nvoid setupMotors() {}\n```", "2025-12-28 10:00:05"),
]
output = f"# BuddAI Session Export\n"
output += f"**Session ID:** {session_id}\n\n"
output += "---\n\n"
for role, content, timestamp in messages:
if role == 'user':
output += f"## 🧑 James ({timestamp})\n{content}\n\n"
else:
output += f"## 🤖 BuddAI\n{content}\n\n"
with open(export_path, 'w', encoding='utf-8') as f:
f.write(output)
self.assertTrue(export_path.exists())
content = export_path.read_text(encoding='utf-8')
self.assertIn(session_id, content)
self.assertIn("```cpp", content)
self.assertIn("## ", content)
# Test 8: Actionable Suggestions
def test_actionable_suggestions(self):
user_input = "Generate motor driver with L298N"
generated_code = """
void setupMotors() {
pinMode(MOTOR_PIN, OUTPUT);
}
"""
suggestions = []
if ("motor" in user_input.lower() or "servo" in user_input.lower()) and "applyForge" not in generated_code:
suggestions.append({'text': "Apply Forge Theory smoothing?", 'action': 'add_forge'})
if "L298N" in user_input and "safety" not in generated_code.lower():
suggestions.append({'text': "Add 5s safety timeout?", 'action': 'add_safety'})
self.assertEqual(len(suggestions), 2)
# Test 9: Repository Indexing
def test_repository_indexing(self):
with tempfile.TemporaryDirectory() as tmpdir:
repo_dir = Path(tmpdir) / "test_repo"
repo_dir.mkdir()
test_files = {
"motor_driver.ino": "void setupMotors() { }",
"servo_control.cpp": "void activateFlipper() { }",
"utils.py": "def calculate_pwm(speed): pass"
}
for filename, content in test_files.items():
(repo_dir / filename).write_text(content)
import re
indexed_functions = []
for file_path in repo_dir.rglob('*'):
if file_path.is_file() and file_path.suffix in ['.ino', '.cpp', '.py']:
content = file_path.read_text()
if file_path.suffix in ['.ino', '.cpp']:
matches = re.findall(r'\b(?:void|int)\s+(\w+)\s*\(', content)
indexed_functions.extend(matches)
elif file_path.suffix == '.py':
matches = re.findall(r'\bdef\s+(\w+)\s*\(', content)
indexed_functions.extend(matches)
expected_functions = ['setupMotors', 'activateFlipper', 'calculate_pwm']
self.assertEqual(set(indexed_functions), set(expected_functions))
# Test 10: Search Query Safety
def test_search_query_safety(self):
malicious_queries = [
"'; DROP TABLE repo_index; --",
"' OR '1'='1",
"admin'--",
"<script>alert('xss')</script>",
]
import re
for query in malicious_queries:
keywords = re.findall(r'\b\w{4,}\b', query.lower())
conditions = []
for keyword in keywords:
conditions.append("(function_name LIKE ? OR content LIKE ?)")
if conditions:
safe_sql = f"SELECT * FROM repo_index WHERE {' OR '.join(conditions)}"
self.assertNotIn("DROP", safe_sql)
self.assertNotIn("'; ", safe_sql)
# Test 11: Context Window Management
def test_context_window(self):
context_messages = []
for i in range(20):
context_messages.append({"role": "user", "content": f"Message {i}"})
context_messages.append({"role": "assistant", "content": f"Response {i}"})
limited_context = context_messages[-5:]
self.assertEqual(len(limited_context), 5)
self.assertEqual(limited_context[0]['content'], "Response 17")
self.assertEqual(limited_context[-1]['content'], "Response 19")
# Test 12: Schedule Awareness
def test_schedule_awareness(self):
with patch('core.buddai_personality.datetime') as mock_date:
mock_date.now.return_value = datetime(2025, 12, 29, 6, 0, 0)
buddai = BuddAI(server_mode=False)
status = buddai.personality_manager.get_user_status()
self.assertIn("Early Morning", status)
mock_date.now.return_value = datetime(2025, 12, 29, 10, 0, 0)
status = buddai.personality_manager.get_user_status()
self.assertIn("Work Hours", status)
# Test 13: Modular Plan Generation
def test_modular_plan(self):
buddai = BuddAI(server_mode=False)
modules = ["ble", "servo"]
plan = buddai.prompt_engine.build_modular_plan(modules)
self.assertEqual(len(plan), 3)
tasks = [p['module'] for p in plan]
self.assertIn("integration", tasks)
# Test 14: Session Management
def test_session_management(self):
fd, test_db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
test_db = Path(test_db_path)
# Patch both executive and shared DB_PATH to ensure StorageManager uses temp DB
patchers = [patch.object(buddai_module, 'DB_PATH', test_db)]
if 'core.buddai_shared' in sys.modules:
patchers.append(patch('core.buddai_shared.DB_PATH', test_db))
try:
for p in patchers: p.start()
try:
buddai = BuddAI(server_mode=False)
sid = buddai.start_new_session()
buddai.rename_session(sid, "Unit Test Session")
sessions = buddai.get_sessions(limit=1)
self.assertEqual(sessions[0]['title'], "Unit Test Session")
buddai.delete_session(sid)
sessions = buddai.get_sessions(limit=5)
self.assertFalse(any(s['id'] == sid for s in sessions))
finally:
for p in reversed(patchers): p.stop()
finally:
try:
if test_db.exists(): os.unlink(test_db)
except Exception: pass
# Test 15: Rapid Session Creation
def test_rapid_session_creation(self):
fd, test_db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
test_db = Path(test_db_path)
# Ensure REPO_ROOT is in path to import core modules
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
# Ensure storage module is loaded to allow patching
if 'core.buddai_storage' not in sys.modules:
try:
import core.buddai_storage
except ImportError:
pass
patchers = [patch.object(buddai_module, 'DB_PATH', test_db)]
if 'core.buddai_shared' in sys.modules:
patchers.append(patch('core.buddai_shared.DB_PATH', test_db))
if 'core.buddai_storage' in sys.modules:
patchers.append(patch('core.buddai_storage.DB_PATH', test_db))
try:
fixed_time = datetime(2025, 1, 1, 12, 0, 0)
for p in patchers: p.start()
dt_patchers = [patch.object(buddai_module, 'datetime')]
if 'core.buddai_storage' in sys.modules:
dt_patchers.append(patch('core.buddai_storage.datetime'))
try:
for p in dt_patchers:
mock_dt = p.start()
mock_dt.now.return_value = fixed_time
# Handle case where datetime is imported as a module
mock_dt.datetime.now.return_value = fixed_time
buddai = BuddAI(server_mode=False)
ids = [buddai.storage.current_session_id]
for _ in range(4):
ids.append(buddai.start_new_session())
base_id = fixed_time.strftime("%Y%m%d_%H%M%S")
expected = [base_id] + [f"{base_id}_{i}" for i in range(1, 5)]
self.assertEqual(ids, expected)
finally:
for p in reversed(dt_patchers): p.stop()
for p in reversed(patchers): p.stop()
finally:
try:
if test_db.exists(): os.unlink(test_db)
except Exception: pass
# Test 16: Repository Isolation
def test_repo_isolation(self):
fd, test_db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
test_db = Path(test_db_path)
with tempfile.TemporaryDirectory() as tmp_repo:
repo_path = Path(tmp_repo)
(repo_path / "user1_secret.py").write_text("def user1_secret_function():\n pass")
try:
with patch.object(buddai_module, 'DB_PATH', test_db):
with patch('builtins.print'):
# Create repo_index table
conn = sqlite3.connect(test_db)
conn.execute("CREATE TABLE IF NOT EXISTS repo_index (id INTEGER PRIMARY KEY, file_path TEXT, repo_name TEXT, function_name TEXT, content TEXT, last_modified TIMESTAMP, user_id TEXT)")
conn.commit()
conn.close()
buddai1 = BuddAI(user_id="user1", server_mode=False)
buddai1.repo_manager.index_local_repositories(str(repo_path))
buddai2 = BuddAI(user_id="user2", server_mode=False)
res1 = buddai1.repo_manager.search_repositories("user1_secret_function")
res2 = buddai2.repo_manager.search_repositories("user1_secret_function")
self.assertTrue("Found 1 matches" in res1 or "user1_secret_function" in res1)
self.assertIn("No functions found", res2)
finally:
try:
if test_db.exists(): os.unlink(test_db)
except Exception: pass
# Test 17: Upload Security
def test_upload_security(self):
class MockUploadFile:
def __init__(self, filename, content):
self.filename = filename
self.file = io.BytesIO(content)
self.content_type = "application/zip"
if hasattr(buddai_module, 'validate_upload'):
fake_zip = MockUploadFile("fake.zip", b"This is not a zip file")
with self.assertRaises(ValueError):
buddai_module.validate_upload(fake_zip)
if hasattr(buddai_module, 'safe_extract_zip'):
with tempfile.TemporaryDirectory() as tmpdir:
malicious_zip_path = Path(tmpdir) / "slip.zip"
extract_dir = Path(tmpdir) / "extract"
extract_dir.mkdir()
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zf:
zf.writestr('../evil.txt', 'malicious content')
malicious_zip_path.write_bytes(zip_buffer.getvalue())
with self.assertRaises(ValueError):
buddai_module.safe_extract_zip(malicious_zip_path, extract_dir)
# Test 18: WebSocket Logic
def test_websocket_logic(self):
fd, test_db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
test_db = Path(test_db_path)
try:
with patch.object(buddai_module, 'DB_PATH', test_db):
with patch('builtins.print'):
# Create tables
conn = sqlite3.connect(test_db)
conn.execute("CREATE TABLE IF NOT EXISTS repo_index (id INTEGER PRIMARY KEY, file_path TEXT, repo_name TEXT, function_name TEXT, content TEXT, last_modified TIMESTAMP, user_id TEXT)")
conn.execute("CREATE TABLE IF NOT EXISTS messages (id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, role TEXT, content TEXT, timestamp TIMESTAMP)")
conn.commit()
conn.close()
buddai = BuddAI(server_mode=False)
def mock_generator(*args, **kwargs):
yield "Stream"
yield "ing"
yield "..."
with patch.object(buddai, 'call_model', side_effect=mock_generator) as mock_call:
with patch.object(buddai.shadow_engine, 'get_all_suggestions', return_value=[]):
stream = buddai.chat_stream("Test Message", force_model="fast")
chunks = list(stream)
full_text = "".join(chunks)
self.assertEqual(full_text, "Streaming...")
args, kwargs = mock_call.call_args
self.assertTrue(kwargs.get('stream'))
finally:
try:
if test_db.exists(): os.unlink(test_db)
except Exception: pass
# Test 19: Connection Pooling
def test_connection_pool(self):
if not hasattr(buddai_module, 'OLLAMA_POOL'):
return
pool = buddai_module.OLLAMA_POOL
while not pool.pool.empty():
try:
c = pool.pool.get_nowait()
c.close()
except: break
conn1 = pool.get_connection()
self.assertIsInstance(conn1, http.client.HTTPConnection)
pool.return_connection(conn1)
self.assertEqual(pool.pool.qsize(), 1)
conn2 = pool.get_connection()
self.assertIs(conn2, conn1)
# Test 20: Feedback System
def test_feedback_system(self):
fd, test_db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
test_db = Path(test_db_path)
if 'core.buddai_storage' not in sys.modules:
try:
import core.buddai_storage
except ImportError:
pass
# Patch DB_PATH in both executive and shared to ensure consistency
patchers = [patch.object(buddai_module, 'DB_PATH', test_db)]
if 'core.buddai_shared' in sys.modules:
patchers.append(patch('core.buddai_shared.DB_PATH', test_db))
if 'core.buddai_storage' in sys.modules:
patchers.append(patch('core.buddai_storage.DB_PATH', test_db))
try:
for p in patchers: p.start()
try:
with patch('builtins.print'):
# Create feedback and messages table manually for test
conn = sqlite3.connect(test_db)
conn.execute("CREATE TABLE IF NOT EXISTS feedback (message_id INTEGER, positive BOOLEAN, comment TEXT, timestamp TEXT)")
conn.execute("CREATE TABLE IF NOT EXISTS messages (id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, role TEXT, content TEXT, timestamp TIMESTAMP)")
conn.commit()
conn.close()
buddai = BuddAI(server_mode=False)
msg_id = buddai.storage.save_message("assistant", "Test response")
buddai.record_feedback(msg_id, True)
conn = sqlite3.connect(test_db)
cursor = conn.cursor()
cursor.execute("SELECT positive FROM feedback WHERE message_id = ?", (msg_id,))
row = cursor.fetchone()
conn.close()
self.assertIsNotNone(row, "Feedback row not found in database")
self.assertEqual(row[0], 1)
finally:
for p in reversed(patchers): p.stop()
finally:
try:
if test_db.exists(): os.unlink(test_db)
except Exception: pass
if __name__ == "__main__":
unittest.main()