Skip to content

Commit

Permalink
Merge pull request dipy#2717 from samcoveney/fixfitmask
Browse files Browse the repository at this point in the history
fixed bug for non-linear fitting with masks
  • Loading branch information
skoudoro authored Jan 25, 2023
2 parents a2cb9ad + d25fc5a commit 1f5eaaa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion dipy/reconst/dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def fit(self, data, mask=None):
dti_params[mask, :] = params_in_mask
if self.return_S0_hat:
S0_params = np.zeros(data.shape[:-1])
S0_params[mask] = model_S0
S0_params[mask] = model_S0.squeeze()

return TensorFit(self, dti_params, model_S0=S0_params)

Expand Down
49 changes: 25 additions & 24 deletions dipy/reconst/tests/test_dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,27 +474,10 @@ def test_all_zeros():

def test_mask():
data, gtab = dsi_voxels()
dm = dti.TensorModel(gtab, 'LS')
mask = np.zeros(data.shape[:-1], dtype=bool)
mask[0, 0, 0] = True
dtifit = dm.fit(data)
dtifit_w_mask = dm.fit(data, mask=mask)
# Without a mask it has some value
assert not np.isnan(dtifit.fa[0, 0, 0])
# Where mask is False, evals, evecs and fa should all be 0
npt.assert_array_equal(dtifit_w_mask.evals[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.evecs[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.fa[~mask], 0)
# Except for the one voxel that was selected by the mask:
npt.assert_almost_equal(dtifit_w_mask.fa[0, 0, 0], dtifit.fa[0, 0, 0])

# Test with returning S0_hat
dm = dti.TensorModel(gtab, 'LS', return_S0_hat=True)
mask = np.zeros(data.shape[:-1], dtype=bool)
mask[0, 0, 0] = True
for mask_more in [True, False]:
if mask_more:
mask[0, 0, 1] = True
for fit_type in ['LS', 'NLLS']:
dm = dti.TensorModel(gtab, fit_type)
mask = np.zeros(data.shape[:-1], dtype=bool)
mask[0, 0, 0] = True
dtifit = dm.fit(data)
dtifit_w_mask = dm.fit(data, mask=mask)
# Without a mask it has some value
Expand All @@ -503,11 +486,29 @@ def test_mask():
npt.assert_array_equal(dtifit_w_mask.evals[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.evecs[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.fa[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.S0_hat[~mask], 0)
# Except for the one voxel that was selected by the mask:
npt.assert_almost_equal(dtifit_w_mask.fa[0, 0, 0], dtifit.fa[0, 0, 0])
npt.assert_almost_equal(dtifit_w_mask.S0_hat[0, 0, 0],
dtifit.S0_hat[0, 0, 0])

# Test with returning S0_hat
dm = dti.TensorModel(gtab, fit_type, return_S0_hat=True)
mask = np.zeros(data.shape[:-1], dtype=bool)
mask[0, 0, 0] = True
for mask_more in [True, False]:
if mask_more:
mask[0, 0, 1] = True
dtifit = dm.fit(data)
dtifit_w_mask = dm.fit(data, mask=mask)
# Without a mask it has some value
assert not np.isnan(dtifit.fa[0, 0, 0])
# Where mask is False, evals, evecs and fa should all be 0
npt.assert_array_equal(dtifit_w_mask.evals[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.evecs[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.fa[~mask], 0)
npt.assert_array_equal(dtifit_w_mask.S0_hat[~mask], 0)
# Except for the one voxel that was selected by the mask:
npt.assert_almost_equal(dtifit_w_mask.fa[0, 0, 0], dtifit.fa[0, 0, 0])
npt.assert_almost_equal(dtifit_w_mask.S0_hat[0, 0, 0],
dtifit.S0_hat[0, 0, 0])


def test_nnls_jacobian_fucn():
Expand Down

0 comments on commit 1f5eaaa

Please sign in to comment.