from ast import literal_eval
from nltk.stem import PorterStemmer, WordNetLemmatizer

# Entity Extraction


def generate_ner_docs_prompt(query):
    prompt = """USER: Extract the company names and time duration mentioned in the question. The entities should be extracted in the following format: {"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")}. Return {"companies": None, "start-duration": (None, None), "end-duration": (None, None)} if the entities are not found.

Examples:
What is Nvidia's visibility in the data center business in Q2 2020?
{"companies": ["Nvidia"], "start-duration": ("Q2", "2020"), "end-duration": ("Q2", "2020")}
What is Intel's update on the server chip road map and strategy for 2019?
{"companies": ["Intel"], "start-duration": ("Q1", "2019"), "end-duration": ("Q4", "2019")}
What are the opportunities and challenges in the Indian market for Amazon in 2016?
{"companies": ["Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2016")}
What did analysts ask about the Cisco's Webex?
{"companies": ["Cisco"], "start-duration": (None, None), "end-duration": (None, None)}
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q1 to Q3 2016?
{"companies": ["Intel", "AMD"], "start-duration": ("Q1", "2016"), "end-duration": ("Q3", "2016")}
How has the growth been for AMD in the PC market in 2020?
{"companies": ["AMD"], "start-duration": ("Q1", "2020"), "end-duration": ("Q4", "2020")}
How did Microsoft and Amazon perform in terms of reliability and scalability of cloud for the years 2016 and 2017?
{"companies": ["Microsoft", "Amazon"], "start-duration": ("Q1", "2016"), "end-duration": ("Q4", "2017")}"""
    input_prompt = f"""###Input: {query}
ASSISTANT:"""
    final_prompt = prompt + "\n\n" + input_prompt
    return final_prompt


def extract_entities_docs(query, model):
    """
    Takes input a string which contains a dictionary of entities of the format:
    {"companies": list of companies mentioned in the question,"start-duration": ("start-quarter", "start-year"), "end-duration": ("end-quarter", "end-year")}
    """
    prompt = generate_ner_docs_prompt(query)
    string_of_dict = model.predict(prompt, api_name="/predict")
    print(string_of_dict)
    string_of_dict = string_of_dict.strip()
    entities_dict = literal_eval(f"""{string_of_dict}""")
    start_quarter, start_year = entities_dict["start-duration"]
    end_quarter, end_year = entities_dict["end-duration"]
    companies = entities_dict["companies"]
    print((companies, start_quarter, start_year, end_quarter, end_year))
    return companies, start_quarter, start_year, end_quarter, end_year


def year_quarter_range(start_quarter, start_year, end_quarter, end_year):
    """
    Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters.

    Example:
    year_quarter_range("Q2", "2020", "Q3", "2021")
    [('Q2', '2020'), ('Q3', '2020'), ('Q4', '2020'), ('Q1', '2021'), ('Q2', '2021'), ('Q3', '2021')]
    """
    if (
        start_quarter is None
        or start_year is None
        or end_quarter is None
        or end_year is None
    ):
        return []
    else:
        quarters = ["Q1", "Q2", "Q3", "Q4"]
        start_index = quarters.index(start_quarter)
        end_index = quarters.index(end_quarter)

        years = range(int(start_year), int(end_year) + 1)
        year_quarter_range_list = []

        for year in years:
            if year == int(start_year):
                start = start_index
            else:
                start = 0

            if year == int(end_year):
                end = end_index + 1
            else:
                end = len(quarters)

            for quarter_index in range(start, end):
                year_quarter_range_list.append(
                    (quarters[quarter_index], str(year))
                )
        print(year_quarter_range_list)
        return year_quarter_range_list


def clean_companies(company_list):
    """Returns list of Tickers from list of companies"""
    company_ticker_map = {
        "apple": "AAPL",
        "amd": "AMD",
        "amazon": "AMZN",
        "cisco": "CSCO",
        "google": "GOOGL",
        "microsoft": "MSFT",
        "nvidia": "NVDA",
        "asml": "ASML",
        "intel": "INTC",
        "micron": "MU",
    }

    tickers = [
        "AAPL",
        "CSCO",
        "MSFT",
        "ASML",
        "NVDA",
        "GOOGL",
        "MU",
        "INTC",
        "AMZN",
        "AMD",
        "aapl",
        "csco",
        "msft",
        "asml",
        "nvda",
        "googl",
        "mu",
        "intc",
        "amzn",
        "amd",
    ]

    ticker_list = []
    for company in company_list:
        if company.lower() in company_ticker_map.keys():
            ticker = company_ticker_map[company.lower()]
            ticker_list.append(ticker)
        elif company.lower() in tickers:
            ticker_list.append(company.upper())
    return ticker_list


def ticker_year_quarter_tuples_creator(ticker_list, year_quarter_range_list):
    ticker_year_quarter_tuples_list = []
    for ticker in ticker_list:
        if year_quarter_range_list == []:
            return []
        else:
            for quarter, year in year_quarter_range_list:
                ticker_year_quarter_tuples_list.append((ticker, quarter, year))
    return ticker_year_quarter_tuples_list


# Keyword Extraction


def generate_ner_keywords_prompt(query):
    prompt = """USER: Extract the entities which describe the key theme and topics being asked in the question. Extract the entities in the following format: {"entities":["keywords"]}.
Examples:
What is Intel's update on the server chip roadmap and strategy for Q1 2019?
{"entities":["server"]}
What are the opportunities and challenges in the Indian market for Amazon from Q1 to Q3 in 2016?
{"entities":["indian"]}
What is the comparative performance analysis between Intel and AMD in key overlapping segments such as PC, Gaming, and Data Centers in Q1 2016?
{"entities":["PC","Gaming","Data Centers"]}
What was Google's and Microsoft's capex spend for the last 2 years?
{"entities":["capex"]}
What did analysts ask about the cloud during Microsoft's earnings call in Q1 2018?
{"entities":["cloud"]}
What was the growth in Apple services revenue for 2017 Q3?
{"entities":["services"]}"""
    input_prompt = f"""###Input: {query}
ASSISTANT:"""
    final_prompt = prompt + "\n" + input_prompt
    return final_prompt


def extract_entities_keywords(query, model):
    """
    Takes input a string which contains a dictionary of entities of the format:
    {"entities":["keywords"]}
    """
    prompt = generate_ner_keywords_prompt(query)
    string_of_dict = model.predict(prompt, api_name="/predict")
    print(string_of_dict)
    string_of_dict = string_of_dict.strip()
    entities_dict = literal_eval(f"""{string_of_dict}""")
    keywords_list = entities_dict["entities"]
    return keywords_list


def expand_list_of_lists(list_of_lists):
    """
    Expands a list of lists of strings to a list of strings.
    Args:
      list_of_lists: A list of lists of strings.
    Returns:
      A list of strings.
    """

    expanded_list = []
    for inner_list in list_of_lists:
        for string in inner_list:
            expanded_list.append(string)
    return expanded_list


def all_keywords_combs(list_of_cleaned_keywords):
    # Convert all strings to lowercase.
    lower_texts = [text.lower() for text in list_of_cleaned_keywords]

    # Stem the words in each string.
    stemmer = PorterStemmer()
    stem_texts = [stemmer.stem(text) for text in list_of_cleaned_keywords]

    # Lemmatize the words in each string.
    lemmatizer = WordNetLemmatizer()
    lemm_texts = [
        lemmatizer.lemmatize(text) for text in list_of_cleaned_keywords
    ]

    list_of_cleaned_keywords.extend(lower_texts)
    list_of_cleaned_keywords.extend(stem_texts)
    list_of_cleaned_keywords.extend(lemm_texts)

    list_of_cleaned_keywords = list(set(list_of_cleaned_keywords))

    return list_of_cleaned_keywords


def create_incorrect_entities_list():
    words_to_remove = [
        "q1",
        "q2",
        "q3",
        "q4",
        "2016",
        "2017",
        "2018",
        "2019",
        "2020",
        "apple",
        "amd",
        "amazon",
        "cisco",
        "google",
        "microsoft",
        "nvidia",
        "asml",
        "intel",
        "micron",
        "strategy",
        "roadmap",
        "impact",
        "opportunities",
        "challenges",
        "growth",
        "performance",
        "analysis",
        "segments",
        "comparative",
        "overlapping",
        "acquisition",
        "revenue",
    ]
    words_to_remove = all_keywords_combs(words_to_remove)
    return words_to_remove


def clean_keywords_all_combs(keywords_list):

    words_to_remove = create_incorrect_entities_list()

    texts = [text.split(" ") for text in keywords_list]
    texts = expand_list_of_lists(texts)

    # Convert all strings to lowercase.
    lower_texts = [text.lower() for text in texts]
    cleaned_keywords = [
        text for text in lower_texts if text not in words_to_remove
    ]
    all_cleaned_keywords = all_keywords_combs(cleaned_keywords)
    return all_cleaned_keywords