Computing embeddings on 1xH100 works, 2xH100 fails
#3
by
davidcj
- opened
Hello, thanks for releasing Gemma3!
I'm trying to compute and save embeddings for various strings using the model. I'm attempting to do this on 2xH100's.
If I run the following (see below for contents of inference.py
, I can get the embeddings just fine, and I have observed that the model is split between gpu0 and cpu.
export CUDA_VISIBLE_DEVICES=0
uv run python inference.py --hf_model_dir=$MODEL_DIR --outdir=$OUTDIR
However, if I run the same exact code with export CUDA_VISIBLE_DEVICES=0,1
I observe the model is indeed split between gpu0 and gpu1, and the forward pass fails (stack trace below). Do you know what might be going on? Thanks!
Traceback (most recent call last):
File "/home/davidcj/projects/myproject/inference.py", line 103, in <module>
main()
~~~~^^
File "/home/davidcj/projects/myproject/inference.py", line 89, in main
final_hidden_states = get_gemma_embeddings(
~~~~~~~~~~~~~~~~~~~~^
smiles, model, processor, accelerator
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
).to(torch.float32)
^
File "/home/davidcj/projects/myproject/inference.py", line 44, in get_gemma_embeddings
outputs = model.generate(
**inputs,
...<2 lines>...
max_new_tokens=1,
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/generation/utils.py", line 2250, in generate
result = self._sample(
input_ids,
...<5 lines>...
**model_kwargs,
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/generation/utils.py", line 3238, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1352, in forward
outputs = self.language_model(
attention_mask=causal_mask,
...<9 lines>...
**lm_kwargs,
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 976, in forward
outputs = self.model(
input_ids=input_ids,
...<9 lines>...
**loss_kwargs,
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 754, in forward
layer_outputs = decoder_layer(
hidden_states,
...<9 lines>...
**flash_attn_kwargs,
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 443, in forward
hidden_states, self_attn_weights = self.self_attn(
~~~~~~~~~~~~~~^
hidden_states=hidden_states,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<7 lines>...
**kwargs,
^^^^^^^^^
)
^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 347, in forward
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/cache_utils.py", line 1734, in update
return update_fn(
cache_position,
...<5 lines>...
k_out.shape[2],
)
File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/cache_utils.py", line 1703, in _static_update
k_out[:, :, cache_position] = key_states
~~~~~^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Contents of inference.py
:
"""Example of gemma inference"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from accelerate import Accelerator
import pandas as pd
import argparse
import os
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_model_dir",
type=str,
default="./hf_llama_32_11B_vision",
help="Path to HuggingFace model directory",
)
parser.add_argument("--outdir", type=str, help="Path to output directory")
parser.add_argument(
"--precision", type=int, default=32, help="FP precision to use for model"
)
return parser.parse_args()
def get_llama_embeddings(prompt, model, tokenizer, accelerator):
"""Get embeddings from llama model"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(accelerator.device) for key, value in inputs.items()}
outputs = model(**inputs, output_hidden_states=True)
final_hidden_states = outputs.hidden_states[-1]
return final_hidden_states
def get_gemma_embeddings(prompt, model, processor, accelerator):
"""Get embeddings from gemma model"""
inputs = processor(text=prompt, images=None, return_tensors="pt")
inputs = {key: value.to(accelerator.device) for key, value in inputs.items()}
with torch.inference_mode():
outputs = model.generate(
**inputs,
output_hidden_states=True,
return_dict_in_generate=True,
max_new_tokens=1,
)
final_hidden_states = outputs["hidden_states"][-1][-1]
return final_hidden_states
def main():
args = parse_args()
accelerator = Accelerator()
ckpt_dir = args.hf_model_dir
outdir = args.outdir
df = pd.read_csv('mydata.csv')
if "gemma" not in ckpt_dir.lower():
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map="auto")
else:
model = Gemma3ForConditionalGeneration.from_pretrained(
ckpt_dir, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(ckpt_dir)
out = {}
for i, row in df.iterrows():
mystr = row['input_str']
if "gemma" not in ckpt_dir.lower():
# llama models
final_hidden_states = get_llama_embeddings(
mystr, model, tokenizer, accelerator
)
else:
# Gemma models
final_hidden_states = get_gemma_embeddings(
mystr, model, processor, accelerator
).to(torch.float32)
out[row["ID"]] = final_hidden_states.cpu().detach()
if not os.path.exists(outdir):
os.makedirs(outdir)
with open(os.path.join(outdir, "embeddings.pt"), "wb") as f:
torch.save(out, f)
if __name__ == "__main__":
main()