aleenarayamajhi commited on
Commit
55aa0d5
·
verified ·
1 Parent(s): 2a382a9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +181 -178
inference.py CHANGED
@@ -1,178 +1,181 @@
1
- # inference.py
2
- import os, sys, re, unicodedata, torch, torch.nn.functional as F
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
- # --- Windows console UTF-8
6
- if sys.platform.startswith("win"):
7
- try:
8
- sys.stdout.reconfigure(encoding="utf-8")
9
- sys.stderr.reconfigure(encoding="utf-8")
10
- except Exception:
11
- pass
12
-
13
- # --- Host constraints (free tiers)
14
- os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
15
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
- try:
17
- torch.set_num_threads(1)
18
- except Exception:
19
- pass
20
-
21
- # -------- Config --------
22
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "aleenarayamajhi/spotchecker-gpt2-medium-merged")
23
- HF_TOKEN = os.getenv("HF_TOKEN", "").strip() # optional if repo is public
24
-
25
- DEVICE = "cpu"
26
- DTYPE = torch.float32
27
- MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "96"))
28
- NUM_BEAMS = int(os.getenv("NUM_BEAMS", "1"))
29
- USE_CACHE = False
30
-
31
- def _auth_kwargs():
32
- return {"token": HF_TOKEN} if HF_TOKEN else {}
33
-
34
- # -------- Mappings --------
35
- DISEASE_TO_PATHOGEN = {
36
- "Phyllosticta Leaf Spot": "Phyllosticta spp.",
37
- "Cercospora Leaf Spot": "Cercospora spp.",
38
- "Septoria Leaf Spot": "Septoria spp.",
39
- "Spot Anthracnose": "Elsinoë corni",
40
- "Dogwood Anthracnose": "Discula destructiva",
41
- "Bacterial Leaf Scorch": "Xylella fastidiosa",
42
- }
43
- ALLOWED_DISEASES = list(DISEASE_TO_PATHOGEN.keys())
44
-
45
- # -------- Cleaning helpers (minimal) --------
46
- CTRL_PATTERN = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]")
47
- def clean_text(s: str) -> str:
48
- s = unicodedata.normalize("NFKC", s)
49
- s = (s.replace("\u00A0", " ").replace("\u200B", " ").replace("\ufeff", "")
50
- .replace("\u2009", " ").replace("\u202F", " ").replace("\u2060", " "))
51
- s = CTRL_PATTERN.sub("", s)
52
- s = re.sub(r"[ \t]+", " ", s)
53
- s = re.sub(r" *\n *", "\n", s)
54
- s = re.sub(r" *; *", "; ", s)
55
- s = re.sub(r" *– *", "", s)
56
- return s.strip()
57
-
58
- def final_clean(text: str) -> str:
59
- text = unicodedata.normalize("NFKC", text).replace("Â", "").replace("\u00A0", " ")
60
- text = re.sub(r"[ \t]+", " ", text)
61
- text = re.sub(r" *\n *", "\n", text)
62
- return text.strip()
63
-
64
- # -------- Load merged model (non-crashing) --------
65
- MODEL_READY = False
66
- LOAD_ERROR = ""
67
- tok = None
68
- model = None
69
-
70
- print("Loading merged model:", MODEL_REPO_ID, "(CPU)")
71
- try:
72
- try:
73
- tok = AutoTokenizer.from_pretrained(
74
- MODEL_REPO_ID, use_fast=True, force_download=True, **_auth_kwargs()
75
- )
76
- except Exception as e_fast:
77
- print("Fast tokenizer failed; falling back to slow tokenizer:", e_fast)
78
- tok = AutoTokenizer.from_pretrained(
79
- MODEL_REPO_ID, use_fast=False, force_download=True, **_auth_kwargs()
80
- )
81
-
82
- if tok.pad_token is None:
83
- tok.pad_token = tok.eos_token
84
-
85
- model = AutoModelForCausalLM.from_pretrained(
86
- MODEL_REPO_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, **_auth_kwargs()
87
- ).to(DEVICE).eval()
88
-
89
- # generation config
90
- gc = model.generation_config
91
- gc.max_new_tokens = MAX_NEW_TOKENS
92
- gc.num_beams = NUM_BEAMS
93
- gc.do_sample = False
94
- gc.repetition_penalty = 1.05
95
- gc.no_repeat_ngram_size = 3
96
- gc.eos_token_id = tok.eos_token_id
97
- gc.pad_token_id = tok.eos_token_id
98
- gc.use_cache = USE_CACHE
99
-
100
- MODEL_READY = True
101
- print("Model ready on", DEVICE)
102
-
103
- except Exception as e:
104
- LOAD_ERROR = f"{type(e).__name__}: {e}"
105
- print("AI model failed to load:", LOAD_ERROR)
106
-
107
- # -------- Prompt / scoring --------
108
- def training_header(user_text: str) -> str:
109
- return f"<BOS>User: {user_text.strip()}\nAssistant:\n"
110
-
111
- @torch.inference_mode()
112
- def logprob_continuation(prefix: str, continuation: str) -> float:
113
- if not MODEL_READY:
114
- return -1e30
115
- max_len = getattr(tok, "model_max_length", 1024)
116
- full = prefix + continuation
117
- enc = tok(full, return_tensors="pt", truncation=True, max_length=max_len)
118
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
119
- out = model(**enc)
120
- logp = F.log_softmax(out.logits, dim=-1)
121
-
122
- pref_ids = tok(prefix, return_tensors="pt", truncation=True, max_length=max_len)["input_ids"].to(DEVICE)
123
- start = pref_ids.shape[1] - 1
124
- end = enc["input_ids"].shape[1] - 1
125
- total = 0.0
126
- for i in range(max(0, start), max(0, end)):
127
- next_id = int(enc["input_ids"][0, i + 1])
128
- total += float(logp[0, i, next_id].item())
129
- return total
130
-
131
- @torch.inference_mode()
132
- def choose_disease_by_joint(prefix: str) -> str:
133
- best_d, best_score = None, None
134
- for d in ALLOWED_DISEASES:
135
- p = DISEASE_TO_PATHOGEN[d]
136
- continuation = f"Disease: {d}\nPathogen: {p}\n"
137
- score = logprob_continuation(prefix, continuation)
138
- if (best_score is None) or (score > best_score):
139
- best_score, best_d = score, d
140
- return best_d or ALLOWED_DISEASES[0]
141
-
142
- @torch.inference_mode()
143
- def generate_management(prefix_with_labels: str) -> str:
144
- if not MODEL_READY:
145
- return ""
146
- max_len = getattr(tok, "model_max_length", 1024)
147
- enc = tok(prefix_with_labels, return_tensors="pt", truncation=True, max_length=max_len)
148
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
149
- out = model.generate(**enc)
150
- gen_ids = out[0][enc["input_ids"].shape[1]:]
151
- text = tok.decode(gen_ids, skip_special_tokens=True)
152
- text = text.split("<EOS>")[0].split("\n\n")[0].strip()
153
- return clean_text(text)
154
-
155
- @torch.inference_mode()
156
- def generate_answer(user_text: str) -> str:
157
- if not MODEL_READY:
158
- return ("AI text analysis is unavailable on this free tier right now. "
159
- f"{('Reason: ' + LOAD_ERROR) if LOAD_ERROR else ''}").strip()
160
- h = training_header(clean_text(user_text))
161
- disease = choose_disease_by_joint(h)
162
- pathogen = DISEASE_TO_PATHOGEN[disease]
163
- labels_block = f"Disease: {disease}\nPathogen: {pathogen}\nManagement: "
164
- mgmt = generate_management(h + labels_block)
165
- return final_clean(f"{labels_block}{mgmt}")
166
-
167
- if __name__ == "__main__":
168
- if len(sys.argv) > 1:
169
- print(generate_answer(" ".join(sys.argv[1:])))
170
- else:
171
- while True:
172
- try:
173
- q = input("Symptoms: ").strip()
174
- if not q:
175
- continue
176
- print(generate_answer(q))
177
- except (KeyboardInterrupt, EOFError):
178
- break
 
 
 
 
1
+ # inference.py
2
+ import os, sys, re, unicodedata, torch, torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # --- Windows console UTF-8 (no-op on Linux)
6
+ if sys.platform.startswith("win"):
7
+ try:
8
+ sys.stdout.reconfigure(encoding="utf-8")
9
+ sys.stderr.reconfigure(encoding="utf-8")
10
+ except Exception:
11
+ pass
12
+
13
+ # --- Host constraints (free tiers)
14
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
15
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
+ try:
17
+ torch.set_num_threads(1)
18
+ except Exception:
19
+ pass
20
+
21
+ # -------- Config --------
22
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "aleenarayamajhi/spotchecker-gpt2-medium-merged")
23
+ HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or
24
+ os.getenv("HUGGINGFACE_HUB_TOKEN") or "").strip()
25
+
26
+ DEVICE = "cpu"
27
+ DTYPE = torch.float32
28
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "96"))
29
+ NUM_BEAMS = int(os.getenv("NUM_BEAMS", "1"))
30
+ USE_CACHE = False # lower RAM
31
+
32
+ def _auth_kwargs():
33
+ # Compatible with newer & older hub/transformers
34
+ return ({"token": HF_TOKEN} if HF_TOKEN else {})
35
+
36
+ # -------- Mappings --------
37
+ DISEASE_TO_PATHOGEN = {
38
+ "Phyllosticta Leaf Spot": "Phyllosticta spp.",
39
+ "Cercospora Leaf Spot": "Cercospora spp.",
40
+ "Septoria Leaf Spot": "Septoria spp.",
41
+ "Spot Anthracnose": "Elsinoë corni",
42
+ "Dogwood Anthracnose": "Discula destructiva",
43
+ "Bacterial Leaf Scorch": "Xylella fastidiosa",
44
+ }
45
+ ALLOWED_DISEASES = list(DISEASE_TO_PATHOGEN.keys())
46
+
47
+ # -------- Cleaning helpers --------
48
+ CTRL_PATTERN = re.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]")
49
+ def clean_text(s: str) -> str:
50
+ s = unicodedata.normalize("NFKC", s)
51
+ s = (s.replace("\u00A0", " ").replace("\u200B", " ").replace("\ufeff", "")
52
+ .replace("\u2009", " ").replace("\u202F", " ").replace("\u2060", " "))
53
+ s = CTRL_PATTERN.sub("", s)
54
+ s = re.sub(r"[ \t]+", " ", s)
55
+ s = re.sub(r" *\n *", "\n", s)
56
+ s = re.sub(r" *; *", "; ", s)
57
+ s = re.sub(r" *– *", "–", s)
58
+ return s.strip()
59
+
60
+ def final_clean(text: str) -> str:
61
+ text = unicodedata.normalize("NFKC", text).replace("Â", "").replace("\u00A0", " ")
62
+ text = re.sub(r"[ \t]+", " ", text)
63
+ text = re.sub(r" *\n *", "\n", text)
64
+ return text.strip()
65
+
66
+ # -------- Load merged model (CPU, no Accelerate) --------
67
+ MODEL_READY = False
68
+ LOAD_ERROR = ""
69
+ tok = None
70
+ model = None
71
+
72
+ print("Loading merged model:", MODEL_REPO_ID, "(CPU)")
73
+ try:
74
+ try:
75
+ tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=True, **_auth_kwargs())
76
+ except Exception as e_fast:
77
+ print("Fast tokenizer failed; falling back to slow tokenizer:", e_fast)
78
+ tok = AutoTokenizer.from_pretrained(MODEL_REPO_ID, use_fast=False, **_auth_kwargs())
79
+
80
+ if tok.pad_token is None:
81
+ tok.pad_token = tok.eos_token
82
+
83
+ model = AutoModelForCausalLM.from_pretrained(
84
+ MODEL_REPO_ID,
85
+ torch_dtype=DTYPE,
86
+ # IMPORTANT: do NOT set low_cpu_mem_usage or device_map → avoids Accelerate
87
+ **_auth_kwargs(),
88
+ ).to(DEVICE).eval()
89
+
90
+ # ensure generation config & model config have pad/eos tokens
91
+ model.generation_config.max_new_tokens = MAX_NEW_TOKENS
92
+ model.generation_config.num_beams = NUM_BEAMS
93
+ model.generation_config.do_sample = False
94
+ model.generation_config.repetition_penalty = 1.05
95
+ model.generation_config.no_repeat_ngram_size = 3
96
+ model.generation_config.eos_token_id = tok.eos_token_id
97
+ model.generation_config.pad_token_id = tok.eos_token_id
98
+ model.generation_config.use_cache = USE_CACHE
99
+
100
+ model.config.eos_token_id = tok.eos_token_id
101
+ model.config.pad_token_id = tok.eos_token_id
102
+
103
+ MODEL_READY = True
104
+ print("Model ready on", DEVICE)
105
+
106
+ except Exception as e:
107
+ LOAD_ERROR = f"{type(e).__name__}: {e}"
108
+ print("AI model failed to load:", LOAD_ERROR)
109
+
110
+ # -------- Prompt / scoring --------
111
+ def training_header(user_text: str) -> str:
112
+ return f"<BOS>User: {user_text.strip()}\nAssistant:\n"
113
+
114
+ @torch.inference_mode()
115
+ def logprob_continuation(prefix: str, continuation: str) -> float:
116
+ if not MODEL_READY:
117
+ return -1e30
118
+ max_len = getattr(tok, "model_max_length", 1024)
119
+ full = prefix + continuation
120
+ enc = tok(full, return_tensors="pt", truncation=True, max_length=max_len)
121
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
122
+ out = model(**enc)
123
+ logp = F.log_softmax(out.logits, dim=-1)
124
+
125
+ pref_ids = tok(prefix, return_tensors="pt", truncation=True, max_length=max_len)["input_ids"].to(DEVICE)
126
+ start = pref_ids.shape[1] - 1
127
+ end = enc["input_ids"].shape[1] - 1
128
+ total = 0.0
129
+ for i in range(max(0, start), max(0, end)):
130
+ next_id = int(enc["input_ids"][0, i + 1])
131
+ total += float(logp[0, i, next_id].item())
132
+ return total
133
+
134
+ @torch.inference_mode()
135
+ def choose_disease_by_joint(prefix: str) -> str:
136
+ best_d, best_score = None, None
137
+ for d in ALLOWED_DISEASES:
138
+ p = DISEASE_TO_PATHOGEN[d]
139
+ continuation = f"Disease: {d}\nPathogen: {p}\n"
140
+ score = logprob_continuation(prefix, continuation)
141
+ if (best_score is None) or (score > best_score):
142
+ best_score, best_d = score, d
143
+ return best_d or ALLOWED_DISEASES[0]
144
+
145
+ @torch.inference_mode()
146
+ def generate_management(prefix_with_labels: str) -> str:
147
+ if not MODEL_READY:
148
+ return ""
149
+ max_len = getattr(tok, "model_max_length", 1024)
150
+ enc = tok(prefix_with_labels, return_tensors="pt", truncation=True, max_length=max_len)
151
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
152
+ out = model.generate(**enc) # uses generation_config set above
153
+ gen_ids = out[0][enc["input_ids"].shape[1]:]
154
+ text = tok.decode(gen_ids, skip_special_tokens=True)
155
+ text = text.split("<EOS>")[0].split("\n\n")[0].strip()
156
+ return clean_text(text)
157
+
158
+ @torch.inference_mode()
159
+ def generate_answer(user_text: str) -> str:
160
+ if not MODEL_READY:
161
+ return ("AI text analysis is unavailable on this free tier right now. "
162
+ f"{('Reason: ' + LOAD_ERROR) if LOAD_ERROR else ''}").strip()
163
+ h = training_header(clean_text(user_text))
164
+ disease = choose_disease_by_joint(h)
165
+ pathogen = DISEASE_TO_PATHOGEN[disease]
166
+ labels_block = f"Disease: {disease}\nPathogen: {pathogen}\nManagement: "
167
+ mgmt = generate_management(h + labels_block)
168
+ return final_clean(f"{labels_block}{mgmt}")
169
+
170
+ if __name__ == "__main__":
171
+ if len(sys.argv) > 1:
172
+ print(generate_answer(" ".join(sys.argv[1:])))
173
+ else:
174
+ while True:
175
+ try:
176
+ q = input("Symptoms: ").strip()
177
+ if not q:
178
+ continue
179
+ print(generate_answer(q))
180
+ except (KeyboardInterrupt, EOFError):
181
+ break