Update app.py
Browse files
app.py
CHANGED
|
@@ -15,11 +15,8 @@ from zeta.optim import StableAdamWUnfused
|
|
| 15 |
import gradio as gr
|
| 16 |
import os
|
| 17 |
import subprocess
|
|
|
|
| 18 |
|
| 19 |
-
def install(package):
|
| 20 |
-
subprocess.check_call([os.sys.executable, "-m", "pip", "install", "--ignore-installed", package])
|
| 21 |
-
install("zetascale=2.8.2")
|
| 22 |
-
install("git+https://github.com/shumingma/transformers.git#egg=transformers")
|
| 23 |
# Suppress TorchDynamo errors (this will fallback to eager mode)
|
| 24 |
import torch._dynamo
|
| 25 |
torch._dynamo.config.suppress_errors = True
|
|
@@ -66,8 +63,12 @@ transformers.utils.logging.enable_explicit_format()
|
|
| 66 |
model_id = "microsoft/bitnet-b1.58-2B-4T-bf16"
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 68 |
hf_save_dir = "./bitnet"
|
| 69 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
if torch.cuda.is_available():
|
| 72 |
print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
|
| 73 |
else:
|
|
@@ -124,6 +125,7 @@ val_dataset = torch.utils.data.Subset(processed_dataset, list(range(split_idx, l
|
|
| 124 |
# ---------------------------------------------------------------------------------
|
| 125 |
# Collate function for DataLoader
|
| 126 |
# ---------------------------------------------------------------------------------
|
|
|
|
| 127 |
def sft_collate_fn(batch):
|
| 128 |
"""
|
| 129 |
Collate a list of examples by padding them to the maximum sequence length in the batch.
|
|
@@ -159,6 +161,7 @@ optim = StableAdamWUnfused(model.parameters(), lr=LEARNING_RATE)
|
|
| 159 |
# ---------------------------------------------------------------------------------
|
| 160 |
# Define training function for Gradio UI
|
| 161 |
# ---------------------------------------------------------------------------------
|
|
|
|
| 162 |
def train_model():
|
| 163 |
"""
|
| 164 |
Runs a training loop for a fixed number of batches and returns training logs.
|
|
|
|
| 15 |
import gradio as gr
|
| 16 |
import os
|
| 17 |
import subprocess
|
| 18 |
+
os.system("pip install git+https://github.com/shumingma/transformers.git")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Suppress TorchDynamo errors (this will fallback to eager mode)
|
| 21 |
import torch._dynamo
|
| 22 |
torch._dynamo.config.suppress_errors = True
|
|
|
|
| 63 |
model_id = "microsoft/bitnet-b1.58-2B-4T-bf16"
|
| 64 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 65 |
hf_save_dir = "./bitnet"
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
torch_dtype=torch.bfloat16,
|
| 69 |
+
device_map="auto"
|
| 70 |
+
)
|
| 71 |
+
device = model.device
|
| 72 |
if torch.cuda.is_available():
|
| 73 |
print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
|
| 74 |
else:
|
|
|
|
| 125 |
# ---------------------------------------------------------------------------------
|
| 126 |
# Collate function for DataLoader
|
| 127 |
# ---------------------------------------------------------------------------------
|
| 128 |
+
@spaces.GPU
|
| 129 |
def sft_collate_fn(batch):
|
| 130 |
"""
|
| 131 |
Collate a list of examples by padding them to the maximum sequence length in the batch.
|
|
|
|
| 161 |
# ---------------------------------------------------------------------------------
|
| 162 |
# Define training function for Gradio UI
|
| 163 |
# ---------------------------------------------------------------------------------
|
| 164 |
+
@spaces.GPU
|
| 165 |
def train_model():
|
| 166 |
"""
|
| 167 |
Runs a training loop for a fixed number of batches and returns training logs.
|