Computing embeddings on 1xH100 works, 2xH100 fails

#3
by davidcj - opened

Hello, thanks for releasing Gemma3!

I'm trying to compute and save embeddings for various strings using the model. I'm attempting to do this on 2xH100's.

If I run the following (see below for contents of inference.py, I can get the embeddings just fine, and I have observed that the model is split between gpu0 and cpu.

export CUDA_VISIBLE_DEVICES=0
uv run python inference.py --hf_model_dir=$MODEL_DIR --outdir=$OUTDIR

However, if I run the same exact code with export CUDA_VISIBLE_DEVICES=0,1 I observe the model is indeed split between gpu0 and gpu1, and the forward pass fails (stack trace below). Do you know what might be going on? Thanks!

Traceback (most recent call last):
  File "/home/davidcj/projects/myproject/inference.py", line 103, in <module>
    main()
    ~~~~^^
  File "/home/davidcj/projects/myproject/inference.py", line 89, in main
    final_hidden_states = get_gemma_embeddings(
                          ~~~~~~~~~~~~~~~~~~~~^
        smiles, model, processor, accelerator
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ).to(torch.float32)
    ^
  File "/home/davidcj/projects/myproject/inference.py", line 44, in get_gemma_embeddings
    outputs = model.generate(
        **inputs,
    ...<2 lines>...
        max_new_tokens=1,
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/generation/utils.py", line 2250, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/generation/utils.py", line 3238, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1352, in forward
    outputs = self.language_model(
        attention_mask=causal_mask,
    ...<9 lines>...
        **lm_kwargs,
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 976, in forward
    outputs = self.model(
        input_ids=input_ids,
    ...<9 lines>...
        **loss_kwargs,
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 754, in forward
    layer_outputs = decoder_layer(
        hidden_states,
    ...<9 lines>...
        **flash_attn_kwargs,
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 443, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ~~~~~~~~~~~~~~^
        hidden_states=hidden_states,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<7 lines>...
        **kwargs,
        ^^^^^^^^^
    )
    ^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 347, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
                               ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/cache_utils.py", line 1734, in update
    return update_fn(
        cache_position,
    ...<5 lines>...
        k_out.shape[2],
    )
  File "/home/davidcj/projects/myproject/.venv/lib/python3.13/site-packages/transformers/cache_utils.py", line 1703, in _static_update
    k_out[:, :, cache_position] = key_states
    ~~~~~^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Contents of inference.py:

"""Example of gemma inference"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from accelerate import Accelerator
import pandas as pd
import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--hf_model_dir",
        type=str,
        default="./hf_llama_32_11B_vision",
        help="Path to HuggingFace model directory",
    )
    parser.add_argument("--outdir", type=str, help="Path to output directory")
    parser.add_argument(
        "--precision", type=int, default=32, help="FP precision to use for model"
    )

    return parser.parse_args()


def get_llama_embeddings(prompt, model, tokenizer, accelerator):
    """Get embeddings from llama model"""
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {key: value.to(accelerator.device) for key, value in inputs.items()}
    outputs = model(**inputs, output_hidden_states=True)
    final_hidden_states = outputs.hidden_states[-1]

    return final_hidden_states


def get_gemma_embeddings(prompt, model, processor, accelerator):
    """Get embeddings from gemma model"""
    inputs = processor(text=prompt, images=None, return_tensors="pt")
    inputs = {key: value.to(accelerator.device) for key, value in inputs.items()}
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            output_hidden_states=True,
            return_dict_in_generate=True,
            max_new_tokens=1,
        )
        final_hidden_states = outputs["hidden_states"][-1][-1]

    return final_hidden_states


def main():
    args = parse_args()

    accelerator = Accelerator()

    ckpt_dir = args.hf_model_dir
    outdir = args.outdir
    df = pd.read_csv('mydata.csv')

    if "gemma" not in ckpt_dir.lower():
        tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
        model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map="auto")
    else:
        model = Gemma3ForConditionalGeneration.from_pretrained(
            ckpt_dir, device_map="auto"
        ).eval()
        processor = AutoProcessor.from_pretrained(ckpt_dir)
    
    out = {}
    for i, row in df.iterrows():
        mystr = row['input_str']

        if "gemma" not in ckpt_dir.lower():
            # llama models
            final_hidden_states = get_llama_embeddings(
                mystr, model, tokenizer, accelerator
            )
        else:
            # Gemma models
            final_hidden_states = get_gemma_embeddings(
                mystr, model, processor, accelerator
            ).to(torch.float32)

        out[row["ID"]] = final_hidden_states.cpu().detach()

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    with open(os.path.join(outdir, "embeddings.pt"), "wb") as f:
        torch.save(out, f)


if __name__ == "__main__":
    main()

Sign up or log in to comment