Skip to content

Commit

Permalink
ENH Adds FeatureHasher support to pypy (#23023)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan committed Apr 8, 2022
1 parent 5a23a85 commit b2632de
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 74 deletions.
2 changes: 0 additions & 2 deletions doc/conftest.py
Expand Up @@ -126,8 +126,6 @@ def pytest_runtest_setup(item):
setup_working_with_text_data()
elif fname.endswith("modules/compose.rst") or is_index:
setup_compose()
elif IS_PYPY and fname.endswith("modules/feature_extraction.rst"):
raise SkipTest("FeatureHasher is not compatible with PyPy")
elif fname.endswith("datasets/loading_other_datasets.rst"):
setup_loading_other_datasets()
elif fname.endswith("modules/impute.rst"):
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/v1.1.rst
Expand Up @@ -474,6 +474,9 @@ Changelog
:mod:`sklearn.feature_extraction`
.................................

- |Feature| :class:`feature_extraction.FeatureHasher` now supports PyPy.
:pr:`23023` by `Thomas Fan`_.

- |Fix| :class:`feature_extraction.FeatureHasher` now validates input parameters
in `transform` instead of `__init__`. :pr:`21573` by
:user:`Hannah Bohle <hhnnhh>` and :user:`Maren Westermann <marenwestermann>`.
Expand Down
11 changes: 1 addition & 10 deletions sklearn/conftest.py
Expand Up @@ -118,17 +118,8 @@ def pytest_collection_modifyitems(config, items):
dataset_fetchers[name]()

for item in items:
# FeatureHasher is not compatible with PyPy
if (
item.name.endswith(("_hash.FeatureHasher", "text.HashingVectorizer"))
and platform.python_implementation() == "PyPy"
):
marker = pytest.mark.skip(
reason="FeatureHasher is not compatible with PyPy"
)
item.add_marker(marker)
# Known failure on with GradientBoostingClassifier on ARM64
elif (
if (
item.name.endswith("GradientBoostingClassifier")
and platform.machine() == "aarch64"
):
Expand Down
13 changes: 1 addition & 12 deletions sklearn/feature_extraction/_hash.py
Expand Up @@ -6,19 +6,8 @@
import numpy as np
import scipy.sparse as sp

from ..utils import IS_PYPY
from ..base import BaseEstimator, TransformerMixin

if not IS_PYPY:
from ._hashing_fast import transform as _hashing_transform
else:

def _hashing_transform(*args, **kwargs):
raise NotImplementedError(
"FeatureHasher is not compatible with PyPy (see "
"https://github.com/scikit-learn/scikit-learn/issues/11540 "
"for the status updates)."
)
from ._hashing_fast import transform as _hashing_transform


def _iteritems(d):
Expand Down
41 changes: 17 additions & 24 deletions sklearn/feature_extraction/_hashing_fast.pyx
Expand Up @@ -3,13 +3,15 @@

import sys
import array
from cpython cimport array
cimport cython
from libc.stdlib cimport abs
from libcpp.vector cimport vector

cimport numpy as np
import numpy as np

from ..utils._typedefs cimport INT32TYPE_t, INT64TYPE_t
from ..utils.murmurhash cimport murmurhash3_bytes_s32
from ..utils._vector_sentinel cimport vector_to_nd_array

np.import_array()

Expand All @@ -25,19 +27,12 @@ def transform(raw_X, Py_ssize_t n_features, dtype,
For constructing a scipy.sparse.csr_matrix.
"""
assert n_features > 0

cdef np.int32_t h
cdef INT32TYPE_t h
cdef double value

cdef array.array indices
cdef array.array indptr
indices = array.array("i")
indices_array_dtype = "q"
indices_np_dtype = np.longlong


indptr = array.array(indices_array_dtype, [0])
cdef vector[INT32TYPE_t] indices
cdef vector[INT64TYPE_t] indptr
indptr.push_back(0)

# Since Python array does not understand Numpy dtypes, we grow the indices
# and values arrays ourselves. Use a Py_ssize_t capacity for safety.
Expand Down Expand Up @@ -65,13 +60,12 @@ def transform(raw_X, Py_ssize_t n_features, dtype,

h = murmurhash3_bytes_s32(<bytes>f, seed)

array.resize_smart(indices, len(indices) + 1)
if h == - 2147483648:
# abs(-2**31) is undefined behavior because h is a `np.int32`
# The following is defined such that it is equal to: abs(-2**31) % n_features
indices[len(indices) - 1] = (2147483647 - (n_features - 1)) % n_features
indices.push_back((2147483647 - (n_features - 1)) % n_features)
else:
indices[len(indices) - 1] = abs(h) % n_features
indices.push_back(abs(h) % n_features)
# improve inner product preservation in the hashed space
if alternate_sign:
value *= (h >= 0) * 2 - 1
Expand All @@ -84,16 +78,15 @@ def transform(raw_X, Py_ssize_t n_features, dtype,
# references to the arrays due to Cython's error checking
values = np.resize(values, capacity)

array.resize_smart(indptr, len(indptr) + 1)
indptr[len(indptr) - 1] = size
indptr.push_back(size)

indices_a = np.frombuffer(indices, dtype=np.int32)
indptr_a = np.frombuffer(indptr, dtype=indices_np_dtype)
indicies_array = vector_to_nd_array(&indices)
indptr_array = vector_to_nd_array(&indptr)

if indptr[len(indptr) - 1] > np.iinfo(np.int32).max: # = 2**31 - 1
if indptr_array[indptr_array.shape[0]-1] > np.iinfo(np.int32).max: # = 2**31 - 1
# both indices and indptr have the same dtype in CSR arrays
indices_a = indices_a.astype(np.int64, copy=False)
indicies_array = indicies_array.astype(np.int64, copy=False)
else:
indptr_a = indptr_a.astype(np.int32, copy=False)
indptr_array = indptr_array.astype(np.int32, copy=False)

return (indices_a, indptr_a, values[:size])
return (indicies_array, indptr_array, values[:size])
15 changes: 7 additions & 8 deletions sklearn/feature_extraction/setup.py
@@ -1,5 +1,4 @@
import os
import platform


def configuration(parent_package="", top_path=None):
Expand All @@ -11,13 +10,13 @@ def configuration(parent_package="", top_path=None):
if os.name == "posix":
libraries.append("m")

if platform.python_implementation() != "PyPy":
config.add_extension(
"_hashing_fast",
sources=["_hashing_fast.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
)
config.add_extension(
"_hashing_fast",
sources=["_hashing_fast.pyx"],
include_dirs=[numpy.get_include()],
language="c++",
libraries=libraries,
)
config.add_subpackage("tests")

return config
7 changes: 1 addition & 6 deletions sklearn/feature_extraction/tests/test_feature_hasher.py
Expand Up @@ -3,9 +3,7 @@
import pytest

from sklearn.feature_extraction import FeatureHasher
from sklearn.utils._testing import fails_if_pypy

pytestmark = fails_if_pypy
from sklearn.feature_extraction._hashing_fast import transform as _hashing_transform


def test_feature_hasher_dicts():
Expand Down Expand Up @@ -47,9 +45,6 @@ def test_feature_hasher_strings():

def test_hashing_transform_seed():
# check the influence of the seed when computing the hashes
# import is here to avoid importing on pypy
from sklearn.feature_extraction._hashing_fast import transform as _hashing_transform

raw_X = [
["foo", "bar", "baz", "foo".encode("ascii")],
["bar".encode("ascii"), "baz", "quux"],
Expand Down
10 changes: 0 additions & 10 deletions sklearn/utils/__init__.py
Expand Up @@ -1210,16 +1210,6 @@ def is_abstract(c):
classes = [
(name, est_cls) for name, est_cls in classes if not name.startswith("_")
]

# TODO: Remove when FeatureHasher is implemented in PYPY
# Skips FeatureHasher for PYPY
if IS_PYPY and "feature_extraction" in modname:
classes = [
(name, est_cls)
for name, est_cls in classes
if name == "FeatureHasher"
]

all_classes.extend(classes)

all_classes = set(all_classes)
Expand Down
4 changes: 4 additions & 0 deletions sklearn/utils/_typedefs.pxd
Expand Up @@ -7,7 +7,11 @@ ctypedef np.float64_t DTYPE_t # WARNING: should match DTYPE in typedefs.pyx
cdef enum:
DTYPECODE = np.NPY_FLOAT64
ITYPECODE = np.NPY_INTP
INT32TYPECODE = np.NPY_INT32
INT64TYPECODE = np.NPY_INT64

# Index/integer type.
# WARNING: ITYPE_t must be a signed integer type or you will have a bad time!
ctypedef np.intp_t ITYPE_t # WARNING: should match ITYPE in typedefs.pyx
ctypedef np.int32_t INT32TYPE_t # WARNING: should match INT32TYPE in typedefs.pyx
ctypedef np.int64_t INT64TYPE_t # WARNING: should match INT32TYPE in typedefs.pyx
2 changes: 2 additions & 0 deletions sklearn/utils/_typedefs.pyx
Expand Up @@ -14,6 +14,8 @@ np.import_array()
#cdef ITYPE_t[:] idummy_view = <ITYPE_t[:1]> &idummy
#ITYPE = np.asarray(idummy_view).dtype
ITYPE = np.intp # WARNING: this should match ITYPE_t in typedefs.pxd
INT32TYPE = np.int32 # WARNING: should match INT32TYPE_t in typedefs.pyx
INT64TYPE = np.int64 # WARNING: this should match INT64TYPE_t in typedefs.pxd

#cdef DTYPE_t ddummy
#cdef DTYPE_t[:] ddummy_view = <DTYPE_t[:1]> &ddummy
Expand Down
4 changes: 3 additions & 1 deletion sklearn/utils/_vector_sentinel.pxd
@@ -1,10 +1,12 @@
cimport numpy as np

from libcpp.vector cimport vector
from ..utils._typedefs cimport ITYPE_t, DTYPE_t
from ..utils._typedefs cimport ITYPE_t, DTYPE_t, INT32TYPE_t, INT64TYPE_t

ctypedef fused vector_typed:
vector[DTYPE_t]
vector[ITYPE_t]
vector[INT32TYPE_t]
vector[INT64TYPE_t]

cdef np.ndarray vector_to_nd_array(vector_typed * vect_ptr)
42 changes: 41 additions & 1 deletion sklearn/utils/_vector_sentinel.pyx
Expand Up @@ -2,14 +2,18 @@ from cython.operator cimport dereference as deref
from cpython.ref cimport Py_INCREF
cimport numpy as np

from ._typedefs cimport DTYPECODE, ITYPECODE
from ._typedefs cimport DTYPECODE, ITYPECODE, INT32TYPECODE, INT64TYPECODE

np.import_array()


cdef StdVectorSentinel _create_sentinel(vector_typed * vect_ptr):
if vector_typed is vector[DTYPE_t]:
return StdVectorSentinelFloat64.create_for(vect_ptr)
elif vector_typed is vector[INT32TYPE_t]:
return StdVectorSentinelInt32.create_for(vect_ptr)
elif vector_typed is vector[INT64TYPE_t]:
return StdVectorSentinelInt64.create_for(vect_ptr)
else:
return StdVectorSentinelIntP.create_for(vect_ptr)

Expand Down Expand Up @@ -64,6 +68,42 @@ cdef class StdVectorSentinelIntP(StdVectorSentinel):
return ITYPECODE


cdef class StdVectorSentinelInt32(StdVectorSentinel):
cdef vector[INT32TYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[INT32TYPE_t] * vect_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelInt32 sentinel = StdVectorSentinelInt32.__new__(StdVectorSentinelInt32)
sentinel.vec.swap(deref(vect_ptr))
return sentinel

cdef void* get_data(self):
return self.vec.data()

cdef int get_typenum(self):
return INT32TYPECODE


cdef class StdVectorSentinelInt64(StdVectorSentinel):
cdef vector[INT64TYPE_t] vec

@staticmethod
cdef StdVectorSentinel create_for(vector[INT64TYPE_t] * vect_ptr):
# This initializes the object directly without calling __init__
# See: https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#instantiation-from-existing-c-c-pointers # noqa
cdef StdVectorSentinelInt64 sentinel = StdVectorSentinelInt64.__new__(StdVectorSentinelInt64)
sentinel.vec.swap(deref(vect_ptr))
return sentinel

cdef void* get_data(self):
return self.vec.data()

cdef int get_typenum(self):
return INT64TYPECODE


cdef np.ndarray vector_to_nd_array(vector_typed * vect_ptr):
cdef:
np.npy_intp size = deref(vect_ptr).size()
Expand Down

0 comments on commit b2632de

Please sign in to comment.