spotchecker / inference.py
aleenarayamajhi's picture
Update inference.py
55aa0d5 verified
# inference.py
import os, sys, re, unicodedata, torch, torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- Windows console UTF-8 (no-op on Linux)
if sys.platform.startswith("win"):
try:
sys.stdout.reconfigure(encoding="utf-8")
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# --- Host constraints (free tiers)
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
try:
torch.set_num_threads(1)
except Exception:
pass
# -------- Config --------
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "aleenarayamajhi/spotchecker-gpt2-medium-merged")
HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or
os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip()
DEVICE = "cpu"
DTYPE = torch.float32
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "96"))
NUM_BEAMS = int(os.getenv("NUM_BEAMS", "1"))
USE_CACHE = False # lower RAM
def _auth_kwargs():
# Compatible with newer & older hub/transformers
return ({"token": HF_TOKEN} if HF_TOKEN else {})
# -------- Mappings --------
DISEASE_TO_PATHOGEN = {
"Phyllosticta Leaf Spot": "Phyllosticta spp.",
"Cercospora Leaf Spot": "Cercospora spp.",
"Septoria Leaf Spot": "Septoria spp.",
"Spot Anthracnose": "Elsinoë corni",
"Dogwood Anthracnose": "Discula destructiva",
"Bacterial Leaf Scorch": "Xylella fastidiosa",
}
ALLOWED_DISEASES = list(DISEASE_TO_PATHOGEN.keys())
# -------- Cleaning helpers --------
CTRL_PATTERN = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]")
def clean_text(s: str) -> str:
s = unicodedata.normalize("NFKC", s)
s = (s.replace("\u00A0", " ").replace("\u200B", " ").replace("\ufeff", "")
.replace("\u2009", " ").replace("\u202F", " ").replace("\u2060", " "))
s = CTRL_PATTERN.sub("", s)
s = re.sub(r"[ \t]+", " ", s)
s = re.sub(r" *\n *", "\n", s)
s = re.sub(r" *; *", "; ", s)
s = re.sub(r" *– *", "–", s)
return s.strip()
def final_clean(text: str) -> str:
text = unicodedata.normalize("NFKC", text).replace("Â", "").replace("\u00A0", " ")
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r" *\n *", "\n", text)
return text.strip()
# -------- Load merged model (CPU, no Accelerate) --------
MODEL_READY = False
LOAD_ERROR = ""
tok = None
model = None
print("Loading merged model:", MODEL_REPO_ID, "(CPU)")
try:
try:
tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=True, **_auth_kwargs())
except Exception as e_fast:
print("Fast tokenizer failed; falling back to slow tokenizer:", e_fast)
tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=False, **_auth_kwargs())
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO_ID,
torch_dtype=DTYPE,
# IMPORTANT: do NOT set low_cpu_mem_usage or device_map → avoids Accelerate
**_auth_kwargs(),
).to(DEVICE).eval()
# ensure generation config & model config have pad/eos tokens
model.generation_config.max_new_tokens = MAX_NEW_TOKENS
model.generation_config.num_beams = NUM_BEAMS
model.generation_config.do_sample = False
model.generation_config.repetition_penalty = 1.05
model.generation_config.no_repeat_ngram_size = 3
model.generation_config.eos_token_id = tok.eos_token_id
model.generation_config.pad_token_id = tok.eos_token_id
model.generation_config.use_cache = USE_CACHE
model.config.eos_token_id = tok.eos_token_id
model.config.pad_token_id = tok.eos_token_id
MODEL_READY = True
print("Model ready on", DEVICE)
except Exception as e:
LOAD_ERROR = f"{type(e).__name__}: {e}"
print("AI model failed to load:", LOAD_ERROR)
# -------- Prompt / scoring --------
def training_header(user_text: str) -> str:
return f"<BOS>User: {user_text.strip()}\nAssistant:\n"
@torch.inference_mode()
def logprob_continuation(prefix: str, continuation: str) -> float:
if not MODEL_READY:
return -1e30
max_len = getattr(tok, "model_max_length", 1024)
full = prefix + continuation
enc = tok(full, return_tensors="pt", truncation=True, max_length=max_len)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
out = model(**enc)
logp = F.log_softmax(out.logits, dim=-1)
pref_ids = tok(prefix, return_tensors="pt", truncation=True, max_length=max_len)["input_ids"].to(DEVICE)
start = pref_ids.shape[1] - 1
end = enc["input_ids"].shape[1] - 1
total = 0.0
for i in range(max(0, start), max(0, end)):
next_id = int(enc["input_ids"][0, i + 1])
total += float(logp[0, i, next_id].item())
return total
@torch.inference_mode()
def choose_disease_by_joint(prefix: str) -> str:
best_d, best_score = None, None
for d in ALLOWED_DISEASES:
p = DISEASE_TO_PATHOGEN[d]
continuation = f"Disease: {d}\nPathogen: {p}\n"
score = logprob_continuation(prefix, continuation)
if (best_score is None) or (score > best_score):
best_score, best_d = score, d
return best_d or ALLOWED_DISEASES[0]
@torch.inference_mode()
def generate_management(prefix_with_labels: str) -> str:
if not MODEL_READY:
return ""
max_len = getattr(tok, "model_max_length", 1024)
enc = tok(prefix_with_labels, return_tensors="pt", truncation=True, max_length=max_len)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
out = model.generate(**enc) # uses generation_config set above
gen_ids = out[0][enc["input_ids"].shape[1]:]
text = tok.decode(gen_ids, skip_special_tokens=True)
text = text.split("<EOS>")[0].split("\n\n")[0].strip()
return clean_text(text)
@torch.inference_mode()
def generate_answer(user_text: str) -> str:
if not MODEL_READY:
return ("AI text analysis is unavailable on this free tier right now. "
f"{('Reason: ' + LOAD_ERROR) if LOAD_ERROR else ''}").strip()
h = training_header(clean_text(user_text))
disease = choose_disease_by_joint(h)
pathogen = DISEASE_TO_PATHOGEN[disease]
labels_block = f"Disease: {disease}\nPathogen: {pathogen}\nManagement: "
mgmt = generate_management(h + labels_block)
return final_clean(f"{labels_block}{mgmt}")
if __name__ == "__main__":
if len(sys.argv) > 1:
print(generate_answer(" ".join(sys.argv[1:])))
else:
while True:
try:
q = input("Symptoms: ").strip()
if not q:
continue
print(generate_answer(q))
except (KeyboardInterrupt, EOFError):
break