K00B404 commited on
Commit
3a06055
·
verified ·
1 Parent(s): a52c65b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -83
app.py CHANGED
@@ -90,20 +90,6 @@ class Pix2PixDataset(torch.utils.data.Dataset):
90
 
91
  class UNetWrapper:
92
 
93
- def push_to_hub(self, pth_name):
94
- """Push model checkpoint and metadata to the Hugging Face Hub."""
95
- try:
96
- self.api.upload_file(
97
- path_or_fileobj=pth_name,
98
- path_in_repo=pth_name,
99
- repo_id=self.repo_id,
100
- token=self.token,
101
- repo_type="model"
102
- )
103
- print(f"Model checkpoint successfully uploaded to {self.repo_id}")
104
- except Exception as e:
105
- print(f"Error uploading model: {e}")
106
-
107
 
108
 
109
 
@@ -146,52 +132,25 @@ class UNetWrapper:
146
  self.loss = checkpoint['loss']
147
  print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
148
 
149
-
150
- def push_to_hub(self):
151
  try:
152
- # Save model state and configuration
153
- save_dict = {
154
- 'model_state_dict': self.model.state_dict(),
155
- 'model_config': {
156
- 'big': isinstance(self.model, big_UNet),
157
- 'img_size': 1024 if isinstance(self.model, big_UNet) else 256
158
- },
159
- 'model_architecture': str(self.model),
160
- 'model_state':{
161
- 'epoch': self.epoch,
162
- 'loss': self.loss
163
- }
164
- }
165
-
166
- # Save model locally
167
- pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
168
- torch.save(save_dict, pth_name)
169
-
170
- # Create repo if it doesn't exist
171
- try:
172
- create_repo(
173
- repo_id=self.repo_id,
174
- token=self.token,
175
- exist_ok=True
176
- )
177
- except Exception as e:
178
- print(f"Repository creation note: {e}")
179
-
180
- """Push model checkpoint and metadata to the Hugging Face Hub."""
181
- try:
182
- self.api.upload_file(
183
- path_or_fileobj=pth_name,
184
- path_in_repo=pth_name,
185
- repo_id=self.repo_id,
186
- token=self.token,
187
- repo_type="model"
188
- )
189
- print(f"Model checkpoint successfully uploaded to {self.repo_id}")
190
- except Exception as e:
191
- print(f"Error uploading model: {e}")
192
 
193
- # Create and upload model card
194
- model_card = f"""---
195
  tags:
196
  - unet
197
  - pix2pix
@@ -230,31 +189,31 @@ model.eval()
230
  ## Model Architecture
231
 
232
  {str(self.model)} """
233
- rp(model_card)
234
-
235
- # Save and upload README
236
- with open("README.md", "w") as f:
237
- f.write(f"# Pix2Pix UNet Model\n\n"
238
- f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
239
- f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
240
- f"## Model Architecture\n{str(self.model)}")
241
-
242
- self.api.upload_file(
243
- path_or_fileobj="README.md",
244
- path_in_repo="README.md",
245
- repo_id=self.repo_id,
246
- token=self.token,
247
- repo_type="model"
248
- )
249
-
250
- # Clean up local files
251
- os.remove(pth_name)
252
- os.remove("README.md")
253
-
254
- print(f"Model successfully uploaded to {self.repo_id}")
255
-
256
- except Exception as e:
257
- print(f"Error uploading model: {e}")
258
 
259
  def prepare_input(image, device='cpu'):
260
  """Prepare image for inference"""
 
90
 
91
  class UNetWrapper:
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
 
 
132
  self.loss = checkpoint['loss']
133
  print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
134
 
135
+ def push_to_hub(self, pth_name):
136
+ """Push model checkpoint and metadata to the Hugging Face Hub."""
137
  try:
138
+ self.api.upload_file(
139
+ path_or_fileobj=pth_name,
140
+ path_in_repo=pth_name,
141
+ repo_id=self.repo_id,
142
+ token=self.token,
143
+ repo_type="model"
144
+ )
145
+ print(f"Model checkpoint successfully uploaded to {self.repo_id}")
146
+ except Exception as e:
147
+ print(f"Error uploading model: {e}")
148
+
149
+
150
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # Create and upload model card
153
+ model_card = f"""---
154
  tags:
155
  - unet
156
  - pix2pix
 
189
  ## Model Architecture
190
 
191
  {str(self.model)} """
192
+ rp(model_card)
193
+
194
+ # Save and upload README
195
+ with open("README.md", "w") as f:
196
+ f.write(f"# Pix2Pix UNet Model\n\n"
197
+ f"- **Image Size:** {save_dict['model_config']['img_size']}\n"
198
+ f"- **Model Type:** {'big' if big else 'small'}_UNet ({save_dict['model_config']['img_size']})\n"
199
+ f"## Model Architecture\n{str(self.model)}")
200
+
201
+ self.api.upload_file(
202
+ path_or_fileobj="README.md",
203
+ path_in_repo="README.md",
204
+ repo_id=self.repo_id,
205
+ token=self.token,
206
+ repo_type="model"
207
+ )
208
+
209
+ # Clean up local files
210
+ os.remove(pth_name)
211
+ os.remove("README.md")
212
+
213
+ print(f"Model successfully uploaded to {self.repo_id}")
214
+
215
+ except Exception as e:
216
+ print(f"Error uploading model: {e}")
217
 
218
  def prepare_input(image, device='cpu'):
219
  """Prepare image for inference"""