|
from dataclasses import dataclass, field |
|
from typing import Dict, ForwardRef, List, Optional, Type, Union |
|
|
|
|
|
ParamIdentifierType = ForwardRef("ParamIdentifier") |
|
ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata") |
|
ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata") |
|
|
|
_ContextParallelInputType = Dict[ |
|
ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]] |
|
] |
|
_ContextParallelOutputType = List[ContextParallelOutputMetadataType] |
|
ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType] |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ParamId: |
|
""" |
|
A class to identify a parameter of a method. |
|
|
|
Atleast one of `name` or `index` must be provided. |
|
|
|
Attributes: |
|
name (`str`, *optional*): |
|
The name of the parameter. |
|
index (`int`, *optional*): |
|
The index of the parameter in the method signature. Indexing starts at 0 (ignore |
|
the `self` parameter for instance methods). |
|
""" |
|
|
|
name: Optional[str] = None |
|
index: Optional[int] = None |
|
|
|
def __post_init__(self): |
|
if self.name is None and self.index is None: |
|
raise ValueError("At least one of `name` or `index` must be provided.") |
|
|
|
|
|
@dataclass(frozen=True) |
|
class CPInput: |
|
split_dim: int |
|
expected_dims: Optional[int] = None |
|
split_output: bool = False |
|
|
|
|
|
@dataclass(frozen=True) |
|
class CPOutput: |
|
gather_dim: int |
|
expected_dims: Optional[int] = None |
|
|
|
|
|
@dataclass |
|
class TransformerMetadata: |
|
|
|
cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
|
class TransformerRegistry: |
|
_registry = {} |
|
|
|
@classmethod |
|
def register(cls, model_class: Type, metadata: TransformerMetadata): |
|
cls._registry[model_class] = metadata |
|
|
|
@classmethod |
|
def get(cls, model_class: Type) -> TransformerMetadata: |
|
if model_class not in cls._registry: |
|
raise ValueError(f"Model class {model_class} not registered.") |
|
return cls._registry[model_class] |
|
|