# MIT License# Copyright (c) 2025 aeeeeeepfromtypesimportFunctionTypefromtypingimportCallable,Optional,Unionfrom.utils.loggerimportlog_error,log_info
[docs]classMPHandls:""" Handles multi-process initialization and synchronization using specified multi-process frameworks. Supported frameworks: - 'torch.distributed': PyTorch's distributed environment for multi-GPU support. - 'multiprocessing': Python's built-in multiprocessing for parallel processing. Manages process synchronization and provides the index of the current process. """
[docs]def__init__(self,framework:Optional[str]=None)->None:""" Initializes the handler with the specified framework. Args: framework (Optional[str]): The multi-process framework to use. """self.framework:Optional[str]=frameworkself.initialized:bool=Falseself.index:Optional[int]=Noneself.sync_fn:Optional[Union[FunctionType,Callable]]=Noneself._check_initialized()
[docs]def_check_initialized(self)->None:""" Verifies if the selected multi-process framework is initialized. Supports built-in frameworks and allows for custom framework extensions. """ifself.frameworkisNone:passelifself.framework=='torch.distributed':self._check_init_torch()elifself.framework=='multiprocessing':self._check_init_multiprocessing()else:# Check for custom framework extensioncustom_method_name=f"_check_init_{self.framework}"ifhasattr(self,custom_method_name):custom_method=getattr(self,custom_method_name)custom_method()else:log_error(f"Invalid framework: {self.framework}")raiseValueError(f"Invalid framework: {self.framework}")
[docs]defsync(self)->None:""" Synchronizes processes across all indices. """ifself.initializedandself.sync_fnisnotNone:self.sync_fn()
[docs]defget_index(self)->Optional[int]:""" Returns the index of the current process. Returns: Optional[int]: The index of the current process (e.g., rank in distributed environments). Returns None if the framework is not initialized. """returnself.index
[docs]defis_initialized(self)->bool:""" Checks if the multi-process framework has been initialized. Returns: bool: True if the framework is initialized, False otherwise. """ifself.frameworkisnotNoneandnotself.initialized:self._check_initialized()returnself.initialized
[docs]def_check_init_torch(self)->None:""" Checks if the PyTorch distributed environment is initialized for multi-GPU support. If initialized, sets the current process index and synchronization function. """importtorchiftorch.distributedandtorch.distributed.is_initialized():self.initialized=Trueself.index=torch.distributed.get_rank()self.sync_fn=torch.distributed.barrierlog_info(f"torch.distributed initialized. index: {self.index}")
[docs]def_check_init_multiprocessing(self)->None:""" Checks if Python's built-in multiprocessing is initialized. If initialized, sets the current process index and synchronization function. """importmultiprocessingcurrent_process=multiprocessing.current_process()ifcurrent_process.name!="MainProcess":self.initialized=Trueself.index=current_process._identity[0]-1# Adjusting index (starts from 1)self.sync_fn=Nonelog_info(f"multiprocessing initialized. index: {self.index}")