Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 318 additions & 4 deletions element_interface/caiman_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pathlib
from datetime import datetime

import re
import caiman as cm
import h5py
import numpy as np
Expand All @@ -18,6 +18,312 @@


class CaImAn:
"""
Loader class for CaImAn analysis results
A top level aggregator of multiple set of CaImAn results (e.g. multi-plane analysis)
Calling _CaImAn (see below) under the hood
"""

def __init__(self, caiman_dir: str):
"""Initialize CaImAn loader class

Args:
caiman_dir (str): string, absolute file path to CaIman directory

Raises:
FileNotFoundError: No CaImAn analysis output file found
FileNotFoundError: No CaImAn analysis output found, missing required fields
"""
# ---- Search and verify CaImAn output file exists ----
caiman_dir = pathlib.Path(caiman_dir)
if not caiman_dir.exists():
raise FileNotFoundError("CaImAn directory not found: {}".format(caiman_dir))

caiman_subdirs = []
for fp in caiman_dir.rglob("*.hdf5"):
with h5py.File(fp, "r") as h5f:
if all(s in h5f for s in _required_hdf5_fields):
caiman_subdirs.append(fp.parent)

if not caiman_subdirs:
raise FileNotFoundError(
"No CaImAn analysis output file found at {}"
" containg all required fields ({})".format(
caiman_dir, _required_hdf5_fields
)
)

# Extract CaImAn results from all planes, sorted by plane index
_planes_caiman = {}
for idx, caiman_subdir in enumerate(sorted(caiman_subdirs)):
pln_cm = _CaImAn(caiman_subdir.as_posix())
pln_idx_match = re.search(r"pln(\d+)_.*", caiman_subdir.stem)
pln_idx = pln_idx_match.groups()[0] if pln_idx_match else idx
pln_cm.plane_idx = pln_idx
_planes_caiman[pln_idx] = pln_cm
sorted_pln_ind = sorted(list(_planes_caiman.keys()))
self.planes = {k: _planes_caiman[k] for k in sorted_pln_ind}

self.creation_time = min(
[p.creation_time for p in self.planes.values()]
) # ealiest file creation time
self.curation_time = max(
[p.curation_time for p in self.planes.values()]
) # most recent curation time

# is this 3D CaImAn analyis or multiple 2D per-plane analysis
if len(self.planes) > 1:
# if more than one set of caiman result, likely to be multiple 2D per-plane
# assert that the "is3D" value are all False for each of the caiman result
assert all(p.params.motion["is3D"] is False for p in self.planes.values())
self.is3D = False
self.is_multiplane = True
else:
self.is3D = list(self.planes.values())[0].params.motion["is3D"]
self.is_multiplane = False

if self.is_multiplane and self.is3D:
raise NotImplementedError(
f"Unable to load CaImAn results mixed between 3D and multi-plane analysis"
)

self._motion_correction = None
self._masks = None
self._ref_image = None
self._mean_image = None
self._max_proj_image = None
self._correlation_map = None

@property
def is_pw_rigid(self):
pw_rigid = set(p.params.motion["pw_rigid"] for p in self.planes.values())
assert (
len(pw_rigid) == 1
), f"Unable to load CaImAn results mixed between rigid and pw_rigid motion correction"
return pw_rigid.pop()

@property
def motion_correction(self):
if self._motion_correction is None:
self._motion_correction = (
self.extract_pw_rigid_mc()
if self.is_pw_rigid
else self.extract_rigid_mc()
)
return self._motion_correction

def extract_rigid_mc(self):
# -- rigid motion correction --
rigid_correction = {}
for pln_idx, (plane, pln_cm) in enumerate(self.planes.items()):
if pln_idx == 0:
rigid_correction = {
"x_shifts": pln_cm.motion_correction["shifts_rig"][:, 0],
"y_shifts": pln_cm.motion_correction["shifts_rig"][:, 1],
}
rigid_correction["x_std"] = np.nanstd(
rigid_correction["x_shifts"].flatten()
)
rigid_correction["y_std"] = np.nanstd(
rigid_correction["y_shifts"].flatten()
)
else:
rigid_correction["x_shifts"] = np.vstack(
[
rigid_correction["x_shifts"],
pln_cm.motion_correction["shifts_rig"][:, 0],
]
)
rigid_correction["x_std"] = np.nanstd(
rigid_correction["x_shifts"].flatten()
)
rigid_correction["y_shifts"] = np.vstack(
[
rigid_correction["y_shifts"],
pln_cm.motion_correction["shifts_rig"][:, 1],
]
)
rigid_correction["y_std"] = np.nanstd(
rigid_correction["y_shifts"].flatten()
)

if not self.is_multiplane:
pln_cm = list(self.planes.values())[0]
rigid_correction["z_shifts"] = (
pln_cm.motion_correction["shifts_rig"][:, 2]
if self.is3D
else np.full_like(rigid_correction["x_shifts"], 0)
)
rigid_correction["z_std"] = (
np.nanstd(pln_cm.motion_correction["shifts_rig"][:, 2])
if self.is3D
else np.nan
)
else:
rigid_correction["z_shifts"] = np.full_like(rigid_correction["x_shifts"], 0)
rigid_correction["z_std"] = np.nan

rigid_correction["outlier_frames"] = None

return rigid_correction

def extract_pw_rigid_mc(self):
# -- piece-wise rigid motion correction --
nonrigid_correction, nonrigid_blocks = {}
for pln_idx, (plane, pln_cm) in enumerate(self.planes.items()):
block_count = len(nonrigid_blocks)
if pln_idx == 0:
nonrigid_correction = {
"block_height": (
pln_cm.params.motion["strides"][0]
+ pln_cm.params.motion["overlaps"][0]
),
"block_width": (
pln_cm.params.motion["strides"][1]
+ pln_cm.params.motion["overlaps"][1]
),
"block_depth": 1,
"block_count_x": len(
set(pln_cm.motion_correction["coord_shifts_els"][:, 0])
),
"block_count_y": len(
set(pln_cm.motion_correction["coord_shifts_els"][:, 2])
),
"block_count_z": len(self.planes),
"outlier_frames": None,
}
for b_id in range(len(pln_cm.motion_correction["x_shifts_els"][0, :])):
b_id += block_count
nonrigid_blocks[b_id] = {
"block_id": b_id,
"block_x": np.arange(
*pln_cm.motion_correction["coord_shifts_els"][b_id, 0:2]
),
"block_y": np.arange(
*pln_cm.motion_correction["coord_shifts_els"][b_id, 2:4]
),
"block_z": (
np.arange(
*pln_cm.motion_correction["coord_shifts_els"][b_id, 4:6]
)
if self.is3D
else np.full_like(
np.arange(
*pln_cm.motion_correction["coord_shifts_els"][b_id, 0:2]
),
pln_idx,
)
),
"x_shifts": pln_cm.motion_correction["x_shifts_els"][:, b_id],
"y_shifts": pln_cm.motion_correction["y_shifts_els"][:, b_id],
"z_shifts": (
pln_cm.motion_correction["z_shifts_els"][:, b_id]
if self.is3D
else np.full_like(
pln_cm.motion_correction["x_shifts_els"][:, b_id],
0,
)
),
"x_std": np.nanstd(
pln_cm.motion_correction["x_shifts_els"][:, b_id]
),
"y_std": np.nanstd(
pln_cm.motion_correction["y_shifts_els"][:, b_id]
),
"z_std": (
np.nanstd(pln_cm.motion_correction["z_shifts_els"][:, b_id])
if self.is3D
else np.nan
),
}

if not self.is_multiplane and self.is3D:
pln_cm = list(self.planes.values())[0]
nonrigid_correction["block_depth"] = (
pln_cm.params.motion["strides"][2] + pln_cm.params.motion["overlaps"][2]
)
nonrigid_correction["block_count_z"] = len(
set(pln_cm.motion_correction["coord_shifts_els"][:, 4])
)

return nonrigid_correction, nonrigid_blocks

@property
def masks(self):
if self._masks is None:
all_masks = []
for pln_idx, pln_cm in sorted(self.planes.items()):
mask_count = len(all_masks) # increment mask id from all "plane"
all_masks.extend(
[
{
**m,
"mask_id": m["mask_id"] + mask_count,
"orig_mask_id": m["mask_id"],
"accepted": (
m["mask_id"] in pln_cm.cnmf.estimates.idx_components
if pln_cm.cnmf.estimates.idx_components is not None
else False
),
}
for m in pln_cm.masks
]
)

self._masks = all_masks
return self._masks

@property
def alignment_channel(self):
return 0 # hard-code to channel index 0

@property
def segmentation_channel(self):
return 0 # hard-code to channel index 0

# -- image property --

def _get_image(self, img_type):
if not self.is_multiplane:
pln_cm = list(self.planes.values())[0]
img_ = (
pln_cm.motion_correction[img_type].transpose()
if self.is3D
else pln_cm.motion_correction[img_type][...][np.newaxis, ...]
)
else:
img_ = np.dstack(
pln_cm.motion_correction[img_type][...]
for pln_cm in self.planes.values()
)
return img_

@property
def ref_image(self):
if self._ref_image is None:
self._ref_image = self._get_image("reference_image")
return self._ref_image

@property
def mean_image(self):
if self._mean_image is None:
self._mean_image = self._get_image("average_image")
return self._mean_image

@property
def max_proj_image(self):
if self._max_proj_image is None:
self._max_proj_image = self._get_image("max_image")
return self._max_proj_image

@property
def correlation_map(self):
if self._correlation_map is None:
self._correlation_map = self._get_image("correlation_image")
return self._correlation_map


class _CaImAn:
"""Parse the CaImAn output file

[CaImAn results doc](https://caiman.readthedocs.io/en/master/Getting_Started.html#result-variables-for-2p-batch-analysis)
Expand Down Expand Up @@ -54,6 +360,7 @@ class CaImAn:
motion_correction: h5f "motion_correction" property
params: cnmf.params
segmentation_channel: hard-coded to 0
plane_idx: N/A if `is3D` else hard-coded to 0
"""

def __init__(self, caiman_dir: str):
Expand Down Expand Up @@ -89,13 +396,20 @@ def __init__(self, caiman_dir: str):
self.params = self.cnmf.params

self.h5f = h5py.File(self.caiman_fp, "r")
self.motion_correction = self.h5f["motion_correction"]
self.plane_idx = None if self.params.motion["is3D"] else 0
self._motion_correction = None
self._masks = None

# ---- Metainfo ----
self.creation_time = datetime.fromtimestamp(os.stat(self.caiman_fp).st_ctime)
self.curation_time = datetime.fromtimestamp(os.stat(self.caiman_fp).st_ctime)

@property
def motion_correction(self):
if self._motion_correction is None:
self._motion_correction = self.h5f["motion_correction"]
return self._motion_correction

@property
def masks(self):
if self._masks is None:
Expand Down Expand Up @@ -139,7 +453,7 @@ def extract_masks(self) -> dict:
else:
xpix, ypix = np.unravel_index(ind, self.cnmf.dims, order="F")
center_x, center_y = comp_contour["CoM"].astype(int)
center_z = 0
center_z = self.plane_idx
zpix = np.full(len(weights), center_z)

masks.append(
Expand All @@ -161,7 +475,7 @@ def extract_masks(self) -> dict:
return masks


def _process_scanimage_tiff(scan_filenames, output_dir="./"):
def _process_scanimage_tiff(scan_filenames, output_dir="./", split_depths=False):
"""
Read ScanImage TIFF - reshape into volumetric data based on scanning depths/channels
Save new TIFF files for each channel - with shape (frame x height x width x depth)
Expand Down
Loading