import subprocess
import re
import pandas as pd
import plotly.express as px
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from prettytable import PrettyTable
import streamlit as st

#st.title('Code Generation on the CoNaLa Dataset')

import subprocess
import re
import pandas as pd
import plotly.express as px
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from prettytable import PrettyTable

#browser.gatherUsageStats=False

class CodeGenerator:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
        self.model = AutoModelForSeq2SeqLM.from_pretrained("AhmedSSoliman/MarianCG-CoNaLa-Large")
        
    def generate_code(self, nl_input):
        input_ids = self.tokenizer.encode(nl_input, return_tensors="pt")
        output_ids = self.model.generate(input_ids)
        output_code = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return output_code

   

  
  
    def check_code(self, code):
        with open("temp.py", "w") as f:
            f.write(code)
        result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output = result.stdout.decode()
        error = result.stderr.decode()
        

        return output
        #return self._process_output(output, error)

    def check_code_list(self, code_list):
        output = ""
        error = ""
        for code in code_list:
            with open("temp.py", "w") as f:
                f.write(code)
            result = subprocess.run(["flake8", "--count", "temp.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            output += result.stdout.decode()
            error += result.stderr.decode()

        return self._process_output(output, error)

    def _process_output(self, output, error):
        if output:
            output_counts = self._get_error_counts(output)
            self.show_variables_in_table(output_counts, output)
            self.visualize_all_errors(output_counts)
            self.visualize_error_types(output_counts)
            
            return self._format_error_counts(output_counts)
        else:
            error_counts = self._get_error_counts(error)
            self.show_variables_in_table(output_counts, output)
            self.visualize_all_errors(error_counts)
            self.visualize_error_types(error_counts)

            return self._format_error_counts(error_counts)

    def _get_error_counts(self, output):
        error_counts = {}
        error_messages = re.findall(r"temp.py:(\d+):\d+: (\w\d+)", output)
        for message in error_messages:
            error_type = message[1]
            if error_type in error_counts:
                error_counts[error_type] += 1
            else:
                error_counts[error_type] = 1
        return error_counts

    def _format_error_counts(self, error_counts):
        error_message = "\n".join([f"{error_type}: {count}" for error_type, count in error_counts.items()])
        return error_message

    def visualize_all_errors(self, error_counts):
        for error_type, count in error_counts.items():
            print(f"{error_type}: {count}\n")


    def visualize_error_types(self, error_counts):
        df = pd.DataFrame({'Error Type': list(error_counts.keys()), 'Count': list(error_counts.values())})
        fig = px.bar(df, x='Count', y='Error Type', title='Error Occurrences in The Generated Code')
        fig.update_layout(
            title={
                'text': "Error Occurrences in The Generated Code",
                'x': 0.5,
                'y': 0.96,
                'xanchor': 'center',
                'yanchor': 'top'
            },
            xaxis_title="Error Counts",
            yaxis_title="Error Codes"
        )
        fig.show()

    def show_variables_in_table(self, output_counts, output):
        table = PrettyTable()
        table.field_names = ["Error Code", "Message"]
        table.add_row([output_counts, output])
        #table.add_row(["Error", error])
        print(table)

    def display_variables(self, output, error):
        output_df = pd.DataFrame({"Output": [output]})
        error_df = pd.DataFrame({"Error": [error]})
        display(pd.concat([output_df, error_df], axis=1))






import autopep8
import black
import isort
import pylint.lint
import autoimport
from yapf.yapflib.yapf_api import FormatCode  # reformat a string of code

class PythonCodeFormatter:
    def __init__(self, code):
        self.code = code.replace('▁', ' ').strip()


    def load_code_from_file(self, filename):
        # Load the code to be fixed
        with open(filename, 'r') as f:
            self.code = f.read()

    def format(self):
        try:
            # Use isort to sort and organize the imports
            formatted_code = isort.code(self.code)

            # Use black to format the code
            formatted_code = black.format_str(formatted_code, mode=black.Mode())

            # Use autoimport to add a missing import statement
            formatted_code = autoimport.fix_code(formatted_code)

            # Use autopep8 to fix any remaining issues
            formatted_code = autopep8.fix_code(formatted_code)

            formatted_code, changed = FormatCode(formatted_code)

            return formatted_code

        except RuntimeError as error:
            if str(error) == 'Project root not found.':
                return formatted_code
            else:
                raise  # re-raise the error if it's not the one we're looking for

        except ValueError as error:
            return formatted_code
            
        return formatted_code


    def save(self, filename):
        # Save the fixed code to a file
        with open(filename, 'w') as f:
            f.write(self.code)





code_generator = CodeGenerator()


# Streamlit app
def main():
    st.title('Code Generator and Error Checker')
    nl_input = st.text_area('Enter natural language input for code generation')
    if st.button('Generate Code'):
        # Generate code
        output_code = code_generator.generate_code(nl_input)
        st.subheader('Generated Code')
        st.code(output_code, language='python')

        # Check code for errors
        st.subheader('Error Check')
        error_message = code_generator.check_code(output_code)
        st.write('Error Counts:')
        st.write(error_message)


        st.subheader('Error Correction')
        formatter = PythonCodeFormatter(output_code)
        formatted_code = formatter.format()
        st.write('Code after correction:')
        st.write(formatted_code)

        


if __name__ == '__main__':
    
    main()