Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
compute_center_of_mass,
)

from .amplitude_scalings import compute_amplitude_scalings, AmplitudeScalingsCalculator

from .alignsorting import align_sorting, AlignSortingExtractor

from .noise_level import compute_noise_levels, NoiseLevelsCalculator
304 changes: 304 additions & 0 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
import numpy as np
from scipy.stats import linregress

from spikeinterface.core import ChannelSparsity, get_chunk_with_margin
from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs

from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift
from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension


class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension):
"""
Computes amplitude scalings from WaveformExtractor.
"""

extension_name = "amplitude_scalings"

def __init__(self, waveform_extractor):
BaseWaveformExtractorExtension.__init__(self, waveform_extractor)

extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index")
self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)

def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after):
params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after)
return params

def _select_extension_data(self, unit_ids):
old_unit_ids = self.waveform_extractor.sorting.unit_ids
unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids))

spike_mask = np.in1d(self.spikes["unit_ind"], unit_inds)
new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask]
return dict(amplitude_scalings=new_amplitude_scalings)

def _run(self, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
we = self.waveform_extractor
recording = we.recording
nbefore = we.nbefore
nafter = we.nafter
ms_before = self._params["ms_before"]
ms_after = self._params["ms_after"]
return_scaled = we._params["return_scaled"]

if ms_before is not None:
assert (
ms_before <= we._params["ms_before"]
), f"`ms_before` must be smaller than `ms_before` used in WaveformExractor: {we._params['ms_before']}"
if ms_after is not None:
assert (
ms_after <= we._params["ms_after"]
), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}"

cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore
cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter

if we.is_sparse():
sparsity = we.sparsity
elif self._params["sparsity"] is not None:
sparsity = self._params["sparsity"]
else:
if self._params["max_dense_channels"] is not None:
assert recording.get_num_channels() <= self._params["max_dense_channels"], ""
sparsity = ChannelSparsity.create_dense(we)
sparsity_inds = sparsity.unit_id_to_channel_indices
all_templates = we.get_all_templates()

# and run
func = _amplitude_scalings_chunk
init_func = _init_worker_amplitude_scalings
n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None))
job_kwargs["n_jobs"] = n_jobs
init_args = (
recording,
self.spikes,
all_templates,
we.unit_ids,
sparsity_inds,
nbefore,
nafter,
cut_out_before,
cut_out_after,
return_scaled,
)
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
handle_returns=True,
job_name="extract amplitude scalings",
**job_kwargs,
)
out = processor.run()
(amp_scalings,) = zip(*out)
amp_scalings = np.concatenate(amp_scalings)

self._extension_data[f"amplitude_scalings"] = amp_scalings

def get_data(self, outputs="concatenated"):
"""
Get computed spike amplitudes.
Parameters
----------
outputs : str, optional
'concatenated' or 'by_unit', by default 'concatenated'
Returns
-------
spike_amplitudes : np.array or dict
The spike amplitudes as an array (outputs='concatenated') or
as a dict with units as key and spike amplitudes as values.
"""
we = self.waveform_extractor
sorting = we.sorting

if outputs == "concatenated":
return self._extension_data[f"amplitude_scalings"]
elif outputs == "by_unit":
amplitudes_by_unit = []
for segment_index in range(we.get_num_segments()):
amplitudes_by_unit.append({})
segment_mask = self.spikes["segment_ind"] == segment_index
spikes_segment = self.spikes[segment_mask]
amp_scalings_segment = self._extension_data[f"amplitude_scalings"][segment_mask]
for unit_index, unit_id in enumerate(sorting.unit_ids):
unit_mask = spikes_segment["unit_ind"] == unit_index
amp_scalings = amp_scalings_segment[unit_mask]
amplitudes_by_unit[segment_index][unit_id] = amp_scalings
return amplitudes_by_unit

@staticmethod
def get_extension_function():
return compute_amplitude_scalings


WaveformExtractor.register_extension(AmplitudeScalingsCalculator)


def compute_amplitude_scalings(
waveform_extractor,
sparsity=None,
max_dense_channels=16,
ms_before=None,
ms_after=None,
load_if_exists=False,
outputs="concatenated",
**job_kwargs,
):
"""
Computes the amplitude scalings from a WaveformExtractor.

Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor object
sparsity: ChannelSparsity
If waveforms are not sparse, sparsity is required if the number of channels is greater than
`max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used.
By default None
max_dense_channels: int, optional
Maximum number of channels to allow running without sparsity. To compute amplitude scaling using
dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input.
By default 16
ms_before : float, optional
The cut out to apply before the spike peak to extract local waveforms.
If None, the WaveformExtractor ms_before is used, by default None
ms_after : float, optional
The cut out to apply after the spike peak to extract local waveforms.
If None, the WaveformExtractor ms_after is used, by default None
load_if_exists : bool, default: False
Whether to load precomputed spike amplitudes, if they already exist.
peak_sign: str
The sign to compute maximum channel:
- 'neg'
- 'pos'
- 'both'
return_scaled: bool
If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV.
outputs: str
How the output should be returned:
- 'concatenated'
- 'by_unit'
{}

Returns
-------
amplitude_scalings: np.array or list of dict
The amplitude scalings.
- If 'concatenated' all amplitudes for all spikes and all units are concatenated
- If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units)
"""
if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name):
sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name)
else:
sac = AmplitudeScalingsCalculator(waveform_extractor)
sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after)
sac.run(**job_kwargs)

amps = sac.get_data(outputs=outputs)
return amps


compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc)


def _init_worker_amplitude_scalings(
recording,
spikes,
all_templates,
unit_ids,
unit_ids_to_channel_indices,
nbefore,
nafter,
cut_out_before,
cut_out_after,
return_scaled,
):
# create a local dict per worker
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["spikes"] = spikes
worker_ctx["all_templates"] = all_templates
worker_ctx["nbefore"] = nbefore
worker_ctx["nafter"] = nafter
worker_ctx["cut_out_before"] = cut_out_before
worker_ctx["cut_out_after"] = cut_out_after
worker_ctx["margin"] = max(nbefore, nafter)
worker_ctx["return_scaled"] = return_scaled

# construct handy unit_inds -> channel_inds
worker_ctx["unit_inds_to_channel_indices"] = {
unit_ind: unit_ids_to_channel_indices[unit_id] for unit_ind, unit_id in enumerate(unit_ids)
}

return worker_ctx


def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx):
# recover variables of the worker
spikes = worker_ctx["spikes"]
recording = worker_ctx["recording"]
all_templates = worker_ctx["all_templates"]
unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"]
nbefore = worker_ctx["nbefore"]
cut_out_before = worker_ctx["cut_out_before"]
cut_out_after = worker_ctx["cut_out_after"]
margin = worker_ctx["margin"]
return_scaled = worker_ctx["return_scaled"]

i0 = np.searchsorted(spikes["segment_ind"], segment_index)
i1 = np.searchsorted(spikes["segment_ind"], segment_index + 1)
spikes_in_segment = spikes[i0:i1]
Comment on lines +251 to +253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The segment slicing could be done in worker init once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the workers could span multiple segments no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but the first serahsorted can be done once for all.


i0 = np.searchsorted(spikes["segment_ind"], segment_index)
i1 = np.searchsorted(spikes["segment_ind"], segment_index + 1)
spikes_in_segment = spikes[i0:i1]
Comment on lines +255 to +257
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate


i0 = np.searchsorted(spikes_in_segment["sample_ind"], start_frame)
i1 = np.searchsorted(spikes_in_segment["sample_ind"], end_frame)

local_waveforms = []
templates = []

scalings = []

if i0 != i1:
local_spikes = spikes_in_segment[i0:i1]
traces_with_margin, left, right = get_chunk_with_margin(
recording._recording_segments[segment_index], start_frame, end_frame, channel_indices=None, margin=margin
)

# scale traces with margin to match scaling of templates
if return_scaled and recording.has_scaled():
gains = recording.get_property("gain_to_uV")
offsets = recording.get_property("offset_to_uV")
traces_with_margin = traces_with_margin.astype("float32") * gains + offsets

# get all waveforms
for spike in local_spikes:
unit_index = spike["unit_ind"]
sample_index = spike["sample_ind"]
sparse_indices = unit_inds_to_channel_indices[unit_index]
template = all_templates[unit_index][:, sparse_indices]
template = template[nbefore - cut_out_before : nbefore + cut_out_after]
sample_centered = sample_index - start_frame
cut_out_start = left + sample_centered - cut_out_before
cut_out_end = left + sample_centered + cut_out_after
if sample_index - cut_out_before < 0:
local_waveform = traces_with_margin[:cut_out_end, sparse_indices]
template = template[cut_out_before - sample_index :]
elif sample_index + cut_out_after > end_frame + right:
local_waveform = traces_with_margin[cut_out_start:, sparse_indices]
template = template[: -(sample_index + cut_out_after - end_frame)]
else:
local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices]
Comment on lines +289 to +296
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this when we have margin ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a smaller cut out to get a local waveform

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but the margin ensure always the correct lenght no ?
When you get_chunk_with_margin(add_zeros=True)

assert template.shape == local_waveform.shape
local_waveforms.append(local_waveform)
templates.append(template)
linregress_res = linregress(template.flatten(), local_waveform.flatten())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does using linregress do not make it too slow ?
Why not a simple scalar product to speedup ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not it's actually quite fast :) happy to discuss options

scalings.append(linregress_res[0])
scalings = np.array(scalings)

return (scalings,)
57 changes: 57 additions & 0 deletions src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
import numpy as np

from spikeinterface.postprocessing import AmplitudeScalingsCalculator

from spikeinterface.postprocessing.tests.common_extension_tests import (
WaveformExtensionCommonTestSuite,
)


class AmplitudeScalingsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase):
extension_class = AmplitudeScalingsCalculator
extension_data_names = ["amplitude_scalings"]
extension_function_kwargs_list = [
dict(outputs="concatenated", chunk_size=10000, n_jobs=1),
dict(outputs="concatenated", chunk_size=10000, n_jobs=1, ms_before=0.5, ms_after=0.5),
dict(outputs="by_unit", chunk_size=10000, n_jobs=1),
dict(outputs="concatenated", chunk_size=10000, n_jobs=-1),
dict(outputs="concatenated", chunk_size=10000, n_jobs=2, ms_before=0.5, ms_after=0.5),
]

def test_scaling_parallel(self):
scalings1 = self.extension_class.get_extension_function()(
self.we1,
outputs="concatenated",
chunk_size=10000,
n_jobs=1,
)
scalings2 = self.extension_class.get_extension_function()(
self.we1,
outputs="concatenated",
chunk_size=10000,
n_jobs=2,
)
np.testing.assert_array_equal(scalings1, scalings2)

def test_scaling_values(self):
scalings1 = self.extension_class.get_extension_function()(
self.we1,
outputs="by_unit",
chunk_size=10000,
n_jobs=1,
)
# since this is GT spikes, the rounded median must be 1
for u, scalings in scalings1[0].items():
median_scaling = np.median(scalings)
print(u, median_scaling)
np.testing.assert_array_equal(np.round(median_scaling), 1)


if __name__ == "__main__":
test = AmplitudeScalingsExtensionTest()
test.setUp()
test.test_scaling_values()
test.test_scaling_parallel()
# test.test_scaled()
# test.test_parallel()
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_parallel(self):
chunk_size=10000,
n_jobs=1,
)
# TODO : fix multi processing for spike amplitudes!!!!!!!
amplitudes2 = self.extension_class.get_extension_function()(
self.we1,
peak_sign="neg",
Expand All @@ -63,5 +62,3 @@ def test_parallel(self):
test = SpikeAmplitudesExtensionTest()
test.setUp()
test.test_extension()
# test.test_scaled()
# test.test_parallel()