Bi-Directional Cross Attention v.s. Standard Attention

TL; DR;The core difference lie…

TL; DR;
The core difference lies in their scope and input sources: 
Multi-Head Self-Attention (MHSA) is a mechanism for a single sequence to contextualize itself, while Bi-Directional Cross Attention (Bi-CA) is a mechanism for two different sequences to exchange information with each other.

Good morning! Welcome to the Machine Learning Knowledge Playground, let’s make today a GREAT day by learning some machine learning knowledge.
Today we’re going to talk about “Bi-directional Cross Attention” – a mechanism first introduced in the paper “xxx”.

Throughout this post, you’ll learn the following:

  1. What is Bi-directional Cross Attention (BDCA)?
  2. How is BDCA different from standard Multi-head Self-Attention? (MHSA)
  3. What are the pros & cons of BDCA?
  4. When should you use BDCA?
  5. Implementation of BDCA in Python

Alright, buckle up and let’s get started! :rocket icon (TODO)

Abkhazia by Daniil Silantev. Source: https://unsplash.com/photos/green-leafed-tree-jrbyQId7KxU

The specific term “Bi-Directional Cross Attention” is most prominently associated with the 2024 paper “Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers” by Hiller, Ehinger, and Drummond. However, the concept has appeared in other papers with similar names, such as the 2022 “Domain Adaptation via Bidirectional Cross-Attention Transformer” (BCAT) and a 2024 paper on “Figurative Language via Visual Entailment” (FigCLIP). 

Motivation

Limitations of Transformer:

  1. Heavily relies on a large amount of labeled training data, which is difficult to obtain in many real-world applications.
  2. To handle the labeled data scarcity problem, Domain Adaptation (DA) [36] has been proposed to transfer the knowledge learned from a source domain with ample label data to help learn in a target domain with unlabeled data only.

The core idea of DA is to learn a domain-invariant feature representation, which could be both transferable to narrow the domain discrepancy and discriminative for classification in the target domain. To achieve this goal, in past decades many DA methods have been proposed and they can be classified into two main categories:

  1. Domain alignment methods [28,18,2]
  2. Adversarial learning methods [6,31]
Motivation in the original paper:

Vision Transformer (ViT) [4], Data-efficient image Transformers (DeiT) [27], and Swin transformer (Swin).
Different from Convolutional Neural Networks (CNNs) that act on local receptive fields of images, transformers
model long-range dependencies among visual features across the image through the self-attention mechanism.
Due to its advantages in context modeling, vision transformers have obtained excellent performance on various vision tasks, such as image classification [17,27,9], object detection [1], dense prediction [24] and video understanding [7,21].

There are some works [34,33,35,20] to apply transformers to solve DA problems. Some works [34,20,35] directly apply vision transformers but ignore the property of DA problems.

To make the vision transformers more suitable for DA tasks, Xu et al. [33] propose a Cross-Domain Transformer (CDTrans) which consists of a weight-sharing triple-branch transformer to utilize the self-attention
and cross-attention mechanisms for both feature learning and domain alignment.
However, the CDTrans model only considers one-directional cross-attention from the source domain to the target domain but ignores the cross-attention from the target domain to the source domain.

Furthermore, during the training process, the CDTrans model restricts the data in a mini-batch to be source and target
images from the same class, which brings additional difficulties to accurately determine pseudo labels for unlabeled target data and restricts its applications.

To remedy those limitations, we propose a Bidirectional Cross-Attention Transformer (BCAT) to help appropriately learn domain-invariant feature representations.

In BCAT, we construct the bidirectional cross-attention to enhance the transferability of vision transformers.
The bidirectional cross-attention naturally fits knowledge transfer between the source and target domains in both directions, and it enables the implicit feature mixup between domains.
The proposed BCAT model combines the bidirectional cross-attention with the self-attention as quadruple transformer blocks for learning two augmented feature representations. The quadruple transformer blocks can holistically focus on intra- and inter-domain features and blur the boundary between the two domains.

By minimizing the Maximum Mean Discrepancy (MMD) [8] between the learned feature representations in both domains, the BCAT could decrease the domain gap and learn domain-invariant feature representations.
In summary, our contributions are three-fold:

  1. We propose a quadruple transformer block to combine both self-attention and cross-attention to learn augmented feature representations for both source and target domains.
  2. Built on the quadruple transformer block, we propose the BCAT under the DA setting to learn domain-invariant feature representations.
  3. The proposed BCAT outperforms state-of-the-art baseline methods on four benchmark datasets.

1. What is Bi-directional Cross Attention (BDCA)?

2. Differences between MHSA and BCAT

Below is the explanation from Chrome AI (Gemini)
Multi-Head Self-Attention (MHSA)
  • Function: Allows each element in a single input sequence to attend to all other elements within the same sequence.
    It builds context-aware representations by calculating relationships among tokens from a single source.
  • Input/Output:
    • Queries (Q), Keys (K), and Values (V) all originate from the same input data (e.g., all words in one sentence).
  • “Multi-Head”: 
    • This component of the name refers to performing the attention mechanism multiple times in parallel, each with different learned linear transformations (heads), allowing the model to focus on various aspects and relationships within the data simultaneously.
  • Usage:
    • It is the primary building block of the Transformer encoder and is crucial for tasks like understanding the context of a sentence in models like BERT. 
Bi-Directional Cross Attention (Bi-CA)
  • Function: 
    • Facilitates information exchange and integration between two different input sequences. “Bi-directional” means that information flows both ways (Sequence A attends to Sequence B, and Sequence B attends to Sequence A), creating mutually enriched feature maps.
  • Input/Output: 
    • Queries (Q) come from one source sequence, while Keys (K) and Values (V) come from a different source sequence. In the “bi-directional” variant, the process is repeated with the roles of the sequences reversed.
  • Usage: 
    • Highly effective in multimodal tasks (e.g., image-text matching, where image features query text features and vice-versa), sequence-to-sequence models (e.g., machine translation), and question answering, where the importance of elements in one input depends on external factors from another. 

Summary of Differences
Feature Multi-Head Self-Attention (MHSA)Bi-Directional Cross Attention (Bi-CA)
Input SourcesSingle sequenceTwo different sequences
Information FlowWithin the same sequence
(internal contextualization)
Between two distinct sequences (inter-modal/source fusion)
Queries, Keys, ValuesAll derived from the same input sourceQ from one source,
K and V from another source
(and vice-versa for bi-directional)
Primary GoalTo build a rich, contextualized representation of a single input setTo fuse and refine features across heterogeneous information sources
\[Attention(Q, K, V) = softmax({QK^T\over \sqrt{d_k}}) V\]

Where softmax is converting a vector from logits to probability (sums up to 1.0):

\[\sigma({z})_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}\]
def softmax(arr) -> np.ndarray[float]:
    # use keepdims=True for denominator to make shape (B, num_heads, L, 1)
    return np.exp(arr) / np.sum(np.exp(arr), axis=-1, keepdims=True) # (B, num_heads, L, L)

3. Procs & Cons of BCAT

4. When should you use BDCA?

5. Implementation

Now, let’s dive into the most interesting part — coding!
First, let’s review on MHSA implementation in Python (using numpy only):

# Implement transformer with numpy
import numpy as np
import math

d_model = 64 # intermediate embedding dimensions (Paper: 512)
num_heads = 8 # multi-head attention
dropout_rate = 0.1 # randomly turn off 10% of the neurons. This forces the network to learn along different paths. (Regularization)
                    # Which help the model to better "generalize" data instead of accidentally memorize certain data.
batch_size = 30     # Pass multiple examples (sequences) at the same time. 
                    # These multiple examples constitute a "batch".
                    # Faster training & More stable training
                    # Loss function be computed and gradients update ONLY after seeing 30 examples
                    #      -> mini-batch gradient descent
max_sequence_length = 200 # Longest sentence can be passed in at a time through encoders.
                          # If a sentence length is < max_seq_len, we pad the sentence by 0. 
                          # (e.g. My name is Goro given max_len=8)
ffn_hidden = 10   # Feed Forward Network: Expanding the number of neurons from 512 to ffn to
                  # "learn additional information" (Paper: 1024)
num_layers = 5    # Number of encoder units in the architecture (Nx)

def softmax(arr) -> np.ndarray[float]:
    # use keepdims=True for denominator to make shape (B, num_heads, L, 1)
    return np.exp(arr) / np.sum(np.exp(arr), axis=-1, keepdims=True) # (B, num_heads, L, L)

# Test softmax
def test_softmax():
  arr = np.array([-1, 2, 4])
  res = softmax(arr)
  assert(np.sum(res, axis=-1)==1)


def scaled_dot_product_attention(q, k, v, mask=None)-> np.ndarray[float]:
    """This is the main operation of the attention mechanism.
       A  = Softmax(QK^T / sqrt(d_k))
       Embedding = A * V
      
        B: batch_size
        N: num_heads
        L: max sequence length
        E: embedding dimension (512)
        q: (B, N, L, E//N)
        k: (B, N, L, E//N)
        v: (B, N, L, E//N)
        mask: (B, N, L, L)

        out: (B, N, L, L)
    """
    d_k = q.shape[-1] # numpy array use shape, NOT size()
    scaled = np.matmul(q, np.transpose(k, (-4, -3, -1, -2))) / math.sqrt(d_k) #(B, num_heads, L,L)
    if mask is not None:
        scaled += mask
    attention = softmax(scaled)   # (B, num_heads, L, L)
    out = np.matmul(attention, v) # (B, num_heads, L, d_model)
    return out                    # (B, num_heads, L, d_model)

# Implement Linear layer with Numpy
class LinearLayer:
    def __init__(self, input_size, output_size):
        # Initialize weights with random values and biases with zeros
        self.weights = np.random.randn(input_size, output_size) 
        self.bias = np.zeros(output_size)

    def forward(self, input_data):
       # Linear transformation: output = input * weights + bias
        self.input_data = input_data # Store input for backward pass
        print(f"input_data.shape {input_data.shape}")
        print(f"self.weights.shape {self.weights.shape}")
        return np.dot(input_data, self.weights) + self.bias

    def backward(self, output_gradient, learning_rate):
        # Compute gradients
        weights_gradient = np.dot(self.input_data.T, output_gradient)
        bias_gradient = np.sum(output_gradient, axis=0)
        input_gradient = np.dot(output_gradient, self.weights.T)

        # Update weights and biases 
        self.weights -= learning_rate * weights_gradient
        self.bias -= learning_rate * bias_gradient

        return input_gradient

class MultiHeadAttention:
    def __init__(self, input_dim, d_model, num_heads):
        self.input_dim = input_dim
        self.d_model = d_model               # 512
        self.num_heads = num_heads           # 8
        self.head_dim = d_model // num_heads # 64
        self.qkv_layer = LinearLayer(input_dim, 3*d_model)
        self.linear_layer = LinearLayer(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.shape     # 30 x 200 x 512
        qkv = self.qkv_layer.forward(x)  # First linear layer to form q, k, v from input x # 30 x 200 x 1536
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim) # 30 x 200 x 8 x 192
        qkv = np.permute_dims(qkv, (0, 2, 1, 3))             # 30, 8, 200, 192
        q, k, v, = np.array_split(qkv, 3, axis=-1)           # 30, 8, 200, 64
        values = scaled_dot_product_attention(q, k, v, mask) # 30 x 8 x 200 x 64
        values = np.permute_dims(values, (0, 2, 1, 3))       # 30, 200, 8, 64
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim) # Concatenate values - 30, 200, 512
        out = self.linear_layer.forward(values)              # 30, 200, 512
        return out

class LayerNormalization:
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape = parameters_shape # embedding dimension - 512
        self.eps = eps # Avoid division by zero error.
        self.gamma = np.ones(self.parameters_shape) # [512] learnable parameters
        self.beta = np.zeros(self.parameters_shape) # [512] learnable parameters
    
    def forward(self, inputs): # inputs shape: 30, 200, 512
        mean = np.mean(inputs, axis=-1, keepdims=True)             # 30, 200, 1 - one mean per word
        var = np.mean(((inputs-mean)**2), axis=-1, keepdims=True)  # 30, 200, 1 - one std per word
        std = np.sqrt(var + self.eps)                              # 30, 200, 1
        y = (inputs - mean) / std                                  # 30, 200, 512 (Broadcasting mean & std matrices)
        out = self.gamma * y + self.beta                           # 30, 200, 512
        return out



if __name__ == "__main__":
    # Create an encoder object
    # encoder = Encoder(d_model, ffn_hidden, num_heads, dropout_rate, num_layers)
    multi_head_self_attn = MultiHeadAttention(d_model, d_model, num_heads)
    embedding_dim = d_model
    x = np.random.randn(batch_size, max_sequence_length, embedding_dim)
    print(f"Input np array x shape: {x.shape}") # (30, 200, 512)
    out = multi_head_self_attn.forward(x)
    print(f"MultiHead Self-Attention output shape: {out.shape}")

    layer_norm = LayerNormalization(x.shape)
    out_normed = layer_norm.forward(x+out)
    print(f"After layernorm: {out_normed.shape}")

# Positional Encoding

# Encoder

# Decoder

# Transformer

Now, let’s look at a Bi-directional Cross Attention Block:

TODO: Code Snippet