hello / app.py
Louxads's picture
Update app.py
ff77a67 verified
raw
history blame
4.79 kB
import os
import time
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Initialize embeddings and FAISS database
def load_embeddings():
return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
def load_faiss_db():
embeddings = load_embeddings()
return FAISS.load_local("ipc_embed_db", embeddings, allow_dangerous_deserialization=True)
# Load embeddings and FAISS database
embeddings = load_embeddings()
db = load_faiss_db()
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Define prompt template
prompt_template = """
<s>[INST]
As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses. Ensure your answers meet these criteria:
- Respond in a bullet-point format to clearly delineate distinct aspects of the legal query.
- Each point should accurately reflect the breadth of the legal provision in question, avoiding over-specificity unless directly relevant to the user's query.
- Clarify the general applicability of the legal rules or sections mentioned, highlighting any common misconceptions or frequently misunderstood aspects.
- Limit responses to essential information that directly addresses the user's question, providing concise yet comprehensive explanations.
- Avoid assuming specific contexts or details not provided in the query, focusing on delivering universally applicable legal interpretations unless otherwise specified.
- Conclude with a brief summary that captures the essence of the legal discussion and corrects any common misinterpretations related to the topic.
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
- [Detail the first key aspect of the law, ensuring it reflects general application]
- [Provide a concise explanation of how the law is typically interpreted or applied]
- [Correct a common misconception or clarify a frequently misunderstood aspect]
- [Detail any exceptions to the general rule, if applicable]
- [Include any additional relevant information that directly relates to the user's query]
</s>[INST]
"""
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history'])
# Load the InLegalBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
model = AutoModelForSequenceClassification.from_pretrained("law-ai/InLegalBERT")
# Function to get the model's response
def get_inlegalbert_response(question):
inputs = tokenizer(question, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
response = tokenizer.decode(torch.argmax(logits, dim=-1))
return response
# Define a wrapper for the model
class InLegalBERTWrapper:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def __call__(self, prompt, **kwargs):
return {"text": get_inlegalbert_response(prompt)}
llm = InLegalBERTWrapper(model, tokenizer)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True),
retriever=db_retriever,
combine_docs_chain_kwargs={'prompt': prompt}
)
def extract_answer(full_response):
answer_start = full_response.find("Response:")
if answer_start != -1:
answer_start += len("Response:")
return full_response[answer_start:].strip()
return full_response
def chat(input_prompt, messages):
if "messages" not in messages:
messages["messages"] = []
messages["messages"].append({"role": "user", "content": input_prompt})
result = qa.invoke(input=input_prompt)
answer = extract_answer(result["answer"])
messages["messages"].append({"role": "assistant", "content": answer})
return [(message["role"], message["content"]) for message in messages["messages"]], messages
with gr.Blocks() as demo:
gr.Markdown("## Stat.ai Legal Assistant")
chatbot = gr.Chatbot()
state = gr.State({"messages": []})
msg = gr.Textbox(placeholder="Ask Stat.ai")
def user_input(message, history):
history["messages"].append({"role": "user", "content": message})
return "", history
msg.submit(user_input, [msg, state], [msg, state], queue=False).then(
chat, [msg, state], [chatbot, state]
)
if __name__ == "__main__":
demo.launch()