|  | from typing import * | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.distributed.rpc as rpc | 
					
						
						|  | from torch import Tensor | 
					
						
						|  | from torch._jit_internal import Future | 
					
						
						|  | from torch.distributed.rpc import RRef | 
					
						
						|  | from typing import Tuple | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | module_interface_cls = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward_async(self, *args, **kwargs): | 
					
						
						|  | args = (self.module_rref, self.device, self.is_device_map_set, *args) | 
					
						
						|  | kwargs = {**kwargs} | 
					
						
						|  | return rpc.rpc_async( | 
					
						
						|  | self.module_rref.owner(), | 
					
						
						|  | _remote_forward, | 
					
						
						|  | args, | 
					
						
						|  | kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, *args, **kwargs): | 
					
						
						|  | args = (self.module_rref, self.device, self.is_device_map_set, *args) | 
					
						
						|  | kwargs = {**kwargs} | 
					
						
						|  | ret_fut = rpc.rpc_async( | 
					
						
						|  | self.module_rref.owner(), | 
					
						
						|  | _remote_forward, | 
					
						
						|  | args, | 
					
						
						|  | kwargs, | 
					
						
						|  | ) | 
					
						
						|  | return ret_fut.wait() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | _generated_methods = [ | 
					
						
						|  | forward_async, | 
					
						
						|  | forward, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _remote_forward( | 
					
						
						|  | module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs): | 
					
						
						|  | module = module_rref.local_value() | 
					
						
						|  | device = torch.device(device) | 
					
						
						|  |  | 
					
						
						|  | if device.type != "cuda": | 
					
						
						|  | return module.forward(*args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | args = (*args,) | 
					
						
						|  | out_args: Tuple[()] = () | 
					
						
						|  | for arg in args: | 
					
						
						|  | arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) | 
					
						
						|  | out_args = out_args + arg | 
					
						
						|  |  | 
					
						
						|  | kwargs = {**kwargs} | 
					
						
						|  | for k, v in kwargs.items(): | 
					
						
						|  | if isinstance(v, Tensor): | 
					
						
						|  | kwargs[k] = kwargs[k].to(device) | 
					
						
						|  |  | 
					
						
						|  | if is_device_map_set: | 
					
						
						|  | return module.forward(*out_args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ret: Tuple[()] = () | 
					
						
						|  | for i in module.forward(*out_args, **kwargs): | 
					
						
						|  | i = (i.cpu(),) if isinstance(i, Tensor) else (i,) | 
					
						
						|  | ret = ret + i | 
					
						
						|  | return ret | 
					
						
						|  |  |