Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -67,8 +67,20 @@ class UNetWrapper:
|
|
67 |
api = HfApi()
|
68 |
# Create a repository if it doesn't exist
|
69 |
create_repo(self.repo_id, exist_ok=True,token=self.token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# Push the model's state dict to the Hugging Face Hub
|
71 |
-
self.model.save_pretrained(self.repo_id,token=self.token) # You may need to implement this method
|
72 |
|
73 |
|
74 |
# Training function
|
@@ -94,7 +106,7 @@ def train_model(epochs):
|
|
94 |
|
95 |
criterion = nn.L1Loss()
|
96 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
97 |
-
|
98 |
# Training loop
|
99 |
for epoch in range(epochs):
|
100 |
for i, (original, target) in enumerate(dataloader):
|
@@ -110,7 +122,12 @@ def train_model(epochs):
|
|
110 |
optimizer.step()
|
111 |
|
112 |
if i % 100 == 0:
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
# Return trained model
|
116 |
return model
|
|
|
67 |
api = HfApi()
|
68 |
# Create a repository if it doesn't exist
|
69 |
create_repo(self.repo_id, exist_ok=True,token=self.token)
|
70 |
+
pth_name = 'model_weights.pth'
|
71 |
+
torch.save(self.model.state_dict(), pth_name)
|
72 |
+
from huggingface_hub import upload_file
|
73 |
+
|
74 |
+
upload_file(
|
75 |
+
path_or_fileobj=pth_name,
|
76 |
+
path_in_repo=pth_name,
|
77 |
+
repo_id=self.repo_id,,
|
78 |
+
token=self.token,
|
79 |
+
|
80 |
+
)
|
81 |
+
#api.upload_file(repo_id=self.repo_id, path_in_repo=pth_name, path_or_fileobj=pth_name)
|
82 |
# Push the model's state dict to the Hugging Face Hub
|
83 |
+
#self.model.save_pretrained(self.repo_id,token=self.token) # You may need to implement this method
|
84 |
|
85 |
|
86 |
# Training function
|
|
|
106 |
|
107 |
criterion = nn.L1Loss()
|
108 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
109 |
+
output_text = []
|
110 |
# Training loop
|
111 |
for epoch in range(epochs):
|
112 |
for i, (original, target) in enumerate(dataloader):
|
|
|
122 |
optimizer.step()
|
123 |
|
124 |
if i % 100 == 0:
|
125 |
+
status=f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}"
|
126 |
+
print(status)
|
127 |
+
output_text.append(status)
|
128 |
+
|
129 |
+
# Here you could also use a delay to simulate training time
|
130 |
+
yield "\n".join(output_text) # Send output to Gradio
|
131 |
|
132 |
# Return trained model
|
133 |
return model
|