Spaces:
Sleeping
Sleeping
import os | |
import time | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Initialize embeddings and FAISS database | |
def load_embeddings(): | |
return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT") | |
def load_faiss_db(): | |
embeddings = load_embeddings() | |
return FAISS.load_local("ipc_embed_db", embeddings, allow_dangerous_deserialization=True) | |
# Load embeddings and FAISS database | |
embeddings = load_embeddings() | |
db = load_faiss_db() | |
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
# Define prompt template | |
prompt_template = """ | |
<s>[INST] | |
As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses. Ensure your answers meet these criteria: | |
- Respond in a bullet-point format to clearly delineate distinct aspects of the legal query. | |
- Each point should accurately reflect the breadth of the legal provision in question, avoiding over-specificity unless directly relevant to the user's query. | |
- Clarify the general applicability of the legal rules or sections mentioned, highlighting any common misconceptions or frequently misunderstood aspects. | |
- Limit responses to essential information that directly addresses the user's question, providing concise yet comprehensive explanations. | |
- Avoid assuming specific contexts or details not provided in the query, focusing on delivering universally applicable legal interpretations unless otherwise specified. | |
- Conclude with a brief summary that captures the essence of the legal discussion and corrects any common misinterpretations related to the topic. | |
CONTEXT: {context} | |
CHAT HISTORY: {chat_history} | |
QUESTION: {question} | |
ANSWER: | |
- [Detail the first key aspect of the law, ensuring it reflects general application] | |
- [Provide a concise explanation of how the law is typically interpreted or applied] | |
- [Correct a common misconception or clarify a frequently misunderstood aspect] | |
- [Detail any exceptions to the general rule, if applicable] | |
- [Include any additional relevant information that directly relates to the user's query] | |
</s>[INST] | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history']) | |
# Load the InLegalBERT model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT") | |
model = AutoModelForSequenceClassification.from_pretrained("law-ai/InLegalBERT") | |
# Function to get the model's response | |
def get_inlegalbert_response(question): | |
inputs = tokenizer(question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
response = tokenizer.decode(torch.argmax(logits, dim=-1)) | |
return response | |
# Define a wrapper for the model | |
class InLegalBERTWrapper: | |
def __init__(self, model, tokenizer): | |
self.model = model | |
self.tokenizer = tokenizer | |
def __call__(self, prompt, **kwargs): | |
return {"text": get_inlegalbert_response(prompt)} | |
llm = InLegalBERTWrapper(model, tokenizer) | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True), | |
retriever=db_retriever, | |
combine_docs_chain_kwargs={'prompt': prompt} | |
) | |
def extract_answer(full_response): | |
answer_start = full_response.find("Response:") | |
if answer_start != -1: | |
answer_start += len("Response:") | |
return full_response[answer_start:].strip() | |
return full_response | |
def chat(input_prompt, messages): | |
if "messages" not in messages: | |
messages["messages"] = [] | |
messages["messages"].append({"role": "user", "content": input_prompt}) | |
result = qa.invoke(input=input_prompt) | |
answer = extract_answer(result["answer"]) | |
messages["messages"].append({"role": "assistant", "content": answer}) | |
return [(message["role"], message["content"]) for message in messages["messages"]], messages | |
with gr.Blocks() as demo: | |
gr.Markdown("## Stat.ai Legal Assistant") | |
chatbot = gr.Chatbot() | |
state = gr.State({"messages": []}) | |
msg = gr.Textbox(placeholder="Ask Stat.ai") | |
def user_input(message, history): | |
history["messages"].append({"role": "user", "content": message}) | |
return "", history | |
msg.submit(user_input, [msg, state], [msg, state], queue=False).then( | |
chat, [msg, state], [chatbot, state] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |