plenoptic.tools package
Submodules
plenoptic.tools.conv module
- plenoptic.tools.conv.blur_downsample(x, n_scales=1, filtname='binom5', scale_filter=True)[source]
Correlate with a binomial coefficient filter and downsample by 2
- Parameters:
x (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.
n_scales (int, optional. Should be non-negative.) – Apply the blur and downsample procedure recursively n_scales times. Default to 1.
filtname (str, optional) – Name of the filter. See pt.named_filter for options. Default to “binom5”.
scale_filter (bool, optional) – If true (default), the filter sums to 1 (ie. it does not affect the DC component of the signal). If false, the filter sums to 2.
- plenoptic.tools.conv.correlate_downsample(image, filt, padding_mode='reflect')[source]
Correlate with a filter and downsample by 2
- Parameters:
image (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.
filt (2-D torch.Tensor) – The filter to correlate with the input image
padding_mode (string, optional) – One of “constant”, “reflect”, “replicate”, “circular”. The option “constant” means padding with zeros.
- plenoptic.tools.conv.same_padding(x, kernel_size, stride=(1, 1), dilation=(1, 1), pad_mode='circular')[source]
Pad a tensor so that 2D convolution will result in output with same dims.
- Return type:
Tensor
- plenoptic.tools.conv.upsample_blur(x, odd, filtname='binom5', scale_filter=True)[source]
Upsample by 2 and convolve with a binomial coefficient filter
- Parameters:
x (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.
odd (tuple, list or numpy.ndarray) – This should contain two integers of value 0 or 1, which determines whether the output height and width should be even (0) or odd (1).
filtname (str, optional) – Name of the filter. See pt.named_filter for options. Default to “binom5”.
scale_filter (bool, optional) – If true (default), the filter sums to 4 (ie. it multiplies the signal by 4 before the blurring operation). If false, the filter sums to 2.
- plenoptic.tools.conv.upsample_convolve(image, odd, filt, padding_mode='reflect')[source]
Upsample by 2 and convolve with a filter
- Parameters:
image (torch.Tensor of shape (batch, channel, height, width)) – Image, or batch of images. Channels are treated in the same way as batches.
odd (tuple, list or numpy.ndarray) – This should contain two integers of value 0 or 1, which determines whether the output height and width should be even (0) or odd (1).
filt (2-D torch.Tensor) – The filter to convolve with the upsampled image
padding_mode (string, optional) – One of “constant”, “reflect”, “replicate”, “circular”. The option “constant” means padding with zeros.
plenoptic.tools.convergence module
Functions that check for optimization convergence/stabilization.
The functions herein generally differ in what they are checking for convergence: loss, pixel change, etc.
They should probably be able to accept the following arguments, in this order (they can accept more):
synth
: an OptimizedSynthesis object to check.stop_criterion
: the value used as criterion / tolerance that our convergence target is compared against.stop_iters_to_check
: how many iterations back to check for convergence.
They must return a single bool
: True
if we’ve reached convergence,
False
if not.
- plenoptic.tools.convergence.coarse_to_fine_enough(synth, i, ctf_iters_to_check)[source]
Check whether we’ve synthesized all scales and done so for at least ctf_iters_to_check iterations
This is meant to be paired with another convergence check, such as
loss_convergence
.- Parameters:
synth (
Metamer
) – The Metamer object to check.i (
int
) – The current iteration (0-indexed).ctf_iters_to_check (
int
) – Minimum number of iterations coarse-to-fine must run at each scale. If self.coarse_to_fine is False, then this is ignored.
- Returns:
Whether we’ve been doing coarse to fine synthesis for long enough.
- Return type:
ctf_enough
- plenoptic.tools.convergence.loss_convergence(synth, stop_criterion, stop_iters_to_check)[source]
Check whether the loss has stabilized and, if so, return True.
Have we been synthesizing for
stop_iters_to_check
iterations? | |- no yes
- ‘—->Is
abs(synth.loss[-1] - synth.losses[-stop_iters_to_check]) < stop_criterion
?no || yes<——-’ | | ‘——> return
True
| ‘———> returnFalse
- Parameters:
synth (
OptimizedSynthesis
) – The OptimizedSynthesis object to check.stop_criterion (
float
) – If the loss over the paststop_iters_to_check
has changed less thanstop_criterion
, we terminate synthesis.stop_iters_to_check (
int
) – How many iterations back to check in order to see if the loss has stopped decreasing (forstop_criterion
).
- Returns:
Whether the loss has stabilized or not.
- Return type:
loss_stabilized
- plenoptic.tools.convergence.pixel_change_convergence(synth, stop_criterion, stop_iters_to_check)[source]
Check whether the pixel change norm has stabilized and, if so, return True.
Have we been synthesizing for
stop_iters_to_check
iterations? | |- no yes
- ‘—->Is
(synth.pixel_change_norm[-stop_iters_to_check:] < stop_criterion).all()
?no || yes<——-’ | | ‘——> return
True
| ‘———> returnFalse
- Parameters:
synth (
OptimizedSynthesis
) – The OptimizedSynthesis object to check.stop_criterion (
float
) – If the pixel change norm has been less thanstop_criterion
for all of the paststop_iters_to_check
, we terminate synthesis.stop_iters_to_check (
int
) – How many iterations back to check in order to see if the pixel change norm has stopped decreasing (forstop_criterion
).
- Returns:
Whether the pixel change norm has stabilized or not.
- Return type:
loss_stabilized
plenoptic.tools.data module
- plenoptic.tools.data.convert_float_to_int(im, dtype=<class 'numpy.uint8'>)[source]
Convert image from float to 8 or 16 bit image
We work with float images that lie between 0 and 1, but for saving them (either as png or in a numpy array), we want to convert them to 8 or 16 bit integers. This function does that by multiplying it by the max value for the target dtype (255 for 8 bit 65535 for 16 bit) and then converting it to the proper type.
We’ll raise an exception if the max is higher than 1, in which case we have no idea what to do.
- Parameters:
im (
ndarray
) – The image to convertdtype – The target data type. {np.uint8, np.uint16}
- Returns:
The converted image, now with dtype=dtype
- Return type:
im
- plenoptic.tools.data.load_images(paths, as_gray=True)[source]
Correctly load in images
Our models and synthesis methods expect their inputs to be 4d float32 images: (batch, channel, height, width), where the batch dimension contains multiple images and channel contains something like RGB or color channel. This function helps you get your inputs into that format. It accepts either a single file, a list of files, or a single directory containing images, will load them in, normalize them to lie between 0 and 1, convert them to float32, optionally convert them to grayscale, make them tensors, and get them into the right shape.
- Parameters:
paths (
Union
[str
,List
[str
]]) – A str or list of strs. If a list, must contain paths of image files. If a str, can either be the path of a single image file or of a single directory. If a directory, we try to load every file it contains (using imageio.imwrite) and skip those we cannot (thus, for efficiency you should not point this to a directory with lots of non-image files). This is NOT recursive.as_gray (
bool
) – Whether to convert the images into grayscale or not after loading them. If False, we do nothing. If True, we call skimage.color.rgb2gray on them.
- Returns:
4d tensor containing the images.
- Return type:
images
- plenoptic.tools.data.make_synthetic_stimuli(size=256, requires_grad=True)[source]
Make a set of basic stimuli, useful for developping and debugging models
- Parameters:
size (
int
) – The stimuli will have torch.Size([size, size]).requires_grad (
bool
) – Whether to initialize the simuli with gradients.
- Returns:
Tensor of shape [11, 1, size, size]. The set of basic stiuli: [impulse, step_edge, ramp, bar, curv_edge, sine_grating, square_grating, polar_angle, angular_sine, zone_plate, fractal]
- Return type:
stimuli
- plenoptic.tools.data.polar_angle(size, phase=0.0, origin=None, device=None)[source]
Make polar angle matrix (in radians).
Compute a matrix of given size containing samples of the polar angle (in radians, CW from the X-axis, ranging from -pi to pi), relative to given phase, about the given origin pixel.
- Parameters:
size (
Union
[int
,Tuple
[int
,int
]]) – If an int, we assume the image should be of dimensions (size, size). if a tuple, must be a 2-tuple of ints specifying the dimensionsphase (
float
) – The phase of the polar angle function (in radians, clockwise from the X-axis)origin (
Union
[int
,Tuple
[float
,float
],None
]) – The center of the image. if an int, we assume the origin is at (origin, origin). if a tuple, must be a 2-tuple of ints specifying the origin (where (0, 0) is the upper left). if None, we assume the origin lies at the center of the matrix, (size+1)/2.device (
Optional
[device
]) – The device to create this tensor on.
- Returns:
The polar angle matrix
- Return type:
res
- plenoptic.tools.data.polar_radius(size, exponent=1.0, origin=None, device=None)[source]
Make distance-from-origin (r) matrix
Compute a matrix of given size containing samples of a radial ramp function, raised to given exponent, centered at given origin.
- Parameters:
size (
Union
[int
,Tuple
[int
,int
]]) – If an int, we assume the image should be of dimensions (size, size). if a tuple, must be a 2-tuple of ints specifying the dimensions.exponent (
float
) – The exponent of the radial ramp function.origin (
Union
[int
,Tuple
[int
,int
],None
]) – The center of the image. if an int, we assume the origin is at (origin, origin). if a tuple, must be a 2-tuple of ints specifying the origin (where (0, 0) is the upper left). if None, we assume the origin lies at the center of the matrix, (size+1)/2.device (
Union
[str
,device
,None
]) – The device to create this tensor on.
- Returns:
The polar radius matrix.
- Return type:
res
- plenoptic.tools.data.to_numpy(x, squeeze=False)[source]
cast tensor to numpy in the most conservative way possible
- Parameters:
x (
Union
[Tensor
,ndarray
]) – Tensor to be converted to numpy.ndarray on CPU.squeeze (
bool
) – Removes all dummy dimensions of the tensor
- Return type:
Converted tensor as numpy.ndarray on CPU.
plenoptic.tools.display module
various helpful utilities for plotting or displaying information
- plenoptic.tools.display.animshow(video, framerate=2.0, repeat=False, vrange='indep1', zoom=1, title='', col_wrap=None, ax=None, cmap=None, plot_complex='rectangular', batch_idx=None, channel_idx=None, as_rgb=False, **kwargs)[source]
Animate video(s) correctly.
This function animates videos correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved animation (assuming video compression doesn’t interfere); it should generally hold in notebooks as well, but will fail if, e.g., your video is 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can’t control).
This functions returns the matplotlib FuncAnimation object. In order to view it in a Jupyter notebook, use the
plenoptic.convert_anim_to_html(anim)
function. In order to save, useanim.save(filename)
(note for this that you’ll need the appropriate writer installed and on your path, e.g., ffmpeg, imagemagick, etc).- Parameters:
video (torch.Tensor or list) – The videos to display. Tensors should be 5d (batch, channel, time, height, width). List of tensors should be used for tensors of different height and width: all videos will automatically be rescaled so they’re displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other. Videos must all have the same number of frames as well.
framerate (float) – Temporal resolution of the video, in Hz (frames per second).
repeat (bool) – whether to loop the animation or just play it once
vrange (tuple or str) –
If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:
- ’auto0’: all images have same vmin/vmax, which have the same absolute
value, and come from the minimum or maximum across all images, whichever has the larger absolute value
- ’auto/auto1’: all images have same vmin/vmax, which are the
minimum/maximum values across all images
- ’auto2’: all images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images)
- ’auto3’: all images have same vmin/vmax, chosen so as to map the
10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile
- ’indep0’: each image has an independent vmin/vmax, which have the
same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.
- ’indep1’: each image has an independent vmin/vmax, which are their
minimum/maximum values
- ’indep2’: each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev
- ’indep3’: each image has an independent vmin/vmax, chosen so that
the 10th/90th percentile values map to the 10th/90th percentile intensities.
zoom (float) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image.
title (str, list, or None, optional) –
Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if str, will put the same title on every plot. * if list, all values must be str, must be the same length as img,
assigning each title to corresponding image.
if None, no title will be printed (and subtitle will be removed).
col_wrap (int or None, optional) – number of axes to have in each row. If None, will fit all axes in a single row.
ax (matplotlib.pyplot.axis or None, optional) – if None, we make the appropriate figure. otherwise, we resize the axes so that it’s the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either pyrtools.make_figure or plt.figure)
cmap (matplotlib colormap, optional) – colormap to use when showing these images
plot_complex ({'rectangular', 'polar', 'logpolar'}) –
specifies handling of complex values.
’rectangular’: plot real and imaginary components as separate images
’polar’: plot amplitude and phase as separate images
’logpolar’: plot log_2 amplitude and phase as separate images
for any other value, we raise a warning and default to rectangular.
batch_idx (int or None, optional) – Which element from the batch dimension to plot. If None, we plot all.
channel_idx (int or None, optional) – Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then as_rgb=True will fail, because we restrict the channels.
as_rgb (bool, optional) – Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image.
kwargs – Passed to ax.imshow
- Returns:
anim – The animation object. In order to view, must convert to HTML or save.
- Return type:
matplotlib.animation.FuncAnimation
Notes
By default, we use the ffmpeg backend, which requires that you have ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use a different, use the matplotlib rcParams: matplotlib.rcParams[‘animation.writer’] = writer, see https://matplotlib.org/stable/api/animation_api.html#writer-classes for more details.
For displaying in a jupyter notebook, ffmpeg appears to be required.
- plenoptic.tools.display.clean_stem_plot(data, ax=None, title='', ylim=None, xvals=None, **kwargs)[source]
convenience wrapper for plotting stem plots
This plots the data, baseline, cleans up the axis, and sets the title
Should not be called by users directly, but is a helper function for the various plot_representation() functions
By default, stem plot would have a baseline that covers the entire range of the data. We want to be able to break that up visually (so there’s a line from 0 to 9, from 10 to 19, etc), and passing xvals separately allows us to do that. If you want the default stem plot behavior, leave xvals as None.
- Parameters:
data (np.ndarray) – The data to plot (as a stem plot)
ax (matplotlib.pyplot.axis or None, optional) – The axis to plot the data on. If None, we plot on the current axis
title (str or None, optional) – The title to put on the axis if not None. If None, we don’t call
ax.set_title
(useful if you want to avoid changing the title on an existing plot)ylim (tuple or None, optional) – If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits.
xvals (tuple or None, optional) – A 2-tuple of lists, containing the start (
xvals[0]
) and stop (xvals[1]
) x values for plotting. If None, we use the default stem plot behavior.kwargs – passed to ax.stem
- Returns:
ax – The axis with the plot
- Return type:
matplotlib.pyplot.axis
Examples
We allow for breaks in the baseline value if we want to visually break up the plot, as we see below.
- ..plot::
- include-source:
import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) y[15:20] = np.nan y[35:40] = np.nan # we want to draw the baseline from 0 to 14, 20 to 34, and 40 to # 54, everywhere that we have non-NaN values for y xvals = ([0, 20, 40], [14, 34, 54]) po.tools.display.clean_stem_plot(y, xvals=xvals) plt.show()
If we don’t care about breaking up the x-axis, you can simply use the default xvals (
None
). In this case, this function will just clean up the plot a little bit- ..plot::
- include-source:
import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) po.tools.display.clean_stem_plot(y) plt.show()
- plenoptic.tools.display.clean_up_axes(ax, ylim=None, spines_to_remove=['top', 'right', 'bottom'], axes_to_remove=['x'])[source]
Clean up an axis, as desired when making a stem plot of the representation
- Parameters:
ax (matplotlib.pyplot.axis) – The axis to clean up.
ylim (tuple, False, or None) – If a tuple, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, we do nothing.
spines_to_remove (list) – Some combination of ‘top’, ‘right’, ‘bottom’, and ‘left’. The spines we remove from the axis.
axes_to_remove (list) – Some combination of ‘x’, ‘y’. The axes to set as invisible.
- Returns:
ax – The cleaned-up axis
- Return type:
matplotlib.pyplot.axis
- plenoptic.tools.display.convert_anim_to_html(anim)[source]
convert a matplotlib animation object to HTML (for display)
This is a simple little wrapper function that allows the animation to be displayed in a Jupyter notebook
- Parameters:
anim (matplotlib.animation.FuncAnimation) – The animation object to convert to HTML
- plenoptic.tools.display.imshow(image, vrange='indep1', zoom=None, title='', col_wrap=None, ax=None, cmap=None, plot_complex='rectangular', batch_idx=None, channel_idx=None, as_rgb=False, **kwargs)[source]
Show image(s) correctly.
This function shows images correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved image; it should generally hold in notebooks as well, but will fail if, e.g., you plot an image that’s 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can’t control).
- Parameters:
image (torch.Tensor or list) – The images to display. Tensors should be 4d (batch, channel, height, width). List of tensors should be used for tensors of different height and width: all images will automatically be rescaled so they’re displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other.
vrange (tuple or str) –
If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:
- ’auto0’: all images have same vmin/vmax, which have the same absolute
value, and come from the minimum or maximum across all images, whichever has the larger absolute value
- ’auto/auto1’: all images have same vmin/vmax, which are the
minimum/maximum values across all images
- ’auto2’: all images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images)
- ’auto3’: all images have same vmin/vmax, chosen so as to map the
10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile
- ’indep0’: each image has an independent vmin/vmax, which have the
same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.
- ’indep1’: each image has an independent vmin/vmax, which are their
minimum/maximum values
- ’indep2’: each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev
- ’indep3’: each image has an independent vmin/vmax, chosen so that
the 10th/90th percentile values map to the 10th/90th percentile intensities.
zoom (float or None) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image. If None, we try to determine the best zoom.
title (str, list, or None, optional) –
Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if str, will put the same title on every plot. * if list, all values must be str, must be the same length as img,
assigning each title to corresponding image.
if None, no title will be printed (and subtitle will be removed).
col_wrap (int or None, optional) – number of axes to have in each row. If None, will fit all axes in a single row.
ax (matplotlib.pyplot.axis or None, optional) – if None, we make the appropriate figure. otherwise, we resize the axes so that it’s the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either make_figure or plt.figure)
cmap (matplotlib colormap, optional) – colormap to use when showing these images
plot_complex ({'rectangular', 'polar', 'logpolar'}) –
specifies handling of complex values.
’rectangular’: plot real and imaginary components as separate images
’polar’: plot amplitude and phase as separate images
’logpolar’: plot log_2 amplitude and phase as separate images
for any other value, we raise a warning and default to rectangular.
batch_idx (int or None, optional) – Which element from the batch dimension to plot. If None, we plot all.
channel_idx (int or None, optional) – Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then as_rgb=True will fail, because we restrict the channels.
as_rgb (bool, optional) – Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image.
kwargs – Passed to ax.imshow
- Returns:
fig – figure containing the plotted images
- Return type:
PyrFigure
- plenoptic.tools.display.plot_representation(model=None, data=None, ax=None, figsize=(5, 5), ylim=False, batch_idx=0, title='', as_rgb=False)[source]
Helper function for plotting model representation
We are trying to plot
data
onax
, usingmodel.plot_representation
method, if it has it, and otherwise default to a function that makes sense based on the shape ofdata
.All of these arguments are optional, but at least some of them need to be set:
If
model
isNone
, we fall-back to a type of plot based on the shape ofdata
. If it looks image-like, we’ll useplenoptic.imshow
and if it looks vector-like, we’ll useplenoptic.clean_stem_plot
. If it’s a dictionary, we’ll assume each key, value pair gives the title and data to plot on a separate sub-plot.If
data
isNone
, we can only do something ifmodel.plot_representation
has some default behavior whendata=None
; this is probably to plot its ownrepresentation
attribute. Thus, this will raise an Exception if bothmodel
anddata
areNone
, because we have no idea what to plot then.If
ax
isNone
, we create a one-subplot figure usingfigsize
. Ifax
is notNone
, we therefore ignorefigsize
.If
ylim
isNone
, we callrescale_ylim
, which sets the axes’ y-limits to be(-y_max, y_max)
, wherey_max=np.abs(data).max()
. If it’sFalse
, we do nothing.
- Parameters:
model (torch.nn.Module or None, optional) – A differentiable model that tells us how to plot
data
. See above for behavior ifNone
.data (array_like, dict, or None, optional) – The data to plot. See above for behavior if
None
.ax (matplotlib.pyplot.axis or None, optional) – The axis to plot on. See above for behavior if
None
.figsize (tuple, optional) – The size of the figure to create. Ignored if
ax
is notNone
.ylim (tuple, None, or False, optional) – If not None, the y-limits to use for this plot. See above for behavior if
None
. If False, we do nothing.batch_idx (int, optional) – Which index to take from the batch dimension
title (str, optional) – The title to put above this axis. If you want no title, pass the empty string (
''
)as_rgb (bool, optional) – The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the representation doesn’t look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to po.imshow(), see that methods docstring for details.
- Returns:
axes – List of created axes.
- Return type:
list
- plenoptic.tools.display.pyrshow(pyr_coeffs, vrange='indep1', zoom=1, show_residuals=True, cmap=None, plot_complex='rectangular', batch_idx=0, channel_idx=0, **kwargs)[source]
Display steerable pyramid coefficients in orderly fashion.
This function uses
imshow
to show the coefficients of the steeable pyramid, such that each scale shows up on a single row, with each scale in a given column.Note that unlike imshow, we can only show one batch or channel at a time
- Parameters:
pyr_coeffs (dict) – pyramid coefficients in the standard dictionary format as returned by
SteerablePyramidFreq.forward()
vrange (tuple or str) –
If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string:
- ’auto0’: all images have same vmin/vmax, which have the same absolute
value, and come from the minimum or maximum across all images, whichever has the larger absolute value
- ’auto/auto1’: all images have same vmin/vmax, which are the
minimum/maximum values across all images
- ’auto2’: all images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images)
- ’auto3’: all images have same vmin/vmax, chosen so as to map the
10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile
- ’indep0’: each image has an independent vmin/vmax, which have the
same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value.
- ’indep1’: each image has an independent vmin/vmax, which are their
minimum/maximum values
- ’indep2’: each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev
- ’indep3’: each image has an independent vmin/vmax, chosen so that
the 10th/90th percentile values map to the 10th/90th percentile intensities.
zoom (float) – ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image.
show_residuals (bool) – whether to display the residual bands (lowpass, highpass depending on the pyramid type)
cmap (matplotlib colormap, optional) – colormap to use when showing these images
plot_complex ({'rectangular', 'polar', 'logpolar'}) –
specifies handling of complex values.
’rectangular’: plot real and imaginary components as separate images
’polar’: plot amplitude and phase as separate images
’logpolar’: plot log_2 amplitude and phase as separate images
for any other value, we raise a warning and default to rectangular.
batch_idx (int, optional) – Which element from the batch dimension to plot.
channel_idx (int, optional) – Which element from the channel dimension to plot.
kwargs – Passed on to
pyrtools.pyrshow
- Returns:
fig – the figure displaying the coefficients.
- Return type:
PyrFigure
- plenoptic.tools.display.rescale_ylim(axes, data)[source]
rescale y-limits nicely
We take the axes and set their limits to be
(-y_max, y_max)
, wherey_max=np.abs(data).max()
- Parameters:
axes (list) – A list of matplotlib axes to rescale
data (array_like or dict) – The data to use when rescaling (or a dictiontary of those values)
- plenoptic.tools.display.update_plot(axes, data, model=None, batch_idx=0)[source]
Update the information in some axes.
This is used for creating an animation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It assumes the plot has been created by something like
plot_representation
, which initializes all the artists.We can update stem plots, lines (as returned by
plt.plot
), scatter plots, or images (RGB, RGBA, or grayscale).There are two modes for this:
single axis: axes is a single axis, which may contain multiple artists (all of the same type) to update. data should be a Tensor with multiple channels (one per artist in the same order) or be a dictionary whose keys give the label(s) of the corresponding artist(s) and whose values are Tensors.
multiple axes: axes is a list of axes, each of which contains a single artist to update (artists can be different types). data should be a Tensor with multiple channels (one per axis in the same order) or a dictionary with the same number of keys as axes, which we can iterate through in order, and whose values are Tensors.
In all cases, data Tensors should be 3d (if the plot we’re updating is a line or stem plot) or 4d (if it’s an image or scatter plot).
RGB(A) images are special, since we store that info along the channel dimension, so they only work with single-axis mode (which will only have a single artist, because that’s how imshow works).
If you have multiple axes, each with multiple artists you want to update, that’s too complicated for us, and so you should write a
model.update_plot()
function which handles that.If
model
is set, we try to callmodel.update_plot()
(which must also return artists). If model doesn’t have anupdate_plot
method, then we try to figure out how to update the axes ourselves, based on the shape of the data.- Parameters:
axes (list or matplotlib.pyplot.axis) – The axis or list of axes to update. We assume that these are the axes created by
plot_representation
and so contain stem plots in the correct order.data (torch.Tensor or dict) – The new data to plot.
model (torch.nn.Module or None, optional) – A differentiable model that tells us how to plot
data
. See above for behavior ifNone
.batch_idx (int, optional) – Which index to take from the batch dimension
- Returns:
artists – A list of the artists used to update the information on the plots
- Return type:
list
- plenoptic.tools.display.update_stem(stem_container, ydata)[source]
Update the information in a stem plot
We update the information in a single stem plot to match that given by
ydata
. We update the position of the markers and and the lines connecting them to the baseline, but we don’t change the baseline at all and assume that the xdata shouldn’t change at all.- Parameters:
stem_container (matplotlib.container.StemContainer) – Single container for the artists created in a
plt.stem
plot. It can be treated like a namedtuple(markerline, stemlines, baseline)
. In order to get this from an axisax
, tryax.containers[0]
(obviously if you have more than one container in that axis, it may not be the first one).ydata (array_like) – The new y-data to show on the plot. Importantly, must be the same length as the existing y-data.
- Returns:
stem_container – The StemContainer containing the updated artists.
- Return type:
matplotlib.container.StemContainer
plenoptic.tools.external module
tools to deal with data from outside plenoptic
For example, pre-existing synthesized images
- plenoptic.tools.external.plot_MAD_results(original_image, noise_levels=None, results_dir=None, ssim_images_dir=None, zoom=3, vrange='indep1', **kwargs)[source]
plot original MAD results, provided by Zhou Wang
Plot the results of original MAD Competition, as provided in .mat files. The figure created shows the results for one reference image and multiple noise levels. The reference image is plotted on the first row, followed by a separate row for each noise level, which will show the initial (noisy) image and the four synthesized images, with their respective losses for the two metrics (MSE and SSIM).
We also return a DataFrame that contains the losses, noise levels, and original image name for each plotted noise level.
This code can probably be adapted to other uses, but requires that all images are the same size and assumes they’re all 64 x 64 pixels.
- Parameters:
original_image ({samp1, samp2, samp3, samp4, samp5, samp6, samp7,) – samp8, samp9, samp10} which of the sample images to plot
noise_levels (list or None, optional) – which noise levels to plot. if None, will plot all. If a list, elements must be 2**i where i is in [1, 10]
results_dir (None or str, optional) – path to the results directory containing the results.mat files. If None, we call po.data.fetch_data to download (requires optional dependency pooch).
ssim_images_dir (None or str, optional) – path to the directory containing the .tif images used in SSIM paper. If None, we call po.data.fetch_data to download (requires optional dependency pooch).
zoom (int, optional) – amount to zoom each image, passed to pyrtools.imshow
vrange (str, optional) – in addition to the values accepted by pyrtools.imshow, we also accept ‘row0/1/2/3’, which is the same as ‘auto0/1/2/3’, except that we do it on a per-row basis (all images with same noise level)
kwargs – passed to pyrtools.imshow. Note that we call imshow separately on each image and so any argument that relies on imshow having access to all images will probably not work as expected
- Returns:
fig (pyrtools.tools.display.Figure) – figure containing the images
results (dict) – dictionary containing the errors for each noise level. To convert to a well-structured pandas DataFrame, run
pd.DataFrame(results).T
plenoptic.tools.optim module
Tools related to optimization such as more objective functions.
- plenoptic.tools.optim.l2_norm(synth_rep, ref_rep, **kwargs)[source]
l2-norm of the difference between ref_rep and synth_rep
- Parameters:
synth_rep (
Tensor
) – The first tensor to compare, model representation of the synthesized image.ref_rep (
Tensor
) – The second tensor to compare, model representation of the reference image. must be same size assynth_rep
.kwargs – Ignored, only present to absorb extra arguments.
- Returns:
The L2-norm of the difference between
ref_rep
andsynth_rep
.- Return type:
loss
- plenoptic.tools.optim.mse(synth_rep, ref_rep, **kwargs)[source]
return the MSE between synth_rep and ref_rep
For two tensors, \(x\) and \(y\), with \(n\) values each:
\[MSE &= \frac{1}{n}\sum_i=1^n (x_i - y_i)^2\]The two images must have a float dtype
- Parameters:
synth_rep (
Tensor
) – The first tensor to compare, model representation of the synthesized imageref_rep (
Tensor
) – The second tensor to compare, model representation of the reference image. must be same size assynth_rep
,kwargs – Ignored, only present to absorb extra arguments
- Returns:
The mean-squared error between
synth_rep
andref_rep
- Return type:
loss
- plenoptic.tools.optim.penalize_range(synth_img, allowed_range=(0.0, 1.0), **kwargs)[source]
penalize values outside of allowed_range
instead of clamping values to exactly fall in a range, this provides a ‘softer’ way of doing it, by imposing a quadratic penalty on any values outside the allowed_range. All values within the allowed_range have a penalty of 0
- Parameters:
synth_img (
Tensor
) – The tensor to penalize. the synthesized image.allowed_range (
Tuple
[float
,float
]) – 2-tuple of values giving the (min, max) allowed valueskwargs – Ignored, only present to absorb extra arguments
- Returns:
Penalty for values outside range
- Return type:
penalty
- plenoptic.tools.optim.relative_MSE(synth_rep, ref_rep, **kwargs)[source]
Squared l2-norm of the difference between reference representation and synthesized representation relative to the squared l2-norm of the reference representation:
$$frac{||x - hat{x}||_2^2}{||x||_2^2}$$
- Parameters:
synth_rep (
Tensor
) – The first tensor to compare, model representation of the synthesized image.ref_rep (
Tensor
) – The second tensor to compare, model representation of the reference image. must be same size assynth_rep
.kwargs – Ignored, only present to absorb extra arguments
- Returns:
Ratio of the squared l2-norm of the difference between
ref_rep
andsynth_rep
to the squared l2-norm ofref_rep
- Return type:
loss
plenoptic.tools.signal module
- plenoptic.tools.signal.add_noise(img, noise_mse)[source]
Add normally distributed noise to an image
This adds normally-distributed noise to an image so that the resulting noisy version has the specified mean-squared error.
- Parameters:
img (
Tensor
) – The image to make noisy.noise_mse (
Union
[float
,List
[float
]]) – The target MSE value / variance of the noise. More than one value is allowed.
- Returns:
The noisy image. If noise_mse contains only one element, this will be the same size as img. Else, each separate value from noise_mse will be along the batch dimension.
- Return type:
noisy_img
- plenoptic.tools.signal.autocorrelation(x)[source]
Compute the autocorrelation of x.
- Parameters:
x (
Tensor
) – N-dimensional tensor. We assume the last two dimension are height and width and compute you autocorrelation on these dimensions (independently on each other dimension).- Returns:
Autocorrelation of x
- Return type:
ac
Notes
By the Einstein-Wiener-Khinchin theorem: The autocorrelation of a wide sense stationary (WSS) process is the inverse Fourier transform of its energy spectrum (ESD) - which itself is the multiplication between FT(x(t)) and FT(x(-t)). In other words, the auto-correlation is convolution of the signal x with itself, which corresponds to squaring in the frequency domain. This approach is computationally more efficient than brute force (n log(n) vs n^2).
By Cauchy-Swartz, the autocorrelation attains it is maximum at the center location (ie. no shift) - that maximum value is the signal’s variance (assuming that the input signal is mean centered).
- plenoptic.tools.signal.center_crop(x, output_size)[source]
Crop out the center of a signal.
If x has an even number of elements on either of those final two dimensions, we round up.
- Parameters:
x (
Tensor
) – N-dimensional tensor, we assume the last two dimensions are height and width.output_size (
int
) – The size of the output. Note that we only support a single number, so both dimensions are cropped identically
- Returns:
Tensor whose last two dimensions have each been cropped to
output_size
- Return type:
cropped
- plenoptic.tools.signal.expand(x, factor)[source]
Expand a signal by a factor.
We do this in the frequency domain: pasting the Fourier contents of
x
in the center of a larger empty tensor, and then taking the inverse FFT.- Parameters:
x (
Tensor
) – The signal for expansion.factor (
float
) – Factor by which to resize image. Must be larger than 1 and factor * x.shape[-2:] must give integer values
- Returns:
The expanded signal
- Return type:
expanded
See also
shrink
The inverse operation
- plenoptic.tools.signal.interpolate1d(x_new, Y, X)[source]
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points (X, Y), evaluated at x_new.
Note: this function is just a wrapper around
np.interp()
.- Parameters:
x_new (
Tensor
) – The x-coordinates at which to evaluate the interpolated values.Y (
Union
[Tensor
,ndarray
]) – The y-coordinates of the data points.X (
Union
[Tensor
,ndarray
]) – The x-coordinates of the data points, same length as X.
- Return type:
Interpolated values of shape identical to x_new.
- plenoptic.tools.signal.make_disk(img_size, outer_radius=None, inner_radius=None)[source]
Create a circular mask with softened edges to an image.
All values within
inner_radius
will be 1, and all values frominner_radius
toouter_radius
will decay smoothly to 0.- Parameters:
img_size (
Union
[int
,Tuple
[int
,int
],Size
]) – Size of image in pixels.outer_radius (
Optional
[float
]) – Total radius of disk. Values frominner_radius
toouter_radius
will decay smoothly to zero.inner_radius (
Optional
[float
]) – Radius of inner disk. All elements from the origin toinner_radius
will be set to 1.
- Returns:
Tensor mask with torch.Size(img_size).
- Return type:
mask
- plenoptic.tools.signal.maximum(x, dim=None, keepdim=False)[source]
Compute maximum in torch over any dim or combination of axes in tensor.
- Parameters:
x (
Tensor
) – Input tensordim (
Optional
[List
[int
]]) – Dimensions over which you would like to compute the minimumkeepdim (
bool
) – Keep original dimensions of tensor when returning result
- Returns:
Maximum value of x.
- Return type:
max_x
- plenoptic.tools.signal.minimum(x, dim=None, keepdim=False)[source]
Compute minimum in torch over any axis or combination of axes in tensor.
- Parameters:
x (
Tensor
) – Input tensor.dim (
Optional
[List
[int
]]) – Dimensions over which you would like to compute the minimum.keepdim (
bool
) – Keep original dimensions of tensor when returning result.
- Returns:
Minimum value of x.
- Return type:
min_x
- plenoptic.tools.signal.modulate_phase(x, phase_factor=2.0)[source]
Modulate the phase of a complex signal.
Doubling the phase of a complex signal allows you to, for example, take the correlation between steerable pyramid coefficients at two adjacent spatial scales.
- Parameters:
x (
Tensor
) – Complex tensor whose phase will be modulated.phase_factor (
float
) – Multiplicative factor to change phase by.
- Returns:
Phase-modulated complex tensor.
- Return type:
x_mod
- plenoptic.tools.signal.polar_to_rectangular(amplitude, phase)[source]
Polar to rectangular coordinate transform
- Parameters:
amplitude (
Tensor
) – Tensor containing the amplitude (aka. complex modulus). Must be > 0.phase (
Tensor
) – Tensor containing the phase
- Return type:
Complex tensor.
- plenoptic.tools.signal.raised_cosine(width=1, position=0, values=(0, 1))[source]
Return a lookup table containing a “raised cosine” soft threshold function.
- Y = VALUES(1)
(VALUES(2)-VALUES(1))
cos^2( PI/2 * (X - POSITION + WIDTH)/WIDTH )
This lookup table is suitable for use by interpolate1d
- Parameters:
width (
float
) – The width of the region over which the transition occurs.position (
float
) – The location of the center of the threshold.values (
Tuple
[float
,float
]) – 2-tuple specifying the values to the left and right of the transition.
- Return type:
Tuple
[ndarray
,ndarray
]- Returns:
X – The x values of this raised cosine.
Y – The y values of this raised cosine.
- plenoptic.tools.signal.rectangular_to_polar(x)[source]
Rectangular to polar coordinate transform
- Parameters:
x (
Tensor
) – Complex tensor.- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
amplitude – Tensor containing the amplitude (aka. complex modulus).
phase – Tensor containing the phase.
- plenoptic.tools.signal.rescale(x, a=0.0, b=1.0)[source]
Linearly rescale the dynamic range of the input x to [a,b].
- Return type:
Tensor
- plenoptic.tools.signal.shrink(x, factor)[source]
Shrink a signal by a factor.
We do this in the frequency domain: cropping out the center of the Fourier transform of
x
, putting it in a new tensor, and taking the IFFT.- Parameters:
x (
Tensor
) – The signal for expansion.factor (
int
) – Factor by which to resize image. Must be larger than 1 and factor / x.shape[-2:] must give integer values
- Returns:
The expanded signal
- Return type:
expanded
See also
expand
The inverse operation
- plenoptic.tools.signal.steer(basis, angle, harmonics=None, steermtx=None, return_weights=False, even_phase=True)[source]
Steer BASIS to the specfied ANGLE.
- Parameters:
basis (
Tensor
) – Array whose columns are vectorized rotated copies of a steerable function, or the responses of a set of steerable filters.angle (
Union
[ndarray
,Tensor
,float
]) – Scalar or column vector the size of the basis. specifies the angle(s) (in radians) to steer toharmonics (
Optional
[List
[int
]]) – A list of harmonic numbers indicating the angular harmonic content of the basis. if None (default), N even or odd low frequencies, as for derivative filterssteermtx (
Union
[Tensor
,ndarray
,None
]) – Matrix which maps the filters onto Fourier series components (ordered [cos0 cos1 sin1 cos2 sin2 … sinN]). See steer_to_harmonics_mtx function for more details. If None (default), assumes cosine phase harmonic components, and filter positions at 2pi*n/N.return_weights (
bool
) – Whether to return the weights or not.even_phase (
bool
) – Specifies whether the harmonics are cosine or sine phase aligned about those positions.
- Returns:
res – The resteered basis.
steervect – The weights used to resteer the basis. only returned if
return_weights
is True.
plenoptic.tools.stats module
- plenoptic.tools.stats.kurtosis(x, mean=None, var=None, dim=None, keepdim=False)[source]
sample estimate of x tailedness (presence of outliers)
kurtosis of univariate noral is 3.
smaller than 3: platykurtic (eg. uniform distribution)
greater than 3: leptokurtic (eg. Laplace distribution)
- Parameters:
x (
Tensor
) – The input tensor.mean (
Union
[float
,Tensor
,None
]) – Reuse a precomputed mean.var (
Union
[float
,Tensor
,None
]) – Reuse a precomputed variance.dim (
Union
[int
,List
[int
],None
]) – The dimension or dimensions to reduce.keepdim (
bool
) – Whether the output tensor has dim retained or not.
- Returns:
The kurtosis tensor.
- Return type:
out
- plenoptic.tools.stats.skew(x, mean=None, var=None, dim=None, keepdim=False)[source]
Sample estimate of x asymmetry about its mean
- Parameters:
x (
Tensor
) – The input tensormean (
Union
[float
,Tensor
,None
]) – Reuse a precomputed meanvar (
Union
[float
,Tensor
,None
]) – Reuse a precomputed variancedim (
Union
[int
,List
[int
],None
]) – The dimension or dimensions to reduce.keepdim (
bool
) – Whether the output tensor has dim retained or not.
- Returns:
The skewness tensor.
- Return type:
out
- plenoptic.tools.stats.variance(x, mean=None, dim=None, keepdim=False)[source]
Calculate sample variance.
Note that this is the uncorrected, or sample, variance, corresponding to
torch.var(*, correction=0)
- Parameters:
x (
Tensor
) – The input tensormean (
Union
[float
,Tensor
,None
]) – Reuse a precomputed meandim (
Union
[int
,List
[int
],None
]) – The dimension or dimensions to reduce.keepdim (
bool
) – Whether the output tensor has dim retained or not.
- Returns:
The variance tensor.
- Return type:
out
plenoptic.tools.straightness module
- plenoptic.tools.straightness.deviation_from_line(sequence, normalize=True)[source]
Compute the deviation of sequence to the straight line between its endpoints.
Project each point of the path sequence onto the line defined by the anchor points, and measure the two sides of a right triangle: - from the projected point to the first anchor point
(aka. distance along line)
from the projected point to the corresponding point on the path sequence (aka. distance from line).
- Parameters:
sequence (
Tensor
) – sequence of signals of shape (T, channel, height, width)normalize (
bool
) – use the distance between the anchor points as a unit of measurement
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
dist_along_line – sequence of T euclidian distances along the line
dist_from_line – sequence of T euclidian distances to the line
- plenoptic.tools.straightness.make_straight_line(start, stop, n_steps)[source]
make a straight line between start and stop with n_steps transitions.
- Parameters:
start (
Tensor
) – Images of shape (1, channel, height, width), the anchor points between which a line will be made.stop (
Tensor
) – Images of shape (1, channel, height, width), the anchor points between which a line will be made.n_steps (
int
) – Number of steps (i.e., transitions) to create between the two anchor points. Must be positive.
- Returns:
Tensor of shape (n_steps+1, channel, height, width)
- Return type:
straight
- plenoptic.tools.straightness.sample_brownian_bridge(start, stop, n_steps, max_norm=1)[source]
Sample a brownian bridge between start and stop made up of n_steps
- Parameters:
start (
Tensor
) – signal of shape (1, channel, height, width), the anchor points between which a random path will be sampled (like pylons on which the bridge will rest)stop (
Tensor
) – signal of shape (1, channel, height, width), the anchor points between which a random path will be sampled (like pylons on which the bridge will rest)n_steps (
int
) – number of steps on the bridgemax_norm (
float
) – controls variability of the bridge by setting how far (in l2 norm) it veers from the straight line interpolation at the midpoint between pylons. each component of the bridge will reach a maximal variability with std = max_norm / sqrt(d), where d is the dimension of the signal. (ie. d = C*H*W). Must be non-negative.
- Returns:
sequence of shape (n_steps+1, channel, height, width) a brownian bridge across the two pylons
- Return type:
bridge
- plenoptic.tools.straightness.translation_sequence(image, n_steps=10)[source]
make a horizontal translation sequence on image
- Parameters:
image (
Tensor
) – Base image of shape, (1, channel, height, width)n_steps (
int
) – Number of steps in the sequence. The length of the sequence is n_steps + 1. Must be positive.
- Returns:
translation sequence of shape (n_steps+1, channel, height, width)
- Return type:
sequence
plenoptic.tools.validate module
Functions to validate synthesis inputs.
- plenoptic.tools.validate.remove_grad(model)[source]
Detach all parameters and buffers of model (in place).
- plenoptic.tools.validate.validate_coarse_to_fine(model, image_shape=None, device='cpu')[source]
Determine whether a model can be used for coarse-to-fine synthesis.
In particular, this function checks the following (with associated errors):
Whether
model
has ascales
attribute (AttributeError
).Whether
model.forward
accepts ascales
keyword argument (TypeError
).Whether the output of
model.forward
changes shape when thescales
keyword argument is set (ValueError
).
- Parameters:
model (
Module
) – The model to validate.image_shape (
Optional
[Tuple
[int
,int
,int
,int
]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel
, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)device (
Union
[str
,device
]) – Which device to place the test image on.
- plenoptic.tools.validate.validate_input(input_tensor, no_batch=False, allowed_range=None)[source]
Determine whether input_tensor tensor can be used for synthesis.
In particular, this function:
Checks if input_tensor has a float or complex dtype
Checks if input_tensor is 4d.
If
no_batch
is True, check whetherinput_tensor.shape[0] != 1
If
allowed_range
is not None, check whether all values ofinput_tensor
lie within the specified range.
If any of the above fail, a
ValueError
is raised.- Parameters:
input_tensor (
Tensor
) – The tensor to validate.no_batch (
bool
) – If True, raise a ValueError if the batch dimension ofinput_tensor
is greater than 1.allowed_range (
Optional
[Tuple
[float
,float
]]) – If not None, ensure that all values ofinput_tensor
lie within allowed_range.
- plenoptic.tools.validate.validate_metric(metric, image_shape=None, image_dtype=torch.float32, device='cpu')[source]
Determines whether a metric can be used for MADCompetition synthesis.
In particular, this functions checks the following (with associated exceptions):
Whether
metric
is callable and accepts two 4d tensors as input (TypeError
).Whether
metric
returns a scalar when called with two 4d tensors as input (ValueError
).Whether
metric
returns a value less than 5e-7 when with two identical 4d tensors as input (ValueError
). (This threshold was chosen because 1-SSIM of two identical images is 5e-8 on GPU).
- Parameters:
metric (
Union
[Module
,Callable
[[Tensor
,Tensor
],Tensor
]]) – The metric to validate.image_shape (
Optional
[Tuple
[int
,int
,int
,int
]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel
, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)image_dtype (
dtype
) – What dtype to validate against.device (
Union
[str
,device
]) – What device to place the test images on.
- plenoptic.tools.validate.validate_model(model, image_shape=None, image_dtype=torch.float32, device='cpu')[source]
Determine whether model can be used for sythesis.
In particular, this function checks the following (with their associated errors raised):
If
model
adds a gradient to an input tensor, which implies that some of it is learnable (ValueError
).If
model
returns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError
).If
model
strips gradient from an input with gradient attached (ValueError
).If
model
casts an input tensor to something else and returns it to a tensor before returning it (ValueError
).If
model
changes the precision of the input tensor (TypeError
).If
model
returns a 3d or 4d output when given a 4d input (ValueError
).If
model
changes the device of the input (RuntimeError
).
Finally, we check if
model
is in training mode and raise a warning if so. Note that this is different from having learnable parameters, see ``pytorch docs <https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc>``_- Parameters:
model (
Module
) – The model to validate.image_shape (
Optional
[Tuple
[int
,int
,int
,int
]]) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel
, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16)image_dtype (
dtype
) – What dtype to validate against.device (
Union
[str
,device
]) – What device to place test image on.
See also
remove_grad
Helper function for detaching all parameters (in place).