joermd commited on
Commit
96853ae
·
verified ·
1 Parent(s): d43e6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -38
app.py CHANGED
@@ -1,68 +1,123 @@
1
  import gradio as gr
2
- import spaces
3
  import transformers
4
  import torch
5
 
6
- model_id = "joermd/speedy-llama2"
 
7
 
8
- pipeline = transformers.pipeline(
9
- "text-generation",
10
- model=model_id,
11
- model_kwargs={"torch_dtype": torch.bfloat16},
12
- device_map="auto",
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- terminators = [
16
- pipeline.tokenizer.eos_token_id,
17
- pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
18
- ]
 
 
 
 
 
 
 
 
 
19
 
20
- @spaces.GPU
21
  def respond(
22
- message,
23
  history: list[tuple[str, str]],
24
- system_message,
25
- max_tokens,
26
- temperature,
27
- top_p,
28
  ):
29
- messages = []
30
-
31
- for val in history:
32
- if val[0]:
33
- messages.append({"role": "user", "content": val[0]})
34
- if val[1]:
35
- messages.append({"role": "assistant", "content": val[1]})
36
-
37
  messages.append({"role": "user", "content": message})
38
-
 
 
 
 
 
 
 
39
  outputs = pipeline(
40
  messages,
41
- max_new_tokens=256,
 
 
 
42
  eos_token_id=terminators,
 
43
  )
44
 
45
- yield outputs[0]["generated_text"][-1]["content"]
 
 
 
 
 
 
 
 
46
 
47
- """
48
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
49
- """
50
  demo = gr.ChatInterface(
51
  respond,
52
  additional_inputs=[
53
- gr.Textbox(value="Kamu adalah seorang asisten yang baik", label="System message"),
54
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
55
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  gr.Slider(
57
  minimum=0.1,
58
  maximum=1.0,
59
  value=0.95,
60
  step=0.05,
61
- label="Top-p (nucleus sampling)",
62
  ),
63
  ],
 
 
64
  )
65
 
66
-
67
  if __name__ == "__main__":
68
- demo.launch()
 
1
  import gradio as gr
 
2
  import transformers
3
  import torch
4
 
5
+ # First install required dependencies
6
+ # pip install tiktoken sentencepiece
7
 
8
+ def initialize_pipeline():
9
+ model_id = "joermd/speedy-llama2"
10
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
11
+ model_id,
12
+ trust_remote_code=True,
13
+ use_fast=False # Use slow tokenizer to avoid tiktoken issues
14
+ )
15
+
16
+ model = transformers.AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+
23
+ pipeline = transformers.pipeline(
24
+ "text-generation",
25
+ model=model,
26
+ tokenizer=tokenizer,
27
+ device_map="auto"
28
+ )
29
+
30
+ return pipeline, tokenizer
31
+
32
+ # Initialize pipeline and tokenizer
33
+ pipeline, tokenizer = initialize_pipeline()
34
 
35
+ def format_chat_prompt(messages, system_message):
36
+ """Format the chat messages into a prompt the model can understand"""
37
+ formatted_messages = []
38
+ if system_message:
39
+ formatted_messages.append({"role": "system", "content": system_message})
40
+
41
+ for msg in messages:
42
+ if msg[0]: # User message
43
+ formatted_messages.append({"role": "user", "content": msg[0]})
44
+ if msg[1]: # Assistant message
45
+ formatted_messages.append({"role": "assistant", "content": msg[1]})
46
+
47
+ return formatted_messages
48
 
 
49
  def respond(
50
+ message: str,
51
  history: list[tuple[str, str]],
52
+ system_message: str,
53
+ max_tokens: int,
54
+ temperature: float,
55
+ top_p: float,
56
  ):
57
+ """Generate response using the pipeline"""
58
+ messages = format_chat_prompt(history, system_message)
 
 
 
 
 
 
59
  messages.append({"role": "user", "content": message})
60
+
61
+ # Define terminators
62
+ terminators = [
63
+ tokenizer.eos_token_id,
64
+ tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None
65
+ ]
66
+ terminators = [t for t in terminators if t is not None]
67
+
68
  outputs = pipeline(
69
  messages,
70
+ max_new_tokens=max_tokens,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
  eos_token_id=terminators,
75
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
76
  )
77
 
78
+ # Extract the generated response
79
+ try:
80
+ response = outputs[0]["generated_text"]
81
+ if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict):
82
+ response = response[-1].get("content", "")
83
+ except (IndexError, KeyError, AttributeError):
84
+ response = "I apologize, but I couldn't generate a proper response."
85
+
86
+ yield response
87
 
88
+ # Create the Gradio interface
 
 
89
  demo = gr.ChatInterface(
90
  respond,
91
  additional_inputs=[
92
+ gr.Textbox(
93
+ value="Kamu adalah seorang asisten yang baik",
94
+ label="System message"
95
+ ),
96
+ gr.Slider(
97
+ minimum=1,
98
+ maximum=2048,
99
+ value=512,
100
+ step=1,
101
+ label="Max new tokens"
102
+ ),
103
+ gr.Slider(
104
+ minimum=0.1,
105
+ maximum=4.0,
106
+ value=0.7,
107
+ step=0.1,
108
+ label="Temperature"
109
+ ),
110
  gr.Slider(
111
  minimum=0.1,
112
  maximum=1.0,
113
  value=0.95,
114
  step=0.05,
115
+ label="Top-p (nucleus sampling)"
116
  ),
117
  ],
118
+ title="Chat Assistant",
119
+ description="A conversational AI assistant powered by Llama-2"
120
  )
121
 
 
122
  if __name__ == "__main__":
123
+ demo.launch()