Skip to content

Commit

Permalink
add training example
Browse files Browse the repository at this point in the history
  • Loading branch information
OmaymaS committed Jul 15, 2022
1 parent 72a599d commit 787db1e
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
9 changes: 9 additions & 0 deletions fastai_training_job/train_job_image/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 1,9 @@

FROM fastdotai/fastai@sha256:c36b43104474006d8f8cd2a65f740bfd505693c670644c1d2dbedb5a6fb2de8a
RUN pip install -U fire==0.4.0 pandas==1.3.5 google-cloud-pubsub==2.13.0 google-cloud-storage==1.35.0 gcsfs==2022.5.0
WORKDIR /app
COPY train.py .
COPY test_model.py .
COPY gcp_utils.py .

ENTRYPOINT ["python", "train.py"]
10 changes: 10 additions & 0 deletions fastai_training_job/train_job_image/model_hyperparams_test.json
Original file line number Diff line number Diff line change
@@ -0,0 1,10 @@
{
"TAG_COLUMN": "tag",
"RESIZE_VALUE": 64,
"METHOD": "squish",
"FREEZE_EPOCHS": 1,
"BASE_LR": 0.002,
"EPOCHS": 0,
"ARCH_RESNET": "resnet34",
"TEST_THR_DEFAULT": 0.5
}
107 changes: 107 additions & 0 deletions fastai_training_job/train_job_image/train.py
Original file line number Diff line number Diff line change
@@ -0,0 1,107 @@
import json
import os
import time

import fire
import numpy as np
import pandas as pd
from fastai.vision.all import *

from gcp_utils import publish_metrics

timestamp = time.strftime("%Y%m%d-%H%M%S")


def train_evaluate(job_dir: str = None,
training_dataset_path: str = None,
training_images_path: str = None,
model_version: str = None,
model_hyperparms_path: str = None):

if torch.cuda.is_available():
print('GPU available')

# GCS gets mounted when the instance starts using gcsfuse. All prefixes exist under /gcs/
job_dir, training_images_path, model_hyperparms_path = [f.replace(
'gs://', '/gcs/') if f.startswith('gs') else f for f in [job_dir, training_images_path, model_hyperparms_path]]

# create subdir/prefix to export results
job_subdir_export = f'{job_dir}/{model_version}_{timestamp}'
os.makedirs(job_subdir_export)

# read hyperparams values
with open(model_hyperparms_path) as model_hparams_json:
model_hparams_dict = json.load(model_hparams_json)
print(f'hyperparameters: {model_hparams_dict}')

# read training data csv
print('Reading csv ...')
df_train = pd.read_csv(training_dataset_path)

print('Loading and transforming images...')
# define datablock
img_block_01 = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
splitter=ColSplitter('is_valid'),
get_x=ColReader(
'image_id', pref=f'{training_images_path}/', suff=''),
get_y=ColReader(
model_hparams_dict['TAG_COLUMN'], label_delim=','),
item_tfms=Resize(
model_hparams_dict['RESIZE_VALUE'], method=model_hparams_dict['METHOD']),
batch_tfms=aug_transforms(do_flip=False))

# load images and transform
img_dls_01 = img_block_01.dataloaders(df_train)

# train
print('Starting training...')
fscore = F1ScoreMulti()
learn_01 = cnn_learner(dls=img_dls_01,
arch=eval(model_hparams_dict['ARCH_RESNET']),
metrics=[accuracy_multi, fscore])
learn_01.fine_tune(
model_hparams_dict['EPOCHS'],
base_lr=model_hparams_dict['BASE_LR'],
freeze_epochs=model_hparams_dict['FREEZE_EPOCHS'],
cbs=[SaveModelCallback(monitor='accuracy_multi',
fname='model', with_opt=True)]
)
print('Training completed.')

# dict to add metrics
metrics_log = {}

# get final model metrics
# here accuracy_multi is maximized (can be changed to any other metric like valid_loss)
accuracy_multi_train = learn_01.recorder.metrics[0].value.item()
metrics_log['accuracy_multi_train'] = accuracy_multi_train

# save metrics
print("Saving metrics")
with open(f'{job_subdir_export}/metrics_log.json', "w") as f:
json.dump(metrics_log, f)
print(metrics_log)

# pickle model ------
model_name = f'model_{model_version}'

pickle_model_path = f'{job_subdir_export}/{model_name}.pkl'
print(f'Saving pickle model to {pickle_model_path}')
learn_01.export(f'{pickle_model_path}')

# jit model ----------
# save for native torch import later
dummy_inp = torch.randn(
[1, 3, model_hparams_dict['RESIZE_VALUE'], model_hparams_dict['RESIZE_VALUE']]) # dummy

jit_model_path = f'{job_subdir_export}/{model_name}.pt'
print(f'Saving jit model to {jit_model_path}')
torch.jit.save(torch.jit.trace(learn_01.model, dummy_inp),
f'{job_subdir_export}/model_{model_version}.pt')

# save vocab with jit model
vocab = np.save('models/vocab.npy', learn_01.dls.vocab)


if __name__ == "__main__":
fire.Fire(train_evaluate)

0 comments on commit 787db1e

Please sign in to comment.