Spaces:
Paused
Paused
| import torch | |
| from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer | |
| from PIL import Image | |
| import gradio as gr | |
| from threading import Thread | |
| import spaces | |
| # --- 1. Model and Processor Setup --- | |
| # This part is loaded only once when the script starts. | |
| try: | |
| model_id = "bharatgenai/patram-7b-instruct" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load processor and model | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, # Use float16 for less memory usage on GPU | |
| device_map="auto", # Automatically uses available GPUs | |
| trust_remote_code=True | |
| ) | |
| print("Model and processor loaded successfully.") | |
| # --- Define and apply the chat template --- | |
| # This is crucial for multi-turn conversation | |
| chat_template = """{% for message in messages -%} | |
| {%- if (loop.index % 2 == 1 and message['role'] != 'user') or | |
| (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%} | |
| {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} | |
| {%- endif -%} | |
| {{ message['role'].capitalize() + ': ' + message['content'] }} | |
| {%- if not loop.last -%} | |
| {{ ' ' }} | |
| {%- endif %} | |
| {%- endfor -%} | |
| {%- if add_generation_prompt -%} | |
| {{ ' Assistant:' }} | |
| {%- endif %}""" | |
| processor.tokenizer.chat_template = chat_template | |
| except Exception as e: | |
| print(f"Error during model loading: {e}") | |
| # Exit if model can't be loaded, as the app is unusable. | |
| exit() | |
| # --- 2. Gradio Chatbot Logic with Streaming --- | |
| def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil): | |
| """ | |
| This generator function handles the chat logic with streaming. | |
| It yields the updated chatbot display at each step. | |
| """ | |
| # Check if an image has been uploaded | |
| if image_pil is None: | |
| chatbot_display.append((user_message, "Please upload an image first to start the conversation.")) | |
| yield chatbot_display, messages_list | |
| return # Stop the generator | |
| # Append user's message to the conversation history and display | |
| messages_list.append({"role": "user", "content": user_message}) | |
| chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response | |
| try: | |
| # Use the processor to apply the chat template | |
| prompt = processor.tokenizer.apply_chat_template( | |
| messages_list, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Preprocess image and the entire formatted prompt | |
| inputs = processor.process(images=[image_pil], text=prompt) | |
| inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} | |
| # Setup the streamer | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Define generation configuration | |
| generation_config = GenerationConfig( | |
| max_new_tokens=512, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.6, | |
| stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation | |
| ) | |
| # Create generation kwargs for the thread | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| generation_config=generation_config | |
| ) | |
| # Run generation in a separate thread | |
| thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield updates to the Gradio UI | |
| full_response = "" | |
| for new_text in streamer: | |
| full_response += new_text | |
| chatbot_display[-1] = (user_message, full_response) | |
| yield chatbot_display, messages_list | |
| # After the loop, the generation is complete. | |
| # Add the final full response to the messages list for context. | |
| messages_list.append({"role": "assistant", "content": full_response}) | |
| yield chatbot_display, messages_list # Yield the final state | |
| except Exception as e: | |
| print(f"Error during streaming inference: {e}") | |
| error_message = f"Sorry, an error occurred: {e}" | |
| chatbot_display[-1] = (user_message, error_message) | |
| yield chatbot_display, messages_list | |
| # --- 3. Gradio Interface Definition --- | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: | |
| gr.Markdown("# π€ Patram-7B-Instruct Streaming Chatbot") | |
| gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.") | |
| # State variables to hold conversation history | |
| messages_list = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| clear_btn = gr.Button("ποΈ Clear Chat and Image") | |
| with gr.Column(scale=2): | |
| chatbot_display = gr.Chatbot( | |
| label="Conversation", | |
| bubble_full_width=False, | |
| height=500, | |
| avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg") | |
| ) | |
| with gr.Row(): | |
| user_textbox = gr.Textbox( | |
| placeholder="Type your question here...", | |
| show_label=False, | |
| scale=4, | |
| container=False | |
| ) | |
| # The submit button is now primarily for show; Enter key is the main way to submit | |
| # but we will wire it up anyway. | |
| # --- Event Listeners --- | |
| # Define the action for submitting a message (via enter key) | |
| submit_action = user_textbox.submit( | |
| fn=process_chat_streaming, | |
| inputs=[user_textbox, chatbot_display, messages_list, image_input], | |
| outputs=[chatbot_display, messages_list], | |
| # queue=False # Set queue to False for faster interaction with streaming | |
| ) | |
| # Chain the action to also clear the textbox after submission | |
| submit_action.then( | |
| fn=lambda: gr.update(value=""), | |
| inputs=None, | |
| outputs=[user_textbox], | |
| queue=False | |
| ) | |
| # Define the action for the clear button | |
| clear_btn.click( | |
| fn=lambda: ([], [], None, ""), # Function to return empty/default values | |
| inputs=[], | |
| outputs=[chatbot_display, messages_list, image_input, user_textbox], | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, mcp_server=True) |