Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
import gradio as gr | |
import os | |
import json | |
import gc | |
from huggingface_hub import hf_hub_download | |
import shutil | |
import tempfile | |
# Free up memory | |
gc.collect() | |
print("Setting up model loading...") | |
# Create a temporary directory for model modifications | |
temp_dir = tempfile.mkdtemp() | |
print(f"Created temporary directory: {temp_dir}") | |
# Your model name | |
model_name = "unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit" | |
try: | |
# Download the config.json file | |
print("Downloading configuration file...") | |
config_path = hf_hub_download(repo_id=model_name, filename="config.json") | |
# Load and modify the config to remove quantization | |
print("Modifying configuration...") | |
with open(config_path, 'r') as file: | |
config_dict = json.load(file) | |
# Remove any quantization configs | |
if "quantization_config" in config_dict: | |
del config_dict["quantization_config"] | |
# Save the modified config to the temp directory | |
modified_config_path = os.path.join(temp_dir, "config.json") | |
with open(modified_config_path, 'w') as file: | |
json.dump(config_dict, file) | |
print("Modified configuration saved") | |
# Now try to load with the modified config | |
print("Loading model with modified configuration...") | |
config = AutoConfig.from_pretrained(temp_dir) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
config=config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
print("Model loaded successfully") | |
except Exception as e: | |
print(f"Error during custom loading: {e}") | |
# If the first approach fails, try a direct approach with explicit params | |
print("Attempting alternative loading method...") | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
load_in_4bit=False, | |
load_in_8bit=False, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True, | |
) | |
print("Model loaded with alternative method") | |
except Exception as e2: | |
print(f"Error with alternative loading: {e2}") | |
raise RuntimeError("Failed to load model in any format") | |
finally: | |
# Clean up temp directory | |
shutil.rmtree(temp_dir) | |
print(f"Cleaned up temporary directory") | |
# Load tokenizer | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
# Function to generate response | |
def generate_response(message, history): | |
# Format history for the model | |
prompt = "" | |
if history: | |
for user_msg, assistant_msg in history: | |
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n" | |
prompt += f"User: {message}\nAssistant: " | |
print(f"Tokenizing input...") | |
inputs = tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
# Free up memory before generation | |
gc.collect() | |
print(f"Generating response...") | |
with torch.no_grad(): # Disable gradient calculation to save memory | |
outputs = base_model.generate( | |
**inputs, | |
max_new_tokens=256, # Reduced from 300 to conserve memory | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95 | |
) | |
print(f"Decoding response...") | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
if "Assistant: " in response: | |
response = response.split("Assistant: ")[-1] | |
return response | |
# Launch Gradio UI with memory-efficient settings | |
print("Setting up Gradio interface...") | |
with gr.Blocks() as demo: | |
gr.Markdown("### 🦙 Chat with Your Fine-tuned LLaMA 3.2 3B") | |
chatbot = gr.ChatInterface( | |
generate_response, | |
chatbot=gr.Chatbot(height=400), | |
textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7), | |
submit_btn="Send", | |
retry_btn="Retry", | |
clear_btn="Clear", | |
) | |
print("Launching interface...") | |
demo.launch(share=False, show_api=False) # Disable sharing and API to save resources | |