Skip to content

Commit

Permalink
Cache the model.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 29, 2022
1 parent 03adff6 commit 721c4ce
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions tests/python/test_model_compatibility.py
Expand Up @@ -102,34 +102,38 @@ def run_scikit_model_check(name, path):

@pytest.mark.skipif(**tm.no_sklearn())
def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't
"""Test model compatibility, can only be run on CI as others don't
have the credentials.
'''
"""
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models')
path = os.path.join(path, "models")

zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' +
'.amazonaws.com/xgboost_model_compatibility_test.zip')
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(path)
if not os.path.exists(path):
zip_path, _ = urllib.request.urlretrieve(
"https://xgboost-ci-jenkins-artifacts.s3-us-west-2"
+ ".amazonaws.com/xgboost_model_compatibility_test.zip"
)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(path)

models = [
os.path.join(root, f) for root, subdir, files in os.walk(path)
os.path.join(root, f)
for root, subdir, files in os.walk(path)
for f in files
if f != 'version'
if f != "version"
]
assert models

for path in models:
name = os.path.basename(path)
if name.startswith('xgboost-'):
if name.startswith("xgboost-"):
booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name)
# Do full serialization.
booster = copy.copy(booster)
run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'):
elif name.startswith("xgboost_scikit"):
run_scikit_model_check(name, path)
else:
assert False

0 comments on commit 721c4ce

Please sign in to comment.