import os import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel # --- Use a writable cache directory (current working dir) --- CACHE_DIR = os.path.join(os.getcwd(), "cache") # /code/cache in the Dockerfile layout os.makedirs(CACHE_DIR, exist_ok=True) os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.environ["HF_HOME"] = CACHE_DIR os.environ["HF_DATASETS_CACHE"] = CACHE_DIR os.environ["HF_METRICS_CACHE"] = CACHE_DIR app = FastAPI(title="DirectEd LoRA API (safe startup)") # lightweight health endpoint @app.get("/health") def health(): return {"ok": True} class Request(BaseModel): prompt: str max_new_tokens: int = 150 temperature: float = 0.7 # Globals to be initialized on startup pipe = None @app.on_event("startup") def load_model(): global pipe try: # heavy imports inside startup so module import stays lightweight from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from peft import PeftModel BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit" # unchanged ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <<< replace with your adapter repo # load tokenizer + base model then attach adapter tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", low_cpu_mem_usage=True, torch_dtype="auto", ) model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) model.eval() pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") logging.info("Model and adapter loaded successfully.") except Exception as e: # Keep server up; logs will show why load failed logging.exception("Failed to load model at startup: %s", e) pipe = None @app.post("/generate") def generate(req: Request): if pipe is None: raise HTTPException(status_code=503, detail="Model not loaded yet. Check Space logs.") try: out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True) return {"response": out[0]["generated_text"]} except Exception as e: logging.exception("Generation failed: %s", e) raise HTTPException(status_code=500, detail=str(e))