-
Notifications
You must be signed in to change notification settings - Fork 246
Add amplitude_scaling implementation #1485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
16db01f
5c5be4a
571b078
4d5c5e2
50f04e6
d6d69b4
e08fdab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,304 @@ | ||
| import numpy as np | ||
| from scipy.stats import linregress | ||
|
|
||
| from spikeinterface.core import ChannelSparsity, get_chunk_with_margin | ||
| from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs | ||
|
|
||
| from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift | ||
| from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension | ||
|
|
||
|
|
||
| class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): | ||
| """ | ||
| Computes amplitude scalings from WaveformExtractor. | ||
| """ | ||
|
|
||
| extension_name = "amplitude_scalings" | ||
|
|
||
| def __init__(self, waveform_extractor): | ||
| BaseWaveformExtractorExtension.__init__(self, waveform_extractor) | ||
|
|
||
| extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") | ||
| self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) | ||
|
|
||
| def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): | ||
| params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) | ||
| return params | ||
|
|
||
| def _select_extension_data(self, unit_ids): | ||
| old_unit_ids = self.waveform_extractor.sorting.unit_ids | ||
| unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) | ||
|
|
||
| spike_mask = np.in1d(self.spikes["unit_ind"], unit_inds) | ||
| new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] | ||
| return dict(amplitude_scalings=new_amplitude_scalings) | ||
|
|
||
| def _run(self, **job_kwargs): | ||
| job_kwargs = fix_job_kwargs(job_kwargs) | ||
| we = self.waveform_extractor | ||
| recording = we.recording | ||
| nbefore = we.nbefore | ||
| nafter = we.nafter | ||
| ms_before = self._params["ms_before"] | ||
| ms_after = self._params["ms_after"] | ||
| return_scaled = we._params["return_scaled"] | ||
|
|
||
| if ms_before is not None: | ||
| assert ( | ||
| ms_before <= we._params["ms_before"] | ||
| ), f"`ms_before` must be smaller than `ms_before` used in WaveformExractor: {we._params['ms_before']}" | ||
| if ms_after is not None: | ||
| assert ( | ||
| ms_after <= we._params["ms_after"] | ||
| ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" | ||
|
|
||
| cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore | ||
| cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter | ||
|
|
||
| if we.is_sparse(): | ||
| sparsity = we.sparsity | ||
| elif self._params["sparsity"] is not None: | ||
| sparsity = self._params["sparsity"] | ||
| else: | ||
| if self._params["max_dense_channels"] is not None: | ||
| assert recording.get_num_channels() <= self._params["max_dense_channels"], "" | ||
| sparsity = ChannelSparsity.create_dense(we) | ||
| sparsity_inds = sparsity.unit_id_to_channel_indices | ||
| all_templates = we.get_all_templates() | ||
|
|
||
| # and run | ||
| func = _amplitude_scalings_chunk | ||
| init_func = _init_worker_amplitude_scalings | ||
| n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) | ||
| job_kwargs["n_jobs"] = n_jobs | ||
| init_args = ( | ||
| recording, | ||
| self.spikes, | ||
| all_templates, | ||
| we.unit_ids, | ||
| sparsity_inds, | ||
| nbefore, | ||
| nafter, | ||
| cut_out_before, | ||
| cut_out_after, | ||
| return_scaled, | ||
| ) | ||
| processor = ChunkRecordingExecutor( | ||
| recording, | ||
| func, | ||
| init_func, | ||
| init_args, | ||
| handle_returns=True, | ||
| job_name="extract amplitude scalings", | ||
| **job_kwargs, | ||
| ) | ||
| out = processor.run() | ||
| (amp_scalings,) = zip(*out) | ||
| amp_scalings = np.concatenate(amp_scalings) | ||
|
|
||
| self._extension_data[f"amplitude_scalings"] = amp_scalings | ||
|
|
||
| def get_data(self, outputs="concatenated"): | ||
| """ | ||
| Get computed spike amplitudes. | ||
| Parameters | ||
| ---------- | ||
| outputs : str, optional | ||
| 'concatenated' or 'by_unit', by default 'concatenated' | ||
| Returns | ||
| ------- | ||
| spike_amplitudes : np.array or dict | ||
| The spike amplitudes as an array (outputs='concatenated') or | ||
| as a dict with units as key and spike amplitudes as values. | ||
| """ | ||
| we = self.waveform_extractor | ||
| sorting = we.sorting | ||
|
|
||
| if outputs == "concatenated": | ||
| return self._extension_data[f"amplitude_scalings"] | ||
| elif outputs == "by_unit": | ||
| amplitudes_by_unit = [] | ||
| for segment_index in range(we.get_num_segments()): | ||
| amplitudes_by_unit.append({}) | ||
| segment_mask = self.spikes["segment_ind"] == segment_index | ||
| spikes_segment = self.spikes[segment_mask] | ||
| amp_scalings_segment = self._extension_data[f"amplitude_scalings"][segment_mask] | ||
| for unit_index, unit_id in enumerate(sorting.unit_ids): | ||
| unit_mask = spikes_segment["unit_ind"] == unit_index | ||
| amp_scalings = amp_scalings_segment[unit_mask] | ||
| amplitudes_by_unit[segment_index][unit_id] = amp_scalings | ||
| return amplitudes_by_unit | ||
|
|
||
| @staticmethod | ||
| def get_extension_function(): | ||
| return compute_amplitude_scalings | ||
|
|
||
|
|
||
| WaveformExtractor.register_extension(AmplitudeScalingsCalculator) | ||
|
|
||
|
|
||
| def compute_amplitude_scalings( | ||
| waveform_extractor, | ||
| sparsity=None, | ||
| max_dense_channels=16, | ||
| ms_before=None, | ||
| ms_after=None, | ||
| load_if_exists=False, | ||
| outputs="concatenated", | ||
| **job_kwargs, | ||
| ): | ||
| """ | ||
| Computes the amplitude scalings from a WaveformExtractor. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| waveform_extractor: WaveformExtractor | ||
| The waveform extractor object | ||
| sparsity: ChannelSparsity | ||
| If waveforms are not sparse, sparsity is required if the number of channels is greater than | ||
| `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. | ||
| By default None | ||
| max_dense_channels: int, optional | ||
| Maximum number of channels to allow running without sparsity. To compute amplitude scaling using | ||
| dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. | ||
| By default 16 | ||
| ms_before : float, optional | ||
| The cut out to apply before the spike peak to extract local waveforms. | ||
| If None, the WaveformExtractor ms_before is used, by default None | ||
| ms_after : float, optional | ||
| The cut out to apply after the spike peak to extract local waveforms. | ||
| If None, the WaveformExtractor ms_after is used, by default None | ||
| load_if_exists : bool, default: False | ||
| Whether to load precomputed spike amplitudes, if they already exist. | ||
| peak_sign: str | ||
| The sign to compute maximum channel: | ||
| - 'neg' | ||
| - 'pos' | ||
| - 'both' | ||
| return_scaled: bool | ||
| If True and recording has gain_to_uV/offset_to_uV properties, amplitudes are converted to uV. | ||
| outputs: str | ||
| How the output should be returned: | ||
| - 'concatenated' | ||
| - 'by_unit' | ||
| {} | ||
|
|
||
| Returns | ||
| ------- | ||
| amplitude_scalings: np.array or list of dict | ||
| The amplitude scalings. | ||
| - If 'concatenated' all amplitudes for all spikes and all units are concatenated | ||
| - If 'by_unit', amplitudes are returned as a list (for segments) of dictionaries (for units) | ||
| """ | ||
| if load_if_exists and waveform_extractor.is_extension(AmplitudeScalingsCalculator.extension_name): | ||
| sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) | ||
| else: | ||
| sac = AmplitudeScalingsCalculator(waveform_extractor) | ||
| sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) | ||
| sac.run(**job_kwargs) | ||
|
|
||
| amps = sac.get_data(outputs=outputs) | ||
| return amps | ||
|
|
||
|
|
||
| compute_amplitude_scalings.__doc__.format(_shared_job_kwargs_doc) | ||
|
|
||
|
|
||
| def _init_worker_amplitude_scalings( | ||
| recording, | ||
| spikes, | ||
| all_templates, | ||
| unit_ids, | ||
| unit_ids_to_channel_indices, | ||
| nbefore, | ||
| nafter, | ||
| cut_out_before, | ||
| cut_out_after, | ||
| return_scaled, | ||
| ): | ||
| # create a local dict per worker | ||
| worker_ctx = {} | ||
| worker_ctx["recording"] = recording | ||
| worker_ctx["spikes"] = spikes | ||
| worker_ctx["all_templates"] = all_templates | ||
| worker_ctx["nbefore"] = nbefore | ||
| worker_ctx["nafter"] = nafter | ||
| worker_ctx["cut_out_before"] = cut_out_before | ||
| worker_ctx["cut_out_after"] = cut_out_after | ||
| worker_ctx["margin"] = max(nbefore, nafter) | ||
| worker_ctx["return_scaled"] = return_scaled | ||
|
|
||
| # construct handy unit_inds -> channel_inds | ||
| worker_ctx["unit_inds_to_channel_indices"] = { | ||
| unit_ind: unit_ids_to_channel_indices[unit_id] for unit_ind, unit_id in enumerate(unit_ids) | ||
| } | ||
|
|
||
| return worker_ctx | ||
|
|
||
|
|
||
| def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx): | ||
| # recover variables of the worker | ||
| spikes = worker_ctx["spikes"] | ||
| recording = worker_ctx["recording"] | ||
| all_templates = worker_ctx["all_templates"] | ||
| unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] | ||
| nbefore = worker_ctx["nbefore"] | ||
| cut_out_before = worker_ctx["cut_out_before"] | ||
| cut_out_after = worker_ctx["cut_out_after"] | ||
| margin = worker_ctx["margin"] | ||
| return_scaled = worker_ctx["return_scaled"] | ||
|
|
||
| i0 = np.searchsorted(spikes["segment_ind"], segment_index) | ||
| i1 = np.searchsorted(spikes["segment_ind"], segment_index + 1) | ||
| spikes_in_segment = spikes[i0:i1] | ||
|
|
||
| i0 = np.searchsorted(spikes["segment_ind"], segment_index) | ||
| i1 = np.searchsorted(spikes["segment_ind"], segment_index + 1) | ||
| spikes_in_segment = spikes[i0:i1] | ||
|
Comment on lines
+255
to
+257
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicate |
||
|
|
||
| i0 = np.searchsorted(spikes_in_segment["sample_ind"], start_frame) | ||
| i1 = np.searchsorted(spikes_in_segment["sample_ind"], end_frame) | ||
|
|
||
| local_waveforms = [] | ||
| templates = [] | ||
|
|
||
| scalings = [] | ||
|
|
||
| if i0 != i1: | ||
| local_spikes = spikes_in_segment[i0:i1] | ||
| traces_with_margin, left, right = get_chunk_with_margin( | ||
| recording._recording_segments[segment_index], start_frame, end_frame, channel_indices=None, margin=margin | ||
| ) | ||
|
|
||
| # scale traces with margin to match scaling of templates | ||
| if return_scaled and recording.has_scaled(): | ||
| gains = recording.get_property("gain_to_uV") | ||
| offsets = recording.get_property("offset_to_uV") | ||
| traces_with_margin = traces_with_margin.astype("float32") * gains + offsets | ||
|
|
||
| # get all waveforms | ||
| for spike in local_spikes: | ||
| unit_index = spike["unit_ind"] | ||
| sample_index = spike["sample_ind"] | ||
| sparse_indices = unit_inds_to_channel_indices[unit_index] | ||
| template = all_templates[unit_index][:, sparse_indices] | ||
| template = template[nbefore - cut_out_before : nbefore + cut_out_after] | ||
| sample_centered = sample_index - start_frame | ||
| cut_out_start = left + sample_centered - cut_out_before | ||
| cut_out_end = left + sample_centered + cut_out_after | ||
| if sample_index - cut_out_before < 0: | ||
| local_waveform = traces_with_margin[:cut_out_end, sparse_indices] | ||
| template = template[cut_out_before - sample_index :] | ||
| elif sample_index + cut_out_after > end_frame + right: | ||
| local_waveform = traces_with_margin[cut_out_start:, sparse_indices] | ||
| template = template[: -(sample_index + cut_out_after - end_frame)] | ||
| else: | ||
| local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] | ||
|
Comment on lines
+289
to
+296
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this when we have margin ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a smaller cut out to get a local waveform
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes but the margin ensure always the correct lenght no ? |
||
| assert template.shape == local_waveform.shape | ||
| local_waveforms.append(local_waveform) | ||
| templates.append(template) | ||
| linregress_res = linregress(template.flatten(), local_waveform.flatten()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does using linregress do not make it too slow ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not it's actually quite fast :) happy to discuss options |
||
| scalings.append(linregress_res[0]) | ||
| scalings = np.array(scalings) | ||
|
|
||
| return (scalings,) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import unittest | ||
| import numpy as np | ||
|
|
||
| from spikeinterface.postprocessing import AmplitudeScalingsCalculator | ||
|
|
||
| from spikeinterface.postprocessing.tests.common_extension_tests import ( | ||
| WaveformExtensionCommonTestSuite, | ||
| ) | ||
|
|
||
|
|
||
| class AmplitudeScalingsExtensionTest(WaveformExtensionCommonTestSuite, unittest.TestCase): | ||
| extension_class = AmplitudeScalingsCalculator | ||
| extension_data_names = ["amplitude_scalings"] | ||
| extension_function_kwargs_list = [ | ||
| dict(outputs="concatenated", chunk_size=10000, n_jobs=1), | ||
| dict(outputs="concatenated", chunk_size=10000, n_jobs=1, ms_before=0.5, ms_after=0.5), | ||
| dict(outputs="by_unit", chunk_size=10000, n_jobs=1), | ||
| dict(outputs="concatenated", chunk_size=10000, n_jobs=-1), | ||
| dict(outputs="concatenated", chunk_size=10000, n_jobs=2, ms_before=0.5, ms_after=0.5), | ||
| ] | ||
|
|
||
| def test_scaling_parallel(self): | ||
| scalings1 = self.extension_class.get_extension_function()( | ||
| self.we1, | ||
| outputs="concatenated", | ||
| chunk_size=10000, | ||
| n_jobs=1, | ||
| ) | ||
| scalings2 = self.extension_class.get_extension_function()( | ||
| self.we1, | ||
| outputs="concatenated", | ||
| chunk_size=10000, | ||
| n_jobs=2, | ||
| ) | ||
| np.testing.assert_array_equal(scalings1, scalings2) | ||
|
|
||
| def test_scaling_values(self): | ||
| scalings1 = self.extension_class.get_extension_function()( | ||
| self.we1, | ||
| outputs="by_unit", | ||
| chunk_size=10000, | ||
| n_jobs=1, | ||
| ) | ||
| # since this is GT spikes, the rounded median must be 1 | ||
| for u, scalings in scalings1[0].items(): | ||
| median_scaling = np.median(scalings) | ||
| print(u, median_scaling) | ||
| np.testing.assert_array_equal(np.round(median_scaling), 1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test = AmplitudeScalingsExtensionTest() | ||
| test.setUp() | ||
| test.test_scaling_values() | ||
| test.test_scaling_parallel() | ||
| # test.test_scaled() | ||
| # test.test_parallel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The segment slicing could be done in worker init once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the workers could span multiple segments no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but the first serahsorted can be done once for all.