Spaces:
Running
Running
feat(train): local jax cache
Browse files- tools/train/train.py +1 -3
tools/train/train.py
CHANGED
|
@@ -57,9 +57,7 @@ from dalle_mini.model import (
|
|
| 57 |
set_partitions,
|
| 58 |
)
|
| 59 |
|
| 60 |
-
cc.initialize_cache(
|
| 61 |
-
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
| 62 |
-
)
|
| 63 |
|
| 64 |
logger = logging.getLogger(__name__)
|
| 65 |
|
|
|
|
| 57 |
set_partitions,
|
| 58 |
)
|
| 59 |
|
| 60 |
+
cc.initialize_cache("./jax_cache", max_cache_size_bytes=5 * 2**30)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|