Spaces:
Sleeping
Sleeping
import streamlit as st | |
import bm25s | |
from operator import itemgetter | |
import os | |
import re | |
import pandas as pd | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate | |
from langchain.docstore.document import Document | |
def load_data(): | |
df = pd.read_csv("cleaned_list.csv",header = None) | |
df.columns = ['document'] | |
corpus = [doc for doc in df['document'].to_list()] | |
retriever = bm25s.BM25(corpus=corpus) | |
retriever.index(bm25s.tokenize(corpus)) | |
return retriever | |
# def extract_hscode(text): | |
# match = re.search(r'hs_code:\s*(\d+)', text) | |
# if match: | |
# return match.group(1) | |
# return None | |
# df2 = pd.read_csv("hscode_main.csv") | |
# new_col = [len(str(code))for code in df2['hs_code'].to_list()] | |
# df2['len'] = new_col | |
# new_hscode = [str(code) for code in df2['hs_code']] | |
# for i in range(len(new_col)): | |
# if new_col[i]==5: | |
# new_hscode[i] = '0'+ new_hscode[i] | |
# df2['hs_code'] = new_hscode | |
# df2=df2.drop(columns='len') | |
# if 'retriever' not in st.session_state: | |
# st.session_state.retriever = None | |
# if st.session_state.retriever is None: | |
# st.session_state.retriever = load_data() | |
# sentence = st.text_input("please enter description:") | |
# if sentence !='': | |
# results,_ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=5) | |
# doc = [d for d in results] | |
# hscodes = [extract_hscode(item) for item in doc[0]] | |
# for code in hscodes: | |
# if len(code)==5: | |
# code = '0'+ code | |
# filter_df = df2[df2['hs_code']==code] | |
# answer = filter_df['description'].iloc[0] | |
# st.write("Hscode:",code) | |
# st.write("Description:",answer.lower()) | |
def load_model(): | |
prompt = ChatPromptTemplate.from_messages([ | |
HumanMessagePromptTemplate.from_template( | |
f""" | |
Extract the appropriate 6-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. | |
Only return the HS Code as a 6-digit number . | |
Example: 123456 | |
Context: {{context}} | |
Description: {{description}} | |
Answer: | |
""" | |
) | |
]) | |
#device = "cuda" if torch.cuda.is_available() else "cpu" | |
#llm = OllamaLLM(model="gemma2", temperature=0, device=device) | |
#api_key = "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" | |
api_key = "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" | |
llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key) | |
chain = prompt|llm | |
return chain | |
def process_input(sentence): | |
docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15) | |
documents =[] | |
for doc in docs[0]: | |
documents.append(Document(doc)) | |
return documents | |
if 'retriever' not in st.session_state: | |
st.session_state.retriever = None | |
if 'chain' not in st.session_state: | |
st.session_state.chain = None | |
if st.session_state.retriever is None: | |
st.session_state.retriever = load_data() | |
if st.session_state.chain is None: | |
st.session_state.chain = load_model() | |
sentence = st.text_input("please enter description:") | |
if sentence !='': | |
documents = process_input(sentence) | |
hscode = st.session_state.chain.invoke({'context': documents,'description':sentence}) | |
st.write("answer:",hscode.content) |