NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling
Junhyeok Lee, Seungu Han @ MINDsLab Inc., SNU
Official Pytorch Lightning Implementation for NU-Wave.
Update: torch.log --> torch.log10 on lsd, value and lsd formula in the paper is right.
- Pytorch >=1.7.0 for nn.SiLU(swish activation)
- Pytorch-Lightning==1.1.6
- The requirements are highlighted in requirements.txt.
- We also provide docker setup Dockerfile.
Before running our project, you need to download and preprocess dataset to .pt
files
- Download VCTK dataset
- Remove speaker
p280
andp315
- Modify path of downloaded dataset
data:dir
inhparameter.yaml
- run
utils/wav2pt.py
$ python utils/wav2pt.py
- Adjust
hparameter.yaml
, especiallytrain
section.
train:
batch_size: 18 # Dependent on GPU memory size
lr: 0.00003
weight_decay: 0.00
num_workers: 64 # Dependent on CPU cores
gpus: 2 # number of GPUs
opt_eps: 1e-9
beta1: 0.5
beta2: 0.999
- If you want to train with single speaker, use
VCTKSingleSpkDataset
instead ofVCTKMultiSpkDataset
for dataset indataloader.py
. And usebatch_size=1
for validation dataloader. - Adjust
data
section inhparameters.yaml
.
data:
dir: '/DATA1/VCTK/VCTK-Corpus/wav48/p225' #dir/spk/format
format: '*mic1.pt'
cv_ratio: (223./231., 8./231., 0.00) #train/val/test
- run
trainer.py
.
$ python trainer.py
- If you want to resume training from checkpoint, check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,\
required = False, help = "Resume Checkpoint epoch number")
parser.add_argument('-s', '--restart', action = "store_true",\
required = False, help = "Significant change occured, use this")
parser.add_argument('-e', '--ema', action = "store_true",\
required = False, help = "Start from ema checkpoint")
args = parser.parse_args()
- During training, tensorboard logger is logging loss, spectrogram and audio.
$ tensorboard --logdir=./tensorboard --bind_all
run for_test.py
or test.py
$ python test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
or
$ python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
Please check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,
required = True, help = "Resume Checkpoint epoch number")
parser.add_argument('-e', '--ema', action = "store_true",
required = False, help = "Start from ema checkpoint")
parser.add_argument('--save', action = "store_true",
required = False, help = "Save file")
While we provide lightning style test code test.py
, it has device dependency.
Thus, we recommend to use for_test.py
.
This implementation uses code from following repositories:
- J.Ho's official DDPM implementation
- lucidrains' DDPM pytorch implementation
- ivanvovk's WaveGrad pytorch implementation
- lmnt-com's DiffWave pytorch implementation
This README and the webpage for the audio samples are inspired by:
- Tips for Publishing Research Code
- Audio samples webpage of DCA
- Cotatron
- Audio samples wabpage of WaveGrad
The audio samples on our webpage are partially derived from:
- VCTK dataset(0.92): 46 hours of English speech from 108 speakers.
.
├── Dockerfile
├── dataloader.py # Dataloader for train/val(=test)
├── filters.py # Filter implementation
├── test.py # Test with lightning_loop.
├── for_test.py # Test with for_loop. Recommended due to device dependency of lightning
├── hparameter.yaml # Config
├── lightning_model.py # NU-Wave implementation. DDPM is based on ivanvok's WaveGrad implementation
├── model.py # NU-Wave model based on lmnt-com's DiffWave implementation
├── requirement.txt # requirement libraries
├── sampling.py # Sampling a file
├── trainer.py # Lightning trainer
├── README.md
├── LICSENSE
├── utils
│ ├── stft.py # STFT layer
│ ├── tblogger.py # Tensorboard Logger for lightning
│ └── wav2pt.py # Preprocessing
└── docs # For github.io
└─ ...
If this repository useful for your research, please consider citing! Bibtex will be updated after INTERSPEECH 2021 conference.
@article{lee2021nuwave,
title={NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling},
author={Lee, Junhyeok and Han, Seungu},
journal={arXiv preprint arXiv:2104.02321},
year={2021}
}
If you have a question or any kind of inquiries, please contact Junhyeok Lee at [email protected]