import gradio as gr
from src.get_stock_code import StockCode_getter
from main import main, init_csv
import csv
import openpyxl
from src.myLogger import set_logger

logger = set_logger("my_app", level="INFO")


# 入力された企業が1つの場合
def company_scoring(company_name):
    logger.info(f'企業名: {company_name}')
    stock_code_getter = StockCode_getter()
    header1_1, header1_2, header2, header3, header4, header5 = init_csv()
    stock_code = stock_code_getter.get_stock_code(company_name)
    group, related_url_list, _, _, scoring_result = main(company_name, stock_code)
    scoring_result_text = ''
    for score in scoring_result:
        scoring_result_text = scoring_result_text + str(score) + ','

    total = 0
    for i in range(0, 5):
        total += scoring_result[i]
    for i in range(6, 10):
        total += scoring_result[i]
    total = str(total)

    # url_list = ', '.join(urls)
    if group == 'Group 1-1':
        header = header1_1
    elif group == 'Group 1-2':
        header = header1_2
    elif group == 'Group 2':
        header = header2
    elif group == 'Group 3':
        header = header3
    elif group == 'Group 4':
        header = header4
    else:
        header = header5

    header = ', '.join(header)
    if group == 'Group 1-2' or group == 'Group 5':
        result = company_name + ',' + group + ',' + related_url_list
    else:
        result = company_name + ',' + group + ',' + scoring_result_text + total + ',' + related_url_list
    output = header + '\n' + result
    return output, None, None, None, None, None, None, None


# 入力企業が複数の場合
def company_list_scoring(file):
    stock_code_getter = StockCode_getter()
    _, _, _, _, _, _ = init_csv()
    path = 'data/Classification_result.csv'

    workbook = openpyxl.load_workbook(file)
    sheet = workbook.active
    for row in sheet.iter_rows(min_row=1, max_row=sheet.max_row, min_col=1, max_col=1):
        for cell in row:
            company_name = cell.value
            logger.info(f'User: {company_name}')
            stock_code = stock_code_getter.get_stock_code(company_name)
            group, related_url_list, unrelated_url_list, other_url_list, _ = main(company_name, stock_code)
            data = [company_name, group, related_url_list, unrelated_url_list, other_url_list]
            with open(path, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(data)

    return ('csvをダウンロードしてください', 'data/Classification_result.csv',
            'data/Group1-1_result.csv', 'data/Group1-2_result.csv',
            'data/Group2_result.csv', 'data/Group3_result.csv',
            'data/Group4_result.csv', 'data/Group5_result.csv')


def answer(input, file):
    if input == '':
        return company_list_scoring(file)
    else:
        return company_scoring(input)


demo = gr.Interface(fn=answer,
                    inputs=["textbox",
                            gr.File(label="input.csv")],
                    outputs=["textbox",
                             gr.File(label="Classification_result.csv"),
                             gr.File(label="Group1-1_result.csv"),
                             gr.File(label="Group1-2_result.csv"),
                             gr.File(label="Group2_result.csv"),
                             gr.File(label="Group3_result.csv"),
                             gr.File(label="Group4_result.csv"),
                             gr.File(label="Group5_result.csv")])

demo.launch(show_api=False, server_name="0.0.0.0")