"""Gammatone auditory filter bank.

A gammatone filter models the frequency response of a point on the basilar
membrane: a bandpass filter whose impulse response is a gamma-distribution
envelope multiplying a sinusoidal carrier. A bank of these, with centre
frequencies spaced on the ERB-rate scale, approximates the cochlea's
tonotopic frequency analysis.

The digital realisation here is Slaney's efficient 4th-order IIR design
(Slaney, 1993), obtained from ``scipy.signal.gammatone(..., 'iir')``. That
returns an 8th-order transfer function (with non-trivial numerator zeros,
so it is not all-pole), which factors into a **cascade of four biquad
sections** -- the same second-order building block covered in the
:doc:`biquad <../biquad/index>` topic. Processing therefore reuses the exact
SOS cascade machinery (``scipy.signal.sosfilt``), making the "math to metal"
path direct: each channel is four biquads.

ERB (equivalent rectangular bandwidth) and the ERB-rate scale follow
Glasberg & Moore (1990).
"""

import numpy as np
from scipy import signal


# ---------------------------------------------------------------------------
# ERB scale (Glasberg & Moore, 1990)
# ---------------------------------------------------------------------------

def erb(f):
    """Equivalent rectangular bandwidth at frequency ``f``.

    Glasberg & Moore (1990): ``ERB(f) = 24.7 * (0.00437 * f + 1)`` with
    ``f`` and the result both in Hz. The bandwidth of an auditory filter
    grows roughly linearly with centre frequency above ~500 Hz.

    Parameters
    ----------
    f : float or ndarray
        Centre frequency in Hz.

    Returns
    -------
    float or ndarray
        Equivalent rectangular bandwidth in Hz.
    """
    return 24.7 * (0.00437 * np.asarray(f, dtype=float) + 1.0)


def hz_to_erb_rate(f):
    """Convert frequency in Hz to position on the ERB-rate scale.

    The ERB-rate (or ERB-number) scale counts how many ERBs fit below a
    given frequency. Equal steps on this scale correspond to equal spacing
    along the basilar membrane.

    ``E(f) = 21.4 * log10(0.00437 * f + 1)``
    """
    return 21.4 * np.log10(0.00437 * np.asarray(f, dtype=float) + 1.0)


def erb_rate_to_hz(e):
    """Inverse of :func:`hz_to_erb_rate`: ERB-rate position back to Hz."""
    return (10.0 ** (np.asarray(e, dtype=float) / 21.4) - 1.0) / 0.00437


def erb_space(low, high, n_channels):
    """Centre frequencies spaced equally on the ERB-rate scale.

    This is how an auditory filter bank lays out its channels: densely at
    low frequencies, sparsely at high frequencies, matching the cochlea's
    tonotopic map.

    Parameters
    ----------
    low, high : float
        Lowest and highest centre frequency in Hz.
    n_channels : int
        Number of filters in the bank.

    Returns
    -------
    ndarray, shape (n_channels,)
        Centre frequencies in Hz, ascending.
    """
    e_low = hz_to_erb_rate(low)
    e_high = hz_to_erb_rate(high)
    return erb_rate_to_hz(np.linspace(e_low, e_high, n_channels))


# ---------------------------------------------------------------------------
# Single gammatone channel as a biquad cascade
# ---------------------------------------------------------------------------

def gammatone_sos(fc, fs, order=4):
    """Design one gammatone channel as a cascade of biquad sections.

    Uses Slaney's 4th-order IIR gammatone (via ``scipy.signal.gammatone``)
    and factors the resulting 8th-order transfer function into second-order
    sections. For ``order=4`` the result is four biquad sections.

    Parameters
    ----------
    fc : float
        Centre frequency in Hz. Must satisfy ``0 < fc < fs / 2``.
    fs : float
        Sample rate in Hz.
    order : int, optional
        Gammatone order. Only the standard ``order=4`` is supported by the
        underlying SciPy design; passed through for documentation.

    Returns
    -------
    sos : ndarray, shape (4, 6)
        Second-order sections ``[b0, b1, b2, 1, a1, a2]`` per row, ready
        for :func:`scipy.signal.sosfilt`.

    Raises
    ------
    ValueError
        If ``fc`` is not strictly between 0 and the Nyquist frequency.
    """
    if not 0 < fc < fs / 2:
        raise ValueError(
            f"Centre frequency fc={fc} Hz must lie in (0, fs/2={fs / 2} Hz)."
        )
    if order != 4:
        raise ValueError(
            f"order={order} is not supported; SciPy's IIR gammatone is fixed "
            "at Slaney's 4th order. Use order=4."
        )
    # SciPy's IIR gammatone is fixed at Slaney's 4th order; the order
    # argument is validated above but not forwarded (scipy ignores it for IIR).
    b, a = signal.gammatone(fc, ftype="iir", fs=fs)
    return signal.tf2sos(b, a)


def make_filterbank(fs, cfs, order=4):
    """Build a gammatone filter bank as a list of biquad cascades.

    Parameters
    ----------
    fs : float
        Sample rate in Hz.
    cfs : array_like
        Centre frequencies in Hz (e.g. from :func:`erb_space`).
    order : int, optional
        Gammatone order (default 4).

    Returns
    -------
    list of ndarray
        One ``(4, 6)`` SOS array per centre frequency.
    """
    return [gammatone_sos(fc, fs, order=order) for fc in cfs]


def gammatone_filterbank(x, fs, cfs, order=4):
    """Filter a signal through a gammatone bank.

    Each channel is applied independently with ``sosfilt`` (a biquad
    cascade). The result is the basilar-membrane motion at each tonotopic
    place: the building block of a cochleagram.

    Parameters
    ----------
    x : array_like, shape (M,)
        Input signal.
    fs : float
        Sample rate in Hz.
    cfs : array_like, shape (C,)
        Channel centre frequencies in Hz.
    order : int, optional
        Gammatone order (default 4).

    Returns
    -------
    ndarray, shape (C, M)
        Subband signals, one row per channel.
    """
    x = np.asarray(x, dtype=float)
    if x.ndim != 1:
        raise ValueError(f"x must be a 1-D signal, got shape {x.shape}.")
    bank = make_filterbank(fs, cfs, order=order)
    out = np.empty((len(cfs), x.shape[-1]), dtype=float)
    for i, sos in enumerate(bank):
        out[i] = signal.sosfilt(sos, x)
    return out


def cochleagram(x, fs, cfs, frame_ms=20.0, hop_ms=10.0, order=4):
    """Compute a cochleagram: framed per-channel energy in dB.

    Runs the signal through the gammatone bank, then measures short-time
    RMS energy in each channel. The result is the auditory analogue of a
    spectrogram, but with ERB-spaced (auditory) frequency resolution
    instead of the linear bins of an STFT.

    Parameters
    ----------
    x : array_like, shape (M,)
        Input signal.
    fs : float
        Sample rate in Hz.
    cfs : array_like, shape (C,)
        Channel centre frequencies in Hz.
    frame_ms : float, optional
        Analysis frame length in milliseconds (default 20).
    hop_ms : float, optional
        Hop between frames in milliseconds (default 10).
    order : int, optional
        Gammatone order (default 4).

    Returns
    -------
    coch : ndarray, shape (C, T)
        Per-channel energy in dB, normalised so the loudest cell in this
        call is 0 dB and floored at -80 dB. Values are relative to the
        per-utterance peak, so they are not comparable across separate calls
        unless the input is pre-normalised.
    times : ndarray, shape (T,)
        Frame centre times in seconds.
    """
    subbands = gammatone_filterbank(x, fs, cfs, order=order)
    frame = max(1, int(round(frame_ms * 1e-3 * fs)))
    hop = max(1, int(round(hop_ms * 1e-3 * fs)))
    n = subbands.shape[1]
    starts = np.arange(0, max(1, n - frame + 1), hop)

    coch = np.empty((len(cfs), len(starts)), dtype=float)
    for t, s in enumerate(starts):
        seg = subbands[:, s:s + frame]
        rms = np.sqrt(np.mean(seg ** 2, axis=1) + 1e-12)
        coch[:, t] = rms

    coch_db = 20.0 * np.log10(np.maximum(coch / np.max(coch), 1e-4))
    times = (starts + frame / 2) / fs
    return coch_db, times
