K00B404 commited on
Commit
1695057
·
verified ·
1 Parent(s): 911fc05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -67,8 +67,20 @@ class UNetWrapper:
67
  api = HfApi()
68
  # Create a repository if it doesn't exist
69
  create_repo(self.repo_id, exist_ok=True,token=self.token)
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Push the model's state dict to the Hugging Face Hub
71
- self.model.save_pretrained(self.repo_id,token=self.token) # You may need to implement this method
72
 
73
 
74
  # Training function
@@ -94,7 +106,7 @@ def train_model(epochs):
94
 
95
  criterion = nn.L1Loss()
96
  optimizer = optim.Adam(model.parameters(), lr=LR)
97
-
98
  # Training loop
99
  for epoch in range(epochs):
100
  for i, (original, target) in enumerate(dataloader):
@@ -110,7 +122,12 @@ def train_model(epochs):
110
  optimizer.step()
111
 
112
  if i % 100 == 0:
113
- print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
 
 
 
 
 
114
 
115
  # Return trained model
116
  return model
 
67
  api = HfApi()
68
  # Create a repository if it doesn't exist
69
  create_repo(self.repo_id, exist_ok=True,token=self.token)
70
+ pth_name = 'model_weights.pth'
71
+ torch.save(self.model.state_dict(), pth_name)
72
+ from huggingface_hub import upload_file
73
+
74
+ upload_file(
75
+ path_or_fileobj=pth_name,
76
+ path_in_repo=pth_name,
77
+ repo_id=self.repo_id,,
78
+ token=self.token,
79
+
80
+ )
81
+ #api.upload_file(repo_id=self.repo_id, path_in_repo=pth_name, path_or_fileobj=pth_name)
82
  # Push the model's state dict to the Hugging Face Hub
83
+ #self.model.save_pretrained(self.repo_id,token=self.token) # You may need to implement this method
84
 
85
 
86
  # Training function
 
106
 
107
  criterion = nn.L1Loss()
108
  optimizer = optim.Adam(model.parameters(), lr=LR)
109
+ output_text = []
110
  # Training loop
111
  for epoch in range(epochs):
112
  for i, (original, target) in enumerate(dataloader):
 
122
  optimizer.step()
123
 
124
  if i % 100 == 0:
125
+ status=f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
126
+ print(status)
127
+ output_text.append(status)
128
+
129
+ # Here you could also use a delay to simulate training time
130
+ yield "\n".join(output_text) # Send output to Gradio
131
 
132
  # Return trained model
133
  return model