File size: 5,302 Bytes
07f742d
 
 
a3723dd
 
 
 
 
 
 
07f742d
 
 
 
 
 
a3723dd
07f742d
a3723dd
1f2b621
 
 
 
 
 
07f742d
 
 
 
 
a3723dd
07f742d
 
 
 
 
 
 
 
 
 
 
a3723dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f742d
 
 
 
 
 
 
a02c3cf
 
07f742d
 
 
 
6f6e53e
07f742d
 
a3723dd
07f742d
a3723dd
 
 
 
 
 
 
 
 
 
 
 
 
07f742d
a02c3cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from threading import Thread

import gradio as gr
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
# import torch
from text_generation import Client, InferenceAPIClient

client = Client("http://20.83.177.108:8080")


def run_generation_stream(user_text, f, max_new_tokens, temperature):
    # Get the model and tokenizer, and tokenize the user text.
    print('called stream')

    if len(user_text.strip()) == 0:
        print('blank')
        gr.Warning('Please enter a question to continue')
        return

    user_text = f"""You are an expert legal assistant with extensive knowledge about Indian law. Your task is to respond to the given query in a consice and factually correct manner. Also mention the relevant sections of the law wherever applicable.
    ### Input: {user_text}
    ### Response: """

    text = ""
    for response in client.generate_stream(user_text, max_new_tokens=max_new_tokens, repetition_penalty=1.05, temperature=temperature):
        if not response.token.special:
            text += response.token.text
            yield text

    return text


def reset_textbox():
    return gr.update(value='')


model_name = "BAAI/bge-base-en"
# set True to compute cosine similarity
encode_kwargs = {'normalize_embeddings': True}

model_norm = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    encode_kwargs=encode_kwargs
)


vectordb = FAISS.load_local('faissdb', embeddings=model_norm)
retriever = vectordb.as_retriever(
    search_type='similarity', search_kwargs={"k": 5})


# relating to refer to Indian Penal Code(IPC), CrPC(Code of Criminal Procedure) for most cases and therefore laws
prompt_template = """You are an expert legal assistant with extensive knowledge about Indian law. Your task is to respond to the given query in a factually correct and consise manner unless asked for a detailed explanation. Assume the query is asked by a common man unless explicitly specified otherwise, therefore no special acts or laws like ones for railway , army , police would apply to them. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question}
Response:"""


PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)


def run_generation(query, factual, max_tokens, temperature):
    print('called non stream')

    llm = HuggingFaceTextGenInference(
        inference_server_url="http://20.83.177.108:8080/",
        max_new_tokens=max_tokens,
        top_k=10,
        top_p=0.95,
        typical_p=0.95,
        temperature=temperature,
        streaming=True if factual else False,
        # repetition_penalty=1.1,
    )

    qa_chain = RetrievalQA.from_chain_type(llm=llm,
                                           chain_type_kwargs={
                                               "prompt": PROMPT},
                                           retriever=retriever,
                                           return_source_documents=True,
                                           )

    # text = ""
    # if factual:
    #     response = llm(query, callbacks=[StreamingStdOutCallbackHandler()])
    #     print(response)
    #     # text += response
    #     yield response

    # else:
    llm_response = qa_chain(query)
    print(llm_response['result'])
    return llm_response['result']


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder="What is the punishment for taking dowry. explain in detail.",
                label="Question"
            )
            model_output = gr.Textbox(
                label="AI Response", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=1, maximum=1000, value=250, step=10, interactive=True, label="Number of words to generate",
            )
            temperature = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True, label="Randomness(can be between 0-1, 0 being least random)",
            )
            factual = gr.Checkbox(
                label='Turn on to get factually correct answers')

    # user_text.submit(run_generation, [
    #                  user_text, top_p, temperature, top_k, max_new_tokens], model_output)
    # button_submit.click(run_generation, [
    #                     user_text, top_p, temperature, top_k, max_new_tokens], model_output)

    # user_text.submit(run_generation, [
    #     user_text, factual, max_new_tokens, temperature], model_output)
    print('fac', factual.value)
    button_submit.click(run_generation if factual.value else run_generation_stream, [
                        user_text, factual, max_new_tokens, temperature], model_output)

    demo.queue(max_size=32).launch(enable_queue=True)