tien314 commited on
Commit
cadbe26
·
verified ·
1 Parent(s): 2b4d4e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import bm25s
4
+ from bm25s.hf import BM25HF
5
+ from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
6
+ from langchain.docstore.document import Document
7
+ import torch
8
+ import os
9
+ from huggingface_hub import login
10
+ from langchain_groq import ChatGroq
11
+
12
+
13
+ @st.cache_resource
14
+ def load_data():
15
+ retriever = BM25HF.load_from_hub(
16
+ "tien314/hs8", load_corpus=True, mmap=True)
17
+ return retriever
18
+
19
+ def load_model():
20
+ prompt = ChatPromptTemplate.from_messages([
21
+ HumanMessagePromptTemplate.from_template(
22
+ f"""
23
+ Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
24
+ Only return the HS Code as a 8-digit number .
25
+ Example: 1234567878
26
+ Context: {{context}}
27
+ Description: {{description}}
28
+ Answer:
29
+ """
30
+ )
31
+ ])
32
+
33
+
34
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ #llm = OllamaLLM(model="gemma2", temperature=0, device=device)
37
+ #api_key = "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"
38
+ api_key = "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"
39
+ llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key)
40
+ chain = prompt|llm
41
+ return chain
42
+
43
+ def process_input(sentence):
44
+ docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
45
+ documents =[]
46
+ for doc in docs[0]:
47
+ documents.append(Document(doc['text']))
48
+ return documents
49
+
50
+ if 'retriever' not in st.session_state:
51
+ st.session_state.retriever = None
52
+
53
+ if 'chain' not in st.session_state:
54
+ st.session_state.chain = None
55
+
56
+ if st.session_state.retriever is None:
57
+ st.session_state.retriever = load_data()
58
+
59
+ if st.session_state.chain is None:
60
+ st.session_state.chain = load_model()
61
+
62
+ sentence = st.text_input("please enter description:")
63
+
64
+ if sentence !='':
65
+ documents = process_input(sentence)
66
+ hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
67
+ st.write("answer:",hscode.content)