Official repository for the paper:
ProtoASNet: Dynamic Prototypes for Inherently Interpretable and Uncertainty-Aware Aortic Stenosis Classification in Echocardiography
Hooman Vaseli*, Ang Nan Gu*, S. Neda Ahmadi Amiri*, Michael Y. Tsang*, Andrea Fung, Nima Kondori, Armin Saadat, Purang Abolmaesumi, Teresa S. M. Tsang
(*Equal Contribution)
Published in MICCAI 2023
Springer Link
arXiv Link
- Introduction
- Environment Setup
- Train and Test
- Local Explanation
- Description of Files and Folders
- Acknowledgement
- Citation
This work has the aim to detect severity of Aortic Stenosis (AS) in B-Mode echo of Parasternal Long and Short axes (PLAX and PSAX) views. Due to privacy issues, we cannot share the private dataset on which we experimented on. We also experimentd on the TMED-2 public dataset, however that would be only for the image-based models.
- Clone the repo
git clone https://github.com/hooman007/ProtoASNet.git
cd ProtoASNet
-
place your data in the
data
folder. For TMED or your private dataset, you need to prepare your own dataset class. The existing code insrc/data/
may be useful for your reference. -
If using Docker, it can be setup by running
docker_setup.sh
on your server. Change the parameters according to your needs:- the name of the container
--name=your_container_name
\ - Find the suitable pytorch image tag from https://hub.docker.com/r/pytorch/pytorch/tags based on your server.
For example, we used:
pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime
- the name of the container
-
Python library dependencies can be installed using:
pip install --upgrade pip
pip install torch torchvision # if pytorch docker is not used
pip install pandas wandb tqdm seaborn torch-summary opencv-python jupyter jupyterlab imageio array2gif moviepy scikit-image scikit-learn torchmetrics termplotlib
pip install -e .
# sanity check
python -c "import torch; print(torch.__version__)"
python -c "import torch; print(torch.version.cuda)"
To train the model cd
to the project folder, then use the command python main.py
with the arguments described below:
-
--config_path="src/configs/<config-name>.yml"
: yaml file containing hyper-parameters for model, experiment, loss objectives, dataset, and augmentations. all are stored insrc/configs
-
--run_name="<your run name>"
: the name used by wandb to show the training results. -
--save_dir="logs/<path-to-save>"
the folder to save all the trained model checkpoints, evaluations, and visualization of learned prototypes -
--eval_only=True
a flag that evaluates the trained model -
--eval_data_type="valid"
or--eval_data_type="test"
evaluates the model using valid or test dataset respectively. only applied when--eval_only
flag is ON. -
--push_only=True
a flag to project (and then save the visualization of) the trained prototypes to the nearest relevant extracted features of training dataset. (this is done during training as well, but we can do it on any model checkpoint as standalone function using this flag) -
Note: You can modify any of the parameters included in the
config.yml
file on the fly by adding it as a parameter to python call in bash. For hierarchical parameters, the format is--parent.child.child=value
Examples for model checkpoint path:python main.py --config_path="src/configs/Ours_ProtoASNet_Video.yml" --run_name="ProtoASNet_test_run" --save_dir="logs/ProtoASNet/VideoBased_testrun_00" --model.checkpoint_path="logs/ProtoASNet/VideoBased_testrun_00/last.pth"
This bash command runs the last checkpoint saved inVideoBased_testrun_00
folder.
Note: You can find the training/testing commands with finalized hyper-parameters and yaml config files for the models reported in the MICCAI 2023 paper (both our models and baselines) in the MICCAI2023_ProtoASNet_Deploy.sh
script.
bash MICCAI2023_ProtoASNet_Deploy.sh
the important content saved in save_dir folder are:
-
model_best.pth
: checkpoint of the best model based on a metric of interest (e.g. mean AUC or F1 score) -
last.pth
: checkpoint of the model saved on the last epoch -
<epoch_num>push_f1-<meanf1>.pth
: saved checkpoint after every prototype projection. -
img/epoch-<epoch_num>_pushed
: folder containing:-
visualization of projected prototypes
-
prototypes_info.pickle
: stored dictionary containing:prototypes_filenames
: filenames of the source imagesprototypes_src_imgs
: source images in numpyprototypes_gts
: label of the source imagesprototypes_preds
: prediction of the source images (how model sees the source images)prototypes_occurrence_maps
: occurence map correpsonding to each prototype (where the model looks at for each prototype)prototypes_similarity_to_src_ROIs
: similarity score of the prototype vector before projection to the ROI it is projected to,
-
You can run the local exlanation to explain a given image locally by showing how similar it is to the learnt prototypes and how the model made its decision to classify the image as such.
To explain all the data in validation or test set, run the command bellow:
python explain.py --explain_locally=True --eval_data_type='val' --config_path="src/configs/<your config>.yml" --run_name="LocalExplain_<your name>" --wandb_mode="disabled" --save_dir="logs/<your run name>" --model.checkpoint_path="logs/<your run name>/model_best.pth"
outputs are stored in folder /path/to/saved/checkpoint/epoch_<##>/val
with this format:
local/filename/test_clip_AS-<AsLabel>.MP4
: showing the input echo videolocal/filename/AS-<AsLable>_<sim_score>_<prototype#>.png
Once you run the system, it will contain the saved models, logs, and evaluation results (visualization of explanations, etc)
When training is done for the first time, pretrained backbone models are saved here.
agents/
: folder containing agent classes for each of the architectures. contains the main framework for the training processconfigs/
: folder containing the yaml files containing hyper-parameters for model, experiment, loss objectives, dataset, and augmentations.data/
: folder for dataset and dataloader classesloss/
: folder for loss functionsmodels/
: folders for model architecturesutils/
: folder for some utility scripts and local explanation
Some code is borrowed from ProtoPNet, and we developed XprotoNet architecture based on their paper,
If you find this work useful in your research, please cite:
@InProceedings{10.1007/978-3-031-43987-2_36,
author="Vaseli, Hooman and Gu, Ang Nan and Ahmadi Amiri, S. Neda and Tsang, Michael Y. and Fung, Andrea and Kondori, Nima and Saadat, Armin and Abolmaesumi, Purang and Tsang, Teresa S. M.",
editor="Greenspan, Hayit and Madabhushi, Anant and Mousavi, Parvin and Salcudean, Septimiu
and Duncan, James and Syeda-Mahmood, Tanveer and Taylor, Russell",
title="ProtoASNet: Dynamic Prototypes for Inherently Interpretable and Uncertainty-Aware Aortic Stenosis Classification in Echocardiography",
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2023",
year="2023",
publisher="Springer Nature Switzerland",
address="Cham",
pages="368--378",
isbn="978-3-031-43987-2"
}