update
Browse files- app.py +2 -1
- inference.py +3 -2
app.py
CHANGED
|
@@ -87,4 +87,5 @@ gr.Interface(
|
|
| 87 |
['example/landscape/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', None],
|
| 88 |
['example/landscape/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', None],
|
| 89 |
]
|
| 90 |
-
).launch(
|
|
|
|
|
|
| 87 |
['example/landscape/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', None],
|
| 88 |
['example/landscape/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', None],
|
| 89 |
]
|
| 90 |
+
).launch()
|
| 91 |
+
# server_name="0.0.0.0", server_port=8080
|
inference.py
CHANGED
|
@@ -43,8 +43,9 @@ def auto_load_weight(weight, version=None, map_location=None):
|
|
| 43 |
"""Auto load Generator version from weight."""
|
| 44 |
project_dir = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
cache_dir = os.path.join(project_dir, ".cache")
|
| 46 |
-
weight_name = os.path.basename(weight)
|
| 47 |
cached_weight = os.path.join(cache_dir, weight_name)
|
|
|
|
| 48 |
|
| 49 |
# Check if the cached weight file exists
|
| 50 |
if os.path.exists(cached_weight):
|
|
@@ -67,7 +68,7 @@ def auto_load_weight(weight, version=None, map_location=None):
|
|
| 67 |
version = RELEASED_WEIGHTS[weight_name][0]
|
| 68 |
return auto_load_weight(weight, version=version, map_location=map_location)
|
| 69 |
|
| 70 |
-
elif weight_name.startswith("
|
| 71 |
cls = GeneratorV2
|
| 72 |
elif weight_name.startswith("generatorv3"):
|
| 73 |
cls = GeneratorV3
|
|
|
|
| 43 |
"""Auto load Generator version from weight."""
|
| 44 |
project_dir = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
cache_dir = os.path.join(project_dir, ".cache")
|
| 46 |
+
weight_name = os.path.basename(weight)#.lower()
|
| 47 |
cached_weight = os.path.join(cache_dir, weight_name)
|
| 48 |
+
print(project_dir, cache_dir, weight, weight_name, cached_weight)
|
| 49 |
|
| 50 |
# Check if the cached weight file exists
|
| 51 |
if os.path.exists(cached_weight):
|
|
|
|
| 68 |
version = RELEASED_WEIGHTS[weight_name][0]
|
| 69 |
return auto_load_weight(weight, version=version, map_location=map_location)
|
| 70 |
|
| 71 |
+
elif weight_name.startswith("GeneratorV2"):
|
| 72 |
cls = GeneratorV2
|
| 73 |
elif weight_name.startswith("generatorv3"):
|
| 74 |
cls = GeneratorV3
|