Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -101,7 +101,7 @@ class UNetWrapper:
|
|
101 |
|
102 |
def save_checkpoint(self, save_path):
|
103 |
"""Save checkpoint with model, optimizer, and scheduler states."""
|
104 |
-
save_dict = {
|
105 |
'model_state_dict': self.model.state_dict(),
|
106 |
'optimizer_state_dict': self.optimizer.state_dict(),
|
107 |
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
@@ -112,7 +112,7 @@ class UNetWrapper:
|
|
112 |
'epoch': self.epoch,
|
113 |
'loss': self.loss
|
114 |
}
|
115 |
-
torch.save(save_dict, save_path)
|
116 |
print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
|
117 |
|
118 |
def load_checkpoint(self, checkpoint_path):
|
@@ -162,8 +162,8 @@ pipeline_tag: image-to-image
|
|
162 |
|
163 |
## Model Description
|
164 |
Custom UNet model for Pix2Pix image translation.
|
165 |
-
- **Image Size:** {save_dict['model_config']['img_size']}
|
166 |
-
- **Model Type:** {"big" if big else "small"}_UNet ({save_dict['model_config']['img_size']})
|
167 |
|
168 |
## Usage
|
169 |
|
@@ -188,8 +188,8 @@ model.eval()
|
|
188 |
# Save and upload README
|
189 |
with open("README.md", "w") as f:
|
190 |
f.write(f"# Pix2Pix UNet Model\n\n"
|
191 |
-
f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
|
192 |
-
f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
|
193 |
f"## Model Architecture\n{str(self.model)}")
|
194 |
|
195 |
self.api.upload_file(
|
|
|
101 |
|
102 |
def save_checkpoint(self, save_path):
|
103 |
"""Save checkpoint with model, optimizer, and scheduler states."""
|
104 |
+
self.save_dict = {
|
105 |
'model_state_dict': self.model.state_dict(),
|
106 |
'optimizer_state_dict': self.optimizer.state_dict(),
|
107 |
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
|
|
112 |
'epoch': self.epoch,
|
113 |
'loss': self.loss
|
114 |
}
|
115 |
+
torch.save(self.save_dict, save_path)
|
116 |
print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
|
117 |
|
118 |
def load_checkpoint(self, checkpoint_path):
|
|
|
162 |
|
163 |
## Model Description
|
164 |
Custom UNet model for Pix2Pix image translation.
|
165 |
+
- **Image Size:** {self.save_dict['model_config']['img_size']}
|
166 |
+
- **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']})
|
167 |
|
168 |
## Usage
|
169 |
|
|
|
188 |
# Save and upload README
|
189 |
with open("README.md", "w") as f:
|
190 |
f.write(f"# Pix2Pix UNet Model\n\n"
|
191 |
+
f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n"
|
192 |
+
f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n"
|
193 |
f"## Model Architecture\n{str(self.model)}")
|
194 |
|
195 |
self.api.upload_file(
|