plenoptic.simulate.models package

Submodules

plenoptic.simulate.models.frontend module

Model architectures in this file are found in [1], [2]. frontend.OnOff() has optional pretrained filters that were reverse-engineered from a previously-trained model and should be used at your own discretion.

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

class plenoptic.simulate.models.frontend.LinearNonlinear(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear-Nonlinear model, applies a difference of Gaussians filter followed by an activation function. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

CenterSurround difference of Gaussians filter.

Type:

nn.Module

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.LuminanceContrastGainControl(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear center-surround followed by luminance and contrast gain control, and activation function. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

Difference of Gaussians linear filter.

Type:

nn.Module

luminance

Gaussian convolutional kernel used to normalize signal by local luminance.

Type:

nn.Module

contrast

Gaussian convolutional kernel used to normalize signal by local contrast.

Type:

nn.Module

luminance_scalar

Scale factor for luminance normalization.

Type:

nn.Parameter

contrast_scalar

Scale factor for contrast normalization.

Type:

nn.Parameter

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.LuminanceGainControl(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear center-surround followed by luminance gain control and activation. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

Difference of Gaussians linear filter.

Type:

nn.Module

luminance

Gaussian convolutional kernel used to normalize signal by local luminance.

Type:

nn.Module

luminance_scalar

Scale factor for luminance normalization.

Type:

nn.Parameter

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.OnOff(kernel_size, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', pretrained=False, activation=<built-in function softplus>, apply_mask=False, cache_filt=False)[source]

Bases: Module

Two-channel on-off and off-on center-surround model with local contrast and luminance gain control.

This model is called OnOff in Berardino et al 2017.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • pretrained – Whether or not to load model params estimated from [1]. See Notes for details.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear and gain control operations.

  • apply_mask (bool) – Whether or not to apply circular disk mask centered on the input image. This is useful for synthesis methods like Eigendistortions to ensure that the synthesized distortion will not appear in the periphery. See plenoptic.tools.signal.make_disk() for details on how mask is created.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt.

Notes

These 12 parameters (standard deviations & scalar constants) were reverse-engineered from model from [1], [2]. Please use these pretrained weights at your own discretion.

References

[1] (1,2)

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

plenoptic.simulate.models.naive module

class plenoptic.simulate.models.naive.CenterSurround(kernel_size, on_center=True, width_ratio_limit=2.0, amplitude_ratio=1.25, center_std=1.0, surround_std=4.0, out_channels=1, pad_mode='reflect', cache_filt=False)[source]

Bases: Module

Center-Surround, Difference of Gaussians (DoG) filter model. Can be either on-center/off-surround, or vice versa.

Filter is constructed as: .. math:

f &= amplitude_ratio * center - surround \
f &= f/f.sum()

The signs of center and surround are determined by center argument. The standard deviation of the surround Gaussian is constrained to be at least width_ratio_limit times that of the center Gaussian.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (Union[bool, List[bool]]) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equal out_channels, if just a single bool, then all out_channels will be assumed to be all on-off or off-on.

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • center_std (Union[float, Tensor]) – Standard deviation of circular Gaussian for center.

  • surround_std (Union[float, Tensor]) – Standard deviation of circular Gaussian for surround. Must be at least ratio_limit times center_std.

  • out_channels (int) – Number of filters.

  • pad_mode (str) – Padding for convolution, defaults to “circular”.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt

Attributes:
filt

Creates an on center/off surround, or off center/on surround conv filter

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

property filt: Tensor

Creates an on center/off surround, or off center/on surround conv filter

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.naive.Gaussian(kernel_size, std=3.0, pad_mode='reflect', out_channels=1, cache_filt=False)[source]

Bases: Module

Isotropic Gaussian convolutional filter. Kernel elements are normalized and sum to one.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Size of convolutional kernel.

  • std (Union[float, Tensor]) – Standard deviation of circularly symmtric Gaussian kernel.

  • pad_mode (str) – Padding mode argument to pass to torch.nn.functional.pad.

  • out_channels (int) – Number of filters with which to convolve.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt.

Attributes:
filt

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x, **conv2d_kwargs)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

property filt
forward(x, **conv2d_kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.naive.Identity(name=None)[source]

Bases: Module

simple class that just returns a copy of the image

We use this as a “dummy model” for metrics that we don’t have the representation for. We use this as the model and then just change the objective function.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(img)

Return a copy of the image

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(img)[source]

Return a copy of the image

Parameters:

img (torch.Tensor) – The image to return

Returns:

img – a clone of the input image

Return type:

torch.Tensor

class plenoptic.simulate.models.naive.Linear(kernel_size=(3, 3), pad_mode='circular', default_filters=True)[source]

Bases: Module

Simplistic linear convolutional model: It splits the input greyscale image into low and high frequencies.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Convolutional kernel size.

  • pad_mode (str) – Mode with which to pad image using nn.functional.pad().

  • default_filters (bool) – Initialize the filters to a low-pass and a band-pass.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

plenoptic.simulate.models.portilla_simoncelli module

Portilla-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in [1]_, that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as members of the same family of textures.

class plenoptic.simulate.models.portilla_simoncelli.PortillaSimoncelli(image_shape, n_scales=4, n_orientations=4, spatial_corr_width=9)[source]

Bases: Module

Portila-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in [1], that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as members of the same family of textures.

The PS stats are computed based on the steerable pyramid [2]. They consist of the local auto-correlations, cross-scale (within-orientation) correlations, and cross-orientation (within-scale) correlations of both the pyramid coefficients and the local energy (as computed by those coefficients). Additionally, they include the first four global moments (mean, variance, skew, and kurtosis) of the image and down-sampled versions of that image. See the paper and notebook for more description.

Parameters:
  • image_shape (Tuple[int, int]) – Shape of input image.

  • n_scales (int) – The number of pyramid scales used to measure the statistics (default=4)

  • n_orientations (int) – The number of orientations used to measure the statistics (default=4)

  • spatial_corr_width (int) – The width of the spatial cross- and auto-correlation statistics

scales

The names of the unique scales of coefficients in the pyramid, used for coarse-to-fine metamer synthesis.

Type:

list

References

[1]

J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int’l Journal of Computer Vision. 40(1):49-71, October, 2000. http://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html http://www.cns.nyu.edu/~lcv/texture/

[2]

E P Simoncelli and W T Freeman, “The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation,” Second Int’l Conf on Image Processing, Washington, DC, Oct 1995.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

convert_to_dict(representation_tensor)

Convert tensor of statistics to a dictionary.

convert_to_tensor(representation_dict)

Convert dictionary of statistics to a tensor.

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(image[, scales])

Generate Texture Statistics representation of an image.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

plot_representation(data[, ax, figsize, ...])

Plot the representation in a human viewable format -- stem plots with data separated out by statistic type.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

remove_scales(representation_tensor, ...)

Remove statistics not associated with scales.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update_plot(axes, data[, batch_idx])

Update the information in our representation plot.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

convert_to_dict(representation_tensor)[source]

Convert tensor of statistics to a dictionary.

While the tensor representation is required by plenoptic’s synthesis objects, the dictionary representation is easier to manually inspect.

This dictionary will contain NaNs in its values: these are placeholders for the redundant statistics.

Parameters:

representation_tensor (Tensor) – 3d tensor of statistics.

Returns:

Dictionary of representation, with informative keys.

Return type:

rep

See also

convert_to_tensor

Convert dictionary representation to tensor.

convert_to_tensor(representation_dict)[source]

Convert dictionary of statistics to a tensor.

Parameters:

representation_dict (OrderedDict) – Dictionary of representation.

Return type:

3d tensor of statistics.

See also

convert_to_dict

Convert tensor representation to dictionary.

forward(image, scales=None)[source]

Generate Texture Statistics representation of an image.

Note that separate batches and channels are analyzed in parallel.

Parameters:
  • image (Tensor) – A 4d tensor (batch, channel, height, width) containing the image(s) to analyze.

  • scales (Optional[List[Union[Literal['pixel_statistics'], int, Literal['residual_lowpass', 'residual_highpass']]]]) – Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model’s scales attribute, and the returned tensor will then contain the subset corresponding to those scales.

Returns:

3d tensor of shape (batch, channel, stats) containing the measured texture statistics.

Return type:

representation_tensor

Raises:

ValueError : – If image is not 4d or has a dtype other than float or complex.

plot_representation(data, ax=None, figsize=(15, 15), ylim=None, batch_idx=0, title=None)[source]

Plot the representation in a human viewable format – stem plots with data separated out by statistic type.

This plots the representation of a single batch and averages over all channels in the representation.

We create the following axes:

  • pixels+var_highpass: marginal pixel statistics (first four moments, min, max) and variance of the residual highpass.

  • std+skew+kurtosis recon: the standard deviation, skew, and kurtosis of the reconstructed lowpass image at each scale

  • magnitude_std: the standard deviation of the steerable pyramid coefficient magnitudes at each orientation and scale.

  • auto_correlation_reconstructed: the auto-correlation of the reconstructed lowpass image at each scale (summarized using Euclidean norm).

  • auto_correlation_magnitude: the auto-correlation of the pyramid coefficient magnitudes at each scale and orientation (summarized using Euclidean norm).

  • cross_orientation_correlation_magnitude: the cross-correlations between each orientation at each scale (summarized using Euclidean norm)

If self.n_scales > 1, we also have:

  • cross_scale_correlation_magnitude: the cross-correlations between the pyramid coefficient magnitude at one scale and the same orientation at the next-coarsest scale (summarized using Euclidean norm).

  • cross_scale_correlation_real: the cross-correlations between the real component of the pyramid coefficients and the real and imaginary components (at the same orientation) at the next-coarsest scale (summarized using Euclidean norm).

Parameters:
  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • ax (Optional[Axes]) – Axes where we will plot the data. If a plt.Axes instance, will subdivide into 6 or 8 new axes (depending on self.n_scales). If None, we create a new figure.

  • figsize (Tuple[float, float]) – The size of the figure. Ignored if ax is not None.

  • ylim (Union[Tuple[float, float], Literal[False], None]) – If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • title (string) – Title for the plot

Return type:

Tuple[Figure, List[Axes]]

Returns:

  • fig – Figure containing the plot

  • axes – List of 6 or 8 axes containing the plot (depending on self.n_scales)

remove_scales(representation_tensor, scales_to_keep)[source]

Remove statistics not associated with scales.

For a given representation_tensor and a list of scales_to_keep, this attribute removes all statistics not associated with those scales.

Note that calling this method will always remove statistics.

Parameters:
  • representation_tensor (Tensor) – 3d tensor containing the measured representation statistics.

  • scales_to_keep (List[Union[Literal['pixel_statistics'], int, Literal['residual_lowpass', 'residual_highpass']]]) – Which scales to include in the returned representation. Can contain subset of values present in this model’s scales attribute, and the returned tensor will then contain the subset of the full representation corresponding to those scales.

Returns:

Representation tensor with some statistics removed.

Return type:

limited_representation_tensor

update_plot(axes, data, batch_idx=0)[source]

Update the information in our representation plot.

This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we’ve used plot_representation to create the plots we want to update and so know that they’re stem plots.

We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we’re showing other information, as done by Metamer.animate), grab the representation from plotting and, since these are both lists, iterate through them, updating them to the values in data as we go.

In order for this to be used by FuncAnimation, we need to return Artists, so we return a list of the relevant artists, the markerline and stemlines from the StemContainer.

Currently, this averages over all channels in the representation.

Parameters:
  • axes (List[Axes]) – A list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

Returns:

A list of the artists used to update the information on the stem plots

Return type:

stem_artists

Module contents