Disty0 commited on
Commit
9b16786
·
verified ·
1 Parent(s): de04131

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -3
README.md CHANGED
@@ -153,10 +153,16 @@ def encode_prompt(
153
  input_ids, attention_mask=None, output_hidden_states=True
154
  )
155
 
156
- start_embed = text_encoder_output.hidden_states[-1][:,0].unsqueeze(0)
157
- end_embed = text_encoder_output.hidden_states[-1][:,-1].unsqueeze(0)
158
  prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
159
- prompt_embeds = torch.cat([start_embed, prompt_embeds, end_embed], dim=1)
 
 
 
 
 
 
160
  prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
161
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
162
 
 
153
  input_ids, attention_mask=None, output_hidden_states=True
154
  )
155
 
156
+ start_embed = text_encoder_output.hidden_states[-1][0][0].unsqueeze(0).unsqueeze(0)
157
+ end_embed = text_encoder_output.hidden_states[-1][0][0].unsqueeze(0).unsqueeze(0)
158
  prompt_embeds = text_encoder_output.hidden_states[-1][:,1:-1].reshape(1,-1,1280)
159
+
160
+ padding = []
161
+ for i in range((max_len + 1) - (prompt_embeds.shape[1] % 77)):
162
+ padding.append(end_embed)
163
+ padding = torch.cat(padding, dim=1)
164
+
165
+ prompt_embeds = torch.cat([start_embed, prompt_embeds, padding], dim=1)
166
  prompt_embeds = prompt_embeds.to(dtype=prior_pipe.text_encoder.dtype, device=device)
167
  prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
168