tien314's picture
Update app.py
d0a4cfc verified
import streamlit as st
import bm25s
from operator import itemgetter
import os
import re
import pandas as pd
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.docstore.document import Document
@st.cache_data
def load_data():
df = pd.read_csv("cleaned_list.csv",header = None)
df.columns = ['document']
corpus = [doc for doc in df['document'].to_list()]
retriever = bm25s.BM25(corpus=corpus)
retriever.index(bm25s.tokenize(corpus))
return retriever
# def extract_hscode(text):
# match = re.search(r'hs_code:\s*(\d+)', text)
# if match:
# return match.group(1)
# return None
# df2 = pd.read_csv("hscode_main.csv")
# new_col = [len(str(code))for code in df2['hs_code'].to_list()]
# df2['len'] = new_col
# new_hscode = [str(code) for code in df2['hs_code']]
# for i in range(len(new_col)):
# if new_col[i]==5:
# new_hscode[i] = '0'+ new_hscode[i]
# df2['hs_code'] = new_hscode
# df2=df2.drop(columns='len')
# if 'retriever' not in st.session_state:
# st.session_state.retriever = None
# if st.session_state.retriever is None:
# st.session_state.retriever = load_data()
# sentence = st.text_input("please enter description:")
# if sentence !='':
# results,_ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=5)
# doc = [d for d in results]
# hscodes = [extract_hscode(item) for item in doc[0]]
# for code in hscodes:
# if len(code)==5:
# code = '0'+ code
# filter_df = df2[df2['hs_code']==code]
# answer = filter_df['description'].iloc[0]
# st.write("Hscode:",code)
# st.write("Description:",answer.lower())
def load_model():
prompt = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
f"""
Extract the appropriate 6-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
Only return the HS Code as a 6-digit number .
Example: 123456
Context: {{context}}
Description: {{description}}
Answer:
"""
)
])
#device = "cuda" if torch.cuda.is_available() else "cpu"
#llm = OllamaLLM(model="gemma2", temperature=0, device=device)
#api_key = "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"
api_key = "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"
llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key)
chain = prompt|llm
return chain
def process_input(sentence):
docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
documents =[]
for doc in docs[0]:
documents.append(Document(doc))
return documents
if 'retriever' not in st.session_state:
st.session_state.retriever = None
if 'chain' not in st.session_state:
st.session_state.chain = None
if st.session_state.retriever is None:
st.session_state.retriever = load_data()
if st.session_state.chain is None:
st.session_state.chain = load_model()
sentence = st.text_input("please enter description:")
if sentence !='':
documents = process_input(sentence)
hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
st.write("answer:",hscode.content)