import os from threading import Thread from typing import Iterator import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 total_count = 0 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) # Always initialize tokenizer and model outside the if block model_id = "beyoru/Neeru-RL2" model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.use_default_system_prompt = False from queue import Empty SYSTEM_PROMPT = """You are a helpful assistant with access to the following functions. Use them if required - [ { "type": "function", "function": { "name": "get_weather", "description": "Get the current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "The city and state"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit of temperature to return"}, }, "required": ["location"], }, }, } ,{ "type": "function", "function": { "name": "get_search", "description": "Search the web for a query", "parameters": { "type": "object", "properties": { "query": {"type": "string", "description": "The search query"} }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "get_warehouse_info", "description": "Get warehouse information based on multiple parameters", "parameters": { "type": "object", "properties": { "warehouse_id": { "type": "string", "description": "Unique identifier of the warehouse" }, "location": { "type": "string", "description": "Location of the warehouse" }, "status": { "type": "string", "enum": ["active", "inactive", "under maintenance"], "description": "Operational status of the warehouse" }, "capacity": { "type": "integer", "description": "Total storage capacity of the warehouse" }, "current_stock": { "type": "integer", "description": "Current stock available in the warehouse" }, "manager": { "type": "string", "description": "Name of the warehouse manager" }, "contact": { "type": "string", "description": "Contact details of the warehouse" }, "operating_hours": { "type": "string", "description": "Operating hours of the warehouse" } }, "required": ["warehouse_id", "location", "status"], }, }, }]. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function calls to achieve the purpose. If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out. For each function call, return a json object with function name and arguments: [{{"name": "", "arguments": }}] Respond in the following format: ... ... """ def generate( message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1, ) -> Iterator[str]: global total_count total_count += 1 print(total_count) # Only run `nvidia-smi` if a GPU is available if torch.cuda.is_available(): os.system("nvidia-smi") conversation = [] if system_prompt: conversation.append({"role": "system", "content": SYSTEM_PROMPT}) for user, assistant in chat_history: conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, # Set to True for sampling-based generation top_p=top_p, top_k=top_k, num_beams=1, repetition_penalty=repetition_penalty, eos_token_id=32021 ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] try: for text in streamer: outputs.append(text) yield "".join(outputs).replace("<|EOT|>", "") except Empty: print("Streamer did not produce output in time.") yield "Error: Streamer timed out without generating output." chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Textbox(label="System prompt", lines=6), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=512, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1, ), ], stop_btn=gr.Button("Stop"), examples=[ ["implement snake game using pygame"], ["Can you explain briefly to me what is the Python programming language?"], ["write a program to find the factorial of a number"], ], ) with gr.Blocks() as demo: chat_interface.render() if __name__ == "__main__": demo.queue(max_size=20).launch()