File size: 7,841 Bytes
dd6c56a
8bb39cb
0445e3f
7730f68
dd6c56a
 
 
5796c7a
 
3cfae8b
0445e3f
8a2d207
dd6c56a
0445e3f
 
dd6c56a
3cfae8b
 
 
 
10fbcfe
3cfae8b
0445e3f
 
4484172
 
 
 
 
 
0445e3f
4484172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0445e3f
5796c7a
 
 
8bb39cb
 
3cfae8b
8bb39cb
4484172
10fbcfe
 
4484172
 
 
 
3cfae8b
8a2d207
4484172
10fbcfe
 
4484172
3cfae8b
8bb39cb
9ab273b
 
 
4484172
9ab273b
 
4484172
5796c7a
 
4484172
10fbcfe
 
5796c7a
10fbcfe
5796c7a
10fbcfe
4484172
9ab273b
4484172
5796c7a
4484172
0445e3f
 
10fbcfe
dd6c56a
9ab273b
 
 
 
 
 
 
 
 
 
 
 
dd6c56a
 
 
 
4484172
 
 
9ab273b
4484172
 
9ab273b
3cfae8b
8bb39cb
 
4484172
 
8bb39cb
dd6c56a
8bb39cb
ae62561
9ab273b
 
 
4484172
 
 
 
9ab273b
 
 
dd6c56a
 
 
9ab273b
10fbcfe
dd6c56a
 
 
 
 
10fbcfe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import create_repo, HfApi, list_models
from transformers.modeling_utils import PreTrainedModel
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
import logging
import sys

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Ensure sentencepiece is installed
try:
    import sentencepiece
except ImportError:
    subprocess.check_call(['pip', 'install', 'sentencepiece'])

# Function to fetch open-weight LLM models
def fetch_open_weight_models():
    try:
        models = list_models()
        return models
    except Exception as e:
        logging.error(f"Error fetching models: {e}")
        return []

# Custom function to retrieve just names from models list
def get_model_names():
    models = fetch_open_weight_models()
    model_names = [model.modelId for model in models if model.modelId is not None]
    return model_names

# Full merge-kit Pruning Function 
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
    """Prunes a model using a merge-kit approach.
    Args:
        model (PreTrainedModel): The model to be pruned.
        target_num_parameters (int): The target number of parameters after pruning.
        progress (gr.Progress): The progress object for visual feedback.
    Returns:
        PreTrainedModel: The pruned model.
    """
    total_params = sum(p.numel() for p in model.parameters())
    amount = 1 - (target_num_parameters / total_params)

    try:
        # Prune the model
        for i, (name, module) in enumerate(tqdm(model.named_modules(), desc="Pruning", file=sys.stdout)):
            if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
                prune.random_unstructured(module, name="weight", amount=amount)
                progress(percent_complete=50 * (i + 1) / len(list(model.named_modules())))  # Progress update

        # Remove the pruned weights
        for i, (name, module) in enumerate(model.named_modules()):
            if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
                prune.remove(module, name="weight")
                progress(percent_complete=50 + 50 * (i + 1) / len(list(model.named_modules())))  # Progress update
        
        return model
    except Exception as e:
        logging.error(f"Error during pruning: {e}")
        raise e

# Function to prune a model
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name=None, progress=gr.Progress(track_tqdm=True)):
    log_messages = []
    try:
        # Load the LLM model and tokenizer
        llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
        llm_model = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            torch_dtype=torch.float16,
        )
        
        log_messages.append('Model and tokenizer loaded successfully.')
        logging.info('Model and tokenizer loaded successfully.')
        
        total_params = sum(p.numel() for p in llm_model.parameters())
        target_num_parameters = int(total_params * (target_size / 100))
        
        # Prune the model
        pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
        
        log_messages.append('Model pruned successfully.')
        logging.info('Model pruned successfully.')
        
        # Save the pruned model
        api = HfApi()
        create_repo(repo_name, token=hf_write_token, private=False, exist_ok=True)
        pruned_model.push_to_hub(repo_name, use_auth_token=hf_write_token)
        llm_tokenizer.push_to_hub(repo_name, use_auth_token=hf_write_token)
        
        log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
        logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
        
        # Create a visualization
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.bar(['Original', 'Pruned'], [total_params, sum(p.numel() for p in pruned_model.parameters())])
        ax.set_ylabel('Number of Parameters')
        ax.set_title('Model Size Comparison')
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        
        return f"Pruned model saved to Hugging Face Hub in repository {repo_name}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
    
    except Exception as e:
        error_message = f"Detailed error: {repr(e)}"
        log_messages.append(error_message)
        logging.error(error_message)
        return error_message, None, '\n'.join(log_messages)

# Define function to generate text
def generate_text(text, repo_name, hf_write_token):
    try:
        tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
        model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
        generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
        generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
        return generated_text
    except Exception as e:
        logging.error(f"Error during text generation: {e}")
        return f"Error: {repr(e)}"

# Function to create a Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Create a Smaller LLM")
        
        # Fetch available model names
        model_names = get_model_names()

        # Input components
        llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
        base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
        target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
        hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
        repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
        pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
        
        pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
        prune_button = gr.Button("Prune Model")
        visualization = gr.Image(label="Model Size Comparison", interactive=False)
        progress_bar = gr.Progress()

        # Define function for pruning model with progress
        def prune_model_with_progress(llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice):
            if pruned_func_choice == "merge-kit":
                return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
            else:
                return f"Pruning function '{pruned_func_choice}' not implemented.", None, None

        prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice], outputs=[pruning_status, visualization])

        text_input = gr.Textbox(label="Input Text")
        text_output = gr.Textbox(label="Generated Text")
        generate_button = gr.Button("Generate Text")

        generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)

    return demo

# Create and launch the Gradio interface
demo = create_interface()
demo.launch()