Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Nov 8, 2021
1 parent 076d12d commit 1284889
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
10 changes: 7 additions & 3 deletions mlflow/models/utils.py
Expand Up @@ -60,9 +60,13 @@ def _is_scalar(x):
return np.isscalar(x) or x is None

def _is_tensor(x):
return isinstance(x, np.ndarray) or (
isinstance(x, dict) and all([isinstance(ary, np.ndarray) for ary in x.values()])
) or isinstance(x, (csr_matrix, csc_matrix))
return (
isinstance(x, np.ndarray)
or (
isinstance(x, dict) and all([isinstance(ary, np.ndarray) for ary in x.values()])
)
or isinstance(x, (csr_matrix, csc_matrix))
)

def _handle_tensor_input(input_tensor: Union[np.ndarray, dict, csr_matrix, csc_matrix]):
if isinstance(input_tensor, dict):
Expand Down
2 changes: 2 additions & 0 deletions mlflow/types/utils.py
Expand Up @@ -31,6 +31,7 @@ def _get_tensor_shape(data: np.ndarray, variable_dimension: Optional[int] = 0) -
:return: tuple : Shape of the inputted data (including a variable dimension)
"""
from scipy.sparse import csr_matrix, csc_matrix

if not isinstance(data, (np.ndarray, csr_matrix, csc_matrix)):
raise TypeError("Expected numpy.ndarray or csc/csr matrix, got '{}'.".format(type(data)))
variable_input_data_shape = data.shape
Expand Down Expand Up @@ -101,6 +102,7 @@ def _infer_schema(data: Any) -> Schema:
:return: Schema
"""
from scipy.sparse import csr_matrix, csc_matrix

if isinstance(data, dict):
res = []
for name in data.keys():
Expand Down
2 changes: 0 additions & 2 deletions mlflow/utils/autologging_utils/__init__.py
Expand Up @@ -174,8 +174,6 @@ def resolve_input_example_and_signature(
model_signature = infer_model_signature(input_example)
except Exception as e:
model_signature_user_msg = "Failed to infer model signature: " + str(e)
import traceback
traceback.print_exc()

if log_input_example and input_example_user_msg is not None:
logger.warning(input_example_user_msg)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -71,6 +71,7 @@ def package_files(directory):
"Flask",
"gunicorn; platform_system != 'Windows'",
"numpy",
"scipy",
"pandas",
"prometheus-flask-exporter",
"querystring_parser",
Expand Down

0 comments on commit 1284889

Please sign in to comment.