suvadityamuk commited on
Commit
272fd5b
·
1 Parent(s): bd59709

chore: made optims

Browse files

Signed-off-by: Suvaditya Mukherjee <[email protected]>

Files changed (2) hide show
  1. app.py +39 -5
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
2
  import re
3
  import json
 
 
4
  import torch
5
  import spaces
 
6
  import pymupdf
7
  import gradio as gr
8
  from qdrant_client import QdrantClient
@@ -42,9 +45,16 @@ def generate_answer(chat_history):
42
  add_generation_prompt=True,
43
  )
44
  tool_prompt = tool_prompt.to(model.device)
45
- out = model.generate(**tool_prompt, max_new_tokens=512)
 
 
 
 
 
 
46
  generated_text = out[0, tool_prompt['input_ids'].shape[1]:]
47
  generated_text = tokenizer.decode(generated_text)
 
48
  return generated_text
49
 
50
  def parse_tool_request(tool_call, top_k=5):
@@ -97,7 +107,7 @@ if __name__ == "__main__":
97
  fulltext = merge_strings_with_prefix(fulltext)
98
 
99
  # Embed the sentences
100
- client = QdrantClient(":memory:")
101
 
102
  client.set_model("sentence-transformers/all-MiniLM-L6-v2")
103
 
@@ -115,7 +125,7 @@ if __name__ == "__main__":
115
  parallel=0,
116
  )
117
 
118
- # FOR QWEN, THIS IS WORKING
119
 
120
  model_name = "Qwen/Qwen2.5-3B-Instruct"
121
 
@@ -128,6 +138,7 @@ if __name__ == "__main__":
128
  }
129
  chat_history.append(current_message)
130
 
 
131
  # Generate LLM answer
132
  generated_text = generate_answer(chat_history)
133
 
@@ -137,7 +148,6 @@ if __name__ == "__main__":
137
 
138
  # If tool call was requested
139
  if query_results is not None and tool_query is not None:
140
- print("Inside")
141
  # Update chat history with result of tool call
142
  chat_history = update_chat_history(
143
  chat_history, tool_query, query_results
@@ -145,13 +155,37 @@ if __name__ == "__main__":
145
  # Generate result from the
146
  generated_text = generate_answer(chat_history)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  return generated_text[:-10]
149
 
150
  model = AutoModelForCausalLM.from_pretrained(
151
  model_name,
152
  torch_dtype=torch.bfloat16,
153
  device_map="auto",
154
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
155
  )
156
 
157
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
1
  import os
2
  import re
3
  import json
4
+ import time
5
+ import wandb
6
  import torch
7
  import spaces
8
+ import psutil
9
  import pymupdf
10
  import gradio as gr
11
  from qdrant_client import QdrantClient
 
45
  add_generation_prompt=True,
46
  )
47
  tool_prompt = tool_prompt.to(model.device)
48
+ out = model.generate(
49
+ **tool_prompt,
50
+ max_new_tokens=512,
51
+ do_sample=True,
52
+ top_p=0.95,
53
+ num_beams=4
54
+ )
55
  generated_text = out[0, tool_prompt['input_ids'].shape[1]:]
56
  generated_text = tokenizer.decode(generated_text)
57
+ torch.cuda.empty_cache()
58
  return generated_text
59
 
60
  def parse_tool_request(tool_call, top_k=5):
 
107
  fulltext = merge_strings_with_prefix(fulltext)
108
 
109
  # Embed the sentences
110
+ client = QdrantClient(":memory:", optimize_for_ram_usage=True)
111
 
112
  client.set_model("sentence-transformers/all-MiniLM-L6-v2")
113
 
 
125
  parallel=0,
126
  )
127
 
128
+ wandb.init(project="resume-rag", name="zerogpu-run")
129
 
130
  model_name = "Qwen/Qwen2.5-3B-Instruct"
131
 
 
138
  }
139
  chat_history.append(current_message)
140
 
141
+ start_time = time.time()
142
  # Generate LLM answer
143
  generated_text = generate_answer(chat_history)
144
 
 
148
 
149
  # If tool call was requested
150
  if query_results is not None and tool_query is not None:
 
151
  # Update chat history with result of tool call
152
  chat_history = update_chat_history(
153
  chat_history, tool_query, query_results
 
155
  # Generate result from the
156
  generated_text = generate_answer(chat_history)
157
 
158
+ metrics = {
159
+ "conversation": {
160
+ "turn": len(chat_history) // 2,
161
+ "history": chat_history,
162
+ "current_question": message,
163
+ "current_answer": generated_text[:-10],
164
+ "tool_query": tool_query,
165
+ "rag_results": query_results
166
+ },
167
+ "performance": {
168
+ "response_time": time.time() - start_time,
169
+ "gpu_memory_used": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
170
+ "cpu_memory": psutil.Process().memory_info().rss,
171
+ "gpu_utilization": torch.cuda.utilization() if torch.cuda.is_available() else 0
172
+ }
173
+ }
174
+ wandb.log(metrics)
175
+
176
+ wandb.finish()
177
+
178
  return generated_text[:-10]
179
 
180
  model = AutoModelForCausalLM.from_pretrained(
181
  model_name,
182
  torch_dtype=torch.bfloat16,
183
  device_map="auto",
184
+ quantization_config = BitsAndBytesConfig(
185
+ load_in_4bit=True,
186
+ bnb_4bit_compute_dtype=torch.float16,
187
+ bnb_4bit_quant_type="nf4"
188
+ )
189
  )
190
 
191
  tokenizer = AutoTokenizer.from_pretrained(model_name)
requirements.txt CHANGED
@@ -9,4 +9,6 @@ torchvision
9
  torchaudio
10
  accelerate
11
  bitsandbytes
12
- optimum
 
 
 
9
  torchaudio
10
  accelerate
11
  bitsandbytes
12
+ optimum
13
+ wandb
14
+ psutil