mirror of
https://github.com/JamesTheGiblet/BuddAI.git
synced 2026-01-08 21:58:40 +00:00
Implement core skills: Code validation, model fine-tuning, and system diagnostics
- Added `ModelFineTuner` class for preparing training data and fine-tuning models based on user corrections. - Introduced `CodeValidator` class to validate generated code against various hardware and style rules, including safety checks and function naming conventions. - Developed skills for calculator operations, system information retrieval, weather fetching, and timer functionality. - Implemented a self-diagnostic skill to run unit tests and report results. - Created a dynamic skill loading mechanism to discover and register skills from the current directory. - Added unit tests for skills to ensure functionality and reliability.
This commit is contained in:
parent
743f9f311d
commit
f9fd27d228
28 changed files with 2398 additions and 4077 deletions
1211
buddai_executive.py
1211
buddai_executive.py
File diff suppressed because it is too large
Load diff
|
|
@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Tuple, Union, Generator
|
|||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from buddai_shared import SERVER_AVAILABLE, DATA_DIR, DB_PATH, MODELS, OLLAMA_HOST, OLLAMA_PORT
|
||||
from core.buddai_shared import SERVER_AVAILABLE, DATA_DIR, DB_PATH, MODELS, OLLAMA_HOST, OLLAMA_PORT
|
||||
from buddai_executive import BuddAI
|
||||
|
||||
# (Removed duplicate definitions of check_ollama, is_port_available, and main to resolve indentation and duplication errors)
|
||||
|
|
@ -166,7 +166,7 @@ app.mount("/web", StaticFiles(directory=frontend_path, html=True), name="web")
|
|||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root(request: Request):
|
||||
server_buddai = buddai_manager.get_instance("default")
|
||||
status = server_buddai.get_user_status()
|
||||
status = server_buddai.personality_manager.get_user_status()
|
||||
|
||||
public_url = getattr(request.app.state, "public_url", "")
|
||||
qr_section = ""
|
||||
|
|
@ -562,7 +562,7 @@ async def upload_repo(file: UploadFile = File(...), user_id: str = Header("defau
|
|||
extract_path = uploads_dir / file_location.stem
|
||||
extract_path.mkdir(exist_ok=True)
|
||||
safe_extract_zip(file_location, extract_path)
|
||||
server_buddai.index_local_repositories(extract_path)
|
||||
server_buddai.repo_manager.index_local_repositories(extract_path)
|
||||
file_location.unlink() # Cleanup zip
|
||||
return {"message": f"✅ Successfully indexed {safe_name}"}
|
||||
else:
|
||||
|
|
@ -572,7 +572,7 @@ async def upload_repo(file: UploadFile = File(...), user_id: str = Header("defau
|
|||
target_dir.mkdir(exist_ok=True)
|
||||
final_path = target_dir / safe_name
|
||||
shutil.move(str(file_location), str(final_path))
|
||||
server_buddai.index_local_repositories(target_dir)
|
||||
server_buddai.repo_manager.index_local_repositories(target_dir)
|
||||
return {"message": f"✅ Successfully indexed {safe_name}"}
|
||||
|
||||
return {"message": f"✅ Successfully uploaded {safe_name}"}
|
||||
|
|
|
|||
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
72
core/buddai_analytics.py
Normal file
72
core/buddai_analytics.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from core.buddai_shared import DB_PATH
|
||||
|
||||
class LearningMetrics:
|
||||
"""Measure BuddAI's improvement over time"""
|
||||
|
||||
def calculate_accuracy(self):
|
||||
"""What % of code is accepted without correction?"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_responses,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive_feedback,
|
||||
COUNT(CASE WHEN c.id IS NOT NULL THEN 1 END) as corrected
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
LEFT JOIN corrections c ON m.content LIKE '%' || c.original_code || '%'
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp > ?
|
||||
""", (thirty_days_ago,))
|
||||
|
||||
total, positive, corrected = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
accuracy = (positive / total) * 100 if total and total > 0 else 0
|
||||
correction_rate = (corrected / total) * 100 if total and total > 0 else 0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correction_rate": correction_rate,
|
||||
"improvement": self.calculate_trend()
|
||||
}
|
||||
|
||||
def calculate_trend(self):
|
||||
"""Is BuddAI getting better over time?"""
|
||||
# Compare last 7 days vs previous 7 days
|
||||
recent = self.get_accuracy_for_period(7)
|
||||
previous = self.get_accuracy_for_period(7, offset=7)
|
||||
|
||||
improvement = recent - previous
|
||||
return f"+{improvement:.1f}%" if improvement > 0 else f"{improvement:.1f}%"
|
||||
|
||||
def get_accuracy_for_period(self, days: int, offset: int = 0) -> float:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
start_dt = (datetime.now() - timedelta(days=days + offset)).isoformat()
|
||||
end_dt = (datetime.now() - timedelta(days=offset)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp BETWEEN ? AND ?
|
||||
""", (start_dt, end_dt))
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return 0.0
|
||||
|
||||
total, positive = row
|
||||
return (positive / total) * 100 if total and total > 0 else 0.0
|
||||
181
core/buddai_knowledge.py
Normal file
181
core/buddai_knowledge.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import sqlite3
|
||||
import re
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
class RepoManager:
|
||||
"""Manages local repository indexing and retrieval (RAG)"""
|
||||
|
||||
def __init__(self, db_path: Path, user_id: str):
|
||||
self.db_path = db_path
|
||||
self.user_id = user_id
|
||||
|
||||
def is_search_query(self, message: str) -> bool:
|
||||
"""Check if this is a search query that should query repo_index"""
|
||||
message_lower = message.lower()
|
||||
search_triggers = [
|
||||
"show me", "find", "search for", "list all",
|
||||
"what functions", "which repos", "do i have",
|
||||
"where did i", "have i used", "examples of",
|
||||
"show all", "display"
|
||||
]
|
||||
return any(trigger in message_lower for trigger in search_triggers)
|
||||
|
||||
def search_repositories(self, query: str) -> str:
|
||||
"""Search repo_index for relevant functions and code"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM repo_index WHERE user_id = ?", (self.user_id,))
|
||||
count = cursor.fetchone()[0]
|
||||
print(f"\n🔍 Searching {count} indexed functions...\n")
|
||||
|
||||
# Extract keywords from query
|
||||
keywords = re.findall(r'\b\w{4,}\b', query.lower())
|
||||
# Add specific search terms
|
||||
specific_terms = []
|
||||
if "exponential" in query.lower() or "decay" in query.lower():
|
||||
specific_terms.append("applyForge")
|
||||
specific_terms.append("exp(")
|
||||
if "forge" in query.lower():
|
||||
specific_terms.append("Forge")
|
||||
keywords.extend(specific_terms)
|
||||
|
||||
if not keywords:
|
||||
print("❌ No search terms found")
|
||||
conn.close()
|
||||
return "No search terms provided."
|
||||
|
||||
# Build parameterized query
|
||||
conditions = []
|
||||
params = []
|
||||
for keyword in keywords:
|
||||
conditions.append("(function_name LIKE ? OR content LIKE ? OR repo_name LIKE ?)")
|
||||
params.extend([f"%{keyword}%", f"%{keyword}%", f"%{keyword}%"])
|
||||
|
||||
sql = f"SELECT repo_name, file_path, function_name, content FROM repo_index WHERE ({' OR '.join(conditions)}) AND user_id = ? ORDER BY last_modified DESC LIMIT 10"
|
||||
params.append(self.user_id)
|
||||
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
if not results:
|
||||
return f"❌ No functions found matching: {', '.join(keywords)}\n\nTry: /index <path> to index more repositories"
|
||||
# Format results
|
||||
output = f"✅ Found {len(results)} matches for: {', '.join(set(keywords))}\n\n"
|
||||
for i, (repo, file_path, func, content) in enumerate(results, 1):
|
||||
# Extract relevant snippet
|
||||
lines = content.split('\n')
|
||||
snippet_lines = []
|
||||
for line in lines[:30]: # First 30 lines
|
||||
if any(kw in line.lower() for kw in keywords):
|
||||
snippet_lines.append(line)
|
||||
if len(snippet_lines) >= 10:
|
||||
break
|
||||
if not snippet_lines:
|
||||
snippet_lines = lines[:10]
|
||||
snippet = '\n'.join(snippet_lines)
|
||||
output += f"**{i}. {func}()** in {repo}\n"
|
||||
output += f" 📁 {Path(file_path).name}\n"
|
||||
output += f"\n```cpp\n{snippet}\n```\n"
|
||||
output += f" ---\n\n"
|
||||
return output
|
||||
|
||||
def index_local_repositories(self, root_path: str) -> None:
|
||||
"""Crawl directories and index .py, .ino, and .cpp files"""
|
||||
print(f"\n🔍 Indexing repositories in: {root_path}")
|
||||
path = Path(root_path)
|
||||
|
||||
if not path.exists():
|
||||
print(f"❌ Path not found: {root_path}")
|
||||
return
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
count = 0
|
||||
|
||||
for file_path in path.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix in ['.py', '.ino', '.cpp', '.h', '.js', '.jsx', '.html', '.css']:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
functions = []
|
||||
|
||||
# Python parsing
|
||||
if file_path.suffix == '.py':
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
functions.append(node.name)
|
||||
except:
|
||||
pass
|
||||
|
||||
# C++/Arduino parsing
|
||||
elif file_path.suffix in ['.ino', '.cpp', '.h']:
|
||||
matches = re.findall(r'\b(?:void|int|bool|float|double|String|char)\s+(\w+)\s*\(', content)
|
||||
functions.extend(matches)
|
||||
|
||||
# JS/Web parsing
|
||||
elif file_path.suffix in ['.js', '.jsx']:
|
||||
matches = re.findall(r'(?:function\s+(\w+)|const\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>)', content)
|
||||
functions.extend([m[0] or m[1] for m in matches if m[0] or m[1]])
|
||||
|
||||
# HTML/CSS - Index as whole file
|
||||
elif file_path.suffix in ['.html', '.css']:
|
||||
functions.append("file_content")
|
||||
|
||||
# Determine repo name
|
||||
try:
|
||||
repo_name = file_path.relative_to(path).parts[0]
|
||||
except:
|
||||
repo_name = "unknown"
|
||||
|
||||
timestamp = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
for func in functions:
|
||||
cursor.execute("""
|
||||
INSERT INTO repo_index (user_id, file_path, repo_name, function_name, content, last_modified)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (self.user_id, str(file_path), repo_name, func, content, timestamp.isoformat()))
|
||||
count += 1
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"✅ Indexed {count} functions across repositories")
|
||||
|
||||
def retrieve_style_context(self, message: str, prompt_template: str, user_name: str) -> str:
|
||||
"""Search repo_index for code snippets matching the request"""
|
||||
# Extract potential keywords (nouns/modules)
|
||||
keywords = re.findall(r'\b\w{4,}\b', message.lower())
|
||||
if not keywords:
|
||||
return ""
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build a search query for function names or repo names
|
||||
search_terms = " OR ".join([f"function_name LIKE '%{k}%'" for k in keywords])
|
||||
search_terms += " OR " + " OR ".join([f"repo_name LIKE '%{k}%'" for k in keywords])
|
||||
|
||||
query = f"SELECT repo_name, function_name, content FROM repo_index WHERE ({search_terms}) AND user_id = ? LIMIT 2"
|
||||
|
||||
cursor.execute(query, (self.user_id,))
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_block = prompt_template.format(user_name=user_name)
|
||||
for repo, func, content in results:
|
||||
# Just grab the first 500 chars of the file to save context window
|
||||
snippet = content[:500] + "..."
|
||||
context_block += f"Repo: {repo} | Function: {func}\nCode:\n{snippet}\n---\n"
|
||||
|
||||
return context_block
|
||||
137
core/buddai_llm.py
Normal file
137
core/buddai_llm.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import http.client
|
||||
import socket
|
||||
from typing import List, Dict, Union, Generator, Optional
|
||||
from core.buddai_shared import MODELS, OLLAMA_POOL, OLLAMA_HOST, OLLAMA_PORT
|
||||
|
||||
class OllamaClient:
|
||||
"""Handles communication with the local Ollama instance"""
|
||||
|
||||
def query(self, model_key: str, messages: List[Dict], stream: bool = False, options: Dict = None) -> Union[str, Generator[str, None, None]]:
|
||||
"""Send a chat request to Ollama"""
|
||||
model_name = MODELS.get(model_key, model_key) # Handle key or direct name
|
||||
|
||||
default_options = {
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"num_ctx": 1024
|
||||
}
|
||||
if options:
|
||||
default_options.update(options)
|
||||
|
||||
body = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"options": default_options
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
json_body = json.dumps(body)
|
||||
|
||||
# Retry logic for connection stability
|
||||
for attempt in range(3):
|
||||
conn = None
|
||||
try:
|
||||
# Re-serialize in case of modification (CPU fallback)
|
||||
json_body = json.dumps(body)
|
||||
|
||||
conn = OLLAMA_POOL.get_connection()
|
||||
conn.request("POST", "/api/chat", json_body, headers)
|
||||
response = conn.getresponse()
|
||||
|
||||
if stream:
|
||||
if response.status != 200:
|
||||
error_text = response.read().decode('utf-8')
|
||||
conn.close()
|
||||
|
||||
# GPU OOM Detection -> CPU Fallback
|
||||
if "CUDA" in error_text or "buffer" in error_text:
|
||||
if "num_gpu" not in body["options"]:
|
||||
print("⚠️ GPU OOM detected. Switching to CPU mode...")
|
||||
body["options"]["num_gpu"] = 0 # Force CPU
|
||||
continue
|
||||
|
||||
try:
|
||||
err_msg = f"Error {response.status}: {json.loads(error_text).get('error', error_text)}"
|
||||
except:
|
||||
err_msg = f"Error {response.status}: {error_text}"
|
||||
|
||||
if "num_gpu" in body["options"]:
|
||||
err_msg += "\n\n(⚠️ CPU Mode also failed. System RAM might be full.)"
|
||||
elif "CUDA" in err_msg or "buffer" in err_msg:
|
||||
err_msg += "\n\n(⚠️ GPU Out of Memory. Retrying on CPU failed.)"
|
||||
|
||||
return (x for x in [err_msg])
|
||||
|
||||
return self._stream_response(response, conn)
|
||||
|
||||
if response.status == 200:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
OLLAMA_POOL.return_connection(conn)
|
||||
return data.get("message", {}).get("content", "No response")
|
||||
else:
|
||||
error_text = response.read().decode('utf-8')
|
||||
conn.close()
|
||||
|
||||
if "CUDA" in error_text or "buffer" in error_text:
|
||||
if "num_gpu" not in body["options"]:
|
||||
print("⚠️ GPU OOM detected. Switching to CPU mode...")
|
||||
body["options"]["num_gpu"] = 0
|
||||
continue
|
||||
|
||||
return f"Error {response.status}: {error_text}"
|
||||
|
||||
except (http.client.NotConnected, BrokenPipeError, ConnectionResetError, socket.timeout) as e:
|
||||
if conn: conn.close()
|
||||
if attempt == 2:
|
||||
return f"Error: Connection failed. {str(e)}"
|
||||
continue
|
||||
except Exception as e:
|
||||
if conn: conn.close()
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def _stream_response(self, response, conn) -> Generator[str, None, None]:
|
||||
"""Yield chunks from HTTP response"""
|
||||
fully_consumed = False
|
||||
has_content = False
|
||||
try:
|
||||
while True:
|
||||
line = response.readline()
|
||||
if not line: break
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
if "message" in data:
|
||||
content = data["message"].get("content", "")
|
||||
if content:
|
||||
has_content = True
|
||||
yield content
|
||||
if data.get("done"):
|
||||
fully_consumed = True
|
||||
break
|
||||
except: pass
|
||||
except Exception as e:
|
||||
yield f"\n[Stream Error: {str(e)}]"
|
||||
finally:
|
||||
if fully_consumed:
|
||||
OLLAMA_POOL.return_connection(conn)
|
||||
else:
|
||||
conn.close()
|
||||
|
||||
if not has_content and not fully_consumed:
|
||||
yield "\n[Error: Empty response from Ollama. Check if model is loaded.]"
|
||||
|
||||
def reset_gpu(self) -> str:
|
||||
"""Force unload models from GPU to free VRAM"""
|
||||
try:
|
||||
conn = http.client.HTTPConnection(OLLAMA_HOST, OLLAMA_PORT, timeout=10)
|
||||
for model in MODELS.values():
|
||||
body = json.dumps({"model": model, "keep_alive": 0})
|
||||
conn.request("POST", "/api/generate", body)
|
||||
resp = conn.getresponse()
|
||||
resp.read()
|
||||
conn.close()
|
||||
return "✅ GPU Memory Cleared (Models Unloaded)"
|
||||
except Exception as e:
|
||||
return f"❌ Error clearing GPU: {str(e)}"
|
||||
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Tuple, Union, Generator
|
||||
|
||||
from buddai_shared import DB_PATH, MODULE_PATTERNS
|
||||
from core.buddai_shared import DB_PATH, MODULE_PATTERNS
|
||||
|
||||
class ShadowSuggestionEngine:
|
||||
"""Proactively suggests modules/settings based on user/project history."""
|
||||
126
core/buddai_personality.py
Normal file
126
core/buddai_personality.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Union, List
|
||||
from datetime import datetime
|
||||
|
||||
class PersonalityManager:
|
||||
"""Manages AI personality, prompts, and user schedules"""
|
||||
|
||||
def __init__(self):
|
||||
self.personality = self.load_personality()
|
||||
self.validate_personality_schema()
|
||||
|
||||
def load_personality(self) -> Dict:
|
||||
"""Loads personality from a JSON file."""
|
||||
personality_path = Path(__file__).parent.parent / "personality.json"
|
||||
if personality_path.exists():
|
||||
print("🧠 Loading custom personality...")
|
||||
with open(personality_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
# Default personality if file doesn't exist
|
||||
print("🧠 Using default 'James' personality.")
|
||||
return {
|
||||
"user_name": "James",
|
||||
"ai_name": "BuddAI",
|
||||
"welcome_message": "BuddAI Executive v4.0 - Decoupled & Personality Sync",
|
||||
"schedule_check_triggers": ["what should i be doing", "my schedule", "schedule check"],
|
||||
"schedule": {
|
||||
"weekdays": {"0-4": {"5.5-6.5": "Early Morning Build Session 🌅 (5:30-6:30 AM)", "6.5-17.0": "Work Hours (Facilities Caretaker) 🏢", "17.0-21.0": "Evening Build Session 🌙 (5:00-9:00 PM)", "default": "Rest Time 💤"}},
|
||||
"saturday": { "5": { "default": "Weekend Freedom 🎨 (Creative Mode)" } },
|
||||
"sunday": { "6": { "0-21.0": "Weekend Freedom 🎨 (Until 9 PM)", "default": "Rest Time 💤" } }
|
||||
},
|
||||
"style_scan_prompt": "Analyze this code sample from {user_name}'s repositories.\nExtract 3 distinct coding preferences or patterns.",
|
||||
"style_reference_prompt": "\n[REFERENCE STYLE FROM {user_name}'S PAST PROJECTS]\n",
|
||||
"integration_task_prompt": "INTEGRATION TASK: Combine modules into a cohesive GilBot system.\n\n[MODULES]\n{modules_summary}\n\n[FORGE PARAMETERS]\nSet k = {k_val} for all applyForge() calls.\n\n[REQUIREMENTS]\n1. Implement applyForge() math helper.\n2. Use k={k_val} to smooth motor and servo transitions.\n3. Ensure naming matches {user_name}'s style: activateFlipper(), setMotors()."
|
||||
}
|
||||
|
||||
def validate_personality_schema(self) -> bool:
|
||||
"""Validate the loaded personality against required schema."""
|
||||
if not self.personality:
|
||||
return False
|
||||
|
||||
required_structure = {
|
||||
"meta": ["version"],
|
||||
"identity": ["user_name", "ai_name"],
|
||||
"communication": ["welcome_message"],
|
||||
"work_cycles": ["schedule"],
|
||||
"forge_theory": ["enabled", "constants"],
|
||||
"prompts": ["style_scan", "integration_task"]
|
||||
}
|
||||
|
||||
missing = []
|
||||
|
||||
version = self.get_value("meta.version")
|
||||
if version and version != "4.0":
|
||||
print(f"⚠️ Warning: Personality version mismatch. Loaded: {version}, Expected: 4.0")
|
||||
|
||||
for section, keys in required_structure.items():
|
||||
if section not in self.personality:
|
||||
missing.append(f"Missing section: {section}")
|
||||
continue
|
||||
|
||||
for key in keys:
|
||||
if key not in self.personality[section]:
|
||||
missing.append(f"Missing key: {section}.{key}")
|
||||
|
||||
if missing:
|
||||
print("⚠️ Personality Schema Validation Failed:")
|
||||
for m in missing:
|
||||
print(f" - {m}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_value(self, path: Union[str, List[str]], default: Any = None) -> Any:
|
||||
"""Access nested personality keys using dot notation or list of keys."""
|
||||
if isinstance(path, str):
|
||||
keys = path.split('.')
|
||||
else:
|
||||
keys = path
|
||||
|
||||
val = self.personality
|
||||
for key in keys:
|
||||
if isinstance(val, dict):
|
||||
val = val.get(key)
|
||||
else:
|
||||
return default
|
||||
return val if val is not None else default
|
||||
|
||||
def get_user_status(self) -> str:
|
||||
"""Determine user's context based on defined schedule from personality file."""
|
||||
schedule = self.get_value("work_cycles.schedule")
|
||||
if not schedule:
|
||||
# Fallback for simple personality files
|
||||
schedule = self.personality.get("schedule")
|
||||
if not schedule:
|
||||
return "Schedule not defined."
|
||||
|
||||
now = datetime.now()
|
||||
day = now.weekday() # 0=Mon, 6=Sun
|
||||
t = now.hour + (now.minute / 60.0)
|
||||
|
||||
# Check all schedule groups (e.g., 'weekdays', 'weekends')
|
||||
for group, day_ranges in schedule.items():
|
||||
for day_range, time_slots in day_ranges.items():
|
||||
try:
|
||||
# Parse day range (e.g., "0-4" or "5")
|
||||
if '-' in day_range:
|
||||
start_day, end_day = map(int, day_range.split('-'))
|
||||
if not (start_day <= day <= end_day):
|
||||
continue
|
||||
elif int(day_range) != day:
|
||||
continue
|
||||
|
||||
# We found the right day group, now check time slots
|
||||
for time_range, status in time_slots.items():
|
||||
if time_range == "default": continue
|
||||
start_time, end_time = map(float, time_range.split('-'))
|
||||
if start_time <= t < end_time:
|
||||
return status.get("description", status) if isinstance(status, dict) else status
|
||||
|
||||
default_status = time_slots.get("default", "No status for this time.")
|
||||
return default_status.get("description", default_status) if isinstance(default_status, dict) else default_status
|
||||
|
||||
except (ValueError, TypeError): continue
|
||||
return "No schedule match for today."
|
||||
310
core/buddai_prompt_engine.py
Normal file
310
core/buddai_prompt_engine.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
import sqlite3
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
from core.buddai_shared import DB_PATH, COMPLEX_TRIGGERS, MODULE_PATTERNS
|
||||
|
||||
class PromptEngine:
|
||||
"""Handles prompt construction, hardware classification, and request analysis"""
|
||||
|
||||
def classify_hardware(self, user_message: str, context_messages: List[Dict] = None) -> dict:
|
||||
"""Detect what hardware this question is about"""
|
||||
|
||||
hardware = {
|
||||
"servo": False,
|
||||
"dc_motor": False,
|
||||
"button": False,
|
||||
"led": False,
|
||||
"sensor": False,
|
||||
"weapon": False
|
||||
}
|
||||
|
||||
msg_lower = user_message.lower()
|
||||
|
||||
# Helper to check keywords
|
||||
def has_keywords(text, keywords):
|
||||
return any(word in text for word in keywords)
|
||||
|
||||
# Keyword definitions
|
||||
servo_kws = ['servo', 'mg996', 'sg90']
|
||||
motor_kws = ['l298n', 'dc motor', 'motor driver', 'motor control']
|
||||
button_kws = ['button', 'switch', 'trigger']
|
||||
led_kws = ['led', 'light', 'brightness', 'indicator']
|
||||
# Removed 'state machine' from weapon_kws to allow abstract logic
|
||||
weapon_kws = ['weapon', 'combat', 'arming', 'fire', 'spinner', 'flipper']
|
||||
logic_kws = ['state machine', 'logic', 'structure', 'flow', 'armed', 'disarmed']
|
||||
|
||||
# 1. Check current message first
|
||||
detected_in_current = False
|
||||
|
||||
if has_keywords(msg_lower, servo_kws):
|
||||
hardware["servo"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, motor_kws):
|
||||
hardware["dc_motor"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, button_kws):
|
||||
hardware["button"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, led_kws):
|
||||
hardware["led"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, weapon_kws):
|
||||
hardware["weapon"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, logic_kws):
|
||||
# Logic detected: Clear context (don't set any hardware)
|
||||
detected_in_current = True
|
||||
|
||||
# 2. Context Switching: Only look back if NO hardware/logic detected in current message
|
||||
# and message is short (likely a follow-up command like "make it spin")
|
||||
if not detected_in_current and len(user_message.split()) < 10 and context_messages:
|
||||
recent = " ".join([m['content'].lower() for m in context_messages[-2:] if m['role'] == 'user'])
|
||||
|
||||
if has_keywords(recent, servo_kws): hardware["servo"] = True
|
||||
if has_keywords(recent, motor_kws): hardware["dc_motor"] = True
|
||||
if has_keywords(recent, button_kws): hardware["button"] = True
|
||||
if has_keywords(recent, led_kws): hardware["led"] = True
|
||||
if has_keywords(recent, weapon_kws): hardware["weapon"] = True
|
||||
|
||||
return hardware
|
||||
|
||||
def get_all_rules(self) -> List[str]:
|
||||
"""Get all learned rules as text"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT rule_text FROM code_rules ORDER BY confidence DESC LIMIT 50")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def filter_rules_by_hardware(self, all_rules, hardware):
|
||||
"""Only return rules relevant to detected hardware"""
|
||||
|
||||
relevant_rules = []
|
||||
|
||||
# Define rule categories
|
||||
servo_kws = ['servo', 'attach', 'setperiodhertz']
|
||||
motor_kws = ['l298n', 'in1', 'in2', 'motor driver']
|
||||
weapon_kws = ['arming', 'disarm', 'fire', 'combat']
|
||||
button_kws = ['button', 'switch', 'debounce', 'digitalread', 'input_pullup']
|
||||
|
||||
has_specific_context = hardware["servo"] or hardware["dc_motor"] or hardware["weapon"] or hardware["button"]
|
||||
|
||||
for rule in all_rules:
|
||||
rule_lower = rule.lower()
|
||||
|
||||
is_servo_rule = any(w in rule_lower for w in servo_kws)
|
||||
is_motor_rule = any(w in rule_lower for w in motor_kws)
|
||||
is_weapon_rule = any(w in rule_lower for w in weapon_kws)
|
||||
is_button_rule = any(w in rule_lower for w in button_kws)
|
||||
|
||||
# Pattern Over-application: Strict filtering
|
||||
if has_specific_context:
|
||||
if hardware["dc_motor"] and not hardware["servo"] and is_servo_rule: continue
|
||||
if hardware["servo"] and not hardware["dc_motor"] and is_motor_rule: continue
|
||||
if not hardware["weapon"] and is_weapon_rule: continue
|
||||
if not hardware["button"] and is_button_rule: continue
|
||||
|
||||
# If question is about weapons (logic), EXCLUDE servo rules unless servo explicitly requested
|
||||
if hardware["weapon"] and not hardware["servo"] and is_servo_rule: continue
|
||||
|
||||
else:
|
||||
# Generic context: Exclude all specific hardware rules
|
||||
if is_servo_rule or is_motor_rule or is_weapon_rule or is_button_rule: continue
|
||||
|
||||
relevant_rules.append(rule)
|
||||
|
||||
return relevant_rules
|
||||
|
||||
def build_enhanced_prompt(self, user_message: str, hardware_detected: str = None, context_messages: List[Dict] = None) -> str:
|
||||
"""Build prompt with FILTERED rules"""
|
||||
|
||||
# Classify hardware
|
||||
hardware = self.classify_hardware(user_message, context_messages)
|
||||
|
||||
# Get ALL rules
|
||||
all_rules = self.get_all_rules()
|
||||
|
||||
# Filter by relevance
|
||||
relevant_rules = self.filter_rules_by_hardware(all_rules, hardware)
|
||||
|
||||
# Build focused prompt
|
||||
hardware_context = []
|
||||
if hardware["servo"]: hardware_context.append("SERVO CONTROL")
|
||||
if hardware["dc_motor"]: hardware_context.append("DC MOTOR CONTROL")
|
||||
if hardware["button"]: hardware_context.append("BUTTON INPUTS")
|
||||
if hardware["led"]: hardware_context.append("LED STATUS")
|
||||
if hardware["weapon"]: hardware_context.append("WEAPON SYSTEM")
|
||||
|
||||
l298n_rules = ""
|
||||
if hardware["dc_motor"]:
|
||||
l298n_rules = """
|
||||
- L298N WIRING RULES (MANDATORY):
|
||||
1. IN1/IN2 = Digital Output (Direction). Use digitalWrite().
|
||||
2. ENA = PWM Output (Speed). Use ledcWrite().
|
||||
3. To Move: IN1/IN2 must be OPPOSITE (HIGH/LOW).
|
||||
4. To Stop: IN1/IN2 both LOW.
|
||||
5. DO NOT treat Motors like Servos (No 'position' or 'angle').
|
||||
- SAFETY RULES (MANDATORY):
|
||||
1. Implement a safety timeout (e.g., 5000ms).
|
||||
2. Stop motors if no signal is received within timeout.
|
||||
3. Use millis() for non-blocking timing.
|
||||
"""
|
||||
|
||||
weapon_rules = ""
|
||||
if hardware.get("weapon"):
|
||||
weapon_rules = """
|
||||
- COMBAT PROTOCOL (MANDATORY):
|
||||
1. LOGIC FOCUS: This is a State Machine request, NOT just servo movement.
|
||||
2. STATES: enum State { DISARMED, ARMING, ARMED, FIRING };
|
||||
3. TRANSITIONS: DISARMED -> ARMING (2s delay) -> ARMED -> FIRING.
|
||||
4. SAFETY: Auto-disarm after 10s idle. Fire only when ARMED.
|
||||
5. STRUCTURE: Use switch(currentState) { case ... } for logic.
|
||||
6. OUTPUTS: Control relays/LEDs/Motors based on state.
|
||||
"""
|
||||
|
||||
# Anti-bloat rules
|
||||
anti_bloat_rules = []
|
||||
if not hardware["button"]:
|
||||
anti_bloat_rules.append("- NO EXTRA INPUTS: Do NOT add buttons, switches, or digitalRead() unless explicitly requested.")
|
||||
if not hardware["servo"]:
|
||||
anti_bloat_rules.append("- NO EXTRA SERVOS: Do NOT add Servo objects or attach() unless explicitly requested.")
|
||||
if not hardware["dc_motor"]:
|
||||
anti_bloat_rules.append("- NO EXTRA MOTORS: Do NOT add motor driver code (L298N) unless explicitly requested.")
|
||||
|
||||
anti_bloat = "\n".join(anti_bloat_rules)
|
||||
|
||||
# Modularity rule
|
||||
modularity_rule = ""
|
||||
if "function" in user_message.lower() or "naming" in user_message.lower() or "modular" in user_message.lower():
|
||||
modularity_rule = """
|
||||
- CODE STRUCTURE (MANDATORY):
|
||||
1. NO MONOLITHIC LOOP: Break code into small, descriptive functions.
|
||||
2. NAMING: Use camelCase for functions (e.g., readBatteryVoltage(), updateDisplay()).
|
||||
3. loop() must ONLY call these functions, not contain raw logic.
|
||||
"""
|
||||
|
||||
# Status LED rule
|
||||
status_led_rule = ""
|
||||
if hardware["led"] and ("status" in user_message.lower() or "indicator" in user_message.lower()):
|
||||
status_led_rule = """
|
||||
- STATUS LED RULES (MANDATORY):
|
||||
1. NO BREATHING/FADING: Do not use simple PWM fading loops.
|
||||
2. USE STATES: Define enum LEDStatus { OFF, IDLE, ACTIVE, ERROR };
|
||||
3. IMPLEMENTATION: Create void setStatusLED(LEDStatus state).
|
||||
4. PATTERNS: IDLE=Slow Blink, ACTIVE=Solid On, ERROR=Fast Blink.
|
||||
"""
|
||||
|
||||
prompt = f"""You are generating code for: {', '.join(hardware_context)}
|
||||
You are an expert embedded developer.
|
||||
TARGET HARDWARE: {hardware_detected}
|
||||
ACTIVE MODULES: {', '.join(hardware_context) if hardware_context else "None (Logic Only)"}
|
||||
|
||||
CRITICAL: Only use code patterns relevant to the hardware mentioned.
|
||||
STRICT NEGATIVE CONSTRAINTS (DO NOT IGNORE):
|
||||
{anti_bloat}
|
||||
|
||||
MANDATORY HARDWARE RULES:
|
||||
{l298n_rules}
|
||||
{weapon_rules}
|
||||
{status_led_rule}
|
||||
{anti_bloat}
|
||||
{modularity_rule}
|
||||
|
||||
GENERAL GUIDELINES:
|
||||
- If DC MOTOR: Use L298N patterns (digitalWrite, ledcWrite)
|
||||
- If SERVO: Use ESP32Servo patterns (attach, write)
|
||||
- DO NOT mix servo code into motor questions
|
||||
- DO NOT mix motor code into servo questions
|
||||
|
||||
CRITICAL RULES (MUST FOLLOW):
|
||||
{chr(10).join(relevant_rules)}
|
||||
|
||||
USER REQUEST:
|
||||
{user_message}
|
||||
|
||||
Generate code following ALL rules above. Do not add unrequested features.
|
||||
FINAL CHECK:
|
||||
1. Did you add unrequested buttons? REMOVE THEM.
|
||||
2. Did you add unrequested servos? REMOVE THEM.
|
||||
3. Generate code ONLY for the hardware requested.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def is_simple_question(self, message: str) -> bool:
|
||||
"""Check if this is a simple question that should use FAST model"""
|
||||
message_lower = message.lower()
|
||||
|
||||
simple_triggers = [
|
||||
"what is", "what's", "who is", "who's", "when is",
|
||||
"how do i", "can you explain", "tell me about",
|
||||
"what are", "where is", "hi", "hello", "hey",
|
||||
"good morning", "good evening"
|
||||
]
|
||||
|
||||
# Also check if it's just a question without code keywords
|
||||
code_keywords = ["generate", "create", "write", "build", "code", "function"]
|
||||
|
||||
has_simple_trigger = any(trigger in message_lower for trigger in simple_triggers)
|
||||
has_code_keyword = any(keyword in message_lower for keyword in code_keywords)
|
||||
|
||||
# Simple if: has simple trigger AND no code keywords
|
||||
return has_simple_trigger and not has_code_keyword
|
||||
|
||||
def is_complex(self, message: str) -> bool:
|
||||
"""Check if request is too complex and should be broken down"""
|
||||
message_lower = message.lower()
|
||||
|
||||
# Count complexity triggers
|
||||
trigger_count = sum(1 for trigger in COMPLEX_TRIGGERS if trigger in message_lower)
|
||||
|
||||
# Count how many modules mentioned
|
||||
module_count = 0
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
# module is used for key, keywords for values
|
||||
if any(kw in message_lower for kw in keywords):
|
||||
module_count += 1
|
||||
|
||||
# Complex if: multiple triggers OR 3+ modules mentioned
|
||||
return trigger_count >= 2 or module_count >= 3
|
||||
|
||||
def extract_modules(self, message: str) -> List[str]:
|
||||
"""Extract which modules are needed"""
|
||||
message_lower = message.lower()
|
||||
needed_modules = []
|
||||
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
# module is used for key, keywords for values
|
||||
if any(kw in message_lower for kw in keywords):
|
||||
needed_modules.append(module)
|
||||
|
||||
return needed_modules
|
||||
|
||||
def build_modular_plan(self, modules: List[str]) -> List[Dict[str, str]]:
|
||||
"""Create a build plan from modules"""
|
||||
plan = []
|
||||
|
||||
module_tasks = {
|
||||
"ble": "BLE communication setup with phone app control",
|
||||
"servo": "Servo motor control for flipper/weapon",
|
||||
"motor": "Motor driver setup for movement (L298N)",
|
||||
"safety": "Safety timeout and failsafe systems",
|
||||
"battery": "Battery voltage monitoring",
|
||||
"sensor": "Sensor integration (distance/proximity)"
|
||||
}
|
||||
|
||||
for module in modules:
|
||||
if module in module_tasks:
|
||||
plan.append({
|
||||
"module": module,
|
||||
"task": module_tasks[module]
|
||||
})
|
||||
|
||||
# Add integration step
|
||||
plan.append({
|
||||
"module": "integration",
|
||||
"task": "Integrate all modules into complete system"
|
||||
})
|
||||
|
||||
return plan
|
||||
|
|
@ -5,7 +5,7 @@ import queue
|
|||
import http.client
|
||||
|
||||
# Global Config
|
||||
DATA_DIR = Path(__file__).parent / "data"
|
||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||
DB_PATH = DATA_DIR / "conversations.db"
|
||||
OLLAMA_HOST = os.getenv("OLLAMA_HOST", "127.0.0.1")
|
||||
OLLAMA_PORT = int(os.getenv("OLLAMA_PORT", "11434"))
|
||||
240
core/buddai_storage.py
Normal file
240
core/buddai_storage.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from core.buddai_shared import DB_PATH, DATA_DIR
|
||||
|
||||
class StorageManager:
|
||||
"""Manages Database, Sessions, and Backups"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.current_session_id = None
|
||||
self.ensure_data_dir()
|
||||
self.init_database()
|
||||
self.start_new_session()
|
||||
|
||||
def ensure_data_dir(self) -> None:
|
||||
DATA_DIR.mkdir(exist_ok=True)
|
||||
|
||||
def init_database(self) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Core Tables
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
started_at TIMESTAMP,
|
||||
ended_at TIMESTAMP,
|
||||
title TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT,
|
||||
role TEXT,
|
||||
content TEXT,
|
||||
timestamp TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS repo_index (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT,
|
||||
file_path TEXT,
|
||||
repo_name TEXT,
|
||||
function_name TEXT,
|
||||
content TEXT,
|
||||
last_modified TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS style_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT,
|
||||
category TEXT,
|
||||
preference TEXT,
|
||||
confidence FLOAT,
|
||||
extracted_at TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS feedback (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER,
|
||||
positive BOOLEAN,
|
||||
comment TEXT,
|
||||
timestamp TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY,
|
||||
timestamp TEXT,
|
||||
original_code TEXT,
|
||||
corrected_code TEXT,
|
||||
reason TEXT,
|
||||
context TEXT,
|
||||
processed BOOLEAN DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compilation_log (
|
||||
id INTEGER PRIMARY KEY,
|
||||
timestamp TEXT,
|
||||
code TEXT,
|
||||
success BOOLEAN,
|
||||
errors TEXT,
|
||||
hardware TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS code_rules (
|
||||
id INTEGER PRIMARY KEY,
|
||||
rule_text TEXT,
|
||||
pattern_find TEXT,
|
||||
pattern_replace TEXT,
|
||||
context TEXT,
|
||||
confidence FLOAT,
|
||||
learned_from TEXT,
|
||||
times_applied INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
# Migrations (Idempotent)
|
||||
try: cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE repo_index ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE style_preferences ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE feedback ADD COLUMN comment TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE corrections ADD COLUMN processed BOOLEAN DEFAULT 0")
|
||||
except: pass
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_session(self) -> str:
|
||||
now = datetime.now()
|
||||
base_id = now.strftime("%Y%m%d_%H%M%S")
|
||||
session_id = base_id
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO sessions (session_id, user_id, started_at) VALUES (?, ?, ?)",
|
||||
(session_id, self.user_id, now.isoformat())
|
||||
)
|
||||
conn.commit()
|
||||
break
|
||||
except sqlite3.IntegrityError:
|
||||
counter += 1
|
||||
session_id = f"{base_id}_{counter}"
|
||||
|
||||
conn.close()
|
||||
return session_id
|
||||
|
||||
def start_new_session(self) -> str:
|
||||
self.current_session_id = self.create_session()
|
||||
return self.current_session_id
|
||||
|
||||
def end_session(self) -> None:
|
||||
if not self.current_session_id: return
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE sessions SET ended_at = ? WHERE session_id = ?",
|
||||
(datetime.now().isoformat(), self.current_session_id)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_message(self, role: str, content: str) -> int:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(self.current_session_id, role, content, datetime.now().isoformat())
|
||||
)
|
||||
msg_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return msg_id
|
||||
|
||||
def get_sessions(self, limit: int = 20) -> List[Dict[str, str]]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT session_id, started_at, title FROM sessions WHERE user_id = ? ORDER BY started_at DESC LIMIT ?", (self.user_id, limit))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": r[0], "date": r[1], "title": r[2] if len(r) > 2 else None} for r in rows]
|
||||
|
||||
def rename_session(self, session_id: str, new_title: str) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("UPDATE sessions SET title = ? WHERE session_id = ? AND user_id = ?", (new_title, session_id, self.user_id))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM sessions WHERE session_id = ? AND user_id = ?", (session_id, self.user_id))
|
||||
if cursor.rowcount > 0:
|
||||
cursor.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def clear_current_session(self) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM messages WHERE session_id = ?", (self.current_session_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load_session(self, session_id: str) -> List[Dict[str, str]]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT 1 FROM sessions WHERE session_id = ? AND user_id = ?", (session_id, self.user_id))
|
||||
if not cursor.fetchone():
|
||||
conn.close()
|
||||
return []
|
||||
|
||||
cursor.execute("SELECT id, role, content, timestamp FROM messages WHERE session_id = ? ORDER BY id ASC", (session_id,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
self.current_session_id = session_id
|
||||
return [{"id": r[0], "role": r[1], "content": r[2], "timestamp": r[3]} for r in rows]
|
||||
|
||||
def create_backup(self) -> Tuple[bool, str]:
|
||||
if not DB_PATH.exists(): return False, "Database file not found."
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = DATA_DIR / "backups"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
backup_path = backup_dir / f"conversations_{timestamp}.db"
|
||||
try:
|
||||
src = sqlite3.connect(DB_PATH); dst = sqlite3.connect(backup_path)
|
||||
with dst: src.backup(dst)
|
||||
dst.close(); src.close()
|
||||
return True, str(backup_path)
|
||||
except Exception as e: return False, str(e)
|
||||
41
core/buddai_training.py
Normal file
41
core/buddai_training.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import sqlite3
|
||||
import json
|
||||
from core.buddai_shared import DB_PATH, DATA_DIR
|
||||
|
||||
class ModelFineTuner:
|
||||
"""Fine-tune local model on YOUR corrections"""
|
||||
|
||||
def prepare_training_data(self):
|
||||
"""Convert corrections to training format"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT original_code, corrected_code, reason
|
||||
FROM corrections
|
||||
""")
|
||||
|
||||
training_data = []
|
||||
for original, corrected, reason in cursor.fetchall():
|
||||
training_data.append({
|
||||
"prompt": f"Generate code for: {reason}",
|
||||
"completion": corrected,
|
||||
"negative_example": original
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
# Save as JSONL for fine-tuning
|
||||
output_path = DATA_DIR / 'training_data.jsonl'
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in training_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f"Exported {len(training_data)} examples to {output_path}"
|
||||
|
||||
def fine_tune_model(self):
|
||||
"""Fine-tune Qwen on your corrections"""
|
||||
# This requires:
|
||||
# 1. Export training data
|
||||
# 2. Use Ollama modelfile or external training
|
||||
# 3. Create custom model: qwen2.5-coder-james:3b
|
||||
pass
|
||||
|
|
@ -1,10 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys, os, json, logging, sqlite3, datetime, pathlib, http.client, re, typing, zipfile, shutil, queue, socket, argparse, io, difflib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Tuple, Union, Generator
|
||||
|
||||
from buddai_shared import DB_PATH
|
||||
import re
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
class CodeValidator:
|
||||
"""Validate generated code before showing to user"""
|
||||
|
|
@ -469,8 +464,6 @@ class CodeValidator:
|
|||
|
||||
return fixed_code
|
||||
|
||||
|
||||
|
||||
class HardwareProfile:
|
||||
"""Learn hardware-specific patterns"""
|
||||
|
||||
|
|
@ -528,75 +521,4 @@ class HardwareProfile:
|
|||
def add_safety(self, code: str) -> str:
|
||||
if "motor" in code.lower() and "millis()" not in code:
|
||||
code += "\n// [BuddAI Safety] Warning: No non-blocking timeout detected. Consider adding safety timeout."
|
||||
return code
|
||||
|
||||
|
||||
|
||||
class LearningMetrics:
|
||||
"""Measure BuddAI's improvement over time"""
|
||||
|
||||
def calculate_accuracy(self):
|
||||
"""What % of code is accepted without correction?"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_responses,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive_feedback,
|
||||
COUNT(CASE WHEN c.id IS NOT NULL THEN 1 END) as corrected
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
LEFT JOIN corrections c ON m.content LIKE '%' || c.original_code || '%'
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp > ?
|
||||
""", (thirty_days_ago,))
|
||||
|
||||
total, positive, corrected = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
accuracy = (positive / total) * 100 if total and total > 0 else 0
|
||||
correction_rate = (corrected / total) * 100 if total and total > 0 else 0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correction_rate": correction_rate,
|
||||
"improvement": self.calculate_trend()
|
||||
}
|
||||
|
||||
def calculate_trend(self):
|
||||
"""Is BuddAI getting better over time?"""
|
||||
# Compare last 7 days vs previous 7 days
|
||||
recent = self.get_accuracy_for_period(7)
|
||||
previous = self.get_accuracy_for_period(7, offset=7)
|
||||
|
||||
improvement = recent - previous
|
||||
return f"+{improvement:.1f}%" if improvement > 0 else f"{improvement:.1f}%"
|
||||
|
||||
def get_accuracy_for_period(self, days: int, offset: int = 0) -> float:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
start_dt = (datetime.now() - timedelta(days=days + offset)).isoformat()
|
||||
end_dt = (datetime.now() - timedelta(days=offset)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp BETWEEN ? AND ?
|
||||
""", (start_dt, end_dt))
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return 0.0
|
||||
|
||||
total, positive = row
|
||||
return (positive / total) * 100 if total and total > 0 else 0.0
|
||||
return code
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
def decouple_exocortex(source_file):
|
||||
with open(source_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Define the file splits based on class/block signatures
|
||||
splits = {
|
||||
"buddai_memory.py": ["class ShadowSuggestionEngine", "class AdaptiveLearner", "class SmartLearner"],
|
||||
"buddai_logic.py": ["class CodeValidator", "class HardwareProfile", "class LearningMetrics"],
|
||||
"buddai_executive.py": ["class OllamaConnectionPool", "class BuddAI", "class ModelFineTuner"],
|
||||
"buddai_server.py": ["if SERVER_AVAILABLE:", "app = FastAPI", "class BuddAIManager"]
|
||||
}
|
||||
|
||||
print(f"🚀 Surgical extraction of {source_file} initiated...")
|
||||
|
||||
# Extraction logic for classes/blocks
|
||||
for filename, markers in splits.items():
|
||||
extracted_sections = []
|
||||
for marker in markers:
|
||||
# Simple extraction based on class indentation/block end
|
||||
pattern = re.compile(rf"{re.escape(marker)}.*?(?=\nclass |\nif __name__ ==|\nif SERVER_AVAILABLE)", re.DOTALL)
|
||||
match = pattern.search(content)
|
||||
if match:
|
||||
extracted_sections.append(match.group(0))
|
||||
|
||||
if extracted_sections:
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
f.write("#!/usr/bin/env python3\n")
|
||||
f.write("import sys, os, json, logging, sqlite3, datetime, pathlib, http.client, re, typing, zipfile, shutil, queue, socket, argparse, io, difflib\n")
|
||||
f.write("from pathlib import Path\nfrom datetime import datetime, timedelta\nfrom typing import Optional, List, Dict, Tuple, Union, Generator\n\n")
|
||||
f.write("try:\n from fastapi import FastAPI, File, Header, Response, UploadFile, WebSocketDisconnect, Request, WebSocket\n from fastapi.middleware.cors import CORSMiddleware\n from fastapi.responses import FileResponse, HTMLResponse, JSONResponse\n from fastapi.staticfiles import StaticFiles\n from pydantic import BaseModel\n import uvicorn\nexcept ImportError:\n pass\n\n")
|
||||
f.write("\n\n".join(extracted_sections))
|
||||
print(f"✅ Created {filename}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use the script's directory to find main.py reliably
|
||||
source_path = os.path.join(os.path.dirname(__file__), "main.py")
|
||||
if os.path.exists(source_path):
|
||||
decouple_exocortex(source_path)
|
||||
else:
|
||||
print(f"❌ Error: Could not find {source_path}")
|
||||
2
main.py
2
main.py
|
|
@ -11,7 +11,7 @@ import socket
|
|||
import uvicorn
|
||||
|
||||
# --- Import The Organs ---
|
||||
from buddai_shared import OLLAMA_HOST, OLLAMA_PORT, SERVER_AVAILABLE
|
||||
from core.buddai_shared import OLLAMA_HOST, OLLAMA_PORT, SERVER_AVAILABLE
|
||||
from buddai_executive import BuddAI
|
||||
|
||||
# If server dependencies are present, import the app
|
||||
|
|
|
|||
41
skills/__init__.py
Normal file
41
skills/__init__.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import importlib
|
||||
import pkgutil
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Configure local logger
|
||||
logger = logging.getLogger("BuddAI-Skills")
|
||||
|
||||
def load_registry():
|
||||
"""
|
||||
Dynamically discovers and loads skill modules from the current directory.
|
||||
Returns a dictionary mapping skill IDs to their executable functions and metadata.
|
||||
"""
|
||||
registry = {}
|
||||
package_dir = Path(__file__).parent
|
||||
|
||||
# Iterate over all .py files in this directory
|
||||
for _, name, _ in pkgutil.iter_modules([str(package_dir)]):
|
||||
try:
|
||||
# Import the module relative to this package
|
||||
module = importlib.import_module(f".{name}", __package__)
|
||||
|
||||
# Verify the Skill Interface (must have 'meta' and 'run')
|
||||
if hasattr(module, "meta") and hasattr(module, "run"):
|
||||
metadata = module.meta()
|
||||
skill_id = name
|
||||
|
||||
registry[skill_id] = {
|
||||
"name": metadata.get("name", skill_id),
|
||||
"triggers": metadata.get("triggers", []),
|
||||
"description": metadata.get("description", "No description provided."),
|
||||
"run": module.run
|
||||
}
|
||||
logger.info(f"🧩 Skill Loaded: {metadata.get('name')} [{skill_id}]")
|
||||
else:
|
||||
logger.debug(f"Skipping {name}: Does not implement Skill Interface.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error loading skill '{name}': {e}")
|
||||
|
||||
return registry
|
||||
47
skills/calculator.py
Normal file
47
skills/calculator.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import re
|
||||
|
||||
def meta():
|
||||
"""
|
||||
Defines the metadata for the Calculator skill.
|
||||
"""
|
||||
return {
|
||||
"name": "Smart Calculator",
|
||||
"description": "Performs basic arithmetic operations detected in the prompt.",
|
||||
"triggers": ["calculate", "compute", "solve", "math", "+", "-", "*", "/"]
|
||||
}
|
||||
|
||||
def run(payload):
|
||||
"""
|
||||
Executes the calculation logic.
|
||||
Accepts a string prompt or a dictionary context.
|
||||
"""
|
||||
# Normalize input
|
||||
prompt = payload if isinstance(payload, str) else payload.get("prompt", "")
|
||||
|
||||
# 1. Extract the mathematical expression
|
||||
# Regex looks for sequences of numbers and operators
|
||||
# Allowed: digits, whitespace, +, -, *, /, ., (, )
|
||||
match = re.search(r'([\d\.\s\+\-\*\/\(\)]+)', prompt)
|
||||
|
||||
if not match:
|
||||
return None # Fallback to LLM if no math found
|
||||
|
||||
expression = match.group(0).strip()
|
||||
|
||||
if not any(char.isdigit() for char in expression):
|
||||
return None
|
||||
|
||||
# 2. Safety Check (Double verification)
|
||||
allowed_chars = set("0123456789.+-*/() ")
|
||||
if not set(expression).issubset(allowed_chars):
|
||||
return "Calculation aborted: Invalid characters detected."
|
||||
|
||||
# 3. Execute
|
||||
try:
|
||||
# pylint: disable=eval-used
|
||||
result = eval(expression, {"__builtins__": None}, {})
|
||||
return f"🧮 Result: {expression} = {result}"
|
||||
except ZeroDivisionError:
|
||||
return "🧮 Error: Division by zero is not allowed."
|
||||
except Exception as e:
|
||||
return f"🧮 Calculation Error: {str(e)}"
|
||||
27
skills/system_info.py
Normal file
27
skills/system_info.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import psutil
|
||||
|
||||
def meta():
|
||||
"""
|
||||
Metadata for the System Info skill.
|
||||
"""
|
||||
return {
|
||||
"name": "System Info",
|
||||
"description": "Reports current CPU and RAM usage.",
|
||||
"triggers": ["cpu usage", "ram usage", "memory usage", "system stats", "how much ram", "cpu load"]
|
||||
}
|
||||
|
||||
def run(payload):
|
||||
"""
|
||||
Fetches system metrics.
|
||||
"""
|
||||
# interval=0.1 ensures we get a fresh sample (blocking briefly)
|
||||
cpu_usage = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
total_gb = mem.total / (1024 ** 3)
|
||||
used_gb = mem.used / (1024 ** 3)
|
||||
percent_used = mem.percent
|
||||
|
||||
return (f"🖥️ System Vital Signs:\n"
|
||||
f" 🧠 CPU Load: {cpu_usage}%\n"
|
||||
f" 💾 RAM Usage: {percent_used}% ({used_gb:.1f}GB / {total_gb:.1f}GB)")
|
||||
68
skills/test_all.py
Normal file
68
skills/test_all.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import unittest
|
||||
import io
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def meta():
|
||||
"""
|
||||
Metadata for the Test Runner skill.
|
||||
"""
|
||||
return {
|
||||
"name": "Self-Diagnostic",
|
||||
"description": "Runs the internal unit test suite (tests/*.py).",
|
||||
"triggers": ["test all", "run tests", "self diagnostic", "check systems", "verify integrity"]
|
||||
}
|
||||
|
||||
def run(payload):
|
||||
"""
|
||||
Discovers and runs tests in the tests/ directory.
|
||||
"""
|
||||
# Root dir is parent of skills/ (i.e., buddAI/)
|
||||
root_dir = Path(__file__).parent.parent
|
||||
tests_dir = root_dir / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
return "❌ Diagnostics failed: 'tests' directory not found."
|
||||
|
||||
# Capture output
|
||||
log_capture = io.StringIO()
|
||||
|
||||
# Create a test runner that writes to our capture stream
|
||||
runner = unittest.TextTestRunner(stream=log_capture, verbosity=1)
|
||||
loader = unittest.TestLoader()
|
||||
|
||||
try:
|
||||
# Ensure root_dir is in sys.path so tests can import 'core', 'skills', etc.
|
||||
if str(root_dir) not in sys.path:
|
||||
sys.path.insert(0, str(root_dir))
|
||||
|
||||
# Discover tests
|
||||
suite = loader.discover(str(tests_dir), pattern="test_*.py", top_level_dir=str(root_dir))
|
||||
|
||||
num_tests = suite.countTestCases()
|
||||
if num_tests == 0:
|
||||
return "⚠️ No tests found in tests/ directory."
|
||||
|
||||
# Run tests
|
||||
result = runner.run(suite)
|
||||
|
||||
# Get output string
|
||||
output = log_capture.getvalue()
|
||||
|
||||
# Construct response
|
||||
header = "✅ **All Systems Operational**" if result.wasSuccessful() else "❌ **System Failures Detected**"
|
||||
stats = f"Executed {result.testsRun} tests."
|
||||
|
||||
if not result.wasSuccessful():
|
||||
stats += f"\n🔴 Failures: {len(result.failures)}"
|
||||
stats += f"\n⚠️ Errors: {len(result.errors)}"
|
||||
|
||||
# Limit output length for chat
|
||||
console_output = output
|
||||
if len(console_output) > 1500:
|
||||
console_output = "..." + console_output[-1500:]
|
||||
|
||||
return f"{header}\n{stats}\n\n**Console Output:**\n```text\n{console_output}\n```"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ Execution Error: {str(e)}"
|
||||
44
skills/timer.py
Normal file
44
skills/timer.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import time
|
||||
import re
|
||||
import threading
|
||||
|
||||
def meta():
|
||||
"""
|
||||
Metadata for the Timer skill.
|
||||
"""
|
||||
return {
|
||||
"name": "Timer",
|
||||
"description": "Sets a non-blocking timer (background thread).",
|
||||
"triggers": ["timer", "sleep", "wait for"]
|
||||
}
|
||||
|
||||
def run(payload):
|
||||
"""
|
||||
Executes the blocking sleep.
|
||||
"""
|
||||
prompt = payload if isinstance(payload, str) else payload.get("prompt", "")
|
||||
|
||||
# Regex to capture number and optional unit (e.g., "5", "5s", "5 minutes")
|
||||
match = re.search(r'(\d+)\s*(seconds?|secs?|s|minutes?|mins?|m)?', prompt.lower())
|
||||
|
||||
if not match:
|
||||
return None # Fallback to LLM if no time found
|
||||
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
duration = amount
|
||||
if unit and unit.startswith('m'):
|
||||
duration *= 60
|
||||
|
||||
if duration > 3600:
|
||||
return f"❌ Timer too long ({duration}s). Max 1 hour."
|
||||
|
||||
def _timer_thread():
|
||||
time.sleep(duration)
|
||||
print(f"\n\n⏰ 🔔 BEEP! Timer finished ({duration}s).\n")
|
||||
|
||||
t = threading.Thread(target=_timer_thread, daemon=True)
|
||||
t.start()
|
||||
|
||||
return f"⏰ Timer started for {duration} seconds (running in background)..."
|
||||
44
skills/weather.py
Normal file
44
skills/weather.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import urllib.request
|
||||
import urllib.parse
|
||||
import re
|
||||
|
||||
def meta():
|
||||
"""
|
||||
Metadata for the Weather skill.
|
||||
"""
|
||||
return {
|
||||
"name": "Weather",
|
||||
"description": "Fetches current weather using wttr.in (no API key required).",
|
||||
"triggers": ["weather", "temperature", "forecast"]
|
||||
}
|
||||
|
||||
def run(payload):
|
||||
"""
|
||||
Fetches weather data.
|
||||
"""
|
||||
prompt = payload if isinstance(payload, str) else payload.get("prompt", "")
|
||||
|
||||
location = ""
|
||||
# Extract location: "weather in London", "weather for Paris"
|
||||
match = re.search(r'\b(?:in|for|at)\s+(.+)', prompt, re.IGNORECASE)
|
||||
if match:
|
||||
location = match.group(1).strip().rstrip("?.!")
|
||||
|
||||
try:
|
||||
query = urllib.parse.quote(location) if location else ""
|
||||
# format=3 gives a concise one-line output (e.g., "London: ⛅️ +13°C")
|
||||
url = f"https://wttr.in/{query}?format=3"
|
||||
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
headers={'User-Agent': 'curl/7.68.0'} # Mimic curl to ensure text output
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
if response.status == 200:
|
||||
result = response.read().decode('utf-8').strip()
|
||||
return f"🌦️ {result}"
|
||||
return f"❌ Weather error: {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ Failed to fetch weather: {str(e)}"
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
29
tests/test_all.py
Normal file
29
tests/test_all.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import unittest
|
||||
import sys
|
||||
import os
|
||||
|
||||
def run_suite():
|
||||
"""
|
||||
Discover and run all tests in the tests/ directory.
|
||||
"""
|
||||
# Get directories
|
||||
tests_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(tests_dir)
|
||||
|
||||
# Add project root to sys.path to allow imports of 'core', 'skills', etc.
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# Discover tests
|
||||
loader = unittest.TestLoader()
|
||||
suite = loader.discover(tests_dir, pattern="test_*.py", top_level_dir=project_root)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_suite()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -22,10 +22,10 @@ import http.client
|
|||
|
||||
# Dynamic import of buddai_v3.2.py
|
||||
REPO_ROOT = Path(__file__).parent.parent
|
||||
MODULE_PATH = REPO_ROOT / "buddai_v3.2.py"
|
||||
spec = importlib.util.spec_from_file_location("buddai_v3_2", MODULE_PATH)
|
||||
MODULE_PATH = REPO_ROOT / "buddai_executive.py"
|
||||
spec = importlib.util.spec_from_file_location("buddai_executive", MODULE_PATH)
|
||||
buddai_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["buddai_v3_2"] = buddai_module
|
||||
sys.modules["buddai_executive"] = buddai_module
|
||||
spec.loader.exec_module(buddai_module)
|
||||
BuddAI = buddai_module.BuddAI
|
||||
|
||||
|
|
@ -564,12 +564,12 @@ def test_schedule_awareness():
|
|||
print_test("Schedule Awareness")
|
||||
|
||||
# Mock datetime to test different times
|
||||
with patch('buddai_v3_2.datetime') as mock_date:
|
||||
with patch('core.buddai_personality.datetime') as mock_date:
|
||||
# 1. Early Morning (Monday 6:00 AM)
|
||||
mock_date.now.return_value = datetime(2025, 12, 29, 6, 0, 0)
|
||||
|
||||
buddai = BuddAI(server_mode=False)
|
||||
status = buddai.get_user_status()
|
||||
status = buddai.personality_manager.get_user_status()
|
||||
|
||||
if "Early Morning" in status:
|
||||
print_pass(f"6:00 AM Mon -> {status}")
|
||||
|
|
@ -579,7 +579,7 @@ def test_schedule_awareness():
|
|||
|
||||
# 2. Work Hours (Monday 10:00 AM)
|
||||
mock_date.now.return_value = datetime(2025, 12, 29, 10, 0, 0)
|
||||
status = buddai.get_user_status()
|
||||
status = buddai.personality_manager.get_user_status()
|
||||
|
||||
if "Work Hours" in status:
|
||||
print_pass(f"10:00 AM Mon -> {status}")
|
||||
|
|
@ -619,7 +619,7 @@ def test_session_management():
|
|||
test_db = Path(test_db_path)
|
||||
|
||||
try:
|
||||
with patch('buddai_v3_2.DB_PATH', test_db):
|
||||
with patch('buddai_executive.DB_PATH', test_db):
|
||||
buddai = BuddAI(server_mode=False)
|
||||
|
||||
# 1. Create
|
||||
|
|
@ -665,8 +665,8 @@ def test_rapid_session_creation():
|
|||
# Mock datetime to return a fixed time, forcing ID collisions
|
||||
fixed_time = datetime(2025, 1, 1, 12, 0, 0)
|
||||
|
||||
with patch('buddai_v3_2.DB_PATH', test_db):
|
||||
with patch('buddai_v3_2.datetime') as mock_dt:
|
||||
with patch('buddai_executive.DB_PATH', test_db):
|
||||
with patch('buddai_executive.datetime') as mock_dt:
|
||||
mock_dt.now.return_value = fixed_time
|
||||
|
||||
buddai = BuddAI(server_mode=False)
|
||||
|
|
@ -711,7 +711,7 @@ def test_repo_isolation():
|
|||
(repo_path / "user1_secret.py").write_text("def user1_secret_function():\n pass")
|
||||
|
||||
try:
|
||||
with patch('buddai_v3_2.DB_PATH', test_db):
|
||||
with patch('buddai_executive.DB_PATH', test_db):
|
||||
# Suppress internal prints to keep test output clean
|
||||
with patch('builtins.print'):
|
||||
# User 1 indexes the repo
|
||||
|
|
@ -815,7 +815,7 @@ def test_websocket_logic():
|
|||
test_db = Path(test_db_path)
|
||||
|
||||
try:
|
||||
with patch('buddai_v3_2.DB_PATH', test_db):
|
||||
with patch('buddai_executive.DB_PATH', test_db):
|
||||
# Suppress prints during init
|
||||
with patch('builtins.print'):
|
||||
buddai = BuddAI(server_mode=False)
|
||||
|
|
@ -932,7 +932,7 @@ def test_feedback_system():
|
|||
test_db = Path(test_db_path)
|
||||
|
||||
try:
|
||||
with patch('buddai_v3_2.DB_PATH', test_db):
|
||||
with patch('buddai_executive.DB_PATH', test_db):
|
||||
# Suppress prints
|
||||
with patch('builtins.print'):
|
||||
buddai = BuddAI(server_mode=False)
|
||||
|
|
|
|||
|
|
@ -10,18 +10,19 @@ import importlib.util
|
|||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
from core.buddai_prompt_engine import PromptEngine
|
||||
|
||||
# Load buddai_v3.2.py dynamically due to version number in filename
|
||||
REPO_ROOT = Path(__file__).parent.parent
|
||||
MODULE_PATH = REPO_ROOT / "buddai_v3.2.py"
|
||||
MODULE_PATH = REPO_ROOT / "buddai_executive.py"
|
||||
|
||||
if not MODULE_PATH.exists():
|
||||
print(f"Error: Could not find {MODULE_PATH}")
|
||||
sys.exit(1)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("buddai_v3_2", MODULE_PATH)
|
||||
spec = importlib.util.spec_from_file_location("buddai_executive", MODULE_PATH)
|
||||
buddai_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["buddai_v3_2"] = buddai_module
|
||||
sys.modules["buddai_executive"] = buddai_module
|
||||
spec.loader.exec_module(buddai_module)
|
||||
|
||||
BuddAI = buddai_module.BuddAI
|
||||
|
|
@ -52,20 +53,20 @@ class TestBuddAITypesAndLogic(unittest.TestCase):
|
|||
self.assertEqual(chat_hints['return'], str)
|
||||
|
||||
# is_complex
|
||||
self.assertEqual(BuddAI.is_complex.__annotations__['return'], bool)
|
||||
self.assertEqual(PromptEngine.is_complex.__annotations__['return'], bool)
|
||||
|
||||
# extract_modules
|
||||
self.assertEqual(BuddAI.extract_modules.__annotations__['return'], List[str])
|
||||
self.assertEqual(PromptEngine.extract_modules.__annotations__['return'], List[str])
|
||||
|
||||
# build_modular_plan
|
||||
self.assertEqual(BuddAI.build_modular_plan.__annotations__['return'], List[Dict[str, str]])
|
||||
self.assertEqual(PromptEngine.build_modular_plan.__annotations__['return'], List[Dict[str, str]])
|
||||
|
||||
def test_routing_simple_question(self):
|
||||
"""Test that simple questions route to the FAST model"""
|
||||
with patch.object(self.buddai, 'call_model', return_value="Fast response") as mock_call:
|
||||
response = self.buddai._route_request("What is a servo?", force_model=None, forge_mode="2")
|
||||
|
||||
mock_call.assert_called_with("fast", "What is a servo?")
|
||||
mock_call.assert_called_with("fast", "What is a servo?", system_task=True)
|
||||
self.assertEqual(response, "Fast response")
|
||||
|
||||
def test_routing_complex_request(self):
|
||||
|
|
@ -74,7 +75,7 @@ class TestBuddAITypesAndLogic(unittest.TestCase):
|
|||
|
||||
with patch.object(self.buddai, 'execute_modular_build', return_value="Modular code") as mock_build:
|
||||
# Mock is_complex to ensure it returns True for this test case
|
||||
with patch.object(self.buddai, 'is_complex', return_value=True):
|
||||
with patch.object(self.buddai.prompt_engine, 'is_complex', return_value=True):
|
||||
response = self.buddai._route_request(complex_msg, force_model=None, forge_mode="2")
|
||||
|
||||
mock_build.assert_called()
|
||||
|
|
@ -84,11 +85,11 @@ class TestBuddAITypesAndLogic(unittest.TestCase):
|
|||
"""Test that search queries route to repository search"""
|
||||
search_msg = "Show me functions using applyForge"
|
||||
|
||||
with patch.object(self.buddai, 'search_repositories', return_value="Search results") as mock_search:
|
||||
with patch.object(self.buddai.repo_manager, 'search_repositories', return_value="Search results") as mock_search:
|
||||
# Mock is_search_query to ensure True
|
||||
with patch.object(self.buddai, 'is_search_query', return_value=True):
|
||||
with patch.object(self.buddai.repo_manager, 'is_search_query', return_value=True):
|
||||
# Ensure is_complex is False so it doesn't preempt search
|
||||
with patch.object(self.buddai, 'is_complex', return_value=False):
|
||||
with patch.object(self.buddai.prompt_engine, 'is_complex', return_value=False):
|
||||
response = self.buddai._route_request(search_msg, force_model=None, forge_mode="2")
|
||||
|
||||
mock_search.assert_called_with(search_msg)
|
||||
|
|
@ -105,7 +106,7 @@ class TestBuddAITypesAndLogic(unittest.TestCase):
|
|||
def test_extract_modules(self):
|
||||
"""Verify module extraction logic"""
|
||||
msg = "I need a robot with bluetooth and a flipper weapon"
|
||||
modules = self.buddai.extract_modules(msg)
|
||||
modules = self.buddai.prompt_engine.extract_modules(msg)
|
||||
self.assertIn("ble", modules)
|
||||
self.assertIn("servo", modules)
|
||||
self.assertNotIn("motor", modules)
|
||||
|
|
|
|||
|
|
@ -18,18 +18,25 @@ import json
|
|||
|
||||
# Dynamic import of buddai_v3.2.py
|
||||
REPO_ROOT = Path(__file__).parent.parent
|
||||
MODULE_PATH = REPO_ROOT / "buddai_v3.2.py"
|
||||
spec = importlib.util.spec_from_file_location("buddai_v3_2", MODULE_PATH)
|
||||
MODULE_PATH = REPO_ROOT / "buddai_executive.py"
|
||||
spec = importlib.util.spec_from_file_location("buddai_executive", MODULE_PATH)
|
||||
buddai_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["buddai_v3_2"] = buddai_module
|
||||
sys.modules["buddai_executive"] = buddai_module
|
||||
spec.loader.exec_module(buddai_module)
|
||||
|
||||
# Check for server dependencies
|
||||
SERVER_AVAILABLE = getattr(buddai_module, "SERVER_AVAILABLE", False)
|
||||
|
||||
if SERVER_AVAILABLE:
|
||||
# Load buddai_server.py dynamically to get 'app'
|
||||
SERVER_PATH = REPO_ROOT / "buddai_server.py"
|
||||
spec_server = importlib.util.spec_from_file_location("buddai_server", SERVER_PATH)
|
||||
server_module = importlib.util.module_from_spec(spec_server)
|
||||
sys.modules["buddai_server"] = server_module
|
||||
spec_server.loader.exec_module(server_module)
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
app = buddai_module.app
|
||||
app = server_module.app
|
||||
client = TestClient(app)
|
||||
else:
|
||||
print("⚠️ Server dependencies missing. Integration tests skipped.")
|
||||
|
|
@ -43,7 +50,7 @@ class TestBuddAIIntegration(unittest.TestCase):
|
|||
os.close(self.db_fd)
|
||||
|
||||
# Patch DB_PATH in the module
|
||||
self.db_patcher = patch("buddai_v3_2.DB_PATH", Path(self.db_path))
|
||||
self.db_patcher = patch("buddai_executive.DB_PATH", Path(self.db_path))
|
||||
self.mock_db_path = self.db_patcher.start()
|
||||
|
||||
# Reset the manager to ensure fresh BuddAI instances connected to temp DB
|
||||
|
|
@ -143,9 +150,9 @@ class TestBuddAIIntegration(unittest.TestCase):
|
|||
def test_upload_api(self):
|
||||
"""Test file upload endpoint"""
|
||||
with tempfile.TemporaryDirectory() as tmp_data_dir:
|
||||
with patch("buddai_v3_2.DATA_DIR", Path(tmp_data_dir)):
|
||||
with patch("buddai_executive.DATA_DIR", Path(tmp_data_dir)):
|
||||
# Mock indexing to avoid parsing logic
|
||||
with patch.object(buddai_module.BuddAI, 'index_local_repositories') as mock_index:
|
||||
with patch.object(buddai_module.RepoManager, 'index_local_repositories') as mock_index:
|
||||
|
||||
# Create dummy file
|
||||
files = {'file': ('test.py', b'print("hello")', 'text/x-python')}
|
||||
|
|
|
|||
56
tests/test_skills.py
Normal file
56
tests/test_skills.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path so we can import 'skills'
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from skills import load_registry
|
||||
|
||||
class TestSkills(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.registry = load_registry()
|
||||
|
||||
def test_registry_loading(self):
|
||||
"""Ensure skills are discovered and loaded"""
|
||||
self.assertGreater(len(self.registry), 0, "No skills loaded")
|
||||
# Check for core skills
|
||||
self.assertIn("calculator", self.registry)
|
||||
self.assertIn("weather", self.registry)
|
||||
self.assertIn("timer", self.registry)
|
||||
self.assertIn("system_info", self.registry)
|
||||
self.assertIn("test_all", self.registry)
|
||||
|
||||
def test_calculator_logic(self):
|
||||
"""Verify calculator skill math"""
|
||||
calc = self.registry["calculator"]["run"]
|
||||
self.assertIn("4", calc("Calculate 2 + 2"))
|
||||
self.assertIn("25", calc("5 * 5"))
|
||||
self.assertIsNone(calc("No math here"))
|
||||
|
||||
def test_timer_parsing(self):
|
||||
"""Verify timer parses duration correctly"""
|
||||
timer = self.registry["timer"]["run"]
|
||||
# We use 0 seconds to avoid waiting during tests
|
||||
response = timer("Set a timer for 0 seconds")
|
||||
self.assertIn("Timer started", response)
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_weather_mock(self, mock_urlopen):
|
||||
"""Verify weather skill with mocked network"""
|
||||
# Mock the API response so tests work offline
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b"London: +15C"
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
weather = self.registry["weather"]["run"]
|
||||
result = weather("Weather in London")
|
||||
|
||||
self.assertIn("London: +15C", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue