File size: 3,585 Bytes
cd63cf6
 
103de27
cd63cf6
103de27
cd63cf6
 
 
 
 
 
 
 
 
785e68c
cd63cf6
785e68c
cd63cf6
 
 
fc12e9a
 
 
 
 
 
 
 
cd63cf6
785e68c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc12e9a
cd63cf6
d0752a6
cd63cf6
 
 
103de27
cd63cf6
785e68c
103de27
 
 
 
 
 
 
 
 
785e68c
103de27
 
 
 
 
 
 
 
 
 
 
 
fc12e9a
103de27
cd63cf6
 
 
 
 
103de27
fc12e9a
cd63cf6
103de27
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from src.get_stock_code import StockCode_getter
from main import main, init_csv
import csv
import openpyxl
from src.myLogger import set_logger

logger = set_logger("my_app", level="INFO")


# 入力された企業が1つの場合
def company_scoring(company_name):
    logger.info(f'企業名: {company_name}')
    stock_code_getter = StockCode_getter()
    header1_1, header1_2, header2, header3, header4, header5 = init_csv()
    stock_code = stock_code_getter.get_stock_code(company_name)
    group, related_url_list, _, _, scoring_result = main(company_name, stock_code)
    scoring_result_text = ''
    for score in scoring_result:
        scoring_result_text = scoring_result_text + str(score) + ','

    total = 0
    for i in range(0, 5):
        total += scoring_result[i]
    for i in range(6, 10):
        total += scoring_result[i]
    total = str(total)

    # url_list = ', '.join(urls)
    if group == 'Group 1-1':
        header = header1_1
    elif group == 'Group 1-2':
        header = header1_2
    elif group == 'Group 2':
        header = header2
    elif group == 'Group 3':
        header = header3
    elif group == 'Group 4':
        header = header4
    else:
        header = header5

    header = ', '.join(header)
    if group == 'Group 1-2' or group == 'Group 5':
        result = company_name + ',' + group + ',' + related_url_list
    else:
        result = company_name + ',' + group + ',' + scoring_result_text + total + ',' + related_url_list
    output = header + '\n' + result
    return output, None, None, None, None, None, None, None


# 入力企業が複数の場合
def company_list_scoring(file):
    stock_code_getter = StockCode_getter()
    _, _, _, _, _, _ = init_csv()
    path = 'data/Classification_result.csv'

    workbook = openpyxl.load_workbook(file)
    sheet = workbook.active
    for row in sheet.iter_rows(min_row=1, max_row=sheet.max_row, min_col=1, max_col=1):
        for cell in row:
            company_name = cell.value
            logger.info(f'User: {company_name}')
            stock_code = stock_code_getter.get_stock_code(company_name)
            group, related_url_list, unrelated_url_list, other_url_list, _ = main(company_name, stock_code)
            data = [company_name, group, related_url_list, unrelated_url_list, other_url_list]
            with open(path, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(data)

    return ('csvをダウンロードしてください', 'data/Classification_result.csv',
            'data/Group1-1_result.csv', 'data/Group1-2_result.csv',
            'data/Group2_result.csv', 'data/Group3_result.csv',
            'data/Group4_result.csv', 'data/Group5_result.csv')


def answer(input, file):
    if input == '':
        return company_list_scoring(file)
    else:
        return company_scoring(input)


demo = gr.Interface(fn=answer,
                    inputs=["textbox",
                            gr.File(label="input.csv")],
                    outputs=["textbox",
                             gr.File(label="Classification_result.csv"),
                             gr.File(label="Group1-1_result.csv"),
                             gr.File(label="Group1-2_result.csv"),
                             gr.File(label="Group2_result.csv"),
                             gr.File(label="Group3_result.csv"),
                             gr.File(label="Group4_result.csv"),
                             gr.File(label="Group5_result.csv")])

demo.launch(show_api=False, server_name="0.0.0.0")