The main benefits of Deformable DETR are:
Reduces the memory and computational cost of DETR through deformable attention where the attention modules only attend to a small set of key sampling points around a reference. It can achieve better performance than DETR (especially on small objects) with 10x less epochs. Additionally, it can have higher resolution features since you don’t need to attend to the full input.
Deformable Convolution is able to attend to sparse spatial locations with a convolution, but it lacks the element relation modeling of attention (all elements can relate to all other elements). This paper introduces Deformable Attention to attend to a small set of sampling locations as a pre-filter for prominent key elements out of all the feature pixel maps. They are able to use multi-scale features without requiring an FPN.
Related Work
There are three main approaches to improving the complexity of attention:
- Restrict the attention pattern to be a fixed local window. This results in decreased complexity, but it loses global information.
- Learn data-dependent sparse attention. Deformable DETR uses this approach.
- Use the low-rank property in self-attention to reduce the complexity (use linear algebra to reduce the matrix multiplies to the most significant elements).
Deformable Attention
Deformable DETR Version of Deformable Attention
Deformable DETR was inspired by Deformable Convolution and modifies the attention module to learn to focus on a small fixed set of sampling points predicted from the features of query elements. It is just slightly slower than traditional convolution under the same FLOPs. It is slower because you are accessing memory in a random order (vs. a conv layer which always accesses memory in the same order) so memory optimization and caching doesn’t work great.
The above diagram shows a single scale feature map () and 3 attention heads in green, blue, and yellow (). You use sampling points.
Notation:
Link to original
- is the number of query features () you have. For the encoder, this is (you pass your image through a backbone that gives a downsampled feature map. You then flatten this feature map into a sequence of length . See DETR). For the decoder, this is the number of objects you want to detect ().
- is the number of sampling points for each query feature to attend to the feature map. It should be much less than
grid_height
*grid_width
.- is the feature vector of query element . For the encoder, this could be one pixel of the input feature map . For the decoder, this could be an object query.
- is the normalized coordinates of the reference point for each query element . These are normalized from the reference point . For the encoder, the reference point can be the pixel of the input feature map. For the decoder, this can be predicted from its object query embedding via a linear projection + sigmoid.
- are the input feature maps extracted by a CNN backbone at multiple scales. In the above diagram only a single feature map is used.
Bounding Box Prediction
You start with a reference point that is used as the initial guess of the box center. The reference point is predicted as a 2D normalized coordinate () using a linear project + sigmoid that takes an object query as input.
The reference point is used as the initial guess as a box center and then a detection head predicts the relative offsets w.r.t. the reference points:
where:
- and are the sigmoid and inverse sigmoid functions
- will be normalized coordinates of the form .
Iterative Bounding Box Refinement
Each decoder layer refines the bounding boxes based on the predictions from the previous layers. Suppose there are decoder layers (e.g. ), then given a normalized bounding box predicted by the ()-th decoder layer, the -th decoder layer will refine the box as:
The initial box uses the reference point () as the center with width = 0.1 and height = 0.1. Using the above notation this is:
To stabilize training, the gradients only back propagate through and are blocked at .
Two-Stage Deformable DETR
In the original DETR, object queries in the decoder are irrelevant to the current image (they are learned during training and do not change depending on the inference image). Inspired by two-stage object detectors, Deformable DETR explores a variant of Deformable DETR for generating region proposals as the first stage. The generated region proposals will be fed into the decoder as object queries for further refinement, forming a two-stage Deformable DETR.
In the first stage, given the output feature maps of the encoder, a detection head is applied to each pixel. The detection head consists of:
- A 3-layer FFN for bounding box regression.
- A linear projection for bounding box binary classification (this is similar to an objectness score - 0 means background and 1 means foreground).
Given the predicted bounding boxes from the first stage, the top scoring bounding boxes are picked as region proposals. In the second stage, these proposals are fed into the decoder as the initial boxes for iterative bounding box refinement.
Loss
Focal loss with loss weight of 2 is used for bounding box classification.