In [1]:
from safetensors import safe_open
from safetensors.torch import load_file
import torch

def read_safetensor(filepath: str, key: str = None):
    """
    Read a SafeTensor file using the safetensors library.

    Args:
        filepath: Path to the SafeTensor file
        key: Optional specific key to extract values for. If None, returns all keys.

    Returns:
        If key is None: Dictionary of all tensors
        If key is specified: Single tensor for that key
    """
    # Method 1: Using safe_open (memory efficient, loads only what you need)
    with safe_open(filepath, framework="pt") as f:
        # List all keys
        keys = f.keys()

        if key is None:
            # Return all tensors as a dictionary
            return {k: f.get_tensor(k) for k in keys}
        else:
            if key not in keys:
                raise KeyError(f"Key '{key}' not found. Available keys: {list(keys)}")
            return f.get_tensor(key)

def print_tensor_info(tensor_dict):
    """
    Print information about tensors in the dictionary.
    """
    for key, tensor in tensor_dict.items():
        print(f"\nKey: {key}")
        print(f"Shape: {tensor.shape}")
        print(f"Dtype: {tensor.dtype}")
        print(f"First few values: {tensor.flatten()[:5]}")  # Show first 5 values

# Example usage
if __name__ == "__main__":
    filepath = "/data/seungah/flux_test_fp8/transformer/diffusion_pytorch_model-00001-of-00002.safetensors"

    # Example 1: List all tensors and their info
    print("Loading all tensors:")
    tensors = read_safetensor(filepath)
    print_tensor_info(tensors)

    # Example 2: Load specific tensor
    print("\nLoading specific tensor:")
    try:
        key = "single_transformer_blocks.1.attn.to_k.in_scale"  # replace with actual key name
        tensor = read_safetensor(filepath, key)
        print(f"\nKey: {key}")
        print(f"Shape: {tensor.shape}")
        print(f"Dtype: {tensor.dtype}")
        print(f"First few values: {tensor.flatten()[:5]}")
    except KeyError as e:
        print(f"Error: {e}")

    # Alternative Method: Load entire file at once
    print("\nAlternative method - loading entire file:")
    tensors = load_file(filepath)
    print(f"Available keys: {list(tensors.keys())}")

Loading all tensors:

Key: context_embedder.bias
Shape: torch.Size([3072])
Dtype: torch.float32
First few values: tensor([ 0.0032, -0.0107,  0.0138, -0.0129,  0.0147])

Key: context_embedder.weight
Shape: torch.Size([3072, 4096])
Dtype: torch.float32
First few values: tensor([-0.0669, -0.0099, -0.0311,  0.0228, -0.0073])

Key: single_transformer_blocks.0.attn.norm_k.weight
Shape: torch.Size([128])
Dtype: torch.float32
First few values: tensor([1.2266, 1.2578, 1.2969, 1.2734, 1.2500])

Key: single_transformer_blocks.0.attn.norm_q.weight
Shape: torch.Size([128])
Dtype: torch.float32
First few values: tensor([1.2266, 1.2578, 1.2969, 1.2734, 1.2500])

Key: single_transformer_blocks.0.attn.to_k.bias
Shape: torch.Size([3072])
Dtype: torch.float32
First few values: tensor([-0.0947, -0.0981,  0.0498,  0.0422,  0.0525])

Key: single_transformer_blocks.0.attn.to_k.in_scale
Shape: torch.Size([1])
Dtype: torch.float32
First few values: tensor([1.])

Key: single_transformer_blocks.0.attn.to_k.weigh