Rajesh0279 commited on
Commit
679b323
·
verified ·
1 Parent(s): 6549e3d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +56 -0
src/streamlit_app.py CHANGED
@@ -13,9 +13,13 @@ import numpy as np
13
  import re
14
  import plotly.express as px
15
  import plotly.graph_objects as go
 
16
  from typing import Optional, Tuple, List, Dict
17
  from run3 import estimate_training_time_and_cost,get_gpu_teraflops,get_gpu_cost_per_tflop_hour
18
  from utils import get_all_models_from_database
 
 
 
19
 
20
  # ADD THIS BLOCK HERE (Line 16)
21
  # Language configuration
@@ -116,6 +120,58 @@ TRANSLATIONS = {
116
  }
117
  }
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def get_text(key, lang='en'):
120
  """Get translated text for given key and language"""
121
  return TRANSLATIONS.get(lang, TRANSLATIONS['en']).get(key, key)
 
13
  import re
14
  import plotly.express as px
15
  import plotly.graph_objects as go
16
+ import torch
17
  from typing import Optional, Tuple, List, Dict
18
  from run3 import estimate_training_time_and_cost,get_gpu_teraflops,get_gpu_cost_per_tflop_hour
19
  from utils import get_all_models_from_database
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
21
+
22
+
23
 
24
  # ADD THIS BLOCK HERE (Line 16)
25
  # Language configuration
 
120
  }
121
  }
122
 
123
+ @st.cache_resource
124
+ def load_llama3_pipeline():
125
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
126
+ model = AutoModelForCausalLM.from_pretrained(
127
+ "meta-llama/Llama-3.1-8B-Instruct",
128
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
129
+ device_map="auto" if torch.cuda.is_available() else None
130
+ )
131
+ return tokenizer, model
132
+
133
+ tokenizer, model = load_llama3_pipeline()
134
+
135
+ st.title("🧠 Chat with Llama 3.1 8B (Instruct)")
136
+
137
+ if 'chat_history' not in st.session_state:
138
+ st.session_state.chat_history = [
139
+ {"role": "system", "content": "You are a helpful, concise assistant."}
140
+ ]
141
+
142
+ user_input = st.text_input("You:", key="user_input")
143
+
144
+ if user_input:
145
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
146
+
147
+ # Format messages into prompt
148
+ messages = st.session_state.chat_history
149
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
150
+
151
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
152
+
153
+ with st.spinner("Llama 3 is thinking..."):
154
+ output = model.generate(
155
+ **inputs,
156
+ max_new_tokens=512,
157
+ temperature=0.7,
158
+ do_sample=True,
159
+ top_p=0.9,
160
+ pad_token_id=tokenizer.eos_token_id
161
+ )
162
+
163
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
164
+ response = decoded.split(prompt)[-1].strip()
165
+
166
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
167
+
168
+ # Display conversation
169
+ for msg in st.session_state.chat_history:
170
+ if msg["role"] == "user":
171
+ st.markdown(f"**You:** {msg['content']}")
172
+ elif msg["role"] == "assistant":
173
+ st.markdown(f"**AI:** {msg['content']}")
174
+
175
  def get_text(key, lang='en'):
176
  """Get translated text for given key and language"""
177
  return TRANSLATIONS.get(lang, TRANSLATIONS['en']).get(key, key)