Spaces:
Running
Running
fix style
Browse files- tools/train/train.py +1 -1
tools/train/train.py
CHANGED
|
@@ -33,6 +33,7 @@ import jax.numpy as jnp
|
|
| 33 |
import numpy as np
|
| 34 |
import optax
|
| 35 |
import transformers
|
|
|
|
| 36 |
from datasets import Dataset
|
| 37 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 38 |
from flax.core.frozen_dict import freeze
|
|
@@ -44,7 +45,6 @@ from jax.experimental.pjit import pjit
|
|
| 44 |
from tqdm import tqdm
|
| 45 |
from transformers import HfArgumentParser
|
| 46 |
|
| 47 |
-
import wandb
|
| 48 |
from dalle_mini.data import Dataset
|
| 49 |
from dalle_mini.model import (
|
| 50 |
DalleBart,
|
|
|
|
| 33 |
import numpy as np
|
| 34 |
import optax
|
| 35 |
import transformers
|
| 36 |
+
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
from flax.core.frozen_dict import freeze
|
|
|
|
| 45 |
from tqdm import tqdm
|
| 46 |
from transformers import HfArgumentParser
|
| 47 |
|
|
|
|
| 48 |
from dalle_mini.data import Dataset
|
| 49 |
from dalle_mini.model import (
|
| 50 |
DalleBart,
|