Update generator.py
Browse files- generator.py +1 -0
generator.py
CHANGED
@@ -177,6 +177,7 @@ def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
|
|
177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
178 |
state_dict = torch.load(ckpt_path)
|
179 |
model.load_state_dict(state_dict)
|
|
|
180 |
|
181 |
generator = Generator(model)
|
182 |
return generator
|
|
|
177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
178 |
state_dict = torch.load(ckpt_path)
|
179 |
model.load_state_dict(state_dict)
|
180 |
+
model.decoder = torch.compile(model.decoder, fullgraph=True, mode='max-autotune')
|
181 |
|
182 |
generator = Generator(model)
|
183 |
return generator
|