KingNish commited on
Commit
7ee9a0d
·
verified ·
1 Parent(s): 3838d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -8
app.py CHANGED
@@ -6,7 +6,6 @@ import spaces
6
  import threading
7
 
8
  # --- 1. Model and Processor Setup ---
9
-
10
  model_id = "bharatgenai/patram-7b-instruct"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
@@ -21,6 +20,11 @@ model = AutoModelForCausalLM.from_pretrained(
21
  )
22
  print("Model and processor loaded successfully.")
23
 
 
 
 
 
 
24
  # --- Define and apply a more flexible chat template ---
25
  chat_template = """{% for message in messages %}
26
  {{ message['role'].capitalize() }}: {{ message['content'] }}
@@ -49,8 +53,11 @@ def generate_response(user_message, messages_list, image_pil, max_new_tokens, to
49
  add_generation_prompt=True
50
  )
51
 
52
- # Preprocess image and the entire formatted prompt
53
- inputs = processor.process(images=[image_pil], text=prompt)
 
 
 
54
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
55
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
56
 
@@ -93,9 +100,6 @@ def process_chat(user_message, chatbot_display, messages_list, image_pil, max_ne
93
  """
94
  This function handles the chat logic for a single turn with streaming.
95
  """
96
- if image_pil is None:
97
- chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
98
- return chatbot_display, messages_list, ""
99
 
100
  # Append user's message to the chatbot display list
101
  chatbot_display.append((user_message, ""))
@@ -113,7 +117,7 @@ def process_chat(user_message, chatbot_display, messages_list, image_pil, max_ne
113
 
114
  def clear_chat():
115
  """Resets the chat, history, and image."""
116
- return [], [], None, "", 256, 0.9, 50, 0.6
117
 
118
  # --- 3. Gradio Interface Definition ---
119
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
@@ -127,6 +131,12 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
  image_input_render = gr.Image(type="pil", label="Upload Image")
 
 
 
 
 
 
130
  clear_btn = gr.Button("🗑️ Clear Chat and Image")
131
  with gr.Accordion("Generation Parameters", open=False):
132
  max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens")
@@ -149,6 +159,20 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
149
  )
150
  submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # --- Event Listeners ---
153
 
154
  # Define the action for submitting a message (via button or enter key)
@@ -167,7 +191,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
167
  clear_btn.click(
168
  fn=clear_chat,
169
  inputs=[],
170
- outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature],
171
  queue=False
172
  )
173
 
 
6
  import threading
7
 
8
  # --- 1. Model and Processor Setup ---
 
9
  model_id = "bharatgenai/patram-7b-instruct"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"Using device: {device}")
 
20
  )
21
  print("Model and processor loaded successfully.")
22
 
23
+ # Default system prompt
24
+ DEFAULT_SYSTEM_PROMPT = """You are Patram, a helpful AI assistant created by BharatGenAI. You are designed to analyze images and answer questions about them.
25
+ Think step by step before providing your answers. Be detailed, accurate, and helpful in your responses.
26
+ You can understand both text and image inputs to provide comprehensive answers to user queries."""
27
+
28
  # --- Define and apply a more flexible chat template ---
29
  chat_template = """{% for message in messages %}
30
  {{ message['role'].capitalize() }}: {{ message['content'] }}
 
53
  add_generation_prompt=True
54
  )
55
 
56
+ if image_pil:
57
+ # Preprocess image and the entire formatted prompt
58
+ inputs = processor.process(images=[image_pil], text=prompt)
59
+ else:
60
+ inputs = processor.process(text=prompt)
61
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
62
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
63
 
 
100
  """
101
  This function handles the chat logic for a single turn with streaming.
102
  """
 
 
 
103
 
104
  # Append user's message to the chatbot display list
105
  chatbot_display.append((user_message, ""))
 
117
 
118
  def clear_chat():
119
  """Resets the chat, history, and image."""
120
+ return [], [], None, "", 256, 0.9, 50, 0.6, DEFAULT_SYSTEM_PROMPT
121
 
122
  # --- 3. Gradio Interface Definition ---
123
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
 
131
  with gr.Row():
132
  with gr.Column(scale=1):
133
  image_input_render = gr.Image(type="pil", label="Upload Image")
134
+ system_prompt = gr.Textbox(
135
+ label="System Prompt",
136
+ value=DEFAULT_SYSTEM_PROMPT,
137
+ interactive=True,
138
+ lines=5
139
+ )
140
  clear_btn = gr.Button("🗑️ Clear Chat and Image")
141
  with gr.Accordion("Generation Parameters", open=False):
142
  max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens")
 
159
  )
160
  submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
161
 
162
+ # Initialize messages_list with system prompt
163
+ demo.load(
164
+ fn=lambda: [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}],
165
+ inputs=None,
166
+ outputs=messages_list
167
+ )
168
+
169
+ # Update messages_list when system prompt changes
170
+ system_prompt.change(
171
+ fn=lambda system_prompt: [{"role": "system", "content": system_prompt}],
172
+ inputs=system_prompt,
173
+ outputs=messages_list
174
+ )
175
+
176
  # --- Event Listeners ---
177
 
178
  # Define the action for submitting a message (via button or enter key)
 
191
  clear_btn.click(
192
  fn=clear_chat,
193
  inputs=[],
194
+ outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature, system_prompt],
195
  queue=False
196
  )
197