runner
Browse files- runner/inference.py +9 -8
runner/inference.py
CHANGED
@@ -214,14 +214,15 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
|
|
214 |
tos_url = URL[cache_name]
|
215 |
logger.info(f"Downloading data cache from\n {tos_url}...")
|
216 |
urllib.request.urlretrieve(tos_url, cache_path)
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
225 |
|
226 |
|
227 |
# checkpoint_path = configs.load_checkpoint_path
|
|
|
214 |
tos_url = URL[cache_name]
|
215 |
logger.info(f"Downloading data cache from\n {tos_url}...")
|
216 |
urllib.request.urlretrieve(tos_url, cache_path)
|
217 |
+
|
218 |
+
if not os.path.exists('./checkpoint.pt'):
|
219 |
+
# Google Drive file ID
|
220 |
+
file_id = '17zBIRed3xZM8ux0bq2hpf1oFC75Y7OEw'
|
221 |
+
# URL to download the file
|
222 |
+
url = f'https://drive.google.com/uc?id={file_id}'
|
223 |
+
|
224 |
+
# Download the file and save it as 'checkpoint.pt'
|
225 |
+
gdown.download(url, './checkpoint.pt', quiet=False)
|
226 |
|
227 |
|
228 |
# checkpoint_path = configs.load_checkpoint_path
|