diff --git a/tests/test_arguments.py b/tests/test_arguments.py index 28ef204..757bc0a 100644 --- a/tests/test_arguments.py +++ b/tests/test_arguments.py @@ -11,6 +11,7 @@ import shutil import sys import uuid from testfixtures import log_capture +from wakatime.arguments import parseArguments from wakatime.compat import u from wakatime.constants import ( API_ERROR, @@ -620,3 +621,24 @@ class ArgumentsTestCase(utils.TestCase): self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() + + def test_uses_wakatime_home_env_variable(self): + with utils.TemporaryDirectory() as tempdir: + entity = 'tests/samples/codefiles/twolinefile.txt' + shutil.copy(entity, os.path.join(tempdir, 'twolinefile.txt')) + entity = os.path.realpath(os.path.join(tempdir, 'twolinefile.txt')) + key = str(uuid.uuid4()) + config = 'tests/samples/configs/good_config.cfg' + logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log')) + + args = ['--file', entity, '--key', key, '--config', config] + + with utils.mock.patch.object(sys, 'argv', ['wakatime'] + args): + args, configs = parseArguments() + self.assertEquals(args.logfile, None) + + with utils.mock.patch('os.environ.get') as mock_env: + mock_env.return_value = os.path.realpath(tempdir) + + args, configs = parseArguments() + self.assertEquals(args.logfile, logfile) diff --git a/tests/test_logging.py b/tests/test_logging.py index 37131d2..d9feb4b 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -101,7 +101,7 @@ class LoggingTestCase(utils.TestCase): config = 'tests/samples/configs/good_config.cfg' shutil.copy(config, os.path.join(tempdir, '.wakatime.cfg')) config = os.path.realpath(os.path.join(tempdir, '.wakatime.cfg')) - logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log')) + expected_logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log')) with utils.mock.patch('wakatime.main.os.environ.get') as mock_env: mock_env.return_value = tempdir @@ -116,8 +116,8 @@ class LoggingTestCase(utils.TestCase): self.assertEquals(sys.stderr.getvalue(), '') self.assertEquals(logging.WARNING, logging.getLogger('WakaTime').level) - actual_logfile = os.path.realpath(logging.getLogger('WakaTime').handlers[0].baseFilename) - self.assertEquals(logfile, actual_logfile) + logfile = os.path.realpath(logging.getLogger('WakaTime').handlers[0].baseFilename) + self.assertEquals(logfile, expected_logfile) logs.check() @log_capture() diff --git a/tests/test_offlinequeue.py b/tests/test_offlinequeue.py index 0387f51..dd9d40e 100644 --- a/tests/test_offlinequeue.py +++ b/tests/test_offlinequeue.py @@ -199,12 +199,14 @@ class OfflineQueueTestCase(utils.TestCase): class CustomResponse(Response): count = 0 + @property def status_code(self): if self.count > 2: return 401 self.count += 1 return 201 + @status_code.setter def status_code(self, value): pass @@ -254,12 +256,14 @@ class OfflineQueueTestCase(utils.TestCase): class CustomResponse(Response): count = 0 + @property def status_code(self): if self.count > 2: return 500 self.count += 1 return 201 + @status_code.setter def status_code(self, value): pass @@ -354,12 +358,24 @@ class OfflineQueueTestCase(utils.TestCase): saved_heartbeat = queue.pop() self.assertEquals(None, saved_heartbeat) - def test_get_db_file(self): + def test_uses_home_folder_by_default(self): queue = Queue() db_file = queue.get_db_file() expected = os.path.join(os.path.expanduser('~'), '.wakatime.db') self.assertEquals(db_file, expected) + def test_uses_wakatime_home_env_variable(self): + queue = Queue() + + with utils.TemporaryDirectory() as tempdir: + expected = os.path.realpath(os.path.join(tempdir, '.wakatime.db')) + + with utils.mock.patch('os.environ.get') as mock_env: + mock_env.return_value = os.path.realpath(tempdir) + + actual = queue.get_db_file() + self.assertEquals(actual, expected) + @log_capture() def test_heartbeat_saved_when_requests_raises_exception(self, logs): logging.disable(logging.NOTSET) diff --git a/tests/test_session_cache.py b/tests/test_session_cache.py index 59e79e0..ce96c78 100644 --- a/tests/test_session_cache.py +++ b/tests/test_session_cache.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +import os from wakatime.session_cache import SessionCache from wakatime.logger import setup_logging from . import utils @@ -16,6 +17,7 @@ class SessionCacheTestCase(utils.TestCase): def setUp(self): super(SessionCacheTestCase, self).setUp() + class MockArgs(object): timestamp = 0 is_write = False @@ -30,30 +32,49 @@ class SessionCacheTestCase(utils.TestCase): def test_can_crud_session(self): with utils.NamedTemporaryFile() as fh: cache = SessionCache() - cache.DB_FILE = fh.name - session = cache.get() - session.headers.update({'x-test': 'abc'}) - cache.save(session) - session = cache.get() - self.assertEquals(session.headers.get('x-test'), 'abc') - cache.delete() - session = cache.get() - self.assertEquals(session.headers.get('x-test'), None) - - def test_handles_connection_exception(self): - with utils.NamedTemporaryFile() as fh: - cache = SessionCache() - cache.DB_FILE = fh.name - - with utils.mock.patch('wakatime.session_cache.SessionCache.connect') as mock_connect: - mock_connect.side_effect = OSError('') + with utils.mock.patch('wakatime.session_cache.SessionCache.get_db_file') as mock_dbfile: + mock_dbfile.return_value = fh.name session = cache.get() session.headers.update({'x-test': 'abc'}) cache.save(session) session = cache.get() - self.assertEquals(session.headers.get('x-test'), None) + self.assertEquals(session.headers.get('x-test'), 'abc') cache.delete() session = cache.get() self.assertEquals(session.headers.get('x-test'), None) + + def test_handles_connection_exception(self): + with utils.NamedTemporaryFile() as fh: + cache = SessionCache() + + with utils.mock.patch('wakatime.session_cache.SessionCache.get_db_file') as mock_dbfile: + mock_dbfile.return_value = fh.name + + with utils.mock.patch('wakatime.session_cache.SessionCache.connect') as mock_connect: + mock_connect.side_effect = OSError('') + + session = cache.get() + session.headers.update({'x-test': 'abc'}) + cache.save(session) + session = cache.get() + self.assertEquals(session.headers.get('x-test'), None) + cache.delete() + session = cache.get() + self.assertEquals(session.headers.get('x-test'), None) + + def test_uses_wakatime_home_env_variable(self): + with utils.TemporaryDirectory() as tempdir: + expected = os.path.realpath(os.path.join(os.path.expanduser('~'), '.wakatime.db')) + + cache = SessionCache() + actual = cache.get_db_file() + self.assertEquals(actual, expected) + + with utils.mock.patch('os.environ.get') as mock_env: + mock_env.return_value = os.path.realpath(tempdir) + + expected = os.path.realpath(os.path.join(tempdir, '.wakatime.db')) + actual = cache.get_db_file() + self.assertEquals(actual, expected) diff --git a/wakatime/offlinequeue.py b/wakatime/offlinequeue.py index bbeb97c..a7826ba 100644 --- a/wakatime/offlinequeue.py +++ b/wakatime/offlinequeue.py @@ -27,11 +27,14 @@ log = logging.getLogger('WakaTime') class Queue(object): - db_file = os.path.join(os.path.expanduser('~'), '.wakatime.db') + db_file = '.wakatime.db' table_name = 'heartbeat_1' def get_db_file(self): - return self.db_file + home = '~' + if os.environ.get('WAKATIME_HOME'): + home = os.environ.get('WAKATIME_HOME') + return os.path.join(os.path.expanduser(home), '.wakatime.db') def connect(self): conn = sqlite3.connect(self.get_db_file(), isolation_level=None) diff --git a/wakatime/session_cache.py b/wakatime/session_cache.py index 7a2ff28..80f5ea0 100644 --- a/wakatime/session_cache.py +++ b/wakatime/session_cache.py @@ -30,14 +30,21 @@ log = logging.getLogger('WakaTime') class SessionCache(object): - DB_FILE = os.path.join(os.path.expanduser('~'), '.wakatime.db') + db_file = '.wakatime.db' + table_name = 'session' + + def get_db_file(self): + home = '~' + if os.environ.get('WAKATIME_HOME'): + home = os.environ.get('WAKATIME_HOME') + return os.path.join(os.path.expanduser(home), '.wakatime.db') def connect(self): - conn = sqlite3.connect(self.DB_FILE, isolation_level=None) + conn = sqlite3.connect(self.get_db_file(), isolation_level=None) c = conn.cursor() - c.execute('''CREATE TABLE IF NOT EXISTS session ( + c.execute('''CREATE TABLE IF NOT EXISTS {0} ( value BLOB) - ''') + '''.format(self.table_name)) return (conn, c) def save(self, session): @@ -48,11 +55,11 @@ class SessionCache(object): return try: conn, c = self.connect() - c.execute('DELETE FROM session') + c.execute('DELETE FROM {0}'.format(self.table_name)) values = { 'value': sqlite3.Binary(pickle.dumps(session, protocol=2)), } - c.execute('INSERT INTO session VALUES (:value)', values) + c.execute('INSERT INTO {0} VALUES (:value)'.format(self.table_name), values) conn.commit() conn.close() except: # pragma: nocover @@ -76,7 +83,7 @@ class SessionCache(object): session = None try: c.execute('BEGIN IMMEDIATE') - c.execute('SELECT value FROM session LIMIT 1') + c.execute('SELECT value FROM {0} LIMIT 1'.format(self.table_name)) row = c.fetchone() if row is not None: session = pickle.loads(row[0]) @@ -98,7 +105,7 @@ class SessionCache(object): return try: conn, c = self.connect() - c.execute('DELETE FROM session') + c.execute('DELETE FROM {0}'.format(self.table_name)) conn.commit() conn.close() except: