Source code for objwatch.mp_handls
# MIT License
# Copyright (c) 2025 aeeeeeep
from types import FunctionType
from typing import Callable, Optional, Union
from .utils.logger import log_error, log_info
[docs]
class MPHandls:
"""
Handles multi-process initialization and synchronization
using specified multi-process frameworks.
Supported frameworks:
- 'torch.distributed': PyTorch's distributed environment for multi-GPU support.
- 'multiprocessing': Python's built-in multiprocessing for parallel processing.
Manages process synchronization and provides the index of the current process.
"""
[docs]
def __init__(self, framework: Optional[str] = None) -> None:
"""
Initializes the handler with the specified framework.
Args:
framework (Optional[str]): The multi-process framework to use.
"""
self.framework: Optional[str] = framework
self.initialized: bool = False
self.index: Optional[int] = None
self.sync_fn: Optional[Union[FunctionType, Callable]] = None
self._check_initialized()
[docs]
def _check_initialized(self) -> None:
"""
Verifies if the selected multi-process framework is initialized.
Supports built-in frameworks and allows for custom framework extensions.
"""
if self.framework is None:
pass
elif self.framework == 'torch.distributed':
self._check_init_torch()
elif self.framework == 'multiprocessing':
self._check_init_multiprocessing()
else:
# Check for custom framework extension
custom_method_name = f"_check_init_{self.framework}"
if hasattr(self, custom_method_name):
custom_method = getattr(self, custom_method_name)
custom_method()
else:
log_error(f"Invalid framework: {self.framework}")
raise ValueError(f"Invalid framework: {self.framework}")
[docs]
def sync(self) -> None:
"""
Synchronizes processes across all indices.
"""
if self.initialized and self.sync_fn is not None:
self.sync_fn()
[docs]
def get_index(self) -> Optional[int]:
"""
Returns the index of the current process.
Returns:
Optional[int]: The index of the current process (e.g., rank in distributed environments).
Returns None if the framework is not initialized.
"""
return self.index
[docs]
def is_initialized(self) -> bool:
"""
Checks if the multi-process framework has been initialized.
Returns:
bool: True if the framework is initialized, False otherwise.
"""
if self.framework is not None and not self.initialized:
self._check_initialized()
return self.initialized
[docs]
def _check_init_torch(self) -> None:
"""
Checks if the PyTorch distributed environment is initialized for multi-GPU support.
If initialized, sets the current process index and synchronization function.
"""
import torch
if torch.distributed and torch.distributed.is_initialized():
self.initialized = True
self.index = torch.distributed.get_rank()
self.sync_fn = torch.distributed.barrier
log_info(f"torch.distributed initialized. index: {self.index}")
[docs]
def _check_init_multiprocessing(self) -> None:
"""
Checks if Python's built-in multiprocessing is initialized.
If initialized, sets the current process index and synchronization function.
"""
import multiprocessing
current_process = multiprocessing.current_process()
if current_process.name != "MainProcess":
self.initialized = True
self.index = current_process._identity[0] - 1 # Adjusting index (starts from 1)
self.sync_fn = None
log_info(f"multiprocessing initialized. index: {self.index}")