Skip to content

Commit

Permalink
Fix CLI ranking demo. (#6439)
Browse files Browse the repository at this point in the history
Save model at final round.
  • Loading branch information
trivialfis committed Nov 28, 2020
1 parent b0036b3 commit f4ff1c5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
16 changes: 7 additions & 9 deletions demo/rank/mq2008.conf
Expand Up @@ -5,9 +5,9 @@ objective="rank:pairwise"

# Tree Booster Parameters
# step size shrinkage
eta = 0.1
eta = 0.1
# minimum loss reduction required to make a further partition
gamma = 1.0
gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child
min_child_weight = 0.1
# maximum depth of a tree
Expand All @@ -17,12 +17,10 @@ max_depth = 6
# the number of round to do boosting
num_round = 4
# 0 means do not save any model except the final round model
save_period = 0
save_period = 0
# The path of training data
data = "mq2008.train"
data = "mq2008.train"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "mq2008.vali"
# The path of test data
test:data = "mq2008.test"


eval[test] = "mq2008.vali"
# The path of test data
test:data = "mq2008.test"
2 changes: 1 addition & 1 deletion src/cli_main.cc
Expand Up @@ -268,7 +268,7 @@ class CLI {
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
param_.model_out != CLIParam::kNull && rabit::GetRank() == 0) {
rabit::GetRank() == 0) {
std::ostringstream os;
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
Expand Down
35 changes: 32 additions & 3 deletions tests/python/test_cli.py
Expand Up @@ -22,6 +22,7 @@ class TestCLI:
model_out = {model_out}
test_path = {test_path}
name_pred = {name_pred}
model_dir = {model_dir}
num_round = 10
data = {data_path}
Expand Down Expand Up @@ -59,7 +60,8 @@ def test_cli_model(self):
model_in='NULL',
model_out=model_out_cli,
test_path='NULL',
name_pred='NULL')
name_pred='NULL',
model_dir='NULL')
with open(config_path, 'w') as fd:
fd.write(train_conf)

Expand All @@ -73,7 +75,8 @@ def test_cli_model(self):
model_in=model_out_cli,
model_out='NULL',
test_path=data_path,
name_pred=predict_out)
name_pred=predict_out,
model_dir='NULL')
with open(config_path, 'w') as fd:
fd.write(predict_conf)

Expand Down Expand Up @@ -145,7 +148,8 @@ def test_cli_model_json(self):
model_in='NULL',
model_out=model_out_cli,
test_path='NULL',
name_pred='NULL')
name_pred='NULL',
model_dir='NULL')
with open(config_path, 'w') as fd:
fd.write(train_conf)

Expand All @@ -154,3 +158,28 @@ def test_cli_model_json(self):
model = json.load(fd)

assert model['learner']['gradient_booster']['name'] == 'gbtree'

def test_cli_save_model(self):
'''Test save on final round'''
exe = self.get_exe()
data_path = "{root}/demo/data/agaricus.txt.train?format=libsvm".format(
root=self.PROJECT_ROOT)
seed = 1994

with tempfile.TemporaryDirectory() as tmpdir:
model_out_cli = os.path.join(tmpdir, '0010.model')
config_path = os.path.join(tmpdir, 'test_load_cli_model.conf')

train_conf = self.template.format(data_path=data_path,
seed=seed,
task='train',
model_in='NULL',
model_out='NULL',
test_path='NULL',
name_pred='NULL',
model_dir=tmpdir)
with open(config_path, 'w') as fd:
fd.write(train_conf)

subprocess.run([exe, config_path])
assert os.path.exists(model_out_cli)

0 comments on commit f4ff1c5

Please sign in to comment.