Update audiocraft/models/musicgen.py
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -78,14 +78,11 @@ class MusicGen:
|
|
78 |
self.generation_params: dict = {}
|
79 |
self.set_generation_params(duration=15) # 15 seconds by default
|
80 |
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
81 |
-
'''
|
82 |
if self.device.type == 'cpu':
|
83 |
self.autocast = TorchAutocast(enabled=False)
|
84 |
else:
|
85 |
self.autocast = TorchAutocast(
|
86 |
enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
|
87 |
-
'''
|
88 |
-
self.autocast = TorchAutocast(enabled=False)
|
89 |
|
90 |
@property
|
91 |
def frame_rate(self) -> float:
|
|
|
78 |
self.generation_params: dict = {}
|
79 |
self.set_generation_params(duration=15) # 15 seconds by default
|
80 |
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
|
|
81 |
if self.device.type == 'cpu':
|
82 |
self.autocast = TorchAutocast(enabled=False)
|
83 |
else:
|
84 |
self.autocast = TorchAutocast(
|
85 |
enabled=True, device_type=self.device.type, dtype=torch.bfloat16)
|
|
|
|
|
86 |
|
87 |
@property
|
88 |
def frame_rate(self) -> float:
|