Source code for objwatch.tracer

# MIT License
# Copyright (c) 2025 aeeeeeep

import sys
import pkgutil
import importlib
from functools import lru_cache
from types import FunctionType, FrameType, ModuleType
from typing import Optional, Union, Any, Dict, List, Set

from .wrappers import FunctionWrapper
from .events import EventType
from .event_handls import EventHandls, log_sequence_types
from .utils.logger import log_error, log_debug, log_warn, log_info
from .utils.weak import WeakTensorKeyDictionary

try:
    import torch

    torch_available = True
except ImportError:
    torch_available = False


[docs] class Tracer: """ Tracer class to monitor and trace function calls, returns, and variable updates within specified target modules. Supports multi-GPU environments with PyTorch. """ def __init__( self, targets: List[Union[str, ModuleType]], exclude_targets: Optional[List[str]] = None, ranks: Optional[List[int]] = None, wrapper: Optional[FunctionWrapper] = None, output_xml: Optional[str] = None, with_locals: bool = False, with_globals: bool = False, with_module_path: bool = False, ) -> None: """ Initialize the Tracer with configuration parameters. Args: targets (List[str]): Files or modules to monitor. exclude_targets (Optional[List[str]]): Files or modules to exclude from monitoring. ranks (Optional[List[int]]): GPU ranks to track when using torch.distributed. wrapper (Optional[FunctionWrapper]): Custom wrapper to extend tracing and logging functionality. output_xml (Optional[str]): Path to the XML file for writing structured logs. with_locals (bool): Enable tracing and logging of local variables within functions. with_globals (bool): Enable tracing and logging of global variables across function calls. with_module_path (bool): Prepend the module path to function names in logs. """ self.with_locals: bool = with_locals if self.with_locals: self.tracked_locals: Dict[FrameType, Dict[str, Any]] = {} self.tracked_locals_lens: Dict[FrameType, Dict[str, int]] = {} self.with_globals: bool = with_globals if self.with_globals: self.tracked_globals: Dict[FrameType, Dict[str, Any]] = {} self.tracked_globals_lens: Dict[FrameType, Dict[str, int]] = {} # List of Python built-in fields to exclude from tracking self.builtin_fields = set(dir(__builtins__)) | { 'self', '__builtins__', '__name__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', } self.with_module_path: bool = with_module_path # Process and determine the set of target files to monitor self.targets: Set[str] = self._process_targets(targets) - self._process_targets(exclude_targets) log_debug(f"Processed targets:\n{'>' * 10}\n" + "\n".join(self.targets) + f"\n{'<' * 10}") # Initialize tracking dictionaries for objects self.tracked_objects: WeakTensorKeyDictionary = WeakTensorKeyDictionary() self.tracked_objects_lens: WeakTensorKeyDictionary = WeakTensorKeyDictionary() # Initialize event handlers with optional XML output self.event_handlers: EventHandls = EventHandls(output_xml=output_xml) # Handle multi-GPU support if PyTorch is available self.torch_available: bool = torch_available self.rank_info: str = "" if self.torch_available: self.current_rank = None self.ranks: Set[int] = set(ranks if ranks is not None else [0]) else: self.ranks: Set[int] = set() # Load the function wrapper if provided self.function_wrapper: FunctionWrapper = self.load_wrapper(wrapper) self.call_depth: int = 0 def _process_targets(self, targets: Optional[List[Union[str, ModuleType]]]) -> Set[str]: """ Process the list of target modules or files to monitor. Args: targets (Optional[List[Union[str, ModuleType]]): List of target modules or file paths. Returns: Set[str]: Set of processed file paths to monitor. """ processed: Set[str] = set() if isinstance(targets, str): targets = [targets] elif targets is None: return processed for target in targets: if isinstance(target, str): if target.endswith('.py'): processed.add(target) continue target_name = target elif isinstance(target, ModuleType): target_name = target.__name__ else: log_warn(f"Unsupported target type: {type(target)}. Only 'str' or 'ModuleType' are supported.") continue spec = importlib.util.find_spec(target_name) if spec and spec.origin: processed.add(spec.origin) # Check if the module has submodules if hasattr(spec, 'submodule_search_locations'): for importer, modname, ispkg in pkgutil.walk_packages( spec.submodule_search_locations, prefix=target_name + '.' ): # For each submodule, use find_spec to check its path try: sub_spec = importlib.util.find_spec(modname) if sub_spec and sub_spec.origin: processed.add(sub_spec.origin) except Exception as e: log_error(f"Submodule {modname} could not be imported. Error: {e}") else: log_warn(f"Module {target_name} could not be found or has no file associated.") return processed
[docs] def load_wrapper(self, wrapper: Optional[FunctionWrapper]) -> Optional[FunctionWrapper]: """ Load a custom function wrapper if provided. Args: wrapper (Optional[FunctionWrapper]): The custom wrapper to load. Returns: Optional[FunctionWrapper]: The initialized wrapper or None. """ if wrapper and issubclass(wrapper, FunctionWrapper): log_warn(f"wrapper '{wrapper.__name__}' loaded") return wrapper() return None
def _get_function_info(self, frame: FrameType) -> Dict[str, Any]: """ Extract information about the currently executing function. Args: frame (FrameType): The current stack frame. Returns: Dict[str, Any]: Dictionary containing function information. """ func_info: Dict[str, Any] = {} func_name: str = frame.f_code.co_name if self.with_module_path: module_name: str = frame.f_globals.get('__name__', '') if module_name: func_name = f"{module_name}.{func_name}" func_info['func_name'] = func_name func_info['frame'] = frame if 'self' in frame.f_locals: obj = frame.f_locals['self'] class_name: str = obj.__class__.__name__ func_info['is_method'] = False try: method = getattr(obj, func_name, None) except Exception as e: log_error(f"Error occurred while getattr '{func_name}' from class '{class_name}': {e}") method = None if callable(method) and hasattr(method, '__code__') and method.__code__ == frame.f_code: func_info['is_method'] = True func_info['class_name'] = class_name if hasattr(obj, '__dict__') and hasattr(obj.__class__, '__weakref__'): attrs: Dict[str, Any] = {k: v for k, v in obj.__dict__.items() if not callable(v)} if obj not in self.tracked_objects: self.tracked_objects[obj] = attrs if obj not in self.tracked_objects_lens: self.tracked_objects_lens[obj] = {} for k, v in attrs.items(): if isinstance(v, log_sequence_types): self.tracked_objects_lens[obj][k] = len(v) else: func_info['is_method'] = False return func_info @lru_cache(maxsize=sys.maxsize) def _filename_not_endswith(self, filename: str) -> bool: """ Check if the filename does not end with any of the target extensions. Args: filename (str): The filename to check. Returns: bool: True if the filename does not end with the target extensions, False otherwise. """ return not filename.endswith(tuple(self.targets)) def _handle_change_type( self, lineno: int, class_name: str, key: str, old_value: Optional[Any], current_value: Any, old_value_len: Optional[int], current_value_len: Optional[int], ) -> None: """ Helper function to handle the change type for both object attributes and local variables. Args: lineno (int): Line number where the change occurred. class_name (str): Class name if the change relates to an object attribute. key (str): The key (variable or attribute) being tracked. old_value (Optional[Any]): The old value of the variable or attribute. current_value (Any): The current value of the variable or attribute. old_value_len (Optional[int]): The length of the old value (if applicable). current_value_len (Optional[int]): The length of the current value (if applicable). """ if old_value_len is not None and current_value_len is not None: change_type: EventType = ( self.event_handlers.determine_change_type(old_value_len, current_value_len) if old_value_len is not None else EventType.UPD ) else: change_type = EventType.UPD if id(old_value) == id(current_value): if change_type == EventType.APD: self.event_handlers.handle_apd( lineno, class_name, key, type(current_value), old_value_len, current_value_len, self.call_depth, self.rank_info, ) elif change_type == EventType.POP: self.event_handlers.handle_pop( lineno, class_name, key, type(current_value), old_value_len, current_value_len, self.call_depth, self.rank_info, ) elif change_type == EventType.UPD: self.event_handlers.handle_upd( lineno, class_name, key, old_value, current_value, self.call_depth, self.rank_info, self.function_wrapper, ) def _track_object_change(self, frame: FrameType, lineno: int): """ Handle changes in object attributes and track updates. Args: frame (FrameType): The current stack frame. lineno (int): The line number where the change occurred. """ obj = frame.f_locals['self'] class_name = obj.__class__.__name__ if obj in self.tracked_objects: old_attrs = self.tracked_objects[obj] old_attrs_lens = self.tracked_objects_lens[obj] current_attrs = {k: v for k, v in obj.__dict__.items() if not callable(v)} for key, current_value in current_attrs.items(): old_value = old_attrs.get(key, None) old_value_len = old_attrs_lens.get(key, None) is_current_seq = isinstance(current_value, log_sequence_types) current_value_len = len(current_value) if old_value_len is not None and is_current_seq else None self._handle_change_type( lineno, class_name, key, old_value, current_value, old_value_len, current_value_len, ) old_attrs[key] = current_value if is_current_seq: self.tracked_objects_lens[obj][key] = len(current_value) def _track_locals_change(self, frame: FrameType, lineno: int): """ Handle changes in local variables and track updates. Args: frame (FrameType): The current stack frame. lineno (int): The line number where the change occurred. """ if frame not in self.tracked_locals: return old_locals = self.tracked_locals[frame] current_locals = {k: v for k, v in frame.f_locals.items() if k != 'self' and not callable(v)} old_locals_lens = self.tracked_locals_lens[frame] added_vars = set(current_locals.keys()) - set(old_locals.keys()) for var in added_vars: current_local = current_locals[var] self.event_handlers.handle_upd( lineno, class_name="_", key=var, old_value=None, current_value=current_local, call_depth=self.call_depth, rank_info=self.rank_info, function_wrapper=self.function_wrapper, ) if isinstance(current_local, log_sequence_types): self.tracked_locals_lens[frame][var] = len(current_local) common_vars = set(old_locals.keys()) & set(current_locals.keys()) for var in common_vars: old_local = old_locals[var] old_local_len = old_locals_lens.get(var, None) current_local = current_locals[var] is_current_seq = isinstance(current_local, log_sequence_types) current_local_len = len(current_local) if old_local_len is not None and is_current_seq else None self._handle_change_type(lineno, "_", var, old_local, current_local, old_local_len, current_local_len) if is_current_seq: self.tracked_locals_lens[frame][var] = len(current_local) self.tracked_locals[frame] = current_locals def _track_globals_change(self, frame: FrameType, lineno: int): """ Handle changes in global variables and track updates. Args: frame (FrameType): The current stack frame. lineno (int): The line number where the change occurred. """ global_vars = frame.f_globals for key, current_value in global_vars.items(): if key in self.builtin_fields: continue old_value = self.tracked_globals.get(key, None) old_value_len = self.tracked_globals_lens.get(key, None) is_current_seq = isinstance(current_value, log_sequence_types) current_value_len = len(current_value) if old_value_len is not None and is_current_seq else None self._handle_change_type(lineno, "@", key, old_value, current_value, old_value_len, current_value_len) self.tracked_globals[key] = current_value if is_current_seq: self.tracked_globals_lens[key] = len(current_value)
[docs] def trace_factory(self) -> FunctionType: # noqa: C901 """ Create the tracing function to be used with sys.settrace. Returns: FunctionType: The trace function. """ def trace_func(frame: FrameType, event: str, arg: Any) -> Optional[FunctionType]: """ This function is the actual trace function used by sys.settrace. It is called for every event (e.g., call, return, line) during code execution. Args: frame (FrameType): The current stack frame. event (str): The type of event ('call', 'return', or 'line'). arg (Any): The argument for the event (e.g., return value for 'return'). Returns: Optional[FunctionType]: Returns the trace function itself to continue tracing. """ # Skip frames that do not match the filename condition if self._filename_not_endswith(frame.f_code.co_filename): return trace_func # Handle multi-GPU ranks if PyTorch is available if self.torch_available: if self.current_rank is None: if torch.distributed and torch.distributed.is_initialized(): self.current_rank = torch.distributed.get_rank() self.rank_info = f"[Rank {self.current_rank}] " elif self.current_rank not in self.ranks: return trace_func lineno = frame.f_lineno if event == "call": # Handle function call event func_info = self._get_function_info(frame) self.event_handlers.handle_run( lineno, func_info, self.function_wrapper, self.call_depth, self.rank_info ) self.call_depth += 1 # Track local variables if needed if self.with_locals: local_vars: Dict[str, Any] = { k: v for k, v in frame.f_locals.items() if k != 'self' and not callable(v) } self.tracked_locals[frame] = local_vars self.tracked_locals_lens[frame] = {} for var, value in local_vars.items(): if isinstance(value, log_sequence_types): self.tracked_locals_lens[frame][var] = len(value) return trace_func elif event == "return": # Handle function return event self.call_depth -= 1 func_info = self._get_function_info(frame) self.event_handlers.handle_end( lineno, func_info, self.function_wrapper, self.call_depth, self.rank_info, arg ) # Clean up local tracking after function return if self.with_locals and frame in self.tracked_locals: del self.tracked_locals[frame] del self.tracked_locals_lens[frame] return trace_func elif event == "line": # Handle line event (track changes at each line of code) if 'self' in frame.f_locals: self._track_object_change(frame, lineno) if self.with_locals: self._track_locals_change(frame, lineno) if self.with_globals: self._track_globals_change(frame, lineno) return trace_func return trace_func return trace_func
[docs] def start(self) -> None: """ Start the tracing process by setting the trace function. """ log_info("Starting tracing.") sys.settrace(self.trace_factory()) if self.torch_available and torch.distributed and torch.distributed.is_initialized(): torch.distributed.barrier()
[docs] def stop(self) -> None: """ Stop the tracing process by removing the trace function and saving XML logs. """ log_info("Stopping tracing.") sys.settrace(None) self.event_handlers.save_xml()