|
import gradio as gr |
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_groq import ChatGroq |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain.chains import create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
""" |
|
|
|
model_name = "llama-3.3-70b-versatile" |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name = "pkshatech/GLuCoSE-base-ja" |
|
) |
|
vector_store = InMemoryVectorStore.load( |
|
"sop_vector_store", embeddings |
|
) |
|
retriever = vector_store.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
def fetch_response(groq_api_key, user_input): |
|
chat = ChatGroq( |
|
api_key = groq_api_key, |
|
model_name = model_name |
|
) |
|
system_prompt = ( |
|
"あなたは便利なアシスタントです。" |
|
"マニュアルの内容から回答してください。" |
|
"\n\n" |
|
"{context}" |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_prompt), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
question_answer_chain = create_stuff_documents_chain(chat, prompt) |
|
|
|
rag_chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
response = rag_chain.invoke({"input": user_input}) |
|
return [response["answer"], response["context"][0], response["context"][1]] |
|
|
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
with gr.Blocks() as demo: |
|
gr.Markdown('''# SOP事業マスター \n |
|
SOP作成研究に関して、公募要領やQAを参考にRAGを使って回答します。 |
|
''') |
|
with gr.Row(): |
|
api_key = gr.Textbox(label="Groq API key") |
|
with gr.Row(): |
|
with gr.Column(): |
|
user_input = gr.Textbox(label="User Input") |
|
submit = gr.Button("Submit") |
|
answer = gr.Textbox(label="Answer") |
|
with gr.Row(): |
|
with gr.Column(): |
|
source1 = gr.Textbox(label="回答ソース1") |
|
with gr.Column(): |
|
source2 = gr.Textbox(label="回答ソース2") |
|
submit.click(fetch_response, inputs=[api_key, user_input], outputs=[answer, source1, source2]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|