hhelesto commited on
Commit
ebfe80b
·
verified ·
1 Parent(s): 64eb70f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
4
  import gradio as gr
 
5
 
6
  # --- Load Model & Tokenizer ---
7
 
@@ -30,24 +31,48 @@ model = PeftModel.from_pretrained(model, adapter_model_name)
30
  model = model.merge_and_unload()
31
  print("Model ready!")
32
 
33
- # --- Gradio Function ---
34
 
35
- def generate_text(prompt):
36
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
37
- outputs = model.generate(**inputs, max_new_tokens=100)
38
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
- return text
40
-
41
- # --- Gradio Interface ---
42
- def chat_fn(message, history):
43
  prompt = message
44
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
45
- outputs = model.generate(**inputs, max_new_tokens=100)
46
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  gr.ChatInterface(
50
  fn=chat_fn,
 
51
  title="LLM Finetuned Comment Generator",
52
  description="Chat with the model.",
53
  ).launch(share=True)
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
3
  from peft import PeftModel
4
  import gradio as gr
5
+ import threading
6
 
7
  # --- Load Model & Tokenizer ---
8
 
 
31
  model = model.merge_and_unload()
32
  print("Model ready!")
33
 
34
+ # --- Gradio Streaming Function ---
35
 
36
+ def chat_fn(message, history, max_tokens):
 
 
 
 
 
 
 
37
  prompt = message
38
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
+
40
+ streamer = TextIteratorStreamer(
41
+ tokenizer,
42
+ skip_special_tokens=True
43
+ )
44
+
45
+ generation_kwargs = dict(
46
+ **inputs,
47
+ streamer=streamer,
48
+ max_new_tokens=int(max_tokens),
49
+ temperature=0.7,
50
+ )
51
+
52
+ thread = threading.Thread(
53
+ target=model.generate,
54
+ kwargs=generation_kwargs
55
+ )
56
+ thread.start()
57
+
58
+ partial_text = ""
59
+ for new_text in streamer:
60
+ partial_text += new_text
61
+ yield partial_text
62
+
63
+ # --- Create Dropdown Component for max tokens ---
64
+
65
+ dropdown = gr.Dropdown(
66
+ choices=["100", "200", "300"],
67
+ value="100",
68
+ label="Max New Tokens"
69
+ )
70
+
71
+ # --- Launch Gradio Chat Interface ---
72
 
73
  gr.ChatInterface(
74
  fn=chat_fn,
75
+ additional_inputs=[dropdown],
76
  title="LLM Finetuned Comment Generator",
77
  description="Chat with the model.",
78
  ).launch(share=True)