"""Matched filtering: optimal detection of known signals in noise."""

import numpy as np
from scipy import signal as sig


def matched_filter(template: np.ndarray, received: np.ndarray,
                   normalize: bool = True) -> np.ndarray:
    """Apply a matched filter to detect a known template in a received signal.

    The matched filter is the time-reversed, conjugated template — equivalent
    to cross-correlation. It maximises the output SNR for the given template
    in additive white noise.

    Parameters
    ----------
    template : array_like
        The known signal to detect.
    received : array_like
        The received signal (template + noise + possibly shifted copies).
    normalize : bool
        If True, normalize output by template energy so that a perfect match
        gives a peak value of 1.0.

    Returns
    -------
    output : ndarray
        Matched filter output. Peaks indicate detected template locations.
    """
    template = np.asarray(template, dtype=float)
    received = np.asarray(received, dtype=float)
    output = np.correlate(received, template, mode='full')
    if normalize:
        energy = np.sum(template ** 2)
        if energy > 0:
            output = output / energy
    return output


def matched_filter_fft(template: np.ndarray, received: np.ndarray,
                       normalize: bool = True) -> np.ndarray:
    """Matched filter using FFT-based correlation (efficient for long signals).

    Parameters
    ----------
    template : array_like
        The known signal to detect.
    received : array_like
        The received signal.
    normalize : bool
        If True, normalize by template energy.

    Returns
    -------
    output : ndarray
        Matched filter output (same length as received).

    Notes
    -----
    Returns output aligned to lag 0 (same length as received). Note:
    ``matched_filter()`` with mode='full' returns len(received)+len(template)-1
    samples starting from negative lags. To compare: ``fft_out[k]`` corresponds
    to ``direct_out[k + len(template) - 1]``.
    """
    template = np.asarray(template, dtype=float)
    received = np.asarray(received, dtype=float)
    n = len(received)
    # Frequency-domain correlation: conj(H) * X
    nfft = 2 ** int(np.ceil(np.log2(n + len(template) - 1)))
    H = np.fft.fft(template, nfft)
    X = np.fft.fft(received, nfft)
    output = np.fft.ifft(np.conj(H) * X).real[:n]
    if normalize:
        energy = np.sum(template ** 2)
        if energy > 0:
            output = output / energy
    return output


def make_chirp(f0: float, f1: float, duration: float, fs: float,
               method: str = 'linear') -> tuple[np.ndarray, np.ndarray]:
    """Generate a frequency-modulated chirp signal.

    Parameters
    ----------
    f0 : float
        Start frequency in Hz.
    f1 : float
        End frequency in Hz.
    duration : float
        Chirp duration in seconds.
    fs : float
        Sample rate in Hz.
    method : str
        Sweep method: 'linear', 'quadratic', 'logarithmic', or 'hyperbolic'.

    Returns
    -------
    t : ndarray
        Time vector.
    chirp : ndarray
        Chirp signal.
    """
    n = int(duration * fs)
    t = np.arange(n) / fs
    chirp = sig.chirp(t, f0, duration, f1, method=method)
    return t, chirp


def simulate_bat_echo(fs: float = 250000, f0: float = 100000, f1: float = 30000,
                      chirp_duration: float = 0.002, target_distance: float = 2.0,
                      snr_db: float = -5.0,
                      rng: np.random.Generator | None = None
                      ) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]:
    """Simulate a bat echolocation scenario.

    A bat emits a frequency-modulated chirp and listens for echoes from a
    target at a known distance. The received signal contains the echo buried
    in noise.

    Parameters
    ----------
    fs : float
        Sample rate in Hz. Must be > 2 * f0 (Nyquist).
    f0 : float
        Start frequency of the chirp (Hz).
    f1 : float
        End frequency of the chirp (Hz).
    chirp_duration : float
        Duration of the emitted chirp (seconds).
    target_distance : float
        Distance to the target (meters).
    snr_db : float
        SNR of the echo in dB (can be negative — echo buried in noise).
        Note: this is the global SNR (signal power averaged over the full
        buffer, including zero-padded regions). The local SNR within the
        echo segment is higher.
    rng : Generator, optional
        NumPy random generator for reproducibility.

    Returns
    -------
    t_full : ndarray
        Time vector for the full recording.
    transmitted : ndarray
        The transmitted chirp (zero-padded to full length).
    received : ndarray
        The received signal (echo + noise).
    delay_true : float
        True round-trip delay in seconds.
    """
    if rng is None:
        rng = np.random.default_rng(42)

    speed_of_sound = 343.0  # m/s
    delay = 2 * target_distance / speed_of_sound
    delay_samples = int(delay * fs)

    # Total recording length: enough for chirp + echo + some margin
    n_chirp = int(chirp_duration * fs)
    n_total = delay_samples + 2 * n_chirp

    # Transmitted chirp
    _, chirp = make_chirp(f0, f1, chirp_duration, fs)

    transmitted = np.zeros(n_total)
    transmitted[:n_chirp] = chirp

    # Echo: attenuated and delayed copy
    echo = np.zeros(n_total)
    attenuation = 0.1  # distance + reflection loss
    echo[delay_samples:delay_samples + n_chirp] = attenuation * chirp

    # Add noise
    echo_power = np.sum(echo ** 2) / n_total
    noise_power = echo_power * 10 ** (-snr_db / 10)
    noise = rng.normal(0, np.sqrt(noise_power), n_total)

    received = echo + noise
    t_full = np.arange(n_total) / fs

    return t_full, transmitted, received, delay


def detect_echo(template: np.ndarray, received: np.ndarray,
                fs: float) -> tuple[float, np.ndarray]:
    """Detect the echo delay using matched filtering.

    Parameters
    ----------
    template : array_like
        The transmitted chirp.
    received : array_like
        The received signal.
    fs : float
        Sample rate.

    Returns
    -------
    delay : float
        Estimated delay in seconds (time of peak correlation).
    mf_output : ndarray
        Matched filter output.
    """
    mf_output = matched_filter_fft(template, received, normalize=True)
    peak_sample = np.argmax(np.abs(mf_output))
    delay = peak_sample / fs
    return delay, mf_output
