alvanlii commited on
Commit
00d336f
·
1 Parent(s): de65f8b

Use gradio state instead of class variable

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -23,7 +23,6 @@ class ChatBotCheese:
23
  model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar")
24
  self.model = models.load_fromage(MODEL_DIR, model_ckpt_path)
25
  self.curr_image = None
26
- self.chat_history = ''
27
 
28
  def add_image(self, state, image_in):
29
  state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")]
@@ -35,14 +34,14 @@ class ChatBotCheese:
35
  image_pil.save(file_name)
36
  return file_name
37
 
38
- def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
39
- # model_outputs = ["heyo", []]
40
- self.chat_history += f'Q: {input_text} \nA:'
41
  if self.curr_image is not None:
42
- model_outputs = self.model.generate_for_images_and_texts([self.curr_image, self.chat_history], num_words=num_words, max_num_rets=num_ims, ret_scale_factor=ret_scale_factor, temperature=temp)
43
  else:
44
- model_outputs = self.model.generate_for_images_and_texts([self.chat_history], max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
45
- self.chat_history += ' '.join([s for s in model_outputs if type(s) == str]) + '\n'
46
 
47
  im_names = []
48
  if len(model_outputs) > 1:
@@ -52,11 +51,10 @@ class ChatBotCheese:
52
  for im_name in im_names:
53
  response += f'<img src="/file={im_name}">'
54
  state.append((input_text, response.replace("[RET]", "")))
55
- self.curr_image = None
56
- return state, state
57
 
58
  def reset(self):
59
- self.chat_history = ""
60
  self.curr_image = None
61
  return [], []
62
 
@@ -66,7 +64,7 @@ class ChatBotCheese:
66
  """
67
  ### FROMAGe: Grounding Language Models to Images for Multimodal Generation
68
  Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried <br/>
69
- [Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage) <br/>
70
  This is an unofficial Gradio demo for the paper FROMAGe <br/>
71
  - Instructions (in order):
72
  - [Optional] Upload an image (the button with a photo emoji)
@@ -83,6 +81,7 @@ class ChatBotCheese:
83
 
84
  chatbot = gr.Chatbot(elem_id="chatbot")
85
  gr_state = gr.State([])
 
86
 
87
  with gr.Row():
88
  with gr.Column(scale=0.85):
@@ -104,7 +103,7 @@ class ChatBotCheese:
104
  gr.Image("example_3.png", label="Example 3")
105
 
106
 
107
- txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp], [gr_state, chatbot])
108
  txt.submit(lambda :"", None, txt)
109
  btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot])
110
  reset_btn.click(self.reset, [], [gr_state, chatbot])
@@ -119,5 +118,4 @@ def main():
119
  cheddar.main()
120
 
121
  if __name__ == "__main__":
122
- cheddar = ChatBotCheese()
123
- cheddar.main()
 
23
  model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar")
24
  self.model = models.load_fromage(MODEL_DIR, model_ckpt_path)
25
  self.curr_image = None
 
26
 
27
  def add_image(self, state, image_in):
28
  state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")]
 
34
  image_pil.save(file_name)
35
  return file_name
36
 
37
+ def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp, chat_state):
38
+ chat_state.append(f'Q: {input_text} \nA:')
39
+ chat_history = " ".join(chat_state)
40
  if self.curr_image is not None:
41
+ model_outputs = self.model.generate_for_images_and_texts([self.curr_image, chat_history], num_words=num_words, max_num_rets=num_ims, ret_scale_factor=ret_scale_factor, temperature=temp)
42
  else:
43
+ model_outputs = self.model.generate_for_images_and_texts([chat_history], max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
44
+ chat_state.append(' '.join([s for s in model_outputs if type(s) == str]) + '\n')
45
 
46
  im_names = []
47
  if len(model_outputs) > 1:
 
51
  for im_name in im_names:
52
  response += f'<img src="/file={im_name}">'
53
  state.append((input_text, response.replace("[RET]", "")))
54
+ # self.curr_image = None
55
+ return state, state, chat_state
56
 
57
  def reset(self):
 
58
  self.curr_image = None
59
  return [], []
60
 
 
64
  """
65
  ### FROMAGe: Grounding Language Models to Images for Multimodal Generation
66
  Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried <br/>
67
+ [Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage) [Official Demo](https://huggingface.co/spaces/jykoh/fromage) <br/>
68
  This is an unofficial Gradio demo for the paper FROMAGe <br/>
69
  - Instructions (in order):
70
  - [Optional] Upload an image (the button with a photo emoji)
 
81
 
82
  chatbot = gr.Chatbot(elem_id="chatbot")
83
  gr_state = gr.State([])
84
+ gr_chat_state = gr.State([])
85
 
86
  with gr.Row():
87
  with gr.Column(scale=0.85):
 
103
  gr.Image("example_3.png", label="Example 3")
104
 
105
 
106
+ txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp, gr_chat_state], [gr_state, chatbot, gr_chat_state])
107
  txt.submit(lambda :"", None, txt)
108
  btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot])
109
  reset_btn.click(self.reset, [], [gr_state, chatbot])
 
118
  cheddar.main()
119
 
120
  if __name__ == "__main__":
121
+ main()