import os import torch from fastapi import FastAPI, Form, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel # Paths BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" LORA_MODEL_DIR = "./lora_model" QLORA_MODEL_DIR = "./Qlora_model" ADALORA_MODEL_DIR = "./adalora_model" cache_dir = "./cache" # Prompt Template PROMPT_TEMPLATE = """<|system|> You are Jack Patel. Answer questions about yourself using only information you were trained on. If you don't know something specific about yourself, say "I don't have that information." If the user's question is not about Jack Patel, answer as an AI assistant using your general knowledge. Always respond in 2 to 3 short sentences. <|user|> {prompt} <|assistant|> """ app = FastAPI() templates = Jinja2Templates(directory="templates") # Global cache to avoid reloading models model_cache = {} def load_model(adapter_path): if adapter_path in model_cache: return model_cache[adapter_path] print(f"🔄 Loading model from: {adapter_path}") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) tokenizer.pad_token = tokenizer.eos_token base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir=cache_dir, ) model = PeftModel.from_pretrained(base, adapter_path) model.to("cuda" if torch.cuda.is_available() else "cpu").eval() model_cache[adapter_path] = (tokenizer, model) return tokenizer, model def generate_response(prompt, tokenizer, model): full_prompt = PROMPT_TEMPLATE.format(prompt=prompt) inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=50, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return decoded.split("<|assistant|>")[-1].strip() if "<|assistant|>" in decoded else decoded.strip() @app.get("/", response_class=HTMLResponse) async def form_get(request: Request): return templates.TemplateResponse("index.html", { "request": request, "result": None, "model": "", "prompt": "", "data_count": 0 }) @app.post("/", response_class=HTMLResponse) async def form_post( request: Request, prompt: str = Form(...), model_type: str = Form(...) ): model_paths = { "lora": LORA_MODEL_DIR, "Qlora1": QLORA_MODEL_DIR, "adalora": ADALORA_MODEL_DIR } model_labels = { "lora": "LoRA - lora-tinyllama-final", "Qlora1": "QLoRA - lora-tinyllama-final1", "adalora": "AdaLoRA - adalora-tinyllama-final" } adapter_path = model_paths.get(model_type) model_label = model_labels.get(model_type, model_type.upper()) if not adapter_path or not os.path.exists(adapter_path): return templates.TemplateResponse("index.html", { "request": request, "result": "Invalid or missing model selected.", "model": model_label, "prompt": prompt, "data_count": 0 }) try: tokenizer, model = load_model(adapter_path) result = generate_response(prompt, tokenizer, model) except Exception as e: result = f"Error generating response: {str(e)}" return templates.TemplateResponse("index.html", { "request": request, "result": result, "model": model_label, "prompt": prompt, "data_count": 0 # Replace with real data count if available }) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)