Skip to content

Commit

Permalink
object detection tried
Browse files Browse the repository at this point in the history
  • Loading branch information
Goader committed May 14, 2023
1 parent a3d218d commit 6b3c76d
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,6 @@ splitted.zip

denoised
preprocessed

detection-dataset
detection.yaml
30 changes: 30 additions & 0 deletions cadml/models/retina_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch import nn
from torchtyping import TensorType

from torchvision import models

from omegaconf import DictConfig


class RetinaNet(nn.Module):
def __init__(self, cfg: DictConfig):
super().__init__()

self.cfg = cfg
self.in_channels = cfg.model.in_channels
self.use_pretrained = cfg.model.use_pretrained

weights = models.detection.RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
self.preprocess = weights.transforms()
if self.use_pretrained:
self.model = models.detection.retinanet_resnet50_fpn_v2(weights=weights, num_classes=2)
else:
self.model = models.detection.retinanet_resnet50_fpn_v2(weights=None, num_classes=2)

# TODO

def forward(self, X: TensorType["batch_size", "channels", "height", "width"]) -> TensorType["batch_size"]:
X = X.repeat(1, 3, 1, 1) # increases the number of in_channels to 3
X = self.preprocess(X)
# TODO
44 changes: 44 additions & 0 deletions scripts/fine-tune-retina-net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import numpy as np
import torch

# Import Detectron2 libraries
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.config import get_cfg
from detectron2 import model_zoo

# Register your custom dataset
register_coco_instances("my_dataset_train", {}, "path/to/annotations.json", "path/to/image/directory")
register_coco_instances("my_dataset_val", {}, "path/to/annotations.json", "path/to/image/directory")

# Define the configuration for the RetinaNet model
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = ("my_dataset_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/retinanet_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.MAX_ITER = 1000
cfg.MODEL.RETINANET.NUM_CLASSES = 2

# Define the trainer and start training
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

# Evaluate the model on the validation set
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

evaluator = COCOEvaluator("my_dataset_val", cfg, False, output_dir="./output/")
val_loader = build_detection_test_loader(cfg, "my_dataset_val")
inference_on_dataset(trainer.model, val_loader, evaluator)

# Fine-tune the model by adjusting the hyperparameters
cfg.SOLVER.MAX_ITER = 2000
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()
50 changes: 50 additions & 0 deletions scripts/form-coco-dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
import cv2

import coronaryx as cx


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('train_dataset_dir', help='Path to the dataset directory')
parser.add_argument('val_dataset_dir', help='Path to the dataset directory')
parser.add_argument('test_dataset_dir', help='Path to the dataset directory')
parser.add_argument('output_dir', help='Path to the output directory')
args = parser.parse_args()

output_dir = Path(args.output_dir)

yaml_conf = f"""
path: {output_dir}
train: 'images/train'
val: 'images/val'
test: 'images/test'
names:
0: stenosis
"""

with open('detection.yaml', 'w') as f:
f.write(yaml_conf)

for dataset_dir, part in zip([args.train_dataset_dir, args.val_dataset_dir, args.test_dataset_dir],
['train', 'val', 'test']):

dataset = cx.read_dataset(dataset_dir)

for item in dataset:
# create images directory
(output_dir / 'images' / part).mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(output_dir / 'images' / part / f'{item.name}.jpg'), item.scan)

# create labels directory
(output_dir / 'labels' / part).mkdir(parents=True, exist_ok=True)
with open(output_dir / 'labels' / part / f'{item.name}.txt', 'w') as f:
for roi in item.rois:
x1, x2 = roi.start_x / item.scan.shape[0], roi.end_x / item.scan.shape[0]
y1, y2 = roi.start_y / item.scan.shape[1], roi.end_y / item.scan.shape[1]

f.write(f'0 {(x1 + x2) / 2} {(y1 + y2) / 2} {x2 - x1} {y2 - y1}\n')

0 comments on commit 6b3c76d

Please sign in to comment.