Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Accelerators] Update torch.nn.functional multi_head_attention_forward(). Add pastpath #88912

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_jit.py
Expand Up @@ -15106,8 +15106,8 @@ def test_functional_multi_head_attn_fast_path_None(self):
with torch.no_grad():
fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
embed_size, nhead,
None,
None,
multi_head_attn_nn.in_proj_weight,
multi_head_attn_nn.in_proj_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change? Is this to avoid the assertion that you added elsewhere to prevent bias from being None?

Projection bias can be None, see the test and fix in this as per issue raised. #88970

multi_head_attn_nn.bias_k,
multi_head_attn_nn.bias_v,
multi_head_attn_nn.add_zero_attn,
Expand Down
41 changes: 18 additions & 23 deletions torch/nn/functional.py
Expand Up @@ -2524,8 +2524,6 @@ def group_norm(
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps)
if input.dim() < 2:
raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}")
_verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))
return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)

Expand Down Expand Up @@ -4993,30 +4991,20 @@ def multi_head_attention_forward(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
with torch.no_grad():
in_proj_weight_ = torch.nn.init.xavier_uniform_(torch.empty((3 * embed_dim_to_check, embed_dim_to_check),\
device=query.device, dtype=query.dtype))\
if in_proj_weight is None or in_proj_weight.dim() != 2 else in_proj_weight
with torch.no_grad():
in_proj_bias_ = torch.zeros(3 * embed_dim_to_check, device=query.device, dtype=query.dtype) \
if in_proj_bias is None else in_proj_bias

with torch.no_grad():
out_proj_bias_ = torch.zeros(embed_dim_to_check, device=query.device, dtype=query.dtype) \
if out_proj_bias is None else out_proj_bias
why_not_fast_path = _can_use_fastpath(query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight_,
in_proj_bias_,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias_,
out_proj_bias,
training,
key_padding_mask,
attn_mask)
Expand All @@ -5038,10 +5026,10 @@ def multi_head_attention_forward(
value,
embed_dim_to_check,
num_heads,
in_proj_weight_,
in_proj_bias_,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias_,
out_proj_bias,
merged_mask,
need_weights,
average_attn_weights,
Expand Down Expand Up @@ -5116,14 +5104,14 @@ def _can_use_fastpath(
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -5183,13 +5171,20 @@ def _can_use_fastpath(
out_proj_weight,
out_proj_bias,
)

# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
elif in_proj_weight is None:
why_not_fast_path = "in projection weight is None"
elif in_proj_bias is None:
why_not_fast_path = "in projection bias is None"
elif out_proj_bias is None:
why_not_fast_path = "out projecion bias is None"
elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
return why_not_fast_path
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/activation.py
Expand Up @@ -1060,7 +1060,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Expand Down