K00B404 commited on
Commit
cc04230
·
verified ·
1 Parent(s): 5010115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -25,14 +25,20 @@ LR = 0.0002
25
  dataset_id = "K00B404/pix2pix_flux_set"
26
  model_repo_id = "K00B404/pix2pix_flux"
27
 
 
28
  class Pix2PixDataset(torch.utils.data.Dataset):
29
  def __init__(self, ds):
 
30
  self.originals = [x for x in ds["train"] if x['label'] == 'original']
31
  self.targets = [x for x in ds["train"] if x['label'] == 'target']
32
 
33
- # Ensure original and target images match by their index
34
  assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
35
 
 
 
 
 
36
  def __len__(self):
37
  return len(self.originals)
38
 
 
25
  dataset_id = "K00B404/pix2pix_flux_set"
26
  model_repo_id = "K00B404/pix2pix_flux"
27
 
28
+ # Create dataset and dataloader
29
  class Pix2PixDataset(torch.utils.data.Dataset):
30
  def __init__(self, ds):
31
+ # Filter dataset for 'original' and 'target' images
32
  self.originals = [x for x in ds["train"] if x['label'] == 'original']
33
  self.targets = [x for x in ds["train"] if x['label'] == 'target']
34
 
35
+ # Ensure the number of original and target images match
36
  assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
37
 
38
+ # Debug: Print dataset size
39
+ print(f"Number of original images: {len(self.originals)}")
40
+ print(f"Number of target images: {len(self.targets)}")
41
+
42
  def __len__(self):
43
  return len(self.originals)
44