0

How to replicate scipy.signal.correlate2d(x, h) with arbitrarily sized x and h? ifft2(fft2(x) * conj(fft2(h))) gives bad results. I've read related Q&As but they either do circular cross-correlation, or do convolution which doesn't easily translate.

OverLordGoldDragon
  • 8,912
  • 5
  • 23
  • 74

1 Answers1

1

I've replicated scipy.signal.correlate2d for all mode - 'full', 'same', 'valid'. Scipy's cross-correlation, interestingly, agrees with my philosophy of being defined "backwards". This means we can't simply run convolve logic with a conjugated + flipped kernel, except for 'full' output mode (with correct padding).

The filter/template $h$ sweeps the input $x$, in a non-commutative manner, including in output shape. The idea is, we seek similarities of template with input at every point of input:

  • 'same': unpadding is such that out.shape == x.shape, and the filter overlaps the input by at least half its size.
  • 'valid': here the idea mimics convolve, except for when $h$ is larger than $x$. Then, scipy does something I can't understand, and it differs from its 1D correlate: it swaps the inputs to convolve faster, but then instead of flipping the output and conjugating, it only flips. I've not investigated, just reproduced it.

I've only implemented zero-padding. Non-zero padding is tricky and more complicated if it is to remain performant.

Extra speedups

  1. next_fast_len: FFTs are done with fast FFT lengths instead of naive padding
  2. workers: multiprocessing FFTs, scipy's feature
  3. Don't explicitly pad, instead take bigger FFTs
  4. inplace: operate in-place where possible instead of allocating new arrays
  5. reusables: if $x$ has the same shape and $h$ doesn't change, we effectively cache what computation's in common
  6. real: if $x$ and $h$ are real-valued, skip conj

Instead of conjugating fft, I conjugate then flip $h$, which is faster for real but otherwise slower. Conjugating fft is also tricker to implement efficiently.

Code + testing

Available at Github.

Just the function

Not up to date, I'll only be updating the Github code.

import numpy as np
import scipy.signal
from scipy.fft import next_fast_len, fft2, ifft2

def cross_correlate_2d(x, h, mode='same', real=True, get_reusables=False): """2D cross-correlation, replicating scipy.signal.correlate2d.

`reusables` are passed in as `h`.
Set `get_reusables=True` to return `out, reusables`.
"""
# check if `h` is reusables
if not isinstance(h, tuple):
    # fetch shapes, check inputs
    xs, hs = x.shape, h.shape
    h_not_smaller = all(hs[i] >= xs[i] for i in (0, 1))
    x_not_smaller = all(xs[i] >= hs[i] for i in (1, 0))
    if mode == 'valid' and not (h_not_smaller or x_not_smaller):
        raise ValueError(
            "For `mode='valid'`, every axis in `x` must be at least "
            "as long as in `h`, or vice versa. Got x:{}, h:{}".format(
                             str(xs), str(hs)))

    # swap if needed
    swap = bool(mode == 'valid' and not x_not_smaller)
    if swap:
        xadj, hadj = h, x
    else:
        xadj, hadj = x, h
    xs, hs = xadj.shape, hadj.shape

    # compute pad quantities
    full_len_h = xs[0] + hs[0] - 1
    full_len_w = xs[1] + hs[1] - 1
    padded_len_h = next_fast_len(full_len_h)
    padded_len_w = next_fast_len(full_len_w)
    padded_shape = (padded_len_h, padded_len_w)

    # compute unpad indices
    if mode == 'full':
        offset_h, offset_w = 0, 0
        len_h, len_w = full_len_h, full_len_w
    elif mode == 'same':
        len_h, len_w = xs
        offset_h, offset_w = [g//2 for g in hs]
    elif mode == 'valid':
        ax_pairs = ((xs[0], hs[0]), (xs[1], hs[1]))
        len_h, len_w = [max(g) - min(g) + 1 for g in ax_pairs]
        offset_h, offset_w = [min(g) - 1 for g in ax_pairs]
    unpad_h = slice(offset_h, offset_h + len_h)
    unpad_w = slice(offset_w, offset_w + len_w)

    # handle filter / template
    if real:
        hadj = hadj[::-1, ::-1]
    else:
        hadj = np.conj(hadj)[::-1, ::-1]
    hf = fft2(hadj, padded_shape)
else:
    reusables = h
    (hf, swap, padded_shape, unpad_h, unpad_w) = reusables
    if swap:
        xadj, hadj = h, x
    else:
        xadj, hadj = x, h

# FFT convolution
out = ifft2(fft2(xadj, padded_shape) * hf)
if real:
    out = out.real

# unpad, unswap
out = out[unpad_h, unpad_w]
if swap:
    out = out[::-1, ::-1]

# pack reusables
if get_reusables:
    reusables = (hf, swap, padded_shape, unpad_h, unpad_w)

# return
return ((out, reusables) if get_reusables else
        out)

OverLordGoldDragon
  • 8,912
  • 5
  • 23
  • 74