Skip to content

Commit

Permalink
merge optimized multihead-attention and unoptimized multihead-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
cwallenwein committed Aug 5, 2024
1 parent da3ae52 commit c3efe9b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 118 deletions.
145 changes: 29 additions & 116 deletions model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 15,35 @@
class ScaledDotProductAttention(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()

self.config = config
assert config.n_heads == 1

self.query_weight = nn.Linear(config.d_model, config.d_head, bias=config.attention_bias)
self.key_weight = nn.Linear(config.d_model, config.d_head, bias=config.attention_bias)
self.value_weight = nn.Linear(config.d_model, config.d_head, bias=config.attention_bias)

self._init_weights()

def forward(
self,
x: Float[Tensor, "batch seq_len d_model"],
query: Float[Tensor, "batch n_heads seq_len d_head"],
key: Float[Tensor, "batch n_heads seq_len d_head"],
value: Float[Tensor, "batch n_heads seq_len d_head"],
attention_mask: Bool[Tensor, "batch seq_len"]
):
query = self.query_weight(x)
key = self.key_weight(x)
value = self.value_weight(x)

attention_score = einsum(
key, query,
"batch key_len d_head, batch query_len d_head -> batch query_len key_len"
query, key,
"batch n_heads query_len d_head, batch n_heads key_len d_head -> batch n_heads query_len key_len"
)
mask = torch.where(attention_mask, 0, float("inf"))[:, None, :]
mask = torch.where(attention_mask, 0, float("inf"))
# add dim for broadcasting (n_heads, query_len)
mask = mask[:, None, None, :]

attention_score -= mask
attention_probability = F.softmax(
attention_score / math.sqrt(self.config.d_model),
dim=2
dim=-1
)

# d_model == d_head -> correct output dimension
output = einsum(
attention_probability, value,
"batch query_len key_len, batch seq_len d_head -> batch query_len d_head"
"batch n_heads query_len key_len, batch n_heads seq_len d_head -> batch n_heads query_len d_head"
)

return output

def _init_weights(self):
init_xavier(linear=self.query_weight)
init_xavier(linear=self.key_weight)
init_xavier(linear=self.value_weight)


class MultiHeadAttention(nn.Module):
def __init__(self, config: BertConfig):
Expand All @@ -70,73 56,8 @@ def __init__(self, config: BertConfig):
self.value_weight = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=config.attention_bias)
self.output_weight = nn.Linear(config.d_model, config.d_model, bias=config.attention_bias)

self._init_weights()

def forward(
self,
x: Float[Tensor, "batch seq_len d_model"],
attention_mask: Bool[Tensor, "batch seq_len"]
):
query = rearrange(
self.query_weight(x),
"batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
n_heads=self.config.n_heads
)
key = rearrange(
self.key_weight(x),
"batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
n_heads=self.config.n_heads
)
value = rearrange(
self.value_weight(x),
"batch seq_len (n_heads d_head) -> batch seq_len n_heads d_head",
n_heads=self.config.n_heads
)

attention_score = einsum(
query, key,
"batch query_len n_heads d_head, batch key_len n_heads d_head -> batch n_heads query_len key_len"
)

mask = torch.where(attention_mask, 0, float("inf"))[:, None, None, :]
attention_score -= mask

attention_probability = F.softmax(
attention_score / math.sqrt(self.config.d_head),
dim=3
)

attention_output = einsum(
attention_probability, value,
"batch n_heads query_len key_len, batch key_len n_heads d_head -> batch n_heads query_len d_head"
)

attention_output = rearrange(
attention_output,
"batch n_heads seq_len d_head -> batch seq_len (n_heads d_head)"
)

output = self.output_weight(attention_output)

return output

def _init_weights(self):
init_xavier(linear=self.query_weight)
init_xavier(linear=self.key_weight)
init_xavier(linear=self.value_weight)
init_xavier(linear=self.output_weight)


class MultiHeadAttentionOptimized(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()

self.config = config

self.query_weight = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=config.attention_bias)
self.key_weight = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=config.attention_bias)
self.value_weight = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=config.attention_bias)
self.output_weight = nn.Linear(config.d_model, config.d_model, bias=config.attention_bias)
if self.config.multi_head_attention_implementation == "default":
self.scaled_dot_product_attention = ScaledDotProductAttention(config=config)

self._init_weights()

Expand All @@ -161,17 82,22 @@ def forward(
n_heads=self.config.n_heads
)

attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask[:, None, None, :]
)
if self.config.multi_head_attention_implementation == "default":
attention_output = self.scaled_dot_product_attention(
query,
key,
value,
attention_mask
)
elif self.config.multi_head_attention_implementation == "pytorch":
attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask[:, None, None, :]
)

attention_output = rearrange(
attention_output,
"batch n_heads seq_len d_head -> batch seq_len (n_heads d_head)"
)
attention_output = rearrange(attention_output, "batch n_heads seq_len d_head -> batch seq_len (n_heads d_head)")

output = self.output_weight(attention_output)

Expand All @@ -181,17 107,4 @@ def _init_weights(self):
init_xavier(linear=self.query_weight)
init_xavier(linear=self.key_weight)
init_xavier(linear=self.value_weight)
init_xavier(linear=self.output_weight)


class MultiHeadAttentionBuilder:
def __init__(self, config: BertConfig):
self.config = config

def build(self):
if self.config.multi_head_attention_implementation == "default":
return MultiHeadAttention(self.config)
elif self.config.multi_head_attention_implementation == "pytorch":
return MultiHeadAttentionOptimized(self.config)
else:
raise ValueError("Unknown MultiHeadAttention implementation")
init_xavier(linear=self.output_weight)
5 changes: 3 additions & 2 deletions model/bert.py
Original file line number Diff line number Diff line change
@@ -1,6 1,6 @@
from torch import nn
from model.config import BertConfig
from model.attention import MultiHeadAttentionBuilder
from model.attention import MultiHeadAttention
from model.embedding import BertEmbedding
from model.util import init_xavier
from model.gated_linear_unit import GatedLinearUnit2
Expand All @@ -11,6 11,7 @@


class BertModelForPretraining(nn.Module):
# TODO: implement and test flash attention
# TODO: add an option for weight tying to the config
def __init__(self, config: BertConfig):
super().__init__()
Expand Down Expand Up @@ -97,7 98,7 @@ def __init__(self, config: BertConfig):
GatedLinearUnit2(),
nn.Linear(config.feed_forward_intermediate_size // 2, config.d_model, bias=config.feed_forward_bias),
)
self.multi_head_attention = MultiHeadAttentionBuilder(config).build()
self.multi_head_attention = MultiHeadAttention(config)

self.layer_norm1 = nn.LayerNorm(
normalized_shape=config.d_model
Expand Down

0 comments on commit c3efe9b

Please sign in to comment.