import torch from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer from PIL import Image import gradio as gr from threading import Thread import spaces # --- 1. Model and Processor Setup --- # This part is loaded only once when the script starts. try: model_id = "bharatgenai/patram-7b-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load processor and model processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, # Use float16 for less memory usage on GPU device_map="auto", # Automatically uses available GPUs trust_remote_code=True ) print("Model and processor loaded successfully.") # --- Define and apply the chat template --- # This is crucial for multi-turn conversation chat_template = """{% for message in messages -%} {%- if (loop.index % 2 == 1 and message['role'] != 'user') or (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%} {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} {%- endif -%} {{ message['role'].capitalize() + ': ' + message['content'] }} {%- if not loop.last -%} {{ ' ' }} {%- endif %} {%- endfor -%} {%- if add_generation_prompt -%} {{ ' Assistant:' }} {%- endif %}""" processor.tokenizer.chat_template = chat_template except Exception as e: print(f"Error during model loading: {e}") # Exit if model can't be loaded, as the app is unusable. exit() # --- 2. Gradio Chatbot Logic with Streaming --- @spaces.GPU def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil): """ This generator function handles the chat logic with streaming. It yields the updated chatbot display at each step. """ # Check if an image has been uploaded if image_pil is None: chatbot_display.append((user_message, "Please upload an image first to start the conversation.")) yield chatbot_display, messages_list return # Stop the generator # Append user's message to the conversation history and display messages_list.append({"role": "user", "content": user_message}) chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response try: # Use the processor to apply the chat template prompt = processor.tokenizer.apply_chat_template( messages_list, tokenize=False, add_generation_prompt=True ) # Preprocess image and the entire formatted prompt inputs = processor.process(images=[image_pil], text=prompt) inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} # Setup the streamer streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) # Define generation configuration generation_config = GenerationConfig( max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation ) # Create generation kwargs for the thread generation_kwargs = dict( inputs, streamer=streamer, generation_config=generation_config ) # Run generation in a separate thread thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs) thread.start() # Yield updates to the Gradio UI full_response = "" for new_text in streamer: full_response += new_text chatbot_display[-1] = (user_message, full_response) yield chatbot_display, messages_list # After the loop, the generation is complete. # Add the final full response to the messages list for context. messages_list.append({"role": "assistant", "content": full_response}) yield chatbot_display, messages_list # Yield the final state except Exception as e: print(f"Error during streaming inference: {e}") error_message = f"Sorry, an error occurred: {e}" chatbot_display[-1] = (user_message, error_message) yield chatbot_display, messages_list # --- 3. Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: gr.Markdown("# 🤖 Patram-7B-Instruct Streaming Chatbot") gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.") # State variables to hold conversation history messages_list = gr.State([]) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") clear_btn = gr.Button("🗑️ Clear Chat and Image") with gr.Column(scale=2): chatbot_display = gr.Chatbot( label="Conversation", bubble_full_width=False, height=500, avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg") ) with gr.Row(): user_textbox = gr.Textbox( placeholder="Type your question here...", show_label=False, scale=4, container=False ) # The submit button is now primarily for show; Enter key is the main way to submit # but we will wire it up anyway. # --- Event Listeners --- # Define the action for submitting a message (via enter key) submit_action = user_textbox.submit( fn=process_chat_streaming, inputs=[user_textbox, chatbot_display, messages_list, image_input], outputs=[chatbot_display, messages_list], # queue=False # Set queue to False for faster interaction with streaming ) # Chain the action to also clear the textbox after submission submit_action.then( fn=lambda: gr.update(value=""), inputs=None, outputs=[user_textbox], queue=False ) # Define the action for the clear button clear_btn.click( fn=lambda: ([], [], None, ""), # Function to return empty/default values inputs=[], outputs=[chatbot_display, messages_list, image_input, user_textbox], queue=False ) if __name__ == "__main__": demo.launch(debug=True, mcp_server=True)