This is the official implementation of the NeurIPS 2022 paper MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation. In this paper, we devise a general-purpose model for video prediction (forward and backward), unconditional generation, and interpolation with Masked Conditional Video Diffusion (MCVD) models. Please see our website for more details. This repo is based on the code from https://github.com/ermongroup/ncsnv2.
If you find the code/idea useful for your research, please cite:
@inproceedings{voleti2022MCVD,
author = {Voleti, Vikram and Jolicoeur-Martineau, Alexia and Pal, Christopher},
title = {MCVD: Masked Conditional Video Diffusion for Prediction, Generation, and Interpolation},
url = {https://arxiv.org/abs/2205.09853},
booktitle = {(NeurIPS) Advances in Neural Information Processing Systems},
year = {2022}
}
The models from our paper were trained with 1 to 4 GPUs (requiring from 32GB to 160GB of RAM). Models can be scaled with less or more GPUs by changing the following parameters:
- model.ngf and model.n_heads_channel (doubling ngf and n_heads_channels approximately doubles the memory demand)
- model.num_res_blocks (number of sequential residual layers per block)
- model.ch_mult=[1,2,3,4,4,4] will use 6 resblocks instead of the default 4 (model.ch_mult=[1,2,3,4])
- training.batch_size (doubling the batch size approximately increase the memory demand by 50%)
- SPATIN models can be scaled through model.spade_dim (128 -> 512 increase memory demands by 2x, 128 -> 1024 increase memory demand by 4x); it should be scaled proportionally to the number of past future frames for best results. In practice we find the SPATIN models often need very large spade_dim to be competitive, thus we recommend regular users to stick to concatenation.
# if using conda (ignore otherwise)
conda create --name vid python=3.8
# # (Optional) If your machine has a GCC/G version < 5:
# conda install -c conda-forge gxx=8.5.0 # (should be executed before the installation of pytorch, torchvision, and torchaudio)
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt # install all requirements
The experiments to reproduce the paper can be found in /example_scripts/final/training_scripts.sh and /example_scripts/final/sampling_scripts.sh.
We also provide a small notebook demo for sampling from SMMNIST: https://github.com/voletiv/mcvd-pytorch/blob/master/MCVD_demo_SMMNIST.ipynb.
The checkpoints used for the experiments and their results can be used here: https://drive.google.com/drive/u/1/folders/15pDq2ziTv3n5SlrGhGM0GVqwIZXgebyD
The models configurations are available at /configs. To overide any existing configuration from a config file, you can simply use the --config_mod argument in the command line. For example:
--config_mod training.snapshot_freq=50000 sampling.subsample=100 sampling.clip_before=True sampling.max_data_iter=1 model.version=DDPM model.arch=unetmore model.num_res_blocks=2
The important config options are:
training.batch_size=64 # training batch size
sampling.batch_size=200 # sampling batch size
sampling.subsample=100 # how many diffusion steps to take (1000 is best but is slower, 100 is faster)
sampling.max_data_iter=1000 # how many mini-batches of the test to go through at the maximum (set to 1 for training and a large value for sampling)
model.ngf=192 # number of channels (controls model size)
model.n_head_channels=192 # number of channels per self-attention head (should ideally be larger or equal to model.ngf, otherwise you may have a size mismatch error)
model.spade=True # if True uses space-time adaptive normalization instead of concatenation
model.spade_dim=128 # number of channels in space-time adaptive normalization; worth increasing, especially if conditioning on a large number of frames
sampling.num_frames_pred=16 # number of frames to predict (autoregressively)
data.num_frames=4 # number of current frames
data.num_frames_cond=4 # number of previous frames
data.num_frames_future=4 # number of future frames
data.prob_mask_cond=0.50 # probability of masking the previous frames (allows predicting current frames with no past frames)
data.prob_mask_future=0.50 # probability of masking the future frames (allows predicting current frames with no future frames)
When data.num_frames_future > 0
, data.num_frames_cond > 0
, data.prob_mask_cond=0.50
, and data.prob_mask_future=0.50
, one can do video prediction (forward and backward), generation, and interpolation.
You can train on Stochastic Moving MNIST with 1 GPU (if memory issues, use model.ngf=64) using:
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/smmnist_DDPM_big5.yml --data_path /my/data/path/to/datasets --exp smmnist_cat --ni
Log files will be saved in <exp>/logs/smmnist_cat
. This folder contains stdout, metric plots, and video samples over time.
You can train on Cityscapes with 4 GPUs using:
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --config configs/cityscapes_big_spade.yml --data_path /my/data/path/to/datasets --exp exp_city_spade --ni
You can look at stdout or the metric plots in <exp>/logs/smmnist_cat
to determine which checkpoint provides the best metrics. Then, you can sample from 25 frames using the chosen checkpoint (e.g., 250k) of the previous SMNIST model by running main.py
with the --video_gen
option:
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/smmnist_DDPM_big5.yml --data_path /my/data/path/to/datasets --exp smmnist_cat --ni --config_mod sampling.max_data_iter=1000 sampling.num_frames_pred=25 sampling.preds_per_test=10 sampling.subsample=100 model.version=DDPM --ckpt 250000 --video_gen -v videos_250k_DDPM_1000_nfp_pred25
Results will be saved in <exp>/video_samples/videos_250k_DDPM_1000_nfp_pred25
.
You can use the above option to sample videos from any pretrained MCVD model.
We tried a few options that did not help, but we left them in the code. Some of these options might be broken, we make no guarantees, use them at your own risk.
model.gamma=True # Gamma noise from https://arxiv.org/abs/2106.07582
training.L1=True # L1 loss
model.cond_emb=True # Embedding for wether we mask (1) or we don't mask (0)
output_all_frames=True # Option to output/predict all frames, not just current frames
noise_in_cond=True # Diffusion noise also in conditioning frames
one_frame_at_a_time=True # Autoregressive one image at a time instead of blockwise
model.version=FPNDM # F-PNDM from https://arxiv.org/abs/2202.09778
Note that this code can be used to generate images by setting data.num_frames=0, data.num_frames_cond=0, data.num_frames_future=0.
Many unused options also exist which are from the original code by https://github.com/ermongroup/ncsnv2, mostly applicable only to images.
The code will do it for you!
Code will download https://download.pytorch.org/models/alexnet-owt-7be5be79.pth and move it into:
models/weights/v0.1/alex.pth
The code will do it for you!
Code will download i3D model pretrained on Kinetics-400 from "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB!199&authkey=AApKdFHPXzWLNyI" Use
models/fvd/convert_tf_pretrained.py
to makei3d_pretrained_400.pt
The script will automatically download the PyTorch MNIST dataset, which will be used to generate Stochastic Moving MNIST dynamically.
Download the hdf5 dataset:
gdown --fuzzy https://drive.google.com/file/d/1d2UfHV6RhSrwdDAlCFY3GymtFPpmh_8X/view?usp=sharing
How the data was processed:
- Download KTH dataset to
/path/to/KTH
:
sh kth_download.sh /path/to/KTH
- Convert 64x64 images to HDF5 format:
python datasets/kth_convert.py --kth_dir '/path/to/KTH' --image_size 64 --out_dir '/path/to/KTH64_h5' --force_h5 False
Download the hdf5 dataset:
gdown --fuzzy https://drive.google.com/file/d/1-R_srAOy5ZcylGXVernqE4WLCe6N4_wq/view?usp=sharing
How the data was processed:
- Download BAIR Robotic Push dataset to
/path/to/BAIR
:
sh bair_dowload.sh /path/to/BAIR
- Convert it to HDF5 format, and save in
/path/to/BAIR_h5
:
python datasets/bair_convert.py --bair_dir '/path/to/BAIR' --out_dir '/path/to/BAIR_h5'
gdown --fuzzy https://drive.google.com/file/d/1oP7n-FUfa9ifsMn6JHNS9depZfftvrXx/view?usp=sharing
How the data was processed:
MAKE SURE YOU HAVE ~657GB SPACE! 324GB for the zip file, and 333GB for the unzipped image files
- Download Cityscapes video dataset (
leftImg8bit_sequence_trainvaltest.zip
(324GB)) :
sh cityscapes_download.sh username password
using yourusername
andpassword
that you created on https://www.cityscapes-dataset.com/- Convert it to HDF5 format, and save in
/path/to/Cityscapes<image_size>_h5
:
python datasets/cityscapes_convert.py --leftImg8bit_sequence_dir '/path/to/Cityscapes/leftImg8bit_sequence' --image_size 64 --out_dir '/path/to/Cityscapes64_h5'
Download the hdf5 dataset:
gdown --fuzzy https://drive.google.com/file/d/13yaJkKtmDsgtaEvuXKSvbix5usea6TJy/view?usp=sharing
How the data was processed:
MAKE SURE YOU HAVE ~657GB SPACE! 324GB for the zip file, and 333GB for the unzipped image files
- Download Cityscapes video dataset (
leftImg8bit_sequence_trainvaltest.zip
(324GB)) :
sh cityscapes_download.sh /path/to/download/to username password
using yourusername
andpassword
that you created on https://www.cityscapes-dataset.com/- Convert it to HDF5 format, and save in
/path/to/Cityscapes<image_size>_h5
:
python datasets/cityscapes_convert.py --leftImg8bit_sequence_dir '/path/to/Cityscapes/leftImg8bit_sequence' --image_size 128 --out_dir '/path/to/Cityscapes128_h5'
Download the hdf5 dataset:
gdown --fuzzy https://drive.google.com/file/d/1bDqhhfKYrdbIIOZeJcWHWjSyFQwmO1t-/view?usp=sharing
How the data was processed:
MAKE SURE YOU HAVE ~20GB SPACE! 6.5GB for the zip file, and 8GB for the unzipped image files
- Download UCF-101 video dataset (
UCF101.rar
(6.5GB)) :
sh cityscapes_download.sh /download/dir
\- Convert it to HDF5 format, and save in
/path/to/UCF101_h5
:
python datasets/ucf101_convert.py --out_dir /path/to/UCF101_h5 --ucf_dir /download/dir/UCF-101 --splits_dir /download/dir/ucfTrainTestlist