yyuri's picture
Upload app.py
fc12e9a verified
import gradio as gr
from src.get_stock_code import StockCode_getter
from main import main, init_csv
import csv
import openpyxl
from src.myLogger import set_logger
logger = set_logger("my_app", level="INFO")
# 入力された企業が1つの場合
def company_scoring(company_name):
logger.info(f'企業名: {company_name}')
stock_code_getter = StockCode_getter()
header1_1, header1_2, header2, header3, header4, header5 = init_csv()
stock_code = stock_code_getter.get_stock_code(company_name)
group, related_url_list, _, _, scoring_result = main(company_name, stock_code)
scoring_result_text = ''
for score in scoring_result:
scoring_result_text = scoring_result_text + str(score) + ','
total = 0
for i in range(0, 5):
total += scoring_result[i]
for i in range(6, 10):
total += scoring_result[i]
total = str(total)
# url_list = ', '.join(urls)
if group == 'Group 1-1':
header = header1_1
elif group == 'Group 1-2':
header = header1_2
elif group == 'Group 2':
header = header2
elif group == 'Group 3':
header = header3
elif group == 'Group 4':
header = header4
else:
header = header5
header = ', '.join(header)
if group == 'Group 1-2' or group == 'Group 5':
result = company_name + ',' + group + ',' + related_url_list
else:
result = company_name + ',' + group + ',' + scoring_result_text + total + ',' + related_url_list
output = header + '\n' + result
return output, None, None, None, None, None, None, None
# 入力企業が複数の場合
def company_list_scoring(file):
stock_code_getter = StockCode_getter()
_, _, _, _, _, _ = init_csv()
path = 'data/Classification_result.csv'
workbook = openpyxl.load_workbook(file)
sheet = workbook.active
for row in sheet.iter_rows(min_row=1, max_row=sheet.max_row, min_col=1, max_col=1):
for cell in row:
company_name = cell.value
logger.info(f'User: {company_name}')
stock_code = stock_code_getter.get_stock_code(company_name)
group, related_url_list, unrelated_url_list, other_url_list, _ = main(company_name, stock_code)
data = [company_name, group, related_url_list, unrelated_url_list, other_url_list]
with open(path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow(data)
return ('csvをダウンロードしてください', 'data/Classification_result.csv',
'data/Group1-1_result.csv', 'data/Group1-2_result.csv',
'data/Group2_result.csv', 'data/Group3_result.csv',
'data/Group4_result.csv', 'data/Group5_result.csv')
def answer(input, file):
if input == '':
return company_list_scoring(file)
else:
return company_scoring(input)
demo = gr.Interface(fn=answer,
inputs=["textbox",
gr.File(label="input.csv")],
outputs=["textbox",
gr.File(label="Classification_result.csv"),
gr.File(label="Group1-1_result.csv"),
gr.File(label="Group1-2_result.csv"),
gr.File(label="Group2_result.csv"),
gr.File(label="Group3_result.csv"),
gr.File(label="Group4_result.csv"),
gr.File(label="Group5_result.csv")])
demo.launch(show_api=False, server_name="0.0.0.0")