ford442 commited on
Commit
d119ef1
·
verified ·
1 Parent(s): 96f1b56

Update audiocraft/models/musicgen.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +8 -5
audiocraft/models/musicgen.py CHANGED
@@ -78,11 +78,14 @@ class MusicGen:
78
  self.generation_params: dict = {}
79
  self.set_generation_params(duration=15) # 15 seconds by default
80
  self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
 
81
  if self.device.type == 'cpu':
82
  self.autocast = TorchAutocast(enabled=False)
83
  else:
84
  self.autocast = TorchAutocast(
85
  enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
 
 
86
 
87
  @property
88
  def frame_rate(self) -> float:
@@ -173,7 +176,7 @@ class MusicGen:
173
  """Override the default progress callback."""
174
  self._progress_callback = progress_callback
175
 
176
- def generate_unconditional(self, num_samples: int, progress: bool = False,
177
  return_tokens: bool = False) -> tp.Union[torch.Tensor,
178
  tp.Tuple[torch.Tensor, torch.Tensor]]:
179
  """Generate samples in an unconditional manner.
@@ -189,7 +192,7 @@ class MusicGen:
189
  return self.generate_audio(tokens), tokens
190
  return self.generate_audio(tokens)
191
 
192
- def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
193
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
194
  """Generate samples conditioned on text.
195
 
@@ -205,7 +208,7 @@ class MusicGen:
205
  return self.generate_audio(tokens)
206
 
207
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
208
- melody_sample_rate: int, progress: bool = False,
209
  return_tokens: bool = False) -> tp.Union[torch.Tensor,
210
  tp.Tuple[torch.Tensor, torch.Tensor]]:
211
  """Generate samples conditioned on text and melody.
@@ -244,7 +247,7 @@ class MusicGen:
244
 
245
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
246
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
247
- progress: bool = False, return_tokens: bool = False) \
248
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
249
  """Generate samples conditioned on audio prompts.
250
 
@@ -328,7 +331,7 @@ class MusicGen:
328
  return attributes, prompt_tokens
329
 
330
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
331
- prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
332
  """Generate discrete audio tokens given audio prompt and/or conditions.
333
 
334
  Args:
 
78
  self.generation_params: dict = {}
79
  self.set_generation_params(duration=15) # 15 seconds by default
80
  self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
81
+ '''
82
  if self.device.type == 'cpu':
83
  self.autocast = TorchAutocast(enabled=False)
84
  else:
85
  self.autocast = TorchAutocast(
86
  enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
87
+ '''
88
+ self.autocast = TorchAutocast(enabled=False)
89
 
90
  @property
91
  def frame_rate(self) -> float:
 
176
  """Override the default progress callback."""
177
  self._progress_callback = progress_callback
178
 
179
+ def generate_unconditional(self, num_samples: int, progress: bool = True,
180
  return_tokens: bool = False) -> tp.Union[torch.Tensor,
181
  tp.Tuple[torch.Tensor, torch.Tensor]]:
182
  """Generate samples in an unconditional manner.
 
192
  return self.generate_audio(tokens), tokens
193
  return self.generate_audio(tokens)
194
 
195
+ def generate(self, descriptions: tp.List[str], progress: bool = True, return_tokens: bool = False) \
196
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
197
  """Generate samples conditioned on text.
198
 
 
208
  return self.generate_audio(tokens)
209
 
210
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
211
+ melody_sample_rate: int, progress: bool = True,
212
  return_tokens: bool = False) -> tp.Union[torch.Tensor,
213
  tp.Tuple[torch.Tensor, torch.Tensor]]:
214
  """Generate samples conditioned on text and melody.
 
247
 
248
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
249
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
250
+ progress: bool = True, return_tokens: bool = False) \
251
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
252
  """Generate samples conditioned on audio prompts.
253
 
 
331
  return attributes, prompt_tokens
332
 
333
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
334
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = True) -> torch.Tensor:
335
  """Generate discrete audio tokens given audio prompt and/or conditions.
336
 
337
  Args: