K00B404 commited on
Commit
4fe8b5e
·
verified ·
1 Parent(s): 709936e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -101,7 +101,7 @@ class UNetWrapper:
101
 
102
  def save_checkpoint(self, save_path):
103
  """Save checkpoint with model, optimizer, and scheduler states."""
104
- save_dict = {
105
  'model_state_dict': self.model.state_dict(),
106
  'optimizer_state_dict': self.optimizer.state_dict(),
107
  'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
@@ -112,7 +112,7 @@ class UNetWrapper:
112
  'epoch': self.epoch,
113
  'loss': self.loss
114
  }
115
- torch.save(save_dict, save_path)
116
  print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
117
 
118
  def load_checkpoint(self, checkpoint_path):
@@ -162,8 +162,8 @@ pipeline_tag: image-to-image
162
 
163
  ## Model Description
164
  Custom UNet model for Pix2Pix image translation.
165
- - **Image Size:** {save_dict['model_config']['img_size']}
166
- - **Model Type:** {"big" if big else "small"}_UNet ({save_dict['model_config']['img_size']})
167
 
168
  ## Usage
169
 
@@ -188,8 +188,8 @@ model.eval()
188
  # Save and upload README
189
  with open("README.md", "w") as f:
190
  f.write(f"# Pix2Pix UNet Model\n\n"
191
- f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
192
- f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
193
  f"## Model Architecture\n{str(self.model)}")
194
 
195
  self.api.upload_file(
 
101
 
102
  def save_checkpoint(self, save_path):
103
  """Save checkpoint with model, optimizer, and scheduler states."""
104
+ self.save_dict = {
105
  'model_state_dict': self.model.state_dict(),
106
  'optimizer_state_dict': self.optimizer.state_dict(),
107
  'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
 
112
  'epoch': self.epoch,
113
  'loss': self.loss
114
  }
115
+ torch.save(self.save_dict, save_path)
116
  print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
117
 
118
  def load_checkpoint(self, checkpoint_path):
 
162
 
163
  ## Model Description
164
  Custom UNet model for Pix2Pix image translation.
165
+ - **Image Size:** {self.save_dict['model_config']['img_size']}
166
+ - **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']})
167
 
168
  ## Usage
169
 
 
188
  # Save and upload README
189
  with open("README.md", "w") as f:
190
  f.write(f"# Pix2Pix UNet Model\n\n"
191
+ f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n"
192
+ f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n"
193
  f"## Model Architecture\n{str(self.model)}")
194
 
195
  self.api.upload_file(