Update README.md
Browse files
README.md
CHANGED
@@ -124,33 +124,45 @@ def encode_prompt(
|
|
124 |
prompt=""
|
125 |
):
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
padding
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
text_encoder_output = prior_pipe.text_encoder(
|
153 |
-
input_ids, attention_mask=
|
154 |
)
|
155 |
|
156 |
prompt_embeds = text_encoder_output.hidden_states[-1].reshape(1,-1,1280)
|
@@ -169,7 +181,7 @@ quality_prompt = "very aesthetic, best quality, newest"
|
|
169 |
negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
|
170 |
num_images_per_prompt=1
|
171 |
|
172 |
-
# Encode prompts and quality prompts eperately, don't use attention masks
|
173 |
# pipe, device, num_images_per_prompt, prompt
|
174 |
empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
|
175 |
|
|
|
124 |
prompt=""
|
125 |
):
|
126 |
|
127 |
+
if prompt == "":
|
128 |
+
text_inputs = prior_pipe.tokenizer(
|
129 |
+
prompt,
|
130 |
+
padding="max_length",
|
131 |
+
max_length=77,
|
132 |
+
truncation=False,
|
133 |
+
return_tensors="pt",
|
134 |
+
)
|
135 |
+
input_ids = text_inputs.input_ids
|
136 |
+
attention_mask=None
|
137 |
+
else:
|
138 |
+
text_inputs = prior_pipe.tokenizer(
|
139 |
+
prompt,
|
140 |
+
padding="longest",
|
141 |
+
truncation=False,
|
142 |
+
return_tensors="pt",
|
143 |
+
)
|
144 |
+
chunk = []
|
145 |
+
padding = []
|
146 |
+
max_len = 75
|
147 |
+
start_token = text_inputs.input_ids[:,0].unsqueeze(0)
|
148 |
+
end_token = text_inputs.input_ids[:,-1].unsqueeze(0)
|
149 |
+
raw_input_ids = text_inputs.input_ids[:,1:-1]
|
150 |
+
prompt_len = len(raw_input_ids[0])
|
151 |
+
last_lenght = prompt_len % max_len
|
152 |
+
|
153 |
+
for i in range(int((prompt_len - last_lenght) / max_len)):
|
154 |
+
chunk.append(torch.cat([start_token, raw_input_ids[:,i*max_len:(i+1)*max_len], end_token], dim=1))
|
155 |
+
for i in range(max_len - last_lenght):
|
156 |
+
padding.append(text_inputs.input_ids[:,-1])
|
157 |
+
|
158 |
+
last_chunk = torch.cat([raw_input_ids[:,prompt_len-last_lenght:], torch.tensor([padding])], dim=1)
|
159 |
+
chunk.append(torch.cat([start_token, last_chunk, end_token], dim=1))
|
160 |
+
input_ids = torch.cat(chunk, dim=0)
|
161 |
+
attention_mask = torch.ones(input_ids.shape, device=device, dtype=torch.int64)
|
162 |
+
attention_mask[-1,last_lenght+1:] = 0
|
163 |
+
|
164 |
text_encoder_output = prior_pipe.text_encoder(
|
165 |
+
input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
166 |
)
|
167 |
|
168 |
prompt_embeds = text_encoder_output.hidden_states[-1].reshape(1,-1,1280)
|
|
|
181 |
negative_prompt = "very displeasing, displeasing, worst quality, bad quality, low quality, realistic, monochrome, comic, sketch, oldest, early, artist name, signature, blurry, simple background, upside down, interlocked fingers,"
|
182 |
num_images_per_prompt=1
|
183 |
|
184 |
+
# Encode prompts and quality prompts eperately, long prompt support and don't use attention masks for empty prompts:
|
185 |
# pipe, device, num_images_per_prompt, prompt
|
186 |
empty_prompt_embeds, _ = encode_prompt(pipe.prior_pipe, device, num_images_per_prompt, prompt="")
|
187 |
|