KingNish commited on
Commit
e44dcbf
·
verified ·
1 Parent(s): b19004c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from threading import Thread
6
+
7
+ # --- 1. Model and Processor Setup ---
8
+ # This part is loaded only once when the script starts.
9
+
10
+ try:
11
+ model_id = "bharatgenai/patram-7b-instruct"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
+ # Load processor and model
16
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ torch_dtype=torch.float16, # Use float16 for less memory usage on GPU
20
+ device_map="auto", # Automatically uses available GPUs
21
+ trust_remote_code=True
22
+ )
23
+ print("Model and processor loaded successfully.")
24
+
25
+ # --- Define and apply the chat template ---
26
+ # This is crucial for multi-turn conversation
27
+ chat_template = """{% for message in messages -%}
28
+ {%- if (loop.index % 2 == 1 and message['role'] != 'user') or
29
+ (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}
30
+ {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
31
+ {%- endif -%}
32
+ {{ message['role'].capitalize() + ': ' + message['content'] }}
33
+ {%- if not loop.last -%}
34
+ {{ ' ' }}
35
+ {%- endif %}
36
+ {%- endfor -%}
37
+ {%- if add_generation_prompt -%}
38
+ {{ ' Assistant:' }}
39
+ {%- endif %}"""
40
+ processor.tokenizer.chat_template = chat_template
41
+
42
+ except Exception as e:
43
+ print(f"Error during model loading: {e}")
44
+ # Exit if model can't be loaded, as the app is unusable.
45
+ exit()
46
+
47
+ # --- 2. Gradio Chatbot Logic with Streaming ---
48
+
49
+ def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil):
50
+ """
51
+ This generator function handles the chat logic with streaming.
52
+ It yields the updated chatbot display at each step.
53
+ """
54
+ # Check if an image has been uploaded
55
+ if image_pil is None:
56
+ chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
57
+ yield chatbot_display, messages_list
58
+ return # Stop the generator
59
+
60
+ # Append user's message to the conversation history and display
61
+ messages_list.append({"role": "user", "content": user_message})
62
+ chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response
63
+
64
+ try:
65
+ # Use the processor to apply the chat template
66
+ prompt = processor.tokenizer.apply_chat_template(
67
+ messages_list,
68
+ tokenize=False,
69
+ add_generation_prompt=True
70
+ )
71
+
72
+ # Preprocess image and the entire formatted prompt
73
+ inputs = processor.process(images=[image_pil], text=prompt)
74
+ inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
75
+
76
+ # Setup the streamer
77
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
78
+
79
+ # Define generation configuration
80
+ generation_config = GenerationConfig(
81
+ max_new_tokens=512,
82
+ do_sample=True,
83
+ top_p=0.9,
84
+ temperature=0.6,
85
+ stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation
86
+ )
87
+
88
+ # Create generation kwargs for the thread
89
+ generation_kwargs = dict(
90
+ inputs,
91
+ streamer=streamer,
92
+ generation_config=generation_config
93
+ )
94
+
95
+ # Run generation in a separate thread
96
+ thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs)
97
+ thread.start()
98
+
99
+ # Yield updates to the Gradio UI
100
+ full_response = ""
101
+ for new_text in streamer:
102
+ full_response += new_text
103
+ chatbot_display[-1] = (user_message, full_response)
104
+ yield chatbot_display, messages_list
105
+
106
+ # After the loop, the generation is complete.
107
+ # Add the final full response to the messages list for context.
108
+ messages_list.append({"role": "assistant", "content": full_response})
109
+ yield chatbot_display, messages_list # Yield the final state
110
+
111
+ except Exception as e:
112
+ print(f"Error during streaming inference: {e}")
113
+ error_message = f"Sorry, an error occurred: {e}"
114
+ chatbot_display[-1] = (user_message, error_message)
115
+ yield chatbot_display, messages_list
116
+
117
+ # --- 3. Gradio Interface Definition ---
118
+
119
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
120
+ gr.Markdown("# 🤖 Patram-7B-Instruct Streaming Chatbot")
121
+ gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.")
122
+
123
+ # State variables to hold conversation history
124
+ messages_list = gr.State([])
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=1):
128
+ image_input = gr.Image(type="pil", label="Upload Image")
129
+ clear_btn = gr.Button("🗑️ Clear Chat and Image")
130
+
131
+ with gr.Column(scale=2):
132
+ chatbot_display = gr.Chatbot(
133
+ label="Conversation",
134
+ bubble_full_width=False,
135
+ height=500,
136
+ avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg")
137
+ )
138
+ with gr.Row():
139
+ user_textbox = gr.Textbox(
140
+ placeholder="Type your question here...",
141
+ show_label=False,
142
+ scale=4,
143
+ container=False
144
+ )
145
+ # The submit button is now primarily for show; Enter key is the main way to submit
146
+ # but we will wire it up anyway.
147
+
148
+ # --- Event Listeners ---
149
+
150
+ # Define the action for submitting a message (via enter key)
151
+ submit_action = user_textbox.submit(
152
+ fn=process_chat_streaming,
153
+ inputs=[user_textbox, chatbot_display, messages_list, image_input],
154
+ outputs=[chatbot_display, messages_list],
155
+ # queue=False # Set queue to False for faster interaction with streaming
156
+ )
157
+
158
+ # Chain the action to also clear the textbox after submission
159
+ submit_action.then(
160
+ fn=lambda: gr.update(value=""),
161
+ inputs=None,
162
+ outputs=[user_textbox],
163
+ queue=False
164
+ )
165
+
166
+ # Define the action for the clear button
167
+ clear_btn.click(
168
+ fn=lambda: ([], [], None, ""), # Function to return empty/default values
169
+ inputs=[],
170
+ outputs=[chatbot_display, messages_list, image_input, user_textbox],
171
+ queue=False
172
+ )
173
+
174
+
175
+ if __name__ == "__main__":
176
+ demo.queue().launch(debug=True, share=True)