BifrostTitan commited on
Commit
5811564
·
verified ·
1 Parent(s): 6c43e81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -15,11 +15,8 @@ from zeta.optim import StableAdamWUnfused
15
  import gradio as gr
16
  import os
17
  import subprocess
 
18
 
19
- def install(package):
20
- subprocess.check_call([os.sys.executable, "-m", "pip", "install", "--ignore-installed", package])
21
- install("zetascale=2.8.2")
22
- install("git+https://github.com/shumingma/transformers.git#egg=transformers")
23
  # Suppress TorchDynamo errors (this will fallback to eager mode)
24
  import torch._dynamo
25
  torch._dynamo.config.suppress_errors = True
@@ -66,8 +63,12 @@ transformers.utils.logging.enable_explicit_format()
66
  model_id = "microsoft/bitnet-b1.58-2B-4T-bf16"
67
  tokenizer = AutoTokenizer.from_pretrained(model_id)
68
  hf_save_dir = "./bitnet"
69
- model = AutoModelForCausalLM.from_pretrained(model_id)
70
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
71
  if torch.cuda.is_available():
72
  print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
73
  else:
@@ -124,6 +125,7 @@ val_dataset = torch.utils.data.Subset(processed_dataset, list(range(split_idx, l
124
  # ---------------------------------------------------------------------------------
125
  # Collate function for DataLoader
126
  # ---------------------------------------------------------------------------------
 
127
  def sft_collate_fn(batch):
128
  """
129
  Collate a list of examples by padding them to the maximum sequence length in the batch.
@@ -159,6 +161,7 @@ optim = StableAdamWUnfused(model.parameters(), lr=LEARNING_RATE)
159
  # ---------------------------------------------------------------------------------
160
  # Define training function for Gradio UI
161
  # ---------------------------------------------------------------------------------
 
162
  def train_model():
163
  """
164
  Runs a training loop for a fixed number of batches and returns training logs.
 
15
  import gradio as gr
16
  import os
17
  import subprocess
18
+ os.system("pip install git+https://github.com/shumingma/transformers.git")
19
 
 
 
 
 
20
  # Suppress TorchDynamo errors (this will fallback to eager mode)
21
  import torch._dynamo
22
  torch._dynamo.config.suppress_errors = True
 
63
  model_id = "microsoft/bitnet-b1.58-2B-4T-bf16"
64
  tokenizer = AutoTokenizer.from_pretrained(model_id)
65
  hf_save_dir = "./bitnet"
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ torch_dtype=torch.bfloat16,
69
+ device_map="auto"
70
+ )
71
+ device = model.device
72
  if torch.cuda.is_available():
73
  print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
74
  else:
 
125
  # ---------------------------------------------------------------------------------
126
  # Collate function for DataLoader
127
  # ---------------------------------------------------------------------------------
128
+ @spaces.GPU
129
  def sft_collate_fn(batch):
130
  """
131
  Collate a list of examples by padding them to the maximum sequence length in the batch.
 
161
  # ---------------------------------------------------------------------------------
162
  # Define training function for Gradio UI
163
  # ---------------------------------------------------------------------------------
164
+ @spaces.GPU
165
  def train_model():
166
  """
167
  Runs a training loop for a fixed number of batches and returns training logs.