Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -137,13 +137,15 @@ def predict(text, speaker):
|
|
137 |
|
138 |
### ### ###
|
139 |
example = dataset['test'][11]
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
144 |
|
145 |
-
|
146 |
-
# spectrogram = model.generate_speech(inputs["input_ids"], speaker_embedding)
|
147 |
with torch.no_grad():
|
148 |
speech = vocoder(spectrogram)
|
149 |
# speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
|
|
|
137 |
|
138 |
### ### ###
|
139 |
example = dataset['test'][11]
|
140 |
+
speaker_embedding = torch.tensor(example["speaker_embeddings"]).unsqueeze(0).to(device)
|
141 |
+
|
142 |
+
# Ensure the speaker_embedding has the correct dimensions
|
143 |
+
if speaker_embedding.dim() == 2:
|
144 |
+
speaker_embedding = speaker_embedding.unsqueeze(1).expand(-1, inputs["input_ids"].size(1), -1)
|
145 |
+
elif speaker_embedding.dim() == 3:
|
146 |
+
speaker_embedding = speaker_embedding.expand(-1, inputs["input_ids"].size(1), -1)
|
147 |
|
148 |
+
spectrogram = model.generate_speech(inputs["input_ids"].to(device), speaker_embedding)
|
|
|
149 |
with torch.no_grad():
|
150 |
speech = vocoder(spectrogram)
|
151 |
# speech = model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
|