improve traceback logging and non-utf8 handling

This commit is contained in:
Alan Hamlett 2016-09-01 11:49:12 +02:00
parent c08288eefd
commit fd322ba3b6
12 changed files with 101 additions and 68 deletions

View file

@ -1,8 +1,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from wakatime.compat import is_py3, u
from wakatime.main import execute from wakatime.main import execute
from wakatime.packages import requests from wakatime.packages import requests
from wakatime.packages.requests.models import Response
import logging import logging
import os import os
@ -10,8 +12,6 @@ import tempfile
import time import time
import sys import sys
from testfixtures import log_capture from testfixtures import log_capture
from wakatime.compat import u
from wakatime.packages.requests.models import Response
from . import utils from . import utils
@ -139,9 +139,9 @@ class LoggingTestCase(utils.TestCase):
self.assertEquals(sys.stdout.getvalue(), '') self.assertEquals(sys.stdout.getvalue(), '')
self.assertEquals(sys.stderr.getvalue(), '') self.assertEquals(sys.stderr.getvalue(), '')
output = u("\n").join([u(' ').join(x) for x in logs.actual()]) log_output = u("\n").join([u(' ').join(x) for x in logs.actual()])
self.assertIn(u('WakaTime DEBUG Traceback (most recent call last):'), output) self.assertIn(u('WakaTime DEBUG Traceback (most recent call last):'), log_output)
self.assertIn(u('Exception: FooBar'), output) self.assertIn(u('Exception: FooBar'), log_output)
@log_capture() @log_capture()
def test_exception_traceback_not_logged_normally(self, logs): def test_exception_traceback_not_logged_normally(self, logs):
@ -164,5 +164,23 @@ class LoggingTestCase(utils.TestCase):
self.assertEquals(sys.stdout.getvalue(), '') self.assertEquals(sys.stdout.getvalue(), '')
self.assertEquals(sys.stderr.getvalue(), '') self.assertEquals(sys.stderr.getvalue(), '')
output = u("\n").join([u(' ').join(x) for x in logs.actual()]) log_output = u("\n").join([u(' ').join(x) for x in logs.actual()])
self.assertEquals(u(''), output) self.assertEquals(u(''), log_output)
@log_capture()
def test_can_log_invalid_utf8(self, logs):
logging.disable(logging.NOTSET)
data = bytes('\xab', 'utf-16') if is_py3 else '\xab'
with self.assertRaises(UnicodeDecodeError):
data.decode('utf8')
logger = logging.getLogger('WakaTime')
logger.error(data)
found = False
for msg in list(logs.actual())[0]:
if u(msg) == u(data):
found = True
self.assertTrue(found)

View file

@ -867,8 +867,12 @@ class MainTestCase(utils.TestCase):
shutil.copy(entity, os.path.join(tempdir, 'emptyfile.txt')) shutil.copy(entity, os.path.join(tempdir, 'emptyfile.txt'))
entity = os.path.realpath(os.path.join(tempdir, 'emptyfile.txt')) entity = os.path.realpath(os.path.join(tempdir, 'emptyfile.txt'))
timezone = tzlocal.get_localzone() class TZ(object):
timezone.zone = 'tz汉语' if is_py3 else 'tz\xe6\xb1\x89\xe8\xaf\xad' @property
def zone(self):
return 'tz汉语' if is_py3 else 'tz\xe6\xb1\x89\xe8\xaf\xad'
timezone = TZ()
with utils.mock.patch('wakatime.packages.tzlocal.get_localzone') as mock_getlocalzone: with utils.mock.patch('wakatime.packages.tzlocal.get_localzone') as mock_getlocalzone:
mock_getlocalzone.return_value = timezone mock_getlocalzone.return_value = timezone
@ -887,6 +891,40 @@ class MainTestCase(utils.TestCase):
headers = self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].call_args[0][0].headers headers = self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].call_args[0][0].headers
self.assertEquals(headers.get('TimeZone'), u(timezone.zone).encode('utf-8') if is_py3 else timezone.zone) self.assertEquals(headers.get('TimeZone'), u(timezone.zone).encode('utf-8') if is_py3 else timezone.zone)
def test_timezone_with_invalid_encoding(self):
response = Response()
response.status_code = 201
self.patched['wakatime.packages.requests.adapters.HTTPAdapter.send'].return_value = response
with utils.TemporaryDirectory() as tempdir:
entity = 'tests/samples/codefiles/emptyfile.txt'
shutil.copy(entity, os.path.join(tempdir, 'emptyfile.txt'))
entity = os.path.realpath(os.path.join(tempdir, 'emptyfile.txt'))
class TZ(object):
@property
def zone(self):
return bytes('\xab', 'utf-16') if is_py3 else '\xab'
timezone = TZ()
with self.assertRaises(UnicodeDecodeError):
timezone.zone.decode('utf8')
with utils.mock.patch('wakatime.packages.tzlocal.get_localzone') as mock_getlocalzone:
mock_getlocalzone.return_value = timezone
config = 'tests/samples/configs/has_everything.cfg'
args = ['--file', entity, '--config', config, '--timeout', '15']
retval = execute(args)
self.assertEquals(retval, SUCCESS)
self.patched['wakatime.session_cache.SessionCache.get'].assert_called_once_with()
self.patched['wakatime.session_cache.SessionCache.delete'].assert_not_called()
self.patched['wakatime.session_cache.SessionCache.save'].assert_called_once_with(ANY)
self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called()
self.patched['wakatime.offlinequeue.Queue.pop'].assert_called_once_with()
def test_tzlocal_exception(self): def test_tzlocal_exception(self):
response = Response() response = Response()
response.status_code = 201 response.status_code = 201

View file

@ -31,7 +31,7 @@ if is_py2: # pragma: nocover
try: try:
return unicode(text) return unicode(text)
except: except:
return text return text.decode('utf-8', 'replace')
open = codecs.open open = codecs.open
basestring = basestring basestring = basestring
@ -52,7 +52,7 @@ elif is_py3: # pragma: nocover
try: try:
return str(text) return str(text)
except: except:
return text return text.decode('utf-8', 'replace')
open = open open = open
basestring = (str, bytes) basestring = (str, bytes)

View file

@ -12,7 +12,6 @@
import logging import logging
import re import re
import sys import sys
import traceback
from ..compat import u, open, import_module from ..compat import u, open, import_module
from ..exceptions import NotYetImplemented from ..exceptions import NotYetImplemented
@ -120,7 +119,7 @@ class DependencyParser(object):
except AttributeError: except AttributeError:
log.debug('Module {0} is missing class {1}'.format(module.__name__, class_name)) log.debug('Module {0} is missing class {1}'.format(module.__name__, class_name))
except ImportError: except ImportError:
log.debug(traceback.format_exc()) log.traceback(logging.DEBUG)
def parse(self): def parse(self):
if self.parser: if self.parser:

View file

@ -25,23 +25,6 @@ except (ImportError, SyntaxError): # pragma: nocover
import json import json
class CustomEncoder(json.JSONEncoder):
def encode(self, obj):
try:
return super(CustomEncoder, self).encode(obj)
except UnicodeDecodeError:
obj = u(obj)
return super(CustomEncoder, self).encode(obj)
def default(self, obj):
try:
return super(CustomEncoder, self).default(obj)
except TypeError:
obj = u(obj)
return super(CustomEncoder, self).default(obj)
class JsonFormatter(logging.Formatter): class JsonFormatter(logging.Formatter):
def setup(self, timestamp, is_write, entity, version, plugin, verbose, def setup(self, timestamp, is_write, entity, version, plugin, verbose,
@ -58,33 +41,28 @@ class JsonFormatter(logging.Formatter):
data = OrderedDict([ data = OrderedDict([
('now', self.formatTime(record, self.datefmt)), ('now', self.formatTime(record, self.datefmt)),
]) ])
data['version'] = self.version data['version'] = u(self.version)
data['plugin'] = self.plugin if self.plugin:
data['plugin'] = u(self.plugin)
data['time'] = self.timestamp data['time'] = self.timestamp
if self.verbose: if self.verbose:
data['caller'] = record.pathname data['caller'] = u(record.pathname)
data['lineno'] = record.lineno data['lineno'] = record.lineno
if self.is_write:
data['is_write'] = self.is_write data['is_write'] = self.is_write
data['file'] = self.entity data['file'] = u(self.entity)
if not self.is_write:
del data['is_write']
data['level'] = record.levelname data['level'] = record.levelname
data['message'] = record.getMessage() if self.warnings else record.msg data['message'] = u(record.getMessage() if self.warnings else record.msg)
if not self.plugin: return json.dumps(data)
del data['plugin']
return CustomEncoder().encode(data)
def traceback(self, lvl=None):
logger = logging.getLogger('WakaTime')
if not lvl:
lvl = logger.getEffectiveLevel()
logger.log(lvl, traceback.format_exc())
def traceback_formatter(*args, **kwargs): def formatException(self, exc_info):
level = kwargs.get('level', args[0] if len(args) else None) raise RuntimeError('Use traceback method instead.')
if level:
level = level.lower()
if level == 'warn' or level == 'warning':
logging.getLogger('WakaTime').warning(traceback.format_exc())
elif level == 'debug':
logging.getLogger('WakaTime').debug(traceback.format_exc())
else:
logging.getLogger('WakaTime').error(traceback.format_exc())
def set_log_level(logger, args): def set_log_level(logger, args):
@ -117,7 +95,7 @@ def setup_logging(args, version):
logger.addHandler(handler) logger.addHandler(handler)
# add custom traceback logging method # add custom traceback logging method
logger.traceback = traceback_formatter logger.traceback = formatter.traceback
warnings_formatter = JsonFormatter(datefmt='%Y/%m/%d %H:%M:%S %z') warnings_formatter = JsonFormatter(datefmt='%Y/%m/%d %H:%M:%S %z')
warnings_formatter.setup( warnings_formatter.setup(

View file

@ -545,6 +545,6 @@ def execute(argv=None):
return retval return retval
except: except:
log.traceback() log.traceback(logging.ERROR)
print(traceback.format_exc()) print(traceback.format_exc())
return UNKNOWN_ERROR return UNKNOWN_ERROR

View file

@ -80,7 +80,7 @@ class Queue(object):
try: try:
conn, c = self.connect() conn, c = self.connect()
except sqlite3.Error: except sqlite3.Error:
log.traceback('debug') log.traceback(logging.DEBUG)
return None return None
loop = True loop = True
while loop and tries > -1: while loop and tries > -1:
@ -118,11 +118,11 @@ class Queue(object):
} }
loop = False loop = False
except sqlite3.Error: # pragma: nocover except sqlite3.Error: # pragma: nocover
log.traceback('debug') log.traceback(logging.DEBUG)
sleep(wait) sleep(wait)
tries -= 1 tries -= 1
try: try:
conn.close() conn.close()
except sqlite3.Error: # pragma: nocover except sqlite3.Error: # pragma: nocover
log.traceback('debug') log.traceback(logging.DEBUG)
return heartbeat return heartbeat

View file

@ -44,9 +44,9 @@ class Git(BaseProject):
with open(head, 'r', encoding=sys.getfilesystemencoding()) as fh: with open(head, 'r', encoding=sys.getfilesystemencoding()) as fh:
return self._get_branch_from_head_file(fh.readline()) return self._get_branch_from_head_file(fh.readline())
except: except:
log.traceback('warn') log.traceback(logging.WARNING)
except IOError: # pragma: nocover except IOError: # pragma: nocover
log.traceback('warn') log.traceback(logging.WARNING)
return u('master') return u('master')
def _project_base(self): def _project_base(self):

View file

@ -42,9 +42,9 @@ class Mercurial(BaseProject):
with open(branch_file, 'r', encoding=sys.getfilesystemencoding()) as fh: with open(branch_file, 'r', encoding=sys.getfilesystemencoding()) as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1]) return u(fh.readline().strip().rsplit('/', 1)[-1])
except: except:
log.traceback('warn') log.traceback(logging.WARNING)
except IOError: # pragma: nocover except IOError: # pragma: nocover
log.traceback('warn') log.traceback(logging.WARNING)
return u('default') return u('default')
def _find_hg_config_dir(self, path): def _find_hg_config_dir(self, path):

View file

@ -41,9 +41,9 @@ class WakaTimeProjectFile(BaseProject):
self._project_name = u(fh.readline().strip()) self._project_name = u(fh.readline().strip())
self._project_branch = u(fh.readline().strip()) self._project_branch = u(fh.readline().strip())
except: except:
log.traceback('warn') log.traceback(logging.WARNING)
except IOError: # pragma: nocover except IOError: # pragma: nocover
log.traceback('warn') log.traceback(logging.WARNING)
return True return True
return False return False

View file

@ -57,7 +57,7 @@ class SessionCache(object):
conn.commit() conn.commit()
conn.close() conn.close()
except: # pragma: nocover except: # pragma: nocover
log.traceback('debug') log.traceback(logging.DEBUG)
def get(self): def get(self):
@ -72,7 +72,7 @@ class SessionCache(object):
try: try:
conn, c = self.connect() conn, c = self.connect()
except: except:
log.traceback('debug') log.traceback(logging.DEBUG)
return requests.session() return requests.session()
session = None session = None
@ -83,12 +83,12 @@ class SessionCache(object):
if row is not None: if row is not None:
session = pickle.loads(row[0]) session = pickle.loads(row[0])
except: # pragma: nocover except: # pragma: nocover
log.traceback('debug') log.traceback(logging.DEBUG)
try: try:
conn.close() conn.close()
except: # pragma: nocover except: # pragma: nocover
log.traceback('debug') log.traceback(logging.DEBUG)
return session if session is not None else requests.session() return session if session is not None else requests.session()
@ -105,4 +105,4 @@ class SessionCache(object):
conn.commit() conn.commit()
conn.close() conn.close()
except: except:
log.traceback('debug') log.traceback(logging.DEBUG)

View file

@ -236,5 +236,5 @@ def get_file_head(file_name):
with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh: with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh:
text = fh.read(512000) # pragma: nocover text = fh.read(512000) # pragma: nocover
except: except:
log.traceback('debug') log.traceback(logging.DEBUG)
return text return text