Skip to content

Commit

Permalink
Restore unknown data support. (dmlc#6595)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 13, 2021
1 parent 89a00a5 commit d356b7a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
21 changes: 1 addition & 20 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,25 +262,6 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)


def _convert_unknown_data(data, meta=None, meta_type=None):
if meta is not None:
try:
data = np.array(data, dtype=meta_type)
except Exception as e:
raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e
else:
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
try:
data = scipy.sparse.csr_matrix(data)
except Exception as e:
raise TypeError('Can not initialize DMatrix from'
' {}'.format(type(data).__name__)) from e
return data


class DataIter:
'''The interface for user defined data iterator. Currently is only
supported by Device DMatrix.
Expand Down Expand Up @@ -542,7 +523,7 @@ def set_info(self, *,
if group is not None:
self.set_group(group)
if qid is not None:
dispatch_meta_backend(matrix=self, data=qid, name='qid')
self.set_uint_info('qid', qid)
if label_lower_bound is not None:
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:
Expand Down
23 changes: 23 additions & 0 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,24 @@ def _has_array_protocol(data):
return hasattr(data, '__array__')


def _convert_unknown_data(data):
warnings.warn(
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
UserWarning
)
try:
import scipy
except ImportError:
return None

try:
data = scipy.sparse.csr_matrix(data)
except Exception: # pylint: disable=broad-except
return None

return data


def dispatch_data_backend(data, missing, threads,
feature_names, feature_types,
enable_categorical=False):
Expand Down Expand Up @@ -570,6 +588,11 @@ def dispatch_data_backend(data, missing, threads,
feature_types)
if _has_array_protocol(data):
pass

converted = _convert_unknown_data(data)
if converted:
return _from_scipy_csr(data, missing, feature_names, feature_types)

raise TypeError('Not supported type for data.' + str(type(data)))


Expand Down
5 changes: 3 additions & 2 deletions python-package/xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def get_host_ip(hostIP=None):
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.warning(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
Expand Down
9 changes: 9 additions & 0 deletions tests/python/test_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,12 @@ def test_sparse_dmatrix_csc(self):
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
bst = xgb.train(param, dtrain, 5, watchlist)
bst.predict(dtrain)

def test_unknown_data(self):
class Data:
pass

with pytest.raises(TypeError):
with pytest.warns(UserWarning):
d = Data()
xgb.DMatrix(d)

0 comments on commit d356b7a

Please sign in to comment.