Spaces:
Running
Running
style: isort
Browse files
tools/train/distributed_shampoo.py
CHANGED
|
@@ -36,13 +36,13 @@ import itertools
|
|
| 36 |
from typing import Any, List, NamedTuple
|
| 37 |
|
| 38 |
import chex
|
| 39 |
-
from flax import struct
|
| 40 |
import jax
|
| 41 |
-
from jax import lax
|
| 42 |
import jax.experimental.pjit as pjit
|
| 43 |
import jax.numpy as jnp
|
| 44 |
import numpy as np
|
| 45 |
import optax
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# pylint:disable=no-value-for-parameter
|
|
|
|
| 36 |
from typing import Any, List, NamedTuple
|
| 37 |
|
| 38 |
import chex
|
|
|
|
| 39 |
import jax
|
|
|
|
| 40 |
import jax.experimental.pjit as pjit
|
| 41 |
import jax.numpy as jnp
|
| 42 |
import numpy as np
|
| 43 |
import optax
|
| 44 |
+
from flax import struct
|
| 45 |
+
from jax import lax
|
| 46 |
|
| 47 |
|
| 48 |
# pylint:disable=no-value-for-parameter
|
tools/train/train.py
CHANGED
|
@@ -34,6 +34,7 @@ import optax
|
|
| 34 |
import transformers
|
| 35 |
import wandb
|
| 36 |
from datasets import Dataset
|
|
|
|
| 37 |
from flax import jax_utils, traverse_util
|
| 38 |
from flax.jax_utils import unreplicate
|
| 39 |
from flax.serialization import from_bytes, to_bytes
|
|
@@ -45,8 +46,6 @@ from transformers import AutoTokenizer, HfArgumentParser
|
|
| 45 |
from dalle_mini.data import Dataset
|
| 46 |
from dalle_mini.model import DalleBart, DalleBartConfig
|
| 47 |
|
| 48 |
-
from distributed_shampoo import distributed_shampoo, GraftingType
|
| 49 |
-
|
| 50 |
logger = logging.getLogger(__name__)
|
| 51 |
|
| 52 |
|
|
|
|
| 34 |
import transformers
|
| 35 |
import wandb
|
| 36 |
from datasets import Dataset
|
| 37 |
+
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 38 |
from flax import jax_utils, traverse_util
|
| 39 |
from flax.jax_utils import unreplicate
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
|
|
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
from dalle_mini.model import DalleBart, DalleBartConfig
|
| 48 |
|
|
|
|
|
|
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
| 51 |
|