Update audiocraft/models/musicgen.py
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -78,11 +78,14 @@ class MusicGen:
|
|
78 |
self.generation_params: dict = {}
|
79 |
self.set_generation_params(duration=15) # 15 seconds by default
|
80 |
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
|
|
81 |
if self.device.type == 'cpu':
|
82 |
self.autocast = TorchAutocast(enabled=False)
|
83 |
else:
|
84 |
self.autocast = TorchAutocast(
|
85 |
enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
|
|
|
|
|
86 |
|
87 |
@property
|
88 |
def frame_rate(self) -> float:
|
@@ -173,7 +176,7 @@ class MusicGen:
|
|
173 |
"""Override the default progress callback."""
|
174 |
self._progress_callback = progress_callback
|
175 |
|
176 |
-
def generate_unconditional(self, num_samples: int, progress: bool =
|
177 |
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
178 |
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
179 |
"""Generate samples in an unconditional manner.
|
@@ -189,7 +192,7 @@ class MusicGen:
|
|
189 |
return self.generate_audio(tokens), tokens
|
190 |
return self.generate_audio(tokens)
|
191 |
|
192 |
-
def generate(self, descriptions: tp.List[str], progress: bool =
|
193 |
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
194 |
"""Generate samples conditioned on text.
|
195 |
|
@@ -205,7 +208,7 @@ class MusicGen:
|
|
205 |
return self.generate_audio(tokens)
|
206 |
|
207 |
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
208 |
-
melody_sample_rate: int, progress: bool =
|
209 |
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
210 |
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
211 |
"""Generate samples conditioned on text and melody.
|
@@ -244,7 +247,7 @@ class MusicGen:
|
|
244 |
|
245 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
246 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
247 |
-
progress: bool =
|
248 |
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
249 |
"""Generate samples conditioned on audio prompts.
|
250 |
|
@@ -328,7 +331,7 @@ class MusicGen:
|
|
328 |
return attributes, prompt_tokens
|
329 |
|
330 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
331 |
-
prompt_tokens: tp.Optional[torch.Tensor], progress: bool =
|
332 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
333 |
|
334 |
Args:
|
|
|
78 |
self.generation_params: dict = {}
|
79 |
self.set_generation_params(duration=15) # 15 seconds by default
|
80 |
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
81 |
+
'''
|
82 |
if self.device.type == 'cpu':
|
83 |
self.autocast = TorchAutocast(enabled=False)
|
84 |
else:
|
85 |
self.autocast = TorchAutocast(
|
86 |
enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
|
87 |
+
'''
|
88 |
+
self.autocast = TorchAutocast(enabled=False)
|
89 |
|
90 |
@property
|
91 |
def frame_rate(self) -> float:
|
|
|
176 |
"""Override the default progress callback."""
|
177 |
self._progress_callback = progress_callback
|
178 |
|
179 |
+
def generate_unconditional(self, num_samples: int, progress: bool = True,
|
180 |
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
181 |
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
182 |
"""Generate samples in an unconditional manner.
|
|
|
192 |
return self.generate_audio(tokens), tokens
|
193 |
return self.generate_audio(tokens)
|
194 |
|
195 |
+
def generate(self, descriptions: tp.List[str], progress: bool = True, return_tokens: bool = False) \
|
196 |
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
197 |
"""Generate samples conditioned on text.
|
198 |
|
|
|
208 |
return self.generate_audio(tokens)
|
209 |
|
210 |
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
211 |
+
melody_sample_rate: int, progress: bool = True,
|
212 |
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
213 |
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
214 |
"""Generate samples conditioned on text and melody.
|
|
|
247 |
|
248 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
249 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
250 |
+
progress: bool = True, return_tokens: bool = False) \
|
251 |
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
252 |
"""Generate samples conditioned on audio prompts.
|
253 |
|
|
|
331 |
return attributes, prompt_tokens
|
332 |
|
333 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
334 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = True) -> torch.Tensor:
|
335 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
336 |
|
337 |
Args:
|