File size: 4,259 Bytes
8e0ab8f
5534d9e
 
896a3c2
9c2dbe4
 
8e0ab8f
9c2dbe4
 
 
 
 
 
c5ecffb
9c2dbe4
 
 
896a3c2
9c2dbe4
3e348db
7c68a15
e5d4b35
058b8bc
c5ecffb
058b8bc
9c2dbe4
c5ecffb
b3af2b6
e5d4b35
 
b3af2b6
 
 
 
 
4419107
058b8bc
e5d4b35
ed49a22
058b8bc
9c2dbe4
6cb1f28
9c2dbe4
 
 
e223506
 
9c2dbe4
7c68a15
058b8bc
 
42698e2
 
058b8bc
ac3be44
7c68a15
 
 
 
 
e8088bd
7c68a15
 
 
 
 
 
 
14e4a65
42698e2
14e4a65
7c68a15
42698e2
 
058b8bc
e5d4b35
8e0ab8f
 
 
31b9fec
feb6b5f
c36f45d
8e0ab8f
 
 
 
c5ecffb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import InferenceClient
import re
import torch

# Model and tokenizer loading (outside the respond function)
try:
    tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
    base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
    peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
    peft_model = peft_model.merge_and_unload()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    tokenizer = None
    base_model = None
    peft_model = None
    


def respond(message, history, system_message, max_tokens, temperature, top_p):
    global tokenizer, peft_model
    
    if tokenizer is None or peft_model is None:
        return "Model loading failed. Please check the logs."
    prompt = ''
    # Construct the prompt
    prompt = system_message
    # for user_msg, assistant_msg in history:
    #     if user_msg:
    #         prompt += f"\nUser: {user_msg} output: "
    #     if assistant_msg:
    #         prompt += f"\nAssistant: {assistant_msg}"
    prompt += f"\n<user> input:{message} output:"
    
    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    
    try:
        outputs = peft_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            # top_p=top_p,
            # do_sample=True
        )
        generated_text = None
        # Get the last message from the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
            return f"Generation error: {e}"
        
    def extract_user_content(text):
        """
        Extracts and returns the content that follows the word 'output' in the given text.
        If 'output' is not found, returns an empty string.
        """
        # Define the regular expression pattern to match 'output' and capture everything after it
        pattern = re.compile('<user>(.*?)</user>', re.IGNORECASE | re.DOTALL)
        match = pattern.search(text)
        if match:
            # Return the captured group, which is the content after 'output'
            return match.group(1).strip()
        else:
            # If 'output' is not found, return an empty string
            return "Retry to get output, the model failed to generated required output(This occurs rarely🤷‍♂️)"
    print(generated_text)        
    lines = extract_user_content(generated_text)
    print(lines)
            
    return lines
        
        

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, Here's an example input and output of the task <Example> Exampleinput :  2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>. Understand the task and Only provide the normalized output with atmost accuracy",label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.95, step=0.1, label="Temperature")
    ],
)

if __name__ == "__main__":
    demo.launch()