feiyang-cai commited on
Commit
a18ead3
·
1 Parent(s): 1d1d4f3
Files changed (1) hide show
  1. utils.py +3 -0
utils.py CHANGED
@@ -279,10 +279,13 @@ class MolecularGenerationModel():
279
  input_length = batch['input_ids'].shape[1]
280
  steps = 1024 - input_length
281
 
 
282
  with torch.set_grad_enabled(False):
283
  early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
284
  for k in range(steps):
 
285
  logits = self.model(**batch)['logits']
 
286
  logits = logits[:, -1, :] / temperature
287
  probs = F.softmax(logits, dim=-1)
288
  ix = torch.multinomial(probs, num_samples=num_generations)
 
279
  input_length = batch['input_ids'].shape[1]
280
  steps = 1024 - input_length
281
 
282
+ print(self.model.device, "model_device")
283
  with torch.set_grad_enabled(False):
284
  early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device)
285
  for k in range(steps):
286
+ print("batch", batch)
287
  logits = self.model(**batch)['logits']
288
+ print("logits", logits)
289
  logits = logits[:, -1, :] / temperature
290
  probs = F.softmax(logits, dim=-1)
291
  ix = torch.multinomial(probs, num_samples=num_generations)