rayymaxx commited on
Commit
4969e4b
·
1 Parent(s): c3e0a3a

Made changes to app file

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,19 +1,26 @@
 
1
  import os
2
  import logging
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
 
5
 
6
- # --- Use a writable cache directory (current working dir) ---
7
- CACHE_DIR = os.path.join(os.getcwd(), "cache") # /code/cache in the Dockerfile layout
8
- os.makedirs(CACHE_DIR, exist_ok=True)
9
- os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
10
- os.environ["HF_HOME"] = CACHE_DIR
11
- os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
12
- os.environ["HF_METRICS_CACHE"] = CACHE_DIR
 
 
 
 
 
 
13
 
14
  app = FastAPI(title="DirectEd LoRA API (safe startup)")
15
 
16
- # lightweight health endpoint
17
  @app.get("/health")
18
  def health():
19
  return {"ok": True}
@@ -23,21 +30,19 @@ class Request(BaseModel):
23
  max_new_tokens: int = 150
24
  temperature: float = 0.7
25
 
26
- # Globals to be initialized on startup
27
  pipe = None
28
 
29
  @app.on_event("startup")
30
  def load_model():
31
  global pipe
32
  try:
33
- # heavy imports inside startup so module import stays lightweight
34
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
35
  from peft import PeftModel
36
 
37
- BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit" # unchanged
38
- ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <<< replace with your adapter repo
39
 
40
- # load tokenizer + base model then attach adapter
41
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
42
  base_model = AutoModelForCausalLM.from_pretrained(
43
  BASE_MODEL,
@@ -50,17 +55,15 @@ def load_model():
50
  model.eval()
51
 
52
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
53
-
54
  logging.info("Model and adapter loaded successfully.")
55
  except Exception as e:
56
- # Keep server up; logs will show why load failed
57
  logging.exception("Failed to load model at startup: %s", e)
58
  pipe = None
59
 
60
  @app.post("/generate")
61
  def generate(req: Request):
62
  if pipe is None:
63
- raise HTTPException(status_code=503, detail="Model not loaded yet. Check Space logs.")
64
  try:
65
  out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True)
66
  return {"response": out[0]["generated_text"]}
 
1
+ # app.py (safe, use /tmp for cache)
2
  import os
3
  import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
+ import tempfile
7
 
8
+ # --- Put caches in a writable temp dir to avoid permission errors ---
9
+ TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
10
+ try:
11
+ os.makedirs(TMP_CACHE, exist_ok=True)
12
+ except Exception as e:
13
+ # if even this fails, fall back to tempfile.gettempdir()
14
+ TMP_CACHE = tempfile.gettempdir()
15
+
16
+ # export environment vars before importing transformers
17
+ os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
18
+ os.environ["HF_HOME"] = TMP_CACHE
19
+ os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
20
+ os.environ["HF_METRICS_CACHE"] = TMP_CACHE
21
 
22
  app = FastAPI(title="DirectEd LoRA API (safe startup)")
23
 
 
24
  @app.get("/health")
25
  def health():
26
  return {"ok": True}
 
30
  max_new_tokens: int = 150
31
  temperature: float = 0.7
32
 
 
33
  pipe = None
34
 
35
  @app.on_event("startup")
36
  def load_model():
37
  global pipe
38
  try:
39
+ # heavy imports done during startup
40
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
41
  from peft import PeftModel
42
 
43
+ BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
44
+ ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <-- replace with your adapter repo
45
 
 
46
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
47
  base_model = AutoModelForCausalLM.from_pretrained(
48
  BASE_MODEL,
 
55
  model.eval()
56
 
57
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
 
58
  logging.info("Model and adapter loaded successfully.")
59
  except Exception as e:
 
60
  logging.exception("Failed to load model at startup: %s", e)
61
  pipe = None
62
 
63
  @app.post("/generate")
64
  def generate(req: Request):
65
  if pipe is None:
66
+ raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
67
  try:
68
  out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True)
69
  return {"response": out[0]["generated_text"]}