joermd commited on
Commit
f03201b
·
verified ·
1 Parent(s): d64cccd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -2
app.py CHANGED
@@ -1,4 +1,123 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load("models/joermd/speedy-llama2").launch()
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import transformers
3
+ import torch
4
 
5
+ # First install required dependencies
6
+ # pip install tiktoken sentencepiece
7
+
8
+ def initialize_pipeline():
9
+ model_id = "joermd/speedy-llama2"
10
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
11
+ model_id,
12
+ trust_remote_code=True,
13
+ use_fast=False # Use slow tokenizer to avoid tiktoken issues
14
+ )
15
+
16
+ model = transformers.AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+
23
+ pipeline = transformers.pipeline(
24
+ "text-generation",
25
+ model=model,
26
+ tokenizer=tokenizer,
27
+ device_map="auto"
28
+ )
29
+
30
+ return pipeline, tokenizer
31
+
32
+ # Initialize pipeline and tokenizer
33
+ pipeline, tokenizer = initialize_pipeline()
34
+
35
+ def format_chat_prompt(messages, system_message):
36
+ """Format the chat messages into a prompt the model can understand"""
37
+ formatted_messages = []
38
+ if system_message:
39
+ formatted_messages.append({"role": "system", "content": system_message})
40
+
41
+ for msg in messages:
42
+ if msg[0]: # User message
43
+ formatted_messages.append({"role": "user", "content": msg[0]})
44
+ if msg[1]: # Assistant message
45
+ formatted_messages.append({"role": "assistant", "content": msg[1]})
46
+
47
+ return formatted_messages
48
+
49
+ def respond(
50
+ message: str,
51
+ history: list[tuple[str, str]],
52
+ system_message: str,
53
+ max_tokens: int,
54
+ temperature: float,
55
+ top_p: float,
56
+ ):
57
+ """Generate response using the pipeline"""
58
+ messages = format_chat_prompt(history, system_message)
59
+ messages.append({"role": "user", "content": message})
60
+
61
+ # Define terminators
62
+ terminators = [
63
+ tokenizer.eos_token_id,
64
+ tokenizer.convert_tokens_to_ids("<|eot_id|>") if "<|eot_id|>" in tokenizer.get_vocab() else None
65
+ ]
66
+ terminators = [t for t in terminators if t is not None]
67
+
68
+ outputs = pipeline(
69
+ messages,
70
+ max_new_tokens=max_tokens,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ eos_token_id=terminators,
75
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
76
+ )
77
+
78
+ # Extract the generated response
79
+ try:
80
+ response = outputs[0]["generated_text"]
81
+ if isinstance(response, list) and len(response) > 0 and isinstance(response[-1], dict):
82
+ response = response[-1].get("content", "")
83
+ except (IndexError, KeyError, AttributeError):
84
+ response = "I apologize, but I couldn't generate a proper response."
85
+
86
+ yield response
87
+
88
+ # Create the Gradio interface
89
+ demo = gr.ChatInterface(
90
+ respond,
91
+ additional_inputs=[
92
+ gr.Textbox(
93
+ value="Kamu adalah seorang asisten yang baik",
94
+ label="System message"
95
+ ),
96
+ gr.Slider(
97
+ minimum=1,
98
+ maximum=2048,
99
+ value=512,
100
+ step=1,
101
+ label="Max new tokens"
102
+ ),
103
+ gr.Slider(
104
+ minimum=0.1,
105
+ maximum=4.0,
106
+ value=0.7,
107
+ step=0.1,
108
+ label="Temperature"
109
+ ),
110
+ gr.Slider(
111
+ minimum=0.1,
112
+ maximum=1.0,
113
+ value=0.95,
114
+ step=0.05,
115
+ label="Top-p (nucleus sampling)"
116
+ ),
117
+ ],
118
+ title="Chat Assistant",
119
+ description="A conversational AI assistant powered by Llama-2"
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()