Source code for objwatch.wrappers.tensor_shape_wrapper
# MIT License# Copyright (c) 2025 aeeeeeepfromtypesimportFrameTypefromtypingimportAny,List,Optional,Tuplefrom..constantsimportConstantsfrom..event_handlsimportEventHandlsfrom.abc_wrapperimportABCWrappertry:importtorchexceptImportError:torch=None# type: ignore
[docs]defprocess_tensor_item(seq:List[Any])->Optional[List[Any]]:""" Process a sequence to extract tensor shapes if all items are torch.Tensor. Args: seq (List[Any]): The sequence to process. Returns: Optional[List[Any]]: List of tensor shapes or None if not applicable. """iftorchisnotNoneandall(isinstance(x,torch.Tensor)forxinseq):return[x.shapeforxinseq]else:returnNone
[docs]classTensorShapeWrapper(ABCWrapper):""" TensorShapeWrapper extends ABCWrapper to log the shapes of torch.Tensor objects. """
[docs]defwrap_call(self,func_name:str,frame:FrameType)->str:""" Format the function call information, including tensor shapes if applicable. Args: func_name (str): Name of the function being called. frame (FrameType): The current stack frame. Returns: str: Formatted call message. """args,kwargs=self._extract_args_kwargs(frame)call_msg=self._format_args_kwargs(args,kwargs)returncall_msg
[docs]defwrap_return(self,func_name:str,result:Any)->str:""" Format the function return information, including tensor shapes if applicable. Args: func_name (str): Name of the function returning. result (Any): The result returned by the function. Returns: str: Formatted return message. """return_msg=self._format_return(result)returnreturn_msg
[docs]defwrap_upd(self,old_value:Any,current_value:Any)->Tuple[str,str]:""" Format the update information of a variable, including tensor shapes if applicable. Args: old_value (Any): The old value of the variable. current_value (Any): The new value of the variable. Returns: Tuple[str, str]: Formatted old and new values. """old_msg=self._format_value(old_value)current_msg=self._format_value(current_value)returnold_msg,current_msg
[docs]def_format_value(self,value:Any,is_return:bool=False)->str:""" Format a value into a string, logging tensor shapes if applicable. Args: value (Any): The value to format. is_return (bool): Flag indicating if the value is a return value. Returns: str: Formatted value string. """iftorchisnotNoneandisinstance(value,torch.Tensor):formatted=f"{value.shape}"elifisinstance(value,Constants.LOG_ELEMENT_TYPES):formatted=f"{value}"elifisinstance(value,Constants.LOG_SEQUENCE_TYPES):formatted_sequence=EventHandls.format_sequence(value,func=self.format_sequence_func)ifformatted_sequence:formatted=f"{formatted_sequence}"else:formatted=f"(type){type(value).__name__}"else:try:formatted=f"(type){value.__name__}"except:formatted=f"(type){type(value).__name__}"ifis_return:ifisinstance(value,torch.Tensor):returnf"{value.shape}"elifisinstance(value,Constants.LOG_SEQUENCE_TYPES)andformatted:returnf"[{formatted}]"returnf"{formatted}"returnformatted