Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,12 @@ import gradio as gr
|
|
5 |
import spaces
|
6 |
import torch
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
8 |
-
from typing import List, Dict, Optional, Tuple
|
9 |
|
10 |
DESCRIPTION = """
|
11 |
-
#
|
12 |
"""
|
13 |
|
14 |
-
css =
|
15 |
h1 {
|
16 |
text-align: center;
|
17 |
display: block;
|
@@ -31,76 +30,37 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
31 |
|
32 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
33 |
|
34 |
-
model_id = "
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
36 |
model = AutoModelForCausalLM.from_pretrained(
|
37 |
model_id,
|
38 |
device_map="auto",
|
39 |
torch_dtype=torch.bfloat16,
|
40 |
)
|
41 |
-
model.config.sliding_window = 4096
|
42 |
model.eval()
|
43 |
|
44 |
-
# Set the pad token ID if it's not already set
|
45 |
-
if tokenizer.pad_token_id is None:
|
46 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
47 |
-
|
48 |
-
# Define roles for the chat
|
49 |
-
class Role:
|
50 |
-
SYSTEM = "system"
|
51 |
-
USER = "user"
|
52 |
-
ASSISTANT = "assistant"
|
53 |
-
|
54 |
-
# Default system message
|
55 |
-
default_system = "You are a helpful assistant."
|
56 |
-
|
57 |
-
def clear_session() -> List:
|
58 |
-
return "", []
|
59 |
-
|
60 |
-
def modify_system_session(system: str) -> Tuple[str, str, List]:
|
61 |
-
if system is None or len(system) == 0:
|
62 |
-
system = default_system
|
63 |
-
return system, system, []
|
64 |
-
|
65 |
-
def history_to_messages(history: List, system: str) -> List[Dict]:
|
66 |
-
messages = [{'role': Role.SYSTEM, 'content': system}]
|
67 |
-
for h in history:
|
68 |
-
messages.append({'role': Role.USER, 'content': h[0]})
|
69 |
-
messages.append({'role': Role.ASSISTANT, 'content': h[1]})
|
70 |
-
return messages
|
71 |
|
72 |
@spaces.GPU(duration=120)
|
73 |
def generate(
|
74 |
-
|
75 |
-
|
76 |
-
system: str,
|
77 |
max_new_tokens: int = 1024,
|
78 |
temperature: float = 0.6,
|
79 |
top_p: float = 0.9,
|
80 |
top_k: int = 50,
|
81 |
repetition_penalty: float = 1.2,
|
82 |
) -> Iterator[str]:
|
83 |
-
|
84 |
-
query = ''
|
85 |
-
if history is None:
|
86 |
-
history = []
|
87 |
-
|
88 |
-
# Convert history to messages
|
89 |
-
messages = history_to_messages(history, system)
|
90 |
-
messages.append({'role': Role.USER, 'content': query})
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
)
|
98 |
-
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
99 |
|
100 |
-
# Set up the streamer for real-time text generation
|
101 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
102 |
generate_kwargs = dict(
|
103 |
-
|
104 |
streamer=streamer,
|
105 |
max_new_tokens=max_new_tokens,
|
106 |
do_sample=True,
|
@@ -109,12 +69,10 @@ def generate(
|
|
109 |
temperature=temperature,
|
110 |
num_beams=1,
|
111 |
repetition_penalty=repetition_penalty,
|
112 |
-
pad_token_id=tokenizer.pad_token_id,
|
113 |
)
|
114 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
115 |
t.start()
|
116 |
|
117 |
-
# Stream the output tokens
|
118 |
outputs = []
|
119 |
for text in streamer:
|
120 |
outputs.append(text)
|
@@ -124,7 +82,6 @@ def generate(
|
|
124 |
demo = gr.ChatInterface(
|
125 |
fn=generate,
|
126 |
additional_inputs=[
|
127 |
-
gr.Textbox(label="System Message", value=default_system, lines=2),
|
128 |
gr.Slider(
|
129 |
label="Max new tokens",
|
130 |
minimum=1,
|
@@ -163,12 +120,14 @@ demo = gr.ChatInterface(
|
|
163 |
],
|
164 |
stop_btn=None,
|
165 |
examples=[
|
166 |
-
["Write a Python function to reverses a string if it's length is a multiple of 4."],
|
167 |
-
["
|
168 |
-
["
|
169 |
["What happens when the sun goes down?"],
|
170 |
],
|
|
|
171 |
cache_examples=False,
|
|
|
172 |
description=DESCRIPTION,
|
173 |
css=css,
|
174 |
fill_height=True,
|
@@ -176,4 +135,4 @@ demo = gr.ChatInterface(
|
|
176 |
|
177 |
|
178 |
if __name__ == "__main__":
|
179 |
-
demo.queue(max_size=20).launch(
|
|
|
5 |
import spaces
|
6 |
import torch
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
8 |
|
9 |
DESCRIPTION = """
|
10 |
+
# LlamaEXP
|
11 |
"""
|
12 |
|
13 |
+
css ='''
|
14 |
h1 {
|
15 |
text-align: center;
|
16 |
display: block;
|
|
|
30 |
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
|
33 |
+
model_id = "prithivMLmods/Llama-Express.1"
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
model = AutoModelForCausalLM.from_pretrained(
|
36 |
model_id,
|
37 |
device_map="auto",
|
38 |
torch_dtype=torch.bfloat16,
|
39 |
)
|
|
|
40 |
model.eval()
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@spaces.GPU(duration=120)
|
44 |
def generate(
|
45 |
+
message: str,
|
46 |
+
chat_history: list[dict],
|
|
|
47 |
max_new_tokens: int = 1024,
|
48 |
temperature: float = 0.6,
|
49 |
top_p: float = 0.9,
|
50 |
top_k: int = 50,
|
51 |
repetition_penalty: float = 1.2,
|
52 |
) -> Iterator[str]:
|
53 |
+
conversation = [*chat_history, {"role": "user", "content": message}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
56 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
57 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
58 |
+
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
59 |
+
input_ids = input_ids.to(model.device)
|
|
|
|
|
60 |
|
|
|
61 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
62 |
generate_kwargs = dict(
|
63 |
+
{"input_ids": input_ids},
|
64 |
streamer=streamer,
|
65 |
max_new_tokens=max_new_tokens,
|
66 |
do_sample=True,
|
|
|
69 |
temperature=temperature,
|
70 |
num_beams=1,
|
71 |
repetition_penalty=repetition_penalty,
|
|
|
72 |
)
|
73 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
74 |
t.start()
|
75 |
|
|
|
76 |
outputs = []
|
77 |
for text in streamer:
|
78 |
outputs.append(text)
|
|
|
82 |
demo = gr.ChatInterface(
|
83 |
fn=generate,
|
84 |
additional_inputs=[
|
|
|
85 |
gr.Slider(
|
86 |
label="Max new tokens",
|
87 |
minimum=1,
|
|
|
120 |
],
|
121 |
stop_btn=None,
|
122 |
examples=[
|
123 |
+
["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
|
124 |
+
["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
|
125 |
+
["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
|
126 |
["What happens when the sun goes down?"],
|
127 |
],
|
128 |
+
cache_examp
|
129 |
cache_examples=False,
|
130 |
+
type="messages",
|
131 |
description=DESCRIPTION,
|
132 |
css=css,
|
133 |
fill_height=True,
|
|
|
135 |
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
+
demo.queue(max_size=20).launch()
|