speedy-llm / app.py
joermd's picture
Update app.py
a526abd verified
import numpy as np
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# التحقق من توفر GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load_model():
"""
تحميل النموذج والمُرمِّز مع التخزين المؤقت
"""
model_name = "joermd/speedy-llama2"
# تهيئة الـtokenizer أولاً
tokenizer = AutoTokenizer.from_pretrained(model_name)
# تهيئة النموذج مع إعدادات مناسبة
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
device_map="auto"
)
return model, tokenizer
def reset_conversation():
'''
إعادة تعيين المحادثة
'''
st.session_state.conversation = []
st.session_state.messages = []
return None
def format_prompt(prompt):
"""
تنسيق المدخل بالطريقة المناسبة لنموذج LLaMA
"""
return f"<s>[INST] {prompt} [/INST]"
def generate_response(model, tokenizer, prompt, temperature=0.7, max_length=500):
"""
توليد استجابة من النموذج
"""
try:
# تنسيق المدخل
formatted_prompt = format_prompt(prompt)
# تحضير المدخلات
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
# توليد النص
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2 # لتجنب التكرار
)
# فك ترميز النص المولد
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# إزالة النص الأصلي من الاستجابة
response = response.replace(formatted_prompt, "").strip()
return response
except Exception as e:
return f"حدث خطأ أثناء توليد الاستجابة: {str(e)}"
# تهيئة Streamlit
st.title("LLaMA-2 Chat 🦙")
# إضافة أزرار التحكم في الشريط الجانبي
with st.sidebar:
st.header("إعدادات")
temperature = st.slider("درجة الإبداعية", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
help="قيمة أعلى تعني إجابات أكثر إبداعية وتنوعاً")
max_tokens = st.slider("الحد الأقصى للكلمات", min_value=50, max_value=1000, value=500, step=50,
help="الحد الأقصى لطول الإجابة")
if st.button("مسح المحادثة"):
reset_conversation()
st.markdown("---")
st.markdown("""
### معلومات النموذج
- **النموذج:** Speedy LLaMA-2
- **الجهاز:** {}
""".format("GPU ⚡" if device == "cuda" else "CPU 💻"))
# تحميل النموذج
try:
with st.spinner("جاري تحميل النموذج... قد يستغرق هذا بضع دقائق..."):
model, tokenizer = load_model()
st.sidebar.success("تم تحميل النموذج بنجاح! 🎉")
except Exception as e:
st.error(f"حدث خطأ أثناء تحميل النموذج: {str(e)}")
st.error("""
تأكد من تثبيت جميع المكتبات المطلوبة:
```bash
pip install transformers torch accelerate streamlit
```
""")
st.stop()
# تهيئة سجل المحادثة
if "messages" not in st.session_state:
st.session_state.messages = []
# عرض المحادثة السابقة
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# معالجة إدخال المستخدم
if prompt := st.chat_input("اكتب رسالتك هنا..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("جاري التفكير..."):
response = generate_response(
model=model,
tokenizer=tokenizer,
prompt=prompt,
temperature=temperature,
max_length=max_tokens
)
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response})