"""Adaptive filtering algorithms: LMS, NLMS, and RLS."""

import numpy as np


class LMS:
    """Least Mean Squares adaptive filter.

    The simplest adaptive algorithm. Updates filter coefficients by stepping
    in the direction of the instantaneous gradient of the squared error.

    Parameters
    ----------
    n_taps : int
        Number of filter coefficients.
    mu : float
        Step size (learning rate). Must satisfy 0 < mu < 2 / (n_taps * Px)
        where Px is the input signal power, for convergence.
    """

    def __init__(self, n_taps: int, mu: float = 0.01):
        if n_taps < 1:
            raise ValueError("n_taps must be at least 1")
        if mu <= 0:
            raise ValueError("Step size mu must be positive")
        self.n_taps = n_taps
        self.mu = mu
        self.w = np.zeros(n_taps)
        self._x_buf = np.zeros(n_taps)

    def update(self, x: float, d: float) -> tuple[float, float]:
        """Process one sample.

        Parameters
        ----------
        x : float
            Input sample.
        d : float
            Desired (reference) signal sample.

        Returns
        -------
        y : float
            Filter output.
        e : float
            Error signal (d - y).
        """
        self._x_buf = np.roll(self._x_buf, 1)
        self._x_buf[0] = x
        y = self.w @ self._x_buf
        e = d - y
        self.w += self.mu * e * self._x_buf
        return y, e

    def run(self, x: np.ndarray, d: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """Filter an entire signal.

        Parameters
        ----------
        x : array_like
            Input signal.
        d : array_like
            Desired signal (same length as x).

        Returns
        -------
        y : ndarray
            Filter output.
        e : ndarray
            Error signal.
        """
        n = len(x)
        y = np.zeros(n)
        e = np.zeros(n)
        for i in range(n):
            y[i], e[i] = self.update(x[i], d[i])
        return y, e


class NLMS:
    """Normalised Least Mean Squares adaptive filter.

    Normalises the step size by the input power, giving faster and more
    stable convergence than plain LMS, especially for non-stationary signals.

    Parameters
    ----------
    n_taps : int
        Number of filter coefficients.
    mu : float
        Step size, typically 0 < mu < 2. The effective step size is
        mu / (eps + ||x||^2), so convergence is independent of signal level.
    eps : float
        Regularisation constant to avoid division by zero.
    """

    def __init__(self, n_taps: int, mu: float = 0.5, eps: float = 1e-8):
        if n_taps < 1:
            raise ValueError("n_taps must be at least 1")
        if mu <= 0:
            raise ValueError("Step size mu must be positive")
        self.n_taps = n_taps
        self.mu = mu
        self.eps = eps
        self.w = np.zeros(n_taps)
        self._x_buf = np.zeros(n_taps)

    def update(self, x: float, d: float) -> tuple[float, float]:
        """Process one sample. See LMS.update for interface."""
        self._x_buf = np.roll(self._x_buf, 1)
        self._x_buf[0] = x
        y = self.w @ self._x_buf
        e = d - y
        norm = self._x_buf @ self._x_buf + self.eps
        self.w += (self.mu / norm) * e * self._x_buf
        return y, e

    def run(self, x: np.ndarray, d: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """Filter an entire signal. See LMS.run for interface."""
        n = len(x)
        y = np.zeros(n)
        e = np.zeros(n)
        for i in range(n):
            y[i], e[i] = self.update(x[i], d[i])
        return y, e


class RLS:
    """Recursive Least Squares adaptive filter.

    Converges much faster than LMS/NLMS at the cost of O(n_taps^2) computation
    per sample. Uses an exponential forgetting factor to track non-stationary
    environments.

    Parameters
    ----------
    n_taps : int
        Number of filter coefficients.
    lam : float
        Forgetting factor, typically 0.95 to 1.0. Values close to 1 give
        longer memory; smaller values track changes faster.
    delta : float
        Initialisation value for the inverse correlation matrix (P = delta * I).
    """

    def __init__(self, n_taps: int, lam: float = 0.99, delta: float = 100.0):
        if n_taps < 1:
            raise ValueError("n_taps must be at least 1")
        if not 0 < lam <= 1:
            raise ValueError("Forgetting factor lam must be in (0, 1]")
        self.n_taps = n_taps
        self.lam = lam
        self.w = np.zeros(n_taps)
        self.P = delta * np.eye(n_taps)
        self._x_buf = np.zeros(n_taps)

    def update(self, x: float, d: float) -> tuple[float, float]:
        """Process one sample. See LMS.update for interface."""
        self._x_buf = np.roll(self._x_buf, 1)
        self._x_buf[0] = x
        y = self.w @ self._x_buf
        e = d - y
        # Gain vector
        Px = self.P @ self._x_buf
        k = Px / (self.lam + self._x_buf @ Px)
        # Update inverse correlation matrix
        self.P = (self.P - np.outer(k, self._x_buf @ self.P)) / self.lam
        # Update coefficients
        self.w += k * e
        return y, e

    def run(self, x: np.ndarray, d: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """Filter an entire signal. See LMS.run for interface."""
        n = len(x)
        y = np.zeros(n)
        e = np.zeros(n)
        for i in range(n):
            y[i], e[i] = self.update(x[i], d[i])
        return y, e


def identify_system(unknown_ir: np.ndarray, n_samples: int = 5000,
                    algorithm: str = "nlms", snr_db: float = 30.0,
                    **kwargs) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Identify an unknown system (impulse response) using adaptive filtering.

    Drives the unknown system and an adaptive filter with the same white noise
    input. The adaptive filter converges to the unknown impulse response.

    Parameters
    ----------
    unknown_ir : array_like
        Impulse response of the system to identify.
    n_samples : int
        Number of samples to process.
    algorithm : str
        "lms", "nlms", or "rls".
    snr_db : float
        Signal-to-noise ratio at the system output (dB).
    **kwargs
        Additional keyword arguments passed to the adaptive filter constructor.

    Returns
    -------
    w_final : ndarray
        Final adaptive filter coefficients (estimate of unknown_ir).
    e : ndarray
        Error signal over time.
    w_history : ndarray
        Coefficient history, shape (n_samples, n_taps).
    """
    unknown_ir = np.asarray(unknown_ir, dtype=float)
    n_taps = len(unknown_ir)

    # White noise input
    rng = np.random.default_rng(42)
    x = rng.standard_normal(n_samples)

    # Desired signal: unknown system output + noise
    d_clean = np.convolve(x, unknown_ir)[:n_samples]
    noise_power = np.var(d_clean) * 10 ** (-snr_db / 10)
    d = d_clean + rng.normal(0, np.sqrt(noise_power), n_samples)

    # Select algorithm
    classes = {"lms": LMS, "nlms": NLMS, "rls": RLS}
    if algorithm.lower() not in classes:
        raise ValueError(f"Unknown algorithm: {algorithm}")
    filt = classes[algorithm.lower()](n_taps, **kwargs)

    # Run adaptation
    w_history = np.zeros((n_samples, n_taps))
    y = np.zeros(n_samples)
    e = np.zeros(n_samples)
    for i in range(n_samples):
        y[i], e[i] = filt.update(x[i], d[i])
        w_history[i] = filt.w.copy()

    return filt.w, e, w_history
