-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Relation DETR #34900
base: main
Are you sure you want to change the base?
Add Relation DETR #34900
Conversation
7867221
to
1f0465c
Compare
37959ac
to
d114fc7
Compare
Hi @xiuqhou! Congratulations on the paper, awesome work! And thanks for working on transformers implementation! Feel free to ping me when it's ready for review or if you have any questions! |
14308cf
to
ce63725
Compare
Hi @qubvel Thanks for your support! The code is now ready for review—I'd greatly appreciate it if you could take a look and share your feedback. Please let me know if there’s anything that needs improvement. |
src/transformers/models/relation_detr/configuration_relation_detr.py
Outdated
Show resolved
Hide resolved
src/transformers/models/relation_detr/configuration_relation_detr.py
Outdated
Show resolved
Hide resolved
src/transformers/models/relation_detr/image_processing_relation_detr.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Pavel Iakubovskii <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
6d30228
to
f530bbb
Compare
Hi @qubvel , thank you very much for your careful review! 🤗 I have updated the code accordingly, and left some responses in your comments to clarify some questions. Please let me know if there is anything that needs to be improved. |
Hi @xiuqhou, Happy New Year! Thanks for addressing the comments, and sorry for the delay. I'm going to review it right now. 🤗 |
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) | ||
if annotations is not None: | ||
prepared_images = [] | ||
prepared_annotations = [] | ||
for image, target in zip(images, annotations): | ||
target = self.prepare_annotation( | ||
image, | ||
target, | ||
format, | ||
input_data_format=input_data_format, | ||
) | ||
prepared_images.append(image) | ||
prepared_annotations.append(target) | ||
images = prepared_images | ||
annotations = prepared_annotations | ||
del prepared_images, prepared_annotations | ||
|
||
# transformations | ||
if do_resize: | ||
if annotations is not None: | ||
resized_images, resized_annotations = [], [] | ||
for image, target in zip(images, annotations): | ||
orig_size = get_image_size(image, input_data_format) | ||
resized_image = self.resize( | ||
image, size=size, resample=resample, input_data_format=input_data_format | ||
) | ||
resized_annotation = self.resize_annotation( | ||
target, orig_size, get_image_size(resized_image, input_data_format) | ||
) | ||
resized_images.append(resized_image) | ||
resized_annotations.append(resized_annotation) | ||
images = resized_images | ||
annotations = resized_annotations | ||
del resized_images, resized_annotations | ||
else: | ||
images = [ | ||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format) | ||
for image in images | ||
] | ||
|
||
if do_rescale: | ||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] | ||
|
||
if do_normalize: | ||
images = [ | ||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images | ||
] | ||
|
||
if do_convert_annotations and annotations is not None: | ||
annotations = [ | ||
self.normalize_annotation(annotation, get_image_size(image, input_data_format)) | ||
for annotation, image in zip(annotations, images) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets put this to a single loop:
processed_images = []
processed_annotations = []
for i, image in enumerate(images):
annotation = annotations[i] if annotations is not None else None
if resize:
image = self.resize(...)
if annotation is not None:
annotation = ...
...
processed_images.append(image)
if annotation is not None:
processed_annotations.append(annotation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I did another round of review! It looks clean. There are mostly small comments on my side, please see them below. Next time gonna try fine-tuning and if everything good pass it to core-maintainers approval to be merged 🤗
backbone_kwargs=backbone_kwargs, | ||
) | ||
|
||
assert backbone_features_format in ["channels_first", "channels_last"], ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets use raise instead of assert
if is_timm_available(): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
init_reference_points: torch.FloatTensor = None | ||
dec_outputs_class: torch.FloatTensor = None | ||
dec_outputs_coord: torch.FloatTensor = None | ||
enc_outputs_class: torch.FloatTensor = None | ||
enc_outputs_coord: torch.FloatTensor = None | ||
last_hidden_state: torch.FloatTensor = None | ||
intermediate_hidden_states: torch.FloatTensor = None | ||
intermediate_reference_points: torch.FloatTensor = None | ||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we uae order similar to other models? like last_hidden_state
should be the first
super().__init__() | ||
self.in_channels = in_channels | ||
self.post_layer_norm = post_layer_norm | ||
if self.post_layer_norm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have pretrained checkpoints for both cases? Am I right that if post_layer_norm=False, this module will do nothing or just transpose? Let's ust skip this layer instead in parent module then
if self.post_layer_norm: | ||
if self.backbone_features_format == "channels_first": | ||
# convert (batch_size, channels, height, width) -> (batch_size, height, width, channels) | ||
multi_level_feats = [feat.permute(0, 2, 3, 1) for feat in multi_level_feats] | ||
|
||
for idx, feat in enumerate(multi_level_feats): | ||
multi_level_feats[idx] = self.norms[idx](feat) | ||
|
||
# convert (batch_size, height, width, channels) -> (batch_size, channels, height, width) | ||
multi_level_feats = [feat.permute(0, 3, 1, 2) for feat in multi_level_feats] | ||
else: | ||
for idx, feat in enumerate(multi_level_feats): | ||
multi_level_feats[idx] = self.norms[idx](feat) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.post_layer_norm: | |
if self.backbone_features_format == "channels_first": | |
# convert (batch_size, channels, height, width) -> (batch_size, height, width, channels) | |
multi_level_feats = [feat.permute(0, 2, 3, 1) for feat in multi_level_feats] | |
for idx, feat in enumerate(multi_level_feats): | |
multi_level_feats[idx] = self.norms[idx](feat) | |
# convert (batch_size, height, width, channels) -> (batch_size, channels, height, width) | |
multi_level_feats = [feat.permute(0, 3, 1, 2) for feat in multi_level_feats] | |
else: | |
for idx, feat in enumerate(multi_level_feats): | |
multi_level_feats[idx] = self.norms[idx](feat) | |
if self.post_layer_norm and self.backbone_features_format == "channels_first": | |
# convert (batch_size, channels, height, width) -> (batch_size, height, width, channels) | |
multi_level_feats = [feat.permute(0, 2, 3, 1) for feat in multi_level_feats] | |
for idx, feat in enumerate(multi_level_feats): | |
multi_level_feats[idx] = self.norms[idx](feat) | |
# convert (batch_size, height, width, channels) -> (batch_size, channels, height, width) | |
multi_level_feats = [feat.permute(0, 3, 1, 2) for feat in multi_level_feats] | |
elif self.post_layer_norm: | |
for idx, feat in enumerate(multi_level_feats): | |
multi_level_feats[idx] = self.norms[idx](feat) | |
# When using clones, all layers > 0 will be clones, but layer 0 *is* required | ||
# _tied_weights_keys = [r"bbox_head\.[1-9]\d*", r"class_head\.[1-9]\d*"] | ||
# We can't initialize the model on meta device as some weights are modified during the initialization | ||
_no_split_modules = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please specify layers, it should be aly layer class names that contain residual connectons, like *EncoderLayer, *DecoderLayer
@@ -0,0 1,741 @@ | |||
# coding=utf-8 | |||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe 2025 already
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let also add # Copied from for non modified tests
@@ -0,0 1,485 @@ | |||
# coding=utf-8 | |||
# Copyright 2022 HuggingFace Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2022 HuggingFace Inc. | |
# Copyright 2024 HuggingFace Inc. |
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need head initialization for finetuning? see rt-detr's _init_weights
for example
What does this PR do?
This PR adds Relation-DETR as introduced in Relation DETR: Exploring Explicit Position Relation Prior for Object Detection. Checkpoint for Relation-DETR (ResNet50) converted from original repo https://github.com/xiuqhou/Relation-DETR has been uploaded to https://huggingface.co/xiuqhou/relation-detr-resnet50
Related issues in original repo:
xiuqhou/Relation-DETR#25
xiuqhou/Relation-DETR#21
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
TODO:
Who can review?
@amyeroberts @qubvel