File size: 3,481 Bytes
28621b1
 
 
 
 
4ce10ee
e2982b0
3956fff
 
28621b1
 
 
7b04950
28621b1
 
 
 
ccf7fbf
28621b1
 
 
e2982b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28621b1
e2982b0
 
28621b1
 
e2982b0
28621b1
e2982b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a4cfc
e2982b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a4cfc
e2982b0
 
28621b1
 
 
e2982b0
 
 
28621b1
6415b25
28621b1
e2982b0
 
 
28621b1
 
 
e2982b0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import bm25s
from operator import itemgetter
import os
import re
import pandas as pd
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.docstore.document import Document

@st.cache_data
def load_data():
    df = pd.read_csv("cleaned_list.csv",header = None)
    df.columns = ['document']
    corpus = [doc for doc in df['document'].to_list()]

    retriever = bm25s.BM25(corpus=corpus)
    retriever.index(bm25s.tokenize(corpus))

    return retriever

# def extract_hscode(text):
#     match = re.search(r'hs_code:\s*(\d+)', text)
#     if match:
#         return match.group(1)
#     return None

# df2 = pd.read_csv("hscode_main.csv")
# new_col = [len(str(code))for code in df2['hs_code'].to_list()]
# df2['len'] = new_col

# new_hscode = [str(code) for code in df2['hs_code']]

# for i in range(len(new_col)):
#     if new_col[i]==5:
#         new_hscode[i] = '0'+ new_hscode[i]
# df2['hs_code'] = new_hscode
# df2=df2.drop(columns='len')

# if 'retriever' not in st.session_state:
#     st.session_state.retriever = None

# if st.session_state.retriever is None:
#     st.session_state.retriever = load_data()


# sentence = st.text_input("please enter description:")

# if sentence !='':
#     results,_ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=5)
#     doc = [d for d in results]
#     hscodes = [extract_hscode(item) for item in doc[0]]
#     for code in hscodes:
#         if len(code)==5:
#             code = '0'+ code

#         filter_df = df2[df2['hs_code']==code]
#         answer = filter_df['description'].iloc[0]
#         st.write("Hscode:",code)
#         st.write("Description:",answer.lower())

def load_model():
    prompt = ChatPromptTemplate.from_messages([
        HumanMessagePromptTemplate.from_template(
        f"""
        Extract the appropriate 6-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
        Only return the HS Code as a 6-digit number .
        Example: 123456
        Context: {{context}}
        Description: {{description}}
        Answer:
        """
        )
    ])
    

    #device = "cuda" if torch.cuda.is_available() else "cpu"
    
    #llm = OllamaLLM(model="gemma2", temperature=0, device=device)
    #api_key = "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"
    api_key = "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"
    llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key)
    chain = prompt|llm
    return chain

def process_input(sentence):
    docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
    documents =[]
    for doc in docs[0]:
        documents.append(Document(doc)) 
    return documents
    
if 'retriever' not in st.session_state:
    st.session_state.retriever = None

if 'chain' not in st.session_state:
    st.session_state.chain = None
    
if st.session_state.retriever is None:
    st.session_state.retriever = load_data()

if st.session_state.chain is None:
    st.session_state.chain = load_model()
    
sentence = st.text_input("please enter description:")

if sentence !='':
    documents = process_input(sentence)
    hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
    st.write("answer:",hscode.content)