From 721c4ce58df725c1c272b1c5ef4ba56851203864 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 29 Jul 2022 15:13:43 +0800 Subject: [PATCH] Cache the model. --- tests/python/test_model_compatibility.py | 26 ++++++++++++++---------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/python/test_model_compatibility.py b/tests/python/test_model_compatibility.py index 6f9a184922ab..88549e1f2acb 100644 --- a/tests/python/test_model_compatibility.py +++ b/tests/python/test_model_compatibility.py @@ -102,34 +102,38 @@ def run_scikit_model_check(name, path): @pytest.mark.skipif(**tm.no_sklearn()) def test_model_compatibility(): - '''Test model compatibility, can only be run on CI as others don't + """Test model compatibility, can only be run on CI as others don't have the credentials. - ''' + """ path = os.path.dirname(os.path.abspath(__file__)) - path = os.path.join(path, 'models') + path = os.path.join(path, "models") - zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' + - '.amazonaws.com/xgboost_model_compatibility_test.zip') - with zipfile.ZipFile(zip_path, 'r') as z: - z.extractall(path) + if not os.path.exists(path): + zip_path, _ = urllib.request.urlretrieve( + "https://xgboost-ci-jenkins-artifacts.s3-us-west-2" + + ".amazonaws.com/xgboost_model_compatibility_test.zip" + ) + with zipfile.ZipFile(zip_path, "r") as z: + z.extractall(path) models = [ - os.path.join(root, f) for root, subdir, files in os.walk(path) + os.path.join(root, f) + for root, subdir, files in os.walk(path) for f in files - if f != 'version' + if f != "version" ] assert models for path in models: name = os.path.basename(path) - if name.startswith('xgboost-'): + if name.startswith("xgboost-"): booster = xgboost.Booster(model_file=path) run_booster_check(booster, name) # Do full serialization. booster = copy.copy(booster) run_booster_check(booster, name) - elif name.startswith('xgboost_scikit'): + elif name.startswith("xgboost_scikit"): run_scikit_model_check(name, path) else: assert False