Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: nn.Parameter return type identified as Tensor instead of nn.Parameter #125106

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/typing/pass/creation_ops.py
Expand Up @@ -2,6 +2,10 @@
# flake8: noqa
import torch
from torch.testing._internal.common_utils import TEST_NUMPY

from typing_extensions import assert_type


if TEST_NUMPY:
import numpy as np

Expand Down Expand Up @@ -117,3 +121,7 @@
inp = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])
torch.heaviside(inp, values)

# Parameter
p = torch.nn.Parameter(torch.empty(1))
assert_type(p, torch.nn.Parameter)
10 changes: 5 additions & 5 deletions tools/pyi/gen_pyi.py
Expand Up @@ -1064,14 +1064,14 @@ def replace_special_case(hint: str) -> str:
"new_tensor": [
f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..."
],
"__new__": ["def __new__(self, *args, **kwargs) -> Tensor: ..."],
"__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."],
randolf-scholz marked this conversation as resolved.
Show resolved Hide resolved
# new and __init__ have the same signatures differ only in return type
# Adapted from legacy_tensor_ctor and legacy_tensor_new
"new": [
f"def new(self, *args: Any, {DEVICE_PARAM}) -> Tensor: ...",
"def new(self, storage: Storage) -> Tensor: ...",
"def new(self, other: Tensor) -> Tensor: ...",
f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...",
f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...",
"def new(cls, storage: Storage) -> Self: ...",
"def new(cls, other: Tensor) -> Self: ...",
f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...",
],
"__init__": [
f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...",
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Expand Up @@ -29,7 +29,7 @@ from typing import (
overload,
runtime_checkable,
)
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Self

import numpy

Expand Down