Early Exiting for Image Captioning Models

It’s all about speed!

But first, — WHAT IS EARLY EXITING?

As described in this paper as “BranchyNet”, these neural networks have exit branches at specific locations throughout the network.

Source: https://arxiv.org/pdf/1709.01686v1.pdf

What’s the main idea of these branches then?

Suppose you have 100 layers in a neural network. After passing the input image through the input layer, it must go through all layers to generate an output. But, you realize that for some samples, you can classify them accurately after just passing through n layers (for example, 5 layers). In this case, feed-forward passes through all other layers is wasteful in terms of computational resources. So, we exit at the nth layer through a branch.

This is especially useful in large networks that have 100s of layers in them.

It even allows large networks like AlexNet to be deployed for real-time inferences. But when it comes to Image captioning models, there’s a catch.

Image Captioning Models

Image captioning models are different. They are basically **Encoder-Decoder models — **I won’t go deep into the Encoder-Decoder architecture, but on the surface level, the encoder part creates a representation/intermediate vector based on the input, and the decoder predicts the token(word) at each timestep.

Source: https://openaccess.thecvf.com/content/CVPR2022/papers/Fei_DeeCap_Dynamic_Early_Exiting_for_Efficient_Image_Captioning_CVPR_2022_paper.pdf

Speaking of these encoder-decoder models, can we also apply early-exit branches to them? Yes, but with some modifications.

These “modifications” are made possible, thanks to Fei, Zhengcong, et al. and their research work “DeeCap: Dynamic Early Exiting for Efficient Image Captioning”.

In this paper, the authors saw that conventional early exiting is not possible in the case of encoder-decoder networks, since at the shallow level, the representations generated by the decoder layers are insufficient to predict tokens accurately, since they lack high-level semantic information.

That is, let’s say that we are early exiting after n decoder layers. The representation produced by those n decoders is not sufficient to predict accurately, and we still need the outputs from the deeper decoder layers after them.

The workaround? Try to replicate the outputs of the deeper decoder layers!

Source: https://openaccess.thecvf.com/content/CVPR2022/papers/Fei_DeeCap_Dynamic_Early_Exiting_for_Efficient_Image_Captioning_CVPR_2022_paper.pdf

First, let us assume that we are going to exit early at layer m. So this means that for layer ‘k’, where k≤m we have done the forward pass. And for every layer ‘k’ where k>m, we have to replicate the representations.

This replication is done with the help of imitation learning. Let’s say we want to replicate the representation of layer k where k>m. The k-th imitation network paired with the k-th layer inputs the truly hidden state h_m and outputs:

Imitated Representation

where MLP is a multi-layer perception. But these MLPs also need to be trained accordingly, so that they can output the imitation representations. This is done with the help of cosine similarity as our main loss function.

Cosine Similarity Loss Function

Now, we found a way to get the output of the deeper decoder layers using imitation learning, and how to train the MLP that facilitates imitation learning.

But we are still a few steps away from achieving early-exiting. Now that we have the representations of all the decoder layers, we can group them into 2 kinds: h-shallow are the layers that performed the actual feed-forward, and h-deep are the layers that used imitation layers.

Now, we need to aggregate these h-shallow and h-deep layers. This aggregation is done with the help of a fusion strategy. There different fusion strategies that we can use are:

  1. Average Strategy
  2. Concatenation Strategy
  3. Attention-Pooling Strategy
  4. Sequential Strategy

Now that we finished aggregating the h-shallow layers separately, and h-deep layers separately, we now need to know how to merge them.

This calls for a Gate Decision Mechanism.

So far we know these facts:

  1. h-deep layers representations are produced using imitation networks.
  2. h-shallow layers representations are produced by the actual feed-forward mechanism.

Since the h-shallow layers have actually seen the input data, they are more likely to be reliable. So, we need to find a balance factor to control how much information we need from h-shallow vs h-deep. This balance factor α is calculated using the following formula:

With this balancing factor, we can finally represent the final merged information as

After applying the softmax function to zₘ, we can calculate the output class distribution as

But, we also need to dynamically learn the balancing factor. The corresponding loss function is:

Thus, the final training objective is:

That’s a lot of equations!

Now, we can finally talk about training and inference!

Remember that we calculated pₘ previously? By calculating the entropy of pₘ as H(pₘ), we calculate the prediction confidence of the current token at layer m. We then compare the prediction confidence with a manually selected threshold τ. Once this value is lower than the threshold, we can exit!

Also, note that the shallow layer’s weights will be updated more frequently because they are receiving more update signals than the deep layers. So, while calculating the cross entropy loss of each layer, we reweight them using the following equation:

Added to this, the authors argue that updating all parameters at the same time at each time step will damage the well-trained features during fine-tuning. Therefore, they try to freeze the parameters at each layer with a probability p, and it linearly decreases from 1 to 0 with depth.

And that’s how we add early-exit to encoder-decoder based models!

References:

  1. Fei, Zhengcong, et al. “DeeCap: Dynamic Early Exiting for Efficient Image Captioning.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
  2. Teerapittayanon, Surat, Bradley McDanel, and Hsiang-Tsung Kung. “Branchynet: Fast inference via early exiting from deep neural networks.” 2016 23rd International Conference on Pattern Recognition (ICPR). IEEE, 2016.