Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +56 -0
src/streamlit_app.py
CHANGED
|
@@ -13,9 +13,13 @@ import numpy as np
|
|
| 13 |
import re
|
| 14 |
import plotly.express as px
|
| 15 |
import plotly.graph_objects as go
|
|
|
|
| 16 |
from typing import Optional, Tuple, List, Dict
|
| 17 |
from run3 import estimate_training_time_and_cost,get_gpu_teraflops,get_gpu_cost_per_tflop_hour
|
| 18 |
from utils import get_all_models_from_database
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# ADD THIS BLOCK HERE (Line 16)
|
| 21 |
# Language configuration
|
|
@@ -116,6 +120,58 @@ TRANSLATIONS = {
|
|
| 116 |
}
|
| 117 |
}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def get_text(key, lang='en'):
|
| 120 |
"""Get translated text for given key and language"""
|
| 121 |
return TRANSLATIONS.get(lang, TRANSLATIONS['en']).get(key, key)
|
|
|
|
| 13 |
import re
|
| 14 |
import plotly.express as px
|
| 15 |
import plotly.graph_objects as go
|
| 16 |
+
import torch
|
| 17 |
from typing import Optional, Tuple, List, Dict
|
| 18 |
from run3 import estimate_training_time_and_cost,get_gpu_teraflops,get_gpu_cost_per_tflop_hour
|
| 19 |
from utils import get_all_models_from_database
|
| 20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
|
| 24 |
# ADD THIS BLOCK HERE (Line 16)
|
| 25 |
# Language configuration
|
|
|
|
| 120 |
}
|
| 121 |
}
|
| 122 |
|
| 123 |
+
@st.cache_resource
|
| 124 |
+
def load_llama3_pipeline():
|
| 125 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
| 126 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 127 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
| 128 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 129 |
+
device_map="auto" if torch.cuda.is_available() else None
|
| 130 |
+
)
|
| 131 |
+
return tokenizer, model
|
| 132 |
+
|
| 133 |
+
tokenizer, model = load_llama3_pipeline()
|
| 134 |
+
|
| 135 |
+
st.title("🧠 Chat with Llama 3.1 8B (Instruct)")
|
| 136 |
+
|
| 137 |
+
if 'chat_history' not in st.session_state:
|
| 138 |
+
st.session_state.chat_history = [
|
| 139 |
+
{"role": "system", "content": "You are a helpful, concise assistant."}
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
user_input = st.text_input("You:", key="user_input")
|
| 143 |
+
|
| 144 |
+
if user_input:
|
| 145 |
+
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
| 146 |
+
|
| 147 |
+
# Format messages into prompt
|
| 148 |
+
messages = st.session_state.chat_history
|
| 149 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 150 |
+
|
| 151 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 152 |
+
|
| 153 |
+
with st.spinner("Llama 3 is thinking..."):
|
| 154 |
+
output = model.generate(
|
| 155 |
+
**inputs,
|
| 156 |
+
max_new_tokens=512,
|
| 157 |
+
temperature=0.7,
|
| 158 |
+
do_sample=True,
|
| 159 |
+
top_p=0.9,
|
| 160 |
+
pad_token_id=tokenizer.eos_token_id
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 164 |
+
response = decoded.split(prompt)[-1].strip()
|
| 165 |
+
|
| 166 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
| 167 |
+
|
| 168 |
+
# Display conversation
|
| 169 |
+
for msg in st.session_state.chat_history:
|
| 170 |
+
if msg["role"] == "user":
|
| 171 |
+
st.markdown(f"**You:** {msg['content']}")
|
| 172 |
+
elif msg["role"] == "assistant":
|
| 173 |
+
st.markdown(f"**AI:** {msg['content']}")
|
| 174 |
+
|
| 175 |
def get_text(key, lang='en'):
|
| 176 |
"""Get translated text for given key and language"""
|
| 177 |
return TRANSLATIONS.get(lang, TRANSLATIONS['en']).get(key, key)
|