Update README.md
Browse files
README.md
CHANGED
@@ -153,10 +153,16 @@ def encode_prompt(
|
|
153 |
input_ids, attention_mask=None, output_hidden_states=True
|
154 |
)
|
155 |
|
156 |
-
start_embed = text_encoder_output.hidden_states[-1][
|
157 |
-
end_embed = text_encoder_output.hidden_states[-1][
|
158 |
prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
|
161 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
162 |
|
|
|
153 |
input_ids, attention_mask=None, output_hidden_states=True
|
154 |
)
|
155 |
|
156 |
+
start_embed = text_encoder_output.hidden_states[-1][0][0].unsqueeze(0).unsqueeze(0)
|
157 |
+
end_embed = text_encoder_output.hidden_states[-1][0][0].unsqueeze(0).unsqueeze(0)
|
158 |
prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
|
159 |
+
|
160 |
+
padding = []
|
161 |
+
for i in range((max_len + 1) - (prompt_embeds.shape[1] % 77)):
|
162 |
+
padding.append(end_embed)
|
163 |
+
padding = torch.cat(padding, dim=1)
|
164 |
+
|
165 |
+
prompt_embeds = torch.cat([start_embed, prompt_embeds, padding], dim=1)
|
166 |
prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
|
167 |
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
168 |
|