import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import create_repo, HfApi, list_models
from transformers.modeling_utils import PreTrainedModel
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
import logging
import sys

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Ensure sentencepiece is installed
try:
    import sentencepiece
except ImportError:
    subprocess.check_call(['pip', 'install', 'sentencepiece'])

# Function to fetch open-weight LLM models
def fetch_open_weight_models():
    try:
        models = list_models()
        return models
    except Exception as e:
        logging.error(f"Error fetching models: {e}")
        return []

# Custom function to retrieve just names from models list
def get_model_names():
    models = fetch_open_weight_models()
    model_names = [model.modelId for model in models if model.modelId is not None]
    return model_names

# Full merge-kit Pruning Function 
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
    """Prunes a model using a merge-kit approach.
    Args:
        model (PreTrainedModel): The model to be pruned.
        target_num_parameters (int): The target number of parameters after pruning.
        progress (gr.Progress): The progress object for visual feedback.
    Returns:
        PreTrainedModel: The pruned model.
    """
    total_params = sum(p.numel() for p in model.parameters())
    amount = 1 - (target_num_parameters / total_params)

    try:
        # Prune the model
        for i, (name, module) in enumerate(tqdm(model.named_modules(), desc="Pruning", file=sys.stdout)):
            if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
                prune.random_unstructured(module, name="weight", amount=amount)
                progress(percent_complete=50 * (i + 1) / len(list(model.named_modules())))  # Progress update

        # Remove the pruned weights
        for i, (name, module) in enumerate(model.named_modules()):
            if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
                prune.remove(module, name="weight")
                progress(percent_complete=50 + 50 * (i + 1) / len(list(model.named_modules())))  # Progress update
        
        return model
    except Exception as e:
        logging.error(f"Error during pruning: {e}")
        raise e

# Function to prune a model
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name=None, progress=gr.Progress(track_tqdm=True)):
    log_messages = []
    try:
        # Load the LLM model and tokenizer
        llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
        llm_model = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            torch_dtype=torch.float16,
        )
        
        log_messages.append('Model and tokenizer loaded successfully.')
        logging.info('Model and tokenizer loaded successfully.')
        
        total_params = sum(p.numel() for p in llm_model.parameters())
        target_num_parameters = int(total_params * (target_size / 100))
        
        # Prune the model
        pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
        
        log_messages.append('Model pruned successfully.')
        logging.info('Model pruned successfully.')
        
        # Save the pruned model
        api = HfApi()
        create_repo(repo_name, token=hf_write_token, private=False, exist_ok=True)
        pruned_model.push_to_hub(repo_name, use_auth_token=hf_write_token)
        llm_tokenizer.push_to_hub(repo_name, use_auth_token=hf_write_token)
        
        log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
        logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
        
        # Create a visualization
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.bar(['Original', 'Pruned'], [total_params, sum(p.numel() for p in pruned_model.parameters())])
        ax.set_ylabel('Number of Parameters')
        ax.set_title('Model Size Comparison')
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        
        return f"Pruned model saved to Hugging Face Hub in repository {repo_name}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
    
    except Exception as e:
        error_message = f"Detailed error: {repr(e)}"
        log_messages.append(error_message)
        logging.error(error_message)
        return error_message, None, '\n'.join(log_messages)

# Define function to generate text
def generate_text(text, repo_name, hf_write_token):
    try:
        tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
        model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
        generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
        generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
        return generated_text
    except Exception as e:
        logging.error(f"Error during text generation: {e}")
        return f"Error: {repr(e)}"

# Function to create a Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Create a Smaller LLM")
        
        # Fetch available model names
        model_names = get_model_names()

        # Input components
        llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
        base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
        target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
        hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
        repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
        pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
        
        pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
        prune_button = gr.Button("Prune Model")
        visualization = gr.Image(label="Model Size Comparison", interactive=False)
        progress_bar = gr.Progress()

        # Define function for pruning model with progress
        def prune_model_with_progress(llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice):
            if pruned_func_choice == "merge-kit":
                return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
            else:
                return f"Pruning function '{pruned_func_choice}' not implemented.", None, None

        prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice], outputs=[pruning_status, visualization])

        text_input = gr.Textbox(label="Input Text")
        text_output = gr.Textbox(label="Generated Text")
        generate_button = gr.Button("Generate Text")

        generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)

    return demo

# Create and launch the Gradio interface
demo = create_interface()
demo.launch()