Skip to content
/ cDAL Public

Conditional diffusion model with spatial attention and latent embedding for medical image segmentation

License

Notifications You must be signed in to change notification settings

Hejrati/cDAL

Repository files navigation

PyTorch implementation of "Conditional diffusion model with spatial attention and latent embedding" [MICCAI 2024]



Diffusion models have been used extensively for high quality image and video generation tasks. In this paper, we propose a novel conditional diffusion model with spatial attention and latent embedding (cDAL) for medical image segmentation. In cDAL, a convolutional neural network (CNN) based discriminator is used at every time-step of the diffusion process to distinguish between the generated labels and the real ones. A spatial attention map is computed based on the features learned by the discriminator to help cDAL generate more accurate segmentation of discriminative regions in an input image. Additionally, we incorporated a random latent embedding into each layer of our model to significantly reduce the number of training and sampling time-steps, thereby making it much faster than other diffusion models for image segmentation

Architecture

Set up datasets

We trained cDAL on several datasets, including MoNuSeg2018, Chest-XRay(CXR) and Hippocampus.

Training cDAL

We use the following commands on each dataset for training cDAL. Use parameters_monu.json for MonuSeg and parameters_lung.json for CXR.

To train the model for Hippocampus dataset, use this train_cDAL_hippo.py. You can find corresponding parameters in the code.

To train either MoNuSeg or CXR, you should use train_cDal_monu_and_lung.py. All necessary parameters are included in parameters_monu.json and parameters_lung.json. These files can be directly loaded into the code, or you can modify parameters in the code file.

MoNuSeg

Here you can find general website of the challenge, download the dataset train and test sets.

CXR

This is the link for Lung segmentation from Chest X-Ray dataset. To preprocess images, we followed the same standard.

Hippocampus 3D

In this is link you can find Hippocampus dataset. This dataset can be directly downloaded from this google drive link

Pretrained Checkpoints

We have already released pretrained checkpoints on MonuSeg and CXR in 'saved_models'. Simply download the saved_models directory to the code directory. Use parameters_monu.json for MonuSeg and parameters_lung.json for CXR.

Evaluation

After training, samples can be generated by calling sampling_monu_and_lung.py for MoNuSeg and CXR datasets or sampling_hippo.py for Hippocampus dataset. Hippocamus uses metrics_hippo.py file for evaluation since it should be processed based on One-hot encoding. We evaluated the models with a single NVIDIA Quadro RTX 6000 GPU.

We use the MONAI implementation for Hippocampus dataset to process dataset and compute one-hot encoding. Aslo, we use DDGAN implemention for our difusion model and time-dependent discriminator.

Evaluation1

Evaluation2

Evaluation3

License

Please check the LICENSE file.

Bibtex

Cite our paper using the following bibtex item:

About

Conditional diffusion model with spatial attention and latent embedding for medical image segmentation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published