Spaces:
Sleeping
Sleeping
import gradio as gr | |
from src.get_stock_code import StockCode_getter | |
from main import main, init_csv | |
import csv | |
import openpyxl | |
from src.myLogger import set_logger | |
logger = set_logger("my_app", level="INFO") | |
# 入力された企業が1つの場合 | |
def company_scoring(company_name): | |
logger.info(f'企業名: {company_name}') | |
stock_code_getter = StockCode_getter() | |
header1_1, header1_2, header2, header3, header4, header5 = init_csv() | |
stock_code = stock_code_getter.get_stock_code(company_name) | |
group, related_url_list, _, _, scoring_result = main(company_name, stock_code) | |
scoring_result_text = '' | |
for score in scoring_result: | |
scoring_result_text = scoring_result_text + str(score) + ',' | |
total = 0 | |
for i in range(0, 5): | |
total += scoring_result[i] | |
for i in range(6, 10): | |
total += scoring_result[i] | |
total = str(total) | |
# url_list = ', '.join(urls) | |
if group == 'Group 1-1': | |
header = header1_1 | |
elif group == 'Group 1-2': | |
header = header1_2 | |
elif group == 'Group 2': | |
header = header2 | |
elif group == 'Group 3': | |
header = header3 | |
elif group == 'Group 4': | |
header = header4 | |
else: | |
header = header5 | |
header = ', '.join(header) | |
if group == 'Group 1-2' or group == 'Group 5': | |
result = company_name + ',' + group + ',' + related_url_list | |
else: | |
result = company_name + ',' + group + ',' + scoring_result_text + total + ',' + related_url_list | |
output = header + '\n' + result | |
return output, None, None, None, None, None, None, None | |
# 入力企業が複数の場合 | |
def company_list_scoring(file): | |
stock_code_getter = StockCode_getter() | |
_, _, _, _, _, _ = init_csv() | |
path = 'data/Classification_result.csv' | |
workbook = openpyxl.load_workbook(file) | |
sheet = workbook.active | |
for row in sheet.iter_rows(min_row=1, max_row=sheet.max_row, min_col=1, max_col=1): | |
for cell in row: | |
company_name = cell.value | |
logger.info(f'User: {company_name}') | |
stock_code = stock_code_getter.get_stock_code(company_name) | |
group, related_url_list, unrelated_url_list, other_url_list, _ = main(company_name, stock_code) | |
data = [company_name, group, related_url_list, unrelated_url_list, other_url_list] | |
with open(path, mode='a', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow(data) | |
return ('csvをダウンロードしてください', 'data/Classification_result.csv', | |
'data/Group1-1_result.csv', 'data/Group1-2_result.csv', | |
'data/Group2_result.csv', 'data/Group3_result.csv', | |
'data/Group4_result.csv', 'data/Group5_result.csv') | |
def answer(input, file): | |
if input == '': | |
return company_list_scoring(file) | |
else: | |
return company_scoring(input) | |
demo = gr.Interface(fn=answer, | |
inputs=["textbox", | |
gr.File(label="input.csv")], | |
outputs=["textbox", | |
gr.File(label="Classification_result.csv"), | |
gr.File(label="Group1-1_result.csv"), | |
gr.File(label="Group1-2_result.csv"), | |
gr.File(label="Group2_result.csv"), | |
gr.File(label="Group3_result.csv"), | |
gr.File(label="Group4_result.csv"), | |
gr.File(label="Group5_result.csv")]) | |
demo.launch(show_api=False, server_name="0.0.0.0") | |