File size: 6,824 Bytes
e44dcbf
 
 
 
 
27f604a
e44dcbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27f604a
e44dcbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27f604a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from PIL import Image
import gradio as gr
from threading import Thread
import spaces

# --- 1. Model and Processor Setup ---
# This part is loaded only once when the script starts.

try:
    model_id = "bharatgenai/patram-7b-instruct"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load processor and model
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16, # Use float16 for less memory usage on GPU
        device_map="auto",        # Automatically uses available GPUs
        trust_remote_code=True
    )
    print("Model and processor loaded successfully.")

    # --- Define and apply the chat template ---
    # This is crucial for multi-turn conversation
    chat_template = """{% for message in messages -%}
            {%- if (loop.index % 2 == 1 and message['role'] != 'user') or
              (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}
            {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
            {%- endif -%}
            {{ message['role'].capitalize() + ': ' + message['content'] }}
            {%- if not loop.last -%}
            {{ ' ' }}
            {%- endif %}
            {%- endfor -%}
            {%- if add_generation_prompt -%}
            {{ ' Assistant:' }}
            {%- endif %}"""
    processor.tokenizer.chat_template = chat_template

except Exception as e:
    print(f"Error during model loading: {e}")
    # Exit if model can't be loaded, as the app is unusable.
    exit()

# --- 2. Gradio Chatbot Logic with Streaming ---
@spaces.GPU
def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil):
    """
    This generator function handles the chat logic with streaming.
    It yields the updated chatbot display at each step.
    """
    # Check if an image has been uploaded
    if image_pil is None:
        chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
        yield chatbot_display, messages_list
        return # Stop the generator

    # Append user's message to the conversation history and display
    messages_list.append({"role": "user", "content": user_message})
    chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response

    try:
        # Use the processor to apply the chat template
        prompt = processor.tokenizer.apply_chat_template(
            messages_list,
            tokenize=False,
            add_generation_prompt=True
        )

        # Preprocess image and the entire formatted prompt
        inputs = processor.process(images=[image_pil], text=prompt)
        inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}

        # Setup the streamer
        streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Define generation configuration
        generation_config = GenerationConfig(
            max_new_tokens=512,
            do_sample=True,
            top_p=0.9,
            temperature=0.6,
            stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation
        )

        # Create generation kwargs for the thread
        generation_kwargs = dict(
            inputs,
            streamer=streamer,
            generation_config=generation_config
        )

        # Run generation in a separate thread
        thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs)
        thread.start()

        # Yield updates to the Gradio UI
        full_response = ""
        for new_text in streamer:
            full_response += new_text
            chatbot_display[-1] = (user_message, full_response)
            yield chatbot_display, messages_list

        # After the loop, the generation is complete.
        # Add the final full response to the messages list for context.
        messages_list.append({"role": "assistant", "content": full_response})
        yield chatbot_display, messages_list # Yield the final state

    except Exception as e:
        print(f"Error during streaming inference: {e}")
        error_message = f"Sorry, an error occurred: {e}"
        chatbot_display[-1] = (user_message, error_message)
        yield chatbot_display, messages_list

# --- 3. Gradio Interface Definition ---

with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
    gr.Markdown("# πŸ€– Patram-7B-Instruct Streaming Chatbot")
    gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.")

    # State variables to hold conversation history
    messages_list = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Image")
            clear_btn = gr.Button("πŸ—‘οΈ Clear Chat and Image")

        with gr.Column(scale=2):
            chatbot_display = gr.Chatbot(
                label="Conversation",
                bubble_full_width=False,
                height=500,
                avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg")
            )
            with gr.Row():
                user_textbox = gr.Textbox(
                    placeholder="Type your question here...",
                    show_label=False,
                    scale=4,
                    container=False
                )
                # The submit button is now primarily for show; Enter key is the main way to submit
                # but we will wire it up anyway.

    # --- Event Listeners ---

    # Define the action for submitting a message (via enter key)
    submit_action = user_textbox.submit(
        fn=process_chat_streaming,
        inputs=[user_textbox, chatbot_display, messages_list, image_input],
        outputs=[chatbot_display, messages_list],
        # queue=False # Set queue to False for faster interaction with streaming
    )

    # Chain the action to also clear the textbox after submission
    submit_action.then(
        fn=lambda: gr.update(value=""),
        inputs=None,
        outputs=[user_textbox],
        queue=False
    )

    # Define the action for the clear button
    clear_btn.click(
        fn=lambda: ([], [], None, ""),  # Function to return empty/default values
        inputs=[],
        outputs=[chatbot_display, messages_list, image_input, user_textbox],
        queue=False
    )


if __name__ == "__main__":
    demo.launch(debug=True, mcp_server=True)