objwatch.wrappers.tensor_shape_wrapper module

objwatch.wrappers.tensor_shape_wrapper.process_tensor_item(seq: List[Any]) List[Any] | None[source]

Process a sequence to extract tensor shapes if all items are torch.Tensor.

Parameters:

seq (List[Any]) – The sequence to process.

Returns:

List of tensor shapes or None if not applicable.

Return type:

Optional[List[Any]]

class objwatch.wrappers.tensor_shape_wrapper.TensorShapeWrapper[source]

Bases: ABCWrapper

TensorShapeWrapper extends ABCWrapper to log the shapes of torch.Tensor objects.

__init__()[source]
wrap_call(func_name: str, frame: FrameType) str[source]

Format the function call information, including tensor shapes if applicable.

Parameters:
  • func_name (str) – Name of the function being called.

  • frame (FrameType) – The current stack frame.

Returns:

Formatted call message.

Return type:

str

wrap_return(func_name: str, result: Any) str[source]

Format the function return information, including tensor shapes if applicable.

Parameters:
  • func_name (str) – Name of the function returning.

  • result (Any) – The result returned by the function.

Returns:

Formatted return message.

Return type:

str

wrap_upd(old_value: Any, current_value: Any) Tuple[str, str][source]

Format the update information of a variable, including tensor shapes if applicable.

Parameters:
  • old_value (Any) – The old value of the variable.

  • current_value (Any) – The new value of the variable.

Returns:

Formatted old and new values.

Return type:

Tuple[str, str]

_format_value(value: Any, is_return: bool = False) str[source]

Format a value into a string, logging tensor shapes if applicable.

Parameters:
  • value (Any) – The value to format.

  • is_return (bool) – Flag indicating if the value is a return value.

Returns:

Formatted value string.

Return type:

str

_abc_impl = <_abc._abc_data object>