Spaces:
Running
Running
# inference.py | |
import os, sys, re, unicodedata, torch, torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# --- Windows console UTF-8 (no-op on Linux) | |
if sys.platform.startswith("win"): | |
try: | |
sys.stdout.reconfigure(encoding="utf-8") | |
sys.stderr.reconfigure(encoding="utf-8") | |
except Exception: | |
pass | |
# --- Host constraints (free tiers) | |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
try: | |
torch.set_num_threads(1) | |
except Exception: | |
pass | |
# -------- Config -------- | |
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "aleenarayamajhi/spotchecker-gpt2-medium-merged") | |
HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or | |
os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip() | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "96")) | |
NUM_BEAMS = int(os.getenv("NUM_BEAMS", "1")) | |
USE_CACHE = False # lower RAM | |
def _auth_kwargs(): | |
# Compatible with newer & older hub/transformers | |
return ({"token": HF_TOKEN} if HF_TOKEN else {}) | |
# -------- Mappings -------- | |
DISEASE_TO_PATHOGEN = { | |
"Phyllosticta Leaf Spot": "Phyllosticta spp.", | |
"Cercospora Leaf Spot": "Cercospora spp.", | |
"Septoria Leaf Spot": "Septoria spp.", | |
"Spot Anthracnose": "Elsinoë corni", | |
"Dogwood Anthracnose": "Discula destructiva", | |
"Bacterial Leaf Scorch": "Xylella fastidiosa", | |
} | |
ALLOWED_DISEASES = list(DISEASE_TO_PATHOGEN.keys()) | |
# -------- Cleaning helpers -------- | |
CTRL_PATTERN = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]") | |
def clean_text(s: str) -> str: | |
s = unicodedata.normalize("NFKC", s) | |
s = (s.replace("\u00A0", " ").replace("\u200B", " ").replace("\ufeff", "") | |
.replace("\u2009", " ").replace("\u202F", " ").replace("\u2060", " ")) | |
s = CTRL_PATTERN.sub("", s) | |
s = re.sub(r"[ \t]+", " ", s) | |
s = re.sub(r" *\n *", "\n", s) | |
s = re.sub(r" *; *", "; ", s) | |
s = re.sub(r" *– *", "–", s) | |
return s.strip() | |
def final_clean(text: str) -> str: | |
text = unicodedata.normalize("NFKC", text).replace("Â", "").replace("\u00A0", " ") | |
text = re.sub(r"[ \t]+", " ", text) | |
text = re.sub(r" *\n *", "\n", text) | |
return text.strip() | |
# -------- Load merged model (CPU, no Accelerate) -------- | |
MODEL_READY = False | |
LOAD_ERROR = "" | |
tok = None | |
model = None | |
print("Loading merged model:", MODEL_REPO_ID, "(CPU)") | |
try: | |
try: | |
tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=True, **_auth_kwargs()) | |
except Exception as e_fast: | |
print("Fast tokenizer failed; falling back to slow tokenizer:", e_fast) | |
tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=False, **_auth_kwargs()) | |
if tok.pad_token is None: | |
tok.pad_token = tok.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_REPO_ID, | |
torch_dtype=DTYPE, | |
# IMPORTANT: do NOT set low_cpu_mem_usage or device_map → avoids Accelerate | |
**_auth_kwargs(), | |
).to(DEVICE).eval() | |
# ensure generation config & model config have pad/eos tokens | |
model.generation_config.max_new_tokens = MAX_NEW_TOKENS | |
model.generation_config.num_beams = NUM_BEAMS | |
model.generation_config.do_sample = False | |
model.generation_config.repetition_penalty = 1.05 | |
model.generation_config.no_repeat_ngram_size = 3 | |
model.generation_config.eos_token_id = tok.eos_token_id | |
model.generation_config.pad_token_id = tok.eos_token_id | |
model.generation_config.use_cache = USE_CACHE | |
model.config.eos_token_id = tok.eos_token_id | |
model.config.pad_token_id = tok.eos_token_id | |
MODEL_READY = True | |
print("Model ready on", DEVICE) | |
except Exception as e: | |
LOAD_ERROR = f"{type(e).__name__}: {e}" | |
print("AI model failed to load:", LOAD_ERROR) | |
# -------- Prompt / scoring -------- | |
def training_header(user_text: str) -> str: | |
return f"<BOS>User: {user_text.strip()}\nAssistant:\n" | |
def logprob_continuation(prefix: str, continuation: str) -> float: | |
if not MODEL_READY: | |
return -1e30 | |
max_len = getattr(tok, "model_max_length", 1024) | |
full = prefix + continuation | |
enc = tok(full, return_tensors="pt", truncation=True, max_length=max_len) | |
enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
out = model(**enc) | |
logp = F.log_softmax(out.logits, dim=-1) | |
pref_ids = tok(prefix, return_tensors="pt", truncation=True, max_length=max_len)["input_ids"].to(DEVICE) | |
start = pref_ids.shape[1] - 1 | |
end = enc["input_ids"].shape[1] - 1 | |
total = 0.0 | |
for i in range(max(0, start), max(0, end)): | |
next_id = int(enc["input_ids"][0, i + 1]) | |
total += float(logp[0, i, next_id].item()) | |
return total | |
def choose_disease_by_joint(prefix: str) -> str: | |
best_d, best_score = None, None | |
for d in ALLOWED_DISEASES: | |
p = DISEASE_TO_PATHOGEN[d] | |
continuation = f"Disease: {d}\nPathogen: {p}\n" | |
score = logprob_continuation(prefix, continuation) | |
if (best_score is None) or (score > best_score): | |
best_score, best_d = score, d | |
return best_d or ALLOWED_DISEASES[0] | |
def generate_management(prefix_with_labels: str) -> str: | |
if not MODEL_READY: | |
return "" | |
max_len = getattr(tok, "model_max_length", 1024) | |
enc = tok(prefix_with_labels, return_tensors="pt", truncation=True, max_length=max_len) | |
enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
out = model.generate(**enc) # uses generation_config set above | |
gen_ids = out[0][enc["input_ids"].shape[1]:] | |
text = tok.decode(gen_ids, skip_special_tokens=True) | |
text = text.split("<EOS>")[0].split("\n\n")[0].strip() | |
return clean_text(text) | |
def generate_answer(user_text: str) -> str: | |
if not MODEL_READY: | |
return ("AI text analysis is unavailable on this free tier right now. " | |
f"{('Reason: ' + LOAD_ERROR) if LOAD_ERROR else ''}").strip() | |
h = training_header(clean_text(user_text)) | |
disease = choose_disease_by_joint(h) | |
pathogen = DISEASE_TO_PATHOGEN[disease] | |
labels_block = f"Disease: {disease}\nPathogen: {pathogen}\nManagement: " | |
mgmt = generate_management(h + labels_block) | |
return final_clean(f"{labels_block}{mgmt}") | |
if __name__ == "__main__": | |
if len(sys.argv) > 1: | |
print(generate_answer(" ".join(sys.argv[1:]))) | |
else: | |
while True: | |
try: | |
q = input("Symptoms: ").strip() | |
if not q: | |
continue | |
print(generate_answer(q)) | |
except (KeyboardInterrupt, EOFError): | |
break | |