plenoptic.tools package

Submodules

plenoptic.tools.conv module

plenoptic.tools.conv.blur_downsample(x, n_scales=1, filtname='binom5', scale_filter=True)[source]

Correlate with a binomial coefficient filter and downsample by 2

Parameters:
  • x (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.

  • n_scales (int, optional. Should be non-negative.) – Apply the blur and downsample procedure recursively n_scales times. Default to 1.

  • filtname (str, optional) – Name of the filter. See pt.named_filter for options. Default to “binom5”.

  • scale_filter (bool, optional) – If true (default), the filter sums to 1 (ie. it does not affect the DC component of the signal). If false, the filter sums to 2.

plenoptic.tools.conv.correlate_downsample(image, filt, padding_mode='reflect')[source]

Correlate with a filter and downsample by 2

Parameters:
  • image (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.

  • filt (2-D torch.Tensor) – The filter to correlate with the input image

  • padding_mode (string, optional) – One of “constant”, “reflect”, “replicate”, “circular”. The option “constant” means padding with zeros.

plenoptic.tools.conv.same_padding(x, kernel_size, stride=(1, 1), dilation=(1, 1), pad_mode='circular')[source]

Pad a tensor so that 2D convolution will result in output with same dims.

Return type:

Tensor

plenoptic.tools.conv.upsample_blur(x, odd, filtname='binom5', scale_filter=True)[source]

Upsample by 2 and convolve with a binomial coefficient filter

Parameters:
  • x (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.

  • odd (tuple, list or numpy.ndarray) – This should contain two integers of value 0 or 1, which determines whether the output height and width should be even (0) or odd (1).

  • filtname (str, optional) – Name of the filter. See pt.named_filter for options. Default to “binom5”.

  • scale_filter (bool, optional) – If true (default), the filter sums to 4 (ie. it multiplies the signal by 4 before the blurring operation). If false, the filter sums to 2.

plenoptic.tools.conv.upsample_convolve(image, odd, filt, padding_mode='reflect')[source]

Upsample by 2 and convolve with a filter

Parameters:
  • image (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.

  • odd (tuple, list or numpy.ndarray) – This should contain two integers of value 0 or 1, which determines whether the output height and width should be even (0) or odd (1).

  • filt (2-D torch.Tensor) – The filter to convolve with the upsampled image

  • padding_mode (string, optional) – One of “constant”, “reflect”, “replicate”, “circular”. The option “constant” means padding with zeros.

plenoptic.tools.convergence module

Functions that check for optimization convergence/stabilization.

The functions herein generally differ in what they are checking for convergence: loss, pixel change, etc.

They should probably be able to accept the following arguments, in this order (they can accept more):

  • synth: an OptimizedSynthesis object to check.

  • stop_criterion: the value used as criterion / tolerance that our convergence target is compared against.

  • stop_iters_to_check: how many iterations back to check for convergence.

They must return a single bool: True if we’ve reached convergence, False if not.

plenoptic.tools.convergence.coarse_to_fine_enough(synth, i, ctf_iters_to_check)[source]

Check whether we’ve synthesized all scales and done so for at least ctf_iters_to_check iterations

This is meant to be paired with another convergence check, such as loss_convergence.

Parameters:
  • synth (Metamer) – The Metamer object to check.

  • i (int) – The current iteration (0-indexed).

  • ctf_iters_to_check (int) – Minimum number of iterations coarse-to-fine must run at each scale. If self.coarse_to_fine is False, then this is ignored.

Returns:

Whether we’ve been doing coarse to fine synthesis for long enough.

Return type:

ctf_enough

plenoptic.tools.convergence.loss_convergence(synth, stop_criterion, stop_iters_to_check)[source]

Check whether the loss has stabilized and, if so, return True.

Have we been synthesizing for stop_iters_to_check iterations? | |

no yes
‘—->Is abs(synth.loss[-1] - synth.losses[-stop_iters_to_check]) < stop_criterion?
no |
| yes

<——-’ | | ‘——> return True | ‘———> return False

Parameters:
  • synth (OptimizedSynthesis) – The OptimizedSynthesis object to check.

  • stop_criterion (float) – If the loss over the past stop_iters_to_check has changed less than stop_criterion, we terminate synthesis.

  • stop_iters_to_check (int) – How many iterations back to check in order to see if the loss has stopped decreasing (for stop_criterion).

Returns:

Whether the loss has stabilized or not.

Return type:

loss_stabilized

plenoptic.tools.convergence.pixel_change_convergence(synth, stop_criterion, stop_iters_to_check)[source]

Check whether the pixel change norm has stabilized and, if so, return True.

Have we been synthesizing for stop_iters_to_check iterations? | |

no yes
‘—->Is (synth.pixel_change_norm[-stop_iters_to_check:] < stop_criterion).all()?
no |
| yes

<——-’ | | ‘——> return True | ‘———> return False

Parameters:
  • synth (OptimizedSynthesis) – The OptimizedSynthesis object to check.

  • stop_criterion (float) – If the pixel change norm has been less than stop_criterion for all of the past stop_iters_to_check, we terminate synthesis.

  • stop_iters_to_check (int) – How many iterations back to check in order to see if the pixel change norm has stopped decreasing (for stop_criterion).

Returns:

Whether the pixel change norm has stabilized or not.

Return type:

loss_stabilized

plenoptic.tools.data module

plenoptic.tools.data.convert_float_to_int(im, dtype=<class 'numpy.uint8'>)[source]

Convert image from float to 8 or 16 bit image

We work with float images that lie between 0 and 1, but for saving them (either as png or in a numpy array), we want to convert them to 8 or 16 bit integers. This function does that by multiplying it by the max value for the target dtype (255 for 8 bit 65535 for 16 bit) and then converting it to the proper type.

We’ll raise an exception if the max is higher than 1, in which case we have no idea what to do.

Parameters:
  • im (ndarray) – The image to convert

  • dtype – The target data type. {np.uint8, np.uint16}

Returns:

The converted image, now with dtype=dtype

Return type:

im

plenoptic.tools.data.load_images(paths, as_gray=True)[source]

Correctly load in images

Our models and synthesis methods expect their inputs to be 4d float32 images: (batch, channel, height, width), where the batch dimension contains multiple images and channel contains something like RGB or color channel. This function helps you get your inputs into that format. It accepts either a single file, a list of files, or a single directory containing images, will load them in, normalize them to lie between 0 and 1, convert them to float32, optionally convert them to grayscale, make them tensors, and get them into the right shape.

Parameters:
  • paths (Union[str, List[str]]) – A str or list of strs. If a list, must contain paths of image files. If a str, can either be the path of a single image file or of a single directory. If a directory, we try to load every file it contains (using imageio.imwrite) and skip those we cannot (thus, for efficiency you should not point this to a directory with lots of non-image files). This is NOT recursive.

  • as_gray (bool) – Whether to convert the images into grayscale or not after loading them. If False, we do nothing. If True, we call skimage.color.rgb2gray on them.

Returns:

4d tensor containing the images.

Return type:

images

plenoptic.tools.data.make_synthetic_stimuli(size=256, requires_grad=True)[source]

Make a set of basic stimuli, useful for developping and debugging models

Parameters:
  • size (int) – The stimuli will have torch.Size([size, size]).

  • requires_grad (bool) – Whether to initialize the simuli with gradients.

Returns:

Tensor of shape [11, 1, size, size]. The set of basic stiuli: [impulse, step_edge, ramp, bar, curv_edge, sine_grating, square_grating, polar_angle, angular_sine, zone_plate, fractal]

Return type:

stimuli

plenoptic.tools.data.polar_angle(size, phase=0.0, origin=None, device=None)[source]

Make polar angle matrix (in radians).

Compute a matrix of given size containing samples of the polar angle (in radians, CW from the X-axis, ranging from -pi to pi), relative to given phase, about the given origin pixel.

Parameters:
  • size (Union[int, Tuple[int, int]]) – If an int, we assume the image should be of dimensions (size, size). if a tuple, must be a 2-tuple of ints specifying the dimensions

  • phase (float) – The phase of the polar angle function (in radians, clockwise from the X-axis)

  • origin (Union[int, Tuple[float, float], None]) – The center of the image. if an int, we assume the origin is at (origin, origin). if a tuple, must be a 2-tuple of ints specifying the origin (where (0, 0) is the upper left). if None, we assume the origin lies at the center of the matrix, (size+1)/2.

  • device (Optional[device]) – The device to create this tensor on.

Returns:

The polar angle matrix

Return type:

res

plenoptic.tools.data.polar_radius(size, exponent=1.0, origin=None, device=None)[source]

Make distance-from-origin (r) matrix

Compute a matrix of given size containing samples of a radial ramp function, raised to given exponent, centered at given origin.

Parameters:
  • size (Union[int, Tuple[int, int]]) – If an int, we assume the image should be of dimensions (size, size). if a tuple, must be a 2-tuple of ints specifying the dimensions.

  • exponent (float) – The exponent of the radial ramp function.

  • origin (Union[int, Tuple[int, int], None]) – The center of the image. if an int, we assume the origin is at (origin, origin). if a tuple, must be a 2-tuple of ints specifying the origin (where (0, 0) is the upper left). if None, we assume the origin lies at the center of the matrix, (size+1)/2.

  • device (Union[str, device, None]) – The device to create this tensor on.

Returns:

The polar radius matrix.

Return type:

res

plenoptic.tools.data.to_numpy(x, squeeze=False)[source]

cast tensor to numpy in the most conservative way possible

Parameters:
  • x (Union[Tensor, ndarray]) – Tensor to be converted to numpy.ndarray on CPU.

  • squeeze (bool) – Removes all dummy dimensions of the tensor

Return type:

Converted tensor as numpy.ndarray on CPU.

plenoptic.tools.display module

various helpful utilities for plotting or displaying information

plenoptic.tools.display.animshow(video, framerate=2.0, repeat=False, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, cmap=None, plot_complex='rectangular', batch_idx=None, channel_idx=None, as_rgb=False, **kwargs)[source]

Animate video(s) correctly.

This function animates videos correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved animation (assuming video compression doesn’t interfere); it should generally hold in notebooks as well, but will fail if, e.g., your video is 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can’t control).

This functions returns the matplotlib FuncAnimation object. In order to view it in a Jupyter notebook, use the plenoptic.convert_anim_to_html(anim) function. In order to save, use anim.save(filename) (note for this that you’ll need the appropriate writer installed and on your path, e.g., ffmpeg, imagemagick, etc).

Parameters:
  • video (torch.Tensor or list) – The videos to display. Tensors should be 5d (batch, channel, time, height, width). List of tensors should be used for tensors of different height and width: all videos will automatically be rescaled so they’re displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other. Videos must all have the same number of frames as well.

  • framerate (float) – Temporal resolution of the video, in Hz (frames per second).

  • repeat (bool) – whether to loop the animation or just play it once

  • vrange (tuple or str) –

    If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:

    • ’auto0’: all images have same vmin/vmax, which have the same absolute

      value, and come from the minimum or maximum across all images, whichever has the larger absolute value

    • ’auto/auto1’: all images have same vmin/vmax, which are the

      minimum/maximum values across all images

    • ’auto2’: all images have same vmin/vmax, which are the mean (across

      all images) minus/ plus 2 std dev (across all images)

    • ’auto3’: all images have same vmin/vmax, chosen so as to map the

      10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile

    • ’indep0’: each image has an independent vmin/vmax, which have the

      same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.

    • ’indep1’: each image has an independent vmin/vmax, which are their

      minimum/maximum values

    • ’indep2’: each image has an independent vmin/vmax, which is their

      mean minus/plus 2 std dev

    • ’indep3’: each image has an independent vmin/vmax, chosen so that

      the 10th/90th percentile values map to the 10th/90th percentile intensities.

  • zoom (float) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image.

  • title (str, list, or None, optional) –

    Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if str, will put the same title on every plot. * if list, all values must be str, must be the same length as img,

    assigning each title to corresponding image.

    • if None, no title will be printed (and subtitle will be removed).

  • col_wrap (int or None, optional) – number of axes to have in each row. If None, will fit all axes in a single row.

  • ax (matplotlib.pyplot.axis or None, optional) – if None, we make the appropriate figure. otherwise, we resize the axes so that it’s the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either pyrtools.make_figure or plt.figure)

  • cmap (matplotlib colormap, optional) – colormap to use when showing these images

  • plot_complex ({'rectangular', 'polar', 'logpolar'}) –

    specifies handling of complex values.

    • ’rectangular’: plot real and imaginary components as separate images

    • ’polar’: plot amplitude and phase as separate images

    • ’logpolar’: plot log_2 amplitude and phase as separate images

    for any other value, we raise a warning and default to rectangular.

  • batch_idx (int or None, optional) – Which element from the batch dimension to plot. If None, we plot all.

  • channel_idx (int or None, optional) – Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then as_rgb=True will fail, because we restrict the channels.

  • as_rgb (bool, optional) – Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image.

  • kwargs – Passed to ax.imshow

Returns:

anim – The animation object. In order to view, must convert to HTML or save.

Return type:

matplotlib.animation.FuncAnimation

Notes

By default, we use the ffmpeg backend, which requires that you have ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use a different, use the matplotlib rcParams: matplotlib.rcParams[‘animation.writer’] = writer, see https://matplotlib.org/stable/api/animation_api.html#writer-classes for more details.

For displaying in a jupyter notebook, ffmpeg appears to be required.

plenoptic.tools.display.clean_stem_plot(data, ax=None, title='', ylim=None, xvals=None, **kwargs)[source]

convenience wrapper for plotting stem plots

This plots the data, baseline, cleans up the axis, and sets the title

Should not be called by users directly, but is a helper function for the various plot_representation() functions

By default, stem plot would have a baseline that covers the entire range of the data. We want to be able to break that up visually (so there’s a line from 0 to 9, from 10 to 19, etc), and passing xvals separately allows us to do that. If you want the default stem plot behavior, leave xvals as None.

Parameters:
  • data (np.ndarray) – The data to plot (as a stem plot)

  • ax (matplotlib.pyplot.axis or None, optional) – The axis to plot the data on. If None, we plot on the current axis

  • title (str or None, optional) – The title to put on the axis if not None. If None, we don’t call ax.set_title (useful if you want to avoid changing the title on an existing plot)

  • ylim (tuple or None, optional) – If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits.

  • xvals (tuple or None, optional) – A 2-tuple of lists, containing the start (xvals[0]) and stop (xvals[1]) x values for plotting. If None, we use the default stem plot behavior.

  • kwargs – passed to ax.stem

Returns:

ax – The axis with the plot

Return type:

matplotlib.pyplot.axis

Examples

We allow for breaks in the baseline value if we want to visually break up the plot, as we see below.

..plot::
include-source:

import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) y[15:20] = np.nan y[35:40] = np.nan # we want to draw the baseline from 0 to 14, 20 to 34, and 40 to # 54, everywhere that we have non-NaN values for y xvals = ([0, 20, 40], [14, 34, 54]) po.tools.display.clean_stem_plot(y, xvals=xvals) plt.show()

If we don’t care about breaking up the x-axis, you can simply use the default xvals (None). In this case, this function will just clean up the plot a little bit

..plot::
include-source:

import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) po.tools.display.clean_stem_plot(y) plt.show()

plenoptic.tools.display.clean_up_axes(ax, ylim=None, spines_to_remove=['top', 'right', 'bottom'], axes_to_remove=['x'])[source]

Clean up an axis, as desired when making a stem plot of the representation

Parameters:
  • ax (matplotlib.pyplot.axis) – The axis to clean up.

  • ylim (tuple, False, or None) – If a tuple, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, we do nothing.

  • spines_to_remove (list) – Some combination of ‘top’, ‘right’, ‘bottom’, and ‘left’. The spines we remove from the axis.

  • axes_to_remove (list) – Some combination of ‘x’, ‘y’. The axes to set as invisible.

Returns:

ax – The cleaned-up axis

Return type:

matplotlib.pyplot.axis

plenoptic.tools.display.convert_anim_to_html(anim)[source]

convert a matplotlib animation object to HTML (for display)

This is a simple little wrapper function that allows the animation to be displayed in a Jupyter notebook

Parameters:

anim (matplotlib.animation.FuncAnimation) – The animation object to convert to HTML

plenoptic.tools.display.imshow(image, vrange='indep1', zoom=None, title='', col_wrap=None, ax=None, cmap=None, plot_complex='rectangular', batch_idx=None, channel_idx=None, as_rgb=False, **kwargs)[source]

Show image(s) correctly.

This function shows images correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved image; it should generally hold in notebooks as well, but will fail if, e.g., you plot an image that’s 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can’t control).

Parameters:
  • image (torch.Tensor or list) – The images to display. Tensors should be 4d (batch, channel, height, width). List of tensors should be used for tensors of different height and width: all images will automatically be rescaled so they’re displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other.

  • vrange (tuple or str) –

    If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:

    • ’auto0’: all images have same vmin/vmax, which have the same absolute

      value, and come from the minimum or maximum across all images, whichever has the larger absolute value

    • ’auto/auto1’: all images have same vmin/vmax, which are the

      minimum/maximum values across all images

    • ’auto2’: all images have same vmin/vmax, which are the mean (across

      all images) minus/ plus 2 std dev (across all images)

    • ’auto3’: all images have same vmin/vmax, chosen so as to map the

      10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile

    • ’indep0’: each image has an independent vmin/vmax, which have the

      same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.

    • ’indep1’: each image has an independent vmin/vmax, which are their

      minimum/maximum values

    • ’indep2’: each image has an independent vmin/vmax, which is their

      mean minus/plus 2 std dev

    • ’indep3’: each image has an independent vmin/vmax, chosen so that

      the 10th/90th percentile values map to the 10th/90th percentile intensities.

  • zoom (float or None) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image. If None, we try to determine the best zoom.

  • title (str, list, or None, optional) –

    Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if str, will put the same title on every plot. * if list, all values must be str, must be the same length as img,

    assigning each title to corresponding image.

    • if None, no title will be printed (and subtitle will be removed).

  • col_wrap (int or None, optional) – number of axes to have in each row. If None, will fit all axes in a single row.

  • ax (matplotlib.pyplot.axis or None, optional) – if None, we make the appropriate figure. otherwise, we resize the axes so that it’s the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either make_figure or plt.figure)

  • cmap (matplotlib colormap, optional) – colormap to use when showing these images

  • plot_complex ({'rectangular', 'polar', 'logpolar'}) –

    specifies handling of complex values.

    • ’rectangular’: plot real and imaginary components as separate images

    • ’polar’: plot amplitude and phase as separate images

    • ’logpolar’: plot log_2 amplitude and phase as separate images

    for any other value, we raise a warning and default to rectangular.

  • batch_idx (int or None, optional) – Which element from the batch dimension to plot. If None, we plot all.

  • channel_idx (int or None, optional) – Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then as_rgb=True will fail, because we restrict the channels.

  • as_rgb (bool, optional) – Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image.

  • kwargs – Passed to ax.imshow

Returns:

fig – figure containing the plotted images

Return type:

PyrFigure

plenoptic.tools.display.plot_representation(model=None, data=None, ax=None, figsize=(5, 5), ylim=False, batch_idx=0, title='', as_rgb=False)[source]

Helper function for plotting model representation

We are trying to plot data on ax, using model.plot_representation method, if it has it, and otherwise default to a function that makes sense based on the shape of data.

All of these arguments are optional, but at least some of them need to be set:

  • If model is None, we fall-back to a type of plot based on the shape of data. If it looks image-like, we’ll use plenoptic.imshow and if it looks vector-like, we’ll use plenoptic.clean_stem_plot. If it’s a dictionary, we’ll assume each key, value pair gives the title and data to plot on a separate sub-plot.

  • If data is None, we can only do something if model.plot_representation has some default behavior when data=None; this is probably to plot its own representation attribute. Thus, this will raise an Exception if both model and data are None, because we have no idea what to plot then.

  • If ax is None, we create a one-subplot figure using figsize. If ax is not None, we therefore ignore figsize.

  • If ylim is None, we call rescale_ylim, which sets the axes’ y-limits to be (-y_max, y_max), where y_max=np.abs(data).max(). If it’s False, we do nothing.

Parameters:
  • model (torch.nn.Module or None, optional) – A differentiable model that tells us how to plot data. See above for behavior if None.

  • data (array_like, dict, or None, optional) – The data to plot. See above for behavior if None.

  • ax (matplotlib.pyplot.axis or None, optional) – The axis to plot on. See above for behavior if None.

  • figsize (tuple, optional) – The size of the figure to create. Ignored if ax is not None.

  • ylim (tuple, None, or False, optional) – If not None, the y-limits to use for this plot. See above for behavior if None. If False, we do nothing.

  • batch_idx (int, optional) – Which index to take from the batch dimension

  • title (str, optional) – The title to put above this axis. If you want no title, pass the empty string ('')

  • as_rgb (bool, optional) – The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the representation doesn’t look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to po.imshow(), see that methods docstring for details.

Returns:

axes – List of created axes.

Return type:

list

plenoptic.tools.display.pyrshow(pyr_coeffs, vrange='indep1', zoom=1, show_residuals=True, cmap=None, plot_complex='rectangular', batch_idx=0, channel_idx=0, **kwargs)[source]

Display steerable pyramid coefficients in orderly fashion.

This function uses imshow to show the coefficients of the steeable pyramid, such that each scale shows up on a single row, with each scale in a given column.

Note that unlike imshow, we can only show one batch or channel at a time

Parameters:
  • pyr_coeffs (dict) – pyramid coefficients in the standard dictionary format as returned by SteerablePyramidFreq.forward()

  • vrange (tuple or str) –

    If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:

    • ’auto0’: all images have same vmin/vmax, which have the same absolute

      value, and come from the minimum or maximum across all images, whichever has the larger absolute value

    • ’auto/auto1’: all images have same vmin/vmax, which are the

      minimum/maximum values across all images

    • ’auto2’: all images have same vmin/vmax, which are the mean (across

      all images) minus/ plus 2 std dev (across all images)

    • ’auto3’: all images have same vmin/vmax, chosen so as to map the

      10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile

    • ’indep0’: each image has an independent vmin/vmax, which have the

      same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.

    • ’indep1’: each image has an independent vmin/vmax, which are their

      minimum/maximum values

    • ’indep2’: each image has an independent vmin/vmax, which is their

      mean minus/plus 2 std dev

    • ’indep3’: each image has an independent vmin/vmax, chosen so that

      the 10th/90th percentile values map to the 10th/90th percentile intensities.

  • zoom (float) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image.

  • show_residuals (bool) – whether to display the residual bands (lowpass, highpass depending on the pyramid type)

  • cmap (matplotlib colormap, optional) – colormap to use when showing these images

  • plot_complex ({'rectangular', 'polar', 'logpolar'}) –

    specifies handling of complex values.

    • ’rectangular’: plot real and imaginary components as separate images

    • ’polar’: plot amplitude and phase as separate images

    • ’logpolar’: plot log_2 amplitude and phase as separate images

    for any other value, we raise a warning and default to rectangular.

  • batch_idx (int, optional) – Which element from the batch dimension to plot.

  • channel_idx (int, optional) – Which element from the channel dimension to plot.

  • kwargs – Passed on to pyrtools.pyrshow

Returns:

fig – the figure displaying the coefficients.

Return type:

PyrFigure

plenoptic.tools.display.rescale_ylim(axes, data)[source]

rescale y-limits nicely

We take the axes and set their limits to be (-y_max, y_max), where y_max=np.abs(data).max()

Parameters:
  • axes (list) – A list of matplotlib axes to rescale

  • data (array_like or dict) – The data to use when rescaling (or a dictiontary of those values)

plenoptic.tools.display.update_plot(axes, data, model=None, batch_idx=0)[source]

Update the information in some axes.

This is used for creating an animation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It assumes the plot has been created by something like plot_representation, which initializes all the artists.

We can update stem plots, lines (as returned by plt.plot), scatter plots, or images (RGB, RGBA, or grayscale).

There are two modes for this:

  • single axis: axes is a single axis, which may contain multiple artists (all of the same type) to update. data should be a Tensor with multiple channels (one per artist in the same order) or be a dictionary whose keys give the label(s) of the corresponding artist(s) and whose values are Tensors.

  • multiple axes: axes is a list of axes, each of which contains a single artist to update (artists can be different types). data should be a Tensor with multiple channels (one per axis in the same order) or a dictionary with the same number of keys as axes, which we can iterate through in order, and whose values are Tensors.

In all cases, data Tensors should be 3d (if the plot we’re updating is a line or stem plot) or 4d (if it’s an image or scatter plot).

RGB(A) images are special, since we store that info along the channel dimension, so they only work with single-axis mode (which will only have a single artist, because that’s how imshow works).

If you have multiple axes, each with multiple artists you want to update, that’s too complicated for us, and so you should write a model.update_plot() function which handles that.

If model is set, we try to call model.update_plot() (which must also return artists). If model doesn’t have an update_plot method, then we try to figure out how to update the axes ourselves, based on the shape of the data.

Parameters:
  • axes (list or matplotlib.pyplot.axis) – The axis or list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • data (torch.Tensor or dict) – The new data to plot.

  • model (torch.nn.Module or None, optional) – A differentiable model that tells us how to plot data. See above for behavior if None.

  • batch_idx (int, optional) – Which index to take from the batch dimension

Returns:

artists – A list of the artists used to update the information on the plots

Return type:

list

plenoptic.tools.display.update_stem(stem_container, ydata)[source]

Update the information in a stem plot

We update the information in a single stem plot to match that given by ydata. We update the position of the markers and and the lines connecting them to the baseline, but we don’t change the baseline at all and assume that the xdata shouldn’t change at all.

Parameters:
  • stem_container (matplotlib.container.StemContainer) – Single container for the artists created in a plt.stem plot. It can be treated like a namedtuple (markerline, stemlines, baseline). In order to get this from an axis ax, try ax.containers[0] (obviously if you have more than one container in that axis, it may not be the first one).

  • ydata (array_like) – The new y-data to show on the plot. Importantly, must be the same length as the existing y-data.

Returns:

stem_container – The StemContainer containing the updated artists.

Return type:

matplotlib.container.StemContainer

plenoptic.tools.external module

tools to deal with data from outside plenoptic

For example, pre-existing synthesized images

plenoptic.tools.external.plot_MAD_results(original_image, noise_levels=None, results_dir=None, ssim_images_dir=None, zoom=3, vrange='indep1', **kwargs)[source]

plot original MAD results, provided by Zhou Wang

Plot the results of original MAD Competition, as provided in .mat files. The figure created shows the results for one reference image and multiple noise levels. The reference image is plotted on the first row, followed by a separate row for each noise level, which will show the initial (noisy) image and the four synthesized images, with their respective losses for the two metrics (MSE and SSIM).

We also return a DataFrame that contains the losses, noise levels, and original image name for each plotted noise level.

This code can probably be adapted to other uses, but requires that all images are the same size and assumes they’re all 64 x 64 pixels.

Parameters:
  • original_image ({samp1, samp2, samp3, samp4, samp5, samp6, samp7,) – samp8, samp9, samp10} which of the sample images to plot

  • noise_levels (list or None, optional) – which noise levels to plot. if None, will plot all. If a list, elements must be 2**i where i is in [1, 10]

  • results_dir (None or str, optional) – path to the results directory containing the results.mat files. If None, we call po.data.fetch_data to download (requires optional dependency pooch).

  • ssim_images_dir (None or str, optional) – path to the directory containing the .tif images used in SSIM paper. If None, we call po.data.fetch_data to download (requires optional dependency pooch).

  • zoom (int, optional) – amount to zoom each image, passed to pyrtools.imshow

  • vrange (str, optional) – in addition to the values accepted by pyrtools.imshow, we also accept ‘row0/1/2/3’, which is the same as ‘auto0/1/2/3’, except that we do it on a per-row basis (all images with same noise level)

  • kwargs – passed to pyrtools.imshow. Note that we call imshow separately on each image and so any argument that relies on imshow having access to all images will probably not work as expected

Returns:

  • fig (pyrtools.tools.display.Figure) – figure containing the images

  • results (dict) – dictionary containing the errors for each noise level. To convert to a well-structured pandas DataFrame, run pd.DataFrame(results).T

plenoptic.tools.optim module

Tools related to optimization such as more objective functions.

plenoptic.tools.optim.l2_norm(synth_rep, ref_rep, **kwargs)[source]

l2-norm of the difference between ref_rep and synth_rep

Parameters:
  • synth_rep (Tensor) – The first tensor to compare, model representation of the synthesized image.

  • ref_rep (Tensor) – The second tensor to compare, model representation of the reference image. must be same size as synth_rep.

  • kwargs – Ignored, only present to absorb extra arguments.

Returns:

The L2-norm of the difference between ref_rep and synth_rep.

Return type:

loss

plenoptic.tools.optim.mse(synth_rep, ref_rep, **kwargs)[source]

return the MSE between synth_rep and ref_rep

For two tensors, \(x\) and \(y\), with \(n\) values each:

\[MSE &= \frac{1}{n}\sum_i=1^n (x_i - y_i)^2\]

The two images must have a float dtype

Parameters:
  • synth_rep (Tensor) – The first tensor to compare, model representation of the synthesized image

  • ref_rep (Tensor) – The second tensor to compare, model representation of the reference image. must be same size as synth_rep,

  • kwargs – Ignored, only present to absorb extra arguments

Returns:

The mean-squared error between synth_rep and ref_rep

Return type:

loss

plenoptic.tools.optim.penalize_range(synth_img, allowed_range=(0.0, 1.0), **kwargs)[source]

penalize values outside of allowed_range

instead of clamping values to exactly fall in a range, this provides a ‘softer’ way of doing it, by imposing a quadratic penalty on any values outside the allowed_range. All values within the allowed_range have a penalty of 0

Parameters:
  • synth_img (Tensor) – The tensor to penalize. the synthesized image.

  • allowed_range (Tuple[float, float]) – 2-tuple of values giving the (min, max) allowed values

  • kwargs – Ignored, only present to absorb extra arguments

Returns:

Penalty for values outside range

Return type:

penalty

plenoptic.tools.optim.relative_MSE(synth_rep, ref_rep, **kwargs)[source]

Squared l2-norm of the difference between reference representation and synthesized representation relative to the squared l2-norm of the reference representation:

$$frac{||x - hat{x}||_2^2}{||x||_2^2}$$

Parameters:
  • synth_rep (Tensor) – The first tensor to compare, model representation of the synthesized image.

  • ref_rep (Tensor) – The second tensor to compare, model representation of the reference image. must be same size as synth_rep.

  • kwargs – Ignored, only present to absorb extra arguments

Returns:

Ratio of the squared l2-norm of the difference between ref_rep and synth_rep to the squared l2-norm of ref_rep

Return type:

loss

plenoptic.tools.optim.set_seed(seed=None)[source]

Set the seed.

We call both torch.manual_seed() and np.random.seed().

Parameters:

seed (Optional[int]) – The seed to set. If None, do nothing.

Return type:

None

plenoptic.tools.signal module

plenoptic.tools.signal.add_noise(img, noise_mse)[source]

Add normally distributed noise to an image

This adds normally-distributed noise to an image so that the resulting noisy version has the specified mean-squared error.

Parameters:
  • img (Tensor) – The image to make noisy.

  • noise_mse (Union[float, List[float]]) – The target MSE value / variance of the noise. More than one value is allowed.

Returns:

The noisy image. If noise_mse contains only one element, this will be the same size as img. Else, each separate value from noise_mse will be along the batch dimension.

Return type:

noisy_img

plenoptic.tools.signal.autocorrelation(x)[source]

Compute the autocorrelation of x.

Parameters:

x (Tensor) – N-dimensional tensor. We assume the last two dimension are height and width and compute you autocorrelation on these dimensions (independently on each other dimension).

Returns:

Autocorrelation of x

Return type:

ac

Notes

  • By the Einstein-Wiener-Khinchin theorem: The autocorrelation of a wide sense stationary (WSS) process is the inverse Fourier transform of its energy spectrum (ESD) - which itself is the multiplication between FT(x(t)) and FT(x(-t)). In other words, the auto-correlation is convolution of the signal x with itself, which corresponds to squaring in the frequency domain. This approach is computationally more efficient than brute force (n log(n) vs n^2).

  • By Cauchy-Swartz, the autocorrelation attains it is maximum at the center location (ie. no shift) - that maximum value is the signal’s variance (assuming that the input signal is mean centered).

plenoptic.tools.signal.center_crop(x, output_size)[source]

Crop out the center of a signal.

If x has an even number of elements on either of those final two dimensions, we round up.

Parameters:
  • x (Tensor) – N-dimensional tensor, we assume the last two dimensions are height and width.

  • output_size (int) – The size of the output. Note that we only support a single number, so both dimensions are cropped identically

Returns:

Tensor whose last two dimensions have each been cropped to output_size

Return type:

cropped

plenoptic.tools.signal.expand(x, factor)[source]

Expand a signal by a factor.

We do this in the frequency domain: pasting the Fourier contents of x in the center of a larger empty tensor, and then taking the inverse FFT.

Parameters:
  • x (Tensor) – The signal for expansion.

  • factor (float) – Factor by which to resize image. Must be larger than 1 and factor * x.shape[-2:] must give integer values

Returns:

The expanded signal

Return type:

expanded

See also

shrink

The inverse operation

plenoptic.tools.signal.interpolate1d(x_new, Y, X)[source]

One-dimensional linear interpolation.

Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points (X, Y), evaluated at x_new.

Note: this function is just a wrapper around np.interp().

Parameters:
  • x_new (Tensor) – The x-coordinates at which to evaluate the interpolated values.

  • Y (Union[Tensor, ndarray]) – The y-coordinates of the data points.

  • X (Union[Tensor, ndarray]) – The x-coordinates of the data points, same length as X.

Return type:

Interpolated values of shape identical to x_new.

plenoptic.tools.signal.make_disk(img_size, outer_radius=None, inner_radius=None)[source]

Create a circular mask with softened edges to an image.

All values within inner_radius will be 1, and all values from inner_radius to outer_radius will decay smoothly to 0.

Parameters:
  • img_size (Union[int, Tuple[int, int], Size]) – Size of image in pixels.

  • outer_radius (Optional[float]) – Total radius of disk. Values from inner_radius to outer_radius will decay smoothly to zero.

  • inner_radius (Optional[float]) – Radius of inner disk. All elements from the origin to inner_radius will be set to 1.

Returns:

Tensor mask with torch.Size(img_size).

Return type:

mask

plenoptic.tools.signal.maximum(x, dim=None, keepdim=False)[source]

Compute maximum in torch over any dim or combination of axes in tensor.

Parameters:
  • x (Tensor) – Input tensor

  • dim (Optional[List[int]]) – Dimensions over which you would like to compute the minimum

  • keepdim (bool) – Keep original dimensions of tensor when returning result

Returns:

Maximum value of x.

Return type:

max_x

plenoptic.tools.signal.minimum(x, dim=None, keepdim=False)[source]

Compute minimum in torch over any axis or combination of axes in tensor.

Parameters:
  • x (Tensor) – Input tensor.

  • dim (Optional[List[int]]) – Dimensions over which you would like to compute the minimum.

  • keepdim (bool) – Keep original dimensions of tensor when returning result.

Returns:

Minimum value of x.

Return type:

min_x

plenoptic.tools.signal.modulate_phase(x, phase_factor=2.0)[source]

Modulate the phase of a complex signal.

Doubling the phase of a complex signal allows you to, for example, take the correlation between steerable pyramid coefficients at two adjacent spatial scales.

Parameters:
  • x (Tensor) – Complex tensor whose phase will be modulated.

  • phase_factor (float) – Multiplicative factor to change phase by.

Returns:

Phase-modulated complex tensor.

Return type:

x_mod

plenoptic.tools.signal.polar_to_rectangular(amplitude, phase)[source]

Polar to rectangular coordinate transform

Parameters:
  • amplitude (Tensor) – Tensor containing the amplitude (aka. complex modulus). Must be > 0.

  • phase (Tensor) – Tensor containing the phase

Return type:

Complex tensor.

plenoptic.tools.signal.raised_cosine(width=1, position=0, values=(0, 1))[source]

Return a lookup table containing a “raised cosine” soft threshold function.

Y = VALUES(1)
  • (VALUES(2)-VALUES(1))

  • cos^2( PI/2 * (X - POSITION + WIDTH)/WIDTH )

This lookup table is suitable for use by interpolate1d

Parameters:
  • width (float) – The width of the region over which the transition occurs.

  • position (float) – The location of the center of the threshold.

  • values (Tuple[float, float]) – 2-tuple specifying the values to the left and right of the transition.

Return type:

Tuple[ndarray, ndarray]

Returns:

  • X – The x values of this raised cosine.

  • Y – The y values of this raised cosine.

plenoptic.tools.signal.rectangular_to_polar(x)[source]

Rectangular to polar coordinate transform

Parameters:

x (Tensor) – Complex tensor.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • amplitude – Tensor containing the amplitude (aka. complex modulus).

  • phase – Tensor containing the phase.

plenoptic.tools.signal.rescale(x, a=0.0, b=1.0)[source]

Linearly rescale the dynamic range of the input x to [a,b].

Return type:

Tensor

plenoptic.tools.signal.shrink(x, factor)[source]

Shrink a signal by a factor.

We do this in the frequency domain: cropping out the center of the Fourier transform of x, putting it in a new tensor, and taking the IFFT.

Parameters:
  • x (Tensor) – The signal for expansion.

  • factor (int) – Factor by which to resize image. Must be larger than 1 and factor / x.shape[-2:] must give integer values

Returns:

The expanded signal

Return type:

expanded

See also

expand

The inverse operation

plenoptic.tools.signal.steer(basis, angle, harmonics=None, steermtx=None, return_weights=False, even_phase=True)[source]

Steer BASIS to the specfied ANGLE.

Parameters:
  • basis (Tensor) – Array whose columns are vectorized rotated copies of a steerable function, or the responses of a set of steerable filters.

  • angle (Union[ndarray, Tensor, float]) – Scalar or column vector the size of the basis. specifies the angle(s) (in radians) to steer to

  • harmonics (Optional[List[int]]) – A list of harmonic numbers indicating the angular harmonic content of the basis. if None (default), N even or odd low frequencies, as for derivative filters

  • steermtx (Union[Tensor, ndarray, None]) – Matrix which maps the filters onto Fourier series components (ordered [cos0 cos1 sin1 cos2 sin2 … sinN]). See steer_to_harmonics_mtx function for more details. If None (default), assumes cosine phase harmonic components, and filter positions at 2pi*n/N.

  • return_weights (bool) – Whether to return the weights or not.

  • even_phase (bool) – Specifies whether the harmonics are cosine or sine phase aligned about those positions.

Returns:

  • res – The resteered basis.

  • steervect – The weights used to resteer the basis. only returned if return_weights is True.

plenoptic.tools.stats module

plenoptic.tools.stats.kurtosis(x, mean=None, var=None, dim=None, keepdim=False)[source]

sample estimate of x tailedness (presence of outliers)

kurtosis of univariate noral is 3.

smaller than 3: platykurtic (eg. uniform distribution)

greater than 3: leptokurtic (eg. Laplace distribution)

Parameters:
  • x (Tensor) – The input tensor.

  • mean (Union[float, Tensor, None]) – Reuse a precomputed mean.

  • var (Union[float, Tensor, None]) – Reuse a precomputed variance.

  • dim (Union[int, List[int], None]) – The dimension or dimensions to reduce.

  • keepdim (bool) – Whether the output tensor has dim retained or not.

Returns:

The kurtosis tensor.

Return type:

out

plenoptic.tools.stats.skew(x, mean=None, var=None, dim=None, keepdim=False)[source]

Sample estimate of x asymmetry about its mean

Parameters:
  • x (Tensor) – The input tensor

  • mean (Union[float, Tensor, None]) – Reuse a precomputed mean

  • var (Union[float, Tensor, None]) – Reuse a precomputed variance

  • dim (Union[int, List[int], None]) – The dimension or dimensions to reduce.

  • keepdim (bool) – Whether the output tensor has dim retained or not.

Returns:

The skewness tensor.

Return type:

out

plenoptic.tools.stats.variance(x, mean=None, dim=None, keepdim=False)[source]

Calculate sample variance.

Note that this is the uncorrected, or sample, variance, corresponding to torch.var(*, correction=0)

Parameters:
  • x (Tensor) – The input tensor

  • mean (Union[float, Tensor, None]) – Reuse a precomputed mean

  • dim (Union[int, List[int], None]) – The dimension or dimensions to reduce.

  • keepdim (bool) – Whether the output tensor has dim retained or not.

Returns:

The variance tensor.

Return type:

out

plenoptic.tools.straightness module

plenoptic.tools.straightness.deviation_from_line(sequence, normalize=True)[source]

Compute the deviation of sequence to the straight line between its endpoints.

Project each point of the path sequence onto the line defined by the anchor points, and measure the two sides of a right triangle: - from the projected point to the first anchor point

(aka. distance along line)

  • from the projected point to the corresponding point on the path sequence (aka. distance from line).

Parameters:
  • sequence (Tensor) – sequence of signals of shape (T, channel, height, width)

  • normalize (bool) – use the distance between the anchor points as a unit of measurement

Return type:

Tuple[Tensor, Tensor]

Returns:

  • dist_along_line – sequence of T euclidian distances along the line

  • dist_from_line – sequence of T euclidian distances to the line

plenoptic.tools.straightness.make_straight_line(start, stop, n_steps)[source]

make a straight line between start and stop with n_steps transitions.

Parameters:
  • start (Tensor) – Images of shape (1, channel, height, width), the anchor points between which a line will be made.

  • stop (Tensor) – Images of shape (1, channel, height, width), the anchor points between which a line will be made.

  • n_steps (int) – Number of steps (i.e., transitions) to create between the two anchor points. Must be positive.

Returns:

Tensor of shape (n_steps+1, channel, height, width)

Return type:

straight

plenoptic.tools.straightness.sample_brownian_bridge(start, stop, n_steps, max_norm=1)[source]

Sample a brownian bridge between start and stop made up of n_steps

Parameters:
  • start (Tensor) – signal of shape (1, channel, height, width), the anchor points between which a random path will be sampled (like pylons on which the bridge will rest)

  • stop (Tensor) – signal of shape (1, channel, height, width), the anchor points between which a random path will be sampled (like pylons on which the bridge will rest)

  • n_steps (int) – number of steps on the bridge

  • max_norm (float) – controls variability of the bridge by setting how far (in l2 norm) it veers from the straight line interpolation at the midpoint between pylons. each component of the bridge will reach a maximal variability with std = max_norm / sqrt(d), where d is the dimension of the signal. (ie. d = C*H*W). Must be non-negative.

Returns:

sequence of shape (n_steps+1, channel, height, width) a brownian bridge across the two pylons

Return type:

bridge

plenoptic.tools.straightness.translation_sequence(image, n_steps=10)[source]

make a horizontal translation sequence on image

Parameters:
  • image (Tensor) – Base image of shape, (1, channel, height, width)

  • n_steps (int) – Number of steps in the sequence. The length of the sequence is n_steps + 1. Must be positive.

Returns:

translation sequence of shape (n_steps+1, channel, height, width)

Return type:

sequence

plenoptic.tools.validate module

Functions to validate synthesis inputs.

plenoptic.tools.validate.remove_grad(model)[source]

Detach all parameters and buffers of model (in place).

plenoptic.tools.validate.validate_coarse_to_fine(model, image_shape=None, device='cpu')[source]

Determine whether a model can be used for coarse-to-fine synthesis.

In particular, this function checks the following (with associated errors):

  • Whether model has a scales attribute (AttributeError).

  • Whether model.forward accepts a scales keyword argument (TypeError).

  • Whether the output of model.forward changes shape when the scales keyword argument is set (ValueError).

Parameters:
  • model (Module) – The model to validate.

  • image_shape (Optional[Tuple[int, int, int, int]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)

  • device (Union[str, device]) – Which device to place the test image on.

plenoptic.tools.validate.validate_input(input_tensor, no_batch=False, allowed_range=None)[source]

Determine whether input_tensor tensor can be used for synthesis.

In particular, this function:

  • Checks if input_tensor has a float or complex dtype

  • Checks if input_tensor is 4d.

  • If no_batch is True, check whether input_tensor.shape[0] != 1

  • If allowed_range is not None, check whether all values of input_tensor lie within the specified range.

If any of the above fail, a ValueError is raised.

Parameters:
  • input_tensor (Tensor) – The tensor to validate.

  • no_batch (bool) – If True, raise a ValueError if the batch dimension of input_tensor is greater than 1.

  • allowed_range (Optional[Tuple[float, float]]) – If not None, ensure that all values of input_tensor lie within allowed_range.

plenoptic.tools.validate.validate_metric(metric, image_shape=None, image_dtype=torch.float32, device='cpu')[source]

Determines whether a metric can be used for MADCompetition synthesis.

In particular, this functions checks the following (with associated exceptions):

  • Whether metric is callable and accepts two 4d tensors as input (TypeError).

  • Whether metric returns a scalar when called with two 4d tensors as input (ValueError).

  • Whether metric returns a value less than 5e-7 when with two identical 4d tensors as input (ValueError). (This threshold was chosen because 1-SSIM of two identical images is 5e-8 on GPU).

Parameters:
  • metric (Union[Module, Callable[[Tensor, Tensor], Tensor]]) – The metric to validate.

  • image_shape (Optional[Tuple[int, int, int, int]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)

  • image_dtype (dtype) – What dtype to validate against.

  • device (Union[str, device]) – What device to place the test images on.

plenoptic.tools.validate.validate_model(model, image_shape=None, image_dtype=torch.float32, device='cpu')[source]

Determine whether model can be used for sythesis.

In particular, this function checks the following (with their associated errors raised):

  • If model adds a gradient to an input tensor, which implies that some of it is learnable (ValueError).

  • If model returns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError).

  • If model strips gradient from an input with gradient attached (ValueError).

  • If model casts an input tensor to something else and returns it to a tensor before returning it (ValueError).

  • If model changes the precision of the input tensor (TypeError).

  • If model returns a 3d or 4d output when given a 4d input (ValueError).

  • If model changes the device of the input (RuntimeError).

Finally, we check if model is in training mode and raise a warning if so. Note that this is different from having learnable parameters, see ``pytorch docs <https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc>``_

Parameters:
  • model (Module) – The model to validate.

  • image_shape (Optional[Tuple[int, int, int, int]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)

  • image_dtype (dtype) – What dtype to validate against.

  • device (Union[str, device]) – What device to place test image on.

See also

remove_grad

Helper function for detaching all parameters (in place).

Module contents