import keras
import keras_hub

model_presets = [
    # 8B params models
    "hf://google/gemma-2-instruct-9b-keras",
    "hf://meta-llama/Llama-3.1-8B-Instruct",
    "hf://google/codegemma-7b-it-keras",
    "hf://keras/mistral_instruct_7b_en",
    "hf://keras/vicuna_1.5_7b_en",
    # "keras/gemma_1.1_instruct_7b_en", # won't fit?
    # 1-3B params models
    "hf://meta-llama/Llama-3.2-1B-Instruct",
    "hf://google/gemma-2b-it-keras",
    "hf://meta-llama/Llama-3.2-3B-Instruct",
]

model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)


def preset_to_website_url(preset):
    preset = preset.removeprefix("hf://")
    url = "http://huggingface.co/" + preset
    return url


def get_appropriate_chat_template(preset):
    return "Vicuna" if "vicuna" in preset else "auto"


def get_default_layout_map(preset_name, device_mesh):
    # Llama's default layout map works for mistral and vicuna
    # because their transformer layers have the same names.
    if (
        "Llama" in preset_name
        or "mistral" in preset_name
        or "vicuna" in preset_name
    ):
        layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
        # Default layout map patch:
        # This line is missing for some Llama models (TODO: fix this in keras_hub)
        layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
        return layout_map

    elif "gemma" in preset_name:
        layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)

        if "gemma-2b-" in preset_name:
            # Default layout map patch:
            # Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
            # Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
            # However:
            # The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
            # Which means sharding NB_HEADS on the "model" dimension.
            # But gemma-2b-it-keras has only 1 head so this won't work: must patch it
            # TODO: fix this in the Gemma layout map in Keras hub.
            patch_key = "decoder_block.*attention.*(query|key|value).kernel"
            layout_map.pop(patch_key)
            layout_map[patch_key] = (None, "model", "batch")

        return layout_map


def log_applied_layout_map(model):
    print("Model class:", type(model).__name__)

    if "Gemma" in type(model).__name__:
        transformer_decoder_block_name = "decoder_block_1"
    elif "Llama" in type(model).__name__:  # works for Llama (Vicuna) and Llama3
        transformer_decoder_block_name = "transformer_layer_1"
    elif "Mistral" in type(model).__name__:
        transformer_decoder_block_name = "transformer_layer_1"
    else:
        print("Unknown architecture. Cannot display the applied layout.")
        return

    # See how layer sharding was applied
    embedding_layer = model.backbone.get_layer("token_embedding")
    print(embedding_layer)
    decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
    print(type(decoder_block))
    for variable in embedding_layer.weights + decoder_block.weights:
        print(
            f"{variable.path:<58}  \
                {str(variable.shape):<16}  \
                {str(variable.value.sharding.spec):<35} \
                {str(variable.dtype)}"
        )


def load_model(preset):
    devices = keras.distribution.list_devices()
    device_mesh = keras.distribution.DeviceMesh(
        shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
    )
    model_parallel = keras.distribution.ModelParallel(
        layout_map=get_default_layout_map(preset, device_mesh),
        batch_dim_name="batch",
    )

    with model_parallel.scope():
        # These two buggy models need this workaround to be loaded in bfloat16
        if "google/gemma-2-instruct-9b-keras" in preset:
            model = keras_hub.models.GemmaCausalLM(
                backbone=keras_hub.models.GemmaBackbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
            model = keras_hub.models.Llama3CausalLM(
                backbone=keras_hub.models.Llama3Backbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        else:
            model = keras_hub.models.CausalLM.from_preset(
                preset, dtype="bfloat16"
            )

    log_applied_layout_map(model)
    return model