Skip to content

Commit

Permalink
Fix: nn.Parameter return type identified as Tensor instead of `nn…
Browse files Browse the repository at this point in the history
….Parameter` (pytorch#125106)

Fixes pytorch#125105

Pull Request resolved: pytorch#125106
Approved by: https://github.com/ezyang, https://github.com/albanD
  • Loading branch information
randolf-scholz authored and petrex committed May 3, 2024
1 parent 8ec96e6 commit 738284c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
8 changes: 8 additions & 0 deletions test/typing/pass/creation_ops.py
Expand Up @@ -2,6 +2,10 @@
# flake8: noqa
import torch
from torch.testing._internal.common_utils import TEST_NUMPY

from typing_extensions import assert_type


if TEST_NUMPY:
import numpy as np

Expand Down Expand Up @@ -117,3 +121,7 @@
inp = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])
torch.heaviside(inp, values)

# Parameter
p = torch.nn.Parameter(torch.empty(1))
assert_type(p, torch.nn.Parameter)
10 changes: 5 additions & 5 deletions tools/pyi/gen_pyi.py
Expand Up @@ -1064,14 +1064,14 @@ def replace_special_case(hint: str) -> str:
"new_tensor": [
f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..."
],
"__new__": ["def __new__(self, *args, **kwargs) -> Tensor: ..."],
"__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."],
# new and __init__ have the same signatures differ only in return type
# Adapted from legacy_tensor_ctor and legacy_tensor_new
"new": [
f"def new(self, *args: Any, {DEVICE_PARAM}) -> Tensor: ...",
"def new(self, storage: Storage) -> Tensor: ...",
"def new(self, other: Tensor) -> Tensor: ...",
f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...",
f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...",
"def new(cls, storage: Storage) -> Self: ...",
"def new(cls, other: Tensor) -> Self: ...",
f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...",
],
"__init__": [
f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...",
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Expand Up @@ -29,7 +29,7 @@ from typing import (
overload,
runtime_checkable,
)
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Self

import numpy

Expand Down

0 comments on commit 738284c

Please sign in to comment.