Skip to content

Commit

Permalink
TENSOR: Prune numbers real
Browse files Browse the repository at this point in the history
* Real and mypy don't play nice python/mypy#3186
* This allows partial typing support of HOSVD
  • Loading branch information
ntjohnson1 committed Mar 16, 2023
1 parent b34e10f commit 55ea66e
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import warnings
from itertools import permutations
from math import factorial
from numbers import Real
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -181,9 +180,9 @@ def collapse(
self,
dims: Optional[np.ndarray] = None,
fun: Union[
Literal["sum"], Callable[[np.ndarray], Union[Real, np.ndarray]]
Literal["sum"], Callable[[np.ndarray], Union[float, np.ndarray]]
] = "sum",
) -> Union[Real, np.ndarray, tensor]:
) -> Union[float, np.ndarray, tensor]:
"""
Collapse tensor along specified dimensions.
Expand Down Expand Up @@ -389,7 +388,7 @@ def full(self) -> tensor:
"""
return ttb.tensor.from_data(self.data)

def innerprod(self, other: Union[tensor, ttb.sptensor, ttb.ktensor]) -> Real:
def innerprod(self, other: Union[tensor, ttb.sptensor, ttb.ktensor]) -> float:
"""
Efficient inner product with a tensor
Expand Down Expand Up @@ -542,7 +541,7 @@ def issymmetric(
return bool((all_diffs == 0).all())
return bool((all_diffs == 0).all()), all_diffs, all_perms

def logical_and(self, B: Union[Real, tensor]) -> tensor:
def logical_and(self, B: Union[float, tensor]) -> tensor:
"""
Logical and for tensors
Expand Down Expand Up @@ -578,7 +577,7 @@ def logical_not(self) -> tensor:
"""
return ttb.tensor.from_data(np.logical_not(self.data))

def logical_or(self, other: Union[Real, tensor]) -> tensor:
def logical_or(self, other: Union[float, tensor]) -> tensor:
"""
Logical or for tensors
Expand All @@ -598,7 +597,7 @@ def tensor_or(x, y):

return ttb.tt_tenfun(tensor_or, self, other)

def logical_xor(self, other: Union[Real, tensor]) -> tensor:
def logical_xor(self, other: Union[float, tensor]) -> tensor:
"""
Logical xor for tensors
Expand Down Expand Up @@ -867,7 +866,7 @@ def reshape(self, shape: Tuple[int, ...]) -> tensor:

return ttb.tensor.from_data(np.reshape(self.data, shape, order="F"), shape)

def squeeze(self) -> Union[tensor, np.ndarray, Real]:
def squeeze(self) -> Union[tensor, np.ndarray, float]:
"""
Removes singleton dimensions from a tensor
Expand Down Expand Up @@ -1029,7 +1028,7 @@ def symmetrize(
def ttm(
self,
matrix: Union[np.ndarray, List[np.ndarray]],
dims: Optional[Union[Real, np.ndarray]] = None,
dims: Optional[Union[float, np.ndarray]] = None,
transpose: bool = False,
) -> tensor:
"""
Expand All @@ -1046,7 +1045,7 @@ def ttm(
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif isinstance(dims, Real):
elif isinstance(dims, (float, int, np.generic)):
dims = np.array([dims])

if isinstance(matrix, list):
Expand Down

0 comments on commit 55ea66e

Please sign in to comment.