Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Accelerators] Update torch.nn.functional multi_head_attention_forward(). Add pastpath #88912

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions test/test_jit.py
Expand Up @@ -15059,6 +15059,67 @@ def jit_multihead_attn_forward(query, # type: Tensor
# print(jit_out / py_out - 1)
self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)

def test_functional_multi_head_attn_fast_path(self):
src_l = 3
bsz = 5
embed_size = 8
nhead = 2
query = key = value = torch.rand((src_l, bsz, embed_size))
multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
multi_head_attn_nn = multi_head_attn_nn.eval()

with torch.no_grad():
fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
embed_size, nhead,
multi_head_attn_nn.in_proj_weight,
multi_head_attn_nn.in_proj_bias,
multi_head_attn_nn.bias_k,
multi_head_attn_nn.bias_v,
multi_head_attn_nn.add_zero_attn,
multi_head_attn_nn.dropout,
multi_head_attn_nn.out_proj.weight,
None,
training=False)[0]

query_fb = key_fb = value_fb = query.transpose(1, 0)
py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0)
mha = torch.jit.script(multi_head_attn_nn)
jit_out = mha(query_fb, key_fb, value_fb)[0].transpose(1,0)
torch.testing.assert_close(fn_out, py_out)
torch.testing.assert_close(fn_out, jit_out)

def test_functional_multi_head_attn_fast_path_None(self):
random_seed = 1 # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

np.random.seed(random_seed)

src_l = 3
bsz = 5
embed_size = 8
nhead = 2
query = key = value = torch.rand((src_l, bsz, embed_size))
multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
multi_head_attn_nn = multi_head_attn_nn.eval()

with torch.no_grad():
fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
embed_size, nhead,
multi_head_attn_nn.in_proj_weight,
multi_head_attn_nn.in_proj_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change? Is this to avoid the assertion that you added elsewhere to prevent bias from being None?

Projection bias can be None, see the test and fix in this as per issue raised. #88970

multi_head_attn_nn.bias_k,
multi_head_attn_nn.bias_v,
multi_head_attn_nn.add_zero_attn,
multi_head_attn_nn.dropout,
multi_head_attn_nn.out_proj.weight,
None,
training=False)[0]

query_fb = key_fb = value_fb = query.transpose(1, 0)
py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0)
self.assertEqual(fn_out.shape, py_out.shape)

def test_torchscript_multi_head_attn_fast_path(self):
src_l = 3
bsz = 5
Expand Down
238 changes: 235 additions & 3 deletions torch/nn/functional.py
Expand Up @@ -4864,6 +4864,10 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.

# Shape check.
assert query.dim() == 3 or query.dim() == 2, \
(f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")

is_batched = False
if query.dim() == 3:
# Batched Inputs
is_batched = True
Expand Down Expand Up @@ -4898,9 +4902,6 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert attn_mask.shape == expected_shape, \
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")

return is_batched

Expand Down Expand Up @@ -4992,6 +4993,237 @@ def multi_head_attention_forward(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""

why_not_fast_path = _can_use_fastpath(query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training,
key_padding_mask,
attn_mask)
if not why_not_fast_path:

if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads)
assert in_proj_weight is not None
assert in_proj_bias is not None
assert out_proj_bias is not None
attn_output, attn_output_weights = torch._native_multi_head_attention(
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
merged_mask,
need_weights,
average_attn_weights,
mask_type)
return attn_output.transpose(1, 0), attn_output_weights

any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
f"The fast path was not hit because {why_not_fast_path}")

return _multi_head_attention_forward_impl(query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight,
q_proj_weight,
k_proj_weight,
v_proj_weight,
static_k,
static_v,
average_attn_weights)

def _merge_masks(attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
query: Tensor, num_heads: int,) -> Tuple[Optional[Tensor], Optional[int]]:
r"""
Determine mask type and combine masks if necessary. If only one mask is provided, that mask
and the corresponding mask type will be returned. If both masks are provided, they will be both
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
and mask type 2 will be returned
Args:
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
Returns:
merged_mask: merged mask
mask_type: merged mask type (0, 1, or 2)
"""
mask_type: Optional[int] = None
merged_mask: Optional[Tensor] = None
if attn_mask is not None:
mask_type = 0
merged_mask = attn_mask
if key_padding_mask is not None:
mask_type = 1
merged_mask = key_padding_mask
if (attn_mask is not None) and (key_padding_mask is not None):
# In this branch query can't be a nested tensor, so it has a shape
batch_size, seq_len, _ = query.shape
mask_type = 2
key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \
.expand(-1, num_heads, -1, -1)
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, num_heads, -1, -1)
merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded)
return merged_mask, mask_type

def _can_use_fastpath(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,

):
why_not_fast_path = ''
if torch.overrides.has_torch_function((
query,
key,
value,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,)):
why_not_fast_path = "some Tensor argument has_torch_function"
else:
is_batched = query.dim() == 3
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif in_proj_bias is not None and query.dtype != in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({in_proj_bias.dtype}) don't match"
elif in_proj_weight is not None and query.dtype != in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({in_proj_weight.dtype}) don't match"
elif training:
why_not_fast_path = "training is enabled"
#elif not self.batch_first:
# why_not_fast_path = "batch_first was not True"
elif bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif dropout_p:
why_not_fast_path = f"dropout was {dropout_p}, required zero"
elif add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
#elif not self._qkv_same_embed_dim:
# why_not_fast_path = "_qkv_same_embed_dim was not True"
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
is not supported with NestedTensor input"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"

if not why_not_fast_path:
tensor_args = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
)

# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif in_proj_weight is None:
why_not_fast_path = "in projection weight is None"
elif in_proj_bias is None:
why_not_fast_path = "in projection bias is None"
elif out_proj_bias is None:
why_not_fast_path = "out projecion bias is None"
elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
return why_not_fast_path

def _multi_head_attention_forward_impl(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Please, refer to multi_head_attention_forward documentation. Implementation of multihead_attention without fast path.
"""

tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
if has_torch_function(tens_ops):
return handle_torch_function(
Expand Down