-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,9 @@ | ||
|
||
FROM fastdotai/fastai@sha256:c36b43104474006d8f8cd2a65f740bfd505693c670644c1d2dbedb5a6fb2de8a | ||
RUN pip install -U fire==0.4.0 pandas==1.3.5 google-cloud-pubsub==2.13.0 google-cloud-storage==1.35.0 gcsfs==2022.5.0 | ||
WORKDIR /app | ||
COPY train.py . | ||
COPY test_model.py . | ||
COPY gcp_utils.py . | ||
|
||
ENTRYPOINT ["python", "train.py"] |
10 changes: 10 additions & 0 deletions
10
fastai_training_job/train_job_image/model_hyperparams_test.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,10 @@ | ||
{ | ||
"TAG_COLUMN": "tag", | ||
"RESIZE_VALUE": 64, | ||
"METHOD": "squish", | ||
"FREEZE_EPOCHS": 1, | ||
"BASE_LR": 0.002, | ||
"EPOCHS": 0, | ||
"ARCH_RESNET": "resnet34", | ||
"TEST_THR_DEFAULT": 0.5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,107 @@ | ||
import json | ||
import os | ||
import time | ||
|
||
import fire | ||
import numpy as np | ||
import pandas as pd | ||
from fastai.vision.all import * | ||
|
||
from gcp_utils import publish_metrics | ||
|
||
timestamp = time.strftime("%Y%m%d-%H%M%S") | ||
|
||
|
||
def train_evaluate(job_dir: str = None, | ||
training_dataset_path: str = None, | ||
training_images_path: str = None, | ||
model_version: str = None, | ||
model_hyperparms_path: str = None): | ||
|
||
if torch.cuda.is_available(): | ||
print('GPU available') | ||
|
||
# GCS gets mounted when the instance starts using gcsfuse. All prefixes exist under /gcs/ | ||
job_dir, training_images_path, model_hyperparms_path = [f.replace( | ||
'gs://', '/gcs/') if f.startswith('gs') else f for f in [job_dir, training_images_path, model_hyperparms_path]] | ||
|
||
# create subdir/prefix to export results | ||
job_subdir_export = f'{job_dir}/{model_version}_{timestamp}' | ||
os.makedirs(job_subdir_export) | ||
|
||
# read hyperparams values | ||
with open(model_hyperparms_path) as model_hparams_json: | ||
model_hparams_dict = json.load(model_hparams_json) | ||
print(f'hyperparameters: {model_hparams_dict}') | ||
|
||
# read training data csv | ||
print('Reading csv ...') | ||
df_train = pd.read_csv(training_dataset_path) | ||
|
||
print('Loading and transforming images...') | ||
# define datablock | ||
img_block_01 = DataBlock(blocks=(ImageBlock, MultiCategoryBlock), | ||
splitter=ColSplitter('is_valid'), | ||
get_x=ColReader( | ||
'image_id', pref=f'{training_images_path}/', suff=''), | ||
get_y=ColReader( | ||
model_hparams_dict['TAG_COLUMN'], label_delim=','), | ||
item_tfms=Resize( | ||
model_hparams_dict['RESIZE_VALUE'], method=model_hparams_dict['METHOD']), | ||
batch_tfms=aug_transforms(do_flip=False)) | ||
|
||
# load images and transform | ||
img_dls_01 = img_block_01.dataloaders(df_train) | ||
|
||
# train | ||
print('Starting training...') | ||
fscore = F1ScoreMulti() | ||
learn_01 = cnn_learner(dls=img_dls_01, | ||
arch=eval(model_hparams_dict['ARCH_RESNET']), | ||
metrics=[accuracy_multi, fscore]) | ||
learn_01.fine_tune( | ||
model_hparams_dict['EPOCHS'], | ||
base_lr=model_hparams_dict['BASE_LR'], | ||
freeze_epochs=model_hparams_dict['FREEZE_EPOCHS'], | ||
cbs=[SaveModelCallback(monitor='accuracy_multi', | ||
fname='model', with_opt=True)] | ||
) | ||
print('Training completed.') | ||
|
||
# dict to add metrics | ||
metrics_log = {} | ||
|
||
# get final model metrics | ||
# here accuracy_multi is maximized (can be changed to any other metric like valid_loss) | ||
accuracy_multi_train = learn_01.recorder.metrics[0].value.item() | ||
metrics_log['accuracy_multi_train'] = accuracy_multi_train | ||
|
||
# save metrics | ||
print("Saving metrics") | ||
with open(f'{job_subdir_export}/metrics_log.json', "w") as f: | ||
json.dump(metrics_log, f) | ||
print(metrics_log) | ||
|
||
# pickle model ------ | ||
model_name = f'model_{model_version}' | ||
|
||
pickle_model_path = f'{job_subdir_export}/{model_name}.pkl' | ||
print(f'Saving pickle model to {pickle_model_path}') | ||
learn_01.export(f'{pickle_model_path}') | ||
|
||
# jit model ---------- | ||
# save for native torch import later | ||
dummy_inp = torch.randn( | ||
[1, 3, model_hparams_dict['RESIZE_VALUE'], model_hparams_dict['RESIZE_VALUE']]) # dummy | ||
|
||
jit_model_path = f'{job_subdir_export}/{model_name}.pt' | ||
print(f'Saving jit model to {jit_model_path}') | ||
torch.jit.save(torch.jit.trace(learn_01.model, dummy_inp), | ||
f'{job_subdir_export}/model_{model_version}.pt') | ||
|
||
# save vocab with jit model | ||
vocab = np.save('models/vocab.npy', learn_01.dls.vocab) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(train_evaluate) |