Spaces:
Runtime error
Runtime error
Li
commited on
Commit
·
d3fbc73
1
Parent(s):
f407227
“update”
Browse files
app.py
CHANGED
@@ -92,6 +92,9 @@ def generate(
|
|
92 |
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
|
93 |
bad_words_ids = list(all_ids - set(loc_token_ids))
|
94 |
bad_words_ids = [[b] for b in bad_words_ids]
|
|
|
|
|
|
|
95 |
min_loc_token_id = min(loc_token_ids)
|
96 |
max_loc_token_id = max(loc_token_ids)
|
97 |
image_ori = image
|
@@ -103,9 +106,11 @@ def generate(
|
|
103 |
if idx == 1:
|
104 |
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
|
105 |
bad_words_ids = None
|
|
|
106 |
else:
|
107 |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
|
108 |
-
bad_words_ids =
|
|
|
109 |
encodings = tokenizer(
|
110 |
prompt,
|
111 |
padding="longest",
|
@@ -122,7 +127,7 @@ def generate(
|
|
122 |
model=flamingo,
|
123 |
batch_images=batch_images,
|
124 |
attention_mask=attention_mask,
|
125 |
-
max_generation_length=
|
126 |
min_generation_length=4,
|
127 |
num_beams=1,
|
128 |
length_penalty=1.0,
|
|
|
92 |
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
|
93 |
bad_words_ids = list(all_ids - set(loc_token_ids))
|
94 |
bad_words_ids = [[b] for b in bad_words_ids]
|
95 |
+
loc_word_ids = list(set(loc_token_ids))
|
96 |
+
loc_word_ids = [[b] for b in loc_word_ids]
|
97 |
+
|
98 |
min_loc_token_id = min(loc_token_ids)
|
99 |
max_loc_token_id = max(loc_token_ids)
|
100 |
image_ori = image
|
|
|
106 |
if idx == 1:
|
107 |
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
|
108 |
bad_words_ids = None
|
109 |
+
max_generation_length = 5
|
110 |
else:
|
111 |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
|
112 |
+
bad_words_ids = loc_word_ids
|
113 |
+
max_generation_length = 100
|
114 |
encodings = tokenizer(
|
115 |
prompt,
|
116 |
padding="longest",
|
|
|
127 |
model=flamingo,
|
128 |
batch_images=batch_images,
|
129 |
attention_mask=attention_mask,
|
130 |
+
max_generation_length=max_generation_length,
|
131 |
min_generation_length=4,
|
132 |
num_beams=1,
|
133 |
length_penalty=1.0,
|