Ali Mohsin commited on
Commit
25bdf34
·
1 Parent(s): 2c856cd

final prod

Browse files
.gitignore ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be added to the global gitignore or merged into this project gitignore. For a PyCharm
158
+ # project, it is recommended to include the following files:
159
+ # .idea/
160
+ # *.iml
161
+ # *.ipr
162
+ # *.iws
163
+ .idea/
164
+ *.iml
165
+ *.ipr
166
+ *.iws
167
+
168
+ # VS Code
169
+ .vscode/
170
+ *.code-workspace
171
+
172
+ # Sublime Text
173
+ *.sublime-project
174
+ *.sublime-workspace
175
+
176
+ # Vim
177
+ *.swp
178
+ *.swo
179
+ *~
180
+
181
+ # Emacs
182
+ *~
183
+ \#*\#
184
+ /.emacs.desktop
185
+ /.emacs.desktop.lock
186
+ *.elc
187
+ auto-save-list
188
+ tramp
189
+ .\#*
190
+
191
+ # macOS
192
+ .DS_Store
193
+ .AppleDouble
194
+ .LSOverride
195
+ Icon
196
+ ._*
197
+ .DocumentRevisions-V100
198
+ .fseventsd
199
+ .Spotlight-V100
200
+ .TemporaryItems
201
+ .Trashes
202
+ .VolumeIcon.icns
203
+ .com.apple.timemachine.donotpresent
204
+ .AppleDB
205
+ .AppleDesktop
206
+ Network Trash Folder
207
+ Temporary Items
208
+ .apdisk
209
+
210
+ # Windows
211
+ Thumbs.db
212
+ Thumbs.db:encryptable
213
+ ehthumbs.db
214
+ ehthumbs_vista.db
215
+ *.tmp
216
+ *.temp
217
+ *.bak
218
+ *.swp
219
+ *~.nib
220
+ local.properties
221
+ .settings/
222
+ .loadpath
223
+ .recommenders
224
+ .target/
225
+ .metadata
226
+ .factorypath
227
+ .buildpath
228
+ .project
229
+ .classpath
230
+ *.launch
231
+ .pydevproject
232
+ .cproject
233
+ .autotools
234
+ .factorypath
235
+ .buildpath
236
+ .target
237
+ .tern-project
238
+ .idea/
239
+ *.iml
240
+ *.ipr
241
+ *.iws
242
+ .settings/
243
+ .loadpath
244
+ .recommenders
245
+ .target/
246
+ .metadata
247
+ .factorypath
248
+ .buildpath
249
+ .project
250
+ .classpath
251
+ *.launch
252
+ .pydevproject
253
+ .cproject
254
+ .autotools
255
+ .factorypath
256
+ .buildpath
257
+ .target
258
+ .tern-project
259
+ .idea/
260
+ *.iml
261
+ *.ipr
262
+ *.iws
263
+ .settings/
264
+ .loadpath
265
+ .recommenders
266
+ .target/
267
+ .metadata
268
+ .factorypath
269
+ .buildpath
270
+ .project
271
+ .classpath
272
+ *.launch
273
+ .pydevproject
274
+ .cproject
275
+ .autotools
276
+ .factorypath
277
+ .buildpath
278
+ .target
279
+ .tern-project
280
+
281
+ # Linux
282
+ *~
283
+ .fuse_hidden*
284
+ .directory
285
+ .Trash-*
286
+ .nfs*
287
+
288
+ # Machine Learning / Deep Learning specific
289
+ # Model checkpoints and weights
290
+ *.pth
291
+ *.pt
292
+ *.ckpt
293
+ *.h5
294
+ *.hdf5
295
+ *.pb
296
+ *.pkl
297
+ *.pickle
298
+ *.joblib
299
+ *.model
300
+ *.weights
301
+ *.bin
302
+ *.safetensors
303
+
304
+ # Training logs and outputs
305
+ logs/
306
+ runs/
307
+ wandb/
308
+ tensorboard/
309
+ lightning_logs/
310
+ mlruns/
311
+ outputs/
312
+ checkpoints/
313
+ models/
314
+ experiments/
315
+ results/
316
+ artifacts/
317
+
318
+ # Data files (large datasets)
319
+ data/
320
+ datasets/
321
+ *.csv
322
+ *.tsv
323
+ *.json
324
+ *.jsonl
325
+ *.parquet
326
+ *.feather
327
+ *.arrow
328
+ *.h5
329
+ *.hdf5
330
+ *.npz
331
+ *.npy
332
+ *.mat
333
+ *.pkl
334
+ *.pickle
335
+
336
+ # Image files (if not needed in repo)
337
+ *.jpg
338
+ *.jpeg
339
+ *.png
340
+ *.gif
341
+ *.bmp
342
+ *.tiff
343
+ *.tif
344
+ *.webp
345
+ *.svg
346
+ *.ico
347
+
348
+ # Video files
349
+ *.mp4
350
+ *.avi
351
+ *.mov
352
+ *.wmv
353
+ *.flv
354
+ *.webm
355
+ *.mkv
356
+
357
+ # Audio files
358
+ *.mp3
359
+ *.wav
360
+ *.flac
361
+ *.aac
362
+ *.ogg
363
+ *.wma
364
+
365
+ # Archive files
366
+ *.zip
367
+ *.tar
368
+ *.tar.gz
369
+ *.tar.bz2
370
+ *.tar.xz
371
+ *.rar
372
+ *.7z
373
+ *.gz
374
+ *.bz2
375
+ *.xz
376
+
377
+ # Hugging Face specific
378
+ .cache/
379
+ huggingface/
380
+ transformers_cache/
381
+ datasets_cache/
382
+
383
+ # Jupyter notebook checkpoints
384
+ .ipynb_checkpoints/
385
+
386
+ # Temporary files
387
+ tmp/
388
+ temp/
389
+ .tmp/
390
+ .temp/
391
+
392
+ # Configuration files with secrets
393
+ .env
394
+ .env.local
395
+ .env.production
396
+ .env.staging
397
+ config.ini
398
+ secrets.json
399
+ credentials.json
400
+ *.key
401
+ *.pem
402
+ *.crt
403
+ *.p12
404
+ *.pfx
405
+
406
+ # IDE and editor files
407
+ .vscode/
408
+ .idea/
409
+ *.swp
410
+ *.swo
411
+ *~
412
+ .project
413
+ .pydevproject
414
+ .settings/
415
+
416
+ # OS generated files
417
+ .DS_Store
418
+ .DS_Store?
419
+ ._*
420
+ .Spotlight-V100
421
+ .Trashes
422
+ ehthumbs.db
423
+ Thumbs.db
424
+
425
+ # Project specific
426
+ # Exclude large model files and datasets
427
+ models/exports/
428
+ data/Polyvore/
429
+ *.pth
430
+ *.pt
431
+ *.ckpt
432
+
433
+ # Exclude generated files
434
+ __pycache__/
435
+ *.pyc
436
+ *.pyo
437
+ *.pyd
438
+ .Python
439
+ build/
440
+ develop-eggs/
441
+ dist/
442
+ downloads/
443
+ eggs/
444
+ .eggs/
445
+ lib/
446
+ lib64/
447
+ parts/
448
+ sdist/
449
+ var/
450
+ wheels/
451
+
452
+ # Exclude virtual environments
453
+ venv/
454
+ env/
455
+ ENV/
456
+ .venv/
457
+ .env/
458
+
459
+ # Exclude test outputs
460
+ .pytest_cache/
461
+ .coverage
462
+ htmlcov/
463
+ .tox/
464
+ .nox/
465
+
466
+ # Exclude documentation builds
467
+ docs/_build/
468
+ site/
469
+
470
+ # Exclude temporary files
471
+ *.tmp
472
+ *.temp
473
+ *.bak
474
+ *.swp
475
+ *~
476
+
477
+ # Exclude logs
478
+ *.log
479
+ logs/
480
+
481
+ # Exclude cache directories
482
+ .cache/
483
+ .pytest_cache/
484
+ .mypy_cache/
485
+ .dmypy.json
486
+ dmypy.json
487
+
488
+ # Exclude coverage reports
489
+ .coverage
490
+ .coverage.*
491
+ coverage.xml
492
+ *.cover
493
+ .hypothesis/
494
+
495
+ # Exclude profiling data
496
+ .prof
497
+
498
+ # Exclude Jupyter notebook checkpoints
499
+ .ipynb_checkpoints/
500
+
501
+ # Exclude IPython
502
+ profile_default/
503
+ ipython_config.py
504
+
505
+ # Exclude pyenv
506
+ .python-version
507
+
508
+ # Exclude pipenv
509
+ Pipfile.lock
510
+
511
+ # Exclude poetry
512
+ poetry.lock
513
+
514
+ # Exclude pdm
515
+ pdm.lock
516
+ .pdm.toml
517
+
518
+ # Exclude PEP 582
519
+ __pypackages__/
520
+
521
+ # Exclude Celery
522
+ celerybeat-schedule
523
+ celerybeat.pid
524
+
525
+ # Exclude SageMath
526
+ *.sage.py
527
+
528
+ # Exclude Spyder
529
+ .spyderproject
530
+ .spyproject
531
+
532
+ # Exclude Rope
533
+ .ropeproject
534
+
535
+ # Exclude mkdocs
536
+ /site
537
+
538
+ # Exclude mypy
539
+ .mypy_cache/
540
+ .dmypy.json
541
+ dmypy.json
542
+
543
+ # Exclude Pyre
544
+ .pyre/
545
+
546
+ # Exclude pytype
547
+ .pytype/
548
+
549
+ # Exclude Cython
550
+ cython_debug/
551
+
552
+ # Exclude PyCharm
553
+ .idea/
554
+ *.iml
555
+ *.ipr
556
+ *.iws
557
+
558
+ # Exclude VS Code
559
+ .vscode/
560
+ *.code-workspace
561
+
562
+ # Exclude Sublime Text
563
+ *.sublime-project
564
+ *.sublime-workspace
565
+
566
+ # Exclude Vim
567
+ *.swp
568
+ *.swo
569
+ *~
570
+
571
+ # Exclude Emacs
572
+ *~
573
+ \#*\#
574
+ /.emacs.desktop
575
+ /.emacs.desktop.lock
576
+ *.elc
577
+ auto-save-list
578
+ tramp
579
+ .\#*
580
+
581
+ # Exclude macOS
582
+ .DS_Store
583
+ .AppleDouble
584
+ .LSOverride
585
+ Icon
586
+ ._*
587
+ .DocumentRevisions-V100
588
+ .fseventsd
589
+ .Spotlight-V100
590
+ .TemporaryItems
591
+ .Trashes
592
+ .VolumeIcon.icns
593
+ .com.apple.timemachine.donotpresent
594
+ .AppleDB
595
+ .AppleDesktop
596
+ Network Trash Folder
597
+ Temporary Items
598
+ .apdisk
599
+
600
+ # Exclude Windows
601
+ Thumbs.db
602
+ Thumbs.db:encryptable
603
+ ehthumbs.db
604
+ ehthumbs_vista.db
605
+ *.tmp
606
+ *.temp
607
+ *.bak
608
+ *.swp
609
+ *~.nib
610
+ local.properties
611
+ .settings/
612
+ .loadpath
613
+ .recommenders
614
+ .target/
615
+ .metadata
616
+ .factorypath
617
+ .buildpath
618
+ .project
619
+ .classpath
620
+ *.launch
621
+ .pydevproject
622
+ .cproject
623
+ .autotools
624
+ .factorypath
625
+ .buildpath
626
+ .target
627
+ .tern-project
628
+ .idea/
629
+ *.iml
630
+ *.ipr
631
+ *.iws
632
+ .settings/
633
+ .loadpath
634
+ .recommenders
635
+ .target/
636
+ .metadata
637
+ .factorypath
638
+ .buildpath
639
+ .project
640
+ .classpath
641
+ *.launch
642
+ .pydevproject
643
+ .cproject
644
+ .autotools
645
+ .factorypath
646
+ .buildpath
647
+ .target
648
+ .tern-project
649
+
650
+ # Exclude Linux
651
+ *~
652
+ .fuse_hidden*
653
+ .directory
654
+ .Trash-*
655
+ .nfs*
656
+
657
+ # Exclude Machine Learning files
658
+ *.pth
659
+ *.pt
660
+ *.ckpt
661
+ *.h5
662
+ *.hdf5
663
+ *.pb
664
+ *.pkl
665
+ *.pickle
666
+ *.joblib
667
+ *.model
668
+ *.weights
669
+ *.bin
670
+ *.safetensors
671
+
672
+ # Exclude training outputs
673
+ logs/
674
+ runs/
675
+ wandb/
676
+ tensorboard/
677
+ lightning_logs/
678
+ mlruns/
679
+ outputs/
680
+ checkpoints/
681
+ models/
682
+ experiments/
683
+ results/
684
+ artifacts/
685
+
686
+ # Exclude data files
687
+ data/
688
+ datasets/
689
+ *.csv
690
+ *.tsv
691
+ *.json
692
+ *.jsonl
693
+ *.parquet
694
+ *.feather
695
+ *.arrow
696
+ *.h5
697
+ *.hdf5
698
+ *.npz
699
+ *.npy
700
+ *.mat
701
+ *.pkl
702
+ *.pickle
703
+
704
+ # Exclude media files
705
+ *.jpg
706
+ *.jpeg
707
+ *.png
708
+ *.gif
709
+ *.bmp
710
+ *.tiff
711
+ *.tif
712
+ *.webp
713
+ *.svg
714
+ *.ico
715
+ *.mp4
716
+ *.avi
717
+ *.mov
718
+ *.wmv
719
+ *.flv
720
+ *.webm
721
+ *.mkv
722
+ *.mp3
723
+ *.wav
724
+ *.flac
725
+ *.aac
726
+ *.ogg
727
+ *.wma
728
+
729
+ # Exclude archives
730
+ *.zip
731
+ *.tar
732
+ *.tar.gz
733
+ *.tar.bz2
734
+ *.tar.xz
735
+ *.rar
736
+ *.7z
737
+ *.gz
738
+ *.bz2
739
+ *.xz
740
+
741
+ # Exclude Hugging Face cache
742
+ .cache/
743
+ huggingface/
744
+ transformers_cache/
745
+ datasets_cache/
746
+
747
+ # Exclude temporary files
748
+ tmp/
749
+ temp/
750
+ .tmp/
751
+ .temp/
752
+
753
+ # Exclude secrets
754
+ .env
755
+ .env.local
756
+ .env.production
757
+ .env.staging
758
+ config.ini
759
+ secrets.json
760
+ credentials.json
761
+ *.key
762
+ *.pem
763
+ *.crt
764
+ *.p12
765
+ *.pfx
app.py CHANGED
@@ -152,9 +152,9 @@ def push_splits_to_hf(token, username):
152
  return "❌ Please provide HF token and username"
153
 
154
  try:
155
- from utils.hf_hub_integration import create_hf_integration
156
- hf = create_hf_integration(token)
157
- result = hf.upload_splits_to_hf()
158
 
159
  if result.get("success"):
160
  return f"✅ Successfully uploaded splits to {username}/Dressify-Helper"
@@ -169,9 +169,9 @@ def push_models_to_hf(token, username):
169
  return "❌ Please provide HF token and username"
170
 
171
  try:
172
- from utils.hf_hub_integration import create_hf_integration
173
- hf = create_hf_integration(token)
174
- result = hf.upload_models_to_hf()
175
 
176
  if result.get("success"):
177
  return f"✅ Successfully uploaded models to {username}/dressify-models"
@@ -186,9 +186,9 @@ def push_everything_to_hf(token, username):
186
  return "❌ Please provide HF token and username"
187
 
188
  try:
189
- from utils.hf_hub_integration import create_hf_integration
190
- hf = create_hf_integration(token)
191
- result = hf.upload_everything_to_hf()
192
 
193
  if result.get("success"):
194
  return f"✅ Successfully uploaded everything to HF Hub"
@@ -271,13 +271,15 @@ def _background_bootstrap():
271
  if not os.path.exists(resnet_ckpt):
272
  BOOT_STATUS = "training-resnet"
273
  subprocess.run([
274
- "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3",
 
275
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
276
  ], check=False)
277
  if not os.path.exists(vit_ckpt):
278
  BOOT_STATUS = "training-vit"
279
  subprocess.run([
280
- "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
 
281
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
282
  ], check=False)
283
  service.reload_models()
@@ -600,9 +602,9 @@ def start_training_advanced(
600
  if hf_token:
601
  log_message += "📤 Auto-uploading artifacts to Hugging Face Hub...\n"
602
  try:
603
- from utils.hf_hub_integration import create_hf_integration
604
- hf = create_hf_integration(hf_token)
605
- result = hf.upload_everything_to_hf()
606
  if result.get("success"):
607
  log_message += "✅ Successfully uploaded to HF Hub!\n"
608
  log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n"
@@ -647,9 +649,10 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
647
 
648
  # Train ResNet first and wait for completion
649
  log_message += f"\n🚀 Starting ResNet training on {dataset_size} samples...\n"
650
- resnet_result = subprocess.run([
651
- "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
652
- "--batch_size", "8", "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
 
653
  ] + dataset_args, capture_output=True, text=True, check=False)
654
 
655
  if resnet_result.returncode == 0:
@@ -674,8 +677,9 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
674
 
675
  log_message += f"\n🚀 Starting ViT training on {dataset_size} samples...\n"
676
  vit_result = subprocess.run([
677
- "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
678
- "--batch_size", "8", "--export", os.path.join(export_dir, "vit_outfit_model.pth")
 
679
  ] + dataset_args, capture_output=True, text=True, check=False)
680
 
681
  if vit_result.returncode == 0:
@@ -692,9 +696,9 @@ def start_training_simple(dataset_size: str, res_epochs: int, vit_epochs: int):
692
  if hf_token:
693
  log_message += "\n📤 Auto-uploading artifacts to Hugging Face Hub...\n"
694
  try:
695
- from utils.hf_hub_integration import create_hf_integration
696
- hf = create_hf_integration(hf_token)
697
- result = hf.upload_everything_to_hf()
698
  if result.get("success"):
699
  log_message += "✅ Successfully uploaded to HF Hub!\n"
700
  log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n"
 
152
  return "❌ Please provide HF token and username"
153
 
154
  try:
155
+ from utils.hf_utils import HFModelManager
156
+ hf = HFModelManager(token=token, username=username)
157
+ result = hf.upload_model("splits", "Dressify-Helper")
158
 
159
  if result.get("success"):
160
  return f"✅ Successfully uploaded splits to {username}/Dressify-Helper"
 
169
  return "❌ Please provide HF token and username"
170
 
171
  try:
172
+ from utils.hf_utils import HFModelManager
173
+ hf = HFModelManager(token=token, username=username)
174
+ result = hf.upload_model("models", "dressify-models")
175
 
176
  if result.get("success"):
177
  return f"✅ Successfully uploaded models to {username}/dressify-models"
 
186
  return "❌ Please provide HF token and username"
187
 
188
  try:
189
+ from utils.hf_utils import HFModelManager
190
+ hf = HFModelManager(token=token, username=username)
191
+ result = hf.upload_model("everything", "dressify-complete")
192
 
193
  if result.get("success"):
194
  return f"✅ Successfully uploaded everything to HF Hub"
 
271
  if not os.path.exists(resnet_ckpt):
272
  BOOT_STATUS = "training-resnet"
273
  subprocess.run([
274
+ "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "50",
275
+ "--batch_size", "16", "--lr", "1e-3", "--early_stopping_patience", "10",
276
  "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
277
  ], check=False)
278
  if not os.path.exists(vit_ckpt):
279
  BOOT_STATUS = "training-vit"
280
  subprocess.run([
281
+ "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "50",
282
+ "--batch_size", "16", "--lr", "5e-4", "--early_stopping_patience", "10",
283
  "--export", os.path.join(export_dir, "vit_outfit_model.pth")
284
  ], check=False)
285
  service.reload_models()
 
602
  if hf_token:
603
  log_message += "📤 Auto-uploading artifacts to Hugging Face Hub...\n"
604
  try:
605
+ from utils.hf_utils import HFModelManager
606
+ hf = HFModelManager(token=hf_token, username="Stylique")
607
+ result = hf.upload_model("everything", "dressify-complete")
608
  if result.get("success"):
609
  log_message += "✅ Successfully uploaded to HF Hub!\n"
610
  log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n"
 
649
 
650
  # Train ResNet first and wait for completion
651
  log_message += f"\n🚀 Starting ResNet training on {dataset_size} samples...\n"
652
+ resnet_result = subprocess.run([
653
+ "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", "50",
654
+ "--batch_size", "16", "--lr", "1e-3", "--early_stopping_patience", "10",
655
+ "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
656
  ] + dataset_args, capture_output=True, text=True, check=False)
657
 
658
  if resnet_result.returncode == 0:
 
677
 
678
  log_message += f"\n🚀 Starting ViT training on {dataset_size} samples...\n"
679
  vit_result = subprocess.run([
680
+ "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", "50",
681
+ "--batch_size", "16", "--lr", "5e-4", "--early_stopping_patience", "10",
682
+ "--export", os.path.join(export_dir, "vit_outfit_model.pth")
683
  ] + dataset_args, capture_output=True, text=True, check=False)
684
 
685
  if vit_result.returncode == 0:
 
696
  if hf_token:
697
  log_message += "\n📤 Auto-uploading artifacts to Hugging Face Hub...\n"
698
  try:
699
+ from utils.hf_utils import HFModelManager
700
+ hf = HFModelManager(token=hf_token, username="Stylique")
701
+ result = hf.upload_model("everything", "dressify-complete")
702
  if result.get("success"):
703
  log_message += "✅ Successfully uploaded to HF Hub!\n"
704
  log_message += "🔗 Models: https://huggingface.co/Stylique/dressify-models\n"
inference.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image
 
8
 
9
  from utils.transforms import build_inference_transform
10
  from models.resnet_embedder import ResNetItemEmbedder
@@ -40,7 +41,27 @@ class InferenceService:
40
  model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
41
  if strategy == "random":
42
  return model
43
- # prefer best if present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
45
  if os.path.exists(best_path):
46
  ckpt_to_use = best_path
@@ -48,11 +69,9 @@ class InferenceService:
48
  ckpt_to_use = ckpt_path
49
  if os.path.exists(ckpt_to_use):
50
  state = torch.load(ckpt_to_use, map_location="cpu")
51
- # accept either full state_dict or {"state_dict": ...}
52
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
53
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
54
- if len(unexpected) == 0:
55
- return model
56
  return model
57
 
58
  def _load_vit(self) -> nn.Module:
@@ -61,6 +80,27 @@ class InferenceService:
61
  model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
62
  if strategy == "random":
63
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
65
  ckpt_to_use = best_path if os.path.exists(best_path) else ckpt_path
66
  if os.path.exists(ckpt_to_use):
@@ -118,32 +158,72 @@ class InferenceService:
118
  min_size, max_size = 4, 6
119
  ids = list(range(len(proc_items)))
120
 
121
- # Slot-aware pools from categories (best-effort)
122
  def cat_str(i: int) -> str:
123
  return (proc_items[i].get("category") or "").lower()
124
 
125
- uppers = [i for i in ids if any(k in cat_str(i) for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"])]
126
- bottoms = [i for i in ids if any(k in cat_str(i) for k in ["pant", "trouser", "jean", "skirt", "short"])]
127
- shoes = [i for i in ids if any(k in cat_str(i) for k in ["shoe", "sneaker", "boot", "heel"])]
128
- accs = [i for i in ids if any(k in cat_str(i) for k in ["watch", "belt", "ring", "bracelet", "accessor", "bag", "hat"])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  candidates: List[List[int]] = []
131
  num_samples = max(num_outfits * 12, 24)
 
 
 
 
 
 
 
 
132
  for _ in range(num_samples):
133
  if uppers and bottoms and shoes and accs:
 
134
  subset = [
135
  int(rng.choice(uppers)),
136
  int(rng.choice(bottoms)),
137
  int(rng.choice(shoes)),
138
  int(rng.choice(accs)),
139
  ]
140
- # Optional: add one more random distinct item
 
141
  remain = list(set(ids) - set(subset))
142
- if remain and rng.random() < 0.5:
143
- subset.append(int(rng.choice(remain)))
 
 
 
 
 
144
  else:
 
145
  k = int(rng.integers(min_size, max_size + 1))
146
- subset = list(map(int, rng.choice(ids, size=k, replace=False).tolist()))
 
 
 
 
 
 
 
 
 
147
  candidates.append(subset)
148
 
149
  # 3) Score using ViT
 
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
 
10
  from utils.transforms import build_inference_transform
11
  from models.resnet_embedder import ResNetItemEmbedder
 
41
  model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
42
  if strategy == "random":
43
  return model
44
+
45
+ # Try to download from Hugging Face Hub first
46
+ try:
47
+ print("🌐 Attempting to download ResNet from Hugging Face Hub...")
48
+ hf_path = hf_hub_download(
49
+ repo_id="Stylique/dressify-models",
50
+ filename="resnet_item_embedder_best.pth",
51
+ local_dir="models/exports",
52
+ local_dir_use_symlinks=False
53
+ )
54
+ print(f"📥 Downloaded ResNet from HF Hub: {hf_path}")
55
+ state = torch.load(hf_path, map_location="cpu")
56
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
57
+ model.load_state_dict(state_dict, strict=False)
58
+ return model
59
+ except Exception as e:
60
+ print(f"❌ Failed to download ResNet from HF Hub: {e}")
61
+ print("⚠️ WARNING: Using untrained ResNet model!")
62
+ print("🚨 Recommendations will not be meaningful without trained weights!")
63
+
64
+ # Fallback to local checkpoints
65
  best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
66
  if os.path.exists(best_path):
67
  ckpt_to_use = best_path
 
69
  ckpt_to_use = ckpt_path
70
  if os.path.exists(ckpt_to_use):
71
  state = torch.load(ckpt_to_use, map_location="cpu")
 
72
  state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
73
+ model.load_state_dict(state_dict, strict=False)
74
+ return model
 
75
  return model
76
 
77
  def _load_vit(self) -> nn.Module:
 
80
  model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
81
  if strategy == "random":
82
  return model
83
+
84
+ # Try to download from Hugging Face Hub first
85
+ try:
86
+ print("🌐 Attempting to download ViT from Hugging Face Hub...")
87
+ hf_path = hf_hub_download(
88
+ repo_id="Stylique/dressify-models",
89
+ filename="vit_outfit_model_best.pth",
90
+ local_dir="models/exports",
91
+ local_dir_use_symlinks=False
92
+ )
93
+ print(f"📥 Downloaded ViT from HF Hub: {hf_path}")
94
+ state = torch.load(hf_path, map_location="cpu")
95
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
96
+ model.load_state_dict(state_dict, strict=False)
97
+ return model
98
+ except Exception as e:
99
+ print(f"❌ Failed to download ViT from HF Hub: {e}")
100
+ print("⚠️ WARNING: Using untrained ViT model!")
101
+ print("🚨 Recommendations will not be meaningful without trained weights!")
102
+
103
+ # Fallback to local checkpoints
104
  best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
105
  ckpt_to_use = best_path if os.path.exists(best_path) else ckpt_path
106
  if os.path.exists(ckpt_to_use):
 
158
  min_size, max_size = 4, 6
159
  ids = list(range(len(proc_items)))
160
 
161
+ # Enhanced category-aware pools with diversity checks
162
  def cat_str(i: int) -> str:
163
  return (proc_items[i].get("category") or "").lower()
164
 
165
+ def get_category_type(cat: str) -> str:
166
+ """Map category to outfit slot type"""
167
+ if any(k in cat for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie", "sweater", "cardigan"]):
168
+ return "upper"
169
+ elif any(k in cat for k in ["pant", "trouser", "jean", "skirt", "short", "legging"]):
170
+ return "bottom"
171
+ elif any(k in cat for k in ["shoe", "sneaker", "boot", "heel", "sandal", "flat"]):
172
+ return "shoe"
173
+ elif any(k in cat for k in ["watch", "belt", "ring", "bracelet", "accessor", "bag", "hat", "scarf", "necklace"]):
174
+ return "accessory"
175
+ else:
176
+ return "other"
177
+
178
+ # Create category pools
179
+ uppers = [i for i in ids if get_category_type(cat_str(i)) == "upper"]
180
+ bottoms = [i for i in ids if get_category_type(cat_str(i)) == "bottom"]
181
+ shoes = [i for i in ids if get_category_type(cat_str(i)) == "shoe"]
182
+ accs = [i for i in ids if get_category_type(cat_str(i)) == "accessory"]
183
+ others = [i for i in ids if get_category_type(cat_str(i)) == "other"]
184
 
185
  candidates: List[List[int]] = []
186
  num_samples = max(num_outfits * 12, 24)
187
+
188
+ def has_category_diversity(subset: List[int]) -> bool:
189
+ """Check if subset has good category diversity"""
190
+ categories = [get_category_type(cat_str(i)) for i in subset]
191
+ unique_categories = set(categories)
192
+ # Require at least 3 different category types for good diversity
193
+ return len(unique_categories) >= 3
194
+
195
  for _ in range(num_samples):
196
  if uppers and bottoms and shoes and accs:
197
+ # Start with one item from each major category
198
  subset = [
199
  int(rng.choice(uppers)),
200
  int(rng.choice(bottoms)),
201
  int(rng.choice(shoes)),
202
  int(rng.choice(accs)),
203
  ]
204
+
205
+ # Add one more accessory or other item for variety
206
  remain = list(set(ids) - set(subset))
207
+ if remain and rng.random() < 0.7:
208
+ # Prefer accessories or other items
209
+ pref_items = [i for i in remain if get_category_type(cat_str(i)) in ["accessory", "other"]]
210
+ if pref_items:
211
+ subset.append(int(rng.choice(pref_items)))
212
+ else:
213
+ subset.append(int(rng.choice(remain)))
214
  else:
215
+ # Fallback: ensure category diversity
216
  k = int(rng.integers(min_size, max_size + 1))
217
+ attempts = 0
218
+ while attempts < 10: # Try to find diverse subset
219
+ subset = list(map(int, rng.choice(ids, size=k, replace=False).tolist()))
220
+ if has_category_diversity(subset):
221
+ break
222
+ attempts += 1
223
+ # If we can't find diverse subset, use what we have
224
+ if attempts >= 10:
225
+ subset = list(map(int, rng.choice(ids, size=k, replace=False).tolist()))
226
+
227
  candidates.append(subset)
228
 
229
  # 3) Score using ViT
train_resnet.py CHANGED
@@ -14,17 +14,20 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
  from data.polyvore import PolyvoreTripletDataset
15
  from models.resnet_embedder import ResNetItemEmbedder
16
  from utils.export import ensure_export_dir
 
17
  import json
18
 
19
 
20
  def parse_args() -> argparse.Namespace:
21
  p = argparse.ArgumentParser()
22
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
23
- p.add_argument("--epochs", type=int, default=20)
24
- p.add_argument("--batch_size", type=int, default=8)
25
  p.add_argument("--lr", type=float, default=1e-3)
26
  p.add_argument("--embedding_dim", type=int, default=512)
27
  p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
 
 
28
  return p.parse_args()
29
 
30
 
@@ -80,8 +83,12 @@ def main() -> None:
80
  export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
81
  best_loss = float("inf")
82
  history = []
 
 
 
83
 
84
  print(f"💾 Checkpoints will be saved to: {export_dir}")
 
85
 
86
  for epoch in range(args.epochs):
87
  model.train()
@@ -108,6 +115,14 @@ def main() -> None:
108
  loss.backward()
109
  optimizer.step()
110
 
 
 
 
 
 
 
 
 
111
  running_loss += loss.item()
112
  steps += 1
113
 
@@ -138,19 +153,53 @@ def main() -> None:
138
 
139
  history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
140
 
141
- if avg_loss < best_loss:
 
142
  best_loss = avg_loss
 
 
143
  best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
144
  torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
145
- print(f"🏆 New best model saved: {best_path}")
 
 
 
 
 
 
 
 
146
 
147
- # Write metrics
148
  metrics_path = os.path.join(export_dir, "resnet_metrics.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  with open(metrics_path, "w") as f:
150
- json.dump({"best_triplet_loss": best_loss, "history": history}, f)
151
 
152
- print(f"📊 Training completed! Best loss: {best_loss:.4f}")
153
- print(f"📈 Metrics saved to: {metrics_path}")
 
154
 
155
 
156
  if __name__ == "__main__":
 
14
  from data.polyvore import PolyvoreTripletDataset
15
  from models.resnet_embedder import ResNetItemEmbedder
16
  from utils.export import ensure_export_dir
17
+ from utils.advanced_metrics import AdvancedMetrics, calculate_triplet_metrics
18
  import json
19
 
20
 
21
  def parse_args() -> argparse.Namespace:
22
  p = argparse.ArgumentParser()
23
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
24
+ p.add_argument("--epochs", type=int, default=50)
25
+ p.add_argument("--batch_size", type=int, default=16)
26
  p.add_argument("--lr", type=float, default=1e-3)
27
  p.add_argument("--embedding_dim", type=int, default=512)
28
  p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
29
+ p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
30
+ p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
31
  return p.parse_args()
32
 
33
 
 
83
  export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
84
  best_loss = float("inf")
85
  history = []
86
+ patience_counter = 0
87
+ best_epoch = 0
88
+ metrics_collector = AdvancedMetrics()
89
 
90
  print(f"💾 Checkpoints will be saved to: {export_dir}")
91
+ print(f"🛑 Early stopping patience: {args.early_stopping_patience} epochs")
92
 
93
  for epoch in range(args.epochs):
94
  model.train()
 
115
  loss.backward()
116
  optimizer.step()
117
 
118
+ # Collect metrics
119
+ triplet_metrics = calculate_triplet_metrics(emb_a, emb_p, emb_n, margin=0.2)
120
+ metrics_collector.add_batch(
121
+ predictions=torch.ones(emb_a.size(0)), # Placeholder for compatibility
122
+ targets=torch.ones(emb_a.size(0)), # Placeholder for compatibility
123
+ embeddings=emb_a
124
+ )
125
+
126
  running_loss += loss.item()
127
  steps += 1
128
 
 
153
 
154
  history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
155
 
156
+ # Early stopping logic
157
+ if avg_loss < best_loss - args.min_delta:
158
  best_loss = avg_loss
159
+ best_epoch = epoch + 1
160
+ patience_counter = 0
161
  best_path = os.path.join(export_dir, "resnet_item_embedder_best.pth")
162
  torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss}, best_path)
163
+ print(f"🏆 New best model saved: {best_path} (loss: {avg_loss:.4f})")
164
+ else:
165
+ patience_counter += 1
166
+ print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
167
+
168
+ if patience_counter >= args.early_stopping_patience:
169
+ print(f"🛑 Early stopping triggered after {patience_counter} epochs without improvement")
170
+ print(f"🏆 Best model was at epoch {best_epoch} with loss {best_loss:.4f}")
171
+ break
172
 
173
+ # Write comprehensive metrics
174
  metrics_path = os.path.join(export_dir, "resnet_metrics.json")
175
+
176
+ # Get advanced metrics
177
+ advanced_metrics = metrics_collector.calculate_all_metrics()
178
+
179
+ final_metrics = {
180
+ "best_triplet_loss": best_loss,
181
+ "best_epoch": best_epoch,
182
+ "total_epochs": epoch + 1,
183
+ "early_stopping_triggered": patience_counter >= args.early_stopping_patience,
184
+ "patience_counter": patience_counter,
185
+ "training_config": {
186
+ "epochs": args.epochs,
187
+ "batch_size": args.batch_size,
188
+ "learning_rate": args.lr,
189
+ "embedding_dim": args.embedding_dim,
190
+ "early_stopping_patience": args.early_stopping_patience,
191
+ "min_delta": args.min_delta
192
+ },
193
+ "history": history,
194
+ "advanced_metrics": advanced_metrics
195
+ }
196
+
197
  with open(metrics_path, "w") as f:
198
+ json.dump(final_metrics, f, indent=2)
199
 
200
+ print(f"📊 Training completed! Best loss: {best_loss:.4f} at epoch {best_epoch}")
201
+ print(f"📈 Comprehensive metrics saved to: {metrics_path}")
202
+ print(f"🔬 Advanced metrics: {advanced_metrics['summary']}")
203
 
204
 
205
  if __name__ == "__main__":
train_vit_triplet.py CHANGED
@@ -15,19 +15,22 @@ from data.polyvore import PolyvoreOutfitTripletDataset
15
  from models.vit_outfit import OutfitCompatibilityModel
16
  from models.resnet_embedder import ResNetItemEmbedder
17
  from utils.export import ensure_export_dir
 
18
  import json
19
 
20
 
21
  def parse_args() -> argparse.Namespace:
22
  p = argparse.ArgumentParser()
23
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
24
- p.add_argument("--epochs", type=int, default=30)
25
- p.add_argument("--batch_size", type=int, default=8)
26
  p.add_argument("--lr", type=float, default=5e-4)
27
  p.add_argument("--embedding_dim", type=int, default=512)
28
  p.add_argument("--triplet_margin", type=float, default=0.3)
29
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
30
  p.add_argument("--eval_every", type=int, default=1)
 
 
31
  return p.parse_args()
32
 
33
 
@@ -105,8 +108,12 @@ def main() -> None:
105
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
106
  best_loss = float("inf")
107
  hist = []
 
 
 
108
 
109
  print(f"💾 Checkpoints will be saved to: {export_dir}")
 
110
 
111
  for epoch in range(args.epochs):
112
  model.train()
@@ -145,6 +152,16 @@ def main() -> None:
145
  loss.backward()
146
  optimizer.step()
147
 
 
 
 
 
 
 
 
 
 
 
148
  running_loss += loss.item()
149
  steps += 1
150
 
@@ -210,25 +227,58 @@ def main() -> None:
210
  if val_loss is not None:
211
  print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
212
  hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
213
- if val_loss < best_loss:
 
 
214
  best_loss = val_loss
 
 
215
  best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
216
  torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
217
- print(f"🏆 New best model saved: {best_path}")
 
 
 
 
 
 
 
 
218
  else:
219
  print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
220
  hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})
221
 
222
- # Write metrics
223
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
224
- payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  with open(metrics_path, "w") as f:
226
- json.dump(payload, f)
227
 
228
- print(f"📊 Training completed!")
229
- if best_loss != float("inf"):
230
- print(f"🏆 Best validation loss: {best_loss:.4f}")
231
- print(f"📈 Metrics saved to: {metrics_path}")
232
 
233
 
234
  if __name__ == "__main__":
 
15
  from models.vit_outfit import OutfitCompatibilityModel
16
  from models.resnet_embedder import ResNetItemEmbedder
17
  from utils.export import ensure_export_dir
18
+ from utils.advanced_metrics import AdvancedMetrics, calculate_outfit_compatibility_metrics
19
  import json
20
 
21
 
22
  def parse_args() -> argparse.Namespace:
23
  p = argparse.ArgumentParser()
24
  p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore"))
25
+ p.add_argument("--epochs", type=int, default=50)
26
+ p.add_argument("--batch_size", type=int, default=16)
27
  p.add_argument("--lr", type=float, default=5e-4)
28
  p.add_argument("--embedding_dim", type=int, default=512)
29
  p.add_argument("--triplet_margin", type=float, default=0.3)
30
  p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
31
  p.add_argument("--eval_every", type=int, default=1)
32
+ p.add_argument("--early_stopping_patience", type=int, default=10, help="Early stopping patience")
33
+ p.add_argument("--min_delta", type=float, default=1e-4, help="Minimum change to qualify as improvement")
34
  return p.parse_args()
35
 
36
 
 
108
  export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
109
  best_loss = float("inf")
110
  hist = []
111
+ patience_counter = 0
112
+ best_epoch = 0
113
+ metrics_collector = AdvancedMetrics()
114
 
115
  print(f"💾 Checkpoints will be saved to: {export_dir}")
116
+ print(f"🛑 Early stopping patience: {args.early_stopping_patience} epochs")
117
 
118
  for epoch in range(args.epochs):
119
  model.train()
 
152
  loss.backward()
153
  optimizer.step()
154
 
155
+ # Collect metrics
156
+ compatibility_metrics = calculate_outfit_compatibility_metrics(
157
+ torch.cat([ea, ep, en], dim=0),
158
+ torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
159
+ )
160
+ metrics_collector.add_batch(
161
+ predictions=torch.cat([ea, ep, en], dim=0),
162
+ targets=torch.cat([torch.ones(ea.size(0)), torch.ones(ep.size(0)), torch.zeros(en.size(0))], dim=0)
163
+ )
164
+
165
  running_loss += loss.item()
166
  steps += 1
167
 
 
227
  if val_loss is not None:
228
  print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
229
  hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss), "val_triplet_loss": float(val_loss)})
230
+
231
+ # Early stopping logic
232
+ if val_loss < best_loss - args.min_delta:
233
  best_loss = val_loss
234
+ best_epoch = epoch + 1
235
+ patience_counter = 0
236
  best_path = os.path.join(export_dir, "vit_outfit_model_best.pth")
237
  torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "loss": avg_loss, "val_loss": val_loss}, best_path)
238
+ print(f"🏆 New best model saved: {best_path} (val_loss: {val_loss:.4f})")
239
+ else:
240
+ patience_counter += 1
241
+ print(f"⏳ No improvement for {patience_counter} epochs (best: {best_loss:.4f} at epoch {best_epoch})")
242
+
243
+ if patience_counter >= args.early_stopping_patience:
244
+ print(f"🛑 Early stopping triggered after {patience_counter} epochs without improvement")
245
+ print(f"🏆 Best model was at epoch {best_epoch} with val_loss {best_loss:.4f}")
246
+ break
247
  else:
248
  print(f"✅ Epoch {epoch+1}/{args.epochs} triplet_loss={avg_loss:.4f} saved -> {out_path}")
249
  hist.append({"epoch": epoch + 1, "triplet_loss": float(avg_loss)})
250
 
251
+ # Write comprehensive metrics
252
  metrics_path = os.path.join(export_dir, "vit_metrics.json")
253
+
254
+ # Get advanced metrics
255
+ advanced_metrics = metrics_collector.calculate_all_metrics()
256
+
257
+ final_metrics = {
258
+ "best_val_triplet_loss": best_loss if best_loss != float("inf") else None,
259
+ "best_epoch": best_epoch,
260
+ "total_epochs": epoch + 1,
261
+ "early_stopping_triggered": patience_counter >= args.early_stopping_patience,
262
+ "patience_counter": patience_counter,
263
+ "training_config": {
264
+ "epochs": args.epochs,
265
+ "batch_size": args.batch_size,
266
+ "learning_rate": args.lr,
267
+ "embedding_dim": args.embedding_dim,
268
+ "triplet_margin": args.triplet_margin,
269
+ "early_stopping_patience": args.early_stopping_patience,
270
+ "min_delta": args.min_delta
271
+ },
272
+ "history": hist,
273
+ "advanced_metrics": advanced_metrics
274
+ }
275
+
276
  with open(metrics_path, "w") as f:
277
+ json.dump(final_metrics, f, indent=2)
278
 
279
+ print(f"📊 Training completed! Best val_loss: {best_loss:.4f} at epoch {best_epoch}")
280
+ print(f"📈 Comprehensive metrics saved to: {metrics_path}")
281
+ print(f"🔬 Advanced metrics: {advanced_metrics['summary']}")
 
282
 
283
 
284
  if __name__ == "__main__":
utils/advanced_metrics.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced metrics calculation for outfit recommendation system.
3
+ Includes accuracy, precision, recall, F1 score, and other research-grade metrics.
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Any, Tuple
10
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
11
+ import json
12
+ from pathlib import Path
13
+
14
+
15
+ class AdvancedMetrics:
16
+ """Calculate comprehensive metrics for outfit recommendation models."""
17
+
18
+ def __init__(self):
19
+ self.reset()
20
+
21
+ def reset(self):
22
+ """Reset all metrics."""
23
+ self.predictions = []
24
+ self.targets = []
25
+ self.scores = []
26
+ self.embeddings = []
27
+ self.outfit_scores = []
28
+
29
+ def add_batch(self, predictions: torch.Tensor, targets: torch.Tensor,
30
+ scores: torch.Tensor = None, embeddings: torch.Tensor = None):
31
+ """Add a batch of predictions and targets."""
32
+ self.predictions.extend(predictions.cpu().numpy())
33
+ self.targets.extend(targets.cpu().numpy())
34
+
35
+ if scores is not None:
36
+ self.scores.extend(scores.cpu().numpy())
37
+
38
+ if embeddings is not None:
39
+ self.embeddings.extend(embeddings.cpu().numpy())
40
+
41
+ def add_outfit_scores(self, outfit_scores: List[float]):
42
+ """Add outfit compatibility scores."""
43
+ self.outfit_scores.extend(outfit_scores)
44
+
45
+ def calculate_classification_metrics(self) -> Dict[str, float]:
46
+ """Calculate classification metrics."""
47
+ if not self.predictions or not self.targets:
48
+ return {}
49
+
50
+ preds = np.array(self.predictions)
51
+ targets = np.array(self.targets)
52
+
53
+ # Convert to binary if needed
54
+ if preds.max() > 1:
55
+ preds = (preds > 0.5).astype(int)
56
+
57
+ if targets.max() > 1:
58
+ targets = (targets > 0.5).astype(int)
59
+
60
+ accuracy = accuracy_score(targets, preds)
61
+ precision, recall, f1, _ = precision_recall_fscore_support(
62
+ targets, preds, average='weighted', zero_division=0
63
+ )
64
+
65
+ # Calculate per-class metrics
66
+ precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
67
+ targets, preds, average='macro', zero_division=0
68
+ )
69
+
70
+ # Calculate AUC if we have scores
71
+ auc = None
72
+ if self.scores:
73
+ try:
74
+ scores_array = np.array(self.scores)
75
+ if len(np.unique(targets)) > 1: # Need both classes for AUC
76
+ auc = roc_auc_score(targets, scores_array)
77
+ except ValueError:
78
+ auc = None
79
+
80
+ return {
81
+ "accuracy": float(accuracy),
82
+ "precision_weighted": float(precision),
83
+ "recall_weighted": float(recall),
84
+ "f1_weighted": float(f1),
85
+ "precision_macro": float(precision_macro),
86
+ "recall_macro": float(recall_macro),
87
+ "f1_macro": float(f1_macro),
88
+ "auc": float(auc) if auc is not None else None
89
+ }
90
+
91
+ def calculate_embedding_metrics(self) -> Dict[str, float]:
92
+ """Calculate embedding quality metrics."""
93
+ if not self.embeddings:
94
+ return {}
95
+
96
+ embeddings = np.array(self.embeddings)
97
+
98
+ # Calculate embedding statistics
99
+ mean_norm = np.mean(np.linalg.norm(embeddings, axis=1))
100
+ std_norm = np.std(np.linalg.norm(embeddings, axis=1))
101
+
102
+ # Calculate intra-class and inter-class distances
103
+ if len(self.targets) > 1:
104
+ targets = np.array(self.targets)
105
+ unique_classes = np.unique(targets)
106
+
107
+ intra_class_distances = []
108
+ inter_class_distances = []
109
+
110
+ for class_label in unique_classes:
111
+ class_embeddings = embeddings[targets == class_label]
112
+ if len(class_embeddings) > 1:
113
+ # Intra-class distances
114
+ for i in range(len(class_embeddings)):
115
+ for j in range(i + 1, len(class_embeddings)):
116
+ dist = np.linalg.norm(class_embeddings[i] - class_embeddings[j])
117
+ intra_class_distances.append(dist)
118
+
119
+ # Inter-class distances
120
+ other_embeddings = embeddings[targets != class_label]
121
+ if len(other_embeddings) > 0:
122
+ for class_emb in class_embeddings:
123
+ for other_emb in other_embeddings:
124
+ dist = np.linalg.norm(class_emb - other_emb)
125
+ inter_class_distances.append(dist)
126
+
127
+ avg_intra_class = np.mean(intra_class_distances) if intra_class_distances else 0
128
+ avg_inter_class = np.mean(inter_class_distances) if inter_class_distances else 0
129
+
130
+ # Separation ratio (higher is better)
131
+ separation_ratio = avg_inter_class / (avg_intra_class + 1e-8)
132
+ else:
133
+ avg_intra_class = 0
134
+ avg_inter_class = 0
135
+ separation_ratio = 0
136
+
137
+ return {
138
+ "embedding_mean_norm": float(mean_norm),
139
+ "embedding_std_norm": float(std_norm),
140
+ "avg_intra_class_distance": float(avg_intra_class),
141
+ "avg_inter_class_distance": float(avg_inter_class),
142
+ "separation_ratio": float(separation_ratio)
143
+ }
144
+
145
+ def calculate_outfit_metrics(self) -> Dict[str, float]:
146
+ """Calculate outfit-specific metrics."""
147
+ if not self.outfit_scores:
148
+ return {}
149
+
150
+ scores = np.array(self.outfit_scores)
151
+
152
+ return {
153
+ "outfit_score_mean": float(np.mean(scores)),
154
+ "outfit_score_std": float(np.std(scores)),
155
+ "outfit_score_min": float(np.min(scores)),
156
+ "outfit_score_max": float(np.max(scores)),
157
+ "outfit_score_median": float(np.median(scores))
158
+ }
159
+
160
+ def calculate_all_metrics(self) -> Dict[str, Any]:
161
+ """Calculate all available metrics."""
162
+ metrics = {
163
+ "classification": self.calculate_classification_metrics(),
164
+ "embeddings": self.calculate_embedding_metrics(),
165
+ "outfits": self.calculate_outfit_metrics()
166
+ }
167
+
168
+ # Add summary statistics
169
+ metrics["summary"] = {
170
+ "total_predictions": len(self.predictions),
171
+ "total_targets": len(self.targets),
172
+ "total_scores": len(self.scores),
173
+ "total_embeddings": len(self.embeddings),
174
+ "total_outfit_scores": len(self.outfit_scores)
175
+ }
176
+
177
+ return metrics
178
+
179
+ def save_metrics(self, filepath: str, additional_info: Dict[str, Any] = None):
180
+ """Save metrics to JSON file."""
181
+ metrics = self.calculate_all_metrics()
182
+
183
+ if additional_info:
184
+ metrics["additional_info"] = additional_info
185
+
186
+ # Ensure directory exists
187
+ Path(filepath).parent.mkdir(parents=True, exist_ok=True)
188
+
189
+ with open(filepath, 'w') as f:
190
+ json.dump(metrics, f, indent=2)
191
+
192
+ return metrics
193
+
194
+
195
+ def calculate_triplet_metrics(anchor_emb: torch.Tensor, positive_emb: torch.Tensor,
196
+ negative_emb: torch.Tensor, margin: float = 0.2) -> Dict[str, float]:
197
+ """Calculate triplet-specific metrics."""
198
+
199
+ # Calculate distances
200
+ pos_dist = F.pairwise_distance(anchor_emb, positive_emb, p=2)
201
+ neg_dist = F.pairwise_distance(anchor_emb, negative_emb, p=2)
202
+
203
+ # Triplet loss
204
+ triplet_loss = F.relu(pos_dist - neg_dist + margin).mean()
205
+
206
+ # Accuracy: positive distance < negative distance
207
+ correct = (pos_dist < neg_dist).float().mean()
208
+
209
+ # Margin violations
210
+ margin_violations = (pos_dist - neg_dist + margin > 0).float().mean()
211
+
212
+ # Distance statistics
213
+ pos_dist_mean = pos_dist.mean()
214
+ neg_dist_mean = neg_dist.mean()
215
+ distance_ratio = neg_dist_mean / (pos_dist_mean + 1e-8)
216
+
217
+ return {
218
+ "triplet_loss": float(triplet_loss),
219
+ "triplet_accuracy": float(correct),
220
+ "margin_violations": float(margin_violations),
221
+ "positive_distance_mean": float(pos_dist_mean),
222
+ "negative_distance_mean": float(neg_dist_mean),
223
+ "distance_ratio": float(distance_ratio)
224
+ }
225
+
226
+
227
+ def calculate_outfit_compatibility_metrics(outfit_scores: torch.Tensor,
228
+ labels: torch.Tensor) -> Dict[str, float]:
229
+ """Calculate outfit compatibility specific metrics."""
230
+
231
+ # Convert to numpy for sklearn compatibility
232
+ scores_np = outfit_scores.cpu().numpy()
233
+ labels_np = labels.cpu().numpy()
234
+
235
+ # Binary classification metrics
236
+ pred_binary = (scores_np > 0.5).astype(int)
237
+
238
+ accuracy = accuracy_score(labels_np, pred_binary)
239
+ precision, recall, f1, _ = precision_recall_fscore_support(
240
+ labels_np, pred_binary, average='weighted', zero_division=0
241
+ )
242
+
243
+ # AUC if we have both classes
244
+ auc = None
245
+ if len(np.unique(labels_np)) > 1:
246
+ try:
247
+ auc = roc_auc_score(labels_np, scores_np)
248
+ except ValueError:
249
+ auc = None
250
+
251
+ # Score distribution metrics
252
+ compatible_scores = scores_np[labels_np == 1]
253
+ incompatible_scores = scores_np[labels_np == 0]
254
+
255
+ return {
256
+ "compatibility_accuracy": float(accuracy),
257
+ "compatibility_precision": float(precision),
258
+ "compatibility_recall": float(recall),
259
+ "compatibility_f1": float(f1),
260
+ "compatibility_auc": float(auc) if auc is not None else None,
261
+ "compatible_score_mean": float(np.mean(compatible_scores)) if len(compatible_scores) > 0 else 0,
262
+ "incompatible_score_mean": float(np.mean(incompatible_scores)) if len(incompatible_scores) > 0 else 0,
263
+ "score_separation": float(np.mean(compatible_scores) - np.mean(incompatible_scores)) if len(compatible_scores) > 0 and len(incompatible_scores) > 0 else 0
264
+ }
265
+
266
+
267
+ if __name__ == "__main__":
268
+ # Example usage
269
+ metrics = AdvancedMetrics()
270
+
271
+ # Simulate some data
272
+ predictions = torch.randn(100, 1)
273
+ targets = torch.randint(0, 2, (100, 1)).float()
274
+ scores = torch.sigmoid(predictions)
275
+ embeddings = torch.randn(100, 512)
276
+
277
+ metrics.add_batch(predictions, targets, scores, embeddings)
278
+ metrics.add_outfit_scores(scores.flatten().tolist())
279
+
280
+ # Calculate and save metrics
281
+ all_metrics = metrics.calculate_all_metrics()
282
+ print("Calculated metrics:")
283
+ print(json.dumps(all_metrics, indent=2))
284
+
285
+ # Save to file
286
+ metrics.save_metrics("test_metrics.json", {"model": "test", "epoch": 1})
287
+
utils/hf_utils.py CHANGED
@@ -130,6 +130,88 @@ class HFModelManager:
130
  except Exception as e:
131
  print(f"Failed to list repo files: {e}")
132
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  def push_model_to_hub(
 
130
  except Exception as e:
131
  print(f"Failed to list repo files: {e}")
132
  return []
133
+
134
+ def upload_model(self, model_type: str, repo_name: str) -> Dict[str, Any]:
135
+ """Upload models or data to HF Hub based on type."""
136
+ try:
137
+ if model_type == "models":
138
+ # Upload model checkpoints
139
+ repo_id = f"{self.username}/{repo_name}"
140
+ self.create_model_repo(repo_name, private=False)
141
+
142
+ # Upload best model checkpoints
143
+ model_files = [
144
+ "models/exports/resnet_item_embedder_best.pth",
145
+ "models/exports/vit_outfit_model_best.pth",
146
+ "models/exports/resnet_metrics.json",
147
+ "models/exports/vit_metrics.json"
148
+ ]
149
+
150
+ uploaded_files = []
151
+ for file_path in model_files:
152
+ if os.path.exists(file_path):
153
+ success = self.push_checkpoint(file_path, repo_id, f"Upload {os.path.basename(file_path)}")
154
+ if success:
155
+ uploaded_files.append(os.path.basename(file_path))
156
+
157
+ return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id}
158
+
159
+ elif model_type == "splits":
160
+ # Upload dataset splits
161
+ repo_id = f"{self.username}/{repo_name}"
162
+ try:
163
+ create_repo(
164
+ repo_id=repo_id,
165
+ repo_type="dataset",
166
+ private=False,
167
+ exist_ok=True
168
+ )
169
+ except Exception as e:
170
+ print(f"Note: Repo might already exist: {e}")
171
+
172
+ # Upload split files
173
+ split_files = [
174
+ "data/Polyvore/splits/train.json",
175
+ "data/Polyvore/splits/valid.json",
176
+ "data/Polyvore/splits/test.json",
177
+ "data/Polyvore/splits/outfit_triplets_train.json",
178
+ "data/Polyvore/splits/outfit_triplets_valid.json",
179
+ "data/Polyvore/splits/outfit_triplets_test.json"
180
+ ]
181
+
182
+ uploaded_files = []
183
+ for file_path in split_files:
184
+ if os.path.exists(file_path):
185
+ try:
186
+ upload_file(
187
+ path_or_fileobj=file_path,
188
+ path_in_repo=f"splits/{os.path.basename(file_path)}",
189
+ repo_id=repo_id,
190
+ repo_type="dataset",
191
+ commit_message=f"Upload {os.path.basename(file_path)}"
192
+ )
193
+ uploaded_files.append(os.path.basename(file_path))
194
+ except Exception as e:
195
+ print(f"Failed to upload {file_path}: {e}")
196
+
197
+ return {"success": True, "uploaded_files": uploaded_files, "repo_id": repo_id}
198
+
199
+ elif model_type == "everything":
200
+ # Upload everything
201
+ models_result = self.upload_model("models", "dressify-models")
202
+ splits_result = self.upload_model("splits", "Dressify-Helper")
203
+
204
+ return {
205
+ "success": models_result["success"] and splits_result["success"],
206
+ "models": models_result,
207
+ "splits": splits_result
208
+ }
209
+
210
+ else:
211
+ return {"success": False, "error": f"Unknown model type: {model_type}"}
212
+
213
+ except Exception as e:
214
+ return {"success": False, "error": str(e)}
215
 
216
 
217
  def push_model_to_hub(
utils/triplet_mining.py CHANGED
@@ -281,3 +281,4 @@ if __name__ == "__main__":
281
  print(f"Anchor indices: {anchors[:5]}")
282
  print(f"Positive indices: {positives[:5]}")
283
  print(f"Negative indices: {negatives[:5]}")
 
 
281
  print(f"Anchor indices: {anchors[:5]}")
282
  print(f"Positive indices: {positives[:5]}")
283
  print(f"Negative indices: {negatives[:5]}")
284
+