mirror of
https://github.com/JamesTheGiblet/BuddAI.git
synced 2026-01-08 21:58:40 +00:00
1034 lines
No EOL
34 KiB
Python
1034 lines
No EOL
34 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BuddAI v3.2 Test Suite
|
|
Comprehensive testing for all features
|
|
|
|
Author: James Gilbert
|
|
License: MIT
|
|
"""
|
|
|
|
import sys
|
|
import importlib.util
|
|
from unittest.mock import MagicMock, patch
|
|
import sqlite3
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import os
|
|
import io
|
|
import zipfile
|
|
import http.client
|
|
|
|
# Dynamic import of buddai_v3.2.py
|
|
REPO_ROOT = Path(__file__).parent.parent
|
|
MODULE_PATH = REPO_ROOT / "buddai_v3.2.py"
|
|
spec = importlib.util.spec_from_file_location("buddai_v3_2", MODULE_PATH)
|
|
buddai_module = importlib.util.module_from_spec(spec)
|
|
sys.modules["buddai_v3_2"] = buddai_module
|
|
spec.loader.exec_module(buddai_module)
|
|
BuddAI = buddai_module.BuddAI
|
|
|
|
# Test utilities
|
|
class TestColors:
|
|
PASS = '\033[92m'
|
|
FAIL = '\033[91m'
|
|
INFO = '\033[94m'
|
|
WARN = '\033[93m'
|
|
END = '\033[0m'
|
|
|
|
def print_test(name):
|
|
print(f"\n{TestColors.INFO}🧪 Testing: {name}{TestColors.END}")
|
|
|
|
def print_pass(message):
|
|
print(f" {TestColors.PASS}✅ {message}{TestColors.END}")
|
|
|
|
def print_fail(message):
|
|
print(f" {TestColors.FAIL}❌ {message}{TestColors.END}")
|
|
|
|
def print_warn(message):
|
|
print(f" {TestColors.WARN}⚠️ {message}{TestColors.END}")
|
|
|
|
|
|
# Test 1: Database Initialization
|
|
def test_database_init():
|
|
print_test("Database Initialization")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "test.db"
|
|
|
|
# Create tables
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
started_at TIMESTAMP,
|
|
ended_at TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT,
|
|
role TEXT,
|
|
content TEXT,
|
|
timestamp TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS repo_index (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
file_path TEXT,
|
|
repo_name TEXT,
|
|
function_name TEXT,
|
|
content TEXT,
|
|
last_modified TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS style_preferences (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
category TEXT,
|
|
preference TEXT,
|
|
confidence FLOAT,
|
|
extracted_at TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
|
|
# Verify tables exist
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
required_tables = ['sessions', 'messages', 'repo_index', 'style_preferences']
|
|
all_exist = all(table in tables for table in required_tables)
|
|
|
|
conn.close()
|
|
|
|
if all_exist:
|
|
print_pass(f"All {len(required_tables)} tables created successfully")
|
|
return True
|
|
else:
|
|
missing = [t for t in required_tables if t not in tables]
|
|
print_fail(f"Missing tables: {', '.join(missing)}")
|
|
return False
|
|
|
|
|
|
# Test 2: SQL Injection Prevention
|
|
def test_sql_injection_prevention():
|
|
print_test("SQL Injection Prevention")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "test.db"
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
CREATE TABLE repo_index (
|
|
id INTEGER PRIMARY KEY,
|
|
function_name TEXT,
|
|
content TEXT
|
|
)
|
|
""")
|
|
|
|
# Insert test data
|
|
cursor.execute("INSERT INTO repo_index (function_name, content) VALUES (?, ?)",
|
|
("testFunc", "test content"))
|
|
conn.commit()
|
|
|
|
# Test malicious input
|
|
malicious_input = "'; DROP TABLE repo_index; --"
|
|
|
|
# SECURE: Parameterized query
|
|
try:
|
|
cursor.execute("SELECT * FROM repo_index WHERE function_name LIKE ?",
|
|
(f"%{malicious_input}%",))
|
|
results = cursor.fetchall()
|
|
|
|
# Verify table still exists
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='repo_index'")
|
|
table_exists = cursor.fetchone() is not None
|
|
|
|
conn.close()
|
|
|
|
if table_exists:
|
|
print_pass("Parameterized queries prevent SQL injection")
|
|
print_pass("Table survived malicious input")
|
|
return True
|
|
else:
|
|
print_fail("Table was dropped - vulnerable to injection!")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print_fail(f"Query failed: {e}")
|
|
return False
|
|
|
|
|
|
# Test 3: Auto-Learning Pattern Extraction
|
|
def test_auto_learning():
|
|
print_test("Auto-Learning Pattern Extraction")
|
|
|
|
sample_code = """
|
|
#include <Arduino.h>
|
|
|
|
#define MOTOR_PIN 5
|
|
const int TIMEOUT_MS = 5000;
|
|
|
|
void setup() {
|
|
Serial.begin(115200);
|
|
ledcSetup(0, 500, 8);
|
|
}
|
|
"""
|
|
|
|
import re
|
|
|
|
patterns = {
|
|
'serial_baud': re.search(r'Serial\.begin\((\d+)\)', sample_code),
|
|
'pin_style': 'define' if '#define' in sample_code else 'const',
|
|
'timeout_value': re.search(r'TIMEOUT.*?(\d+)', sample_code),
|
|
'pwm_freq': re.search(r'ledcSetup\([^,]+,\s*(\d+)', sample_code),
|
|
}
|
|
|
|
extracted = {}
|
|
for key, value in patterns.items():
|
|
if value:
|
|
extracted[key] = value.group(1) if hasattr(value, 'group') else str(value)
|
|
|
|
expected = {
|
|
'serial_baud': '115200',
|
|
'pin_style': 'define',
|
|
'timeout_value': '5000',
|
|
'pwm_freq': '500'
|
|
}
|
|
|
|
success = True
|
|
for key, expected_val in expected.items():
|
|
actual_val = extracted.get(key)
|
|
if actual_val == expected_val:
|
|
print_pass(f"Extracted {key}: {actual_val}")
|
|
else:
|
|
print_fail(f"Failed to extract {key} (got {actual_val}, expected {expected_val})")
|
|
success = False
|
|
|
|
return success
|
|
|
|
|
|
# Test 4: Module Detection
|
|
def test_module_detection():
|
|
print_test("Module Detection")
|
|
|
|
MODULE_PATTERNS = {
|
|
"ble": ["bluetooth", "ble", "wireless"],
|
|
"servo": ["servo", "flipper", "weapon"],
|
|
"motor": ["motor", "drive", "movement", "l298n"],
|
|
"safety": ["safety", "timeout", "failsafe"],
|
|
}
|
|
|
|
test_cases = [
|
|
("Generate code with BLE and servo control", ["ble", "servo"]),
|
|
("Add motor driver with safety timeout", ["motor", "safety"]),
|
|
("Build complete robot with bluetooth, motors, and weapon", ["ble", "motor", "servo"]),
|
|
]
|
|
|
|
def extract_modules(message):
|
|
message_lower = message.lower()
|
|
detected = []
|
|
for module, keywords in MODULE_PATTERNS.items():
|
|
if any(kw in message_lower for kw in keywords):
|
|
detected.append(module)
|
|
return detected
|
|
|
|
success = True
|
|
for message, expected_modules in test_cases:
|
|
detected = extract_modules(message)
|
|
if set(detected) == set(expected_modules):
|
|
print_pass(f"Detected: {detected} from '{message[:50]}...'")
|
|
else:
|
|
print_fail(f"Expected {expected_modules}, got {detected}")
|
|
success = False
|
|
|
|
return success
|
|
|
|
|
|
# Test 5: Complexity Detection
|
|
def test_complexity_detection():
|
|
print_test("Complexity Detection")
|
|
|
|
COMPLEX_TRIGGERS = ["complete", "entire", "full", "build entire"]
|
|
MODULE_PATTERNS = {
|
|
"ble": ["bluetooth", "ble"],
|
|
"servo": ["servo"],
|
|
"motor": ["motor"],
|
|
}
|
|
|
|
def is_complex(message):
|
|
message_lower = message.lower()
|
|
trigger_count = sum(1 for trigger in COMPLEX_TRIGGERS if trigger in message_lower)
|
|
module_count = sum(1 for module, keywords in MODULE_PATTERNS.items()
|
|
if any(kw in message_lower for kw in keywords))
|
|
return trigger_count >= 2 or module_count >= 3
|
|
|
|
test_cases = [
|
|
("Generate a motor driver class", False),
|
|
("Build complete robot with BLE, servo, and motors", True),
|
|
("Create entire system with full integration", True),
|
|
("What pins should I use?", False),
|
|
]
|
|
|
|
success = True
|
|
for message, expected_complex in test_cases:
|
|
detected = is_complex(message)
|
|
if detected == expected_complex:
|
|
complexity = "COMPLEX" if detected else "SIMPLE"
|
|
print_pass(f"{complexity}: '{message}'")
|
|
else:
|
|
print_fail(f"Expected {expected_complex}, got {detected} for '{message}'")
|
|
success = False
|
|
|
|
return success
|
|
|
|
|
|
# Test 6: LRU Cache Performance
|
|
def test_lru_cache():
|
|
print_test("LRU Cache Performance")
|
|
|
|
from functools import lru_cache
|
|
import time
|
|
|
|
call_count = 0
|
|
|
|
@lru_cache(maxsize=128)
|
|
def cached_function(keywords):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
time.sleep(0.01) # Simulate slow operation
|
|
return f"Result for {keywords}"
|
|
|
|
# First call - should execute
|
|
cached_function(("motor", "servo"))
|
|
first_count = call_count
|
|
|
|
# Second call - should use cache
|
|
cached_function(("motor", "servo"))
|
|
second_count = call_count
|
|
|
|
# Different call - should execute
|
|
cached_function(("ble", "battery"))
|
|
third_count = call_count
|
|
|
|
if first_count == 1 and second_count == 1 and third_count == 2:
|
|
print_pass("Cache working: 2nd call skipped execution")
|
|
print_pass(f"Function called {call_count} times for 3 queries")
|
|
return True
|
|
else:
|
|
print_fail(f"Cache not working properly: {first_count}, {second_count}, {third_count}")
|
|
return False
|
|
|
|
|
|
# Test 7: Session Export
|
|
def test_session_export():
|
|
print_test("Session Export")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
export_path = Path(tmpdir) / "test_export.md"
|
|
|
|
# Simulate export
|
|
session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
messages = [
|
|
("user", "Generate motor code", "2025-12-28 10:00:00"),
|
|
("assistant", "```cpp\nvoid setupMotors() {}\n```", "2025-12-28 10:00:05"),
|
|
]
|
|
|
|
output = f"# BuddAI Session Export\n"
|
|
output += f"**Session ID:** {session_id}\n\n"
|
|
output += "---\n\n"
|
|
|
|
for role, content, timestamp in messages:
|
|
if role == 'user':
|
|
output += f"## 🧑 James ({timestamp})\n{content}\n\n"
|
|
else:
|
|
output += f"## 🤖 BuddAI\n{content}\n\n"
|
|
|
|
with open(export_path, 'w', encoding='utf-8') as f:
|
|
f.write(output)
|
|
|
|
# Verify export
|
|
if export_path.exists():
|
|
content = export_path.read_text(encoding='utf-8')
|
|
has_session_id = session_id in content
|
|
has_code = "```cpp" in content
|
|
has_headers = "## " in content and "James" in content # More flexible check
|
|
|
|
if has_session_id and has_code and has_headers:
|
|
print_pass("Export file created with correct format")
|
|
print_pass(f"File size: {len(content)} bytes")
|
|
return True
|
|
else:
|
|
if not has_session_id:
|
|
print_fail("Missing session ID")
|
|
if not has_code:
|
|
print_fail("Missing code blocks")
|
|
if not has_headers:
|
|
print_fail("Missing headers")
|
|
return False
|
|
else:
|
|
print_fail("Export file not created")
|
|
return False
|
|
|
|
|
|
# Test 8: Actionable Suggestions
|
|
def test_actionable_suggestions():
|
|
print_test("Actionable Suggestions")
|
|
|
|
user_input = "Generate motor driver with L298N"
|
|
generated_code = """
|
|
void setupMotors() {
|
|
pinMode(MOTOR_PIN, OUTPUT);
|
|
}
|
|
"""
|
|
|
|
suggestions = []
|
|
|
|
# Forge Theory Check
|
|
if ("motor" in user_input.lower() or "servo" in user_input.lower()) and "applyForge" not in generated_code:
|
|
suggestions.append({
|
|
'text': "Apply Forge Theory smoothing?",
|
|
'action': 'add_forge',
|
|
'code': "float applyForge(float current, float target, float k) { return target + (current - target) * exp(-k); }"
|
|
})
|
|
|
|
# Safety Check
|
|
if "L298N" in user_input and "safety" not in generated_code.lower():
|
|
suggestions.append({
|
|
'text': "Add 5s safety timeout?",
|
|
'action': 'add_safety',
|
|
'code': "unsigned long lastCommandTime = 0;\nconst unsigned long TIMEOUT_MS = 5000;"
|
|
})
|
|
|
|
if len(suggestions) == 2:
|
|
print_pass(f"Generated {len(suggestions)} actionable suggestions")
|
|
for i, s in enumerate(suggestions, 1):
|
|
print_pass(f" {i}. {s['text']} (action: {s['action']})")
|
|
if s['code']:
|
|
print_pass(f" Code snippet: {len(s['code'])} chars")
|
|
return True
|
|
else:
|
|
print_fail(f"Expected 2 suggestions, got {len(suggestions)}")
|
|
return False
|
|
|
|
|
|
# Test 9: Repository Indexing
|
|
def test_repository_indexing():
|
|
print_test("Repository Indexing")
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Create test repository structure
|
|
repo_dir = Path(tmpdir) / "test_repo"
|
|
repo_dir.mkdir()
|
|
|
|
# Create test files
|
|
test_files = {
|
|
"motor_driver.ino": """
|
|
void setupMotors() {
|
|
Serial.begin(115200);
|
|
pinMode(MOTOR_PIN, OUTPUT);
|
|
}
|
|
|
|
void driveForward(int speed) {
|
|
digitalWrite(MOTOR_PIN, HIGH);
|
|
}
|
|
""",
|
|
"servo_control.cpp": """
|
|
#include <Servo.h>
|
|
|
|
void activateFlipper() {
|
|
servo.write(90);
|
|
}
|
|
""",
|
|
"utils.py": """
|
|
def calculate_pwm(speed):
|
|
return int(speed * 255 / 100)
|
|
|
|
def apply_forge(current, target, k):
|
|
return target + (current - target) * math.exp(-k)
|
|
"""
|
|
}
|
|
|
|
for filename, content in test_files.items():
|
|
(repo_dir / filename).write_text(content)
|
|
|
|
# Simulate indexing
|
|
import re
|
|
indexed_functions = []
|
|
|
|
for file_path in repo_dir.rglob('*'):
|
|
if file_path.is_file() and file_path.suffix in ['.ino', '.cpp', '.py']:
|
|
content = file_path.read_text()
|
|
|
|
if file_path.suffix in ['.ino', '.cpp']:
|
|
matches = re.findall(r'\b(?:void|int)\s+(\w+)\s*\(', content)
|
|
indexed_functions.extend(matches)
|
|
elif file_path.suffix == '.py':
|
|
matches = re.findall(r'\bdef\s+(\w+)\s*\(', content)
|
|
indexed_functions.extend(matches)
|
|
|
|
expected_functions = ['setupMotors', 'driveForward', 'activateFlipper', 'calculate_pwm', 'apply_forge']
|
|
|
|
if set(indexed_functions) == set(expected_functions):
|
|
print_pass(f"Indexed {len(indexed_functions)} functions correctly")
|
|
for func in indexed_functions:
|
|
print_pass(f" - {func}()")
|
|
return True
|
|
else:
|
|
missing = set(expected_functions) - set(indexed_functions)
|
|
extra = set(indexed_functions) - set(expected_functions)
|
|
if missing:
|
|
print_fail(f"Missing functions: {missing}")
|
|
if extra:
|
|
print_warn(f"Extra functions: {extra}")
|
|
return False
|
|
|
|
|
|
# Test 10: Search Query Safety
|
|
def test_search_query_safety():
|
|
print_test("Search Query Safety")
|
|
|
|
malicious_queries = [
|
|
"'; DROP TABLE repo_index; --",
|
|
"' OR '1'='1",
|
|
"admin'--",
|
|
"<script>alert('xss')</script>",
|
|
]
|
|
|
|
import re
|
|
|
|
success = True
|
|
for query in malicious_queries:
|
|
# Extract keywords safely
|
|
keywords = re.findall(r'\b\w{4,}\b', query.lower())
|
|
|
|
# Build parameterized query
|
|
conditions = []
|
|
params = []
|
|
for keyword in keywords:
|
|
conditions.append("(function_name LIKE ? OR content LIKE ?)")
|
|
params.extend([f"%{keyword}%", f"%{keyword}%"])
|
|
|
|
# Verify no SQL injection possible
|
|
if conditions:
|
|
safe_sql = f"SELECT * FROM repo_index WHERE {' OR '.join(conditions)}"
|
|
# SQL should only contain placeholders
|
|
if "DROP" not in safe_sql and "'; " not in safe_sql:
|
|
print_pass(f"Safely handled: '{query[:30]}...'")
|
|
else:
|
|
print_fail(f"Potential injection: '{query}'")
|
|
success = False
|
|
else:
|
|
print_pass(f"Rejected empty query: '{query}'")
|
|
|
|
return success
|
|
|
|
|
|
# Test 11: Context Window Management
|
|
def test_context_window():
|
|
print_test("Context Window Management")
|
|
|
|
context_messages = []
|
|
|
|
# Add many messages
|
|
for i in range(20):
|
|
context_messages.append({"role": "user", "content": f"Message {i}"})
|
|
context_messages.append({"role": "assistant", "content": f"Response {i}"})
|
|
|
|
# Simulate limiting to last 5 messages
|
|
limited_context = context_messages[-5:]
|
|
|
|
if len(limited_context) == 5:
|
|
print_pass(f"Context limited to {len(limited_context)} messages (from {len(context_messages)})")
|
|
print_pass(f"Oldest kept: '{limited_context[0]['content']}'")
|
|
print_pass(f"Newest kept: '{limited_context[-1]['content']}'")
|
|
return True
|
|
else:
|
|
print_fail(f"Context not limited properly: {len(limited_context)} messages")
|
|
return False
|
|
|
|
|
|
# Test 12: Schedule Awareness (New)
|
|
def test_schedule_awareness():
|
|
print_test("Schedule Awareness")
|
|
|
|
# Mock datetime to test different times
|
|
with patch('buddai_v3_2.datetime') as mock_date:
|
|
# 1. Early Morning (Monday 6:00 AM)
|
|
mock_date.now.return_value = datetime(2025, 12, 29, 6, 0, 0)
|
|
|
|
buddai = BuddAI(server_mode=False)
|
|
status = buddai.get_user_status()
|
|
|
|
if "Early Morning" in status:
|
|
print_pass(f"6:00 AM Mon -> {status}")
|
|
else:
|
|
print_fail(f"Failed Morning check: {status}")
|
|
return False
|
|
|
|
# 2. Work Hours (Monday 10:00 AM)
|
|
mock_date.now.return_value = datetime(2025, 12, 29, 10, 0, 0)
|
|
status = buddai.get_user_status()
|
|
|
|
if "Work Hours" in status:
|
|
print_pass(f"10:00 AM Mon -> {status}")
|
|
else:
|
|
print_fail(f"Failed Work check: {status}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# Test 13: Modular Plan Generation (New)
|
|
def test_modular_plan():
|
|
print_test("Modular Plan Generation")
|
|
|
|
buddai = BuddAI(server_mode=False)
|
|
modules = ["ble", "servo"]
|
|
plan = buddai.build_modular_plan(modules)
|
|
|
|
# Expect 3 steps: ble, servo, integration
|
|
if len(plan) == 3:
|
|
tasks = [p['module'] for p in plan]
|
|
if "integration" in tasks and "ble" in tasks:
|
|
print_pass(f"Generated {len(plan)} steps including Integration")
|
|
return True
|
|
|
|
print_fail(f"Plan generation failed. Got {len(plan)} steps: {plan}")
|
|
return False
|
|
|
|
|
|
# Test 14: Session Management (New)
|
|
def test_session_management():
|
|
print_test("Session Management (CRUD)")
|
|
|
|
# Use a named temporary file to handle Windows file locking better
|
|
fd, test_db_path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
test_db = Path(test_db_path)
|
|
|
|
try:
|
|
with patch('buddai_v3_2.DB_PATH', test_db):
|
|
buddai = BuddAI(server_mode=False)
|
|
|
|
# 1. Create
|
|
sid = buddai.start_new_session()
|
|
print_pass(f"Created session: {sid}")
|
|
|
|
# 2. Rename
|
|
buddai.rename_session(sid, "Unit Test Session")
|
|
sessions = buddai.get_sessions(limit=1)
|
|
if not sessions or sessions[0]['title'] != "Unit Test Session":
|
|
print_fail("Rename failed")
|
|
return False
|
|
print_pass("Renamed session successfully")
|
|
|
|
# 3. Delete
|
|
buddai.delete_session(sid)
|
|
sessions = buddai.get_sessions(limit=5)
|
|
if any(s['id'] == sid for s in sessions):
|
|
print_fail("Delete failed - session still exists")
|
|
return False
|
|
print_pass("Deleted session successfully")
|
|
finally:
|
|
# Manual cleanup with error suppression for Windows locks
|
|
try:
|
|
if test_db.exists():
|
|
os.unlink(test_db)
|
|
except Exception:
|
|
pass
|
|
|
|
return True
|
|
|
|
|
|
# Test 15: Rapid Session Creation (Collision Handling)
|
|
def test_rapid_session_creation():
|
|
print_test("Rapid Session Creation (Collision Handling)")
|
|
|
|
# Use a named temporary file to handle Windows file locking better
|
|
fd, test_db_path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
test_db = Path(test_db_path)
|
|
|
|
try:
|
|
# Mock datetime to return a fixed time, forcing ID collisions
|
|
fixed_time = datetime(2025, 1, 1, 12, 0, 0)
|
|
|
|
with patch('buddai_v3_2.DB_PATH', test_db):
|
|
with patch('buddai_v3_2.datetime') as mock_dt:
|
|
mock_dt.now.return_value = fixed_time
|
|
|
|
buddai = BuddAI(server_mode=False)
|
|
|
|
ids = [buddai.session_id] # Capture session from __init__
|
|
|
|
# Create 4 more sessions rapidly
|
|
for _ in range(4):
|
|
ids.append(buddai.start_new_session())
|
|
|
|
# Verify format
|
|
base_id = fixed_time.strftime("%Y%m%d_%H%M%S")
|
|
expected = [base_id] + [f"{base_id}_{i}" for i in range(1, 5)]
|
|
|
|
if ids == expected:
|
|
print_pass(f"Generated unique IDs with suffixes: {ids}")
|
|
return True
|
|
else:
|
|
print_fail(f"Unexpected ID format. Expected {expected}, got {ids}")
|
|
return False
|
|
finally:
|
|
try:
|
|
if test_db.exists():
|
|
os.unlink(test_db)
|
|
except Exception:
|
|
pass
|
|
|
|
# Test 16: Repository Isolation (Multi-User)
|
|
def test_repo_isolation():
|
|
print_test("Repository Isolation (Multi-User)")
|
|
|
|
# Use a named temporary file for DB
|
|
fd, test_db_path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
test_db = Path(test_db_path)
|
|
|
|
# Create a temp directory for repo
|
|
with tempfile.TemporaryDirectory() as tmp_repo:
|
|
repo_path = Path(tmp_repo)
|
|
|
|
# Create a unique file for User 1
|
|
(repo_path / "user1_secret.py").write_text("def user1_secret_function():\n pass")
|
|
|
|
try:
|
|
with patch('buddai_v3_2.DB_PATH', test_db):
|
|
# Suppress internal prints to keep test output clean
|
|
with patch('builtins.print'):
|
|
# User 1 indexes the repo
|
|
buddai1 = BuddAI(user_id="user1", server_mode=False)
|
|
buddai1.index_local_repositories(str(repo_path))
|
|
|
|
# User 2 instance
|
|
buddai2 = BuddAI(user_id="user2", server_mode=False)
|
|
|
|
# User 1 searches
|
|
res1 = buddai1.search_repositories("user1_secret_function")
|
|
|
|
# User 2 searches
|
|
res2 = buddai2.search_repositories("user1_secret_function")
|
|
|
|
# Verify User 1 found it
|
|
if "Found 1 matches" in res1 or "user1_secret_function" in res1:
|
|
print_pass("User 1 found their indexed code")
|
|
else:
|
|
print_fail(f"User 1 failed to find code: {res1}")
|
|
return False
|
|
|
|
# Verify User 2 did NOT find it
|
|
if "No functions found" in res2:
|
|
print_pass("User 2 could not see User 1's code")
|
|
else:
|
|
print_fail(f"User 2 saw restricted code: {res2}")
|
|
return False
|
|
|
|
finally:
|
|
try:
|
|
if test_db.exists():
|
|
os.unlink(test_db)
|
|
except Exception:
|
|
pass
|
|
|
|
return True
|
|
|
|
# Test 17: Upload Security (Hardening)
|
|
def test_upload_security():
|
|
print_test("Upload Security (Hardening)")
|
|
|
|
# 1. Test Magic Byte Check
|
|
# We need to mock UploadFile since it's a FastAPI class
|
|
class MockUploadFile:
|
|
def __init__(self, filename, content):
|
|
self.filename = filename
|
|
self.file = io.BytesIO(content)
|
|
self.content_type = "application/zip"
|
|
|
|
if hasattr(buddai_module, 'validate_upload'):
|
|
# Create a fake zip (text file renamed)
|
|
fake_zip = MockUploadFile("fake.zip", b"This is not a zip file")
|
|
try:
|
|
buddai_module.validate_upload(fake_zip)
|
|
print_fail("Magic byte check failed (accepted invalid zip)")
|
|
return False
|
|
except ValueError as e:
|
|
if "Invalid ZIP file header" in str(e):
|
|
print_pass("Magic byte check rejected invalid zip header")
|
|
else:
|
|
print_fail(f"Unexpected error: {e}")
|
|
return False
|
|
else:
|
|
print_warn("Skipping magic byte check (validate_upload not available)")
|
|
|
|
# 2. Test Zip Slip Protection
|
|
if hasattr(buddai_module, 'safe_extract_zip'):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
malicious_zip_path = Path(tmpdir) / "slip.zip"
|
|
extract_dir = Path(tmpdir) / "extract"
|
|
extract_dir.mkdir()
|
|
|
|
# Create a zip file with a member pointing to parent directory
|
|
# We use zipfile to craft this manually
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, 'w') as zf:
|
|
zf.writestr('../evil.txt', 'malicious content')
|
|
|
|
malicious_zip_path.write_bytes(zip_buffer.getvalue())
|
|
|
|
try:
|
|
buddai_module.safe_extract_zip(malicious_zip_path, extract_dir)
|
|
print_fail("Zip Slip protection failed (extracted malicious file)")
|
|
return False
|
|
except ValueError as e:
|
|
if "Malicious zip member" in str(e):
|
|
print_pass("Zip Slip protection caught directory traversal")
|
|
else:
|
|
print_fail(f"Unexpected error during extraction: {e}")
|
|
return False
|
|
return True
|
|
|
|
# Test 18: WebSocket Logic (Streaming)
|
|
def test_websocket_logic():
|
|
print_test("WebSocket Logic (Streaming)")
|
|
|
|
# Use a named temporary file for DB
|
|
fd, test_db_path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
test_db = Path(test_db_path)
|
|
|
|
try:
|
|
with patch('buddai_v3_2.DB_PATH', test_db):
|
|
# Suppress prints during init
|
|
with patch('builtins.print'):
|
|
buddai = BuddAI(server_mode=False)
|
|
|
|
# Mock call_model to return a generator
|
|
def mock_generator(*args, **kwargs):
|
|
yield "Stream"
|
|
yield "ing"
|
|
yield "..."
|
|
|
|
with patch.object(buddai, 'call_model', side_effect=mock_generator) as mock_call:
|
|
# Mock shadow engine to avoid DB lookups or side effects affecting output
|
|
with patch.object(buddai.shadow_engine, 'get_all_suggestions', return_value=[]):
|
|
|
|
# Execute
|
|
stream = buddai.chat_stream("Test Message", force_model="fast")
|
|
chunks = list(stream)
|
|
full_text = "".join(chunks)
|
|
|
|
# Verify 1: Content
|
|
if full_text == "Streaming...":
|
|
print_pass("Streamed content matches expected output")
|
|
else:
|
|
print_fail(f"Stream content mismatch. Got: '{full_text}'")
|
|
return False
|
|
|
|
# Verify 2: Stream flag passed to model
|
|
args, kwargs = mock_call.call_args
|
|
if kwargs.get('stream') is True:
|
|
print_pass("call_model invoked with stream=True")
|
|
else:
|
|
print_fail(f"call_model arguments incorrect: {kwargs}")
|
|
return False
|
|
|
|
# Verify 3: Context saved
|
|
last_msg = buddai.context_messages[-1]
|
|
if last_msg['role'] == 'assistant' and last_msg['content'] == "Streaming...":
|
|
print_pass("Conversation context updated correctly")
|
|
else:
|
|
print_fail("Context update failed")
|
|
return False
|
|
|
|
finally:
|
|
try:
|
|
if test_db.exists():
|
|
os.unlink(test_db)
|
|
except Exception:
|
|
pass
|
|
|
|
return True
|
|
|
|
# Test 19: Connection Pooling
|
|
def test_connection_pool():
|
|
print_test("Connection Pooling")
|
|
|
|
if not hasattr(buddai_module, 'OLLAMA_POOL'):
|
|
print_fail("OLLAMA_POOL not found in module")
|
|
return False
|
|
|
|
pool = buddai_module.OLLAMA_POOL
|
|
|
|
# Drain pool first to ensure clean state for test
|
|
while not pool.pool.empty():
|
|
try:
|
|
c = pool.pool.get_nowait()
|
|
c.close()
|
|
except:
|
|
break
|
|
|
|
# 1. Get a connection (should create new)
|
|
conn1 = pool.get_connection()
|
|
if not isinstance(conn1, http.client.HTTPConnection):
|
|
print_fail("get_connection did not return HTTPConnection")
|
|
return False
|
|
print_pass("Successfully retrieved connection from pool")
|
|
|
|
# 2. Return connection
|
|
pool.return_connection(conn1)
|
|
if pool.pool.qsize() == 1:
|
|
print_pass("Connection returned to pool (size=1)")
|
|
else:
|
|
print_fail(f"Pool size incorrect after return. Expected 1, got {pool.pool.qsize()}")
|
|
return False
|
|
|
|
# 3. Reuse connection
|
|
conn2 = pool.get_connection()
|
|
if conn2 is conn1:
|
|
print_pass("Pool reused the existing connection object")
|
|
else:
|
|
print_fail("Pool created new connection instead of reusing")
|
|
return False
|
|
|
|
# 4. Overflow handling
|
|
# Fill beyond max size (default 10)
|
|
# conn2 is currently checked out, so pool is empty
|
|
for _ in range(15):
|
|
c = http.client.HTTPConnection("localhost", 11434)
|
|
pool.return_connection(c)
|
|
|
|
if pool.pool.full():
|
|
print_pass("Pool capped at max size, excess connections discarded")
|
|
return True
|
|
else:
|
|
print_fail(f"Pool overflow handling failed. Size: {pool.pool.qsize()}")
|
|
return False
|
|
|
|
# Test 20: Feedback System
|
|
def test_feedback_system():
|
|
print_test("Feedback System")
|
|
|
|
# Use a named temporary file for DB
|
|
fd, test_db_path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
test_db = Path(test_db_path)
|
|
|
|
try:
|
|
with patch('buddai_v3_2.DB_PATH', test_db):
|
|
# Suppress prints
|
|
with patch('builtins.print'):
|
|
buddai = BuddAI(server_mode=False)
|
|
|
|
# 1. Create a message to rate
|
|
msg_id = buddai.save_message("assistant", "Test response")
|
|
|
|
# 2. Record positive feedback
|
|
buddai.record_feedback(msg_id, True)
|
|
|
|
# 3. Verify in DB
|
|
conn = sqlite3.connect(test_db)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT positive FROM feedback WHERE message_id = ?", (msg_id,))
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if row and row[0] == 1: # Boolean true is 1 in sqlite
|
|
print_pass("Positive feedback recorded successfully")
|
|
return True
|
|
else:
|
|
print_fail(f"Feedback not recorded correctly. Got: {row}")
|
|
return False
|
|
|
|
finally:
|
|
try:
|
|
if test_db.exists():
|
|
os.unlink(test_db)
|
|
except Exception:
|
|
pass
|
|
|
|
# Main Test Runner
|
|
def run_all_tests():
|
|
print("\n" + "="*60)
|
|
print("🔥 BuddAI v3.2 Comprehensive Test Suite")
|
|
print("="*60)
|
|
|
|
tests = [
|
|
("Database Initialization", test_database_init),
|
|
("SQL Injection Prevention", test_sql_injection_prevention),
|
|
("Auto-Learning", test_auto_learning),
|
|
("Module Detection", test_module_detection),
|
|
("Complexity Detection", test_complexity_detection),
|
|
("LRU Cache", test_lru_cache),
|
|
("Session Export", test_session_export),
|
|
("Actionable Suggestions", test_actionable_suggestions),
|
|
("Repository Indexing", test_repository_indexing),
|
|
("Search Query Safety", test_search_query_safety),
|
|
("Context Window", test_context_window),
|
|
("Schedule Awareness", test_schedule_awareness),
|
|
("Modular Plan", test_modular_plan),
|
|
("Session Management", test_session_management),
|
|
("Rapid Session Creation", test_rapid_session_creation),
|
|
("Repository Isolation", test_repo_isolation),
|
|
("Upload Security", test_upload_security),
|
|
("WebSocket Logic", test_websocket_logic),
|
|
("Connection Pooling", test_connection_pool),
|
|
("Feedback System", test_feedback_system),
|
|
]
|
|
|
|
results = []
|
|
for name, test_func in tests:
|
|
try:
|
|
result = test_func()
|
|
results.append((name, result))
|
|
except Exception as e:
|
|
print_fail(f"Test crashed: {e}")
|
|
results.append((name, False))
|
|
|
|
# Summary
|
|
print("\n" + "="*60)
|
|
print("📊 Test Results Summary")
|
|
print("="*60)
|
|
|
|
passed = sum(1 for _, result in results if result)
|
|
total = len(results)
|
|
|
|
for name, result in results:
|
|
status = f"{TestColors.PASS}✅ PASS{TestColors.END}" if result else f"{TestColors.FAIL}❌ FAIL{TestColors.END}"
|
|
print(f"{status} - {name}")
|
|
|
|
print("\n" + "="*60)
|
|
percentage = int((passed / total) * 100)
|
|
|
|
if passed == total:
|
|
print(f"{TestColors.PASS}🎉 ALL TESTS PASSED: {passed}/{total} ({percentage}%){TestColors.END}")
|
|
elif passed >= total * 0.8:
|
|
print(f"{TestColors.WARN}⚠️ MOST TESTS PASSED: {passed}/{total} ({percentage}%){TestColors.END}")
|
|
else:
|
|
print(f"{TestColors.FAIL}❌ TESTS FAILED: {passed}/{total} ({percentage}%){TestColors.END}")
|
|
|
|
print("="*60 + "\n")
|
|
|
|
return passed == total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = run_all_tests()
|
|
sys.exit(0 if success else 1) |