How to replicate scipy.signal.correlate2d(x, h) with arbitrarily sized x and h? ifft2(fft2(x) * conj(fft2(h))) gives bad results. I've read related Q&As but they either do circular cross-correlation, or do convolution which doesn't easily translate.
- 8,912
- 5
- 23
- 74
1 Answers
I've replicated scipy.signal.correlate2d for all mode - 'full', 'same', 'valid'.
Scipy's cross-correlation, interestingly, agrees with my philosophy of being defined "backwards". This means we can't simply run convolve logic with a conjugated + flipped kernel, except for 'full' output mode (with correct padding).
The filter/template $h$ sweeps the input $x$, in a non-commutative manner, including in output shape. The idea is, we seek similarities of template with input at every point of input:
'same': unpadding is such thatout.shape == x.shape, and the filter overlaps the input by at least half its size.'valid': here the idea mimicsconvolve, except for when $h$ is larger than $x$. Then,scipydoes something I can't understand, and it differs from its 1D correlate: it swaps the inputs to convolve faster, but then instead of flipping the output and conjugating, it only flips. I've not investigated, just reproduced it.
I've only implemented zero-padding. Non-zero padding is tricky and more complicated if it is to remain performant.
Extra speedups
next_fast_len: FFTs are done with fast FFT lengths instead of naive paddingworkers: multiprocessing FFTs, scipy's feature- Don't explicitly pad, instead take bigger FFTs
inplace: operate in-place where possible instead of allocating new arraysreusables: if $x$ has the same shape and $h$ doesn't change, we effectively cache what computation's in commonreal: if $x$ and $h$ are real-valued, skipconj
Instead of conjugating fft, I conjugate then flip $h$, which is faster for real but otherwise slower. Conjugating fft is also tricker to implement efficiently.
Code + testing
Available at Github.
Just the function
Not up to date, I'll only be updating the Github code.
import numpy as np
import scipy.signal
from scipy.fft import next_fast_len, fft2, ifft2
def cross_correlate_2d(x, h, mode='same', real=True, get_reusables=False):
"""2D cross-correlation, replicating scipy.signal.correlate2d.
`reusables` are passed in as `h`.
Set `get_reusables=True` to return `out, reusables`.
"""
# check if `h` is reusables
if not isinstance(h, tuple):
# fetch shapes, check inputs
xs, hs = x.shape, h.shape
h_not_smaller = all(hs[i] >= xs[i] for i in (0, 1))
x_not_smaller = all(xs[i] >= hs[i] for i in (1, 0))
if mode == 'valid' and not (h_not_smaller or x_not_smaller):
raise ValueError(
"For `mode='valid'`, every axis in `x` must be at least "
"as long as in `h`, or vice versa. Got x:{}, h:{}".format(
str(xs), str(hs)))
# swap if needed
swap = bool(mode == 'valid' and not x_not_smaller)
if swap:
xadj, hadj = h, x
else:
xadj, hadj = x, h
xs, hs = xadj.shape, hadj.shape
# compute pad quantities
full_len_h = xs[0] + hs[0] - 1
full_len_w = xs[1] + hs[1] - 1
padded_len_h = next_fast_len(full_len_h)
padded_len_w = next_fast_len(full_len_w)
padded_shape = (padded_len_h, padded_len_w)
# compute unpad indices
if mode == 'full':
offset_h, offset_w = 0, 0
len_h, len_w = full_len_h, full_len_w
elif mode == 'same':
len_h, len_w = xs
offset_h, offset_w = [g//2 for g in hs]
elif mode == 'valid':
ax_pairs = ((xs[0], hs[0]), (xs[1], hs[1]))
len_h, len_w = [max(g) - min(g) + 1 for g in ax_pairs]
offset_h, offset_w = [min(g) - 1 for g in ax_pairs]
unpad_h = slice(offset_h, offset_h + len_h)
unpad_w = slice(offset_w, offset_w + len_w)
# handle filter / template
if real:
hadj = hadj[::-1, ::-1]
else:
hadj = np.conj(hadj)[::-1, ::-1]
hf = fft2(hadj, padded_shape)
else:
reusables = h
(hf, swap, padded_shape, unpad_h, unpad_w) = reusables
if swap:
xadj, hadj = h, x
else:
xadj, hadj = x, h
# FFT convolution
out = ifft2(fft2(xadj, padded_shape) * hf)
if real:
out = out.real
# unpad, unswap
out = out[unpad_h, unpad_w]
if swap:
out = out[::-1, ::-1]
# pack reusables
if get_reusables:
reusables = (hf, swap, padded_shape, unpad_h, unpad_w)
# return
return ((out, reusables) if get_reusables else
out)
- 8,912
- 5
- 23
- 74