mirror of
https://github.com/JamesTheGiblet/BuddAI.git
synced 2026-01-08 21:58:40 +00:00
Implement core skills: Code validation, model fine-tuning, and system diagnostics
- Added `ModelFineTuner` class for preparing training data and fine-tuning models based on user corrections. - Introduced `CodeValidator` class to validate generated code against various hardware and style rules, including safety checks and function naming conventions. - Developed skills for calculator operations, system information retrieval, weather fetching, and timer functionality. - Implemented a self-diagnostic skill to run unit tests and report results. - Created a dynamic skill loading mechanism to discover and register skills from the current directory. - Added unit tests for skills to ensure functionality and reliability.
This commit is contained in:
parent
743f9f311d
commit
f9fd27d228
28 changed files with 2398 additions and 4077 deletions
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
72
core/buddai_analytics.py
Normal file
72
core/buddai_analytics.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from core.buddai_shared import DB_PATH
|
||||
|
||||
class LearningMetrics:
|
||||
"""Measure BuddAI's improvement over time"""
|
||||
|
||||
def calculate_accuracy(self):
|
||||
"""What % of code is accepted without correction?"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_responses,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive_feedback,
|
||||
COUNT(CASE WHEN c.id IS NOT NULL THEN 1 END) as corrected
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
LEFT JOIN corrections c ON m.content LIKE '%' || c.original_code || '%'
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp > ?
|
||||
""", (thirty_days_ago,))
|
||||
|
||||
total, positive, corrected = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
accuracy = (positive / total) * 100 if total and total > 0 else 0
|
||||
correction_rate = (corrected / total) * 100 if total and total > 0 else 0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correction_rate": correction_rate,
|
||||
"improvement": self.calculate_trend()
|
||||
}
|
||||
|
||||
def calculate_trend(self):
|
||||
"""Is BuddAI getting better over time?"""
|
||||
# Compare last 7 days vs previous 7 days
|
||||
recent = self.get_accuracy_for_period(7)
|
||||
previous = self.get_accuracy_for_period(7, offset=7)
|
||||
|
||||
improvement = recent - previous
|
||||
return f"+{improvement:.1f}%" if improvement > 0 else f"{improvement:.1f}%"
|
||||
|
||||
def get_accuracy_for_period(self, days: int, offset: int = 0) -> float:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
start_dt = (datetime.now() - timedelta(days=days + offset)).isoformat()
|
||||
end_dt = (datetime.now() - timedelta(days=offset)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN f.positive = 1 THEN 1 END) as positive
|
||||
FROM messages m
|
||||
LEFT JOIN feedback f ON m.id = f.message_id
|
||||
WHERE m.role = 'assistant'
|
||||
AND m.timestamp BETWEEN ? AND ?
|
||||
""", (start_dt, end_dt))
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row:
|
||||
return 0.0
|
||||
|
||||
total, positive = row
|
||||
return (positive / total) * 100 if total and total > 0 else 0.0
|
||||
181
core/buddai_knowledge.py
Normal file
181
core/buddai_knowledge.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import sqlite3
|
||||
import re
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
class RepoManager:
|
||||
"""Manages local repository indexing and retrieval (RAG)"""
|
||||
|
||||
def __init__(self, db_path: Path, user_id: str):
|
||||
self.db_path = db_path
|
||||
self.user_id = user_id
|
||||
|
||||
def is_search_query(self, message: str) -> bool:
|
||||
"""Check if this is a search query that should query repo_index"""
|
||||
message_lower = message.lower()
|
||||
search_triggers = [
|
||||
"show me", "find", "search for", "list all",
|
||||
"what functions", "which repos", "do i have",
|
||||
"where did i", "have i used", "examples of",
|
||||
"show all", "display"
|
||||
]
|
||||
return any(trigger in message_lower for trigger in search_triggers)
|
||||
|
||||
def search_repositories(self, query: str) -> str:
|
||||
"""Search repo_index for relevant functions and code"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM repo_index WHERE user_id = ?", (self.user_id,))
|
||||
count = cursor.fetchone()[0]
|
||||
print(f"\n🔍 Searching {count} indexed functions...\n")
|
||||
|
||||
# Extract keywords from query
|
||||
keywords = re.findall(r'\b\w{4,}\b', query.lower())
|
||||
# Add specific search terms
|
||||
specific_terms = []
|
||||
if "exponential" in query.lower() or "decay" in query.lower():
|
||||
specific_terms.append("applyForge")
|
||||
specific_terms.append("exp(")
|
||||
if "forge" in query.lower():
|
||||
specific_terms.append("Forge")
|
||||
keywords.extend(specific_terms)
|
||||
|
||||
if not keywords:
|
||||
print("❌ No search terms found")
|
||||
conn.close()
|
||||
return "No search terms provided."
|
||||
|
||||
# Build parameterized query
|
||||
conditions = []
|
||||
params = []
|
||||
for keyword in keywords:
|
||||
conditions.append("(function_name LIKE ? OR content LIKE ? OR repo_name LIKE ?)")
|
||||
params.extend([f"%{keyword}%", f"%{keyword}%", f"%{keyword}%"])
|
||||
|
||||
sql = f"SELECT repo_name, file_path, function_name, content FROM repo_index WHERE ({' OR '.join(conditions)}) AND user_id = ? ORDER BY last_modified DESC LIMIT 10"
|
||||
params.append(self.user_id)
|
||||
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
if not results:
|
||||
return f"❌ No functions found matching: {', '.join(keywords)}\n\nTry: /index <path> to index more repositories"
|
||||
# Format results
|
||||
output = f"✅ Found {len(results)} matches for: {', '.join(set(keywords))}\n\n"
|
||||
for i, (repo, file_path, func, content) in enumerate(results, 1):
|
||||
# Extract relevant snippet
|
||||
lines = content.split('\n')
|
||||
snippet_lines = []
|
||||
for line in lines[:30]: # First 30 lines
|
||||
if any(kw in line.lower() for kw in keywords):
|
||||
snippet_lines.append(line)
|
||||
if len(snippet_lines) >= 10:
|
||||
break
|
||||
if not snippet_lines:
|
||||
snippet_lines = lines[:10]
|
||||
snippet = '\n'.join(snippet_lines)
|
||||
output += f"**{i}. {func}()** in {repo}\n"
|
||||
output += f" 📁 {Path(file_path).name}\n"
|
||||
output += f"\n```cpp\n{snippet}\n```\n"
|
||||
output += f" ---\n\n"
|
||||
return output
|
||||
|
||||
def index_local_repositories(self, root_path: str) -> None:
|
||||
"""Crawl directories and index .py, .ino, and .cpp files"""
|
||||
print(f"\n🔍 Indexing repositories in: {root_path}")
|
||||
path = Path(root_path)
|
||||
|
||||
if not path.exists():
|
||||
print(f"❌ Path not found: {root_path}")
|
||||
return
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
count = 0
|
||||
|
||||
for file_path in path.rglob('*'):
|
||||
if file_path.is_file() and file_path.suffix in ['.py', '.ino', '.cpp', '.h', '.js', '.jsx', '.html', '.css']:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
functions = []
|
||||
|
||||
# Python parsing
|
||||
if file_path.suffix == '.py':
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
functions.append(node.name)
|
||||
except:
|
||||
pass
|
||||
|
||||
# C++/Arduino parsing
|
||||
elif file_path.suffix in ['.ino', '.cpp', '.h']:
|
||||
matches = re.findall(r'\b(?:void|int|bool|float|double|String|char)\s+(\w+)\s*\(', content)
|
||||
functions.extend(matches)
|
||||
|
||||
# JS/Web parsing
|
||||
elif file_path.suffix in ['.js', '.jsx']:
|
||||
matches = re.findall(r'(?:function\s+(\w+)|const\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>)', content)
|
||||
functions.extend([m[0] or m[1] for m in matches if m[0] or m[1]])
|
||||
|
||||
# HTML/CSS - Index as whole file
|
||||
elif file_path.suffix in ['.html', '.css']:
|
||||
functions.append("file_content")
|
||||
|
||||
# Determine repo name
|
||||
try:
|
||||
repo_name = file_path.relative_to(path).parts[0]
|
||||
except:
|
||||
repo_name = "unknown"
|
||||
|
||||
timestamp = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
for func in functions:
|
||||
cursor.execute("""
|
||||
INSERT INTO repo_index (user_id, file_path, repo_name, function_name, content, last_modified)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (self.user_id, str(file_path), repo_name, func, content, timestamp.isoformat()))
|
||||
count += 1
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"✅ Indexed {count} functions across repositories")
|
||||
|
||||
def retrieve_style_context(self, message: str, prompt_template: str, user_name: str) -> str:
|
||||
"""Search repo_index for code snippets matching the request"""
|
||||
# Extract potential keywords (nouns/modules)
|
||||
keywords = re.findall(r'\b\w{4,}\b', message.lower())
|
||||
if not keywords:
|
||||
return ""
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build a search query for function names or repo names
|
||||
search_terms = " OR ".join([f"function_name LIKE '%{k}%'" for k in keywords])
|
||||
search_terms += " OR " + " OR ".join([f"repo_name LIKE '%{k}%'" for k in keywords])
|
||||
|
||||
query = f"SELECT repo_name, function_name, content FROM repo_index WHERE ({search_terms}) AND user_id = ? LIMIT 2"
|
||||
|
||||
cursor.execute(query, (self.user_id,))
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
context_block = prompt_template.format(user_name=user_name)
|
||||
for repo, func, content in results:
|
||||
# Just grab the first 500 chars of the file to save context window
|
||||
snippet = content[:500] + "..."
|
||||
context_block += f"Repo: {repo} | Function: {func}\nCode:\n{snippet}\n---\n"
|
||||
|
||||
return context_block
|
||||
137
core/buddai_llm.py
Normal file
137
core/buddai_llm.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
import http.client
|
||||
import socket
|
||||
from typing import List, Dict, Union, Generator, Optional
|
||||
from core.buddai_shared import MODELS, OLLAMA_POOL, OLLAMA_HOST, OLLAMA_PORT
|
||||
|
||||
class OllamaClient:
|
||||
"""Handles communication with the local Ollama instance"""
|
||||
|
||||
def query(self, model_key: str, messages: List[Dict], stream: bool = False, options: Dict = None) -> Union[str, Generator[str, None, None]]:
|
||||
"""Send a chat request to Ollama"""
|
||||
model_name = MODELS.get(model_key, model_key) # Handle key or direct name
|
||||
|
||||
default_options = {
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"num_ctx": 1024
|
||||
}
|
||||
if options:
|
||||
default_options.update(options)
|
||||
|
||||
body = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"options": default_options
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
json_body = json.dumps(body)
|
||||
|
||||
# Retry logic for connection stability
|
||||
for attempt in range(3):
|
||||
conn = None
|
||||
try:
|
||||
# Re-serialize in case of modification (CPU fallback)
|
||||
json_body = json.dumps(body)
|
||||
|
||||
conn = OLLAMA_POOL.get_connection()
|
||||
conn.request("POST", "/api/chat", json_body, headers)
|
||||
response = conn.getresponse()
|
||||
|
||||
if stream:
|
||||
if response.status != 200:
|
||||
error_text = response.read().decode('utf-8')
|
||||
conn.close()
|
||||
|
||||
# GPU OOM Detection -> CPU Fallback
|
||||
if "CUDA" in error_text or "buffer" in error_text:
|
||||
if "num_gpu" not in body["options"]:
|
||||
print("⚠️ GPU OOM detected. Switching to CPU mode...")
|
||||
body["options"]["num_gpu"] = 0 # Force CPU
|
||||
continue
|
||||
|
||||
try:
|
||||
err_msg = f"Error {response.status}: {json.loads(error_text).get('error', error_text)}"
|
||||
except:
|
||||
err_msg = f"Error {response.status}: {error_text}"
|
||||
|
||||
if "num_gpu" in body["options"]:
|
||||
err_msg += "\n\n(⚠️ CPU Mode also failed. System RAM might be full.)"
|
||||
elif "CUDA" in err_msg or "buffer" in err_msg:
|
||||
err_msg += "\n\n(⚠️ GPU Out of Memory. Retrying on CPU failed.)"
|
||||
|
||||
return (x for x in [err_msg])
|
||||
|
||||
return self._stream_response(response, conn)
|
||||
|
||||
if response.status == 200:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
OLLAMA_POOL.return_connection(conn)
|
||||
return data.get("message", {}).get("content", "No response")
|
||||
else:
|
||||
error_text = response.read().decode('utf-8')
|
||||
conn.close()
|
||||
|
||||
if "CUDA" in error_text or "buffer" in error_text:
|
||||
if "num_gpu" not in body["options"]:
|
||||
print("⚠️ GPU OOM detected. Switching to CPU mode...")
|
||||
body["options"]["num_gpu"] = 0
|
||||
continue
|
||||
|
||||
return f"Error {response.status}: {error_text}"
|
||||
|
||||
except (http.client.NotConnected, BrokenPipeError, ConnectionResetError, socket.timeout) as e:
|
||||
if conn: conn.close()
|
||||
if attempt == 2:
|
||||
return f"Error: Connection failed. {str(e)}"
|
||||
continue
|
||||
except Exception as e:
|
||||
if conn: conn.close()
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def _stream_response(self, response, conn) -> Generator[str, None, None]:
|
||||
"""Yield chunks from HTTP response"""
|
||||
fully_consumed = False
|
||||
has_content = False
|
||||
try:
|
||||
while True:
|
||||
line = response.readline()
|
||||
if not line: break
|
||||
try:
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
if "message" in data:
|
||||
content = data["message"].get("content", "")
|
||||
if content:
|
||||
has_content = True
|
||||
yield content
|
||||
if data.get("done"):
|
||||
fully_consumed = True
|
||||
break
|
||||
except: pass
|
||||
except Exception as e:
|
||||
yield f"\n[Stream Error: {str(e)}]"
|
||||
finally:
|
||||
if fully_consumed:
|
||||
OLLAMA_POOL.return_connection(conn)
|
||||
else:
|
||||
conn.close()
|
||||
|
||||
if not has_content and not fully_consumed:
|
||||
yield "\n[Error: Empty response from Ollama. Check if model is loaded.]"
|
||||
|
||||
def reset_gpu(self) -> str:
|
||||
"""Force unload models from GPU to free VRAM"""
|
||||
try:
|
||||
conn = http.client.HTTPConnection(OLLAMA_HOST, OLLAMA_PORT, timeout=10)
|
||||
for model in MODELS.values():
|
||||
body = json.dumps({"model": model, "keep_alive": 0})
|
||||
conn.request("POST", "/api/generate", body)
|
||||
resp = conn.getresponse()
|
||||
resp.read()
|
||||
conn.close()
|
||||
return "✅ GPU Memory Cleared (Models Unloaded)"
|
||||
except Exception as e:
|
||||
return f"❌ Error clearing GPU: {str(e)}"
|
||||
360
core/buddai_memory.py
Normal file
360
core/buddai_memory.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys, os, json, logging, sqlite3, datetime, pathlib, http.client, re, typing, zipfile, shutil, queue, socket, argparse, io, difflib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Tuple, Union, Generator
|
||||
|
||||
from core.buddai_shared import DB_PATH, MODULE_PATTERNS
|
||||
|
||||
class ShadowSuggestionEngine:
|
||||
"""Proactively suggests modules/settings based on user/project history."""
|
||||
def __init__(self, db_path: Path, user_id: str = "default"):
|
||||
self.db_path = db_path
|
||||
self.user_id = user_id
|
||||
|
||||
def lookup_recent_module_usage(self, module: str, limit: int = 5) -> List[Tuple[str, str, str]]:
|
||||
"""Look up recent usage patterns for a module from repo_index."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT file_path, content, last_modified FROM repo_index
|
||||
WHERE (function_name LIKE ? OR file_path LIKE ?) AND user_id = ?
|
||||
ORDER BY last_modified DESC LIMIT ?
|
||||
""",
|
||||
(f"%{module}%", f"%{module}%", self.user_id, limit)
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
def suggest_for_module(self, module: str) -> Optional[str]:
|
||||
"""Return a proactive suggestion string for a module if pattern detected."""
|
||||
history = self.lookup_recent_module_usage(module)
|
||||
if not history:
|
||||
return None
|
||||
# Example: For 'motor', look for L298N and PWM frequency
|
||||
l298n_count = 0
|
||||
pwm_freqs = []
|
||||
for _, content, _ in history:
|
||||
if "L298N" in content or "l298n" in content:
|
||||
l298n_count += 1
|
||||
pwm_matches = re.findall(r'PWM_FREQ\s*=\s*(\d+)', content)
|
||||
pwm_freqs.extend([int(f) for f in pwm_matches])
|
||||
# Also look for explicit frequency in analogWrite or ledcSetup
|
||||
freq_matches = re.findall(r'(?:ledcSetup|analogWrite)\s*\([^,]+,\s*[^,]+,\s*(\d+)\)', content)
|
||||
pwm_freqs.extend([int(f) for f in freq_matches if f.isdigit()])
|
||||
if l298n_count >= 2:
|
||||
freq = max(set(pwm_freqs), key=pwm_freqs.count) if pwm_freqs else 500
|
||||
return f"I see you usually use the L298N with a {freq}Hz PWM frequency on the ESP32-C3. Should I prep that module?"
|
||||
return None
|
||||
|
||||
def get_proactive_suggestion(self, user_input: str) -> Optional[str]:
|
||||
"""
|
||||
V3.0 Proactive Hook:
|
||||
1. Identify "Concept" (e.g., 'flipper')
|
||||
2. Query repo_index for James's most frequent companion modules
|
||||
3. If 'flipper' often appears with 'safety_timeout', suggest it.
|
||||
"""
|
||||
# 1. Identify Concepts
|
||||
input_lower = user_input.lower()
|
||||
detected_modules = []
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
if any(kw in input_lower for kw in keywords):
|
||||
detected_modules.append(module)
|
||||
|
||||
if not detected_modules:
|
||||
return None
|
||||
|
||||
# 2. Query repo_index for correlations
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
suggestions = []
|
||||
for module in detected_modules:
|
||||
# Find files containing this module (simple heuristic)
|
||||
cursor.execute("SELECT content FROM repo_index WHERE content LIKE ? AND user_id = ? LIMIT 10", (f"%{module}%", self.user_id))
|
||||
rows = cursor.fetchall()
|
||||
if not rows: continue
|
||||
|
||||
# Check for companion modules
|
||||
companions = {}
|
||||
for (content,) in rows:
|
||||
content_lower = content.lower()
|
||||
for other_mod, other_kws in MODULE_PATTERNS.items():
|
||||
if other_mod != module and other_mod not in detected_modules:
|
||||
if any(kw in content_lower for kw in other_kws):
|
||||
companions[other_mod] = companions.get(other_mod, 0) + 1
|
||||
|
||||
# 3. Suggest if frequent (>50% correlation in sample)
|
||||
for other_mod, count in companions.items():
|
||||
if count >= len(rows) * 0.5:
|
||||
suggestions.append(f"I noticed '{module}' often appears with '{other_mod}' in your repos. Want to include that?")
|
||||
|
||||
conn.close()
|
||||
return " ".join(list(set(suggestions))) if suggestions else None
|
||||
|
||||
def get_all_suggestions(self, user_input: str, generated_code: str) -> List[str]:
|
||||
"""Aggregate all proactive suggestions into a list."""
|
||||
suggestions = []
|
||||
|
||||
# 1. Companion Modules
|
||||
companion = self.get_proactive_suggestion(user_input)
|
||||
if companion:
|
||||
suggestions.append(companion)
|
||||
|
||||
# 2. Module Settings
|
||||
input_lower = user_input.lower()
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
if any(kw in input_lower for kw in keywords):
|
||||
s = self.suggest_for_module(module)
|
||||
if s:
|
||||
suggestions.append(s)
|
||||
|
||||
# 3. Forge Theory Check
|
||||
if ("motor" in input_lower or "servo" in input_lower) and "applyForge" not in generated_code:
|
||||
suggestions.append("Apply Forge Theory smoothing to movement?")
|
||||
|
||||
# 4. Safety Check (L298N)
|
||||
if "L298N" in generated_code and "safety" not in generated_code.lower():
|
||||
suggestions.append("Drive system lacks safety timeout (GilBot_V2 uses 5s failsafe). Add that?")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
|
||||
class AdaptiveLearner:
|
||||
"""Learn from every interaction"""
|
||||
|
||||
def learn_from_session(self, session_id: str):
|
||||
"""Analyze what worked/failed in a session"""
|
||||
print(f"🧠 Adaptive Learning: Analyzing Session {session_id}...")
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all messages in session
|
||||
cursor.execute("""
|
||||
SELECT id, role, content
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY id ASC
|
||||
""", (session_id,))
|
||||
|
||||
messages = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
count = 0
|
||||
# Look for correction patterns
|
||||
for i, (msg_id, role, content) in enumerate(messages):
|
||||
if role == 'user' and i > 0:
|
||||
prev_msg = messages[i-1]
|
||||
prev_role = prev_msg[1]
|
||||
prev_content = prev_msg[2]
|
||||
|
||||
if prev_role == 'assistant':
|
||||
# Did James correct the previous response?
|
||||
if self.is_correction(content, prev_content):
|
||||
print(f" - Detected correction in msg #{msg_id}")
|
||||
self.learn_correction(prev_content, content)
|
||||
count += 1
|
||||
|
||||
# Did James ask for modification?
|
||||
if self.is_modification(content):
|
||||
print(f" - Detected preference in msg #{msg_id}")
|
||||
self.learn_preference(content)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
print(" - No obvious corrections found.")
|
||||
|
||||
def is_correction(self, user_msg: str, ai_msg: str) -> bool:
|
||||
"""Detect if user is correcting AI"""
|
||||
correction_signals = [
|
||||
"actually", "no,", "wrong", "should be", "instead of",
|
||||
"not", "use", "don't use", "change", "fix", "error", "bug"
|
||||
]
|
||||
return any(signal in user_msg.lower() for signal in correction_signals)
|
||||
|
||||
def is_modification(self, user_msg: str) -> bool:
|
||||
"""Detect if user is expressing a preference"""
|
||||
signals = ["prefer", "i like", "always use", "style", "better", "make it"]
|
||||
return any(s in user_msg.lower() for s in signals)
|
||||
|
||||
def learn_correction(self, original: str, correction: str):
|
||||
"""Extract the lesson from a correction"""
|
||||
# Save the rule (Generic capture for now)
|
||||
rule_text = correction.split('\n')[0][:100]
|
||||
self.save_rule(rule_text, "context_dependent", correction[:100], confidence=0.5)
|
||||
|
||||
def learn_preference(self, content: str):
|
||||
"""Extract preference"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO style_preferences (user_id, category, preference, confidence, extracted_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", ("default", "learned_preference", content[:200], 0.6, datetime.now().isoformat()))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_rule(self, rule_text, find, replace, confidence):
|
||||
"""Save to code_rules table"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO code_rules
|
||||
(rule_text, pattern_find, pattern_replace, confidence, learned_from)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (rule_text, find, replace, confidence, 'adaptive_session'))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
|
||||
class SmartLearner:
|
||||
"""Extract patterns from corrections"""
|
||||
|
||||
def analyze_corrections(self, ai_interface=None):
|
||||
"""Find common patterns in your fixes"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure processed column exists
|
||||
try:
|
||||
cursor.execute("ALTER TABLE corrections ADD COLUMN processed BOOLEAN DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# Check pending count
|
||||
cursor.execute("SELECT COUNT(*) FROM corrections WHERE processed IS NOT 1")
|
||||
pending_count = cursor.fetchone()[0]
|
||||
|
||||
if pending_count == 0:
|
||||
conn.close()
|
||||
return []
|
||||
|
||||
# Process in small batches
|
||||
cursor.execute("""
|
||||
SELECT id, original_code, corrected_code, reason
|
||||
FROM corrections
|
||||
WHERE processed IS NOT 1
|
||||
LIMIT 5
|
||||
""")
|
||||
|
||||
corrections = cursor.fetchall()
|
||||
print(f" Processing {len(corrections)} of {pending_count} pending corrections...")
|
||||
patterns = []
|
||||
|
||||
for row_id, original, corrected, reason in corrections:
|
||||
# Strategy 1: Diff based (if corrected code exists)
|
||||
if corrected and original:
|
||||
# Extract what changed
|
||||
diff = self.diff_code(original, corrected)
|
||||
|
||||
# Classify the change
|
||||
if "analogWrite" in original and "ledcWrite" in corrected:
|
||||
patterns.append({
|
||||
"rule": "ESP32 uses ledcWrite not analogWrite",
|
||||
"find": "analogWrite",
|
||||
"replace": "ledcWrite",
|
||||
"hardware": "ESP32",
|
||||
"confidence": 1.0
|
||||
})
|
||||
|
||||
if "delay(" in original and "millis()" in corrected:
|
||||
patterns.append({
|
||||
"rule": "Use non-blocking millis() not delay()",
|
||||
"find": "delay\\(",
|
||||
"replace": "millis() based timing",
|
||||
"confidence": 0.9
|
||||
})
|
||||
|
||||
# Strategy 2: Reason based (LLM extraction)
|
||||
if reason and ai_interface:
|
||||
print(f" - Analyzing #{row_id}...", end="\r")
|
||||
# Use LLM to extract rule from text reason
|
||||
prompt = f"""You are extracting specific coding patterns from a user correction.
|
||||
|
||||
CRITICAL INSTRUCTIONS:
|
||||
1. If the correction contains code, formulas, or specific syntax, PRESERVE IT VERBATIM.
|
||||
2. Do NOT generalize. (e.g. DO NOT say "Use a smoothing formula". SAY "Use: current += (target - current) * k")
|
||||
3. Capture specific variable names, types, and values if mentioned.
|
||||
4. If the user provides a code snippet, the rule MUST contain that snippet.
|
||||
|
||||
User Correction:
|
||||
"{reason}"
|
||||
|
||||
Return ONLY the rule in this format (no markdown, no quotes):
|
||||
Rule: <specific technical rule with exact code/formulas>
|
||||
"""
|
||||
try:
|
||||
response = ai_interface.call_model("balanced", prompt, system_task=True)
|
||||
for line in response.splitlines():
|
||||
clean_line = line.strip().replace("**", "").replace("__", "")
|
||||
rule_text = None
|
||||
if "rule:" in clean_line.lower():
|
||||
parts = clean_line.split(":", 1)
|
||||
rule_text = parts[1].strip() if len(parts) > 1 else clean_line
|
||||
elif re.match(r'^[\d-]+\.', clean_line) or clean_line.startswith("- "):
|
||||
rule_text = re.sub(r'^[\d-]+\.?\s*', '', clean_line).strip()
|
||||
|
||||
if rule_text and len(rule_text) > 10 and rule_text != reason:
|
||||
patterns.append({
|
||||
"rule": rule_text,
|
||||
"find": "",
|
||||
"replace": "",
|
||||
"confidence": 0.85
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Mark as processed immediately
|
||||
cursor.execute("UPDATE corrections SET processed = 1 WHERE id = ?", (row_id,))
|
||||
conn.commit()
|
||||
|
||||
print(" - Batch complete. ")
|
||||
conn.close()
|
||||
|
||||
# Store learned rules
|
||||
if patterns:
|
||||
self.save_rules(patterns)
|
||||
|
||||
return patterns
|
||||
|
||||
def save_rules(self, patterns):
|
||||
"""Save to code_rules table"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS code_rules (
|
||||
id INTEGER PRIMARY KEY,
|
||||
rule_text TEXT,
|
||||
pattern_find TEXT,
|
||||
pattern_replace TEXT,
|
||||
context TEXT,
|
||||
confidence FLOAT,
|
||||
learned_from TEXT,
|
||||
times_applied INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
for p in patterns:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO code_rules
|
||||
(rule_text, pattern_find, pattern_replace, confidence, learned_from)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (p['rule'], p['find'], p['replace'], p['confidence'], 'corrections'))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def diff_code(self, original: str, corrected: str) -> str:
|
||||
"""Generate a simple diff"""
|
||||
return "\n".join(difflib.unified_diff(
|
||||
original.splitlines(),
|
||||
corrected.splitlines(),
|
||||
fromfile='original',
|
||||
tofile='corrected',
|
||||
lineterm=''
|
||||
))
|
||||
126
core/buddai_personality.py
Normal file
126
core/buddai_personality.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Union, List
|
||||
from datetime import datetime
|
||||
|
||||
class PersonalityManager:
|
||||
"""Manages AI personality, prompts, and user schedules"""
|
||||
|
||||
def __init__(self):
|
||||
self.personality = self.load_personality()
|
||||
self.validate_personality_schema()
|
||||
|
||||
def load_personality(self) -> Dict:
|
||||
"""Loads personality from a JSON file."""
|
||||
personality_path = Path(__file__).parent.parent / "personality.json"
|
||||
if personality_path.exists():
|
||||
print("🧠 Loading custom personality...")
|
||||
with open(personality_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
# Default personality if file doesn't exist
|
||||
print("🧠 Using default 'James' personality.")
|
||||
return {
|
||||
"user_name": "James",
|
||||
"ai_name": "BuddAI",
|
||||
"welcome_message": "BuddAI Executive v4.0 - Decoupled & Personality Sync",
|
||||
"schedule_check_triggers": ["what should i be doing", "my schedule", "schedule check"],
|
||||
"schedule": {
|
||||
"weekdays": {"0-4": {"5.5-6.5": "Early Morning Build Session 🌅 (5:30-6:30 AM)", "6.5-17.0": "Work Hours (Facilities Caretaker) 🏢", "17.0-21.0": "Evening Build Session 🌙 (5:00-9:00 PM)", "default": "Rest Time 💤"}},
|
||||
"saturday": { "5": { "default": "Weekend Freedom 🎨 (Creative Mode)" } },
|
||||
"sunday": { "6": { "0-21.0": "Weekend Freedom 🎨 (Until 9 PM)", "default": "Rest Time 💤" } }
|
||||
},
|
||||
"style_scan_prompt": "Analyze this code sample from {user_name}'s repositories.\nExtract 3 distinct coding preferences or patterns.",
|
||||
"style_reference_prompt": "\n[REFERENCE STYLE FROM {user_name}'S PAST PROJECTS]\n",
|
||||
"integration_task_prompt": "INTEGRATION TASK: Combine modules into a cohesive GilBot system.\n\n[MODULES]\n{modules_summary}\n\n[FORGE PARAMETERS]\nSet k = {k_val} for all applyForge() calls.\n\n[REQUIREMENTS]\n1. Implement applyForge() math helper.\n2. Use k={k_val} to smooth motor and servo transitions.\n3. Ensure naming matches {user_name}'s style: activateFlipper(), setMotors()."
|
||||
}
|
||||
|
||||
def validate_personality_schema(self) -> bool:
|
||||
"""Validate the loaded personality against required schema."""
|
||||
if not self.personality:
|
||||
return False
|
||||
|
||||
required_structure = {
|
||||
"meta": ["version"],
|
||||
"identity": ["user_name", "ai_name"],
|
||||
"communication": ["welcome_message"],
|
||||
"work_cycles": ["schedule"],
|
||||
"forge_theory": ["enabled", "constants"],
|
||||
"prompts": ["style_scan", "integration_task"]
|
||||
}
|
||||
|
||||
missing = []
|
||||
|
||||
version = self.get_value("meta.version")
|
||||
if version and version != "4.0":
|
||||
print(f"⚠️ Warning: Personality version mismatch. Loaded: {version}, Expected: 4.0")
|
||||
|
||||
for section, keys in required_structure.items():
|
||||
if section not in self.personality:
|
||||
missing.append(f"Missing section: {section}")
|
||||
continue
|
||||
|
||||
for key in keys:
|
||||
if key not in self.personality[section]:
|
||||
missing.append(f"Missing key: {section}.{key}")
|
||||
|
||||
if missing:
|
||||
print("⚠️ Personality Schema Validation Failed:")
|
||||
for m in missing:
|
||||
print(f" - {m}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_value(self, path: Union[str, List[str]], default: Any = None) -> Any:
|
||||
"""Access nested personality keys using dot notation or list of keys."""
|
||||
if isinstance(path, str):
|
||||
keys = path.split('.')
|
||||
else:
|
||||
keys = path
|
||||
|
||||
val = self.personality
|
||||
for key in keys:
|
||||
if isinstance(val, dict):
|
||||
val = val.get(key)
|
||||
else:
|
||||
return default
|
||||
return val if val is not None else default
|
||||
|
||||
def get_user_status(self) -> str:
|
||||
"""Determine user's context based on defined schedule from personality file."""
|
||||
schedule = self.get_value("work_cycles.schedule")
|
||||
if not schedule:
|
||||
# Fallback for simple personality files
|
||||
schedule = self.personality.get("schedule")
|
||||
if not schedule:
|
||||
return "Schedule not defined."
|
||||
|
||||
now = datetime.now()
|
||||
day = now.weekday() # 0=Mon, 6=Sun
|
||||
t = now.hour + (now.minute / 60.0)
|
||||
|
||||
# Check all schedule groups (e.g., 'weekdays', 'weekends')
|
||||
for group, day_ranges in schedule.items():
|
||||
for day_range, time_slots in day_ranges.items():
|
||||
try:
|
||||
# Parse day range (e.g., "0-4" or "5")
|
||||
if '-' in day_range:
|
||||
start_day, end_day = map(int, day_range.split('-'))
|
||||
if not (start_day <= day <= end_day):
|
||||
continue
|
||||
elif int(day_range) != day:
|
||||
continue
|
||||
|
||||
# We found the right day group, now check time slots
|
||||
for time_range, status in time_slots.items():
|
||||
if time_range == "default": continue
|
||||
start_time, end_time = map(float, time_range.split('-'))
|
||||
if start_time <= t < end_time:
|
||||
return status.get("description", status) if isinstance(status, dict) else status
|
||||
|
||||
default_status = time_slots.get("default", "No status for this time.")
|
||||
return default_status.get("description", default_status) if isinstance(default_status, dict) else default_status
|
||||
|
||||
except (ValueError, TypeError): continue
|
||||
return "No schedule match for today."
|
||||
310
core/buddai_prompt_engine.py
Normal file
310
core/buddai_prompt_engine.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
import sqlite3
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
from core.buddai_shared import DB_PATH, COMPLEX_TRIGGERS, MODULE_PATTERNS
|
||||
|
||||
class PromptEngine:
|
||||
"""Handles prompt construction, hardware classification, and request analysis"""
|
||||
|
||||
def classify_hardware(self, user_message: str, context_messages: List[Dict] = None) -> dict:
|
||||
"""Detect what hardware this question is about"""
|
||||
|
||||
hardware = {
|
||||
"servo": False,
|
||||
"dc_motor": False,
|
||||
"button": False,
|
||||
"led": False,
|
||||
"sensor": False,
|
||||
"weapon": False
|
||||
}
|
||||
|
||||
msg_lower = user_message.lower()
|
||||
|
||||
# Helper to check keywords
|
||||
def has_keywords(text, keywords):
|
||||
return any(word in text for word in keywords)
|
||||
|
||||
# Keyword definitions
|
||||
servo_kws = ['servo', 'mg996', 'sg90']
|
||||
motor_kws = ['l298n', 'dc motor', 'motor driver', 'motor control']
|
||||
button_kws = ['button', 'switch', 'trigger']
|
||||
led_kws = ['led', 'light', 'brightness', 'indicator']
|
||||
# Removed 'state machine' from weapon_kws to allow abstract logic
|
||||
weapon_kws = ['weapon', 'combat', 'arming', 'fire', 'spinner', 'flipper']
|
||||
logic_kws = ['state machine', 'logic', 'structure', 'flow', 'armed', 'disarmed']
|
||||
|
||||
# 1. Check current message first
|
||||
detected_in_current = False
|
||||
|
||||
if has_keywords(msg_lower, servo_kws):
|
||||
hardware["servo"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, motor_kws):
|
||||
hardware["dc_motor"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, button_kws):
|
||||
hardware["button"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, led_kws):
|
||||
hardware["led"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, weapon_kws):
|
||||
hardware["weapon"] = True
|
||||
detected_in_current = True
|
||||
if has_keywords(msg_lower, logic_kws):
|
||||
# Logic detected: Clear context (don't set any hardware)
|
||||
detected_in_current = True
|
||||
|
||||
# 2. Context Switching: Only look back if NO hardware/logic detected in current message
|
||||
# and message is short (likely a follow-up command like "make it spin")
|
||||
if not detected_in_current and len(user_message.split()) < 10 and context_messages:
|
||||
recent = " ".join([m['content'].lower() for m in context_messages[-2:] if m['role'] == 'user'])
|
||||
|
||||
if has_keywords(recent, servo_kws): hardware["servo"] = True
|
||||
if has_keywords(recent, motor_kws): hardware["dc_motor"] = True
|
||||
if has_keywords(recent, button_kws): hardware["button"] = True
|
||||
if has_keywords(recent, led_kws): hardware["led"] = True
|
||||
if has_keywords(recent, weapon_kws): hardware["weapon"] = True
|
||||
|
||||
return hardware
|
||||
|
||||
def get_all_rules(self) -> List[str]:
|
||||
"""Get all learned rules as text"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT rule_text FROM code_rules ORDER BY confidence DESC LIMIT 50")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def filter_rules_by_hardware(self, all_rules, hardware):
|
||||
"""Only return rules relevant to detected hardware"""
|
||||
|
||||
relevant_rules = []
|
||||
|
||||
# Define rule categories
|
||||
servo_kws = ['servo', 'attach', 'setperiodhertz']
|
||||
motor_kws = ['l298n', 'in1', 'in2', 'motor driver']
|
||||
weapon_kws = ['arming', 'disarm', 'fire', 'combat']
|
||||
button_kws = ['button', 'switch', 'debounce', 'digitalread', 'input_pullup']
|
||||
|
||||
has_specific_context = hardware["servo"] or hardware["dc_motor"] or hardware["weapon"] or hardware["button"]
|
||||
|
||||
for rule in all_rules:
|
||||
rule_lower = rule.lower()
|
||||
|
||||
is_servo_rule = any(w in rule_lower for w in servo_kws)
|
||||
is_motor_rule = any(w in rule_lower for w in motor_kws)
|
||||
is_weapon_rule = any(w in rule_lower for w in weapon_kws)
|
||||
is_button_rule = any(w in rule_lower for w in button_kws)
|
||||
|
||||
# Pattern Over-application: Strict filtering
|
||||
if has_specific_context:
|
||||
if hardware["dc_motor"] and not hardware["servo"] and is_servo_rule: continue
|
||||
if hardware["servo"] and not hardware["dc_motor"] and is_motor_rule: continue
|
||||
if not hardware["weapon"] and is_weapon_rule: continue
|
||||
if not hardware["button"] and is_button_rule: continue
|
||||
|
||||
# If question is about weapons (logic), EXCLUDE servo rules unless servo explicitly requested
|
||||
if hardware["weapon"] and not hardware["servo"] and is_servo_rule: continue
|
||||
|
||||
else:
|
||||
# Generic context: Exclude all specific hardware rules
|
||||
if is_servo_rule or is_motor_rule or is_weapon_rule or is_button_rule: continue
|
||||
|
||||
relevant_rules.append(rule)
|
||||
|
||||
return relevant_rules
|
||||
|
||||
def build_enhanced_prompt(self, user_message: str, hardware_detected: str = None, context_messages: List[Dict] = None) -> str:
|
||||
"""Build prompt with FILTERED rules"""
|
||||
|
||||
# Classify hardware
|
||||
hardware = self.classify_hardware(user_message, context_messages)
|
||||
|
||||
# Get ALL rules
|
||||
all_rules = self.get_all_rules()
|
||||
|
||||
# Filter by relevance
|
||||
relevant_rules = self.filter_rules_by_hardware(all_rules, hardware)
|
||||
|
||||
# Build focused prompt
|
||||
hardware_context = []
|
||||
if hardware["servo"]: hardware_context.append("SERVO CONTROL")
|
||||
if hardware["dc_motor"]: hardware_context.append("DC MOTOR CONTROL")
|
||||
if hardware["button"]: hardware_context.append("BUTTON INPUTS")
|
||||
if hardware["led"]: hardware_context.append("LED STATUS")
|
||||
if hardware["weapon"]: hardware_context.append("WEAPON SYSTEM")
|
||||
|
||||
l298n_rules = ""
|
||||
if hardware["dc_motor"]:
|
||||
l298n_rules = """
|
||||
- L298N WIRING RULES (MANDATORY):
|
||||
1. IN1/IN2 = Digital Output (Direction). Use digitalWrite().
|
||||
2. ENA = PWM Output (Speed). Use ledcWrite().
|
||||
3. To Move: IN1/IN2 must be OPPOSITE (HIGH/LOW).
|
||||
4. To Stop: IN1/IN2 both LOW.
|
||||
5. DO NOT treat Motors like Servos (No 'position' or 'angle').
|
||||
- SAFETY RULES (MANDATORY):
|
||||
1. Implement a safety timeout (e.g., 5000ms).
|
||||
2. Stop motors if no signal is received within timeout.
|
||||
3. Use millis() for non-blocking timing.
|
||||
"""
|
||||
|
||||
weapon_rules = ""
|
||||
if hardware.get("weapon"):
|
||||
weapon_rules = """
|
||||
- COMBAT PROTOCOL (MANDATORY):
|
||||
1. LOGIC FOCUS: This is a State Machine request, NOT just servo movement.
|
||||
2. STATES: enum State { DISARMED, ARMING, ARMED, FIRING };
|
||||
3. TRANSITIONS: DISARMED -> ARMING (2s delay) -> ARMED -> FIRING.
|
||||
4. SAFETY: Auto-disarm after 10s idle. Fire only when ARMED.
|
||||
5. STRUCTURE: Use switch(currentState) { case ... } for logic.
|
||||
6. OUTPUTS: Control relays/LEDs/Motors based on state.
|
||||
"""
|
||||
|
||||
# Anti-bloat rules
|
||||
anti_bloat_rules = []
|
||||
if not hardware["button"]:
|
||||
anti_bloat_rules.append("- NO EXTRA INPUTS: Do NOT add buttons, switches, or digitalRead() unless explicitly requested.")
|
||||
if not hardware["servo"]:
|
||||
anti_bloat_rules.append("- NO EXTRA SERVOS: Do NOT add Servo objects or attach() unless explicitly requested.")
|
||||
if not hardware["dc_motor"]:
|
||||
anti_bloat_rules.append("- NO EXTRA MOTORS: Do NOT add motor driver code (L298N) unless explicitly requested.")
|
||||
|
||||
anti_bloat = "\n".join(anti_bloat_rules)
|
||||
|
||||
# Modularity rule
|
||||
modularity_rule = ""
|
||||
if "function" in user_message.lower() or "naming" in user_message.lower() or "modular" in user_message.lower():
|
||||
modularity_rule = """
|
||||
- CODE STRUCTURE (MANDATORY):
|
||||
1. NO MONOLITHIC LOOP: Break code into small, descriptive functions.
|
||||
2. NAMING: Use camelCase for functions (e.g., readBatteryVoltage(), updateDisplay()).
|
||||
3. loop() must ONLY call these functions, not contain raw logic.
|
||||
"""
|
||||
|
||||
# Status LED rule
|
||||
status_led_rule = ""
|
||||
if hardware["led"] and ("status" in user_message.lower() or "indicator" in user_message.lower()):
|
||||
status_led_rule = """
|
||||
- STATUS LED RULES (MANDATORY):
|
||||
1. NO BREATHING/FADING: Do not use simple PWM fading loops.
|
||||
2. USE STATES: Define enum LEDStatus { OFF, IDLE, ACTIVE, ERROR };
|
||||
3. IMPLEMENTATION: Create void setStatusLED(LEDStatus state).
|
||||
4. PATTERNS: IDLE=Slow Blink, ACTIVE=Solid On, ERROR=Fast Blink.
|
||||
"""
|
||||
|
||||
prompt = f"""You are generating code for: {', '.join(hardware_context)}
|
||||
You are an expert embedded developer.
|
||||
TARGET HARDWARE: {hardware_detected}
|
||||
ACTIVE MODULES: {', '.join(hardware_context) if hardware_context else "None (Logic Only)"}
|
||||
|
||||
CRITICAL: Only use code patterns relevant to the hardware mentioned.
|
||||
STRICT NEGATIVE CONSTRAINTS (DO NOT IGNORE):
|
||||
{anti_bloat}
|
||||
|
||||
MANDATORY HARDWARE RULES:
|
||||
{l298n_rules}
|
||||
{weapon_rules}
|
||||
{status_led_rule}
|
||||
{anti_bloat}
|
||||
{modularity_rule}
|
||||
|
||||
GENERAL GUIDELINES:
|
||||
- If DC MOTOR: Use L298N patterns (digitalWrite, ledcWrite)
|
||||
- If SERVO: Use ESP32Servo patterns (attach, write)
|
||||
- DO NOT mix servo code into motor questions
|
||||
- DO NOT mix motor code into servo questions
|
||||
|
||||
CRITICAL RULES (MUST FOLLOW):
|
||||
{chr(10).join(relevant_rules)}
|
||||
|
||||
USER REQUEST:
|
||||
{user_message}
|
||||
|
||||
Generate code following ALL rules above. Do not add unrequested features.
|
||||
FINAL CHECK:
|
||||
1. Did you add unrequested buttons? REMOVE THEM.
|
||||
2. Did you add unrequested servos? REMOVE THEM.
|
||||
3. Generate code ONLY for the hardware requested.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def is_simple_question(self, message: str) -> bool:
|
||||
"""Check if this is a simple question that should use FAST model"""
|
||||
message_lower = message.lower()
|
||||
|
||||
simple_triggers = [
|
||||
"what is", "what's", "who is", "who's", "when is",
|
||||
"how do i", "can you explain", "tell me about",
|
||||
"what are", "where is", "hi", "hello", "hey",
|
||||
"good morning", "good evening"
|
||||
]
|
||||
|
||||
# Also check if it's just a question without code keywords
|
||||
code_keywords = ["generate", "create", "write", "build", "code", "function"]
|
||||
|
||||
has_simple_trigger = any(trigger in message_lower for trigger in simple_triggers)
|
||||
has_code_keyword = any(keyword in message_lower for keyword in code_keywords)
|
||||
|
||||
# Simple if: has simple trigger AND no code keywords
|
||||
return has_simple_trigger and not has_code_keyword
|
||||
|
||||
def is_complex(self, message: str) -> bool:
|
||||
"""Check if request is too complex and should be broken down"""
|
||||
message_lower = message.lower()
|
||||
|
||||
# Count complexity triggers
|
||||
trigger_count = sum(1 for trigger in COMPLEX_TRIGGERS if trigger in message_lower)
|
||||
|
||||
# Count how many modules mentioned
|
||||
module_count = 0
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
# module is used for key, keywords for values
|
||||
if any(kw in message_lower for kw in keywords):
|
||||
module_count += 1
|
||||
|
||||
# Complex if: multiple triggers OR 3+ modules mentioned
|
||||
return trigger_count >= 2 or module_count >= 3
|
||||
|
||||
def extract_modules(self, message: str) -> List[str]:
|
||||
"""Extract which modules are needed"""
|
||||
message_lower = message.lower()
|
||||
needed_modules = []
|
||||
|
||||
for module, keywords in MODULE_PATTERNS.items():
|
||||
# module is used for key, keywords for values
|
||||
if any(kw in message_lower for kw in keywords):
|
||||
needed_modules.append(module)
|
||||
|
||||
return needed_modules
|
||||
|
||||
def build_modular_plan(self, modules: List[str]) -> List[Dict[str, str]]:
|
||||
"""Create a build plan from modules"""
|
||||
plan = []
|
||||
|
||||
module_tasks = {
|
||||
"ble": "BLE communication setup with phone app control",
|
||||
"servo": "Servo motor control for flipper/weapon",
|
||||
"motor": "Motor driver setup for movement (L298N)",
|
||||
"safety": "Safety timeout and failsafe systems",
|
||||
"battery": "Battery voltage monitoring",
|
||||
"sensor": "Sensor integration (distance/proximity)"
|
||||
}
|
||||
|
||||
for module in modules:
|
||||
if module in module_tasks:
|
||||
plan.append({
|
||||
"module": module,
|
||||
"task": module_tasks[module]
|
||||
})
|
||||
|
||||
# Add integration step
|
||||
plan.append({
|
||||
"module": "integration",
|
||||
"task": "Integrate all modules into complete system"
|
||||
})
|
||||
|
||||
return plan
|
||||
53
core/buddai_shared.py
Normal file
53
core/buddai_shared.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
import queue
|
||||
import http.client
|
||||
|
||||
# Global Config
|
||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||
DB_PATH = DATA_DIR / "conversations.db"
|
||||
OLLAMA_HOST = os.getenv("OLLAMA_HOST", "127.0.0.1")
|
||||
OLLAMA_PORT = int(os.getenv("OLLAMA_PORT", "11434"))
|
||||
|
||||
# Shared Models
|
||||
MODELS = {
|
||||
"fast": "qwen2.5-coder:1.5b",
|
||||
"balanced": "qwen2.5-coder:3b"
|
||||
}
|
||||
|
||||
# Shared Connection Pool logic to avoid "port in use" or "too many connections" errors
|
||||
class OllamaConnectionPool:
|
||||
def __init__(self, host, port, max_size=10):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.pool = queue.Queue(maxsize=max_size)
|
||||
def get_connection(self):
|
||||
try: return self.pool.get_nowait()
|
||||
except: return http.client.HTTPConnection(self.host, self.port, timeout=90)
|
||||
def return_connection(self, conn):
|
||||
try: self.pool.put_nowait(conn)
|
||||
except: conn.close()
|
||||
|
||||
OLLAMA_POOL = OllamaConnectionPool(OLLAMA_HOST, OLLAMA_PORT)
|
||||
|
||||
# Server Availability Check
|
||||
try:
|
||||
import fastapi
|
||||
import uvicorn
|
||||
SERVER_AVAILABLE = True
|
||||
except ImportError:
|
||||
SERVER_AVAILABLE = False
|
||||
|
||||
# Shared Patterns
|
||||
COMPLEX_TRIGGERS = [
|
||||
"multiple modules", "integrate", "combine", "modular", "state machine", "safety", "failsafe", "logic", "protocol", "integration"
|
||||
]
|
||||
MODULE_PATTERNS = {
|
||||
"ble": ["ble", "bluetooth", "phone app", "remote"],
|
||||
"servo": ["servo", "flipper", "arm", "mg996", "sg90"],
|
||||
"motor": ["motor", "drive", "l298n", "movement", "wheels"],
|
||||
"safety": ["safety", "timeout", "failsafe", "emergency"],
|
||||
"battery": ["battery", "voltage", "power"],
|
||||
"sensor": ["sensor", "distance", "proximity", "ultrasonic", "ir"]
|
||||
}
|
||||
240
core/buddai_storage.py
Normal file
240
core/buddai_storage.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from core.buddai_shared import DB_PATH, DATA_DIR
|
||||
|
||||
class StorageManager:
|
||||
"""Manages Database, Sessions, and Backups"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.current_session_id = None
|
||||
self.ensure_data_dir()
|
||||
self.init_database()
|
||||
self.start_new_session()
|
||||
|
||||
def ensure_data_dir(self) -> None:
|
||||
DATA_DIR.mkdir(exist_ok=True)
|
||||
|
||||
def init_database(self) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Core Tables
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
started_at TIMESTAMP,
|
||||
ended_at TIMESTAMP,
|
||||
title TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT,
|
||||
role TEXT,
|
||||
content TEXT,
|
||||
timestamp TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS repo_index (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT,
|
||||
file_path TEXT,
|
||||
repo_name TEXT,
|
||||
function_name TEXT,
|
||||
content TEXT,
|
||||
last_modified TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS style_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT,
|
||||
category TEXT,
|
||||
preference TEXT,
|
||||
confidence FLOAT,
|
||||
extracted_at TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS feedback (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER,
|
||||
positive BOOLEAN,
|
||||
comment TEXT,
|
||||
timestamp TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY,
|
||||
timestamp TEXT,
|
||||
original_code TEXT,
|
||||
corrected_code TEXT,
|
||||
reason TEXT,
|
||||
context TEXT,
|
||||
processed BOOLEAN DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS compilation_log (
|
||||
id INTEGER PRIMARY KEY,
|
||||
timestamp TEXT,
|
||||
code TEXT,
|
||||
success BOOLEAN,
|
||||
errors TEXT,
|
||||
hardware TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS code_rules (
|
||||
id INTEGER PRIMARY KEY,
|
||||
rule_text TEXT,
|
||||
pattern_find TEXT,
|
||||
pattern_replace TEXT,
|
||||
context TEXT,
|
||||
confidence FLOAT,
|
||||
learned_from TEXT,
|
||||
times_applied INTEGER DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
# Migrations (Idempotent)
|
||||
try: cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE repo_index ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE style_preferences ADD COLUMN user_id TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE feedback ADD COLUMN comment TEXT")
|
||||
except: pass
|
||||
try: cursor.execute("ALTER TABLE corrections ADD COLUMN processed BOOLEAN DEFAULT 0")
|
||||
except: pass
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_session(self) -> str:
|
||||
now = datetime.now()
|
||||
base_id = now.strftime("%Y%m%d_%H%M%S")
|
||||
session_id = base_id
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO sessions (session_id, user_id, started_at) VALUES (?, ?, ?)",
|
||||
(session_id, self.user_id, now.isoformat())
|
||||
)
|
||||
conn.commit()
|
||||
break
|
||||
except sqlite3.IntegrityError:
|
||||
counter += 1
|
||||
session_id = f"{base_id}_{counter}"
|
||||
|
||||
conn.close()
|
||||
return session_id
|
||||
|
||||
def start_new_session(self) -> str:
|
||||
self.current_session_id = self.create_session()
|
||||
return self.current_session_id
|
||||
|
||||
def end_session(self) -> None:
|
||||
if not self.current_session_id: return
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE sessions SET ended_at = ? WHERE session_id = ?",
|
||||
(datetime.now().isoformat(), self.current_session_id)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def save_message(self, role: str, content: str) -> int:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(self.current_session_id, role, content, datetime.now().isoformat())
|
||||
)
|
||||
msg_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return msg_id
|
||||
|
||||
def get_sessions(self, limit: int = 20) -> List[Dict[str, str]]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT session_id, started_at, title FROM sessions WHERE user_id = ? ORDER BY started_at DESC LIMIT ?", (self.user_id, limit))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{"id": r[0], "date": r[1], "title": r[2] if len(r) > 2 else None} for r in rows]
|
||||
|
||||
def rename_session(self, session_id: str, new_title: str) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("UPDATE sessions SET title = ? WHERE session_id = ? AND user_id = ?", (new_title, session_id, self.user_id))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM sessions WHERE session_id = ? AND user_id = ?", (session_id, self.user_id))
|
||||
if cursor.rowcount > 0:
|
||||
cursor.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def clear_current_session(self) -> None:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM messages WHERE session_id = ?", (self.current_session_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def load_session(self, session_id: str) -> List[Dict[str, str]]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT 1 FROM sessions WHERE session_id = ? AND user_id = ?", (session_id, self.user_id))
|
||||
if not cursor.fetchone():
|
||||
conn.close()
|
||||
return []
|
||||
|
||||
cursor.execute("SELECT id, role, content, timestamp FROM messages WHERE session_id = ? ORDER BY id ASC", (session_id,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
self.current_session_id = session_id
|
||||
return [{"id": r[0], "role": r[1], "content": r[2], "timestamp": r[3]} for r in rows]
|
||||
|
||||
def create_backup(self) -> Tuple[bool, str]:
|
||||
if not DB_PATH.exists(): return False, "Database file not found."
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = DATA_DIR / "backups"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
backup_path = backup_dir / f"conversations_{timestamp}.db"
|
||||
try:
|
||||
src = sqlite3.connect(DB_PATH); dst = sqlite3.connect(backup_path)
|
||||
with dst: src.backup(dst)
|
||||
dst.close(); src.close()
|
||||
return True, str(backup_path)
|
||||
except Exception as e: return False, str(e)
|
||||
41
core/buddai_training.py
Normal file
41
core/buddai_training.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import sqlite3
|
||||
import json
|
||||
from core.buddai_shared import DB_PATH, DATA_DIR
|
||||
|
||||
class ModelFineTuner:
|
||||
"""Fine-tune local model on YOUR corrections"""
|
||||
|
||||
def prepare_training_data(self):
|
||||
"""Convert corrections to training format"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT original_code, corrected_code, reason
|
||||
FROM corrections
|
||||
""")
|
||||
|
||||
training_data = []
|
||||
for original, corrected, reason in cursor.fetchall():
|
||||
training_data.append({
|
||||
"prompt": f"Generate code for: {reason}",
|
||||
"completion": corrected,
|
||||
"negative_example": original
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
# Save as JSONL for fine-tuning
|
||||
output_path = DATA_DIR / 'training_data.jsonl'
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for item in training_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
return f"Exported {len(training_data)} examples to {output_path}"
|
||||
|
||||
def fine_tune_model(self):
|
||||
"""Fine-tune Qwen on your corrections"""
|
||||
# This requires:
|
||||
# 1. Export training data
|
||||
# 2. Use Ollama modelfile or external training
|
||||
# 3. Create custom model: qwen2.5-coder-james:3b
|
||||
pass
|
||||
524
core/buddai_validation.py
Normal file
524
core/buddai_validation.py
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
import re
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
class CodeValidator:
|
||||
"""Validate generated code before showing to user"""
|
||||
|
||||
def find_line(self, code: str, substring: str) -> int:
|
||||
for i, line in enumerate(code.splitlines(), 1):
|
||||
if substring in line:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def has_safety_timeout(self, code: str) -> bool:
|
||||
# Simple heuristic: needs millis, subtraction, and a comparison to a value/constant
|
||||
# We want to avoid matching debounce logic (usually < 100ms)
|
||||
if "millis()" not in code: return False
|
||||
|
||||
# Check for constants like SAFETY_TIMEOUT, MOTOR_TIMEOUT
|
||||
if re.search(r'>\s*[A-Z_]*TIMEOUT', code):
|
||||
return True
|
||||
|
||||
# Check for state machine timeout (Combat Protocol)
|
||||
if "DISARM" in code and "millis" in code and ">" in code:
|
||||
return True
|
||||
|
||||
# Check for numeric literals > 500 (Debounce is usually 50)
|
||||
comparisons = re.findall(r'>\s*(\d+)', code)
|
||||
return any(int(val) > 500 for val in comparisons)
|
||||
|
||||
def matches_style(self, code: str) -> bool:
|
||||
# Placeholder for style matching logic
|
||||
return True
|
||||
|
||||
def apply_style(self, code: str) -> str:
|
||||
# Placeholder for style application
|
||||
return code
|
||||
|
||||
def refactor_loop_to_function(self, code: str) -> str:
|
||||
"""Extract loop body into runSystemLogic()"""
|
||||
loop_match = re.search(r'void\s+loop\s*\(\s*\)\s*\{', code)
|
||||
if not loop_match: return code
|
||||
|
||||
start_idx = loop_match.end()
|
||||
brace_count = 1
|
||||
loop_body_end = -1
|
||||
|
||||
for i, char in enumerate(code[start_idx:], start=start_idx):
|
||||
if char == '{': brace_count += 1
|
||||
elif char == '}': brace_count -= 1
|
||||
|
||||
if brace_count == 0:
|
||||
loop_body_end = i
|
||||
break
|
||||
|
||||
if loop_body_end == -1: return code
|
||||
|
||||
body = code[start_idx:loop_body_end]
|
||||
new_code = code[:start_idx] + "\n runSystemLogic();\n" + code[loop_body_end:]
|
||||
new_code += "\n\nvoid runSystemLogic() {" + body + "}\n"
|
||||
return new_code
|
||||
|
||||
def validate(self, code: str, hardware: str, user_message: str = "") -> Tuple[bool, List[Dict]]:
|
||||
"""Check code against known rules"""
|
||||
issues = []
|
||||
|
||||
# Check 1: ESP32 PWM
|
||||
if "ESP32" in hardware.upper():
|
||||
if "analogWrite" in code:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, "analogWrite"),
|
||||
"message": "ESP32 doesn't support analogWrite(). Use ledcWrite()",
|
||||
"fix": lambda c: c.replace("analogWrite", "ledcWrite")
|
||||
})
|
||||
|
||||
# Check 2: Non-blocking code
|
||||
if "delay(" in code and "motor" in code.lower():
|
||||
issues.append({
|
||||
"severity": "warning",
|
||||
"line": self.find_line(code, "delay"),
|
||||
"message": "Using delay() in motor code blocks safety checks",
|
||||
"fix": lambda c: c # No auto-fix
|
||||
})
|
||||
|
||||
# Check 3: Safety timeout
|
||||
if ("motor" in code.lower() or "servo" in code.lower()):
|
||||
if not self.has_safety_timeout(code):
|
||||
# Context-aware stop logic
|
||||
is_servo = "Servo" in code and "L298N" not in code
|
||||
stop_logic = " // STOP MOTORS\n ledcWrite(0, 0);\n ledcWrite(1, 0);"
|
||||
if is_servo:
|
||||
stop_logic = " // STOP SERVO\n // Implement safe position (e.g. myServo.write(90));"
|
||||
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Critical: No safety timeout detected (must be > 500ms).",
|
||||
"fix": lambda c, sl=stop_logic: "#define SAFETY_TIMEOUT 5000\nunsigned long lastCommand = 0;\n" + \
|
||||
re.sub(r'(void\s+loop\s*\(\s*\)\s*\{)', \
|
||||
rf'\1\n // [AUTO-FIX] Safety Timeout\n if (millis() - lastCommand > SAFETY_TIMEOUT) {{\n{sl}\n }}\n', c)
|
||||
})
|
||||
|
||||
# Check 4: L298N PWM Pin Misuse
|
||||
pwm_pins = re.findall(r'ledcAttachPin\s*\(\s*(\w+)\s*,', code)
|
||||
for pin in pwm_pins:
|
||||
# Check if digitalWrite is used on this pin
|
||||
if re.search(r'digitalWrite\s*\(\s*' + re.escape(pin) + r'\s*,', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, f"digitalWrite({pin}"),
|
||||
"message": f"Conflict: PWM pin '{pin}' used with digitalWrite(). Use ledcWrite() for speed control.",
|
||||
"fix": lambda c, p=pin: re.sub(r'digitalWrite\s*\(\s*' + re.escape(p) + r'\s*,\s*[^)]+\);?', f'// [Fixed] Removed conflicting digitalWrite on PWM pin {p}', c)
|
||||
})
|
||||
|
||||
# Check 5: Broken Debounce Logic (Type Mismatch)
|
||||
# Example: if (buttonState != lastDebounceTime)
|
||||
bad_debounce = re.search(r'if\s*\(\s*\w+\s*[!=]=\s*\w*DebounceTime\s*\)', code)
|
||||
if bad_debounce:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, bad_debounce.group(0)),
|
||||
"message": "Type Mismatch: Comparing button state (int) with time (long).",
|
||||
"fix": lambda c: c.replace(bad_debounce.group(0), "if ((millis() - lastDebounceTime) > debounceDelay)")
|
||||
})
|
||||
|
||||
# Check 6: Safety Timeout Value
|
||||
timeout_match = re.search(r'#define\s+SAFETY_TIMEOUT\s+(\d+)', code)
|
||||
if timeout_match and int(timeout_match.group(1)) > 5000:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, timeout_match.group(0)),
|
||||
"message": f"Safety timeout {timeout_match.group(1)}ms is too long (Max: 5000ms).",
|
||||
"fix": lambda c: re.sub(r'(#define\s+SAFETY_TIMEOUT\s+)\d+', r'\g<1>5000', c)
|
||||
})
|
||||
|
||||
# Check 7: Broken Safety Timer Logic (Static Init)
|
||||
bad_static = re.search(r'static\s+unsigned\s+long\s+(\w+)\s*=\s*millis\(\);', code)
|
||||
if bad_static:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, bad_static.group(0)),
|
||||
"message": "Static timer initialized with millis() prevents reset. Initialize to 0.",
|
||||
"fix": lambda c: c.replace(bad_static.group(0), f"static unsigned long {bad_static.group(1)} = 0;")
|
||||
})
|
||||
|
||||
# Check 8: Incomplete Motor Logic (L298N Validation)
|
||||
# If user explicitly asks for L298N or DC Motor, OR asks for 'motor' without 'servo'
|
||||
is_l298n_request = "l298n" in user_message.lower() or "dc motor" in user_message.lower() or ("motor" in user_message.lower() and "servo" not in user_message.lower())
|
||||
|
||||
if is_l298n_request:
|
||||
# 1. Check for Direction Pins (IN1/IN2)
|
||||
if not re.search(r'(?:#define|const\s+int)\s+\w*(?:IN1|IN2|DIR)\w*', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing L298N Direction Pins (IN1/IN2).",
|
||||
"fix": lambda c: "// [AUTO-FIX] L298N Definitions\n#define IN1 18\n#define IN2 19\n" + c
|
||||
})
|
||||
|
||||
# 2. Check for PWM Pin (ENA)
|
||||
if not re.search(r'(?:#define|const\s+int)\s+\w*(?:ENA|ENB|PWM)\w*', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing L298N PWM Pin (ENA).",
|
||||
"fix": lambda c: "#define ENA 21 // [AUTO-FIX] Missing PWM Pin\n" + c
|
||||
})
|
||||
|
||||
# 3. Check for Direction Control (digitalWrite)
|
||||
if "digitalWrite" not in code:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "L298N requires digitalWrite() for direction control.",
|
||||
"fix": lambda c: re.sub(r'(void\s+loop\s*\(\s*\)\s*\{)', r'\1\n // [AUTO-FIX] Set Direction\n digitalWrite(IN1, HIGH);\n digitalWrite(IN2, LOW);\n', c)
|
||||
})
|
||||
|
||||
# Check 9: Unnecessary Wire.h
|
||||
wire_include = re.search(r'#include\s+[<"]Wire\.h[>"]', code)
|
||||
if wire_include:
|
||||
# Check if Wire is actually used (excluding the include itself)
|
||||
rest_of_code = code.replace(wire_include.group(0), "")
|
||||
if not re.search(r'\bWire\b', rest_of_code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, wire_include.group(0)),
|
||||
"message": "Unnecessary #include <Wire.h> detected.",
|
||||
"fix": lambda c: re.sub(r'#include\s+[<"]Wire\.h[>"]', '// [Auto-Fix] Removed unnecessary Wire.h', c)
|
||||
})
|
||||
|
||||
# Check 10: High-Frequency Serial Logging
|
||||
if ("Serial.print" in code or "Serial.write" in code) and \
|
||||
("motor" in code.lower() or "servo" in code.lower()):
|
||||
# Check for throttling pattern (simple heuristic for timer variables)
|
||||
if not re.search(r'(print|log|debug|serial)\s*Timer', code, re.IGNORECASE) and \
|
||||
not re.search(r'last\s*(Print|Log|Debug)', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "warning",
|
||||
"line": self.find_line(code, "Serial.print"),
|
||||
"message": "Serial logging in motor loops causes jitter. Ensure it's throttled (e.g. every 100ms).",
|
||||
"fix": lambda c: c + "\n// [Performance] Warning: Serial.print() inside loops can interrupt motor timing."
|
||||
})
|
||||
|
||||
# Check 11: Feature Bloat (Unrequested Button)
|
||||
if user_message:
|
||||
msg_lower = user_message.lower()
|
||||
# If user didn't ask for inputs/buttons
|
||||
if not any(w in msg_lower for w in ['button', 'switch', 'input', 'trigger']):
|
||||
# Pattern 1: Variable assignment (int btn = digitalRead(...))
|
||||
for match in re.finditer(r'(?:int|bool|byte)\s+(\w*(?:button|btn|switch)\w*)\s*=\s*digitalRead\s*\([^;]+;', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": f"Feature Bloat: Unrequested button code detected ('{match.group(1)}').",
|
||||
"fix": lambda c, m=match.group(0): c.replace(m, "")
|
||||
})
|
||||
|
||||
# Pattern 2: Direct usage in conditions (if (digitalRead(BUTTON_PIN)...))
|
||||
for match in re.finditer(r'digitalRead\s*\(\s*(\w*(?:BUTTON|BTN|SWITCH)\w*)\s*\)', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": f"Feature Bloat: Unrequested button check detected ('{match.group(1)}').",
|
||||
"fix": lambda c, m=match.group(0): c.replace(m, "0")
|
||||
})
|
||||
|
||||
# Pattern 3: pinMode(..., INPUT)
|
||||
for match in re.finditer(r'pinMode\s*\(\s*\w+\s*,\s*INPUT(?:_PULLUP)?\s*\);', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": "Feature Bloat: Unrequested input pin configuration.",
|
||||
"fix": lambda c, m=match.group(0): c.replace(m, "")
|
||||
})
|
||||
|
||||
# Pattern 4: Unused button variable initialization (int btn = LOW;)
|
||||
for match in re.finditer(r'(?:int|bool|byte)\s+(\w*(?:button|btn|switch)\w*)\s*=\s*(?:LOW|HIGH|0|1|false|true)\s*;', code, re.IGNORECASE):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": f"Feature Bloat: Unused button variable '{match.group(1)}'.",
|
||||
"fix": lambda c, m=match.group(0): c.replace(m, "")
|
||||
})
|
||||
|
||||
# Check 14: State Machine for Weapons (Combat Protocol)
|
||||
if "weapon" in user_message.lower() or "combat" in user_message.lower() or "state machine" in user_message.lower():
|
||||
if "enum" not in code and "bool isArmed" not in code:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Combat code requires a State Machine (enum State or bool isArmed).",
|
||||
"fix": lambda c: c.replace("void setup", "\n// [AUTO-FIX] State Machine\nenum State { DISARMED, ARMING, ARMED, FIRING };\nState currentState = DISARMED;\nunsigned long stateTimer = 0;\n\nvoid setup") if "void setup" in c else "// [AUTO-FIX] State Machine\nenum State { DISARMED, ARMING, ARMED, FIRING };\nState currentState = DISARMED;\n" + c
|
||||
})
|
||||
|
||||
if "Serial.read" not in code and "Serial.available" not in code:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing Serial Command handling (e.g., 'A' to Arm).",
|
||||
"fix": lambda c: c.replace("void loop() {", "void loop() {\n if (Serial.available()) {\n char cmd = Serial.read();\n // Handle commands\n }\n")
|
||||
})
|
||||
|
||||
# Check 15: Function Naming Conventions (camelCase)
|
||||
# Exclude standard Arduino functions
|
||||
func_defs = re.finditer(r'\b(void|int|bool|float|double|String|char|long|unsigned(?:\s+long)?)\s+([a-zA-Z0-9_]+)\s*\(', code)
|
||||
for match in func_defs:
|
||||
func_name = match.group(2)
|
||||
if func_name in ['setup', 'loop', 'main']: continue
|
||||
|
||||
# Check if camelCase (starts with lowercase, no underscores unless specific style)
|
||||
if not re.match(r'^[a-z][a-zA-Z0-9]*$', func_name):
|
||||
# Check if it's snake_case or PascalCase
|
||||
suggestion = func_name
|
||||
if '_' in func_name: # snake_case -> camelCase
|
||||
components = func_name.split('_')
|
||||
suggestion = components[0].lower() + ''.join(x.title() for x in components[1:])
|
||||
elif func_name[0].isupper(): # PascalCase -> camelCase
|
||||
suggestion = func_name[0].lower() + func_name[1:]
|
||||
|
||||
issues.append({
|
||||
"severity": "warning",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": f"Style: Function '{func_name}' should be camelCase (e.g., '{suggestion}').",
|
||||
"fix": lambda c, old=func_name, new=suggestion: c.replace(old, new)
|
||||
})
|
||||
|
||||
# Check 16: Monolithic Code Structure
|
||||
if "function" in user_message.lower() or "naming" in user_message.lower() or "modular" in user_message.lower():
|
||||
has_custom_funcs = False
|
||||
for match in re.finditer(r'\b(void|int|bool|float|double|String|char|long|unsigned(?:\s+long)?)\s+([a-zA-Z0-9_]+)\s*\(', code):
|
||||
if match.group(2) not in ['setup', 'loop', 'main']:
|
||||
has_custom_funcs = True
|
||||
break
|
||||
|
||||
if not has_custom_funcs:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Structure Violation: Request asked for functions but code is monolithic.",
|
||||
"fix": lambda c: c.replace("void loop() {", "void loop() {\n runSystemLogic();\n}\n\nvoid runSystemLogic() {") + "\n}"
|
||||
})
|
||||
|
||||
# Check 17: Loop Length (Modularity)
|
||||
if "function" in user_message.lower() or "naming" in user_message.lower() or "modular" in user_message.lower():
|
||||
loop_match = re.search(r'void\s+loop\s*\(\s*\)\s*\{', code)
|
||||
if loop_match:
|
||||
start_idx = loop_match.end()
|
||||
brace_count = 1
|
||||
loop_body = ""
|
||||
|
||||
for char in code[start_idx:]:
|
||||
if char == '{': brace_count += 1
|
||||
elif char == '}': brace_count -= 1
|
||||
|
||||
if brace_count == 0:
|
||||
break
|
||||
loop_body += char
|
||||
|
||||
# Count significant lines
|
||||
lines = [line.strip() for line in loop_body.split('\n')]
|
||||
significant_lines = [l for l in lines if l and not l.startswith('//') and not l.startswith('/*') and l != '']
|
||||
|
||||
if len(significant_lines) >= 10:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": f"Modularity Violation: loop() has {len(significant_lines)} lines (limit 10). Move logic to functions.",
|
||||
"fix": lambda c: self.refactor_loop_to_function(c)
|
||||
})
|
||||
|
||||
# Check 18: ADC Resolution (ESP32)
|
||||
if "ESP32" in hardware.upper():
|
||||
adc_res_match = re.search(r'#define\s+(\w*ADC\w*RES\w*)\s+(\d+)', code, re.IGNORECASE)
|
||||
if adc_res_match:
|
||||
val = int(adc_res_match.group(2))
|
||||
if val not in [4095, 4096]:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, adc_res_match.group(0)),
|
||||
"message": f"Hardware Mismatch: ESP32 ADC is 12-bit (4095), not {val}.",
|
||||
"fix": lambda c, old=adc_res_match.group(0), name=adc_res_match.group(1): c.replace(old, f"#define {name} 4095")
|
||||
})
|
||||
|
||||
# Check 20: Hardcoded 10-bit ADC math
|
||||
# Matches / 1023, / 1023.0, / 1024.0 (avoiding / 1024 int for bytes)
|
||||
for match in re.finditer(r'/\s*(1023(?:\.0?)?f?|1024(?:\.0)f?)', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, match.group(0)),
|
||||
"message": "Hardware Mismatch: ESP32 ADC is 12-bit. Use 4095.0, not 1023/1024.",
|
||||
"fix": lambda c, m=match.group(0): c.replace(m, "/ 4095.0")
|
||||
})
|
||||
|
||||
# Check 21: Status LED Pattern
|
||||
if "status" in user_message.lower() and ("led" in user_message.lower() or "indicator" in user_message.lower()):
|
||||
# Detect breathing logic (incrementing duty cycle in loop)
|
||||
breathing_match = re.search(r'(?:dutyCycle|brightness)\s*(\+=|\+\+|\-=|\-\-)', code)
|
||||
if breathing_match:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, breathing_match.group(0)),
|
||||
"message": "Wrong Pattern: Status indicators should use Blink Patterns (States), not Breathing/Fading.",
|
||||
"fix": lambda c: c + "\n// [Fix Required] Implement setStatusLED(LEDStatus state) instead of fading."
|
||||
})
|
||||
|
||||
# Check for missing Enum
|
||||
if not re.search(r'enum\s+(?:StatusState|LEDStatus)\s*\{', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing Status Enum: Status LEDs require a state machine (enum LEDStatus {OFF, IDLE, ACTIVE, ERROR}).",
|
||||
"fix": lambda c: c.replace("void setup", "\n// [AUTO-FIX] Status Enum\nenum LEDStatus { OFF, IDLE, ACTIVE, ERROR };\nLEDStatus currentStatus = IDLE;\nunsigned long lastBlink = 0;\n\nvoid setup") if "void setup" in c else "// [AUTO-FIX] Status Enum\nenum LEDStatus { OFF, IDLE, ACTIVE, ERROR };\nLEDStatus currentStatus = IDLE;\nunsigned long lastBlink = 0;\n" + c
|
||||
})
|
||||
|
||||
# Check 19: Unnecessary Debouncing (Analog/Battery)
|
||||
if "battery" in user_message.lower() or "voltage" in user_message.lower() or "analog" in user_message.lower():
|
||||
if "button" not in user_message.lower():
|
||||
debounce_match = re.search(r'(?:debounce|lastDebounceTime)', code, re.IGNORECASE)
|
||||
if debounce_match:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, debounce_match.group(0)),
|
||||
"message": "Logic Error: Debouncing detected in analog/battery code. Analog sensors don't need debouncing.",
|
||||
"fix": lambda c: re.sub(r'.*debounce.*', '// [Fixed] Removed unnecessary debounce logic', c, flags=re.IGNORECASE)
|
||||
})
|
||||
|
||||
# Check 12: Undefined Pin Constants
|
||||
pin_vars = set(re.findall(r'(?:digitalRead|digitalWrite|pinMode|ledcAttachPin)\s*\(\s*([a-zA-Z_]\w+)', code))
|
||||
for var in pin_vars:
|
||||
if var in ['LED_BUILTIN', 'HIGH', 'LOW', 'INPUT', 'OUTPUT', 'INPUT_PULLUP', 'true', 'false']:
|
||||
continue
|
||||
|
||||
# Check if defined
|
||||
is_defined = re.search(r'#define\s+' + re.escape(var) + r'\b', code) or \
|
||||
re.search(r'\b(?:const\s+)?(?:int|byte|uint8_t|short)\s+' + re.escape(var) + r'\s*=', code)
|
||||
|
||||
if not is_defined:
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": f"Undefined variable '{var}' used in pin operation.",
|
||||
"fix": lambda c, v=var: f"#define {v} 2 // [Auto-Fix] Defined missing pin\n" + c
|
||||
})
|
||||
|
||||
# Check 22: Misused Debouncing (Animation Timing)
|
||||
if "brightness" in code or "fade" in code:
|
||||
misused_debounce = re.search(r'if\s*\(\s*\(?\s*millis\(\)\s*-\s*\w+\s*\)?\s*>\s*(\w*DEBOUNCE\w*)\s*\)\s*\{', code, re.IGNORECASE)
|
||||
if misused_debounce:
|
||||
var_name = misused_debounce.group(1)
|
||||
# Check if the block actually modifies brightness (simple heuristic lookahead)
|
||||
start_index = misused_debounce.end()
|
||||
snippet = code[start_index:start_index+200]
|
||||
if any(x in snippet for x in ['brightness', 'fade', 'dutyCycle', 'ledcWrite']):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"line": self.find_line(code, var_name),
|
||||
"message": f"Semantic Error: Using {var_name} for animation/fading. Use UPDATE_INTERVAL or FADE_SPEED.",
|
||||
"fix": lambda c, v=var_name: c.replace(v, "FADE_SPEED" if v.isupper() else "fadeSpeed")
|
||||
})
|
||||
|
||||
# Check 24: Unused Variables in Setup
|
||||
setup_match = re.search(r'void\s+setup\s*\(\s*\)\s*\{', code)
|
||||
if setup_match:
|
||||
start_idx = setup_match.end()
|
||||
brace_count = 1
|
||||
setup_body = ""
|
||||
for char in code[start_idx:]:
|
||||
if char == '{': brace_count += 1
|
||||
elif char == '}': brace_count -= 1
|
||||
if brace_count == 0: break
|
||||
setup_body += char
|
||||
|
||||
clean_body = re.sub(r'//.*', '', setup_body)
|
||||
clean_body = re.sub(r'/\*.*?\*/', '', clean_body, flags=re.DOTALL)
|
||||
|
||||
local_vars = re.finditer(r'\b((?:static\s+)?(?:const\s+)?(?:int|float|bool|char|String|long|double|byte|uint8_t|unsigned(?:\s+long)?))\s+([a-zA-Z_]\w*)\s*(?:=|;)', clean_body)
|
||||
|
||||
for match in local_vars:
|
||||
var_type = match.group(1)
|
||||
var_name = match.group(2)
|
||||
if len(re.findall(r'\b' + re.escape(var_name) + r'\b', clean_body)) == 1:
|
||||
issues.append({
|
||||
"severity": "warning",
|
||||
"line": self.find_line(code, f"{var_type} {var_name}"),
|
||||
"message": f"Unused variable '{var_name}' in setup().",
|
||||
"fix": lambda c, v=var_name, t=var_type: re.sub(r'\b' + re.escape(t) + r'\s+' + re.escape(v) + r'[^;]*;\s*', '', c)
|
||||
})
|
||||
|
||||
# Check 25: Missing Serial.begin
|
||||
if re.search(r'Serial\.(?:print|write|println|printf)', code) and not re.search(r'Serial\.begin\s*\(', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing Serial.begin() initialization.",
|
||||
"fix": lambda c: re.sub(r'void\s+setup\s*\(\s*\)\s*\{', r'void setup() {\n Serial.begin(115200);', c, count=1)
|
||||
})
|
||||
|
||||
# Check 26: Missing Wire.begin
|
||||
if re.search(r'Wire\.(?!h\b|begin\b)', code) and not re.search(r'Wire\.begin\s*\(', code):
|
||||
issues.append({
|
||||
"severity": "error",
|
||||
"message": "Missing Wire.begin() initialization for I2C.",
|
||||
"fix": lambda c: re.sub(r'void\s+setup\s*\(\s*\)\s*\{', r'void setup() {\n Wire.begin();', c, count=1)
|
||||
})
|
||||
|
||||
return len([i for i in issues if i['severity'] == 'error']) == 0, issues
|
||||
|
||||
def auto_fix(self, code: str, issues: List[Dict]) -> str:
|
||||
"""Automatically fix known issues"""
|
||||
fixed_code = code
|
||||
|
||||
for issue in issues:
|
||||
if 'fix' in issue and issue['severity'] == 'error':
|
||||
fixed_code = issue['fix'](fixed_code)
|
||||
|
||||
return fixed_code
|
||||
|
||||
class HardwareProfile:
|
||||
"""Learn hardware-specific patterns"""
|
||||
|
||||
ESP32_PATTERNS = {
|
||||
"pwm_setup": {
|
||||
"correct": "ledcSetup(channel, freq, resolution)",
|
||||
"wrong": ["analogWrite", "pwmWrite"],
|
||||
"learned_from": "James's corrections"
|
||||
},
|
||||
"serial_baud": {
|
||||
"preferred": 115200,
|
||||
"alternatives": [9600, 57600],
|
||||
"confidence": 1.0
|
||||
},
|
||||
"safety_timeout": {
|
||||
"standard": 5000,
|
||||
"pattern": "millis() - lastTime > TIMEOUT",
|
||||
"confidence": 1.0
|
||||
}
|
||||
}
|
||||
|
||||
HARDWARE_KEYWORDS = {
|
||||
"ESP32-C3": ["esp32", "esp32c3", "c3", "esp-32"],
|
||||
"Arduino Uno": ["uno", "arduino uno", "atmega328p"],
|
||||
"Raspberry Pi Pico": ["pico", "rp2040"]
|
||||
}
|
||||
|
||||
def detect_hardware(self, message: str) -> Optional[str]:
|
||||
msg_lower = message.lower()
|
||||
for hw, keywords in self.HARDWARE_KEYWORDS.items():
|
||||
if any(k in msg_lower for k in keywords):
|
||||
return hw
|
||||
return None
|
||||
|
||||
def apply_hardware_rules(self, code: str, hardware: str) -> str:
|
||||
"""Apply known hardware patterns"""
|
||||
if hardware == "ESP32-C3":
|
||||
# Apply ESP32-specific fixes
|
||||
code = self.fix_pwm(code)
|
||||
code = self.fix_serial(code)
|
||||
code = self.add_safety(code)
|
||||
return code
|
||||
|
||||
def fix_pwm(self, code: str) -> str:
|
||||
for wrong in self.ESP32_PATTERNS["pwm_setup"]["wrong"]:
|
||||
if wrong in code:
|
||||
if wrong == "analogWrite":
|
||||
code = code.replace("analogWrite", "ledcWrite")
|
||||
return code
|
||||
|
||||
def fix_serial(self, code: str) -> str:
|
||||
preferred = self.ESP32_PATTERNS["serial_baud"]["preferred"]
|
||||
return re.sub(r'Serial\.begin\(\s*\d+\s*\)', f'Serial.begin({preferred})', code)
|
||||
|
||||
def add_safety(self, code: str) -> str:
|
||||
if "motor" in code.lower() and "millis()" not in code:
|
||||
code += "\n// [BuddAI Safety] Warning: No non-blocking timeout detected. Consider adding safety timeout."
|
||||
return code
|
||||
Loading…
Add table
Add a link
Reference in a new issue