import torch from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig from PIL import Image import gradio as gr import spaces # --- 1. Model and Processor Setup --- model_id = "bharatgenai/patram-7b-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load processor and model processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, # Use float16 for less memory usage on GPU device_map="auto", # Automatically uses available GPUs trust_remote_code=True ) print("Model and processor loaded successfully.") # --- Define and apply the chat template --- chat_template = """{% for message in messages -%} {%- if (loop.index % 2 == 1 and message['role'] != 'user') or (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%} {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} {%- endif -%} {{ message['role'].capitalize() + ': ' + message['content'] }} {%- if not loop.last -%} {{ ' ' }} {%- endif %} {%- endfor -%} {%- if add_generation_prompt -%} {{ ' Assistant:' }} {%- endif %}""" processor.tokenizer.chat_template = chat_template # --- 2. Gradio Chatbot Logic --- @spaces.GPU def process_chat(user_message, chatbot_display, messages_list, image_pil): if image_pil is None: chatbot_display.append((user_message, "Please upload an image first to start the conversation.")) return chatbot_display, messages_list, "" messages_list.append({"role": "user", "content": user_message}) chatbot_display.append((user_message, None)) try: prompt = processor.tokenizer.apply_chat_template( messages_list, tokenize=False, add_generation_prompt=True ) # Preprocess image and the entire formatted prompt inputs = processor.process(images=[image_pil], text=prompt) inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} # Ensure all tensors are in the same dtype inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} # Generate output using model's specific method output = model.generate_from_batch( inputs, GenerationConfig(max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings="<|endoftext|>"), tokenizer=processor.tokenizer ) generated_tokens = output[0, inputs['input_ids'].size(1):] response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() messages_list.append({"role": "assistant", "content": response}) chatbot_display[-1] = (user_message, response) except Exception as e: print(f"Error during inference: {e}") error_message = f"Sorry, an error occurred during processing: {e}" chatbot_display[-1] = (user_message, error_message) return chatbot_display, messages_list, "" def clear_chat(chatbot_display, messages_list, image_input): """Resets the chat, history, and image.""" return [], [], None, "Type your question here..." # --- 3. Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: gr.Markdown("# 🤖 Patram-7B-Instruct Chatbot") gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.") messages_list = gr.State([]) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") clear_btn = gr.Button("🗑️ Clear Chat and Image") with gr.Column(scale=2): chatbot_display = gr.Chatbot( label="Conversation", bubble_full_width=False, height=500 ) with gr.Row(): user_textbox = gr.Textbox( placeholder="Type your question here...", show_label=False, scale=4, container=False ) submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0) # --- Event Listeners --- submit_action = user_textbox.submit( fn=process_chat, inputs=[user_textbox, chatbot_display, messages_list, image_input], outputs=[chatbot_display, messages_list, user_textbox] ) submit_btn.click( fn=process_chat, inputs=[user_textbox, chatbot_display, messages_list, image_input], outputs=[chatbot_display, messages_list, user_textbox] ) clear_btn.click( fn=lambda: ([], [], None, ""), inputs=[], outputs=[chatbot_display, messages_list, image_input, user_textbox], queue=False ) if __name__ == "__main__": demo.launch(mcp_server=True)