Spaces:
Sleeping
Sleeping
File size: 4,259 Bytes
8e0ab8f 5534d9e 896a3c2 9c2dbe4 8e0ab8f 9c2dbe4 c5ecffb 9c2dbe4 896a3c2 9c2dbe4 3e348db 7c68a15 e5d4b35 058b8bc c5ecffb 058b8bc 9c2dbe4 c5ecffb b3af2b6 e5d4b35 b3af2b6 4419107 058b8bc e5d4b35 ed49a22 058b8bc 9c2dbe4 6cb1f28 9c2dbe4 e223506 9c2dbe4 7c68a15 058b8bc 42698e2 058b8bc ac3be44 7c68a15 e8088bd 7c68a15 14e4a65 42698e2 14e4a65 7c68a15 42698e2 058b8bc e5d4b35 8e0ab8f 31b9fec feb6b5f c36f45d 8e0ab8f c5ecffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import InferenceClient
import re
import torch
# Model and tokenizer loading (outside the respond function)
try:
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
peft_model = peft_model.merge_and_unload()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
tokenizer = None
base_model = None
peft_model = None
def respond(message, history, system_message, max_tokens, temperature, top_p):
global tokenizer, peft_model
if tokenizer is None or peft_model is None:
return "Model loading failed. Please check the logs."
prompt = ''
# Construct the prompt
prompt = system_message
# for user_msg, assistant_msg in history:
# if user_msg:
# prompt += f"\nUser: {user_msg} output: "
# if assistant_msg:
# prompt += f"\nAssistant: {assistant_msg}"
prompt += f"\n<user> input:{message} output:"
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
try:
outputs = peft_model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
# top_p=top_p,
# do_sample=True
)
generated_text = None
# Get the last message from the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Generation error: {e}"
def extract_user_content(text):
"""
Extracts and returns the content that follows the word 'output' in the given text.
If 'output' is not found, returns an empty string.
"""
# Define the regular expression pattern to match 'output' and capture everything after it
pattern = re.compile('<user>(.*?)</user>', re.IGNORECASE | re.DOTALL)
match = pattern.search(text)
if match:
# Return the captured group, which is the content after 'output'
return match.group(1).strip()
else:
# If 'output' is not found, return an empty string
return "Retry to get output, the model failed to generated required output(This occurs rarely🤷♂️)"
print(generated_text)
lines = extract_user_content(generated_text)
print(lines)
return lines
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, Here's an example input and output of the task <Example> Exampleinput : 2012–13 में रक्षा सेवाओं के लिए 1,93,407 करोड़ रुपए का प्रावधान किया गया था, जबकि 2011–2012 में यह राशि 1,64,415 करोइ़ थी, Exampleoutput: ट्वेन्टी ट्वेल्व फिफ्टीन में रक्षा सेवाओं के लिए वन करोड़ निनेटी थ्री थाउजेंड फोर हंड्रेड सेवन करोड़ रुपए का प्रावधान किया गया था, जबकि ट्वेन्टी एलेवन ट्वेल्व में यह राशि वन करोड़ सिक्स्टी फोर थाउजेंड फोर हंड्रेड फिफ्टीन करोड़ थी </Example>. Understand the task and Only provide the normalized output with atmost accuracy",label="System message"),
gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.95, step=0.1, label="Temperature")
],
)
if __name__ == "__main__":
demo.launch() |