Multimodal Bottleneck Transformer for Depression Recognition

Introduction

Google Researchers introduces a transformer-based fusion model named Multimodal Bottleneck Transformer (MBT) for fusing visual and audio features. Traditional methods try to concatenate the sequences of two or more embeddings from different modalities in temporal dimension. It consumes a lot of resources (time, memory, and computation) due to the quaradtic complexity of attention mechanism. MBT devises a new special token named bottleneck tokens. It is intermediary token to transfer information between two modalities, instead of paying attention to the whole concatenated sequence.

MBT explaination from Google Blog

Implementation

My implementation uses Pytorch framework. Some code from Annotated Transformer is borrowed. Some functions are imported from timm

Define a MBT module

The model includes 2 projection layers, 2 encoder layer groups corresponding to 2 modalities.

class MBT(nn.Module):
    def __init__(self, v_dim, a_dim, embed_dim, num_bottle_token=4, bottle_layer=1
                , project_type='minimal', num_head=8, drop=.1, num_layers=4, feat_dim=128):
        super().__init__()
        self.num_layers = num_layers
        self.bottle_layer = bottle_layer
        self.num_bottle_token = num_bottle_token
        self.project_type_conv1d = (project_type=='conv1d')
        
        self.audio_prj = get_projection(a_dim, embed_dim, project_type)
        self.video_prj = get_projection(v_dim, embed_dim, project_type)

        ff = embed_dim
        layer = EncoderLayer(embed_dim, num_head, ff, drop)
        self.video_layers = clones(layer, num_layers)
        self.audio_layers = clones(layer, num_layers)

        self.mask_cls = nn.Parameter(torch.ones(1, 1))
        self.mask_bot = nn.Parameter(torch.ones(1, num_bottle_token))
        self.bot_token = nn.Parameter(torch.zeros(1, num_bottle_token, embed_dim))
        self.acls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.vcls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.norma = LayerNorm(layer.size)
        self.normv = LayerNorm(layer.size)
        self.head = nn.Linear(embed_dim, feat_dim)

        trunc_normal_(self.bot_token, std=.02)
        trunc_normal_(self.acls_token, std=.02)
        trunc_normal_(self.vcls_token, std=.02)

Define a forward function

    def forward(self, a, v, m):
        '''
            a : (batch_size, seq_len, a_dim)
            v : (batch_size, seq_len, v_dim)
        '''
        B = a.shape[0]
        if self.project_type_conv1d:
            a = self.audio_prj(a.transpose(1, 2)).transpose(1, 2)
            v = self.video_prj(v.transpose(1, 2)).transpose(1, 2)
        else:
            a = self.audio_prj(a)
            v = self.video_prj(v)

        
        acls_tokens = self.acls_token.expand(B, -1, -1)
        a = torch.cat((acls_tokens, a), dim=1)
        
        vcls_tokens = self.vcls_token.expand(B, -1, -1)
        v = torch.cat((vcls_tokens, v), dim=1)

        mask_cls = self.mask_cls.expand(B, -1)
        mask = torch.cat((mask_cls, m), dim=1)

        for i in range(self.bottle_layer):
            v = self.video_layers[i](v, mask)
            a = self.audio_layers[i](a, mask)

        mask_bot = self.mask_bot.expand(B, -1)
        mask = torch.cat((mask_bot, mask), dim=1)
        bot_token = self.bot_token.expand(B, -1, -1)

        for i in range(self.bottle_layer, self.num_layers):
            a = torch.cat((bot_token, a), dim=1)
            v = torch.cat((bot_token, v), dim=1)

            v = self.video_layers[i](v, mask)
            a = self.audio_layers[i](a, mask)

            bot_token = (a[:, :self.num_bottle_token] + v[:, :self.num_bottle_token]) / 2
            a = a[:, self.num_bottle_token:]
            v = v[:, self.num_bottle_token:]

        a = self.norma(a)
        v = self.normv(v)
        feat = (a[:, 0] + v[:, 0]) / 2

        return feat