import huggingface_hub
import re
    
class LlamaManager():
    def __init__(self, llama_token = None, verbose = False):
        self.verbose = verbose
        
        if self.verbose:
            print("LlamaManager::__init__::Initializing LlamaManager")
        self.client = huggingface_hub.InferenceClient(
            "meta-llama/Meta-Llama-3.1-70B-Instruct",
            token=llama_token,
        )
        if self.verbose:
            print("LlamaManager::__init__::Initialized LlamaManager")
            
            
    def __get_items_between_tags(self, input_string, tag1, tag2):
        pattern = r'' + tag1 + '(.*?)' + tag2 + ''
        return re.findall(pattern, input_string, re.DOTALL)
        
    
    def __preprocss_for_auto_generate_questions_categories(self, available_categories):
        if self.verbose:
            print("LlamaManager::__preprocss_for_auto_generate_questions_categories::Preprocessing")
        out = ""
        for available_category in available_categories:
            out += f"[A]{available_category}[/A]"
        return out
    

    def __postprocess_for_auto_generate_questions_categories(self, out):
        if self.verbose:
            print("LlamaManager::__postprocess_for_auto_generate_questions_categories::Postprocessing")
            
        out = self.__get_items_between_tags(out, r"\[L\]", r"\[/L\]")[0]
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_questions_categories::No content found")
            return []
        out = self.__get_items_between_tags(out, r"\[A\]", r"\[/A\]")
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_questions_categories::No categories found")
            return []
        return out
    
        
    def auto_generate_questions_categories(
        self, 
        count = 20, 
        available_categories = ["Variables"], 
        seed = 123,
        temperature = 1.0,
        top_p = 0.9,
        frequency_penalty = 0.0
        ):
        available_content_for_assistant = self.__preprocss_for_auto_generate_questions_categories(available_categories)
        if self.verbose:
            print("LlamaManager::auto_generate_questions_categories::Generating questions categories")
        
        message_content = [
            {"role": "system", "content": "You are a synthetic data generator. You must only answer questions as a list. Each item of the list should be enclosed in [A] and [/A] tags. The list should be enclosed in [L] and [/L] tags."},
            {"role": "user", "content": f"Write me {count} basic topics for python programming"},
            {"role": "assistant", "content": f"[L]{available_content_for_assistant}"}
        ]
        
        out = self.client.chat_completion(
            messages = message_content,
            max_tokens = 1000,
            stream = False,
            seed = seed,
            temperature = temperature,
            top_p = top_p,
            frequency_penalty = frequency_penalty
        )
        
        categories = self.__postprocess_for_auto_generate_questions_categories(out.choices[0].message.content)
        if self.verbose:
            print("LlamaManager::auto_generate_questions_categories::Generated questions Categories")
        
        return categories
    
    
    def __postprocess_for_auto_generate_shots_for_category(self, out):
        if self.verbose:
            print("LlamaManager::__postprocess_for_auto_generate_shots_for_category::Postprocessing")
            
        out = self.__get_items_between_tags(out, r"\[L\]", r"\[/L\]")[0]
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_shots_for_category::No content found")
            return []
        out = self.__get_items_between_tags(out, r"\[A\]", r"\[/A\]")
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_shots_for_category::No questions found")
            return []
        return out

    
    def auto_generate_shots_for_category(
        self, 
        count, 
        category, 
        seed = 123,
        temperature = 1.0,
        top_p = 0.9,
        frequency_penalty = 0.0
        ):
        if self.verbose:
            print("LlamaManager::auto_generate_shots_for_category::Generating shots for category")
        
        message_content = [
            {"role": "system", "content": "You are a synthetic data generator. You must only answer questions as a list. Each item of the list should be enclosed in [A] and [/A] tags. The list should be enclosed in [L] and [/L] tags."},
            {"role": "user", "content": f"Write me 2 programming questions on the topic of For Loop in Python. The question should be of medium and hard difficulty. The question should involve use of just one function"},
            {"role": "assistant", "content": f"""[L]
             - [A]Write a program that takes a positive integer as input and computes the sum of its digits using a for loop.[/A]
             - [A]Write a program that generates a spiral matrix of size NxN, where N is always an odd number. Fill the spiral matrix with consecutive prime numbers in a clockwise spiral pattern, starting from the center of the matrix.[/A]
             """},
            {"role": "user", "content": f"Write me {count} programming questions on the topic of {category} in Python. The question should be of medium and hard difficulty. The question should involve use of just one function"},
            {"role": "assistant", "content": f"[L]"}
        ]
        
        out = self.client.chat_completion(
            messages = message_content,
            max_tokens = 1000,
            stream = False,
            seed = seed,
            temperature = temperature,
            top_p = top_p,
            frequency_penalty = frequency_penalty
        )
        
        shots = self.__postprocess_for_auto_generate_shots_for_category(out.choices[0].message.content + "[/L]")
        if self.verbose:
            print(f"LlamaManager::auto_generate_shots_for_category::Generated {count} shots for {category}")
        
        return shots
    
    
    def __preprocess_for_auto_generate_questions_from_shots(self, shots):
        if self.verbose:
            print("LlamaManager::__preprocess_for_auto_generate_questions_from_shots::Preprocessing")
        out = ""
        for shot in shots:
            out += f"[A]{shot}[/A]"
        return out

    
    def __postprocess_for_auto_generate_questions_from_shots(self, out):
        if self.verbose:
            print("LlamaManager::__postprocess_for_auto_generate_questions_from_shots::Postprocessing")
            
        out = self.__get_items_between_tags(out, r"\[L\]", r"\[/L\]")[0]
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_questions_from_shots::No content found")
            return []
        out = self.__get_items_between_tags(out, r"\[A\]", r"\[/A\]")
        if not out:
            if self.verbose:
                print("LlamaManager::__postprocess_for_auto_generate_questions_from_shots::No questions found")
            return []
        return out
    
    
    def auto_generate_questions_from_shots(
        self,
        count,
        category,
        shots,
        seed = 123,
        temperature = 1.0,
        top_p = 0.9,
        frequency_penalty = 0.0
        ):
        available_content_for_assistant = self.__preprocess_for_auto_generate_questions_from_shots(shots)
        if self.verbose:
            print("LlamaManager::auto_generate_questions_from_shots::Generating questions from shots")
        
        message_content = [
            {"role": "system", "content": "You are a synthetic data generator. You must only answer questions as a list. Each item of the list should be enclosed in [A] and [/A] tags. The list should be enclosed in [L] and [/L] tags."},
            {"role": "user", "content": f"Write me {count} python programming questions which uses {category.lower()}"},
            {"role": "assistant", "content": f"[L]{available_content_for_assistant}"}
        ]
        
        previous_iteration_questions_count = []
        questions = []
        token_count = 1000
        while len(questions) < count:
            out = self.client.chat_completion(
                messages = message_content,
                max_tokens = token_count,
                stream = False,
                seed = seed,
                temperature = temperature,
                top_p = top_p,
                frequency_penalty = frequency_penalty
            )

            questions = self.__postprocess_for_auto_generate_questions_from_shots(out.choices[0].message.content + "[/L]")
            available_content_for_assistant = self.__preprocess_for_auto_generate_questions_from_shots(questions)
            previous_iteration_questions_count.append(len(questions))
            message_content = [
                {"role": "system", "content": "You are a synthetic data generator. You must only answer questions as a list. Each item of the list should be enclosed in [A] and [/A] tags. The list should be enclosed in [L] and [/L] tags."},
                {"role": "user", "content": f"Write me {count} python programming questions which uses {category.lower()}"},
                {"role": "assistant", "content": f"[L]{available_content_for_assistant}"}
            ]
            token_count += 500
            
            if len(previous_iteration_questions_count) > 3:
                if previous_iteration_questions_count[-1] == previous_iteration_questions_count[-2] == previous_iteration_questions_count[-3] == previous_iteration_questions_count[-4]:
                    if self.verbose:
                        print("LlamaManager::auto_generate_questions_from_shots::Generation could not be completed, stopping API calls")
                    break
        
        if self.verbose:    
            print("LlamaManager::auto_generate_questions_from_shots::Generated questions from shots")
        
        return questions
    
    
    def __postprocess_for_auto_generate_function_signature_from_question(self, out):
        if self.verbose:
            print("LlamaManager::__postprocess_for_auto_generate_function_signature_from_question::Postprocessing")
            
        out = self.__get_items_between_tags(out, r"\[A\]", r"\[/A\]")[0]
        function_name = self.__get_items_between_tags(out, r"\[F\]", r"\[/F\]")[0]
        input_parameters = self.__get_items_between_tags(out, r"\[I\]", r"\[/I\]")
        return_type = self.__get_items_between_tags(out, r"\[R\]", r"\[/R\]")[0]
        return function_name, input_parameters, return_type
    
    def auto_generate_function_signature_from_question(
        self,
        question,
        seed = 123,
        temperature = 1.0,
        top_p = 0.9,
        frequency_penalty = 0.0
    ):
        if self.verbose:
            print("LlamaManager::auto_generate_function_signature_from_question::Generating function signature from question")
            
        message_content = [
            {"role": "system", "content": """You are a synthetic data generator. 
                            You must answer the question between [A] and [/A] tags. 
                            The answer should include a function name, input parameters and return type.
                            The function name should be between [F] and [/F] tags.
                            Each input parameter should be between [I] and [/I] tags.
                            The return type should be between [R] and [/R] tags.
                            """},
            {"role": "user", "content": f"""Write me a function signature, input parameters and return type for the following question: 
                            Write a program that takes two positive integers as input and computes the sum of their digits using a for loop."""},
            {"role": "assistant", "content": f"[A][F]sum_of_digits[/F][I]num_1: int[/I][I]num_2: int[/I][R]int[/R][/A]"},
            {"role": "user", "content": f"Write me a function signature, input parameters and return type for the following question: {question}"},
            {"role": "assistant", "content": f"[A]"}
        ]
        
        out = self.client.chat_completion(
            messages = message_content,
            max_tokens = 1000,
            stream = False,
            seed = seed,
            temperature = temperature,
            top_p = top_p,
            frequency_penalty = frequency_penalty
        )
        
        function_name, input_parameters, return_type = self.__postprocess_for_auto_generate_function_signature_from_question(out.choices[0].message.content)
        if self.verbose:
            print("LlamaManager::auto_generate_function_signature_from_question::Generated function signature from question")
        
        return function_name, input_parameters, return_type
    
    
    def __postprocess_for_auto_generate_answers_and_tests(self, out):
        if self.verbose:
            print("LlamaManager::__postprocess_for_auto_generate_answers_and_tests::Postprocessing")
            
        out = self.__get_items_between_tags(out, r"\[A\]", r"\[/A\]")[0]
        answer = self.__get_items_between_tags(out, r"\[F\]", r"\[/F\]")[0]
        test_cases = self.__get_items_between_tags(out, r"\[T\]", r"\[/T\]")
        return answer, test_cases
    
    
    def auto_generate_answers_and_tests(
        self,
        question,
        function_name,
        input_parameters,
        return_type,
        seed = 123,
        temperature = 1.0,
        top_p = 0.9,
        frequency_penalty = 0.0
    ):
        if self.verbose:
            print("LlamaManager::auto_generate_answers_and_tests::Generating answers and test cases")
        
        function_signature = f"{function_name}({', '.join(input_parameters)}) -> {return_type}"
        
        message_content = [
            {"role": "system", "content": """You are a synthetic data generator. 
                            Your must answer the question between [A] and [/A] tags. 
                            The answer should include a function implementation and test cases.
                            The function implementation should be between [F] and [/F] tags.
                            Each test cases should be between [T] and [/T] tags.
                            Test cases must use assert statements.
                            Do not comment on the code. No need to explain the solution.
                            """},
            {"role": "user", "content": f"""Write me a function implementation along with the test cases for the following question: {question},
                            The function has the following signature: {function_signature}"""}
        ]
        
        out = self.client.chat_completion(
            messages = message_content,
            max_tokens = 1000,
            stream = False,
            seed = seed,
            temperature = temperature,
            top_p = top_p,
            frequency_penalty = frequency_penalty
        )
        
        answer, test_cases = self.__postprocess_for_auto_generate_answers_and_tests(out.choices[0].message.content)
        if self.verbose:
            print("LlamaManager::auto_generate_answers_and_tests::Generated answers and test cases")
        
        return answer, test_cases

        
if __name__ == "__main__":
    llama_manager = LlamaManager("nope", True)
    categories = llama_manager.auto_generate_questions_categories(20)
    shots = llama_manager.auto_generate_shots_for_category(2, categories[3])
    questions = llama_manager.auto_generate_questions_from_shots(10, categories[3], shots, temperature = 0.5)