"""Power law noise generation, characterisation, and whitening."""

import numpy as np
from scipy.signal import lfilter


def generate_power_law_noise_psd(nr_of_samples: int, alpha: float) -> np.ndarray:
    """Generate power law noise by shaping the PSD of white noise.

    Shapes the spectrum of a white noise sequence so that the PSD
    is proportional to f^(-alpha). Fast, but produces a finite block.

    Args:
        nr_of_samples: Length of the output sequence.
        alpha: Power law exponent (0=white, 1=pink, 2=brownian).

    Returns:
        Power law noise sequence with the specified spectral slope.
    """
    N = np.fft.rfft(np.random.randn(nr_of_samples))
    f = np.fft.rfftfreq(nr_of_samples)
    # Small offset at DC to avoid division by zero; introduces negligible low-frequency bias
    H = (f + 0.5 * f[1]) ** (-alpha / 2)
    H /= np.sqrt(np.mean(np.abs(H) ** 2))
    return np.fft.irfft(N * H)


def kasdin_coefficients(alpha: float, min_magnitude: float = 0.01) -> list[float]:
    """Compute AR coefficients for 1/f^alpha noise using Kasdin's recurrence.

    Args:
        alpha: Power law exponent.
        min_magnitude: Truncate when coefficient magnitude drops below this.

    Returns:
        List of AR coefficients [1, theta_1, theta_2, ...].
    """
    coeffs = [1.0]
    k = 1
    while k < 10_000:
        a = (k - 1 - alpha / 2) * coeffs[k - 1] / k
        if abs(a) < min_magnitude:
            break
        coeffs.append(a)
        k += 1
    return coeffs


def generate_power_law_noise_ar(
    nr_of_samples: int, alpha: float, min_magnitude: float = 0.01
) -> np.ndarray:
    """Generate power law noise using an autoregressive model.

    Filters white noise through an all-pole filter with coefficients
    from Kasdin's recurrence. Suitable for streaming applications.

    Args:
        nr_of_samples: Length of the output sequence.
        alpha: Power law exponent (0=white, 1=pink, 2=brownian).
        min_magnitude: AR coefficient truncation threshold.

    Returns:
        Power law noise sequence.
    """
    coeffs = kasdin_coefficients(alpha, min_magnitude)
    return lfilter([1], coeffs, np.random.randn(nr_of_samples))


def build_whitening_filter(
    alpha: float, min_magnitude: float = 0.01
) -> np.ndarray:
    """Build a whitening FIR filter for 1/f^alpha noise.

    The whitening filter is the inverse of the AR generation model:
    an all-zero (FIR) filter with the same coefficients.

    Args:
        alpha: Power law exponent.
        min_magnitude: Coefficient truncation threshold.

    Returns:
        FIR filter coefficients for whitening.
    """
    return np.array(kasdin_coefficients(alpha, min_magnitude))


def whiten(x: np.ndarray, alpha: float, min_magnitude: float = 0.01) -> np.ndarray:
    """Whiten a power law noise signal.

    Applies the inverse AR filter to remove spectral colouring.

    Args:
        x: Input signal (assumed to be 1/f^alpha noise).
        alpha: Power law exponent.
        min_magnitude: Coefficient truncation threshold.

    Returns:
        Whitened signal (approximately white noise if alpha is correct).
    """
    h = build_whitening_filter(alpha, min_magnitude)
    return lfilter(h, [1], x)


def acf(x: np.ndarray, max_lag: int) -> np.ndarray:
    """Compute the normalised autocorrelation function up to max_lag.

    Args:
        x: Input signal.
        max_lag: Maximum lag to compute.

    Returns:
        Normalised autocorrelation values from lag 0 to max_lag.
    """
    # correlate mode='full' returns 2N-1 values; last N values are lags 0..N-1
    r = np.correlate(x, x, mode="full")[-len(x):]
    if r[0] == 0:
        return np.zeros(max_lag + 1)
    return r[: max_lag + 1] / r[0]


def estimate_lag1_correlation(x: np.ndarray, eta: float = 0.99) -> np.ndarray:
    """Recursively estimate the lag-1 autocorrelation coefficient.

    Uses frugal median/MAD estimation for the running statistics.
    Works well for alpha < 1; biased for larger alpha due to
    long-range dependence violating the local stationarity assumption.

    Args:
        x: Input signal.
        eta: Forgetting factor (closer to 1 = longer memory).

    Returns:
        Array of lag-1 autocorrelation estimates at each sample.
    """
    k_mad = 1.4826
    med, mad, r_prev = None, None, None
    r = np.zeros_like(x)
    x_prev = None

    for i, xi in enumerate(x):
        alpha = 0.1 if mad is None else mad / 100
        med = xi if med is None else (med + alpha if xi > med else med - alpha)
        xi_centred = xi - med
        mad = (
            abs(xi_centred)
            if mad is None
            else (mad + alpha if abs(xi_centred) > mad else mad - alpha)
        )

        if x_prev is not None:
            sigma = k_mad * mad
            if r_prev is None:
                r_prev = xi_centred * x_prev / (sigma**2)
            else:
                r_prev = (
                    eta * r_prev + (1 - eta) * xi_centred * x_prev / (sigma**2)
                )
            r[i] = r_prev

        x_prev = xi_centred

    return r
