Implementation of MV-MWM in TensorFlow 2.
Multi-View Masked World Models (MV-MWM) is a reinforcement learning framework that (i) trains a multi-view masked autoencoder with view-masking and (ii) learns a world model for single-view, multi-view, and viewpoint-robust control.
Install dependencies
source dependency.sh
First install dependencies from RLBench repository. Then, install our customized RLBench in rlbench_shaped_rewards
directory.
cd ./rlbench_shaped_rewards
pip install -e .
To reproduce our experiments, please run below scripts in mvmwm
directory.
source ./scripts/train_mvmwm_multi_view.sh {TASK} {USE_ROTATION} {GPU} {SEED}
# For instance,
source ./scripts/train_mvmwm_multi_view.sh rlbench_phone_on_base False 0 1
source ./scripts/train_mvmwm_multi_view.sh rlbench_stack_wine True 0 1
source ./scripts/train_mvmwm_single_view.sh {TASK} {USE_ROTATION} {GPU} {SEED}
# For instance,
source ./scripts/train_mvmwm_single_view.sh rlbench_phone_on_base False 0 1
source ./scripts/train_mvmwm_single_view.sh rlbench_stack_wine True 0 1
source ./scripts/train_mvmwm_viewpoint_robust.sh {TASK} {USE_ROTATION} {DIFFICULTY} {GPU} {SEED}
# For instance,
source ./scripts/train_mvmwm_viewpoint_robust.sh rlbench_phone_on_base_custom False medium 0 1
source ./scripts/train_mvmwm_viewpoint_robust.sh rlbench_stack_wine_custom True weak 0 1
This code might not perfectly reproduce the results in the paper, possible due to the human errors in preparing and cleaning the code for release. Please let us know if you have any problem or trouble in reproducing our results. We will also try to conduct sanity-check experiments as soon as possible.