|
import os |
|
import logging |
|
from dotenv import load_dotenv |
|
import google.generativeai as genai |
|
from hardware_detector import HardwareDetector |
|
from optimization_knowledge import get_optimization_guide |
|
from typing import Dict, List |
|
import json |
|
|
|
|
|
try: |
|
import requests |
|
from urllib.parse import urljoin, urlparse |
|
from bs4 import BeautifulSoup |
|
TOOLS_AVAILABLE = True |
|
except ImportError: |
|
TOOLS_AVAILABLE = False |
|
requests = None |
|
urlparse = None |
|
BeautifulSoup = None |
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('auto_diffusers.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AutoDiffusersGenerator: |
|
def __init__(self, api_key: str): |
|
logger.info("Initializing AutoDiffusersGenerator") |
|
logger.debug(f"API key length: {len(api_key) if api_key else 'None'}") |
|
|
|
try: |
|
genai.configure(api_key=api_key) |
|
|
|
|
|
if TOOLS_AVAILABLE: |
|
self.tools = self._create_tools() |
|
|
|
self.model = genai.GenerativeModel( |
|
'gemini-2.5-flash-preview-05-20', |
|
tools=self.tools |
|
) |
|
logger.info("Successfully configured Gemini AI model with tools") |
|
else: |
|
self.tools = None |
|
|
|
self.model = genai.GenerativeModel('gemini-2.5-flash-preview-05-20') |
|
logger.warning("Tool calling dependencies not available, running without tools") |
|
except Exception as e: |
|
logger.error(f"Failed to configure Gemini AI: {e}") |
|
raise |
|
|
|
try: |
|
self.hardware_detector = HardwareDetector() |
|
logger.info("Hardware detector initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize hardware detector: {e}") |
|
raise |
|
|
|
def _create_tools(self): |
|
"""Create function tools for Gemini to use.""" |
|
logger.debug("Creating tools for Gemini") |
|
|
|
if not TOOLS_AVAILABLE: |
|
logger.warning("Tools dependencies not available, returning empty tools") |
|
return [] |
|
|
|
def fetch_huggingface_docs(url: str) -> str: |
|
"""Fetch documentation from HuggingFace URLs.""" |
|
logger.info("🌐 TOOL CALL: fetch_huggingface_docs") |
|
logger.info(f"📋 Requested URL: {url}") |
|
|
|
try: |
|
|
|
parsed = urlparse(url) |
|
logger.debug(f"URL validation - Domain: {parsed.netloc}, Path: {parsed.path}") |
|
|
|
if not any(domain in parsed.netloc for domain in ['huggingface.co', 'hf.co']): |
|
error_msg = "Error: URL must be from huggingface.co domain" |
|
logger.warning(f"❌ URL validation failed: {error_msg}") |
|
return error_msg |
|
|
|
logger.info(f"✅ URL validation passed for domain: {parsed.netloc}") |
|
|
|
headers = { |
|
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)' |
|
} |
|
|
|
logger.info(f"🔄 Fetching content from: {url}") |
|
response = requests.get(url, headers=headers, timeout=10) |
|
response.raise_for_status() |
|
logger.info(f"✅ HTTP {response.status_code} - Successfully fetched {len(response.text)} characters") |
|
|
|
|
|
logger.info("🔍 Parsing HTML content...") |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
|
|
|
|
content = "" |
|
element_count = 0 |
|
for element in soup.find_all(['p', 'pre', 'code', 'h1', 'h2', 'h3', 'h4', 'li']): |
|
text = element.get_text().strip() |
|
if text: |
|
content += text + "\\n" |
|
element_count += 1 |
|
|
|
logger.info(f"📄 Extracted content from {element_count} HTML elements") |
|
|
|
|
|
original_length = len(content) |
|
if len(content) > 5000: |
|
content = content[:5000] + "...[truncated]" |
|
logger.info(f"✂️ Content truncated from {original_length} to 5000 characters") |
|
|
|
logger.info(f"📊 Final processed content: {len(content)} characters") |
|
|
|
|
|
preview = content[:200].replace('\\n', ' ') |
|
logger.info(f"📋 Content preview: {preview}...") |
|
|
|
|
|
sections = [] |
|
for header in soup.find_all(['h1', 'h2', 'h3']): |
|
header_text = header.get_text().strip() |
|
if header_text: |
|
sections.append(header_text) |
|
|
|
if sections: |
|
logger.info(f"📑 Found sections: {', '.join(sections[:5])}{'...' if len(sections) > 5 else ''}") |
|
|
|
logger.info("✅ Content extraction completed successfully") |
|
return content |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Error fetching {url}: {type(e).__name__}: {e}") |
|
return f"Error fetching documentation: {str(e)}" |
|
|
|
def fetch_model_info(model_id: str) -> str: |
|
"""Fetch model information from HuggingFace API.""" |
|
logger.info("🤖 TOOL CALL: fetch_model_info") |
|
logger.info(f"📋 Requested model: {model_id}") |
|
try: |
|
|
|
api_url = f"https://huggingface.co/api/models/{model_id}" |
|
logger.info(f"🔄 Fetching model info from: {api_url}") |
|
headers = { |
|
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)' |
|
} |
|
|
|
response = requests.get(api_url, headers=headers, timeout=10) |
|
response.raise_for_status() |
|
logger.info(f"✅ HTTP {response.status_code} - Model API response received") |
|
|
|
model_data = response.json() |
|
logger.info(f"📊 Raw API response contains {len(model_data)} fields") |
|
|
|
|
|
info = { |
|
'model_id': model_data.get('id', model_id), |
|
'pipeline_tag': model_data.get('pipeline_tag', 'unknown'), |
|
'tags': model_data.get('tags', []), |
|
'library_name': model_data.get('library_name', 'unknown'), |
|
'downloads': model_data.get('downloads', 0), |
|
'likes': model_data.get('likes', 0) |
|
} |
|
|
|
logger.info(f"📋 Extracted model info:") |
|
logger.info(f" - Pipeline: {info['pipeline_tag']}") |
|
logger.info(f" - Library: {info['library_name']}") |
|
logger.info(f" - Downloads: {info['downloads']:,}") |
|
logger.info(f" - Likes: {info['likes']:,}") |
|
logger.info(f" - Tags: {len(info['tags'])} tags") |
|
|
|
result = json.dumps(info, indent=2) |
|
logger.info(f"✅ Model info formatting completed ({len(result)} characters)") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching model info for {model_id}: {e}") |
|
return f"Error fetching model information: {str(e)}" |
|
|
|
def search_optimization_guides(query: str) -> str: |
|
"""Search for optimization guides and best practices.""" |
|
logger.info("🔍 TOOL CALL: search_optimization_guides") |
|
logger.info(f"📋 Search query: '{query}'") |
|
try: |
|
|
|
docs_urls = [ |
|
"https://huggingface.co/docs/diffusers/optimization/fp16", |
|
"https://huggingface.co/docs/diffusers/optimization/memory", |
|
"https://huggingface.co/docs/diffusers/optimization/torch2", |
|
"https://huggingface.co/docs/diffusers/optimization/mps", |
|
"https://huggingface.co/docs/diffusers/optimization/xformers" |
|
] |
|
|
|
logger.info(f"🔎 Searching through {len(docs_urls)} optimization guide URLs...") |
|
|
|
results = [] |
|
matched_urls = [] |
|
for url in docs_urls: |
|
if any(keyword in url for keyword in query.lower().split()): |
|
logger.info(f"✅ URL matched query: {url}") |
|
matched_urls.append(url) |
|
content = fetch_huggingface_docs(url) |
|
if not content.startswith("Error"): |
|
results.append(f"From {url}:\\n{content[:1000]}...\\n") |
|
logger.info(f"📄 Successfully processed content from {url}") |
|
else: |
|
logger.warning(f"❌ Failed to fetch content from {url}") |
|
else: |
|
logger.debug(f"⏭️ URL skipped (no match): {url}") |
|
|
|
logger.info(f"📊 Search completed: {len(matched_urls)} URLs matched, {len(results)} successful fetches") |
|
|
|
if results: |
|
final_result = "\\n".join(results) |
|
logger.info(f"✅ Returning combined content ({len(final_result)} characters)") |
|
return final_result |
|
else: |
|
logger.warning("❌ No specific optimization guides found for the query") |
|
return "No specific optimization guides found for the query" |
|
|
|
except Exception as e: |
|
logger.error(f"Error searching optimization guides: {e}") |
|
return f"Error searching guides: {str(e)}" |
|
|
|
|
|
tools = [ |
|
{ |
|
"function_declarations": [ |
|
{ |
|
"name": "fetch_huggingface_docs", |
|
"description": "Fetch current documentation from HuggingFace URLs for diffusers library, models, or optimization guides", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"url": { |
|
"type": "string", |
|
"description": "The HuggingFace documentation URL to fetch" |
|
} |
|
}, |
|
"required": ["url"] |
|
} |
|
}, |
|
{ |
|
"name": "fetch_model_info", |
|
"description": "Fetch current model information and metadata from HuggingFace API", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"model_id": { |
|
"type": "string", |
|
"description": "The HuggingFace model ID (e.g., 'black-forest-labs/FLUX.1-schnell')" |
|
} |
|
}, |
|
"required": ["model_id"] |
|
} |
|
}, |
|
{ |
|
"name": "search_optimization_guides", |
|
"description": "Search for optimization guides and best practices for diffusers models", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "Search query for optimization topics (e.g., 'memory', 'fp16', 'torch compile')" |
|
} |
|
}, |
|
"required": ["query"] |
|
} |
|
} |
|
] |
|
} |
|
] |
|
|
|
|
|
self.tool_functions = { |
|
'fetch_huggingface_docs': fetch_huggingface_docs, |
|
'fetch_model_info': fetch_model_info, |
|
'search_optimization_guides': search_optimization_guides |
|
} |
|
|
|
logger.info(f"Created {len(tools[0]['function_declarations'])} tools for Gemini") |
|
return tools |
|
|
|
def generate_optimized_code(self, |
|
model_name: str, |
|
prompt_text: str, |
|
image_size: tuple = (768, 1360), |
|
num_inference_steps: int = 4, |
|
use_manual_specs: bool = False, |
|
manual_specs: Dict = None, |
|
memory_analysis: Dict = None) -> str: |
|
"""Generate optimized diffusers code based on hardware specs and memory analysis.""" |
|
|
|
logger.info(f"Starting code generation for model: {model_name}") |
|
logger.debug(f"Parameters: prompt='{prompt_text[:50]}...', size={image_size}, steps={num_inference_steps}") |
|
logger.debug(f"Manual specs: {use_manual_specs}, Memory analysis provided: {memory_analysis is not None}") |
|
|
|
|
|
if use_manual_specs and manual_specs: |
|
logger.info("Using manual hardware specifications") |
|
hardware_specs = manual_specs |
|
logger.debug(f"Manual specs: {hardware_specs}") |
|
|
|
|
|
if hardware_specs.get('gpu_info') and hardware_specs['gpu_info']: |
|
vram_gb = hardware_specs['gpu_info'][0]['memory_mb'] / 1024 |
|
logger.debug(f"GPU detected with {vram_gb:.1f} GB VRAM") |
|
|
|
if vram_gb >= 16: |
|
optimization_profile = 'performance' |
|
elif vram_gb >= 8: |
|
optimization_profile = 'balanced' |
|
else: |
|
optimization_profile = 'memory_efficient' |
|
else: |
|
optimization_profile = 'cpu_only' |
|
logger.info("No GPU detected, using CPU-only profile") |
|
|
|
logger.info(f"Selected optimization profile: {optimization_profile}") |
|
else: |
|
logger.info("Using automatic hardware detection") |
|
hardware_specs = self.hardware_detector.specs |
|
optimization_profile = self.hardware_detector.get_optimization_profile() |
|
logger.debug(f"Detected specs: {hardware_specs}") |
|
logger.info(f"Auto-detected optimization profile: {optimization_profile}") |
|
|
|
|
|
logger.debug("Creating generation prompt for Gemini API") |
|
system_prompt = self._create_generation_prompt( |
|
model_name, prompt_text, image_size, num_inference_steps, |
|
hardware_specs, optimization_profile, memory_analysis |
|
) |
|
logger.debug(f"Prompt length: {len(system_prompt)} characters") |
|
|
|
|
|
logger.info("=" * 80) |
|
logger.info("PROMPT SENT TO GEMINI API:") |
|
logger.info("=" * 80) |
|
logger.info(system_prompt) |
|
logger.info("=" * 80) |
|
|
|
try: |
|
logger.info("Sending request to Gemini API") |
|
response = self.model.generate_content(system_prompt) |
|
|
|
|
|
if self.tools and response.candidates[0].content.parts: |
|
for part in response.candidates[0].content.parts: |
|
if hasattr(part, 'function_call') and part.function_call: |
|
function_name = part.function_call.name |
|
function_args = dict(part.function_call.args) |
|
|
|
logger.info("🛠️ " + "=" * 60) |
|
logger.info(f"🛠️ GEMINI REQUESTED TOOL CALL: {function_name}") |
|
logger.info("🛠️ " + "=" * 60) |
|
logger.info(f"📋 Tool arguments: {function_args}") |
|
|
|
if function_name in self.tool_functions: |
|
logger.info(f"✅ Tool function found, executing...") |
|
tool_result = self.tool_functions[function_name](**function_args) |
|
logger.info("🛠️ " + "=" * 60) |
|
logger.info(f"🛠️ TOOL EXECUTION COMPLETED: {function_name}") |
|
logger.info("🛠️ " + "=" * 60) |
|
logger.info(f"📊 Tool result length: {len(str(tool_result))} characters") |
|
|
|
|
|
preview = str(tool_result)[:300].replace('\\n', ' ') |
|
logger.info(f"📋 Tool result preview: {preview}...") |
|
logger.info("🛠️ " + "=" * 60) |
|
|
|
|
|
follow_up_prompt = f""" |
|
{system_prompt} |
|
|
|
ADDITIONAL CONTEXT FROM TOOLS: |
|
Tool: {function_name} |
|
Result: {tool_result} |
|
|
|
Please use this current information to generate the most up-to-date and optimized code. |
|
""" |
|
|
|
|
|
logger.info("=" * 80) |
|
logger.info("FOLLOW-UP PROMPT SENT TO GEMINI API (WITH TOOL RESULTS):") |
|
logger.info("=" * 80) |
|
logger.info(follow_up_prompt) |
|
logger.info("=" * 80) |
|
|
|
logger.info("Generating final response with tool context") |
|
final_response = self.model.generate_content(follow_up_prompt) |
|
logger.info("Successfully received final response from Gemini API") |
|
logger.debug(f"Final response length: {len(final_response.text)} characters") |
|
return final_response.text |
|
|
|
|
|
logger.info("Successfully received response from Gemini API (no tools used)") |
|
logger.debug(f"Response length: {len(response.text)} characters") |
|
return response.text |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating code: {str(e)}") |
|
return f"Error generating code: {str(e)}" |
|
|
|
def _create_generation_prompt(self, |
|
model_name: str, |
|
prompt_text: str, |
|
image_size: tuple, |
|
num_inference_steps: int, |
|
hardware_specs: Dict, |
|
optimization_profile: str, |
|
memory_analysis: Dict = None) -> str: |
|
"""Create the prompt for Gemini API to generate optimized code.""" |
|
|
|
base_prompt = f""" |
|
You are an expert in optimizing diffusers library code for different hardware configurations. |
|
|
|
NOTE: This system includes curated optimization knowledge from HuggingFace documentation. |
|
|
|
TASK: Generate optimized Python code for running a diffusion model with the following specifications: |
|
- Model: {model_name} |
|
- Prompt: "{prompt_text}" |
|
- Image size: {image_size[0]}x{image_size[1]} |
|
- Inference steps: {num_inference_steps} |
|
|
|
HARDWARE SPECIFICATIONS: |
|
- Platform: {hardware_specs['platform']} ({hardware_specs['architecture']}) |
|
- CPU Cores: {hardware_specs['cpu_count']} |
|
- CUDA Available: {hardware_specs['cuda_available']} |
|
- MPS Available: {hardware_specs['mps_available']} |
|
- Optimization Profile: {optimization_profile} |
|
""" |
|
|
|
if hardware_specs.get('gpu_info'): |
|
base_prompt += f"- GPU: {hardware_specs['gpu_info'][0]['name']} ({hardware_specs['gpu_info'][0]['memory_mb']/1024:.1f} GB VRAM)\n" |
|
|
|
|
|
if hardware_specs.get('user_dtype'): |
|
base_prompt += f"- User specified dtype: {hardware_specs['user_dtype']}\n" |
|
|
|
|
|
if memory_analysis: |
|
memory_info = memory_analysis.get('memory_info', {}) |
|
recommendations = memory_analysis.get('recommendations', {}) |
|
|
|
base_prompt += f"\nMEMORY ANALYSIS:\n" |
|
if memory_info.get('estimated_inference_memory_fp16_gb'): |
|
base_prompt += f"- Model Memory Requirements: {memory_info['estimated_inference_memory_fp16_gb']} GB (FP16 inference)\n" |
|
if memory_info.get('memory_fp16_gb'): |
|
base_prompt += f"- Model Weights Size: {memory_info['memory_fp16_gb']} GB (FP16)\n" |
|
if recommendations.get('recommendations'): |
|
base_prompt += f"- Memory Recommendation: {', '.join(recommendations['recommendations'])}\n" |
|
if recommendations.get('recommended_precision'): |
|
base_prompt += f"- Recommended Precision: {recommendations['recommended_precision']}\n" |
|
if recommendations.get('cpu_offload'): |
|
base_prompt += f"- CPU Offloading Required: {recommendations['cpu_offload']}\n" |
|
if recommendations.get('attention_slicing'): |
|
base_prompt += f"- Attention Slicing Recommended: {recommendations['attention_slicing']}\n" |
|
if recommendations.get('vae_slicing'): |
|
base_prompt += f"- VAE Slicing Recommended: {recommendations['vae_slicing']}\n" |
|
|
|
base_prompt += f""" |
|
OPTIMIZATION KNOWLEDGE BASE: |
|
{get_optimization_guide()} |
|
|
|
IMPORTANT: For FLUX.1-schnell models, do NOT include guidance_scale parameter as it's not needed. |
|
|
|
Using the OPTIMIZATION KNOWLEDGE BASE above, generate Python code that: |
|
|
|
1. **Selects the best optimization techniques** for the specific hardware profile |
|
2. **Applies appropriate memory optimizations** based on available VRAM |
|
3. **Uses optimal data types** for the target hardware: |
|
- User specified dtype (if provided): Use exactly as specified |
|
- Apple Silicon (MPS): prefer torch.bfloat16 |
|
- NVIDIA GPUs: prefer torch.float16 or torch.bfloat16 |
|
- CPU only: use torch.float32 |
|
4. **Implements hardware-specific optimizations** (CUDA, MPS, CPU) |
|
5. **Follows model-specific guidelines** (e.g., FLUX guidance_scale handling) |
|
|
|
IMPORTANT GUIDELINES: |
|
- Reference the OPTIMIZATION KNOWLEDGE BASE to select appropriate techniques |
|
- Include all necessary imports |
|
- Add brief comments explaining optimization choices |
|
- Generate compact, production-ready code |
|
- Inline values where possible for concise code |
|
- Generate ONLY the Python code, no explanations before or after the code block |
|
""" |
|
|
|
return base_prompt |
|
|
|
def run_interactive_mode(self): |
|
"""Run the generator in interactive mode.""" |
|
print("=== Auto-Diffusers Code Generator ===") |
|
print("This tool generates optimized diffusers code based on your hardware.\n") |
|
|
|
|
|
print("=== Hardware Detection ===") |
|
self.hardware_detector.print_specs() |
|
|
|
use_manual = input("\nUse manual hardware input? (y/n): ").lower() == 'y' |
|
|
|
|
|
print("\n=== Model Configuration ===") |
|
model_name = input("Model name (default: black-forest-labs/FLUX.1-schnell): ").strip() |
|
if not model_name: |
|
model_name = "black-forest-labs/FLUX.1-schnell" |
|
|
|
prompt_text = input("Prompt text (default: A cat holding a sign that says hello world): ").strip() |
|
if not prompt_text: |
|
prompt_text = "A cat holding a sign that says hello world" |
|
|
|
try: |
|
width = int(input("Image width (default: 1360): ") or "1360") |
|
height = int(input("Image height (default: 768): ") or "768") |
|
steps = int(input("Inference steps (default: 4): ") or "4") |
|
except ValueError: |
|
width, height, steps = 1360, 768, 4 |
|
|
|
print("\n=== Generating Optimized Code ===") |
|
|
|
|
|
optimized_code = self.generate_optimized_code( |
|
model_name=model_name, |
|
prompt_text=prompt_text, |
|
image_size=(height, width), |
|
num_inference_steps=steps, |
|
use_manual_specs=use_manual |
|
) |
|
|
|
print("\n" + "="*60) |
|
print("OPTIMIZED DIFFUSERS CODE:") |
|
print("="*60) |
|
print(optimized_code) |
|
print("="*60) |
|
|
|
|
|
def main(): |
|
|
|
api_key = os.getenv('GOOGLE_API_KEY') |
|
if not api_key: |
|
api_key = os.getenv('GEMINI_API_KEY') |
|
if not api_key: |
|
api_key = input("Enter your Gemini API key: ").strip() |
|
if not api_key: |
|
print("API key is required!") |
|
return |
|
|
|
generator = AutoDiffusersGenerator(api_key) |
|
generator.run_interactive_mode() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |