Skip to content

Commit

Permalink
Fix the new API.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 14, 2020
1 parent 3b43863 commit b9fc867
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
17 changes: 13 additions & 4 deletions python-package/xgboost/data.py
Expand Up @@ -539,19 +539,27 @@ def _to_data_type(dtype: str, name: str):
return dtype_map[dtype]


def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]

c_type = _to_data_type(str(data.dtype), field)
data = interface['data']
data = ctypes.c_void_p(data[0])
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
data,
size,
ptr,
c_bst_ulong(size),
c_type
))

Expand Down Expand Up @@ -603,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
_validate_meta_shape(data)
if data is None:
return
if _is_list(data):
Expand Down
14 changes: 2 additions & 12 deletions tests/python/test_dmatrix.py
Expand Up @@ -96,18 +96,13 @@ def test_np_view(self):
assert (from_view == from_array).all()

def test_slice(self):
# 2887052510386eb7d12e09c859529a4f2b01ba35b847c807f95cbaddbb5eea7a
X = rng.randn(100, 100)
# 85a40616f6748d98e59baf2961b655e0ebb2fc8ac298fc638a173e434073a0f9
y = rng.randint(low=0, high=3, size=100)
# ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef
# failed:
# 33d8eee9310c7f6ff57f0c876b5977f39f7e450eb7c71513cc4c133c57921a6
d = xgb.DMatrix(X, y)
np.testing.assert_equal(d.get_label(), y.astype(np.float32))

# fw = rng.uniform(size=100).astype(np.float32)
# d.set_info(feature_weights=fw)
fw = rng.uniform(size=100).astype(np.float32)
d.set_info(feature_weights=fw)

eval_res_0 = {}
booster = xgb.train(
Expand All @@ -127,12 +122,7 @@ def test_slice(self):
d.set_base_margin(predt)

ridxs = [1, 2, 3, 4, 5, 6]
# failed:
# f3459dc754e30ff09e91c9660789cef53d998d6256f8ee81cc9c304cc54fbf40
# passed:
# ab7b1162766f953935c244dd73d663ac0b79e4bd0a914a64e5b0fe66ae55b7ef
sliced = d.slice(ridxs)
# np.testing.assert_equal(sliced.get_float_info('feature_weights'), fw)

sliced_margin = sliced.get_float_info('base_margin')
assert sliced_margin.shape[0] == len(ridxs) * 3
Expand Down

0 comments on commit b9fc867

Please sign in to comment.