import requests
import re
from collections import defaultdict
# Utilities related to loading in and working with models/specific models
from urllib.parse import urlparse
from accelerate.commands.estimate import check_has_model, create_empty_model
from accelerate.utils import compute_module_sizes, named_module_tensors
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError


def fetch_dictionary_content(model_id):
    MODEL_URL = "https://huggingface.co/{model_id}/raw/main/config.json"
    response = requests.get(MODEL_URL.format(model_id=model_id))
    
    # Check if the request was successful
    if response.status_code == 200:
        return response.json()  # Parse the JSON content into a Python dictionary
    else:
        return None
    
def load_parameter(model_dict, cand_keys):
    for k in cand_keys:
        if k in model_dict:
            return model_dict[k]
    return 0

# Reference: https://huggingface.co/spaces/hf-accelerate/model-memory-usage
def extract_from_url(name: str):
    "Checks if `name` is a URL, and if so converts it to a model name"
    is_url = False
    try:
        result = urlparse(name)
        is_url = all([result.scheme, result.netloc])
    except Exception:
        is_url = False
    # Pass through if not a URL
    if not is_url:
        return name
    else:
        path = result.path
        return path[1:]


def translate_llama2(text):
    "Translates llama-2 to its hf counterpart"
    if not text.endswith("-hf"):
        return text + "-hf"
    return text


def get_model(model_name: str, library: str, access_token: str):
    "Finds and grabs model from the Hub, and initializes on `meta`"
    if "meta-llama" in model_name:
        model_name = translate_llama2(model_name)
    if library == "auto":
        library = None
    model_name = extract_from_url(model_name)
    try:
        model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
    except GatedRepoError:
        raise RuntimeError(
            f"Model `{model_name}` is a gated model, please ensure to pass in your access token and try again if you have access. You can find your access token here : https://huggingface.co/settings/tokens. "
        )
    except RepositoryNotFoundError:
        raise RuntimeError(f"Model `{model_name}` was not found on the Hub, please try another model name.")
    except ValueError:
        raise RuntimeError(
            f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
        )
    except (RuntimeError, OSError) as e:
        library = check_has_model(e)
        if library != "unknown":
            raise RuntimeError(
                f"Tried to load `{model_name}` with `{library}` but a possible model to load was not found inside the repo."
            )
        raise RuntimeError(
            f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
        )
    except ImportError:
        # hacky way to check if it works with `trust_remote_code=False`
        model = create_empty_model(
            model_name, library_name=library, trust_remote_code=False, access_token=access_token
        )
    except Exception as e:
        raise RuntimeError(
            f"Model `{model_name}` had an error, please open a discussion on the model's page with the error message and name: `{e}`"
        )
    return model

def get_module_tensors(model):
    module_tensors = {}
    for name, tensor in named_module_tensors(model, recurse=True):
        module_tensors[name] = tensor.shape

    return module_tensors


def classify_module(module_tensors):
    # A dictionary to store counts for each generic layer type
    module_classes = defaultdict(list)

    # This function removes all numbers from a given string
    def remove_numbers(s):
        return re.sub(r'\d+', '', s)
    
    # Loop through all named parameters of the model
    for name in module_tensors:
        # Remove numbers from the name
        generic_name = remove_numbers(name)
        generic_name = generic_name.replace('..', '.')
        
        # If the name already exists in the dictionary, increase the count, else set it to 1
        module_classes[generic_name].append({name: module_tensors[name]})

    return module_classes

def get_module_tensors_matched(filter_fn, module_classes_dict):
    matched = []
    for generic, module_list in module_classes_dict.items():
        if filter_fn(generic.lower()):
            matched.extend([v for module in module_list for v in module.values()])
       
    return matched


if __name__ == '__main__':
    import torch
    model = get_model('NousResearch/Nous-Hermes-Llama2-13b', None, None)
    module_tensors = get_module_tensors(model)
    module_classes = classify_module(module_tensors)
    sizes = compute_module_sizes(model, dtype=torch.int8)
    size_dict = {
        'attn':0,
        'mlp':0,
        'embed':0,
    }
    for k, v in sizes.items():
        for kk in size_dict:
            if kk in k and 'weight' in k:
                size_dict[kk] += v/1024**3
    print(sizes)