diff --git a/splunklib/binding.py b/splunklib/binding.py index bb4cfeb7..ce2c7ce0 100644 --- a/splunklib/binding.py +++ b/splunklib/binding.py @@ -30,6 +30,8 @@ import logging import socket import ssl +import sys +import time from base64 import b64encode from contextlib import contextmanager from datetime import datetime @@ -452,6 +454,12 @@ class Context(object): :type splunkToken: ``string`` :param headers: List of extra HTTP headers to send (optional). :type headers: ``list`` of 2-tuples. + :param retires: Number of retries for each HTTP connection (optional, the default is 0). + NOTE THAT THIS MAY INCREASE THE NUMBER OF ROUND TRIP CONNECTIONS TO THE SPLUNK SERVER AND BLOCK THE + CURRENT THREAD WHILE RETRYING. + :type retries: ``int`` + :param retryDelay: How long to wait between connection attempts if `retries` > 0 (optional, defaults to 10s). + :type retryDelay: ``int`` (in seconds) :param handler: The HTTP request handler (optional). :returns: A ``Context`` instance. @@ -469,7 +477,8 @@ class Context(object): """ def __init__(self, handler=None, **kwargs): self.http = HttpLib(handler, kwargs.get("verify", False), key_file=kwargs.get("key_file"), - cert_file=kwargs.get("cert_file"), context=kwargs.get("context")) # Default to False for backward compat + cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), # Default to False for backward compat + retries=kwargs.get("retries", 0), retryDelay=kwargs.get("retryDelay", 10)) self.token = kwargs.get("token", _NoAuthenticationToken) if self.token is None: # In case someone explicitly passes token=None self.token = _NoAuthenticationToken @@ -1153,12 +1162,14 @@ class HttpLib(object): If using the default handler, SSL verification can be disabled by passing verify=False. """ - def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None): + def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, retryDelay=10): if custom_handler is None: self.handler = handler(verify=verify, key_file=key_file, cert_file=cert_file, context=context) else: self.handler = custom_handler self._cookies = {} + self.retries = retries + self.retryDelay = retryDelay def delete(self, url, headers=None, **kwargs): """Sends a DELETE request to a URL. @@ -1272,7 +1283,16 @@ def request(self, url, message, **kwargs): its structure). :rtype: ``dict`` """ - response = self.handler(url, message, **kwargs) + while True: + try: + response = self.handler(url, message, **kwargs) + break + except Exception: + if self.retries <= 0: + raise + else: + time.sleep(self.retryDelay) + self.retries -= 1 response = record(response) if 400 <= response.status: raise HTTPError(response) diff --git a/splunklib/client.py b/splunklib/client.py index e0043196..ab276c3e 100644 --- a/splunklib/client.py +++ b/splunklib/client.py @@ -323,6 +323,11 @@ def connect(**kwargs): :type username: ``string`` :param `password`: The password for the Splunk account. :type password: ``string`` + :param retires: Number of retries for each HTTP connection (optional, the default is 0). + NOTE THAT THIS MAY INCREASE THE NUMBER OF ROUND TRIP CONNECTIONS TO THE SPLUNK SERVER. + :type retries: ``int`` + :param retryDelay: How long to wait between connection attempts if `retries` > 0 (optional, defaults to 10s). + :type retryDelay: ``int`` (in seconds) :param `context`: The SSLContext that can be used when setting verify=True (optional) :type context: ``SSLContext`` :return: An initialized :class:`Service` connection. @@ -391,6 +396,11 @@ class Service(_BaseService): :param `password`: The password, which is used to authenticate the Splunk instance. :type password: ``string`` + :param retires: Number of retries for each HTTP connection (optional, the default is 0). + NOTE THAT THIS MAY INCREASE THE NUMBER OF ROUND TRIP CONNECTIONS TO THE SPLUNK SERVER. + :type retries: ``int`` + :param retryDelay: How long to wait between connection attempts if `retries` > 0 (optional, defaults to 10s). + :type retryDelay: ``int`` (in seconds) :return: A :class:`Service` instance. **Example**:: diff --git a/tests/test_service.py b/tests/test_service.py index 127ce75f..2aaed448 100755 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -36,13 +36,13 @@ def test_capabilities(self): capabilities = self.service.capabilities self.assertTrue(isinstance(capabilities, list)) self.assertTrue(all([isinstance(c, str) for c in capabilities])) - self.assertTrue('change_own_password' in capabilities) # This should always be there... + self.assertTrue('change_own_password' in capabilities) # This should always be there... def test_info(self): info = self.service.info keys = ["build", "cpu_arch", "guid", "isFree", "isTrial", "licenseKeys", - "licenseSignature", "licenseState", "master_guid", "mode", - "os_build", "os_name", "os_version", "serverName", "version"] + "licenseSignature", "licenseState", "master_guid", "mode", + "os_build", "os_name", "os_version", "serverName", "version"] for key in keys: self.assertTrue(key in list(info.keys())) @@ -74,25 +74,25 @@ def test_app_namespace(self): def test_owner_wildcard(self): kwargs = self.opts.kwargs.copy() - kwargs.update({ 'app': "search", 'owner': "-" }) + kwargs.update({'app': "search", 'owner': "-"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_default_app(self): kwargs = self.opts.kwargs.copy() - kwargs.update({ 'app': None, 'owner': "admin" }) + kwargs.update({'app': None, 'owner': "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_app_wildcard(self): kwargs = self.opts.kwargs.copy() - kwargs.update({ 'app': "-", 'owner': "admin" }) + kwargs.update({'app': "-", 'owner': "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_user_namespace(self): kwargs = self.opts.kwargs.copy() - kwargs.update({ 'app': "search", 'owner': "admin" }) + kwargs.update({'app': "search", 'owner': "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() @@ -114,7 +114,7 @@ def test_parse_fail(self): def test_restart(self): service = client.connect(**self.opts.kwargs) self.service.restart(timeout=300) - service.login() # Make sure we are awake + service.login() # Make sure we are awake def test_read_outputs_with_type(self): name = testlib.tmpname() @@ -138,7 +138,7 @@ def test_splunk_version(self): for p in v: self.assertTrue(isinstance(p, int) and p >= 0) - for version in [(4,3,3), (5,), (5,0,1)]: + for version in [(4, 3, 3), (5,), (5, 0, 1)]: with self.fake_splunk_version(version): self.assertEqual(version, self.service.splunk_version) @@ -167,7 +167,7 @@ def _create_unauthenticated_service(self): 'scheme': self.opts.kwargs['scheme'] }) - #To check the HEC event endpoint using Endpoint instance + # To check the HEC event endpoint using Endpoint instance def test_hec_event(self): import json service_hec = client.connect(host='localhost', scheme='https', port=8088, @@ -175,7 +175,7 @@ def test_hec_event(self): event_collector_endpoint = client.Endpoint(service_hec, "/services/collector/event") msg = {"index": "main", "event": "Hello World"} response = event_collector_endpoint.post("", body=json.dumps(msg)) - self.assertEqual(response.status,200) + self.assertEqual(response.status, 200) class TestCookieAuthentication(unittest.TestCase): @@ -287,6 +287,7 @@ def test_login_with_multiple_cookies(self): service2.login() self.assertEqual(service2.apps.get().status, 200) + class TestSettings(testlib.SDKTestCase): def test_read_settings(self): settings = self.service.settings @@ -316,6 +317,7 @@ def test_update_settings(self): self.assertEqual(updated, original) self.restartSplunk() + class TestTrailing(unittest.TestCase): template = '/servicesNS/boris/search/another/path/segment/that runs on' @@ -329,7 +331,8 @@ def test_no_args_is_identity(self): self.assertEqual(self.template, client._trailing(self.template)) def test_trailing_with_one_arg_works(self): - self.assertEqual('boris/search/another/path/segment/that runs on', client._trailing(self.template, 'ervicesNS/')) + self.assertEqual('boris/search/another/path/segment/that runs on', + client._trailing(self.template, 'ervicesNS/')) def test_trailing_with_n_args_works(self): self.assertEqual( @@ -337,11 +340,12 @@ def test_trailing_with_n_args_works(self): client._trailing(self.template, 'servicesNS/', '/', '/') ) + class TestEntityNamespacing(testlib.SDKTestCase): def test_proper_namespace_with_arguments(self): entity = self.service.apps['search'] - self.assertEqual((None,None,"global"), entity._proper_namespace(sharing="global")) - self.assertEqual((None,"search","app"), entity._proper_namespace(sharing="app", app="search")) + self.assertEqual((None, None, "global"), entity._proper_namespace(sharing="global")) + self.assertEqual((None, "search", "app"), entity._proper_namespace(sharing="app", app="search")) self.assertEqual( ("admin", "search", "user"), entity._proper_namespace(sharing="user", app="search", owner="admin") @@ -360,6 +364,7 @@ def test_proper_namespace_with_service_namespace(self): self.service.namespace.sharing) self.assertEqual(namespace, entity._proper_namespace()) + if __name__ == "__main__": try: import unittest2 as unittest diff --git a/tests/testlib.py b/tests/testlib.py index 61be722e..ae3246a2 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -21,6 +21,7 @@ import sys from splunklib import six + # Run the test suite on the SDK without installing it. sys.path.insert(0, '../') sys.path.insert(0, '../examples') @@ -28,6 +29,7 @@ import splunklib.client as client from time import sleep from datetime import datetime, timedelta + try: import unittest2 as unittest except ImportError: @@ -43,17 +45,21 @@ import time import logging + logging.basicConfig( filename='test.log', level=logging.DEBUG, format="%(asctime)s:%(levelname)s:%(message)s") + class NoRestartRequiredError(Exception): pass + class WaitTimedOutError(Exception): pass + def to_bool(x): if x == '1': return True @@ -64,7 +70,7 @@ def to_bool(x): def tmpname(): - name = 'delete-me-' + str(os.getpid()) + str(time.time()).replace('.','-') + name = 'delete-me-' + str(os.getpid()) + str(time.time()).replace('.', '-') return name @@ -77,7 +83,7 @@ def wait(predicate, timeout=60, pause_time=0.5): logging.debug("wait timed out after %d seconds", timeout) raise WaitTimedOutError sleep(pause_time) - logging.debug("wait finished after %s seconds", datetime.now()-start) + logging.debug("wait finished after %s seconds", datetime.now() - start) class SDKTestCase(unittest.TestCase): @@ -94,7 +100,7 @@ def assertEventuallyTrue(self, predicate, timeout=30, pause_time=0.5, logging.debug("wait timed out after %d seconds", timeout) self.fail(timeout_message) sleep(pause_time) - logging.debug("wait finished after %s seconds", datetime.now()-start) + logging.debug("wait finished after %s seconds", datetime.now() - start) def check_content(self, entity, **kwargs): for k, v in six.iteritems(kwargs): @@ -163,12 +169,11 @@ def fake_splunk_version(self, version): finally: self.service._splunk_version = original_version - def install_app_from_collection(self, name): collectionName = 'sdkappcollection' if collectionName not in self.service.apps: raise ValueError("sdk-test-application not installed in splunkd") - appPath = self.pathInApp(collectionName, ["build", name+".tar"]) + appPath = self.pathInApp(collectionName, ["build", name + ".tar"]) kwargs = {"update": True, "name": appPath, "filename": True} try: @@ -233,7 +238,7 @@ def restartSplunk(self, timeout=240): @classmethod def setUpClass(cls): cls.opts = parse([], {}, ".env") - + cls.opts.kwargs.update({'retries': 3}) # Before we start, make sure splunk doesn't need a restart. service = client.connect(**cls.opts.kwargs) if service.restart_required: @@ -241,6 +246,7 @@ def setUpClass(cls): def setUp(self): unittest.TestCase.setUp(self) + self.opts.kwargs.update({'retries': 3}) self.service = client.connect(**self.opts.kwargs) # If Splunk is in a state requiring restart, go ahead # and restart. That way we'll be sane for the rest of