Commit
·
472e2e9
0
Parent(s):
Sure! Pl
Browse files- api_service.py +614 -0
- data/.gitkeep +2 -0
- evaluation/.gitkeep +2 -0
- models/.gitkeep +2 -0
- raw_dataset.json +0 -0
- requirements.txt +12 -0
- setup-guide.md +342 -0
- test_api.py +160 -0
- training_pipeline.py +772 -0
api_service.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Backend Code Generation API Service
|
| 4 |
+
===================================
|
| 5 |
+
|
| 6 |
+
Production-ready API service for serving the trained backend code generation model.
|
| 7 |
+
Provides RESTful endpoints for generating complete backend applications.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
from typing import List, Dict, Optional, Any
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
import json
|
| 18 |
+
import zipfile
|
| 19 |
+
import tempfile
|
| 20 |
+
import os
|
| 21 |
+
import uuid
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
import asyncio
|
| 24 |
+
import logging
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
# Configure logging
|
| 28 |
+
logging.basicConfig(level=logging.INFO)
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
# Pydantic models for API
|
| 32 |
+
class CodeGenerationRequest(BaseModel):
|
| 33 |
+
description: str = Field(..., description="Description of the backend application to generate")
|
| 34 |
+
framework: str = Field(..., description="Target framework (express, fastapi, django, flask)")
|
| 35 |
+
language: str = Field(..., description="Programming language (javascript, python)")
|
| 36 |
+
requirements: List[str] = Field(default=[], description="List of specific requirements")
|
| 37 |
+
project_name: Optional[str] = Field(default=None, description="Custom project name")
|
| 38 |
+
|
| 39 |
+
class Config:
|
| 40 |
+
schema_extra = {
|
| 41 |
+
"example": {
|
| 42 |
+
"description": "E-commerce API with user authentication and product management",
|
| 43 |
+
"framework": "fastapi",
|
| 44 |
+
"language": "python",
|
| 45 |
+
"requirements": [
|
| 46 |
+
"User registration and login",
|
| 47 |
+
"JWT authentication",
|
| 48 |
+
"Product CRUD operations",
|
| 49 |
+
"Shopping cart functionality",
|
| 50 |
+
"Order management"
|
| 51 |
+
],
|
| 52 |
+
"project_name": "ecommerce-api"
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
class GenerationResponse(BaseModel):
|
| 57 |
+
task_id: str
|
| 58 |
+
status: str
|
| 59 |
+
message: str
|
| 60 |
+
estimated_time: int
|
| 61 |
+
|
| 62 |
+
class GenerationStatus(BaseModel):
|
| 63 |
+
task_id: str
|
| 64 |
+
status: str # pending, processing, completed, failed
|
| 65 |
+
progress: int # 0-100
|
| 66 |
+
message: str
|
| 67 |
+
generated_files: Optional[Dict[str, str]] = None
|
| 68 |
+
download_url: Optional[str] = None
|
| 69 |
+
error: Optional[str] = None
|
| 70 |
+
|
| 71 |
+
class GeneratedProject(BaseModel):
|
| 72 |
+
project_name: str
|
| 73 |
+
framework: str
|
| 74 |
+
language: str
|
| 75 |
+
files: Dict[str, str]
|
| 76 |
+
structure: Dict[str, Any]
|
| 77 |
+
setup_instructions: List[str]
|
| 78 |
+
features: List[str]
|
| 79 |
+
|
| 80 |
+
# Global model instance
|
| 81 |
+
class ModelManager:
|
| 82 |
+
def __init__(self):
|
| 83 |
+
self.model = None
|
| 84 |
+
self.tokenizer = None
|
| 85 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 86 |
+
self.loaded = False
|
| 87 |
+
|
| 88 |
+
async def load_model(self, model_path: str = "./trained_model"):
|
| 89 |
+
"""Load the trained model asynchronously"""
|
| 90 |
+
try:
|
| 91 |
+
logger.info(f"Loading model from {model_path} on {self.device}")
|
| 92 |
+
|
| 93 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 94 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 95 |
+
model_path,
|
| 96 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 97 |
+
device_map="auto" if self.device == "cuda" else None
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if self.device == "cpu":
|
| 101 |
+
self.model = self.model.to(self.device)
|
| 102 |
+
|
| 103 |
+
self.loaded = True
|
| 104 |
+
logger.info("Model loaded successfully!")
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Failed to load model: {e}")
|
| 108 |
+
raise
|
| 109 |
+
|
| 110 |
+
def generate_code(self, prompt: str, max_tokens: int = 1024) -> str:
|
| 111 |
+
"""Generate code using the trained model"""
|
| 112 |
+
if not self.loaded:
|
| 113 |
+
raise RuntimeError("Model not loaded")
|
| 114 |
+
|
| 115 |
+
inputs = self.tokenizer.encode(prompt, return_tensors='pt')
|
| 116 |
+
inputs = inputs.to(self.device)
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
outputs = self.model.generate(
|
| 120 |
+
inputs,
|
| 121 |
+
max_length=min(max_tokens, 1024),
|
| 122 |
+
num_return_sequences=1,
|
| 123 |
+
temperature=0.7,
|
| 124 |
+
do_sample=True,
|
| 125 |
+
top_p=0.9,
|
| 126 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 127 |
+
repetition_penalty=1.1
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 131 |
+
return generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):]
|
| 132 |
+
|
| 133 |
+
# Global instances
|
| 134 |
+
model_manager = ModelManager()
|
| 135 |
+
generation_tasks = {} # Store generation tasks
|
| 136 |
+
|
| 137 |
+
# FastAPI app
|
| 138 |
+
app = FastAPI(
|
| 139 |
+
title="Backend Code Generation API",
|
| 140 |
+
description="AI-powered backend application generator",
|
| 141 |
+
version="1.0.0",
|
| 142 |
+
docs_url="/docs",
|
| 143 |
+
redoc_url="/redoc"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# CORS middleware
|
| 147 |
+
app.add_middleware(
|
| 148 |
+
CORSMiddleware,
|
| 149 |
+
allow_origins=["*"], # Configure for production
|
| 150 |
+
allow_credentials=True,
|
| 151 |
+
allow_methods=["*"],
|
| 152 |
+
allow_headers=["*"],
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
@app.on_event("startup")
|
| 156 |
+
async def startup_event():
|
| 157 |
+
"""Load model on startup"""
|
| 158 |
+
model_path = os.getenv("MODEL_PATH", "./trained_model")
|
| 159 |
+
await model_manager.load_model(model_path)
|
| 160 |
+
|
| 161 |
+
@app.get("/")
|
| 162 |
+
async def root():
|
| 163 |
+
"""API root endpoint"""
|
| 164 |
+
return {
|
| 165 |
+
"service": "Backend Code Generation API",
|
| 166 |
+
"version": "1.0.0",
|
| 167 |
+
"status": "running",
|
| 168 |
+
"model_loaded": model_manager.loaded,
|
| 169 |
+
"endpoints": {
|
| 170 |
+
"generate": "/api/v1/generate",
|
| 171 |
+
"status": "/api/v1/status/{task_id}",
|
| 172 |
+
"download": "/api/v1/download/{task_id}",
|
| 173 |
+
"health": "/health"
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
@app.get("/health")
|
| 178 |
+
async def health_check():
|
| 179 |
+
"""Health check endpoint"""
|
| 180 |
+
return {
|
| 181 |
+
"status": "OK",
|
| 182 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 183 |
+
"model_loaded": model_manager.loaded,
|
| 184 |
+
"device": model_manager.device if model_manager.loaded else None
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
@app.post("/api/v1/generate", response_model=GenerationResponse)
|
| 188 |
+
async def generate_backend(
|
| 189 |
+
request: CodeGenerationRequest,
|
| 190 |
+
background_tasks: BackgroundTasks
|
| 191 |
+
):
|
| 192 |
+
"""Generate a complete backend application"""
|
| 193 |
+
|
| 194 |
+
if not model_manager.loaded:
|
| 195 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 196 |
+
|
| 197 |
+
# Create unique task ID
|
| 198 |
+
task_id = str(uuid.uuid4())
|
| 199 |
+
|
| 200 |
+
# Initialize task status
|
| 201 |
+
generation_tasks[task_id] = GenerationStatus(
|
| 202 |
+
task_id=task_id,
|
| 203 |
+
status="pending",
|
| 204 |
+
progress=0,
|
| 205 |
+
message="Task queued for processing"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Start background generation
|
| 209 |
+
background_tasks.add_task(
|
| 210 |
+
generate_project_background,
|
| 211 |
+
task_id,
|
| 212 |
+
request
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return GenerationResponse(
|
| 216 |
+
task_id=task_id,
|
| 217 |
+
status="accepted",
|
| 218 |
+
message="Code generation started",
|
| 219 |
+
estimated_time=60 # seconds
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
@app.get("/api/v1/status/{task_id}", response_model=GenerationStatus)
|
| 223 |
+
async def get_generation_status(task_id: str):
|
| 224 |
+
"""Get the status of a generation task"""
|
| 225 |
+
|
| 226 |
+
if task_id not in generation_tasks:
|
| 227 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 228 |
+
|
| 229 |
+
return generation_tasks[task_id]
|
| 230 |
+
|
| 231 |
+
@app.get("/api/v1/download/{task_id}")
|
| 232 |
+
async def download_generated_project(task_id: str):
|
| 233 |
+
"""Download the generated project as a ZIP file"""
|
| 234 |
+
|
| 235 |
+
if task_id not in generation_tasks:
|
| 236 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 237 |
+
|
| 238 |
+
task = generation_tasks[task_id]
|
| 239 |
+
|
| 240 |
+
if task.status != "completed":
|
| 241 |
+
raise HTTPException(status_code=400, detail="Generation not completed")
|
| 242 |
+
|
| 243 |
+
if not task.download_url:
|
| 244 |
+
raise HTTPException(status_code=404, detail="Download file not available")
|
| 245 |
+
|
| 246 |
+
if not os.path.exists(task.download_url):
|
| 247 |
+
raise HTTPException(status_code=404, detail="Download file not found")
|
| 248 |
+
|
| 249 |
+
return FileResponse(
|
| 250 |
+
path=task.download_url,
|
| 251 |
+
filename=f"generated_project_{task_id}.zip",
|
| 252 |
+
media_type="application/zip"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
@app.delete("/api/v1/cleanup/{task_id}")
|
| 256 |
+
async def cleanup_task(task_id: str):
|
| 257 |
+
"""Clean up task files and data"""
|
| 258 |
+
|
| 259 |
+
if task_id not in generation_tasks:
|
| 260 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 261 |
+
|
| 262 |
+
task = generation_tasks[task_id]
|
| 263 |
+
|
| 264 |
+
# Remove download file if exists
|
| 265 |
+
if task.download_url and os.path.exists(task.download_url):
|
| 266 |
+
os.remove(task.download_url)
|
| 267 |
+
|
| 268 |
+
# Remove task from memory
|
| 269 |
+
del generation_tasks[task_id]
|
| 270 |
+
|
| 271 |
+
return {"message": "Task cleaned up successfully"}
|
| 272 |
+
|
| 273 |
+
async def generate_project_background(task_id: str, request: CodeGenerationRequest):
|
| 274 |
+
"""Background task for generating the complete project"""
|
| 275 |
+
|
| 276 |
+
task = generation_tasks[task_id]
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
# Update status
|
| 280 |
+
task.status = "processing"
|
| 281 |
+
task.progress = 10
|
| 282 |
+
task.message = "Analyzing requirements..."
|
| 283 |
+
|
| 284 |
+
# Create the generation prompt
|
| 285 |
+
prompt = create_generation_prompt(request)
|
| 286 |
+
|
| 287 |
+
# Update progress
|
| 288 |
+
task.progress = 30
|
| 289 |
+
task.message = "Generating application structure..."
|
| 290 |
+
|
| 291 |
+
# Generate code using the model
|
| 292 |
+
generated_code = model_manager.generate_code(prompt, max_tokens=1024)
|
| 293 |
+
|
| 294 |
+
# Update progress
|
| 295 |
+
task.progress = 60
|
| 296 |
+
task.message = "Processing generated code..."
|
| 297 |
+
|
| 298 |
+
# Parse and structure the generated code
|
| 299 |
+
project_files = parse_generated_code(generated_code, request)
|
| 300 |
+
|
| 301 |
+
# Update progress
|
| 302 |
+
task.progress = 80
|
| 303 |
+
task.message = "Creating project files..."
|
| 304 |
+
|
| 305 |
+
# Create downloadable ZIP file
|
| 306 |
+
zip_path = create_project_zip(task_id, project_files, request)
|
| 307 |
+
|
| 308 |
+
# Complete the task
|
| 309 |
+
task.status = "completed"
|
| 310 |
+
task.progress = 100
|
| 311 |
+
task.message = "Project generated successfully"
|
| 312 |
+
task.generated_files = {name: "Generated" for name in project_files.keys()}
|
| 313 |
+
task.download_url = zip_path
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.error(f"Generation failed for task {task_id}: {e}")
|
| 317 |
+
task.status = "failed"
|
| 318 |
+
task.error = str(e)
|
| 319 |
+
task.message = "Generation failed"
|
| 320 |
+
|
| 321 |
+
def create_generation_prompt(request: CodeGenerationRequest) -> str:
|
| 322 |
+
"""Create the prompt for the model"""
|
| 323 |
+
|
| 324 |
+
prompt_parts = [
|
| 325 |
+
f"Description: {request.description}",
|
| 326 |
+
f"Framework: {request.framework}",
|
| 327 |
+
f"Language: {request.language}",
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
if request.requirements:
|
| 331 |
+
prompt_parts.append("Requirements:")
|
| 332 |
+
for req in request.requirements:
|
| 333 |
+
prompt_parts.append(f"- {req}")
|
| 334 |
+
|
| 335 |
+
if request.project_name:
|
| 336 |
+
prompt_parts.append(f"Project Name: {request.project_name}")
|
| 337 |
+
|
| 338 |
+
prompt_parts.append("Generate the complete backend application with all necessary files:")
|
| 339 |
+
|
| 340 |
+
return "\n".join(prompt_parts)
|
| 341 |
+
|
| 342 |
+
def parse_generated_code(generated_code: str, request: CodeGenerationRequest) -> Dict[str, str]:
|
| 343 |
+
"""Parse the generated code into individual files"""
|
| 344 |
+
|
| 345 |
+
files = {}
|
| 346 |
+
|
| 347 |
+
# Simple parsing logic - in production, this should be more sophisticated
|
| 348 |
+
lines = generated_code.split('\n')
|
| 349 |
+
current_file = None
|
| 350 |
+
current_content = []
|
| 351 |
+
|
| 352 |
+
for line in lines:
|
| 353 |
+
if line.startswith('--- ') and line.endswith(' ---'):
|
| 354 |
+
# Save previous file
|
| 355 |
+
if current_file:
|
| 356 |
+
files[current_file] = '\n'.join(current_content)
|
| 357 |
+
|
| 358 |
+
# Start new file
|
| 359 |
+
current_file = line.replace('--- ', '').replace(' ---', '').strip()
|
| 360 |
+
current_content = []
|
| 361 |
+
|
| 362 |
+
elif current_file and not line.startswith('--- End ---'):
|
| 363 |
+
current_content.append(line)
|
| 364 |
+
|
| 365 |
+
# Save last file
|
| 366 |
+
if current_file and current_content:
|
| 367 |
+
files[current_file] = '\n'.join(current_content)
|
| 368 |
+
|
| 369 |
+
# If parsing failed, create basic structure based on framework
|
| 370 |
+
if not files:
|
| 371 |
+
files = create_fallback_structure(request)
|
| 372 |
+
|
| 373 |
+
return files
|
| 374 |
+
|
| 375 |
+
def create_fallback_structure(request: CodeGenerationRequest) -> Dict[str, str]:
|
| 376 |
+
"""Create a basic project structure if parsing fails"""
|
| 377 |
+
|
| 378 |
+
if request.framework.lower() == 'fastapi':
|
| 379 |
+
return {
|
| 380 |
+
'main.py': f'''from fastapi import FastAPI
|
| 381 |
+
|
| 382 |
+
app = FastAPI(title="{request.description}")
|
| 383 |
+
|
| 384 |
+
@app.get("/")
|
| 385 |
+
async def root():
|
| 386 |
+
return {{"message": "Hello from {request.description}"}}
|
| 387 |
+
|
| 388 |
+
@app.get("/health")
|
| 389 |
+
async def health():
|
| 390 |
+
return {{"status": "OK"}}
|
| 391 |
+
''',
|
| 392 |
+
'requirements.txt': '''fastapi==0.104.1
|
| 393 |
+
uvicorn[standard]==0.24.0'''
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
elif request.framework.lower() == 'express':
|
| 397 |
+
return {
|
| 398 |
+
'app.js': f'''const express = require('express');
|
| 399 |
+
const app = express();
|
| 400 |
+
|
| 401 |
+
app.get('/', (req, res) => {{
|
| 402 |
+
res.json({{ message: 'Hello from {request.description}' }});
|
| 403 |
+
}});
|
| 404 |
+
|
| 405 |
+
app.get('/health', (req, res) => {{
|
| 406 |
+
res.json({{ status: 'OK' }});
|
| 407 |
+
}});
|
| 408 |
+
|
| 409 |
+
const PORT = process.env.PORT || 3000;
|
| 410 |
+
app.listen(PORT, () => {{
|
| 411 |
+
console.log(`Server running on port ${{PORT}}`);
|
| 412 |
+
}});
|
| 413 |
+
''',
|
| 414 |
+
'package.json': json.dumps({
|
| 415 |
+
"name": request.project_name or "generated-backend",
|
| 416 |
+
"version": "1.0.0",
|
| 417 |
+
"main": "app.js",
|
| 418 |
+
"dependencies": {
|
| 419 |
+
"express": "^4.18.2"
|
| 420 |
+
}
|
| 421 |
+
}, indent=2)
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
else:
|
| 425 |
+
return {
|
| 426 |
+
'README.md': f'# {request.description}\n\nGenerated backend application using {request.framework}'
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
def create_project_zip(task_id: str, files: Dict[str, str], request: CodeGenerationRequest) -> str:
|
| 430 |
+
"""Create a ZIP file containing all project files"""
|
| 431 |
+
|
| 432 |
+
# Create temporary directory for the ZIP file
|
| 433 |
+
temp_dir = tempfile.gettempdir()
|
| 434 |
+
zip_path = os.path.join(temp_dir, f"project_{task_id}.zip")
|
| 435 |
+
|
| 436 |
+
project_name = request.project_name or f"generated_{request.framework}_app"
|
| 437 |
+
|
| 438 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 439 |
+
for filename, content in files.items():
|
| 440 |
+
# Add each file to the ZIP
|
| 441 |
+
arcname = f"{project_name}/{filename}"
|
| 442 |
+
zipf.writestr(arcname, content)
|
| 443 |
+
|
| 444 |
+
# Add a README with setup instructions
|
| 445 |
+
setup_instructions = get_setup_instructions(request.framework)
|
| 446 |
+
zipf.writestr(f"{project_name}/SETUP.md", setup_instructions)
|
| 447 |
+
|
| 448 |
+
return zip_path
|
| 449 |
+
|
| 450 |
+
def get_setup_instructions(framework: str) -> str:
|
| 451 |
+
"""Get setup instructions for the framework"""
|
| 452 |
+
|
| 453 |
+
instructions = {
|
| 454 |
+
'fastapi': '''# Setup Instructions
|
| 455 |
+
|
| 456 |
+
1. Install dependencies:
|
| 457 |
+
```bash
|
| 458 |
+
pip install -r requirements.txt
|
| 459 |
+
```
|
| 460 |
+
|
| 461 |
+
2. Run the application:
|
| 462 |
+
```bash
|
| 463 |
+
uvicorn main:app --reload
|
| 464 |
+
```
|
| 465 |
+
|
| 466 |
+
3. Access the API:
|
| 467 |
+
- API: http://localhost:8000
|
| 468 |
+
- Docs: http://localhost:8000/docs
|
| 469 |
+
''',
|
| 470 |
+
'express': '''# Setup Instructions
|
| 471 |
+
|
| 472 |
+
1. Install dependencies:
|
| 473 |
+
```bash
|
| 474 |
+
npm install
|
| 475 |
+
```
|
| 476 |
+
|
| 477 |
+
2. Run the application:
|
| 478 |
+
```bash
|
| 479 |
+
node app.js
|
| 480 |
+
```
|
| 481 |
+
|
| 482 |
+
3. Access the API:
|
| 483 |
+
- API: http://localhost:3000
|
| 484 |
+
''',
|
| 485 |
+
'django': '''# Setup Instructions
|
| 486 |
+
|
| 487 |
+
1. Install dependencies:
|
| 488 |
+
```bash
|
| 489 |
+
pip install -r requirements.txt
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
2. Run migrations:
|
| 493 |
+
```bash
|
| 494 |
+
python manage.py migrate
|
| 495 |
+
```
|
| 496 |
+
|
| 497 |
+
3. Run the application:
|
| 498 |
+
```bash
|
| 499 |
+
python manage.py runserver
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
4. Access the API:
|
| 503 |
+
- API: http://localhost:8000
|
| 504 |
+
- Admin: http://localhost:8000/admin
|
| 505 |
+
''',
|
| 506 |
+
'flask': '''# Setup Instructions
|
| 507 |
+
|
| 508 |
+
1. Install dependencies:
|
| 509 |
+
```bash
|
| 510 |
+
pip install -r requirements.txt
|
| 511 |
+
```
|
| 512 |
+
|
| 513 |
+
2. Run the application:
|
| 514 |
+
```bash
|
| 515 |
+
python run.py
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
3. Access the API:
|
| 519 |
+
- API: http://localhost:5000
|
| 520 |
+
'''
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
return instructions.get(framework, '# Setup Instructions\n\nRefer to the framework documentation for setup instructions.')
|
| 524 |
+
|
| 525 |
+
# Additional utility endpoints
|
| 526 |
+
@app.get("/api/v1/frameworks")
|
| 527 |
+
async def list_supported_frameworks():
|
| 528 |
+
"""List supported frameworks and languages"""
|
| 529 |
+
return {
|
| 530 |
+
"frameworks": [
|
| 531 |
+
{
|
| 532 |
+
"name": "fastapi",
|
| 533 |
+
"language": "python",
|
| 534 |
+
"description": "Modern, fast, web framework for building APIs"
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"name": "express",
|
| 538 |
+
"language": "javascript",
|
| 539 |
+
"description": "Fast, unopinionated web framework for Node.js"
|
| 540 |
+
},
|
| 541 |
+
{
|
| 542 |
+
"name": "django",
|
| 543 |
+
"language": "python",
|
| 544 |
+
"description": "High-level Python web framework"
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"name": "flask",
|
| 548 |
+
"language": "python",
|
| 549 |
+
"description": "Lightweight WSGI web application framework"
|
| 550 |
+
}
|
| 551 |
+
]
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
@app.get("/api/v1/examples")
|
| 555 |
+
async def get_example_requests():
|
| 556 |
+
"""Get example generation requests"""
|
| 557 |
+
return {
|
| 558 |
+
"examples": [
|
| 559 |
+
{
|
| 560 |
+
"name": "E-commerce API",
|
| 561 |
+
"request": {
|
| 562 |
+
"description": "Complete e-commerce backend with user management and product catalog",
|
| 563 |
+
"framework": "fastapi",
|
| 564 |
+
"language": "python",
|
| 565 |
+
"requirements": [
|
| 566 |
+
"User registration and authentication",
|
| 567 |
+
"Product CRUD operations",
|
| 568 |
+
"Shopping cart functionality",
|
| 569 |
+
"Order management",
|
| 570 |
+
"Payment processing integration"
|
| 571 |
+
]
|
| 572 |
+
}
|
| 573 |
+
},
|
| 574 |
+
{
|
| 575 |
+
"name": "Task Management System",
|
| 576 |
+
"request": {
|
| 577 |
+
"description": "Task management system with team collaboration",
|
| 578 |
+
"framework": "express",
|
| 579 |
+
"language": "javascript",
|
| 580 |
+
"requirements": [
|
| 581 |
+
"User authentication with JWT",
|
| 582 |
+
"Task CRUD operations",
|
| 583 |
+
"Team and project management",
|
| 584 |
+
"Real-time notifications",
|
| 585 |
+
"File attachments"
|
| 586 |
+
]
|
| 587 |
+
}
|
| 588 |
+
},
|
| 589 |
+
{
|
| 590 |
+
"name": "Blog Platform",
|
| 591 |
+
"request": {
|
| 592 |
+
"description": "Blog platform with content management",
|
| 593 |
+
"framework": "django",
|
| 594 |
+
"language": "python",
|
| 595 |
+
"requirements": [
|
| 596 |
+
"Article management",
|
| 597 |
+
"User comments and ratings",
|
| 598 |
+
"Category and tag system",
|
| 599 |
+
"SEO optimization",
|
| 600 |
+
"Media file handling"
|
| 601 |
+
]
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
]
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
if __name__ == "__main__":
|
| 608 |
+
import uvicorn
|
| 609 |
+
uvicorn.run(
|
| 610 |
+
"api_service:app",
|
| 611 |
+
host="0.0.0.0",
|
| 612 |
+
port=8000,
|
| 613 |
+
reload=True
|
| 614 |
+
)
|
data/.gitkeep
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
evaluation/.gitkeep
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
models/.gitkeep
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
raw_dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
datasets
|
| 4 |
+
pandas
|
| 5 |
+
numpy
|
| 6 |
+
aiohttp
|
| 7 |
+
requests
|
| 8 |
+
accelerate
|
| 9 |
+
fastapi
|
| 10 |
+
uvicorn
|
| 11 |
+
python-multipart
|
| 12 |
+
|
setup-guide.md
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend Code Generation Model - Setup & Usage Guide
|
| 2 |
+
|
| 3 |
+
## 🛠️ Installation & Setup
|
| 4 |
+
|
| 5 |
+
### 1. Install Dependencies
|
| 6 |
+
```bash
|
| 7 |
+
pip install torch transformers datasets pandas numpy aiohttp requests
|
| 8 |
+
pip install accelerate # For faster training
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
### 2. Set Environment Variables
|
| 12 |
+
```bash
|
| 13 |
+
# Optional: GitHub token for collecting real repositories
|
| 14 |
+
export GITHUB_TOKEN="your_github_token_here"
|
| 15 |
+
|
| 16 |
+
# For GPU training (if available)
|
| 17 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### 3. Directory Structure
|
| 21 |
+
```
|
| 22 |
+
backend-ai-trainer/
|
| 23 |
+
├── training_pipeline.py # Main pipeline code
|
| 24 |
+
├── data/
|
| 25 |
+
│ ├── raw_dataset.json # Collected training data
|
| 26 |
+
│ └── processed/ # Preprocessed data
|
| 27 |
+
├── models/
|
| 28 |
+
│ ├── backend_code_model/ # Trained model output
|
| 29 |
+
│ └── checkpoints/ # Training checkpoints
|
| 30 |
+
└── evaluation/
|
| 31 |
+
├── test_cases.json # Test scenarios
|
| 32 |
+
└── results/ # Evaluation results
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## 🏃♂️ Quick Start
|
| 36 |
+
|
| 37 |
+
### Option A: Full Automated Pipeline
|
| 38 |
+
```python
|
| 39 |
+
import asyncio
|
| 40 |
+
from training_pipeline import TrainingPipeline
|
| 41 |
+
|
| 42 |
+
config = {
|
| 43 |
+
'base_model': 'microsoft/DialoGPT-medium',
|
| 44 |
+
'output_dir': './models/backend_code_model',
|
| 45 |
+
'github_token': 'your_token_here', # Optional
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
pipeline = TrainingPipeline(config)
|
| 49 |
+
asyncio.run(pipeline.run_full_pipeline())
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Option B: Step-by-Step Execution
|
| 53 |
+
|
| 54 |
+
#### Step 1: Collect Training Data
|
| 55 |
+
```python
|
| 56 |
+
from training_pipeline import DataCollector
|
| 57 |
+
import asyncio
|
| 58 |
+
|
| 59 |
+
collector = DataCollector()
|
| 60 |
+
|
| 61 |
+
# Collect from GitHub (requires token)
|
| 62 |
+
github_queries = [
|
| 63 |
+
'express api backend',
|
| 64 |
+
'fastapi python backend',
|
| 65 |
+
'django rest api',
|
| 66 |
+
'nodejs backend server',
|
| 67 |
+
'flask api backend'
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
asyncio.run(collector.collect_github_repositories(github_queries, max_repos=100))
|
| 71 |
+
|
| 72 |
+
# Generate synthetic examples
|
| 73 |
+
collector.generate_synthetic_examples(count=500)
|
| 74 |
+
|
| 75 |
+
# Save dataset
|
| 76 |
+
collector.save_dataset('training_data.json')
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
#### Step 2: Preprocess Data
|
| 80 |
+
```python
|
| 81 |
+
from training_pipeline import DataPreprocessor
|
| 82 |
+
|
| 83 |
+
preprocessor = DataPreprocessor()
|
| 84 |
+
processed_examples = preprocessor.preprocess_examples(collector.collected_examples)
|
| 85 |
+
training_dataset = preprocessor.create_training_dataset(processed_examples)
|
| 86 |
+
|
| 87 |
+
print(f"Created dataset with {len(training_dataset)} examples")
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
#### Step 3: Train Model
|
| 91 |
+
```python
|
| 92 |
+
from training_pipeline import CodeGenerationModel
|
| 93 |
+
|
| 94 |
+
model = CodeGenerationModel('microsoft/DialoGPT-medium')
|
| 95 |
+
model.fine_tune(training_dataset, output_dir='./trained_model')
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
#### Step 4: Generate Code
|
| 99 |
+
```python
|
| 100 |
+
# Generate a complete backend application
|
| 101 |
+
generated_code = model.generate_code(
|
| 102 |
+
description="E-commerce API with user authentication and product management",
|
| 103 |
+
framework="fastapi",
|
| 104 |
+
language="python"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
print("Generated Backend Application:")
|
| 108 |
+
print("=" * 50)
|
| 109 |
+
print(generated_code)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## 🎯 Training Configuration Options
|
| 113 |
+
|
| 114 |
+
### Model Selection
|
| 115 |
+
```python
|
| 116 |
+
# Lightweight for testing
|
| 117 |
+
config['base_model'] = 'microsoft/DialoGPT-small'
|
| 118 |
+
|
| 119 |
+
# Balanced performance
|
| 120 |
+
config['base_model'] = 'microsoft/DialoGPT-medium'
|
| 121 |
+
|
| 122 |
+
# High quality (requires more resources)
|
| 123 |
+
config['base_model'] = 'microsoft/DialoGPT-large'
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Training Parameters
|
| 127 |
+
```python
|
| 128 |
+
training_config = {
|
| 129 |
+
'num_epochs': 5, # More epochs = better learning
|
| 130 |
+
'batch_size': 4, # Adjust based on GPU memory
|
| 131 |
+
'learning_rate': 5e-5, # Conservative learning rate
|
| 132 |
+
'max_length': 2048, # Maximum token length
|
| 133 |
+
'warmup_steps': 500, # Learning rate warmup
|
| 134 |
+
'save_steps': 1000, # Checkpoint frequency
|
| 135 |
+
}
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Framework Coverage
|
| 139 |
+
The pipeline supports these backend frameworks:
|
| 140 |
+
|
| 141 |
+
**Node.js Frameworks:**
|
| 142 |
+
- Express.js - Most popular Node.js framework
|
| 143 |
+
- NestJS - Enterprise-grade framework
|
| 144 |
+
- Koa.js - Lightweight alternative
|
| 145 |
+
|
| 146 |
+
**Python Frameworks:**
|
| 147 |
+
- FastAPI - Modern, high-performance API framework
|
| 148 |
+
- Django - Full-featured web framework
|
| 149 |
+
- Flask - Lightweight and flexible
|
| 150 |
+
|
| 151 |
+
**Go Frameworks:**
|
| 152 |
+
- Gin - HTTP web framework
|
| 153 |
+
- Fiber - Express-inspired framework
|
| 154 |
+
|
| 155 |
+
## 📊 Evaluation & Testing
|
| 156 |
+
|
| 157 |
+
### Automatic Quality Assessment
|
| 158 |
+
```python
|
| 159 |
+
from training_pipeline import ModelEvaluator
|
| 160 |
+
|
| 161 |
+
evaluator = ModelEvaluator()
|
| 162 |
+
|
| 163 |
+
# Test specific code generation
|
| 164 |
+
generated_code = model.generate_code(
|
| 165 |
+
description="User authentication API with JWT tokens",
|
| 166 |
+
framework="express",
|
| 167 |
+
language="javascript"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Get quality scores
|
| 171 |
+
quality_scores = evaluator.evaluate_code_quality(generated_code, "javascript")
|
| 172 |
+
print(f"Syntax Correctness: {quality_scores['syntax_correctness']:.2f}")
|
| 173 |
+
print(f"Completeness: {quality_scores['completeness']:.2f}")
|
| 174 |
+
print(f"Best Practices: {quality_scores['best_practices']:.2f}")
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Comprehensive Benchmarking
|
| 178 |
+
```python
|
| 179 |
+
test_cases = [
|
| 180 |
+
{
|
| 181 |
+
'description': 'REST API for task management with user authentication',
|
| 182 |
+
'framework': 'express',
|
| 183 |
+
'language': 'javascript'
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
'description': 'GraphQL API for social media platform',
|
| 187 |
+
'framework': 'fastapi',
|
| 188 |
+
'language': 'python'
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
'description': 'Microservice for payment processing',
|
| 192 |
+
'framework': 'gin',
|
| 193 |
+
'language': 'go'
|
| 194 |
+
}
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
benchmark_results = evaluator.benchmark_model(model, test_cases)
|
| 198 |
+
print("Overall Performance:", benchmark_results)
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
## 🚀 Advanced Usage
|
| 202 |
+
|
| 203 |
+
### Custom Data Sources
|
| 204 |
+
```python
|
| 205 |
+
# Add your own training examples
|
| 206 |
+
custom_examples = [
|
| 207 |
+
{
|
| 208 |
+
'description': 'Custom API requirement',
|
| 209 |
+
'requirements': ['Custom feature 1', 'Custom feature 2'],
|
| 210 |
+
'framework': 'fastapi',
|
| 211 |
+
'language': 'python',
|
| 212 |
+
'code_files': {
|
| 213 |
+
'main.py': '# Your custom code here',
|
| 214 |
+
'requirements.txt': 'fastapi\nuvicorn'
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
# Add to training data
|
| 220 |
+
collector.collected_examples.extend([CodeExample(**ex) for ex in custom_examples])
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
### Fine-tuning on Specific Domains
|
| 224 |
+
```python
|
| 225 |
+
# Focus training on specific application types
|
| 226 |
+
domain_specific_queries = [
|
| 227 |
+
'microservices architecture',
|
| 228 |
+
'api gateway implementation',
|
| 229 |
+
'database orm integration',
|
| 230 |
+
'authentication middleware',
|
| 231 |
+
'rate limiting api'
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
asyncio.run(collector.collect_github_repositories(domain_specific_queries))
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### Export Trained Model
|
| 238 |
+
```python
|
| 239 |
+
# Save model for deployment
|
| 240 |
+
model.model.save_pretrained('./production_model')
|
| 241 |
+
model.tokenizer.save_pretrained('./production_model')
|
| 242 |
+
|
| 243 |
+
# Load for inference
|
| 244 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 245 |
+
|
| 246 |
+
production_model = AutoModelForCausalLM.from_pretrained('./production_model')
|
| 247 |
+
production_tokenizer = AutoTokenizer.from_pretrained('./production_model')
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## 🔧 Troubleshooting
|
| 251 |
+
|
| 252 |
+
### Common Issues
|
| 253 |
+
|
| 254 |
+
**1. Out of Memory Errors**
|
| 255 |
+
```python
|
| 256 |
+
# Reduce batch size
|
| 257 |
+
config['per_device_train_batch_size'] = 1
|
| 258 |
+
config['gradient_accumulation_steps'] = 4
|
| 259 |
+
|
| 260 |
+
# Use gradient checkpointing
|
| 261 |
+
config['gradient_checkpointing'] = True
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
**2. Slow Training**
|
| 265 |
+
```python
|
| 266 |
+
# Enable mixed precision (if GPU supports it)
|
| 267 |
+
config['fp16'] = True
|
| 268 |
+
|
| 269 |
+
# Use multiple GPUs
|
| 270 |
+
config['dataloader_num_workers'] = 4
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
**3. Poor Code Quality**
|
| 274 |
+
```python
|
| 275 |
+
# Increase training data diversity
|
| 276 |
+
collector.generate_synthetic_examples(count=1000)
|
| 277 |
+
|
| 278 |
+
# Extend training duration
|
| 279 |
+
config['num_train_epochs'] = 10
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
### Performance Optimization
|
| 283 |
+
|
| 284 |
+
**For CPU Training:**
|
| 285 |
+
```python
|
| 286 |
+
config['dataloader_pin_memory'] = False
|
| 287 |
+
config['per_device_train_batch_size'] = 1
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
**For GPU Training:**
|
| 291 |
+
```python
|
| 292 |
+
config['fp16'] = True
|
| 293 |
+
config['dataloader_pin_memory'] = True
|
| 294 |
+
config['per_device_train_batch_size'] = 4
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
## 📈 Expected Results
|
| 298 |
+
|
| 299 |
+
After training on ~500-1000 examples, you should expect:
|
| 300 |
+
|
| 301 |
+
- **Syntax Correctness**: 85-95%
|
| 302 |
+
- **Code Completeness**: 80-90%
|
| 303 |
+
- **Best Practices**: 70-85%
|
| 304 |
+
- **Framework Coverage**: All major Node.js and Python frameworks
|
| 305 |
+
- **Generation Speed**: 2-5 seconds per application
|
| 306 |
+
|
| 307 |
+
## 🔄 Continuous Improvement
|
| 308 |
+
|
| 309 |
+
### Regular Retraining
|
| 310 |
+
```python
|
| 311 |
+
# Schedule weekly data collection
|
| 312 |
+
import schedule
|
| 313 |
+
|
| 314 |
+
def update_training_data():
|
| 315 |
+
asyncio.run(collector.collect_github_repositories(['new backend trends']))
|
| 316 |
+
|
| 317 |
+
schedule.every().week.do(update_training_data)
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
### A/B Testing Different Models
|
| 321 |
+
```python
|
| 322 |
+
models_to_compare = [
|
| 323 |
+
'microsoft/DialoGPT-medium',
|
| 324 |
+
'microsoft/DialoGPT-large',
|
| 325 |
+
'gpt2-medium'
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
for base_model in models_to_compare:
|
| 329 |
+
model = CodeGenerationModel(base_model)
|
| 330 |
+
results = evaluator.benchmark_model(model, test_cases)
|
| 331 |
+
print(f"{base_model}: {results}")
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
## 🎯 Next Steps
|
| 335 |
+
|
| 336 |
+
1. **Start Small**: Begin with synthetic data and 100-200 examples
|
| 337 |
+
2. **Add Real Data**: Integrate GitHub repositories gradually
|
| 338 |
+
3. **Evaluate Regularly**: Monitor quality metrics after each training session
|
| 339 |
+
4. **Expand Frameworks**: Add support for new frameworks as needed
|
| 340 |
+
5. **Production Deploy**: Export model for API deployment
|
| 341 |
+
|
| 342 |
+
This pipeline provides a complete foundation for building your own backend code generation AI. The modular design allows you to customize and extend each component based on your specific needs.
|
test_api.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for the Backend Code Generation API
|
| 4 |
+
===============================================
|
| 5 |
+
|
| 6 |
+
Simple test script to verify the API is working correctly.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
# API base URL
|
| 15 |
+
BASE_URL = "http://localhost:8000"
|
| 16 |
+
|
| 17 |
+
def test_health():
|
| 18 |
+
"""Test the health endpoint"""
|
| 19 |
+
print("Testing health endpoint...")
|
| 20 |
+
response = requests.get(f"{BASE_URL}/health")
|
| 21 |
+
print(f"Status: {response.status_code}")
|
| 22 |
+
print(f"Response: {response.json()}")
|
| 23 |
+
return response.status_code == 200
|
| 24 |
+
|
| 25 |
+
def test_generate_code():
|
| 26 |
+
"""Test code generation"""
|
| 27 |
+
print("\nTesting code generation...")
|
| 28 |
+
|
| 29 |
+
# Test request
|
| 30 |
+
request_data = {
|
| 31 |
+
"description": "Simple REST API for task management",
|
| 32 |
+
"framework": "fastapi",
|
| 33 |
+
"language": "python",
|
| 34 |
+
"requirements": [
|
| 35 |
+
"User authentication",
|
| 36 |
+
"Task CRUD operations",
|
| 37 |
+
"Task status tracking"
|
| 38 |
+
],
|
| 39 |
+
"project_name": "task-manager-api"
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Submit generation request
|
| 43 |
+
response = requests.post(f"{BASE_URL}/api/v1/generate", json=request_data)
|
| 44 |
+
print(f"Generation request status: {response.status_code}")
|
| 45 |
+
|
| 46 |
+
if response.status_code == 200:
|
| 47 |
+
result = response.json()
|
| 48 |
+
task_id = result["task_id"]
|
| 49 |
+
print(f"Task ID: {task_id}")
|
| 50 |
+
|
| 51 |
+
# Poll for completion
|
| 52 |
+
print("Polling for completion...")
|
| 53 |
+
for i in range(30): # Wait up to 5 minutes
|
| 54 |
+
status_response = requests.get(f"{BASE_URL}/api/v1/status/{task_id}")
|
| 55 |
+
if status_response.status_code == 200:
|
| 56 |
+
status = status_response.json()
|
| 57 |
+
print(f"Status: {status['status']} - {status['message']} ({status['progress']}%)")
|
| 58 |
+
|
| 59 |
+
if status["status"] == "completed":
|
| 60 |
+
print("✅ Generation completed!")
|
| 61 |
+
if status.get("download_url"):
|
| 62 |
+
print(f"Download URL: {status['download_url']}")
|
| 63 |
+
return True
|
| 64 |
+
elif status["status"] == "failed":
|
| 65 |
+
print(f"❌ Generation failed: {status.get('error', 'Unknown error')}")
|
| 66 |
+
return False
|
| 67 |
+
else:
|
| 68 |
+
print(f"Failed to get status: {status_response.status_code}")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
time.sleep(10) # Wait 10 seconds between polls
|
| 72 |
+
|
| 73 |
+
print("⏰ Timeout waiting for completion")
|
| 74 |
+
return False
|
| 75 |
+
else:
|
| 76 |
+
print(f"❌ Generation request failed: {response.text}")
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
def test_frameworks():
|
| 80 |
+
"""Test frameworks endpoint"""
|
| 81 |
+
print("\nTesting frameworks endpoint...")
|
| 82 |
+
response = requests.get(f"{BASE_URL}/api/v1/frameworks")
|
| 83 |
+
print(f"Status: {response.status_code}")
|
| 84 |
+
if response.status_code == 200:
|
| 85 |
+
frameworks = response.json()
|
| 86 |
+
print(f"Supported frameworks: {len(frameworks['frameworks'])}")
|
| 87 |
+
for fw in frameworks['frameworks']:
|
| 88 |
+
print(f" - {fw['name']} ({fw['language']})")
|
| 89 |
+
return True
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def test_examples():
|
| 93 |
+
"""Test examples endpoint"""
|
| 94 |
+
print("\nTesting examples endpoint...")
|
| 95 |
+
response = requests.get(f"{BASE_URL}/api/v1/examples")
|
| 96 |
+
print(f"Status: {response.status_code}")
|
| 97 |
+
if response.status_code == 200:
|
| 98 |
+
examples = response.json()
|
| 99 |
+
print(f"Available examples: {len(examples['examples'])}")
|
| 100 |
+
for ex in examples['examples']:
|
| 101 |
+
print(f" - {ex['name']}")
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
"""Run all tests"""
|
| 107 |
+
print("🚀 Testing Backend Code Generation API")
|
| 108 |
+
print("=" * 50)
|
| 109 |
+
|
| 110 |
+
# Check if API is running
|
| 111 |
+
try:
|
| 112 |
+
response = requests.get(f"{BASE_URL}/", timeout=5)
|
| 113 |
+
if response.status_code != 200:
|
| 114 |
+
print("❌ API is not running. Please start it with: python api_service.py")
|
| 115 |
+
return
|
| 116 |
+
except requests.exceptions.RequestException:
|
| 117 |
+
print("❌ Cannot connect to API. Please start it with: python api_service.py")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
print("✅ API is running")
|
| 121 |
+
|
| 122 |
+
# Run tests
|
| 123 |
+
tests = [
|
| 124 |
+
("Health Check", test_health),
|
| 125 |
+
("Frameworks List", test_frameworks),
|
| 126 |
+
("Examples List", test_examples),
|
| 127 |
+
("Code Generation", test_generate_code),
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
results = []
|
| 131 |
+
for test_name, test_func in tests:
|
| 132 |
+
print(f"\n{'='*20} {test_name} {'='*20}")
|
| 133 |
+
try:
|
| 134 |
+
result = test_func()
|
| 135 |
+
results.append((test_name, result))
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"❌ Test failed with error: {e}")
|
| 138 |
+
results.append((test_name, False))
|
| 139 |
+
|
| 140 |
+
# Summary
|
| 141 |
+
print(f"\n{'='*50}")
|
| 142 |
+
print("📊 Test Results Summary:")
|
| 143 |
+
print("=" * 50)
|
| 144 |
+
|
| 145 |
+
passed = 0
|
| 146 |
+
for test_name, result in results:
|
| 147 |
+
status = "✅ PASS" if result else "❌ FAIL"
|
| 148 |
+
print(f"{test_name}: {status}")
|
| 149 |
+
if result:
|
| 150 |
+
passed += 1
|
| 151 |
+
|
| 152 |
+
print(f"\nPassed: {passed}/{len(results)} tests")
|
| 153 |
+
|
| 154 |
+
if passed == len(results):
|
| 155 |
+
print("🎉 All tests passed!")
|
| 156 |
+
else:
|
| 157 |
+
print("⚠️ Some tests failed. Check the output above for details.")
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
training_pipeline.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Backend Code Generation Model Training Pipeline
|
| 4 |
+
===============================================
|
| 5 |
+
|
| 6 |
+
A comprehensive training pipeline for building an AI model that generates
|
| 7 |
+
framework-agnostic backend code with full application scaffolding.
|
| 8 |
+
|
| 9 |
+
Features:
|
| 10 |
+
- Data collection from multiple sources
|
| 11 |
+
- Multi-framework support (Express.js, FastAPI, Django, Flask, etc.)
|
| 12 |
+
- Full application scaffolding generation
|
| 13 |
+
- Model training with transformer architecture
|
| 14 |
+
- Evaluation and benchmarking tools
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import asyncio
|
| 21 |
+
import aiohttp
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import numpy as np
|
| 24 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 25 |
+
from dataclasses import dataclass, asdict
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
from torch.utils.data import Dataset, DataLoader
|
| 30 |
+
from transformers import (
|
| 31 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
| 32 |
+
Trainer, DataCollatorForLanguageModeling
|
| 33 |
+
)
|
| 34 |
+
from datasets import Dataset as HFDataset
|
| 35 |
+
import ast
|
| 36 |
+
import subprocess
|
| 37 |
+
import tempfile
|
| 38 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 39 |
+
import requests
|
| 40 |
+
import time
|
| 41 |
+
import random
|
| 42 |
+
|
| 43 |
+
# Configure logging
|
| 44 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class CodeExample:
|
| 50 |
+
"""Represents a single training example"""
|
| 51 |
+
description: str
|
| 52 |
+
requirements: List[str]
|
| 53 |
+
framework: str
|
| 54 |
+
language: str
|
| 55 |
+
code_files: Dict[str, str] # filename -> content
|
| 56 |
+
project_structure: Dict[str, Any]
|
| 57 |
+
metadata: Dict[str, Any]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DataCollector:
|
| 61 |
+
"""Collects training data from various sources"""
|
| 62 |
+
|
| 63 |
+
def __init__(self):
|
| 64 |
+
self.github_token = os.getenv('GITHUB_TOKEN')
|
| 65 |
+
self.collected_examples: List[CodeExample] = []
|
| 66 |
+
|
| 67 |
+
async def collect_github_repositories(self, queries: List[str], max_repos: int = 100):
|
| 68 |
+
"""Collect backend projects from GitHub"""
|
| 69 |
+
logger.info("Starting GitHub repository collection...")
|
| 70 |
+
|
| 71 |
+
headers = {'Authorization': f'token {self.github_token}'} if self.github_token else {}
|
| 72 |
+
|
| 73 |
+
async with aiohttp.ClientSession(headers=headers) as session:
|
| 74 |
+
per_query = max(1, max_repos // max(1, len(queries)))
|
| 75 |
+
for query in queries:
|
| 76 |
+
await self._search_github_repos(session, query, per_query)
|
| 77 |
+
|
| 78 |
+
async def _search_github_repos(self, session: aiohttp.ClientSession, query: str, limit: int):
|
| 79 |
+
"""Search GitHub for repositories matching query"""
|
| 80 |
+
url = f"https://api.github.com/search/repositories"
|
| 81 |
+
params = {
|
| 82 |
+
'q': query,
|
| 83 |
+
'sort': 'stars',
|
| 84 |
+
'order': 'desc',
|
| 85 |
+
'per_page': min(limit, 100)
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
async with session.get(url, params=params) as response:
|
| 90 |
+
if response.status == 200:
|
| 91 |
+
data = await response.json()
|
| 92 |
+
for repo in data.get('items', []):
|
| 93 |
+
await self._process_repository(session, repo)
|
| 94 |
+
else:
|
| 95 |
+
logger.warning(f"GitHub API request failed: {response.status}")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Error searching GitHub: {e}")
|
| 98 |
+
|
| 99 |
+
async def _process_repository(self, session: aiohttp.ClientSession, repo: Dict):
|
| 100 |
+
"""Process a single repository to extract code examples"""
|
| 101 |
+
logger.info(f"Processing repository: {repo.get('full_name', '<unknown>')}")
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
contents_url = f"https://api.github.com/repos/{repo['full_name']}/contents"
|
| 105 |
+
async with session.get(contents_url) as response:
|
| 106 |
+
if response.status == 200:
|
| 107 |
+
contents = await response.json()
|
| 108 |
+
await self._extract_code_example(session, repo, contents)
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Error processing repository {repo.get('full_name')}: {e}")
|
| 111 |
+
|
| 112 |
+
async def _extract_code_example(self, session: aiohttp.ClientSession, repo: Dict, contents: List[Dict]):
|
| 113 |
+
"""Extract a structured code example from repository"""
|
| 114 |
+
framework = self._identify_framework(contents, repo.get('description', ''))
|
| 115 |
+
language = self._identify_language(contents)
|
| 116 |
+
|
| 117 |
+
if not framework or not language:
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
code_files: Dict[str, str] = {}
|
| 121 |
+
for item in contents:
|
| 122 |
+
if item.get('type') == 'file' and self._is_important_file(item.get('name', '')):
|
| 123 |
+
try:
|
| 124 |
+
async with session.get(item['download_url']) as response:
|
| 125 |
+
if response.status == 200:
|
| 126 |
+
content = await response.text()
|
| 127 |
+
code_files[item['name']] = content
|
| 128 |
+
except Exception:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
if code_files:
|
| 132 |
+
example = CodeExample(
|
| 133 |
+
description=repo.get('description', ''),
|
| 134 |
+
requirements=self._extract_requirements(code_files),
|
| 135 |
+
framework=framework,
|
| 136 |
+
language=language,
|
| 137 |
+
code_files=code_files,
|
| 138 |
+
project_structure=self._analyze_structure(contents),
|
| 139 |
+
metadata={
|
| 140 |
+
'stars': repo.get('stargazers_count', 0),
|
| 141 |
+
'forks': repo.get('forks_count', 0),
|
| 142 |
+
'url': repo.get('html_url'),
|
| 143 |
+
'created_at': repo.get('created_at'),
|
| 144 |
+
'updated_at': repo.get('updated_at')
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
self.collected_examples.append(example)
|
| 148 |
+
|
| 149 |
+
def _identify_framework(self, contents: List[Dict], description: str) -> Optional[str]:
|
| 150 |
+
"""Identify the backend framework used"""
|
| 151 |
+
filenames = [item.get('name', '').lower() for item in contents if item.get('type') == 'file']
|
| 152 |
+
|
| 153 |
+
frameworks = {
|
| 154 |
+
'express': ['package.json', 'app.js', 'server.js'],
|
| 155 |
+
'fastapi': ['requirements.txt', 'main.py', 'app.py'],
|
| 156 |
+
'django': ['manage.py', 'settings.py', 'requirements.txt'],
|
| 157 |
+
'flask': ['app.py', 'requirements.txt'],
|
| 158 |
+
'nestjs': ['nest-cli.json', 'package.json'],
|
| 159 |
+
'koa': ['package.json'],
|
| 160 |
+
'gin': ['go.mod', 'main.go'],
|
| 161 |
+
'fiber': ['go.mod', 'main.go'],
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
for framework, required_files in frameworks.items():
|
| 165 |
+
if all(any(req in filename for filename in filenames) for req in required_files[:2]):
|
| 166 |
+
return framework
|
| 167 |
+
|
| 168 |
+
desc_lower = description.lower()
|
| 169 |
+
for framework in frameworks.keys():
|
| 170 |
+
if framework in desc_lower:
|
| 171 |
+
return framework
|
| 172 |
+
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
def _identify_language(self, contents: List[Dict]) -> Optional[str]:
|
| 176 |
+
"""Identify primary programming language"""
|
| 177 |
+
extensions: Dict[str, int] = {}
|
| 178 |
+
for item in contents:
|
| 179 |
+
if item.get('type') == 'file':
|
| 180 |
+
ext = Path(item.get('name', '')).suffix.lower()
|
| 181 |
+
if ext:
|
| 182 |
+
extensions[ext] = extensions.get(ext, 0) + 1
|
| 183 |
+
|
| 184 |
+
lang_map = {
|
| 185 |
+
'.js': 'javascript',
|
| 186 |
+
'.ts': 'typescript',
|
| 187 |
+
'.py': 'python',
|
| 188 |
+
'.go': 'go',
|
| 189 |
+
'.java': 'java',
|
| 190 |
+
'.cs': 'csharp',
|
| 191 |
+
'.rb': 'ruby',
|
| 192 |
+
'.php': 'php'
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if extensions:
|
| 196 |
+
most_common_ext = max(extensions.items(), key=lambda x: x[1])[0]
|
| 197 |
+
return lang_map.get(most_common_ext)
|
| 198 |
+
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
def _is_important_file(self, filename: str) -> bool:
|
| 202 |
+
"""Check if file is important for training"""
|
| 203 |
+
important_patterns = [
|
| 204 |
+
'package.json', 'requirements.txt', 'go.mod', 'pom.xml',
|
| 205 |
+
'dockerfile', 'docker-compose.yml', 'readme.md',
|
| 206 |
+
'app.py', 'main.py', 'server.js', 'app.js', 'index.js',
|
| 207 |
+
'settings.py', 'config.py', 'routes.py', 'models.py',
|
| 208 |
+
'controller.js', 'service.js', 'middleware.js'
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
filename_lower = filename.lower()
|
| 212 |
+
return any(pattern in filename_lower for pattern in important_patterns)
|
| 213 |
+
|
| 214 |
+
def _extract_requirements(self, code_files: Dict[str, str]) -> List[str]:
|
| 215 |
+
"""Extract functional requirements from code"""
|
| 216 |
+
requirements: List[str] = []
|
| 217 |
+
|
| 218 |
+
if 'package.json' in code_files:
|
| 219 |
+
try:
|
| 220 |
+
pkg_data = json.loads(code_files['package.json'])
|
| 221 |
+
deps = list(pkg_data.get('dependencies', {}).keys())
|
| 222 |
+
requirements.extend([f"Uses {dep}" for dep in deps[:5]])
|
| 223 |
+
except Exception:
|
| 224 |
+
pass
|
| 225 |
+
|
| 226 |
+
if 'requirements.txt' in code_files:
|
| 227 |
+
lines = code_files['requirements.txt'].strip().split('\n')
|
| 228 |
+
deps = [line.split('==')[0].split('>=')[0].strip() for line in lines if line.strip()]
|
| 229 |
+
requirements.extend([f"Uses {dep}" for dep in deps[:5]])
|
| 230 |
+
|
| 231 |
+
for filename, content in code_files.items():
|
| 232 |
+
if filename.endswith(('.js', '.py')):
|
| 233 |
+
endpoints = self._extract_endpoints(content)
|
| 234 |
+
requirements.extend(endpoints)
|
| 235 |
+
|
| 236 |
+
return requirements[:10]
|
| 237 |
+
|
| 238 |
+
def _extract_endpoints(self, code_content: str) -> List[str]:
|
| 239 |
+
"""Extract API endpoints from code"""
|
| 240 |
+
endpoints: List[str] = []
|
| 241 |
+
lines = code_content.split('\n')
|
| 242 |
+
|
| 243 |
+
for line in lines:
|
| 244 |
+
s = line.strip()
|
| 245 |
+
if any(method in s for method in ['app.get(', 'app.post(', 'app.put(', 'app.delete(']):
|
| 246 |
+
endpoints.append(f"Implements {s}")
|
| 247 |
+
elif any(decorator in s for decorator in ['@app.get(', '@app.post(', '@app.put(', '@app.delete(']):
|
| 248 |
+
endpoints.append(f"Implements {s}")
|
| 249 |
+
elif 'def ' in s and any(word in s for word in ['get', 'post', 'put', 'delete']):
|
| 250 |
+
endpoints.append(f"Implements {s}")
|
| 251 |
+
|
| 252 |
+
return endpoints[:5]
|
| 253 |
+
|
| 254 |
+
def _analyze_structure(self, contents: List[Dict]) -> Dict[str, Any]:
|
| 255 |
+
"""Analyze project structure"""
|
| 256 |
+
structure: Dict[str, Any] = {
|
| 257 |
+
'files': [],
|
| 258 |
+
'directories': [],
|
| 259 |
+
'total_files': 0,
|
| 260 |
+
'has_tests': False,
|
| 261 |
+
'has_docs': False
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
for item in contents:
|
| 265 |
+
if item.get('type') == 'file':
|
| 266 |
+
name = item.get('name', '')
|
| 267 |
+
structure['files'].append(name)
|
| 268 |
+
structure['total_files'] += 1
|
| 269 |
+
if 'test' in name.lower():
|
| 270 |
+
structure['has_tests'] = True
|
| 271 |
+
if name.lower() in ['readme.md', 'docs.md']:
|
| 272 |
+
structure['has_docs'] = True
|
| 273 |
+
elif item.get('type') == 'dir':
|
| 274 |
+
structure['directories'].append(item.get('name', ''))
|
| 275 |
+
|
| 276 |
+
return structure
|
| 277 |
+
|
| 278 |
+
def generate_synthetic_examples(self, count: int = 100):
|
| 279 |
+
"""Generate synthetic training examples"""
|
| 280 |
+
logger.info(f"Generating {count} synthetic examples...")
|
| 281 |
+
|
| 282 |
+
templates = [
|
| 283 |
+
{
|
| 284 |
+
'description': 'REST API for user management',
|
| 285 |
+
'requirements': ['User registration', 'User authentication', 'Profile management'],
|
| 286 |
+
'frameworks': ['express', 'fastapi', 'django']
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
'description': 'E-commerce backend API',
|
| 290 |
+
'requirements': ['Product catalog', 'Shopping cart', 'Order processing', 'Payment integration'],
|
| 291 |
+
'frameworks': ['nestjs', 'fastapi', 'django']
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
'description': 'Task management system',
|
| 295 |
+
'requirements': ['Task CRUD operations', 'User assignments', 'Status tracking'],
|
| 296 |
+
'frameworks': ['express', 'flask', 'gin']
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
'description': 'Blog platform backend',
|
| 300 |
+
'requirements': ['Article management', 'User comments', 'Category system'],
|
| 301 |
+
'frameworks': ['express', 'django', 'fastapi']
|
| 302 |
+
}
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
for _ in range(count):
|
| 306 |
+
template = random.choice(templates)
|
| 307 |
+
framework = random.choice(template['frameworks'])
|
| 308 |
+
|
| 309 |
+
code_files = self._generate_code_for_template(template, framework)
|
| 310 |
+
|
| 311 |
+
example = CodeExample(
|
| 312 |
+
description=template['description'],
|
| 313 |
+
requirements=template['requirements'],
|
| 314 |
+
framework=framework,
|
| 315 |
+
language='python' if framework in ['fastapi', 'django', 'flask'] else 'javascript',
|
| 316 |
+
code_files=code_files,
|
| 317 |
+
project_structure=self._generate_synthetic_structure(framework),
|
| 318 |
+
metadata={'synthetic': True}
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.collected_examples.append(example)
|
| 322 |
+
|
| 323 |
+
def _generate_code_for_template(self, template: Dict, framework: str) -> Dict[str, str]:
|
| 324 |
+
"""Generate code files for a template and framework"""
|
| 325 |
+
if framework == 'express':
|
| 326 |
+
return {
|
| 327 |
+
'package.json': json.dumps({
|
| 328 |
+
"name": template['description'].lower().replace(' ', '-'),
|
| 329 |
+
"version": "1.0.0",
|
| 330 |
+
"dependencies": {
|
| 331 |
+
"express": "^4.18.0",
|
| 332 |
+
"mongoose": "^6.0.0",
|
| 333 |
+
"bcrypt": "^5.0.0",
|
| 334 |
+
"jsonwebtoken": "^8.5.0"
|
| 335 |
+
}
|
| 336 |
+
}, indent=2),
|
| 337 |
+
'app.js': '''const express = require('express');
|
| 338 |
+
const mongoose = require('mongoose');
|
| 339 |
+
const app = express();
|
| 340 |
+
|
| 341 |
+
// Middleware
|
| 342 |
+
app.use(express.json());
|
| 343 |
+
|
| 344 |
+
// Routes
|
| 345 |
+
app.get('/health', (req, res) => {
|
| 346 |
+
res.json({ status: 'OK' });
|
| 347 |
+
});
|
| 348 |
+
|
| 349 |
+
// Start server
|
| 350 |
+
const PORT = process.env.PORT || 3000;
|
| 351 |
+
app.listen(PORT, () => {
|
| 352 |
+
console.log(`Server running on port ${PORT}`);
|
| 353 |
+
});
|
| 354 |
+
|
| 355 |
+
module.exports = app;'''
|
| 356 |
+
}
|
| 357 |
+
elif framework == 'fastapi':
|
| 358 |
+
return {
|
| 359 |
+
'requirements.txt': '''fastapi==0.68.0
|
| 360 |
+
uvicorn==0.15.0
|
| 361 |
+
sqlalchemy==1.4.23
|
| 362 |
+
pydantic==1.8.2''',
|
| 363 |
+
'main.py': '''from fastapi import FastAPI, HTTPException
|
| 364 |
+
from pydantic import BaseModel
|
| 365 |
+
from typing import List, Optional
|
| 366 |
+
|
| 367 |
+
app = FastAPI()
|
| 368 |
+
|
| 369 |
+
class Item(BaseModel):
|
| 370 |
+
id: Optional[int] = None
|
| 371 |
+
name: str
|
| 372 |
+
description: str
|
| 373 |
+
|
| 374 |
+
@app.get("/")
|
| 375 |
+
async def root():
|
| 376 |
+
return {"message": "Hello World"}
|
| 377 |
+
|
| 378 |
+
@app.get("/health")
|
| 379 |
+
async def health_check():
|
| 380 |
+
return {"status": "OK"}
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
import uvicorn
|
| 384 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)'''
|
| 385 |
+
}
|
| 386 |
+
else:
|
| 387 |
+
return {'placeholder.txt': 'Generated code placeholder'}
|
| 388 |
+
|
| 389 |
+
def _generate_synthetic_structure(self, framework: str) -> Dict[str, Any]:
|
| 390 |
+
"""Generate project structure for framework"""
|
| 391 |
+
if framework in ['express', 'nestjs']:
|
| 392 |
+
return {
|
| 393 |
+
'files': ['package.json', 'app.js', 'README.md'],
|
| 394 |
+
'directories': ['routes', 'controllers', 'middleware', 'models'],
|
| 395 |
+
'total_files': 3,
|
| 396 |
+
'has_tests': True,
|
| 397 |
+
'has_docs': True
|
| 398 |
+
}
|
| 399 |
+
elif framework in ['fastapi', 'django', 'flask']:
|
| 400 |
+
return {
|
| 401 |
+
'files': ['requirements.txt', 'main.py', 'README.md'],
|
| 402 |
+
'directories': ['models', 'routes', 'services'],
|
| 403 |
+
'total_files': 3,
|
| 404 |
+
'has_tests': True,
|
| 405 |
+
'has_docs': True
|
| 406 |
+
}
|
| 407 |
+
else:
|
| 408 |
+
return {}
|
| 409 |
+
|
| 410 |
+
def save_dataset(self, filepath: str):
|
| 411 |
+
"""Save collected examples to file"""
|
| 412 |
+
data = [asdict(example) for example in self.collected_examples]
|
| 413 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 414 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 415 |
+
logger.info(f"Saved {len(data)} examples to {filepath}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class DataPreprocessor:
|
| 419 |
+
"""Preprocesses collected data for training"""
|
| 420 |
+
|
| 421 |
+
def __init__(self, tokenizer_name: str = "microsoft/DialoGPT-medium"):
|
| 422 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 423 |
+
if self.tokenizer.pad_token is None:
|
| 424 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 425 |
+
# Ensure we do not exceed model's maximum positional embeddings (GPT-2/DialoGPT: 1024)
|
| 426 |
+
try:
|
| 427 |
+
model_max = getattr(self.tokenizer, 'model_max_length', 1024)
|
| 428 |
+
# Some tokenizers set a very large sentinel value; cap at 1024 for GPT-2 family
|
| 429 |
+
if model_max and model_max > 0 and model_max < 100000:
|
| 430 |
+
self.max_length = min(1024, int(model_max))
|
| 431 |
+
else:
|
| 432 |
+
self.max_length = 1024
|
| 433 |
+
except Exception:
|
| 434 |
+
self.max_length = 1024
|
| 435 |
+
|
| 436 |
+
def preprocess_examples(self, examples: List[CodeExample]) -> List[Dict[str, str]]:
|
| 437 |
+
"""Convert examples to training format"""
|
| 438 |
+
processed: List[Dict[str, str]] = []
|
| 439 |
+
|
| 440 |
+
for example in examples:
|
| 441 |
+
input_text = self._create_input_text(example)
|
| 442 |
+
output_text = self._create_output_text(example)
|
| 443 |
+
|
| 444 |
+
processed.append({
|
| 445 |
+
'input': input_text,
|
| 446 |
+
'output': output_text,
|
| 447 |
+
'framework': example.framework,
|
| 448 |
+
'language': example.language
|
| 449 |
+
})
|
| 450 |
+
|
| 451 |
+
return processed
|
| 452 |
+
|
| 453 |
+
def _create_input_text(self, example: CodeExample) -> str:
|
| 454 |
+
"""Create model input text"""
|
| 455 |
+
input_parts: List[str] = [
|
| 456 |
+
f"Description: {example.description}",
|
| 457 |
+
f"Framework: {example.framework}",
|
| 458 |
+
f"Language: {example.language}",
|
| 459 |
+
"Requirements:",
|
| 460 |
+
]
|
| 461 |
+
|
| 462 |
+
for req in example.requirements:
|
| 463 |
+
input_parts.append(f"- {req}")
|
| 464 |
+
|
| 465 |
+
input_parts.append("Generate the backend application:")
|
| 466 |
+
|
| 467 |
+
return "\n".join(input_parts)
|
| 468 |
+
|
| 469 |
+
def _create_output_text(self, example: CodeExample) -> str:
|
| 470 |
+
"""Create model output text"""
|
| 471 |
+
output_parts: List[str] = []
|
| 472 |
+
|
| 473 |
+
output_parts.append("Project Structure:")
|
| 474 |
+
for directory in example.project_structure.get('directories', []):
|
| 475 |
+
output_parts.append(f"/{directory}/")
|
| 476 |
+
|
| 477 |
+
output_parts.append("\nGenerated Files:")
|
| 478 |
+
|
| 479 |
+
for filename, content in example.code_files.items():
|
| 480 |
+
output_parts.append(f"\n--- {filename} ---")
|
| 481 |
+
output_parts.append(content)
|
| 482 |
+
output_parts.append("--- End ---")
|
| 483 |
+
|
| 484 |
+
return "\n".join(output_parts)
|
| 485 |
+
|
| 486 |
+
def create_training_dataset(self, processed_examples: List[Dict[str, str]]) -> HFDataset:
|
| 487 |
+
"""Create Hugging Face dataset for training"""
|
| 488 |
+
|
| 489 |
+
def tokenize_function(examples: Dict[str, List[str]]):
|
| 490 |
+
texts: List[str] = []
|
| 491 |
+
for inp, out in zip(examples['input'], examples['output']):
|
| 492 |
+
text = f"<|startoftext|>{inp}<|separator|>{out}<|endoftext|>"
|
| 493 |
+
texts.append(text)
|
| 494 |
+
|
| 495 |
+
return self.tokenizer(
|
| 496 |
+
texts,
|
| 497 |
+
truncation=True,
|
| 498 |
+
padding=True,
|
| 499 |
+
max_length=self.max_length
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
dataset_dict = {
|
| 503 |
+
'input': [ex['input'] for ex in processed_examples],
|
| 504 |
+
'output': [ex['output'] for ex in processed_examples],
|
| 505 |
+
'framework': [ex['framework'] for ex in processed_examples],
|
| 506 |
+
'language': [ex['language'] for ex in processed_examples]
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
dataset = HFDataset.from_dict(dataset_dict)
|
| 510 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
| 511 |
+
|
| 512 |
+
return tokenized_dataset
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class CodeGenerationModel:
|
| 516 |
+
"""Custom model for backend code generation"""
|
| 517 |
+
|
| 518 |
+
def __init__(self, base_model: str = "microsoft/DialoGPT-medium"):
|
| 519 |
+
self.base_model = base_model
|
| 520 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 521 |
+
self.model = AutoModelForCausalLM.from_pretrained(base_model)
|
| 522 |
+
|
| 523 |
+
if self.tokenizer.pad_token is None:
|
| 524 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 525 |
+
|
| 526 |
+
def fine_tune(self, dataset: HFDataset, output_dir: str = "./trained_model"):
|
| 527 |
+
"""Fine-tune the model on backend code generation"""
|
| 528 |
+
logger.info("Starting model fine-tuning...")
|
| 529 |
+
|
| 530 |
+
training_args = TrainingArguments(
|
| 531 |
+
output_dir=output_dir,
|
| 532 |
+
overwrite_output_dir=True,
|
| 533 |
+
num_train_epochs=1, # Reduced from 3
|
| 534 |
+
per_device_train_batch_size=1, # Reduced from 2 for memory
|
| 535 |
+
per_device_eval_batch_size=1, # Reduced from 2
|
| 536 |
+
warmup_steps=50, # Reduced from 500
|
| 537 |
+
max_steps=100, # Drastically reduced from 2000
|
| 538 |
+
logging_steps=10, # More frequent logging
|
| 539 |
+
save_steps=50, # More frequent saves
|
| 540 |
+
save_total_limit=2,
|
| 541 |
+
prediction_loss_only=True,
|
| 542 |
+
fp16=torch.cuda.is_available(),
|
| 543 |
+
dataloader_pin_memory=False,
|
| 544 |
+
gradient_accumulation_steps=4, # Accumulate gradients for effective larger batch
|
| 545 |
+
learning_rate=5e-5, # Explicit learning rate
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 549 |
+
tokenizer=self.tokenizer,
|
| 550 |
+
mlm=False,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
train_size = int(0.8 * len(dataset))
|
| 554 |
+
eval_size = len(dataset) - train_size
|
| 555 |
+
train_dataset, eval_dataset = torch.utils.data.random_split(
|
| 556 |
+
dataset, [train_size, eval_size]
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
trainer = Trainer(
|
| 560 |
+
model=self.model,
|
| 561 |
+
args=training_args,
|
| 562 |
+
data_collator=data_collator,
|
| 563 |
+
train_dataset=train_dataset,
|
| 564 |
+
eval_dataset=eval_dataset,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
trainer.train()
|
| 568 |
+
trainer.save_model()
|
| 569 |
+
|
| 570 |
+
logger.info("Fine-tuning completed!")
|
| 571 |
+
|
| 572 |
+
def generate_code(self, description: str, framework: str, language: str) -> str:
|
| 573 |
+
"""Generate backend code for given requirements"""
|
| 574 |
+
input_text = (
|
| 575 |
+
f"Description: {description}\n"
|
| 576 |
+
f"Framework: {framework}\n"
|
| 577 |
+
f"Language: {language}\n"
|
| 578 |
+
f"Generate the backend application:"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Respect model's max position embeddings (GPT-2/DialoGPT is typically 1024)
|
| 582 |
+
model_max_len = getattr(self.tokenizer, 'model_max_length', 1024)
|
| 583 |
+
max_len = 1024 if model_max_len is None or model_max_len > 100000 else min(1024, int(model_max_len))
|
| 584 |
+
|
| 585 |
+
inputs = self.tokenizer.encode(input_text, return_tensors='pt', truncation=True, max_length=max_len)
|
| 586 |
+
|
| 587 |
+
with torch.no_grad():
|
| 588 |
+
outputs = self.model.generate(
|
| 589 |
+
inputs,
|
| 590 |
+
max_length=max_len,
|
| 591 |
+
num_return_sequences=1,
|
| 592 |
+
temperature=0.7,
|
| 593 |
+
do_sample=True,
|
| 594 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 598 |
+
return generated_text[len(input_text):]
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class ModelEvaluator:
|
| 602 |
+
"""Evaluates model performance"""
|
| 603 |
+
|
| 604 |
+
def __init__(self):
|
| 605 |
+
self.metrics: Dict[str, float] = {}
|
| 606 |
+
|
| 607 |
+
def evaluate_code_quality(self, generated_code: str, language: str) -> Dict[str, float]:
|
| 608 |
+
"""Evaluate generated code quality"""
|
| 609 |
+
metrics = {
|
| 610 |
+
'syntax_correctness': self._check_syntax(generated_code, language),
|
| 611 |
+
'completeness': self._check_completeness(generated_code),
|
| 612 |
+
'best_practices': self._check_best_practices(generated_code, language)
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
return metrics
|
| 616 |
+
|
| 617 |
+
def _check_syntax(self, code: str, language: str) -> float:
|
| 618 |
+
"""Check if generated code has valid syntax"""
|
| 619 |
+
if language == 'python':
|
| 620 |
+
try:
|
| 621 |
+
ast.parse(code)
|
| 622 |
+
return 1.0
|
| 623 |
+
except SyntaxError:
|
| 624 |
+
return 0.0
|
| 625 |
+
elif language == 'javascript':
|
| 626 |
+
if '{' in code and '}' in code:
|
| 627 |
+
return 0.8
|
| 628 |
+
return 0.5
|
| 629 |
+
|
| 630 |
+
return 0.5
|
| 631 |
+
|
| 632 |
+
def _check_completeness(self, code: str) -> float:
|
| 633 |
+
"""Check if code appears complete"""
|
| 634 |
+
completeness_indicators = [
|
| 635 |
+
'import', 'require', 'function', 'def', 'class',
|
| 636 |
+
'app.', 'router.', '@app.', 'app.listen', 'if __name__'
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
indicators_found = sum(1 for indicator in completeness_indicators if indicator in code)
|
| 640 |
+
return min(indicators_found / 3.0, 1.0)
|
| 641 |
+
|
| 642 |
+
def _check_best_practices(self, code: str, language: str) -> float:
|
| 643 |
+
"""Check adherence to best practices"""
|
| 644 |
+
best_practices_score = 0.0
|
| 645 |
+
|
| 646 |
+
if 'try:' in code or 'catch' in code:
|
| 647 |
+
best_practices_score += 0.2
|
| 648 |
+
|
| 649 |
+
if any(comment in code for comment in ['#', '//', '/*']):
|
| 650 |
+
best_practices_score += 0.2
|
| 651 |
+
|
| 652 |
+
if language == 'python':
|
| 653 |
+
if 'if __name__ == "__main__"' in code:
|
| 654 |
+
best_practices_score += 0.2
|
| 655 |
+
elif language == 'javascript':
|
| 656 |
+
if 'const' in code or 'let' in code:
|
| 657 |
+
best_practices_score += 0.2
|
| 658 |
+
|
| 659 |
+
return min(best_practices_score, 1.0)
|
| 660 |
+
|
| 661 |
+
def benchmark_model(self, model: 'CodeGenerationModel', test_cases: List[Dict]) -> Dict[str, float]:
|
| 662 |
+
"""Benchmark model on test cases"""
|
| 663 |
+
total_scores = {'syntax': 0.0, 'completeness': 0.0, 'best_practices': 0.0}
|
| 664 |
+
|
| 665 |
+
for i, test_case in enumerate(test_cases):
|
| 666 |
+
generated_code = model.generate_code(
|
| 667 |
+
test_case['description'],
|
| 668 |
+
test_case['framework'],
|
| 669 |
+
test_case['language']
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
scores = self.evaluate_code_quality(generated_code, test_case['language'])
|
| 673 |
+
|
| 674 |
+
total_scores['syntax'] += scores['syntax_correctness']
|
| 675 |
+
total_scores['completeness'] += scores['completeness']
|
| 676 |
+
total_scores['best_practices'] += scores['best_practices']
|
| 677 |
+
|
| 678 |
+
logger.info(f"Test case {i+1}: {scores}")
|
| 679 |
+
|
| 680 |
+
num_cases = max(1, len(test_cases))
|
| 681 |
+
avg_scores = {key: value / num_cases for key, value in total_scores.items()}
|
| 682 |
+
|
| 683 |
+
return avg_scores
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class TrainingPipeline:
|
| 687 |
+
"""Main training pipeline orchestrator"""
|
| 688 |
+
|
| 689 |
+
def __init__(self, config: Dict[str, Any]):
|
| 690 |
+
self.config = config
|
| 691 |
+
self.data_collector = DataCollector()
|
| 692 |
+
self.preprocessor = DataPreprocessor(config.get('tokenizer', 'microsoft/DialoGPT-medium'))
|
| 693 |
+
self.model = CodeGenerationModel(config.get('base_model', 'microsoft/DialoGPT-medium'))
|
| 694 |
+
self.evaluator = ModelEvaluator()
|
| 695 |
+
|
| 696 |
+
async def run_full_pipeline(self):
|
| 697 |
+
"""Run the complete training pipeline"""
|
| 698 |
+
logger.info("Starting full training pipeline...")
|
| 699 |
+
|
| 700 |
+
logger.info("Step 1: Collecting training data...")
|
| 701 |
+
|
| 702 |
+
if self.data_collector.github_token:
|
| 703 |
+
github_queries = [
|
| 704 |
+
'express api backend',
|
| 705 |
+
'fastapi python backend',
|
| 706 |
+
'django rest api',
|
| 707 |
+
'nodejs backend server',
|
| 708 |
+
'flask api backend'
|
| 709 |
+
]
|
| 710 |
+
await self.data_collector.collect_github_repositories(github_queries, max_repos=50)
|
| 711 |
+
|
| 712 |
+
self.data_collector.generate_synthetic_examples(count=200)
|
| 713 |
+
|
| 714 |
+
self.data_collector.save_dataset('raw_dataset.json')
|
| 715 |
+
|
| 716 |
+
logger.info("Step 2: Preprocessing data...")
|
| 717 |
+
processed_examples = self.preprocessor.preprocess_examples(self.data_collector.collected_examples)
|
| 718 |
+
training_dataset = self.preprocessor.create_training_dataset(processed_examples)
|
| 719 |
+
|
| 720 |
+
logger.info("Step 3: Training model...")
|
| 721 |
+
self.model.fine_tune(training_dataset, output_dir=self.config.get('output_dir', './trained_model'))
|
| 722 |
+
|
| 723 |
+
logger.info("Step 4: Evaluating model...")
|
| 724 |
+
test_cases = [
|
| 725 |
+
{
|
| 726 |
+
'description': 'REST API for user management with authentication',
|
| 727 |
+
'framework': 'express',
|
| 728 |
+
'language': 'javascript'
|
| 729 |
+
},
|
| 730 |
+
{
|
| 731 |
+
'description': 'FastAPI backend for e-commerce platform',
|
| 732 |
+
'framework': 'fastapi',
|
| 733 |
+
'language': 'python'
|
| 734 |
+
},
|
| 735 |
+
{
|
| 736 |
+
'description': 'Django REST API for blog platform',
|
| 737 |
+
'framework': 'django',
|
| 738 |
+
'language': 'python'
|
| 739 |
+
}
|
| 740 |
+
]
|
| 741 |
+
|
| 742 |
+
benchmark_results = self.evaluator.benchmark_model(self.model, test_cases)
|
| 743 |
+
logger.info(f"Benchmark results: {benchmark_results}")
|
| 744 |
+
|
| 745 |
+
logger.info("Training pipeline completed!")
|
| 746 |
+
return benchmark_results
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
if __name__ == "__main__":
|
| 750 |
+
config = {
|
| 751 |
+
'base_model': 'microsoft/DialoGPT-medium',
|
| 752 |
+
'tokenizer': 'microsoft/DialoGPT-medium',
|
| 753 |
+
'output_dir': './backend_code_model',
|
| 754 |
+
'github_token': os.getenv('GITHUB_TOKEN'),
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
pipeline = TrainingPipeline(config)
|
| 758 |
+
|
| 759 |
+
asyncio.run(pipeline.run_full_pipeline())
|
| 760 |
+
|
| 761 |
+
logger.info("\nTesting trained model...")
|
| 762 |
+
generated_code = pipeline.model.generate_code(
|
| 763 |
+
description="Create a REST API for managing tasks with CRUD operations",
|
| 764 |
+
framework="express",
|
| 765 |
+
language="javascript"
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
print("\nGenerated Code:")
|
| 769 |
+
print("=" * 50)
|
| 770 |
+
print(generated_code)
|
| 771 |
+
|
| 772 |
+
|