Source code for biotuner.surrogates

"""
Surrogate signal generation and comparison for BiotunerGroup analysis.

Module type: Functions

Surrogates are signal controls that preserve specific statistical properties
(spectrum, amplitude distribution) while destroying others (phase relationships,
nonlinear structure). Comparing real vs. surrogate BiotunerGroup metrics lets
you assess whether observed harmonicity is above chance.

Typical usage
-------------
>>> from biotuner.biotuner_group import BiotunerGroup
>>> from biotuner.surrogates import surrogate_group, plot_surrogate_distributions
>>>
>>> bt = BiotunerGroup(data, sf=1000)
>>> bt.compute_peaks(peaks_function='EMD').compute_metrics()
>>>
>>> surr = surrogate_group(bt, surr_type='AAFT')
>>> surr_pink = surrogate_group(bt, surr_type='pink')
>>>
>>> plot_surrogate_distributions(bt, {'AAFT': surr, 'pink': surr_pink}, metric='harmsim')
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from typing import Dict, List, Optional

from biotuner.biotuner_utils import (
    AAFT_surrogates,
    butter_bandpass_filter,
    UnivariateSurrogatesTFT,
    phaseScrambleTS,
)

SURROGATE_TYPES = ('AAFT', 'TFT', 'phase', 'shuffle', 'white', 'pink', 'brown', 'blue')


# ---------------------------------------------------------------------------
# Low-level signal generation
# ---------------------------------------------------------------------------

[docs] def generate_surrogate( data: np.ndarray, surr_type: str = 'pink', low_cut: float = 0.5, high_cut: float = 150.0, sf: int = 1000, TFT_freq: int = 5, ) -> np.ndarray: """Generate a surrogate signal from a 1D time series. Parameters ---------- data : ndarray, shape (n_samples,) Original 1D signal. surr_type : str, default='pink' Surrogate type: * ``'AAFT'`` – Amplitude-Adjusted Fourier Transform surrogate. Preserves amplitude distribution and power spectrum. * ``'TFT'`` – Truncated Fourier Transform surrogate. Preserves low-frequency structure below *TFT_freq* Hz. * ``'phase'`` – Phase-scrambled surrogate. Preserves power spectrum, destroys phase relationships. * ``'shuffle'`` – Randomly shuffled samples. Destroys all temporal structure. * ``'white'`` – White noise (flat spectrum, β=0). * ``'pink'`` – Pink noise (1/f spectrum, β=1). * ``'brown'`` – Brown noise (1/f² spectrum, β=2). * ``'blue'`` – Blue noise (f spectrum, β=−1). low_cut : float, default=0.5 High-pass cutoff frequency (Hz). Applied after generation for all types except TFT. high_cut : float, default=150.0 Low-pass cutoff frequency (Hz). sf : int, default=1000 Sampling frequency (Hz). TFT_freq : int, default=5 Corner frequency for TFT surrogates. Returns ------- surrogate : ndarray, shape (n_samples,) Surrogate signal matching the length of *data*. """ if surr_type not in SURROGATE_TYPES: raise ValueError(f"surr_type must be one of {SURROGATE_TYPES}, got '{surr_type}'") if surr_type == 'AAFT': indexes = np.arange(len(data)) data_ = AAFT_surrogates(np.stack((data, indexes)))[0] return butter_bandpass_filter(data_, low_cut, high_cut, sf, 4) if surr_type == 'TFT': return UnivariateSurrogatesTFT(data, 1, fc=TFT_freq) if surr_type == 'phase': scrambled = phaseScrambleTS(data)[:len(data)] return butter_bandpass_filter(scrambled, low_cut, high_cut, sf, 4) if surr_type == 'shuffle': data_ = data.copy() np.random.shuffle(data_) return butter_bandpass_filter(data_, low_cut, high_cut, sf, 4) # Colored noise beta_map = {'white': 0, 'pink': 1, 'brown': 2, 'blue': -1} try: import colorednoise as cn except ImportError: raise ImportError( "The 'colorednoise' package is required for colored noise surrogates.\n" "Install it with: pip install colorednoise" ) data_ = cn.powerlaw_psd_gaussian(beta_map[surr_type], len(data)) return butter_bandpass_filter(data_, low_cut, high_cut, sf, 4)
[docs] def generate_surrogate_data( data: np.ndarray, surr_type: str = 'pink', low_cut: float = 0.5, high_cut: float = 150.0, sf: int = 1000, **kwargs, ) -> np.ndarray: """Generate surrogate signals for a 2D or 3D array of time series. Parameters ---------- data : ndarray, shape (n, n_samples) or (n, m, n_samples) Original signals. Each 1D slice is treated independently. surr_type : str, default='pink' Surrogate type. See :func:`generate_surrogate`. low_cut, high_cut : float Bandpass filter cutoffs (Hz). sf : int, default=1000 Sampling frequency (Hz). **kwargs : Extra keyword arguments forwarded to :func:`generate_surrogate` (e.g. ``TFT_freq``). Returns ------- surrogate_data : ndarray Same shape as *data*. """ if data.ndim not in (2, 3): raise ValueError("data must be 2D or 3D") data_ = data.copy() if data.ndim == 2: for i in range(data.shape[0]): data_[i] = generate_surrogate(data[i], surr_type, low_cut, high_cut, sf, **kwargs) else: for i in range(data.shape[0]): for j in range(data.shape[1]): data_[i, j] = generate_surrogate( data[i, j], surr_type, low_cut, high_cut, sf, **kwargs ) return data_
# --------------------------------------------------------------------------- # BiotunerGroup-level surrogate creation # ---------------------------------------------------------------------------
[docs] def surrogate_group( bt_group, surr_type: str = 'AAFT', low_cut: float = 0.5, high_cut: float = 150.0, recompute: bool = True, n_jobs: int = 1, **peaks_kwargs, ): """Create a surrogate :class:`~biotuner.biotuner_group.BiotunerGroup`. Generates surrogate signals for every time series in *bt_group*, then mirrors the same analysis pipeline (peaks, metrics, dissonance curve, …) that was already run on the original group. Parameters ---------- bt_group : BiotunerGroup Reference group. Shape and sampling frequency are preserved. surr_type : str, default='AAFT' Surrogate type. See :func:`generate_surrogate`. low_cut, high_cut : float Bandpass filter cutoffs (Hz) applied during surrogate generation. recompute : bool, default=True If ``True``, automatically re-run the same analysis steps that were already computed on *bt_group* (peaks, metrics, diss_curve, HE). n_jobs : int, default=1 Parallel jobs forwarded to the BiotunerGroup pipeline. **peaks_kwargs : Extra keyword arguments forwarded to :meth:`~biotuner.biotuner_group.BiotunerGroup.compute_peaks` (e.g. ``min_freq=1``, ``max_freq=60``). Returns ------- surr_group : BiotunerGroup New BiotunerGroup backed by surrogate data, with the same pipeline state as *bt_group* (if *recompute* is ``True``). Examples -------- >>> surr = surrogate_group(bt, surr_type='AAFT', min_freq=1, max_freq=60) >>> comparison = plot_surrogate_distributions(bt, {'AAFT': surr}, metric='harmsim') """ from biotuner.biotuner_group import BiotunerGroup surr_data = generate_surrogate_data( bt_group.data, surr_type, low_cut, high_cut, bt_group.sf ) surr_group = BiotunerGroup( surr_data, sf=bt_group.sf, axis_labels=bt_group.axis_labels, store_objects=bt_group.store_objects, **bt_group.biotuner_kwargs, ) if recompute: computed = bt_group._computed_methods if 'peaks_extraction' in computed: surr_group.compute_peaks(n_jobs=n_jobs, **peaks_kwargs) if 'compute_peaks_metrics' in computed: surr_group.compute_metrics(n_jobs=n_jobs) if 'compute_diss_curve' in computed: surr_group.compute_diss_curve(n_jobs=n_jobs) if 'compute_harmonic_entropy' in computed: surr_group.compute_harmonic_entropy() return surr_group
# --------------------------------------------------------------------------- # Visualization # ---------------------------------------------------------------------------
[docs] def plot_surrogate_distributions( real_group, surrogate_groups: Dict[str, object], metric: str = 'harmsim', colors: Optional[List[str]] = None, figsize: tuple = (11, 6), title: Optional[str] = None, show: bool = True, ) -> plt.Figure: """Plot metric distributions for real data vs. surrogate conditions. Computes an independent t-test between the real group and each surrogate group, and annotates significant differences (p < 0.05) with an asterisk in the legend. Parameters ---------- real_group : BiotunerGroup Original data group (analysis pipeline already computed). surrogate_groups : dict ``{label: BiotunerGroup}`` for each surrogate condition. metric : str, default='harmsim' Metric column to plot (must exist in every group's summary). colors : list of str, optional One color per condition (real first, then surrogates in dict order). Defaults to a preset palette. figsize : tuple, default=(11, 6) Figure size. title : str, optional Custom figure title. Defaults to a descriptive auto-title. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.Figure Raises ------ ValueError If *metric* is not found in any group's summary. """ _default_colors = ['cyan', 'deeppink', 'gold', 'limegreen', 'orange', 'violet', 'red'] if colors is None: colors = _default_colors all_groups: Dict[str, object] = {'real': real_group, **surrogate_groups} data_map: Dict[str, np.ndarray] = {} for label, grp in all_groups.items(): if grp.results is None: grp.summary() if metric not in grp.results.columns: raise ValueError( f"Metric '{metric}' not found in group '{label}'. " f"Available: {list(grp.results.select_dtypes(include=[float, int]).columns)}" ) data_map[label] = grp.results[metric].dropna().values real_vals = data_map['real'] fig, ax = plt.subplots(figsize=figsize) ax.set_facecolor('#1a1a2e') fig.patch.set_facecolor('#1a1a2e') legend_labels: List[str] = [] for (label, vals), color in zip(data_map.items(), colors): sns.kdeplot(vals, ax=ax, color=color, linewidth=2.5) if label != 'real': n = min(len(vals), len(real_vals)) _, p = stats.ttest_ind(real_vals[:n], vals[:n], nan_policy='omit') legend_labels.append(f'{label} *' if p < 0.05 else label) else: legend_labels.append(label) ax.legend(legend_labels, fontsize=12, framealpha=0.85, facecolor='#2d2d4e', labelcolor='white') ax.set_xlabel(metric, fontsize=13, color='white') ax.set_ylabel('Density', fontsize=13, color='white') ax.grid(color='white', linestyle='-.', linewidth=0.5, alpha=0.3) ax.tick_params(colors='white') if title is None: title = f'Real vs. surrogate distributions — {metric}' ax.set_title(title, fontsize=15, color='white', fontweight='bold') plt.tight_layout() if show: plt.show() return fig