Spaces:
Runtime error
Runtime error
File size: 7,841 Bytes
dd6c56a 8bb39cb 0445e3f 7730f68 dd6c56a 5796c7a 3cfae8b 0445e3f 8a2d207 dd6c56a 0445e3f dd6c56a 3cfae8b 10fbcfe 3cfae8b 0445e3f 4484172 0445e3f 4484172 0445e3f 5796c7a 8bb39cb 3cfae8b 8bb39cb 4484172 10fbcfe 4484172 3cfae8b 8a2d207 4484172 10fbcfe 4484172 3cfae8b 8bb39cb 9ab273b 4484172 9ab273b 4484172 5796c7a 4484172 10fbcfe 5796c7a 10fbcfe 5796c7a 10fbcfe 4484172 9ab273b 4484172 5796c7a 4484172 0445e3f 10fbcfe dd6c56a 9ab273b dd6c56a 4484172 9ab273b 4484172 9ab273b 3cfae8b 8bb39cb 4484172 8bb39cb dd6c56a 8bb39cb ae62561 9ab273b 4484172 9ab273b dd6c56a 9ab273b 10fbcfe dd6c56a 10fbcfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import create_repo, HfApi, list_models
from transformers.modeling_utils import PreTrainedModel
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
import logging
import sys
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Ensure sentencepiece is installed
try:
import sentencepiece
except ImportError:
subprocess.check_call(['pip', 'install', 'sentencepiece'])
# Function to fetch open-weight LLM models
def fetch_open_weight_models():
try:
models = list_models()
return models
except Exception as e:
logging.error(f"Error fetching models: {e}")
return []
# Custom function to retrieve just names from models list
def get_model_names():
models = fetch_open_weight_models()
model_names = [model.modelId for model in models if model.modelId is not None]
return model_names
# Full merge-kit Pruning Function
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
"""Prunes a model using a merge-kit approach.
Args:
model (PreTrainedModel): The model to be pruned.
target_num_parameters (int): The target number of parameters after pruning.
progress (gr.Progress): The progress object for visual feedback.
Returns:
PreTrainedModel: The pruned model.
"""
total_params = sum(p.numel() for p in model.parameters())
amount = 1 - (target_num_parameters / total_params)
try:
# Prune the model
for i, (name, module) in enumerate(tqdm(model.named_modules(), desc="Pruning", file=sys.stdout)):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.random_unstructured(module, name="weight", amount=amount)
progress(percent_complete=50 * (i + 1) / len(list(model.named_modules()))) # Progress update
# Remove the pruned weights
for i, (name, module) in enumerate(model.named_modules()):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.remove(module, name="weight")
progress(percent_complete=50 + 50 * (i + 1) / len(list(model.named_modules()))) # Progress update
return model
except Exception as e:
logging.error(f"Error during pruning: {e}")
raise e
# Function to prune a model
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name=None, progress=gr.Progress(track_tqdm=True)):
log_messages = []
try:
# Load the LLM model and tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16,
)
log_messages.append('Model and tokenizer loaded successfully.')
logging.info('Model and tokenizer loaded successfully.')
total_params = sum(p.numel() for p in llm_model.parameters())
target_num_parameters = int(total_params * (target_size / 100))
# Prune the model
pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
log_messages.append('Model pruned successfully.')
logging.info('Model pruned successfully.')
# Save the pruned model
api = HfApi()
create_repo(repo_name, token=hf_write_token, private=False, exist_ok=True)
pruned_model.push_to_hub(repo_name, use_auth_token=hf_write_token)
llm_tokenizer.push_to_hub(repo_name, use_auth_token=hf_write_token)
log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
# Create a visualization
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(['Original', 'Pruned'], [total_params, sum(p.numel() for p in pruned_model.parameters())])
ax.set_ylabel('Number of Parameters')
ax.set_title('Model Size Comparison')
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"Pruned model saved to Hugging Face Hub in repository {repo_name}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
except Exception as e:
error_message = f"Detailed error: {repr(e)}"
log_messages.append(error_message)
logging.error(error_message)
return error_message, None, '\n'.join(log_messages)
# Define function to generate text
def generate_text(text, repo_name, hf_write_token):
try:
tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
return generated_text
except Exception as e:
logging.error(f"Error during text generation: {e}")
return f"Error: {repr(e)}"
# Function to create a Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## Create a Smaller LLM")
# Fetch available model names
model_names = get_model_names()
# Input components
llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
prune_button = gr.Button("Prune Model")
visualization = gr.Image(label="Model Size Comparison", interactive=False)
progress_bar = gr.Progress()
# Define function for pruning model with progress
def prune_model_with_progress(llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice):
if pruned_func_choice == "merge-kit":
return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
else:
return f"Pruning function '{pruned_func_choice}' not implemented.", None, None
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice], outputs=[pruning_status, visualization])
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Generated Text")
generate_button = gr.Button("Generate Text")
generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
return demo
# Create and launch the Gradio interface
demo = create_interface()
demo.launch() |