prithivMLmods commited on
Commit
7d0f94b
·
verified ·
1 Parent(s): f21dff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -59
app.py CHANGED
@@ -5,13 +5,12 @@ import gradio as gr
5
  import spaces
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
- from typing import List, Dict, Optional, Tuple
9
 
10
  DESCRIPTION = """
11
- # QwQ Distill
12
  """
13
 
14
- css = '''
15
  h1 {
16
  text-align: center;
17
  display: block;
@@ -31,76 +30,37 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
 
34
- model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_id,
38
  device_map="auto",
39
  torch_dtype=torch.bfloat16,
40
  )
41
- model.config.sliding_window = 4096
42
  model.eval()
43
 
44
- # Set the pad token ID if it's not already set
45
- if tokenizer.pad_token_id is None:
46
- tokenizer.pad_token_id = tokenizer.eos_token_id
47
-
48
- # Define roles for the chat
49
- class Role:
50
- SYSTEM = "system"
51
- USER = "user"
52
- ASSISTANT = "assistant"
53
-
54
- # Default system message
55
- default_system = "You are a helpful assistant."
56
-
57
- def clear_session() -> List:
58
- return "", []
59
-
60
- def modify_system_session(system: str) -> Tuple[str, str, List]:
61
- if system is None or len(system) == 0:
62
- system = default_system
63
- return system, system, []
64
-
65
- def history_to_messages(history: List, system: str) -> List[Dict]:
66
- messages = [{'role': Role.SYSTEM, 'content': system}]
67
- for h in history:
68
- messages.append({'role': Role.USER, 'content': h[0]})
69
- messages.append({'role': Role.ASSISTANT, 'content': h[1]})
70
- return messages
71
 
72
  @spaces.GPU(duration=120)
73
  def generate(
74
- query: Optional[str],
75
- history: Optional[List],
76
- system: str,
77
  max_new_tokens: int = 1024,
78
  temperature: float = 0.6,
79
  top_p: float = 0.9,
80
  top_k: int = 50,
81
  repetition_penalty: float = 1.2,
82
  ) -> Iterator[str]:
83
- if query is None:
84
- query = ''
85
- if history is None:
86
- history = []
87
-
88
- # Convert history to messages
89
- messages = history_to_messages(history, system)
90
- messages.append({'role': Role.USER, 'content': query})
91
 
92
- # Apply chat template and tokenize
93
- text = tokenizer.apply_chat_template(
94
- messages,
95
- tokenize=False,
96
- add_generation_prompt=True
97
- )
98
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
99
 
100
- # Set up the streamer for real-time text generation
101
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
102
  generate_kwargs = dict(
103
- **model_inputs,
104
  streamer=streamer,
105
  max_new_tokens=max_new_tokens,
106
  do_sample=True,
@@ -109,12 +69,10 @@ def generate(
109
  temperature=temperature,
110
  num_beams=1,
111
  repetition_penalty=repetition_penalty,
112
- pad_token_id=tokenizer.pad_token_id,
113
  )
114
  t = Thread(target=model.generate, kwargs=generate_kwargs)
115
  t.start()
116
 
117
- # Stream the output tokens
118
  outputs = []
119
  for text in streamer:
120
  outputs.append(text)
@@ -124,7 +82,6 @@ def generate(
124
  demo = gr.ChatInterface(
125
  fn=generate,
126
  additional_inputs=[
127
- gr.Textbox(label="System Message", value=default_system, lines=2),
128
  gr.Slider(
129
  label="Max new tokens",
130
  minimum=1,
@@ -163,12 +120,14 @@ demo = gr.ChatInterface(
163
  ],
164
  stop_btn=None,
165
  examples=[
166
- ["Write a Python function to reverses a string if it's length is a multiple of 4."],
167
- ["What is the volume of a pyramid with a rectangular base?"],
168
- ["Explain the difference between List comprehension and Lambda in Python."],
169
  ["What happens when the sun goes down?"],
170
  ],
 
171
  cache_examples=False,
 
172
  description=DESCRIPTION,
173
  css=css,
174
  fill_height=True,
@@ -176,4 +135,4 @@ demo = gr.ChatInterface(
176
 
177
 
178
  if __name__ == "__main__":
179
- demo.queue(max_size=20).launch(share=True)
 
5
  import spaces
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
8
 
9
  DESCRIPTION = """
10
+ # LlamaEXP
11
  """
12
 
13
+ css ='''
14
  h1 {
15
  text-align: center;
16
  display: block;
 
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
+ model_id = "prithivMLmods/Llama-Express.1"
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  device_map="auto",
38
  torch_dtype=torch.bfloat16,
39
  )
 
40
  model.eval()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  @spaces.GPU(duration=120)
44
  def generate(
45
+ message: str,
46
+ chat_history: list[dict],
 
47
  max_new_tokens: int = 1024,
48
  temperature: float = 0.6,
49
  top_p: float = 0.9,
50
  top_k: int = 50,
51
  repetition_penalty: float = 1.2,
52
  ) -> Iterator[str]:
53
+ conversation = [*chat_history, {"role": "user", "content": message}]
 
 
 
 
 
 
 
54
 
55
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
56
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
+ input_ids = input_ids.to(model.device)
 
 
60
 
 
61
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
62
  generate_kwargs = dict(
63
+ {"input_ids": input_ids},
64
  streamer=streamer,
65
  max_new_tokens=max_new_tokens,
66
  do_sample=True,
 
69
  temperature=temperature,
70
  num_beams=1,
71
  repetition_penalty=repetition_penalty,
 
72
  )
73
  t = Thread(target=model.generate, kwargs=generate_kwargs)
74
  t.start()
75
 
 
76
  outputs = []
77
  for text in streamer:
78
  outputs.append(text)
 
82
  demo = gr.ChatInterface(
83
  fn=generate,
84
  additional_inputs=[
 
85
  gr.Slider(
86
  label="Max new tokens",
87
  minimum=1,
 
120
  ],
121
  stop_btn=None,
122
  examples=[
123
+ ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
124
+ ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
125
+ ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
126
  ["What happens when the sun goes down?"],
127
  ],
128
+ cache_examp
129
  cache_examples=False,
130
+ type="messages",
131
  description=DESCRIPTION,
132
  css=css,
133
  fill_height=True,
 
135
 
136
 
137
  if __name__ == "__main__":
138
+ demo.queue(max_size=20).launch()