import streamlit as st
import pandas as pd

# 모델 준비하기
from transformers import RobertaForSequenceClassification, AutoTokenizer
import numpy as np
import pandas as pd
import torch
import os

# [theme]
# base="dark"
# primaryColor="purple"

# 제목 입력
st.header('한국표준산업분류 자동코딩 서비스')

# 재로드 안하도록
@st.experimental_memo(max_entries=20)
def md_loading():
    ## cpu
    # device = torch.device('cpu')

    tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495)

    model_checkpoint = 'upsampling_20.bin'
    project_path = './'
    output_model_file = os.path.join(project_path, model_checkpoint)

    model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))

    label_tbl = np.load('./label_table.npy')
    loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')

    print('ready')

    return tokenizer, model, label_tbl, loc_tbl

# 모델 로드
tokenizer, model, label_tbl, loc_tbl = md_loading()


# 텍스트 input 박스
business = st.text_input('사업체명').replace(',', '')
business_work = st.text_input('사업체 하는일').replace(',', '')
work_department = st.text_input('근무부서').replace(',', '')
work_position = st.text_input('직책').replace(',', '')
what_do_i = st.text_input('내가 하는 일').replace(',', '')

# md_input: 모델에 입력할 input 값 정의
md_input = ', '.join([business, business_work, work_department, work_position, what_do_i])

## 임시 확인
# st.write(md_input)

# 버튼
if st.button('확인'):
    ## 버튼 클릭 시 수행사항
    ### 모델 실행
    query_tokens = md_input.split(',')

    input_ids = np.zeros(shape=[1, 64])
    attention_mask = np.zeros(shape=[1, 64])

    seq = '[CLS] '
    try:
        for i in range(5):
            seq += query_tokens[i] + ' '
    except:
        None

    tokens = tokenizer.tokenize(seq)
    ids = tokenizer.convert_tokens_to_ids(tokens)

    length = len(ids)
    if length > 64:
        length = 64

    for i in range(length):
        input_ids[0, i] = ids[i]
        attention_mask[0, i] = 1

    input_ids = torch.from_numpy(input_ids).type(torch.long)
    attention_mask = torch.from_numpy(attention_mask).type(torch.long)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
    logits = outputs.logits

    # # 단독 예측 시
    # arg_idx = torch.argmax(logits, dim=1)
    # print('arg_idx:', arg_idx)

    # num_ans = label_tbl[arg_idx]
    # str_ans = loc_tbl['항목명'][loc_tbl['코드'] == num_ans].values

    # 상위 k번째까지 예측 시
    k = 10
    topk_idx = torch.topk(logits.flatten(), k).indices    

    num_ans_topk = label_tbl[topk_idx]
    str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]

    # print(num_ans, str_ans)
    # print(num_ans_topk)

    # print('사업체명:', query_tokens[0])
    # print('사업체 하는일:', query_tokens[1])
    # print('근무부서:', query_tokens[2])
    # print('직책:', query_tokens[3])
    # print('내가 하는일:', query_tokens[4])
    # print('산업코드 및 분류:', num_ans, str_ans)

    # ans = ''
    # ans1, ans2, ans3 = '', '', ''

    ## 모델 결과값 출력
    # st.write("산업코드 및 분류:", num_ans, str_ans[0])
    # st.write("세분류 코드")
    # for i in range(k):
    #     st.write(str(i+1) + '순위:', num_ans_topk[i], str_ans_topk[i].iloc[0])

    # print(num_ans)
    # print(str_ans, type(str_ans))

    str_ans_topk_list = []
    for i in range(k):
        str_ans_topk_list.append(str_ans_topk[i].iloc[0])

    # print(str_ans_topk_list)

    ans_topk_df = pd.DataFrame({
        'NO': range(1, k+1),
        '세분류 코드': num_ans_topk,
        '세분류 명칭': str_ans_topk_list
    })
    ans_topk_df = ans_topk_df.set_index('NO')

    st.dataframe(ans_topk_df)