Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cosine similarity dim checks #66214

Merged
merged 2 commits into from Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 3 additions & 6 deletions aten/src/ATen/native/Distance.cpp
Expand Up @@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
}

Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
x1.ndimension(), " and x2 has ", x2.ndimension());
TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
x1.size(dim), " and x2 has ", x2.size(dim));
auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
auto commonDtype = at::result_type(x1, x2);
TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
Tensor x1_ = x1.to(commonDtype);
Tensor x2_ = x2.to(commonDtype);
Tensor x1_ = x1.to(commonDtype).expand(common_size);
Tensor x2_ = x2.to(commonDtype).expand(common_size);
// Follow scipy impl to improve numerical precision
// Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
Tensor w12 = at::sum(x1_ * x2_, dim);
Expand Down
6 changes: 0 additions & 6 deletions test/test_nn.py
Expand Up @@ -9704,12 +9704,6 @@ def test_cosine_similarity(self):
self.assertEqual(input1.grad, torch.zeros_like(input1))
self.assertEqual(input2.grad, input1 * 1e8)

# Check error when inputs are not the same shape
input1 = torch.randn(2, 2, 1)
input2 = torch.randn(2, 1, 3)
with self.assertRaises(RuntimeError):
F.cosine_similarity(input1, input2)

# Check type promotion, issue #61454
input = torch.tensor(12.)
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
Expand Down
14 changes: 6 additions & 8 deletions torch/nn/functional.py
Expand Up @@ -4256,7 +4256,10 @@ def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6,
r"""
cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor

Returns cosine similarity between x1 and x2, computed along dim.
Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 fewer dimension.

.. math ::
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
Expand All @@ -4265,16 +4268,11 @@ def pairwise_distance(x1: Tensor, x2: Tensor, p: float = 2.0, eps: float = 1e-6,

Args:
x1 (Tensor): First input.
x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
and broadcastable with x1 at other dimensions).
dim (int, optional): Dimension of vectors. Default: 1
x2 (Tensor): Second input.
dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
eps (float, optional): Small value to avoid division by zero.
Default: 1e-8

Shape:
- Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
- Output: :math:`(\ast_1, \ast_2)`

Example::

>>> input1 = torch.randn(100, 128)
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1256,6 +1256,8 @@ def generator():
yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
# Test for Broadcasting
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})

return list(generator())

Expand Down