Spaces:
Running
Running
# app.py (safe, use /tmp for cache) | |
import os | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import tempfile | |
# --- Put caches in a writable temp dir to avoid permission errors --- | |
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache")) | |
try: | |
os.makedirs(TMP_CACHE, exist_ok=True) | |
except Exception as e: | |
# if even this fails, fall back to tempfile.gettempdir() | |
TMP_CACHE = tempfile.gettempdir() | |
# export environment vars before importing transformers | |
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE | |
os.environ["HF_HOME"] = TMP_CACHE | |
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE | |
os.environ["HF_METRICS_CACHE"] = TMP_CACHE | |
app = FastAPI(title="DirectEd LoRA API (safe startup)") | |
def health(): | |
return {"ok": True} | |
class Request(BaseModel): | |
prompt: str | |
max_new_tokens: int = 150 | |
temperature: float = 0.7 | |
pipe = None | |
def load_model(): | |
global pipe | |
try: | |
# heavy imports done during startup | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from peft import PeftModel | |
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit" | |
ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <-- replace with your adapter repo | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
torch_dtype="auto", | |
) | |
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) | |
model.eval() | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") | |
logging.info("Model and adapter loaded successfully.") | |
except Exception as e: | |
logging.exception("Failed to load model at startup: %s", e) | |
pipe = None | |
def generate(req: Request): | |
if pipe is None: | |
raise HTTPException(status_code=503, detail="Model not loaded. Check logs.") | |
try: | |
out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True) | |
return {"response": out[0]["generated_text"]} | |
except Exception as e: | |
logging.exception("Generation failed: %s", e) | |
raise HTTPException(status_code=500, detail=str(e)) | |