9  Day8 seq2seq with attention

9.1 Attention

  • Attention enhances Seq2Seq models by allowing the decoder to selectively focus on different parts of the input sequence at each decoding step
  • Instead of relying on a fixed-length context vector from the encoder, the decoder computes a dynamic context vector by weighting the encoder’s hidden states based on their relevance to the current decoding step.

9.1.1 Fixed-length context vector

  • Information bottleneck
    • Long input sequences result in the loss of critical details since all information is compressed into a single vector.
  • Vanishing context
    • For longer sequences, earlier tokens in the input become less relevant in the compressed context vector.
  • Difficulty in aligning input and output
    • There’s no direct alignment between specific input tokens and the corresponding output tokens.

9.1.2 Attention Mechanism

The attention mechanism addresses these issues by dynamically focusing on different parts of the input sequence for each output token.

  • Dynamic Context
    • Instead of relying on a single fixed-length vector, the context vector is dynamically computed by “attending” to specific encoder hidden states.
  • Better Handling of Long Sequences
    • Attention allows the decoder to retrieve relevant information from any part of the input, regardless of sequence length.
  • Alignment Between Input and Output
    • Attention naturally creates alignments, making it easier to map input tokens to corresponding output tokens.

9.1.3 Example

<SOS>MAGKL...<EOS>
<SOS>MIAEE...<EOS>
<SOS>MNNQK...<EOS>
<SOS>MFHAE...<EOS>
...
  • Attention weight = Which one is the most correlated with <SOS>?
  • Context vector (Attention value) = Encoder hidden state + Attention weight
  • Final hidden state of decoder = Context vector + Initial hidden state of decoder (concatenation)

9.1.4 Attention Mechanism Steps (Dot product attention)

  • Alignment Scores
    • Compute alignment scores between the decoder’s current hidden state and each encoder hidden state.
    • \(h_i\) is the encoder hidden state and \(s_{t-1}\) is the decoder’s hidden state. \[ e_{t,i} = h_i^\top W s_{t-1} \]
  • Attention Weights
    • Normalize alignment scores with softmax to get attention weights. \[ \alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^{T_x} \exp(e_{t,j})} \]
  • Context Vector
    • Compute the context vector as the weighted sum of encoder hidden states. \[ c_t = \sum_{i=1}^{T_x} \alpha_{t,i} h_i \]
  • Combine Context and Decoder State
    • Use the context vector \(c_t\) and the decoder’s current hidden state to predict the next token.

import numpy as np

# Example encoder outputs (keys and values)
L = 4  # Sequence length
D_h = 5  # Hidden dimension

# Simulated encoder hidden states (Keys and Values)
encoder_outputs = np.random.randn(L, D_h)  # Shape: [L, D_h]

# Example decoder hidden state (Query)
decoder_hidden = np.random.randn(D_h)  # Shape: [D_h]

# Compute Scaled Dot-Product
# Query (1, D_h) @ Keys (L, D_h)^T -> (1, L)
attention_scores = np.dot(decoder_hidden, encoder_outputs.T)  # Shape: [1, L]

# Scale scores by sqrt of hidden dimension
scaled_attention_scores = attention_scores / np.sqrt(D_h)  # Shape: [1, L]

# Compute Attention Weights
# Apply softmax to scaled scores
attention_weights = np.exp(scaled_attention_scores - np.max(scaled_attention_scores))  # Stabilize softmax
attention_weights /= np.sum(attention_weights, axis=-1, keepdims=True)  # Normalize, Shape: [1, L]

# Compute Context Vector
# Attention Weights (1, L) @ Values (L, D_h) -> (1, D_h)
context_vector = np.dot(attention_weights, encoder_outputs)  # Shape: [1, D_h]

# Output
print("Encoder Outputs (Keys and Values):", encoder_outputs)
print("Decoder Hidden State (Query):", decoder_hidden)
print("Attention Scores:", attention_scores)
print("Scaled Attention Scores:", scaled_attention_scores)
print("Attention Weights (Softmax):", attention_weights)
print("Context Vector:", context_vector)
Encoder Outputs (Keys and Values): [[ 0.19573384 -0.49648183  0.1824915   0.10484739 -1.58569094]
 [ 0.64890387  0.0082949   1.41048922 -0.47920967 -0.7364587 ]
 [ 0.78422808 -0.44381825  0.67551003  1.23725628  0.2371081 ]
 [ 0.51214269  0.40668439 -0.18241258  0.28852113  0.2699827 ]]
Decoder Hidden State (Query): [ 0.40266189 -0.79767357  0.01391088  1.08446925  0.81412344]
Attention Scores: [-0.69986078 -0.84496306  2.21399944  0.41197298]
Scaled Attention Scores: [-0.31298726 -0.37787897  0.99013065  0.18423992]
Attention Weights (Softmax): [0.1377016  0.12904966 0.50684584 0.2264029 ]
Context Vector: [ 0.62412702 -0.20016879  0.50823324  0.64501603 -0.13208979]

9.1.5 Simulation data

  • For the source sequence, insert two motif (RGD, ATL) in random amino acid sequence
  • Target sequences are generated with patterns R –> K, G –> A for RGD motif, and A–>V, T–> for ATL motif, respectively
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# Parameters
k = 3  # Length of k-mers
num_sequences = 100  # Number of full amino acid sequences
sequence_length = 20  # Length of each amino acid sequence (increased to accommodate two motifs)
batch_size = 10  # Batch size for DataLoader

# Define amino acids
amino_acids = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]

# Generate random amino acid sequences
def generate_random_sequences_with_motifs(num_sequences, sequence_length):
    sequences = []
    for _ in range(num_sequences):
        seq = list(np.random.choice(amino_acids, size=sequence_length))
        # Insert biologically meaningful patterns (e.g., conserved motifs)
        if len(seq) > 10:
            seq[5:8] = ["R", "G", "D"]  # Motif 1
            seq[12:15] = ["A", "T", "L"]  # Motif 2
        sequences.append("".join(seq))
    return sequences

random_sequences = generate_random_sequences_with_motifs(num_sequences, sequence_length)

# Extract k-mers from sequences
def generate_kmers(sequence, k):
    return [sequence[i:i + k] for i in range(len(sequence) - k + 1)]

source_kmers = [generate_kmers(seq, k) for seq in random_sequences]

# Define a biological transformation for the target sequences
def transform_with_evolution_two_motifs(kmers):
    transformed_kmers = []
    for kmer in kmers:
        if "RGD" in kmer:
            transformed_kmers.append(kmer.replace("R", "K").replace("G", "A"))  # Transformation for Motif 1
        elif "ATL" in kmer:
            transformed_kmers.append(kmer.replace("A", "V").replace("T", "C"))  # Transformation for Motif 2
        else:
            transformed_kmers.append(kmer[::-1])  # Reverse non-conserved k-mers
    return transformed_kmers

target_kmers_with_evolution = [transform_with_evolution_two_motifs(kmers) for kmers in source_kmers]

# Build vocabulary for k-mers
unique_kmers = set(kmer for seq in source_kmers + target_kmers_with_evolution for kmer in seq)
vocab = {kmer: idx for idx, kmer in enumerate(unique_kmers)}
vocab_size = len(vocab)

# Encode k-mers into indices
def encode_kmers(kmers, vocab):
    return [vocab[kmer] for kmer in kmers]

encoded_source_sequences = [encode_kmers(kmers, vocab) for kmers in source_kmers]
encoded_target_with_evolution = [encode_kmers(kmers, vocab) for kmers in target_kmers_with_evolution]

# Collate function for padding
def collate_batch(batch):
    source_batch, target_batch = zip(*batch)
    src_lengths = [len(seq) for seq in source_batch]
    tgt_lengths = [len(seq) for seq in target_batch]

    max_src_len = max(src_lengths)
    max_tgt_len = max(tgt_lengths)

    padded_src = torch.zeros(len(source_batch), max_src_len, dtype=torch.long)
    padded_tgt = torch.zeros(len(target_batch), max_tgt_len, dtype=torch.long)

    for i, seq in enumerate(source_batch):
        padded_src[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
    for i, seq in enumerate(target_batch):
        padded_tgt[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)

    return padded_src, src_lengths, padded_tgt, tgt_lengths

# Create datasets and dataloaders
class Seq2SeqKmerDataset(Dataset):
    def __init__(self, source_sequences, target_sequences):
        self.source_sequences = source_sequences
        self.target_sequences = target_sequences

    def __len__(self):
        return len(self.source_sequences)

    def __getitem__(self, idx):
        return self.source_sequences[idx], self.target_sequences[idx]

# Dataset
dataset_with_evolution = Seq2SeqKmerDataset(encoded_source_sequences, encoded_target_with_evolution)

# Dataloader
dataloader_with_evolution = DataLoader(dataset_with_evolution, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

# Example batch
for batch in dataloader_with_evolution:
    src, src_lengths, tgt, tgt_lengths = batch
    print("Source batch shape:", src.shape)
    print("Target batch shape:", tgt.shape)
    print("Source lengths:", src_lengths)
    print("Target lengths:", tgt_lengths)
    break

print("Vocabulary size:", vocab_size)
print("Initial seuqence:", random_sequences[0])
print("Source k-mers:", source_kmers[0])
print("Target k-mers:", target_kmers_with_evolution[0])
Source batch shape: torch.Size([10, 18])
Target batch shape: torch.Size([10, 18])
Source lengths: [18, 18, 18, 18, 18, 18, 18, 18, 18, 18]
Target lengths: [18, 18, 18, 18, 18, 18, 18, 18, 18, 18]
Vocabulary size: 2105
Initial seuqence: PQFSNRGDTVFRATLDASTH
Source k-mers: ['PQF', 'QFS', 'FSN', 'SNR', 'NRG', 'RGD', 'GDT', 'DTV', 'TVF', 'VFR', 'FRA', 'RAT', 'ATL', 'TLD', 'LDA', 'DAS', 'AST', 'STH']
Target k-mers: ['FQP', 'SFQ', 'NSF', 'RNS', 'GRN', 'KAD', 'TDG', 'VTD', 'FVT', 'RFV', 'ARF', 'TAR', 'VCL', 'DLT', 'ADL', 'SAD', 'TSA', 'HTS']
import torch
import torch.nn as nn
import torch.nn.functional as F


class DotProductAttention(nn.Module):
    def __init__(self):
        super(DotProductAttention, self).__init__()

    def forward(self, hidden, encoder_outputs):
        # Ensure hidden is [batch_size, hidden_dim * 2]
        if hidden.dim() == 2:  # [batch_size, hidden_dim * 2]
            hidden = hidden.unsqueeze(1)  # [batch_size, 1, hidden_dim * 2]

        # Compute dot-product scores
        scores = torch.bmm(encoder_outputs, hidden.transpose(1, 2)).squeeze(2)  # [batch_size, seq_len]

        # Compute attention weights
        attention_weights = F.softmax(scores, dim=1)  # [batch_size, seq_len]

        # Compute context vector as weighted sum of encoder outputs
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_dim * 2]

        return context_vector.squeeze(1), attention_weights  # [batch_size, hidden_dim * 2], [batch_size, seq_len]


class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Seq2SeqEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)

    def forward(self, x, lengths):
        # Embed the input sequences
        embedded = self.embedding(x)  # [batch_size, seq_len, embedding_dim]

        # Pack padded sequences for LSTM
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, batch_first=True, enforce_sorted=False)

        # Bidirectional LSTM
        packed_output, (hidden, cell) = self.lstm(packed_embedded)

        # Unpack LSTM outputs
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        # Concatenate hidden states from both directions
        hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)  # [batch_size, hidden_dim * 2]

        return encoder_outputs, hidden

class Seq2SeqDecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Seq2SeqDecoderWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + hidden_dim * 2, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2 + hidden_dim, vocab_size)
        self.attention = DotProductAttention()

    def forward(self, input_token, hidden, encoder_outputs):
        # Embed the input token
        embedded = self.embedding(input_token).unsqueeze(1)  # [batch_size, 1, embedding_dim]

        # Ensure hidden matches encoder_outputs dimensions
        if hidden.dim() == 3:  # Handle multi-layer LSTM
            hidden = hidden.transpose(0, 1).contiguous().view(hidden.size(1), -1)  # [batch_size, hidden_dim * 2]

        # Compute attention context vector
        context_vector, attention_weights = self.attention(hidden, encoder_outputs)

        # Concatenate embedded input and context vector
        lstm_input = torch.cat((embedded, context_vector.unsqueeze(1)), dim=2)  # [batch_size, 1, embedding_dim + hidden_dim * 2]

        # Pass through the LSTM
        output, (hidden, _) = self.lstm(lstm_input)

        # Predict the next token
        prediction = self.fc(torch.cat((output.squeeze(1), context_vector), dim=1))  # [batch_size, vocab_size]

        return prediction, hidden, attention_weights



class Seq2SeqWithAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = Seq2SeqEncoder(vocab_size, embedding_dim, hidden_dim)
        self.decoder = Seq2SeqDecoderWithAttention(vocab_size, embedding_dim, hidden_dim)

    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        encoder_outputs, hidden = self.encoder(src, src_lengths)

        batch_size = tgt.size(0)
        tgt_len = tgt.size(1)
        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(src.device)

        input_token = tgt[:, 0]  # Start token

        for t in range(1, tgt_len):
            output, hidden, attention_weights = self.decoder(input_token, hidden, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            input_token = tgt[:, t] if teacher_force else output.argmax(1)

        return outputs, attention_weights

# Model parameters
embedding_dim = 16
hidden_dim = 32
learning_rate = 0.001
epochs = 10

# Initialize model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqWithAttention(vocab_size, embedding_dim, hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in dataloader_with_evolution:
        src, src_lengths, tgt, tgt_lengths = batch
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()
        outputs, _ = model(src, src_lengths, tgt)

        # Reshape outputs and compute loss
        outputs = outputs[:, 1:].reshape(-1, vocab_size)  # Skip <SOS>
        tgt = tgt[:, 1:].reshape(-1)  # Skip <SOS>
        loss = criterion(outputs, tgt)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [10, 64] but got: [10, 32].
import matplotlib.pyplot as plt

def plot_attention(attention_weights, input_sequence, output_sequence):
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_weights.cpu().detach().numpy(), cmap="viridis", aspect="auto")
    plt.colorbar()
    plt.xticks(range(len(input_sequence)), input_sequence, rotation=90)
    plt.yticks(range(len(output_sequence)), output_sequence)
    plt.title("Attention Weights")
    plt.xlabel("Input Sequence")
    plt.ylabel("Output Sequence")
    plt.show()

# Example visualization during inference
src_example, src_len_example, tgt_example, tgt_len_example = next(iter(dataloader_with_evolution))
src_example = src_example.to(device)
tgt_example = tgt_example.to(device)

model.eval()
with torch.no_grad():
    outputs, attention_weights = model(src_example, src_len_example, tgt_example, teacher_forcing_ratio=0)
    plot_attention(
        attention_weights[0],  # Attention weights for the first sequence in the batch
        [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in src_example[0].cpu().numpy()],
        [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in tgt_example[0].cpu().numpy()]
    )

9.2 Transformer