Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clustering at the beginning of template creation #87

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions fmralign/template_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nilearn._utils.masker_validation import check_embedded_masker
from sklearn.base import BaseEstimator, TransformerMixin
from fmralign.pairwise_alignment import PairwiseAlignment
from fmralign._utils import _make_parcellation


def _rescaled_euclidean_mean(imgs, masker, scale_average=False):
Expand Down Expand Up @@ -92,6 +93,7 @@ def _create_template(
alignment_method,
n_pieces,
clustering,
clustering_index,
n_bags,
masker,
memory,
Expand Down Expand Up @@ -128,19 +130,37 @@ def _create_template(
template_history: list of list of 3D Niimgs
List of the intermediate templates computed at the end of each iteration
"""
# Use all contrasts for clustering
imgs_concat = concat_imgs(imgs)
if clustering_index is None:
clustering_index = np.arange(imgs_concat.shape[-1])
# Assert that the clustering index contains only valid indices
if not all(0 <= i < imgs_concat.shape[-1] for i in clustering_index):
raise ValueError("clustering_index should only contain valid indices.")
labels = _make_parcellation(
imgs_concat,
clustering_index,
clustering,
n_pieces,
masker,
verbose=verbose,
)
labels_img = masker.inverse_transform(labels)

aligned_imgs = imgs
template_history = []
for iter in range(n_iter):
template = _rescaled_euclidean_mean(aligned_imgs, masker, scale_template)
template = _rescaled_euclidean_mean(
aligned_imgs, masker, scale_template
)
if 0 < iter < n_iter - 1:
template_history.append(template)
aligned_imgs = _align_images_to_template(
imgs,
template,
alignment_method,
n_pieces,
clustering,
labels_img,
n_bags,
masker,
memory,
Expand Down Expand Up @@ -240,6 +260,7 @@ def __init__(
alignment_method="identity",
n_pieces=1,
clustering="kmeans",
clustering_index=None,
scale_template=False,
n_iter=2,
save_template=None,
Expand Down Expand Up @@ -278,6 +299,9 @@ def __init__(
passed to nilearn.regions.parcellations
If 3D Niimg, image used as predefined clustering,
n_bags and n_pieces are then ignored.
clustering_index: list of integers
Clustering is performed on a subset of the data chosen randomly
in timeframes. This index carries this subset.
scale_template: boolean, default False
rescale template after each inference so that it keeps
the same norm as the average of training images.
Expand Down Expand Up @@ -337,6 +361,7 @@ def __init__(
self.alignment_method = alignment_method
self.n_pieces = n_pieces
self.clustering = clustering
self.clustering_index = clustering_index
self.n_iter = n_iter
self.scale_template = scale_template
self.save_template = save_template
Expand Down Expand Up @@ -404,6 +429,7 @@ def fit(self, imgs):
self.alignment_method,
self.n_pieces,
self.clustering,
self.clustering_index,
self.n_bags,
self.masker_,
self.memory,
Expand Down Expand Up @@ -473,7 +499,9 @@ def transform(self, imgs, train_index, test_index):
"greater index in train_index or test_index."
)

fitted_mappings = Parallel(self.n_jobs, prefer="threads", verbose=self.verbose)(
fitted_mappings = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
delayed(_map_template_to_image)(
img,
train_index,
Expand All @@ -491,7 +519,9 @@ def transform(self, imgs, train_index, test_index):
for img in imgs
)

predicted_imgs = Parallel(self.n_jobs, prefer="threads", verbose=self.verbose)(
predicted_imgs = Parallel(
self.n_jobs, prefer="threads", verbose=self.verbose
)(
delayed(_predict_from_template_and_mapping)(
self.template, test_index, mapping
)
Expand Down
Loading