peterweber commited on
Commit
faae576
·
verified ·
1 Parent(s): 95af4e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -46
app.py CHANGED
@@ -1,24 +1,29 @@
1
  import os, re, difflib
2
  from typing import List
3
- import torch, gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # -------- Model (lazy-load for fast startup) --------
7
- MODEL_ID = os.getenv("MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- _tok = None
10
- _mdl = None
11
 
 
 
12
  def load_model():
13
- global _tok, _mdl
14
- if _tok is None or _mdl is None:
15
- _tok = AutoTokenizer.from_pretrained(MODEL_ID)
16
- _mdl = AutoModelForCausalLM.from_pretrained(
17
- MODEL_ID, low_cpu_mem_usage=True, torch_dtype=torch.float32
18
- ).to(device).eval()
19
- return _tok, _mdl
20
-
21
- # -------- Protect / restore (keep citations, URLs, numbers, code) --------
 
 
 
22
  SENTINEL_OPEN, SENTINEL_CLOSE = "§§KEEP_OPEN§§", "§§KEEP_CLOSE§§"
23
  URL_RE = re.compile(r'(https?://\S+)')
24
  CODE_RE = re.compile(r'`{1,3}[\s\S]*?`{1,3}')
@@ -41,59 +46,64 @@ def restore(text: str, protected: List[str]):
41
  text = re.sub(rf"{SENTINEL_OPEN}(\d+){SENTINEL_CLOSE}", unwrap, text)
42
  return text.replace(SENTINEL_OPEN, "").replace(SENTINEL_CLOSE, "")
43
 
44
- # -------- Humanization engine --------
45
  SYSTEM = (
46
- "You are an expert editor. Humanize the user's text: vary sentence length, remove filler, "
47
- "break up long sentences, swap stiff phrasing for natural alternatives, and keep meaning identical. "
48
- "Do not alter anything inside §§KEEP markers§§, including citations, URLs, numbers, and code. "
49
- "Keep tone & region as requested. No em dashes—use simple punctuation."
50
  )
51
 
52
- def build_messages(text, tone, region, level, intensity):
53
  user = (
54
  f"Tone: {tone}. Region: {region} English. Reading level: {level}. "
55
  f"Humanization intensity: {intensity} (10 strongest).\n\n"
56
- f"Rewrite this text. Keep placeholders intact:\n\n{text}"
57
  )
58
- return [{"role":"system","content":SYSTEM},{"role":"user","content":user}]
59
-
60
- def apply_chat_template(tokenizer, messages):
61
- return tokenizer.apply_chat_template(
62
- messages, tokenize=False, add_generation_prompt=True
 
 
63
  )
64
 
65
- def generate_once(prompt, temperature, max_new=512):
66
- tok, mdl = load_model()
67
- ids = tok(prompt, return_tensors="pt").to(device)
68
- out = mdl.generate(
69
- **ids, do_sample=True, temperature=temperature, top_p=0.95,
70
- max_new_tokens=max_new, pad_token_id=tok.eos_token_id
 
 
 
 
 
71
  )
72
- return tok.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
73
-
74
- def diff_ratio(a, b): return difflib.SequenceMatcher(None, a, b).ratio()
75
 
76
- def humanize_core(text, tone, region, level, intensity):
 
77
  protected_text, bag = protect(text)
78
- msgs = build_messages(protected_text, tone, region, level, intensity)
79
- prompt = apply_chat_template(load_model()[0], msgs)
80
 
81
- # pass 1 (conservative), pass 2 (stronger) if output too similar
82
  draft = generate_once(prompt, temperature=0.35)
83
  if diff_ratio(protected_text, draft) > 0.97:
84
  draft = generate_once(prompt, temperature=0.9)
85
 
86
- draft = draft.replace("—","-")
87
  final = restore(draft, bag)
88
 
89
- # ensure protected spans survived
90
  for i, span in enumerate(bag):
91
  marker = f"{SENTINEL_OPEN}{i}{SENTINEL_CLOSE}"
92
  if marker in protected_text and span not in final:
93
  final = final.replace(marker, span)
94
  return final
95
 
96
- # -------- Gradio UI (also gives REST at /api/predict/) --------
97
  def ui_humanize(text, tone, region, level, intensity):
98
  return humanize_core(text, tone, region, level, int(intensity))
99
 
@@ -107,9 +117,9 @@ demo = gr.Interface(
107
  gr.Slider(1, 10, value=6, step=1, label="Humanization intensity"),
108
  ],
109
  outputs=gr.Textbox(label="Humanized"),
110
- title="NoteCraft Humanizer (TinyLlama 1.1B Chat)",
111
  description="REST: POST /api/predict/ with { data: [text,tone,region,level,intensity] }",
112
  ).queue()
113
 
114
  if __name__ == "__main__":
115
- demo.launch()
 
1
  import os, re, difflib
2
  from typing import List
3
+ import gradio as gr
4
+ from ctransformers import AutoModelForCausalLM
5
 
6
+ # ---------------- Model (GGUF on CPU) ----------------
7
+ # These defaults work on HF free CPU Spaces.
8
+ REPO_ID = os.getenv("LLAMA_GGUF_REPO", "bartowski/Llama-3.2-3B-Instruct-GGUF")
9
+ FILENAME = os.getenv("LLAMA_GGUF_FILE", "Llama-3.2-3B-Instruct-Q5_0.gguf") # if not found, use Q8_0
10
+ MODEL_TYPE = "llama"
11
 
12
+ # lazy-load for fast startup
13
+ _llm = None
14
  def load_model():
15
+ global _llm
16
+ if _llm is None:
17
+ _llm = AutoModelForCausalLM.from_pretrained(
18
+ REPO_ID,
19
+ model_file=FILENAME,
20
+ model_type=MODEL_TYPE,
21
+ gpu_layers=0,
22
+ context_length=8192,
23
+ )
24
+ return _llm
25
+
26
+ # ---------------- Protect / restore ----------------
27
  SENTINEL_OPEN, SENTINEL_CLOSE = "§§KEEP_OPEN§§", "§§KEEP_CLOSE§§"
28
  URL_RE = re.compile(r'(https?://\S+)')
29
  CODE_RE = re.compile(r'`{1,3}[\s\S]*?`{1,3}')
 
46
  text = re.sub(rf"{SENTINEL_OPEN}(\d+){SENTINEL_CLOSE}", unwrap, text)
47
  return text.replace(SENTINEL_OPEN, "").replace(SENTINEL_CLOSE, "")
48
 
49
+ # ---------------- Prompting (Llama 3.x chat template) ----------------
50
  SYSTEM = (
51
+ "You are an expert editor. Humanize the user's text: improve flow, vary sentence length, "
52
+ "split run-ons, replace stiff phrasing with natural alternatives, and preserve meaning. "
53
+ "Do NOT alter anything wrapped by §§KEEP_OPEN§§<id>§§KEEP_CLOSE§§ (citations, URLs, numbers, code). "
54
+ "Keep the requested tone and region. No em dashes—use simple punctuation."
55
  )
56
 
57
+ def build_prompt(text: str, tone: str, region: str, level: str, intensity: int) -> str:
58
  user = (
59
  f"Tone: {tone}. Region: {region} English. Reading level: {level}. "
60
  f"Humanization intensity: {intensity} (10 strongest).\n\n"
61
+ f"Rewrite this text. Keep markers intact:\n\n{text}"
62
  )
63
+ # Llama 3.x chat format
64
+ return (
65
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
66
+ f"{SYSTEM}\n"
67
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
68
+ f"{user}\n"
69
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
70
  )
71
 
72
+ def diff_ratio(a: str, b: str) -> float:
73
+ return difflib.SequenceMatcher(None, a, b).ratio()
74
+
75
+ def generate_once(prompt: str, temperature: float, max_new: int = 768) -> str:
76
+ llm = load_model()
77
+ out = llm(
78
+ prompt,
79
+ temperature=temperature,
80
+ top_p=0.95,
81
+ max_new_tokens=max_new,
82
+ stop=["<|eot_id|>"]
83
  )
84
+ return out.strip()
 
 
85
 
86
+ # ---------------- Main humanizer ----------------
87
+ def humanize_core(text: str, tone: str, region: str, level: str, intensity: int):
88
  protected_text, bag = protect(text)
89
+ prompt = build_prompt(protected_text, tone, region, level, intensity)
 
90
 
91
+ # pass 1 (conservative), pass 2 (stronger) if too similar
92
  draft = generate_once(prompt, temperature=0.35)
93
  if diff_ratio(protected_text, draft) > 0.97:
94
  draft = generate_once(prompt, temperature=0.9)
95
 
96
+ draft = draft.replace("—", "-")
97
  final = restore(draft, bag)
98
 
99
+ # ensure all protected spans survived
100
  for i, span in enumerate(bag):
101
  marker = f"{SENTINEL_OPEN}{i}{SENTINEL_CLOSE}"
102
  if marker in protected_text and span not in final:
103
  final = final.replace(marker, span)
104
  return final
105
 
106
+ # ---------------- Gradio UI (and REST at /api/predict/) ----------------
107
  def ui_humanize(text, tone, region, level, intensity):
108
  return humanize_core(text, tone, region, level, int(intensity))
109
 
 
117
  gr.Slider(1, 10, value=6, step=1, label="Humanization intensity"),
118
  ],
119
  outputs=gr.Textbox(label="Humanized"),
120
+ title="NoteCraft Humanizer (Llama-3.2-3B-Instruct)",
121
  description="REST: POST /api/predict/ with { data: [text,tone,region,level,intensity] }",
122
  ).queue()
123
 
124
  if __name__ == "__main__":
125
+ demo.launch()