From ccf5e39f44ef7835cb5c5395ec698be4df2ce6f3 Mon Sep 17 00:00:00 2001 From: Alan Hamlett Date: Wed, 24 May 2017 22:54:58 -0700 Subject: [PATCH] only use NTLM proxy after trying non-NTLM proxy --- tests/test_main.py | 40 +++++++++++ tests/test_offlinequeue.py | 6 +- tests/test_proxy.py | 66 ++++++++++++++--- wakatime/main.py | 143 ++++++++++++++++++++++++++----------- 4 files changed, 199 insertions(+), 56 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 6c10430..dd9de8c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -91,6 +91,10 @@ class MainTestCase(utils.TestCase): self.assertEquals(stats, json.loads(self.patched['wakatime.offlinequeue.Queue.push'].call_args[0][1])) self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + def test_400_response(self): response = Response() response.status_code = 400 @@ -118,6 +122,10 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + def test_401_response(self): response = Response() response.status_code = 401 @@ -164,6 +172,10 @@ class MainTestCase(utils.TestCase): self.assertEquals(stats, json.loads(self.patched['wakatime.offlinequeue.Queue.push'].call_args[0][1])) self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + @log_capture() def test_500_response_without_offline_logging(self, logs): logging.disable(logging.NOTSET) @@ -207,6 +219,10 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + @log_capture() def test_requests_exception(self, logs): logging.disable(logging.NOTSET) @@ -264,6 +280,10 @@ class MainTestCase(utils.TestCase): self.assertEquals(stats, json.loads(self.patched['wakatime.offlinequeue.Queue.push'].call_args[0][1])) self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + @log_capture() def test_requests_exception_without_offline_logging(self, logs): logging.disable(logging.NOTSET) @@ -298,6 +318,10 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + @log_capture() def test_invalid_api_key(self, logs): logging.disable(logging.NOTSET) @@ -328,6 +352,8 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_not_called() + def test_nonascii_hostname(self): response = Response() response.status_code = 201 @@ -428,6 +454,10 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_called_once_with() + headers = self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].call_args[0][0].headers + expected_tz = u(bytes('\xab', 'utf-16') if is_py3 else '\xab').encode('utf-8') + self.assertEquals(headers.get('TimeZone'), expected_tz) + def test_tzlocal_exception(self): response = Response() response.status_code = 201 @@ -537,6 +567,10 @@ class MainTestCase(utils.TestCase): self.assertEquals(stats, json.loads(self.patched['wakatime.offlinequeue.Queue.push'].call_args[0][1])) self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) + @log_capture() def test_unhandled_exception(self, logs): logging.disable(logging.NOTSET) @@ -563,6 +597,8 @@ class MainTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() self.patched['wakatime.session_cache.SessionCache.get'].assert_not_called() + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_not_called() + def test_large_file_skips_lines_count(self): response = Response() response.status_code = 0 @@ -611,3 +647,7 @@ class MainTestCase(utils.TestCase): self.assertEquals(heartbeat[key], val) self.assertEquals(stats, json.loads(self.patched['wakatime.offlinequeue.Queue.push'].call_args[0][1])) self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with( + ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True, + ) diff --git a/tests/test_offlinequeue.py b/tests/test_offlinequeue.py index 1cca921..0387f51 100644 --- a/tests/test_offlinequeue.py +++ b/tests/test_offlinequeue.py @@ -18,10 +18,6 @@ from wakatime.constants import ( ) from wakatime.packages.requests.models import Response from . import utils -try: - from mock import call -except ImportError: - from unittest.mock import call try: from .packages import simplejson as json except (ImportError, SyntaxError): @@ -393,5 +389,5 @@ class OfflineQueueTestCase(utils.TestCase): self.assertIn(exception_msg, output[0]) self.patched['wakatime.session_cache.SessionCache.get'].assert_called_once_with() - self.patched['wakatime.session_cache.SessionCache.delete'].assert_has_calls([call(), call()]) + self.patched['wakatime.session_cache.SessionCache.delete'].assert_called_once_with() self.patched['wakatime.session_cache.SessionCache.save'].assert_not_called() diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 877c15b..60cc942 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -10,14 +10,14 @@ import shutil import sys from testfixtures import log_capture from wakatime.compat import u -from wakatime.constants import SUCCESS +from wakatime.constants import API_ERROR, SUCCESS from wakatime.packages.requests.models import Response from . import utils try: - from mock import ANY + from mock import ANY, call except ImportError: - from unittest.mock import ANY + from unittest.mock import ANY, call class ProxyTestCase(utils.TestCase): @@ -113,9 +113,9 @@ class ProxyTestCase(utils.TestCase): self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with(ANY, cert=None, proxies={'https': proxy}, stream=False, timeout=60, verify=True) - def test_ntlm_proxy(self): + def test_ntlm_proxy_used_after_trying_normal_proxy(self): response = Response() - response.status_code = 201 + response.status_code = 400 self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].return_value = response with utils.TemporaryDirectory() as tempdir: @@ -127,18 +127,60 @@ class ProxyTestCase(utils.TestCase): args = ['--file', entity, '--config', config, '--proxy', proxy] retval = execute(args) - self.assertEquals(retval, SUCCESS) + self.assertEquals(retval, API_ERROR) self.assertEquals(sys.stdout.getvalue(), '') self.assertEquals(sys.stderr.getvalue(), '') - self.patched['wakatime.session_cache.SessionCache.get'].assert_called_once_with() - self.patched['wakatime.session_cache.SessionCache.delete'].assert_not_called() - self.patched['wakatime.session_cache.SessionCache.save'].assert_called_once_with(ANY) + self.patched['wakatime.session_cache.SessionCache.get'].assert_has_calls([call(), call()]) + self.patched['wakatime.session_cache.SessionCache.delete'].assert_called_once_with() + self.patched['wakatime.session_cache.SessionCache.save'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() - self.patched['wakatime.offlinequeue.Queue.pop'].assert_called_once_with() + self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() - self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_called_once_with(ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True) + expected_calls = [ + call(ANY, cert=None, proxies={'https': proxy}, stream=False, timeout=60, verify=True), + call(ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True), + ] + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_has_calls(expected_calls) + + @log_capture() + def test_ntlm_proxy_used_after_normal_proxy_raises_exception(self, logs): + logging.disable(logging.NOTSET) + + ex_msg = 'after exception, should still try ntlm proxy' + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].side_effect = RuntimeError(ex_msg) + + with utils.TemporaryDirectory() as tempdir: + + entity = 'tests/samples/codefiles/emptyfile.txt' + shutil.copy(entity, os.path.join(tempdir, 'emptyfile.txt')) + entity = os.path.realpath(os.path.join(tempdir, 'emptyfile.txt')) + proxy = 'domain\\user:pass' + config = 'tests/samples/configs/good_config.cfg' + args = ['--file', entity, '--config', config, '--proxy', proxy] + + retval = execute(args) + + self.assertEquals(retval, API_ERROR) + self.assertEquals(sys.stdout.getvalue(), '') + self.assertEquals(sys.stderr.getvalue(), '') + + log_output = u("\n").join([u(' ').join(x) for x in logs.actual()]) + self.assertIn(ex_msg, log_output) + + self.patched['wakatime.session_cache.SessionCache.get'].assert_has_calls([call(), call()]) + self.patched['wakatime.session_cache.SessionCache.delete'].assert_called_once_with() + self.patched['wakatime.session_cache.SessionCache.save'].assert_not_called() + + self.patched['wakatime.offlinequeue.Queue.push'].assert_called_once_with(ANY, ANY, None) + self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + + expected_calls = [ + call(ANY, cert=None, proxies={'https': proxy}, stream=False, timeout=60, verify=True), + call(ANY, cert=None, proxies={}, stream=False, timeout=60, verify=True), + ] + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_has_calls(expected_calls) @log_capture() def test_invalid_proxy(self, logs): @@ -174,3 +216,5 @@ class ProxyTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + + self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].assert_not_called() diff --git a/wakatime/main.py b/wakatime/main.py index 064fec9..109bf7f 100644 --- a/wakatime/main.py +++ b/wakatime/main.py @@ -63,7 +63,7 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None, entity=None, timestamp=None, is_write=None, plugin=None, offline=None, entity_type='file', hidefilenames=None, proxy=None, nosslverify=None, api_url=None, timeout=None, - **kwargs): + use_ntlm_proxy=False, **kwargs): """Sends heartbeat as POST request to WakaTime api server. Returns `SUCCESS` when heartbeat was sent, otherwise returns an @@ -135,9 +135,10 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None, session_cache = SessionCache() session = session_cache.get() + should_try_ntlm = False proxies = {} if proxy: - if '\\' in proxy: + if use_ntlm_proxy: from .packages.requests_ntlm import HttpNtlmAuth username = proxy.rsplit(':', 1) password = '' @@ -146,38 +147,80 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None, username = username[0] session.auth = HttpNtlmAuth(username, password, session) else: + should_try_ntlm = '\\' in proxy proxies['https'] = proxy - # log time to api + # send request to api response = None try: response = session.post(api_url, data=request_body, headers=headers, proxies=proxies, timeout=timeout, verify=not nosslverify) except RequestException: - exception_data = { - sys.exc_info()[0].__name__: u(sys.exc_info()[1]), - } - if log.isEnabledFor(logging.DEBUG): - exception_data['traceback'] = traceback.format_exc() - if offline: - queue = Queue() - queue.push(data, json.dumps(stats), plugin) - if log.isEnabledFor(logging.DEBUG): - log.warn(exception_data) + if should_try_ntlm: + return send_heartbeat( + project=project, + entity=entity, + timestamp=timestamp, + branch=branch, + hostname=hostname, + stats=stats, + key=key, + is_write=is_write, + plugin=plugin, + offline=offline, + hidefilenames=hidefilenames, + entity_type=entity_type, + proxy=proxy, + api_url=api_url, + timeout=timeout, + use_ntlm_proxy=True, + ) else: - log.error(exception_data) + exception_data = { + sys.exc_info()[0].__name__: u(sys.exc_info()[1]), + } + if log.isEnabledFor(logging.DEBUG): + exception_data['traceback'] = traceback.format_exc() + if offline: + queue = Queue() + queue.push(data, json.dumps(stats), plugin) + if log.isEnabledFor(logging.DEBUG): + log.warn(exception_data) + else: + log.error(exception_data) except: # delete cached session when requests raises unknown exception - exception_data = { - sys.exc_info()[0].__name__: u(sys.exc_info()[1]), - 'traceback': traceback.format_exc(), - } - if offline: - queue = Queue() - queue.push(data, json.dumps(stats), plugin) - log.warn(exception_data) - session_cache.delete() + if should_try_ntlm: + return send_heartbeat( + project=project, + entity=entity, + timestamp=timestamp, + branch=branch, + hostname=hostname, + stats=stats, + key=key, + is_write=is_write, + plugin=plugin, + offline=offline, + hidefilenames=hidefilenames, + entity_type=entity_type, + proxy=proxy, + api_url=api_url, + timeout=timeout, + use_ntlm_proxy=True, + ) + else: + exception_data = { + sys.exc_info()[0].__name__: u(sys.exc_info()[1]), + 'traceback': traceback.format_exc(), + } + if offline: + queue = Queue() + queue.push(data, json.dumps(stats), plugin) + log.warn(exception_data) + session_cache.delete() + return API_ERROR else: code = response.status_code if response is not None else None @@ -188,32 +231,52 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None, }) session_cache.save(session) return SUCCESS - if offline: - if code != 400: - queue = Queue() - queue.push(data, json.dumps(stats), plugin) - if code == 401: + if should_try_ntlm: + return send_heartbeat( + project=project, + entity=entity, + timestamp=timestamp, + branch=branch, + hostname=hostname, + stats=stats, + key=key, + is_write=is_write, + plugin=plugin, + offline=offline, + hidefilenames=hidefilenames, + entity_type=entity_type, + proxy=proxy, + api_url=api_url, + timeout=timeout, + use_ntlm_proxy=True, + ) + else: + if offline: + if code != 400: + queue = Queue() + queue.push(data, json.dumps(stats), plugin) + if code == 401: + log.error({ + 'response_code': code, + 'response_content': content, + }) + session_cache.delete() + return AUTH_ERROR + elif log.isEnabledFor(logging.DEBUG): + log.warn({ + 'response_code': code, + 'response_content': content, + }) + else: log.error({ 'response_code': code, 'response_content': content, }) - session_cache.delete() - return AUTH_ERROR - elif log.isEnabledFor(logging.DEBUG): - log.warn({ - 'response_code': code, - 'response_content': content, - }) else: log.error({ 'response_code': code, 'response_content': content, }) - else: - log.error({ - 'response_code': code, - 'response_content': content, - }) session_cache.delete() return API_ERROR