File size: 5,606 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import abc
import socket
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, List, Optional
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
if ENABLE_MULTI_DEVICE:
from mpi4py.futures import MPICommExecutor, MPIPoolExecutor
from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size
class MPINodeState:
''' MPINodeState acts as a central global state shares between tasks on MPI node.
An example:
def task():
if MPINodeState.state is None:
MPINodeState.state = 0
MPINodeState.state += 1
return MPINodeState.state
n_workers = 4
with MPIPoolExecutor(max_workers=n_workers) as executor:
for i in range(2):
futures = [executor.submit(task) for i in range(n_workers)]
This should produce the following output:
- [1, 1, 1, 1]
- [2, 2, 2, 2]
'''
state = None
@staticmethod
def is_initialized() -> bool:
return MPINodeState.state is not None
def external_mpi_comm_available(model_world_size: int) -> bool:
''' Check if the current process is launched by mpirun and does not use MPIPoolExecutor to spawn processes.
e.g. mpirun -np 4 python script.py
'''
if ENABLE_MULTI_DEVICE:
return mpi_world_size() == model_world_size and model_world_size > 1
else:
return False
def need_spawn_mpi_workers(model_world_size: int) -> bool:
''' Check if the current process needs to spawn MPI workers. '''
if ENABLE_MULTI_DEVICE:
return mpi_world_size() == 1 and model_world_size > 1
else:
return False
class MpiSession(abc.ABC):
@abc.abstractmethod
def submit(self, task: (...), *args, **kwargs) -> List[Future]:
raise NotImplementedError()
@abc.abstractmethod
def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]:
raise NotImplementedError()
@abc.abstractmethod
def shutdown(self):
raise NotImplementedError()
class MpiPoolSession(MpiSession):
def __init__(self, n_workers: int):
self.n_workers = n_workers
self.mpi_pool: Optional[MPIPoolExecutor] = None
self._start_mpi_pool()
def submit(self, task: (...), *args, **kwargs) -> List[Future]:
return [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers)
]
def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]:
futures = [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers)
]
return [future.result() for future in futures]
def shutdown(self):
if self.mpi_pool is not None:
self.mpi_pool.shutdown()
self.mpi_pool = None
def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers,
path=sys.path)
def __del__(self):
self.shutdown()
def __reduce__(self):
raise TypeError('cannot pickle MPI session')
class MpiCommSession(MpiSession):
def __init__(self, n_workers: int = 1):
if n_workers <= 0:
raise ValueError(
f'n_workers must be non-negative, but got {n_workers}')
if n_workers != mpi_world_size():
raise ValueError(
f'n_workers must be equal to the number of processes launched by mpirun, got {n_workers} vs {mpi_world_size()}'
)
if mpi_rank() != 0:
raise RuntimeError('only rank 0 can start multi-node session')
if not external_mpi_comm_available(n_workers):
raise RuntimeError('The LLM instance should be launched by mpirun.')
self.n_workers = n_workers
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.mpi_pool: Optional[MPIPoolExecutor] = None
self._start_mpi_pool()
def submit(self, task: (...), *args, **kwargs) -> List[Future]:
assert self.mpi_pool is not None, 'MPI session not started'
# Trick: The MPICommExecutor excludes rank0 from workers, thus an extra task dispatching to rank0 is needed
worker_futures = [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers - 1)
]
# A trick to wait for rank0 to be ready, or the collective tasks will hang
# TODO[chunweiy]: Remove this trick for reducing normal tasks latencies
time.sleep(4)
rank0_future = self.thread_pool.submit(task, *args, **kwargs)
return [rank0_future] + worker_futures
def submit_sync(self, task: (...), *args, **kwargs) -> List[Any]:
futures = self.submit(task, *args, **kwargs)
return [future.result() for future in futures]
def shutdown(self):
if self.mpi_pool is not None:
self.mpi_pool.shutdown()
self.mpi_pool = None
def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'
self.thread_pool = ThreadPoolExecutor(max_workers=2)
comm_executor = MPICommExecutor(mpi_comm())
self.mpi_pool = comm_executor.__enter__()
def __del__(self):
self.shutdown()
def __reduce__(self):
raise TypeError('cannot pickle MPI session')
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
|