Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a18ead3
1
Parent(s):
1d1d4f3
update
Browse files
utils.py
CHANGED
@@ -279,10 +279,13 @@ class MolecularGenerationModel():
|
|
279 |
input_length = batch['input_ids'].shape[1]
|
280 |
steps = 1024 - input_length
|
281 |
|
|
|
282 |
with torch.set_grad_enabled(False):
|
283 |
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
|
284 |
for k in range(steps):
|
|
|
285 |
logits = self.model(**batch)['logits']
|
|
|
286 |
logits = logits[:, -1, :] / temperature
|
287 |
probs = F.softmax(logits, dim=-1)
|
288 |
ix = torch.multinomial(probs, num_samples=num_generations)
|
|
|
279 |
input_length = batch['input_ids'].shape[1]
|
280 |
steps = 1024 - input_length
|
281 |
|
282 |
+
print(self.model.device, "model_device")
|
283 |
with torch.set_grad_enabled(False):
|
284 |
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
|
285 |
for k in range(steps):
|
286 |
+
print("batch", batch)
|
287 |
logits = self.model(**batch)['logits']
|
288 |
+
print("logits", logits)
|
289 |
logits = logits[:, -1, :] / temperature
|
290 |
probs = F.softmax(logits, dim=-1)
|
291 |
ix = torch.multinomial(probs, num_samples=num_generations)
|