"""Particle Swarm Optimization for filter design.

PSO is a population-based, gradient-free optimiser inspired by the collective
motion of bird flocks and fish schools (Kennedy and Eberhart, 1995). A swarm
of candidate solutions ("particles") moves through the search space, each
particle pulled toward the best position it has personally found and the best
position found by the whole swarm. The balance of those two pulls, plus an
inertia term, lets the swarm explore broadly and then converge. The inertia
weight ``w`` used here was not in the original swarm; it was added later by
Shi and Eberhart (1998) to tune the exploration/convergence trade-off.

Where gradient methods like LMS shine on smooth, convex error surfaces (an FIR
filter's mean-squared error is a single bowl), they stall in local minima on
the multimodal surfaces that arise in IIR filter design. PSO needs no
gradient and no convexity, so it is a natural tool there.

This module provides a compact, dependency-free PSO, an IIR magnitude-matching
fitness, and a finite-difference gradient-descent baseline for the head-to-head
comparison on the topic page.
"""

import numpy as np
from scipy import signal


# ---------------------------------------------------------------------------
# Core PSO
# ---------------------------------------------------------------------------

def pso(fitness, bounds, n_particles=30, n_iters=100,
        w=0.7, c1=1.5, c2=1.5, rng=None):
    """Minimise ``fitness`` over a box with particle swarm optimization.

    Velocity update for each particle ``i``::

        v_i <- w*v_i + c1*r1*(pbest_i - x_i) + c2*r2*(gbest - x_i)
        x_i <- x_i + v_i

    with ``r1, r2`` fresh uniform random numbers each step, ``pbest_i`` the
    particle's own best position, and ``gbest`` the swarm's best.

    Parameters
    ----------
    fitness : callable
        ``fitness(x) -> float`` to be minimised; ``x`` has shape ``(D,)``.
    bounds : array_like, shape (D, 2)
        ``(low, high)`` per dimension. Particles are clipped to the box.
    n_particles : int, optional
        Swarm size (default 30).
    n_iters : int, optional
        Number of iterations (default 100).
    w : float, optional
        Inertia weight (default 0.7); higher = more exploration.
    c1, c2 : float, optional
        Cognitive (personal) and social (swarm) acceleration (default 1.5).
    rng : numpy.random.Generator, optional
        Seeded generator for reproducibility. A stochastic optimiser needs a
        seed to be repeatable.

    Returns
    -------
    gbest : ndarray, shape (D,)
        Best position found.
    gbest_f : float
        Fitness at ``gbest``.
    history : ndarray, shape (n_iters + 1,)
        Best fitness after each iteration (for convergence curves).
    """
    rng = np.random.default_rng() if rng is None else rng
    bounds = np.asarray(bounds, dtype=float)
    lo, hi = bounds[:, 0], bounds[:, 1]
    span = hi - lo
    d = len(bounds)

    x = rng.uniform(lo, hi, size=(n_particles, d))
    v = 0.1 * rng.uniform(-span, span, size=(n_particles, d))

    pbest = x.copy()
    pbest_f = np.array([fitness(xi) for xi in x])
    g = int(np.argmin(pbest_f))
    gbest = pbest[g].copy()
    gbest_f = float(pbest_f[g])
    history = [gbest_f]

    for _ in range(n_iters):
        r1 = rng.random((n_particles, d))
        r2 = rng.random((n_particles, d))
        v = w * v + c1 * r1 * (pbest - x) + c2 * r2 * (gbest - x)
        x = np.clip(x + v, lo, hi)

        f = np.array([fitness(xi) for xi in x])
        improved = f < pbest_f
        pbest[improved] = x[improved]
        pbest_f[improved] = f[improved]

        g = int(np.argmin(pbest_f))
        if pbest_f[g] < gbest_f:
            gbest = pbest[g].copy()
            gbest_f = float(pbest_f[g])
        history.append(gbest_f)

    return gbest, gbest_f, np.array(history)


# ---------------------------------------------------------------------------
# Gradient-descent baseline (for the LMS-style contrast)
# ---------------------------------------------------------------------------

def gradient_descent(fitness, x0, bounds, lr=1e-3, n_iters=2000, eps=1e-5):
    """Finite-difference gradient descent, the local-search baseline.

    Stands in for a gradient method (LMS and friends). On a multimodal
    surface it converges to whatever basin ``x0`` starts in, which is the
    point of contrast with the global swarm search.

    Parameters
    ----------
    fitness : callable
        Objective to minimise, ``fitness(x) -> float``.
    x0 : array_like, shape (D,)
        Starting point.
    bounds : array_like, shape (D, 2)
        Box constraints; the iterate is clipped to the box.
    lr : float, optional
        Learning rate / step size (default 1e-3).
    n_iters : int, optional
        Iterations (default 2000).
    eps : float, optional
        Finite-difference step for the gradient estimate (default 1e-5).

    Returns
    -------
    x : ndarray, shape (D,)
        Final iterate.
    fx : float
        Fitness at the final iterate.
    history : ndarray
        Fitness after each iteration.
    """
    bounds = np.asarray(bounds, dtype=float)
    lo, hi = bounds[:, 0], bounds[:, 1]
    x = np.clip(np.asarray(x0, dtype=float), lo, hi)
    d = len(x)
    history = [float(fitness(x))]
    step = np.zeros(d)

    for _ in range(n_iters):
        grad = np.zeros(d)
        fx = fitness(x)
        for j in range(d):
            step[j] = eps
            grad[j] = (fitness(x + step) - fx) / eps
            step[j] = 0.0
        x = np.clip(x - lr * grad, lo, hi)
        history.append(float(fitness(x)))

    return x, history[-1], np.array(history)


# ---------------------------------------------------------------------------
# IIR filter design as an optimization problem
# ---------------------------------------------------------------------------

def iir_is_stable(a1, a2):
    """Stability test for a second-order section ``1 + a1 z^-1 + a2 z^-2``.

    The poles lie inside the unit circle iff ``|a2| < 1`` and
    ``|a1| < 1 + a2`` (the second-order Schur-Cohn / Jury conditions).
    """
    return abs(a2) < 1.0 and abs(a1) < 1.0 + a2


def iir_magnitude_fitness(coeffs, freqs, target_mag, fs, unstable_penalty=10.0):
    """Magnitude-response matching error for one biquad, with a stability wall.

    The particle is ``coeffs = [b0, b1, b2, a1, a2]`` (``a0 = 1``). The
    fitness is the mean-squared error between the biquad's magnitude response
    and ``target_mag`` over ``freqs``; unstable coefficient sets are pushed
    back with a large penalty so the swarm stays in the stable region.

    Parameters
    ----------
    coeffs : array_like, shape (5,)
        ``[b0, b1, b2, a1, a2]``.
    freqs : ndarray
        Frequencies in Hz at which to match.
    target_mag : ndarray
        Desired magnitude response at ``freqs`` (linear, not dB).
    fs : float
        Sample rate in Hz.
    unstable_penalty : float, optional
        Fitness returned for an unstable design (default 10.0).

    Returns
    -------
    float
        Mean-squared magnitude error, or the penalty if unstable.
    """
    b0, b1, b2, a1, a2 = coeffs
    if not iir_is_stable(a1, a2):
        return unstable_penalty
    b = [b0, b1, b2]
    a = [1.0, a1, a2]
    _, h = signal.freqz(b, a, worN=freqs, fs=fs)
    return float(np.mean((np.abs(h) - target_mag) ** 2))


def design_iir_pso(freqs, target_mag, fs, n_particles=40, n_iters=200,
                   rng=None):
    """Design a stable biquad to match a target magnitude response, via PSO.

    Parameters
    ----------
    freqs : ndarray
        Frequencies in Hz.
    target_mag : ndarray
        Target magnitude (linear) at ``freqs``.
    fs : float
        Sample rate in Hz.
    n_particles, n_iters : int, optional
        Swarm size and iteration count.
    rng : numpy.random.Generator, optional
        Seeded generator.

    Returns
    -------
    sos : ndarray, shape (1, 6)
        Designed second-order section ``[b0, b1, b2, 1, a1, a2]``.
    err : float
        Final magnitude MSE.
    history : ndarray
        Convergence curve.
    """
    bounds = [(-2.0, 2.0)] * 3 + [(-2.0, 2.0), (-1.0, 1.0)]  # b's, a1, a2

    def fitness(c):
        return iir_magnitude_fitness(c, freqs, target_mag, fs)

    best, err, history = pso(fitness, bounds, n_particles=n_particles,
                             n_iters=n_iters, rng=rng)
    b0, b1, b2, a1, a2 = best
    sos = np.array([[b0, b1, b2, 1.0, a1, a2]])
    return sos, err, history
