import time
import torch
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
                          OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
                          XLNetLMHeadModel, XLNetTokenizer,
                          TransfoXLLMHeadModel, TransfoXLTokenizer,
                          CTRLLMHeadModel, CTRLTokenizer)

model_metadata = {
    "gpt2/small": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 550,
        "checkpoint": "gpt2",
        "identifier": "gpt2/small"
    }, "gpt": {
        "tokenizer": OpenAIGPTTokenizer,
        "model": OpenAIGPTLMHeadModel,
        "size": 550,
        "checkpoint": "openai-community/openai-gpt",
        "identifier": "gpt"
    }, "xlnet": {
        "tokenizer": XLNetTokenizer,
        "model": XLNetLMHeadModel,
        "size": 550,
        "checkpoint": "xlnet-base-cased",
        "identifier": "xlnet"
    }, "gpt2/arxiv-nlp": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 550,
        "checkpoint": "arxiv-nlp-v1",
        "identifier": "gpt2/arxiv-nlp"
    }, "gpt2/medium": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 1500,
        "checkpoint": "openai-community/gpt2-medium",
        "identifier": "gpt2/medium"
    }, "gpt2/large": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 3300,
        "checkpoint": "openai-community/gpt2-large",
        "identifier": "gpt2/large"
    }, "distilgpt2/small": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 350,
        "checkpoint": "distilgpt2",
        "identifier": "distilgpt2/small"
    }, "ctrl": {
        "tokenizer": CTRLTokenizer,
        "model": CTRLLMHeadModel,
        "size": 6300,
        "checkpoint": "Salesforce/ctrl",
        "identifier": "ctrl"
    }, "pplm": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 3000,
        "checkpoint": "openai-community/gpt2-large",
        "identifier": "pplm"
    }, "gpt2/xl": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 7000,
        "checkpoint": "openai-community/gpt2-xl",
        "identifier": "gpt2/xl"
    }, "pplm": {
        "tokenizer": GPT2Tokenizer,
        "model": GPT2LMHeadModel,
        "size": 4000,
        "checkpoint": "openai-community/gpt2-medium",
        "identifier": "pplm",
        "configuration_options": {
            "config": GPT2Config,
            "options": {
                "output_hidden_states": True
            }
        }
    }
}

memory_overhead = 500

class GPU:
    def __init__(self, id):
        self.id = id
        self.models = []
        self.total_memory = torch.cuda.get_device_properties(
            "cuda:{}".format(id)).total_memory / 1_000_000 - 1_000

        print("INIT GPU WITH DEVICE", "cuda:{}".format(id))

    def register_model(self, model, cached_path=None):
        if self.total_memory_used() + model["size"] < self.total_memory:
            model["device"] = "cuda:{}".format(self.id)

            if cached_path:
                model["cached_path"] = cached_path
    
            self.models.append(model)
            return True
        else:
            return False

    def total_memory_used(self):
        return sum([model["size"] for model in self.models]) + memory_overhead

    def __repr__(self):
        return str(
            [(model["checkpoint"], model["size"]) for model in self.models] +
            [str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
            ["cuda:{}".format(self.id)]
        )


class GPUHandler:
    def __init__(self, ids, model_list, gpu_ids, cached_models=None):
        if cached_models is None:
            cached_models = {}

        self.gpus = [GPU(id) for id in gpu_ids]
        print("GPU handler initiated with {} gpus.".format(len(self.gpus)))

        self.sanity_check([model_metadata[model] for model in model_list])
        
        for model in model_list:
            self.register_model(model_metadata[model], cached_models.get(model))

    def register_model(self, model, cached_path=None):
        for index, gpu in enumerate(self.gpus):
            if gpu.register_model(model, cached_path):
                print("Registered model", model, "in GPU", gpu)
                break

            if index >= len(self.gpus):
                raise ValueError("Could not load model", model["checkpoint"])

    def sanity_check(self, model_list):
        temp_gpus = [GPU(id) for id in range(len(self.gpus))]

        for model in model_list:

            current_gpu_index = 0
            while current_gpu_index < len(temp_gpus):
                if not temp_gpus[current_gpu_index].register_model(model):
                    current_gpu_index += 1
                else:
                    break

                if current_gpu_index >= len(temp_gpus):
                    raise RuntimeError("SANITY CHECK FAILED")

        print("Current layout", temp_gpus)

    def __repr__(self):
        return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"