rayymaxx's picture
Made changes to app file
4969e4b
raw
history blame
2.41 kB
# app.py (safe, use /tmp for cache)
import os
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import tempfile
# --- Put caches in a writable temp dir to avoid permission errors ---
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
os.makedirs(TMP_CACHE, exist_ok=True)
except Exception as e:
# if even this fails, fall back to tempfile.gettempdir()
TMP_CACHE = tempfile.gettempdir()
# export environment vars before importing transformers
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE
app = FastAPI(title="DirectEd LoRA API (safe startup)")
@app.get("/health")
def health():
return {"ok": True}
class Request(BaseModel):
prompt: str
max_new_tokens: int = 150
temperature: float = 0.7
pipe = None
@app.on_event("startup")
def load_model():
global pipe
try:
# heavy imports done during startup
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <-- replace with your adapter repo
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype="auto",
)
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
model.eval()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
logging.info("Model and adapter loaded successfully.")
except Exception as e:
logging.exception("Failed to load model at startup: %s", e)
pipe = None
@app.post("/generate")
def generate(req: Request):
if pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
try:
out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True)
return {"response": out[0]["generated_text"]}
except Exception as e:
logging.exception("Generation failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))