Spaces:
Runtime error
Runtime error
Update infer.py
Browse files
infer.py
CHANGED
|
@@ -63,7 +63,7 @@ def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history
|
|
| 63 |
# step1: initializing
|
| 64 |
model.to('cuda')
|
| 65 |
tokenizer_voila.to('cuda')
|
| 66 |
-
if ref_embs:
|
| 67 |
ref_embs = ref_embs.to('cuda')
|
| 68 |
ref_embs_mask = ref_embs_mask.to('cuda')
|
| 69 |
num_codebooks = model.config.num_codebooks
|
|
|
|
| 63 |
# step1: initializing
|
| 64 |
model.to('cuda')
|
| 65 |
tokenizer_voila.to('cuda')
|
| 66 |
+
if isinstance(ref_embs, torch.Tensor):
|
| 67 |
ref_embs = ref_embs.to('cuda')
|
| 68 |
ref_embs_mask = ref_embs_mask.to('cuda')
|
| 69 |
num_codebooks = model.config.num_codebooks
|