This code implements the skeleton-based action segmentation MS-GCN model from Automated freezing of gait assessment with marker-based motion capture and multi-stage spatial-temporal graph convolutional neural networks and Skeleton-based action segmentation with multi-stage spatial-temporal graph convolutional neural networks, arXiv 2022 (in-review).
It was originally developed for freezing of gait (FOG) assessment on a proprietary dataset. Recently, we have also achieved high skeleton-based action segmentation performance on public datasets, e.g. HuGaDB, LARa version 1, PKU-MMD v2, TUG.
Tested on Ubuntu 16.04 and Pytorch 1.10.1. Models were trained on a Nvidia Tesla K80.
The c3d data preparation script requires Biomechanical-Toolkit. For installation instructions, please refer to the following issue.
The datasets can be downloaded from:
- LARa: https://zenodo.org/record/3862782#.YizNT3pKjZs
- PKU-MMD: https://www.icst.pku.edu.cn/struct/Projects/PKUMMD.html
- HuGaDB: https://github.com/romanchereshnev/HuGaDB
- TUG: https://researchdata.ntu.edu.sg/dataset.xhtml?persistentId=doi:10.21979/N9/7VF22X
- FOG: not public
Alternatively, we provide a Onedrive link to download the used input features and labels. Note, though, that additional refinement of the features (e.g., mmskeleton) and of the labels (e.g., remove background labels/fill gaps between actions that are implausibly short) will likely improve results. Onedrive link: Features and labels.
data_prep/
-- Data preparation scripts.main.py
-- Main script. I suggest working with this interactively with an IDE. Please provide the dataset and train/predict arguments, e.g.--dataset=fog_example --action=train
.batch_gen.py
-- Batch loader.label_eval.py
-- Compute metrics and save prediction results.model.py
-- train/predict script.models/
-- Location for saving the trained models.models/ms_gcn.py
-- The MS-GCN model.models/net_utils/
-- Scripts to partition the graph for the various datasets. For more information about the partitioning, please refer to the section Graph representations. For more information about spatial-temporal graphs, please refer to ST-GCN.data/
-- Location for the processed datasets. For more information, please refer to the 'FOG' example.data/signals.
-- Scripts for computing the feature representations. Used for datasets that provided spatial features per joint, e.g. FOG, TUG, and PKU-MMD v2. For more information, please refer to the section Graph representations.results/
-- Location for saving the results.
After processing the dataset (scripts are dataset specific), each processed dataset should be placed in the data
folder. We provide an example for a motion capture dataset that is in c3d format. For this particular example, we extract 9 joints in 3D:
data_prep/read_frame.py
-- Import the joints and action labels from the c3d and save both in a separate csv.data_prep/gen_data/
-- Import the csv, construct the input, and save to npy for training. For more information about the input and label shape, please refer to the section Problem statement. Data processing is not necessary if the downloaded features and labels are used.
Please refer to the example in data/example/
for more information on how to structure the files for training/prediction.
Pre-trained models are provided for HuGaDB, PKU-MMD, and LARa. To reproduce the results from the paper:
- The dataset should be downloaded from their respective repository.
- See the "Data" section for more information on how to prepare the datasets.
- Place the pre-trained models in
models/
, e.g.models/hugadb
. - Ensure that the correct graph representation is chosen in
ms_gcn
. - Comment out
features = get_features(features)
in model (only for lara and hugadb). - Specify the correct sampling rate, e.g. downsampling factor of 4 for lara.
- Run main to generate the per-sample predictions with proper arguments, e.g.
--dataset=hugadb
--action=predict
. - Run label_eval with proper arguments, e.g.
--dataset=hugadb
.
The MS-GCN model and code are based on ST-GCN and MS-TCN. We thank the authors for publicly releasing their code.