"""Gabor filters and the Difference-of-Gaussians, a visual-cortex front end.

A 2-D Gabor filter is a Gaussian envelope multiplying an oriented sinusoid.
It is the function that achieves the joint space / spatial-frequency
uncertainty bound (the 2-D form of the Gabor limit), which is why the simple
cells of the primary visual cortex are well modelled by Gabor receptive
fields (Daugman, 1985). A bank of them, spanning orientations and scales,
tiles spatial frequency the way the cochlea's gammatone bank tiles audio
frequency.

This module wraps scikit-image's well-tested Gabor routines with a small,
importable API: single kernels, banks, magnitude responses, an energy
feature stack, and the center-surround Difference-of-Gaussians (DoG) that
models retinal/LGN receptive fields.
"""

import numpy as np
from scipy import ndimage
from skimage.filters import gabor, gabor_kernel


# ---------------------------------------------------------------------------
# Single Gabor kernel
# ---------------------------------------------------------------------------

def gabor_kernel_2d(frequency, theta=0.0, bandwidth=1.0, offset=0.0):
    """Complex 2-D Gabor kernel.

    The kernel is ``g(x, y) = exp(-(x'^2 + y'^2) / (2 sigma^2)) *
    exp(i (2 pi f x' + offset))`` where ``x'`` is the coordinate rotated by
    ``theta`` and ``sigma`` is set from the octave ``bandwidth``. The real
    part is an even (cosine) filter, the imaginary part an odd (sine) filter.

    This wrapper uses scikit-image's bandwidth-derived envelope, which fixes
    the aspect ratio gamma = 1 (a circular Gaussian). The general 2-D Gabor
    on the theory page carries a gamma aspect-ratio term; for an elongated
    envelope, pass ``sigma_x`` and ``sigma_y`` directly to
    ``skimage.filters.gabor_kernel``.

    Parameters
    ----------
    frequency : float
        Spatial frequency of the carrier in cycles per pixel.
    theta : float, optional
        Orientation in radians (0 detects vertical bars).
    bandwidth : float, optional
        Spatial-frequency bandwidth in octaves; sets the envelope width.
    offset : float, optional
        Phase offset of the carrier in radians.

    Returns
    -------
    ndarray (complex)
        The Gabor kernel.
    """
    return gabor_kernel(frequency, theta=theta, bandwidth=bandwidth,
                        offset=offset)


# ---------------------------------------------------------------------------
# Filter bank
# ---------------------------------------------------------------------------

def gabor_bank(frequencies, orientations, bandwidth=1.0):
    """Build a Gabor filter bank over frequencies and orientations.

    Parameters
    ----------
    frequencies : array_like
        Carrier frequencies in cycles per pixel (the scales).
    orientations : array_like
        Orientations in radians.
    bandwidth : float, optional
        Octave bandwidth shared by all filters.

    Returns
    -------
    list of dict
        One entry per ``(frequency, theta)`` pair with keys
        ``'frequency'``, ``'theta'``, and ``'kernel'`` (complex ndarray).
    """
    bank = []
    for f in frequencies:
        for theta in orientations:
            bank.append({
                "frequency": float(f),
                "theta": float(theta),
                "kernel": gabor_kernel_2d(f, theta=theta, bandwidth=bandwidth),
            })
    return bank


def gabor_response(image, frequency, theta=0.0, bandwidth=1.0):
    """Gabor magnitude response of an image at one frequency and orientation.

    The magnitude ``sqrt(real^2 + imag^2)`` of the complex Gabor response is
    phase-invariant: it measures how much oriented structure at this scale is
    present at each pixel, regardless of the local edge polarity. This is the
    energy model of a complex cell.

    Parameters
    ----------
    image : ndarray, shape (H, W)
        Grayscale image.
    frequency : float
        Carrier frequency in cycles per pixel.
    theta : float, optional
        Orientation in radians.
    bandwidth : float, optional
        Octave bandwidth.

    Returns
    -------
    ndarray, shape (H, W)
        Per-pixel response magnitude (non-negative).
    """
    image = np.asarray(image, dtype=float)
    real, imag = gabor(image, frequency=frequency, theta=theta,
                       bandwidth=bandwidth)
    return np.sqrt(real ** 2 + imag ** 2)


def gabor_feature_stack(image, frequencies, orientations, bandwidth=1.0):
    """Stack of Gabor magnitude responses across the whole bank.

    Parameters
    ----------
    image : ndarray, shape (H, W)
        Grayscale image.
    frequencies, orientations : array_like
        Bank scales (cycles/pixel) and orientations (radians).
    bandwidth : float, optional
        Octave bandwidth.

    Returns
    -------
    stack : ndarray, shape (C, H, W)
        One magnitude map per channel, ``C = len(frequencies) *
        len(orientations)``.
    params : list of tuple
        The ``(frequency, theta)`` pair for each channel, in order.
    """
    image = np.asarray(image, dtype=float)
    maps, params = [], []
    for f in frequencies:
        for theta in orientations:
            maps.append(gabor_response(image, f, theta=theta,
                                       bandwidth=bandwidth))
            params.append((float(f), float(theta)))
    return np.stack(maps, axis=0), params


def dominant_orientation(image, frequency, orientations, bandwidth=1.0):
    """Orientation that maximises total Gabor energy in an image.

    Useful for textures or gratings with a single dominant orientation.

    Returns
    -------
    float
        The orientation in ``orientations`` (radians) with the largest
        summed magnitude response.
    """
    energies = [gabor_response(image, frequency, theta=t,
                               bandwidth=bandwidth).sum()
                for t in orientations]
    return orientations[int(np.argmax(energies))]


# ---------------------------------------------------------------------------
# Difference of Gaussians (center-surround receptive field)
# ---------------------------------------------------------------------------

def dog(image, sigma1, sigma2):
    """Difference of Gaussians: a center-surround / blob detector.

    Models the retinal and LGN receptive fields, where an excitatory centre
    is opposed by an inhibitory surround. Computed as the difference of two
    Gaussian blurs; ``sigma1 < sigma2`` gives the standard centre-surround
    (positive centre, negative surround).

    Parameters
    ----------
    image : ndarray, shape (H, W)
        Grayscale image.
    sigma1 : float
        Centre (narrow) Gaussian standard deviation in pixels.
    sigma2 : float
        Surround (wide) Gaussian standard deviation in pixels.

    Returns
    -------
    ndarray, shape (H, W)
        The DoG response (signed).

    Raises
    ------
    ValueError
        If ``sigma1 >= sigma2`` (no centre-surround structure).
    """
    if sigma1 >= sigma2:
        raise ValueError(
            f"sigma1={sigma1} must be smaller than sigma2={sigma2} "
            "for a centre-surround Difference of Gaussians."
        )
    image = np.asarray(image, dtype=float)
    return (ndimage.gaussian_filter(image, sigma1)
            - ndimage.gaussian_filter(image, sigma2))
