use WAKATIME_HOME env variable for session and offline caching

This commit is contained in:
Alan Hamlett 2017-10-29 11:16:23 -07:00
parent c8237c30c5
commit 6cd300a30b
6 changed files with 101 additions and 32 deletions

View file

@ -11,6 +11,7 @@ import shutil
import sys import sys
import uuid import uuid
from testfixtures import log_capture from testfixtures import log_capture
from wakatime.arguments import parseArguments
from wakatime.compat import u from wakatime.compat import u
from wakatime.constants import ( from wakatime.constants import (
API_ERROR, API_ERROR,
@ -620,3 +621,24 @@ class ArgumentsTestCase(utils.TestCase):
self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.push'].assert_not_called()
self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called() self.patched['wakatime.offlinequeue.Queue.pop'].assert_not_called()
def test_uses_wakatime_home_env_variable(self):
with utils.TemporaryDirectory() as tempdir:
entity = 'tests/samples/codefiles/twolinefile.txt'
shutil.copy(entity, os.path.join(tempdir, 'twolinefile.txt'))
entity = os.path.realpath(os.path.join(tempdir, 'twolinefile.txt'))
key = str(uuid.uuid4())
config = 'tests/samples/configs/good_config.cfg'
logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log'))
args = ['--file', entity, '--key', key, '--config', config]
with utils.mock.patch.object(sys, 'argv', ['wakatime'] + args):
args, configs = parseArguments()
self.assertEquals(args.logfile, None)
with utils.mock.patch('os.environ.get') as mock_env:
mock_env.return_value = os.path.realpath(tempdir)
args, configs = parseArguments()
self.assertEquals(args.logfile, logfile)

View file

@ -101,7 +101,7 @@ class LoggingTestCase(utils.TestCase):
config = 'tests/samples/configs/good_config.cfg' config = 'tests/samples/configs/good_config.cfg'
shutil.copy(config, os.path.join(tempdir, '.wakatime.cfg')) shutil.copy(config, os.path.join(tempdir, '.wakatime.cfg'))
config = os.path.realpath(os.path.join(tempdir, '.wakatime.cfg')) config = os.path.realpath(os.path.join(tempdir, '.wakatime.cfg'))
logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log')) expected_logfile = os.path.realpath(os.path.join(tempdir, '.wakatime.log'))
with utils.mock.patch('wakatime.main.os.environ.get') as mock_env: with utils.mock.patch('wakatime.main.os.environ.get') as mock_env:
mock_env.return_value = tempdir mock_env.return_value = tempdir
@ -116,8 +116,8 @@ class LoggingTestCase(utils.TestCase):
self.assertEquals(sys.stderr.getvalue(), '') self.assertEquals(sys.stderr.getvalue(), '')
self.assertEquals(logging.WARNING, logging.getLogger('WakaTime').level) self.assertEquals(logging.WARNING, logging.getLogger('WakaTime').level)
actual_logfile = os.path.realpath(logging.getLogger('WakaTime').handlers[0].baseFilename) logfile = os.path.realpath(logging.getLogger('WakaTime').handlers[0].baseFilename)
self.assertEquals(logfile, actual_logfile) self.assertEquals(logfile, expected_logfile)
logs.check() logs.check()
@log_capture() @log_capture()

View file

@ -199,12 +199,14 @@ class OfflineQueueTestCase(utils.TestCase):
class CustomResponse(Response): class CustomResponse(Response):
count = 0 count = 0
@property @property
def status_code(self): def status_code(self):
if self.count > 2: if self.count > 2:
return 401 return 401
self.count += 1 self.count += 1
return 201 return 201
@status_code.setter @status_code.setter
def status_code(self, value): def status_code(self, value):
pass pass
@ -254,12 +256,14 @@ class OfflineQueueTestCase(utils.TestCase):
class CustomResponse(Response): class CustomResponse(Response):
count = 0 count = 0
@property @property
def status_code(self): def status_code(self):
if self.count > 2: if self.count > 2:
return 500 return 500
self.count += 1 self.count += 1
return 201 return 201
@status_code.setter @status_code.setter
def status_code(self, value): def status_code(self, value):
pass pass
@ -354,12 +358,24 @@ class OfflineQueueTestCase(utils.TestCase):
saved_heartbeat = queue.pop() saved_heartbeat = queue.pop()
self.assertEquals(None, saved_heartbeat) self.assertEquals(None, saved_heartbeat)
def test_get_db_file(self): def test_uses_home_folder_by_default(self):
queue = Queue() queue = Queue()
db_file = queue.get_db_file() db_file = queue.get_db_file()
expected = os.path.join(os.path.expanduser('~'), '.wakatime.db') expected = os.path.join(os.path.expanduser('~'), '.wakatime.db')
self.assertEquals(db_file, expected) self.assertEquals(db_file, expected)
def test_uses_wakatime_home_env_variable(self):
queue = Queue()
with utils.TemporaryDirectory() as tempdir:
expected = os.path.realpath(os.path.join(tempdir, '.wakatime.db'))
with utils.mock.patch('os.environ.get') as mock_env:
mock_env.return_value = os.path.realpath(tempdir)
actual = queue.get_db_file()
self.assertEquals(actual, expected)
@log_capture() @log_capture()
def test_heartbeat_saved_when_requests_raises_exception(self, logs): def test_heartbeat_saved_when_requests_raises_exception(self, logs):
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View file

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
from wakatime.session_cache import SessionCache from wakatime.session_cache import SessionCache
from wakatime.logger import setup_logging from wakatime.logger import setup_logging
from . import utils from . import utils
@ -16,6 +17,7 @@ class SessionCacheTestCase(utils.TestCase):
def setUp(self): def setUp(self):
super(SessionCacheTestCase, self).setUp() super(SessionCacheTestCase, self).setUp()
class MockArgs(object): class MockArgs(object):
timestamp = 0 timestamp = 0
is_write = False is_write = False
@ -30,7 +32,9 @@ class SessionCacheTestCase(utils.TestCase):
def test_can_crud_session(self): def test_can_crud_session(self):
with utils.NamedTemporaryFile() as fh: with utils.NamedTemporaryFile() as fh:
cache = SessionCache() cache = SessionCache()
cache.DB_FILE = fh.name
with utils.mock.patch('wakatime.session_cache.SessionCache.get_db_file') as mock_dbfile:
mock_dbfile.return_value = fh.name
session = cache.get() session = cache.get()
session.headers.update({'x-test': 'abc'}) session.headers.update({'x-test': 'abc'})
@ -44,7 +48,9 @@ class SessionCacheTestCase(utils.TestCase):
def test_handles_connection_exception(self): def test_handles_connection_exception(self):
with utils.NamedTemporaryFile() as fh: with utils.NamedTemporaryFile() as fh:
cache = SessionCache() cache = SessionCache()
cache.DB_FILE = fh.name
with utils.mock.patch('wakatime.session_cache.SessionCache.get_db_file') as mock_dbfile:
mock_dbfile.return_value = fh.name
with utils.mock.patch('wakatime.session_cache.SessionCache.connect') as mock_connect: with utils.mock.patch('wakatime.session_cache.SessionCache.connect') as mock_connect:
mock_connect.side_effect = OSError('') mock_connect.side_effect = OSError('')
@ -57,3 +63,18 @@ class SessionCacheTestCase(utils.TestCase):
cache.delete() cache.delete()
session = cache.get() session = cache.get()
self.assertEquals(session.headers.get('x-test'), None) self.assertEquals(session.headers.get('x-test'), None)
def test_uses_wakatime_home_env_variable(self):
with utils.TemporaryDirectory() as tempdir:
expected = os.path.realpath(os.path.join(os.path.expanduser('~'), '.wakatime.db'))
cache = SessionCache()
actual = cache.get_db_file()
self.assertEquals(actual, expected)
with utils.mock.patch('os.environ.get') as mock_env:
mock_env.return_value = os.path.realpath(tempdir)
expected = os.path.realpath(os.path.join(tempdir, '.wakatime.db'))
actual = cache.get_db_file()
self.assertEquals(actual, expected)

View file

@ -27,11 +27,14 @@ log = logging.getLogger('WakaTime')
class Queue(object): class Queue(object):
db_file = os.path.join(os.path.expanduser('~'), '.wakatime.db') db_file = '.wakatime.db'
table_name = 'heartbeat_1' table_name = 'heartbeat_1'
def get_db_file(self): def get_db_file(self):
return self.db_file home = '~'
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME')
return os.path.join(os.path.expanduser(home), '.wakatime.db')
def connect(self): def connect(self):
conn = sqlite3.connect(self.get_db_file(), isolation_level=None) conn = sqlite3.connect(self.get_db_file(), isolation_level=None)

View file

@ -30,14 +30,21 @@ log = logging.getLogger('WakaTime')
class SessionCache(object): class SessionCache(object):
DB_FILE = os.path.join(os.path.expanduser('~'), '.wakatime.db') db_file = '.wakatime.db'
table_name = 'session'
def get_db_file(self):
home = '~'
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME')
return os.path.join(os.path.expanduser(home), '.wakatime.db')
def connect(self): def connect(self):
conn = sqlite3.connect(self.DB_FILE, isolation_level=None) conn = sqlite3.connect(self.get_db_file(), isolation_level=None)
c = conn.cursor() c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS session ( c.execute('''CREATE TABLE IF NOT EXISTS {0} (
value BLOB) value BLOB)
''') '''.format(self.table_name))
return (conn, c) return (conn, c)
def save(self, session): def save(self, session):
@ -48,11 +55,11 @@ class SessionCache(object):
return return
try: try:
conn, c = self.connect() conn, c = self.connect()
c.execute('DELETE FROM session') c.execute('DELETE FROM {0}'.format(self.table_name))
values = { values = {
'value': sqlite3.Binary(pickle.dumps(session, protocol=2)), 'value': sqlite3.Binary(pickle.dumps(session, protocol=2)),
} }
c.execute('INSERT INTO session VALUES (:value)', values) c.execute('INSERT INTO {0} VALUES (:value)'.format(self.table_name), values)
conn.commit() conn.commit()
conn.close() conn.close()
except: # pragma: nocover except: # pragma: nocover
@ -76,7 +83,7 @@ class SessionCache(object):
session = None session = None
try: try:
c.execute('BEGIN IMMEDIATE') c.execute('BEGIN IMMEDIATE')
c.execute('SELECT value FROM session LIMIT 1') c.execute('SELECT value FROM {0} LIMIT 1'.format(self.table_name))
row = c.fetchone() row = c.fetchone()
if row is not None: if row is not None:
session = pickle.loads(row[0]) session = pickle.loads(row[0])
@ -98,7 +105,7 @@ class SessionCache(object):
return return
try: try:
conn, c = self.connect() conn, c = self.connect()
c.execute('DELETE FROM session') c.execute('DELETE FROM {0}'.format(self.table_name))
conn.commit() conn.commit()
conn.close() conn.close()
except: except: