upgrade wakatime-cli to v10.0.0

This commit is contained in:
Alan Hamlett 2017-11-08 23:12:05 -08:00
parent 729a4360ba
commit 02786a744e
11 changed files with 477 additions and 435 deletions

View file

@ -1,7 +1,7 @@
__title__ = 'wakatime' __title__ = 'wakatime'
__description__ = 'Common interface to the WakaTime api.' __description__ = 'Common interface to the WakaTime api.'
__url__ = 'https://github.com/wakatime/wakatime' __url__ = 'https://github.com/wakatime/wakatime'
__version_info__ = ('9', '0', '1') __version_info__ = ('10', '0', '0')
__version__ = '.'.join(__version_info__) __version__ = '.'.join(__version_info__)
__author__ = 'Alan Hamlett' __author__ = 'Alan Hamlett'
__author_email__ = 'alan@wakatime.com' __author_email__ = 'alan@wakatime.com'

177
packages/wakatime/api.py Normal file
View file

@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
"""
wakatime.api
~~~~~~~~~~~~
:copyright: (c) 2017 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
from __future__ import print_function
import base64
import logging
import sys
import traceback
from .compat import u, is_py3, json
from .constants import (
API_ERROR,
AUTH_ERROR,
SUCCESS,
UNKNOWN_ERROR,
)
from .offlinequeue import Queue
from .packages.requests.exceptions import RequestException
from .session_cache import SessionCache
from .utils import get_hostname, get_user_agent
from .packages import tzlocal
log = logging.getLogger('WakaTime')
try:
from .packages import requests
except ImportError:
log.traceback(logging.ERROR)
print(traceback.format_exc())
log.error('Please upgrade Python to the latest version.')
print('Please upgrade Python to the latest version.')
sys.exit(UNKNOWN_ERROR)
def send_heartbeats(heartbeats, args, configs, use_ntlm_proxy=False):
"""Send heartbeats to WakaTime API.
Returns `SUCCESS` when heartbeat was sent, otherwise returns an error code.
"""
if len(heartbeats) == 0:
return SUCCESS
api_url = args.api_url
if not api_url:
api_url = 'https://api.wakatime.com/api/v1/heartbeats.bulk'
log.debug('Sending heartbeats to api at %s' % api_url)
timeout = args.timeout
if not timeout:
timeout = 60
data = [h.sanitize().dict() for h in heartbeats]
log.debug(data)
# setup api request
request_body = json.dumps(data)
api_key = u(base64.b64encode(str.encode(args.key) if is_py3 else args.key))
auth = u('Basic {api_key}').format(api_key=api_key)
headers = {
'User-Agent': get_user_agent(args.plugin),
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': auth,
}
hostname = get_hostname(args)
if hostname:
headers['X-Machine-Name'] = u(hostname).encode('utf-8')
# add Olson timezone to request
try:
tz = tzlocal.get_localzone()
except:
tz = None
if tz:
headers['TimeZone'] = u(tz.zone).encode('utf-8')
session_cache = SessionCache()
session = session_cache.get()
should_try_ntlm = False
proxies = {}
if args.proxy:
if use_ntlm_proxy:
from .packages.requests_ntlm import HttpNtlmAuth
username = args.proxy.rsplit(':', 1)
password = ''
if len(username) == 2:
password = username[1]
username = username[0]
session.auth = HttpNtlmAuth(username, password, session)
else:
should_try_ntlm = '\\' in args.proxy
proxies['https'] = args.proxy
# send request to api
response, code = None, None
try:
response = session.post(api_url, data=request_body, headers=headers,
proxies=proxies, timeout=timeout,
verify=not args.nosslverify)
except RequestException:
if should_try_ntlm:
return send_heartbeats(heartbeats, args, configs, use_ntlm_proxy=True)
else:
exception_data = {
sys.exc_info()[0].__name__: u(sys.exc_info()[1]),
}
if log.isEnabledFor(logging.DEBUG):
exception_data['traceback'] = traceback.format_exc()
if args.offline:
queue = Queue(args, configs)
queue.push_many(heartbeats)
if log.isEnabledFor(logging.DEBUG):
log.warn(exception_data)
else:
log.error(exception_data)
except: # delete cached session when requests raises unknown exception
if should_try_ntlm:
return send_heartbeats(heartbeats, args, configs, use_ntlm_proxy=True)
else:
exception_data = {
sys.exc_info()[0].__name__: u(sys.exc_info()[1]),
'traceback': traceback.format_exc(),
}
if args.offline:
queue = Queue(args, configs)
queue.push_many(heartbeats)
log.warn(exception_data)
else:
code = response.status_code if response is not None else None
content = response.text if response is not None else None
if code == requests.codes.created or code == requests.codes.accepted:
log.debug({
'response_code': code,
})
session_cache.save(session)
return SUCCESS
if should_try_ntlm:
return send_heartbeats(heartbeats, args, configs, use_ntlm_proxy=True)
else:
if args.offline:
if code == 400:
log.error({
'response_code': code,
'response_content': content,
})
else:
if log.isEnabledFor(logging.DEBUG):
log.warn({
'response_code': code,
'response_content': content,
})
queue = Queue(args, configs)
queue.push_many(heartbeats)
else:
log.error({
'response_code': code,
'response_content': content,
})
session_cache.delete()
return AUTH_ERROR if code == 401 else API_ERROR

View file

@ -45,7 +45,7 @@ class StoreWithoutQuotes(argparse.Action):
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)
def parseArguments(): def parse_arguments():
"""Parse command line arguments and configs from ~/.wakatime.cfg. """Parse command line arguments and configs from ~/.wakatime.cfg.
Command line arguments take precedence over config file settings. Command line arguments take precedence over config file settings.
Returns instances of ArgumentParser and SafeConfigParser. Returns instances of ArgumentParser and SafeConfigParser.

View file

@ -9,6 +9,7 @@
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
import codecs import codecs
import sys import sys
@ -91,3 +92,9 @@ except ImportError: # pragma: nocover
name = _resolve_name(name[level:], package, level) name = _resolve_name(name[level:], package, level)
__import__(name) __import__(name)
return sys.modules[name] return sys.modules[name]
try:
from .packages import simplejson as json
except (ImportError, SyntaxError):
import json

View file

@ -34,11 +34,6 @@ Exit code used when there was an unhandled exception.
""" """
UNKNOWN_ERROR = 105 UNKNOWN_ERROR = 105
""" Malformed Heartbeat Error
Exit code used when the JSON input from `--extra-heartbeats` is malformed.
"""
MALFORMED_HEARTBEAT_ERROR = 106
""" Connection Error """ Connection Error
Exit code used when there was proxy or other problem connecting to the WakaTime Exit code used when there was proxy or other problem connecting to the WakaTime
API servers. API servers.

View file

@ -0,0 +1,178 @@
# -*- coding: utf-8 -*-
"""
wakatime.heartbeat
~~~~~~~~~~~~~~~~~~
:copyright: (c) 2017 Alan Hamlett.
:license: BSD, see LICENSE for more details.
"""
import os
import logging
import re
from .compat import u, json
from .project import get_project_info
from .stats import get_file_stats
from .utils import get_user_agent, should_exclude, format_file_path
log = logging.getLogger('WakaTime')
class Heartbeat(object):
"""Heartbeat data for sending to API or storing in offline cache."""
skip = False
args = None
configs = None
time = None
entity = None
type = None
is_write = None
project = None
branch = None
language = None
dependencies = None
lines = None
lineno = None
cursorpos = None
user_agent = None
def __init__(self, data, args, configs, _clone=None):
self.args = args
self.configs = configs
self.entity = data.get('entity')
self.time = data.get('time', data.get('timestamp'))
self.is_write = data.get('is_write')
self.user_agent = data.get('user_agent') or get_user_agent(args.plugin)
self.type = data.get('type', data.get('entity_type'))
if self.type not in ['file', 'domain', 'app']:
self.type = 'file'
if not _clone:
exclude = self._excluded_by_pattern()
if exclude:
self.skip = u('Skipping because matches exclude pattern: {pattern}').format(
pattern=u(exclude),
)
return
if self.type == 'file':
self.entity = format_file_path(self.entity)
if self.type == 'file' and not os.path.isfile(self.entity):
self.skip = u('File does not exist; ignoring this heartbeat.')
return
project, branch = get_project_info(configs, self, data)
self.project = project
self.branch = branch
stats = get_file_stats(self.entity,
entity_type=self.type,
lineno=data.get('lineno'),
cursorpos=data.get('cursorpos'),
plugin=args.plugin,
language=data.get('language'))
else:
self.project = data.get('project')
self.branch = data.get('branch')
stats = data
for key in ['language', 'dependencies', 'lines', 'lineno', 'cursorpos']:
if stats.get(key) is not None:
setattr(self, key, stats[key])
def update(self, attrs):
"""Return a copy of the current Heartbeat with updated attributes."""
data = self.dict()
data.update(attrs)
heartbeat = Heartbeat(data, self.args, self.configs, _clone=True)
heartbeat.skip = self.skip
return heartbeat
def sanitize(self):
"""Removes sensitive data including file names and dependencies.
Returns a Heartbeat.
"""
if not self.args.hidefilenames:
return self
if self.entity is None:
return self
if self.type != 'file':
return self
for pattern in self.args.hidefilenames:
try:
compiled = re.compile(pattern, re.IGNORECASE)
if compiled.search(self.entity):
sanitized = {}
sensitive = ['dependencies', 'lines', 'lineno', 'cursorpos', 'branch']
for key, val in self.items():
if key in sensitive:
sanitized[key] = None
else:
sanitized[key] = val
extension = u(os.path.splitext(self.entity)[1])
sanitized['entity'] = u('HIDDEN{0}').format(extension)
return self.update(sanitized)
except re.error as ex:
log.warning(u('Regex error ({msg}) for include pattern: {pattern}').format(
msg=u(ex),
pattern=u(pattern),
))
return self
def json(self):
return json.dumps(self.dict())
def dict(self):
return {
'time': self.time,
'entity': self.entity,
'type': self.type,
'is_write': self.is_write,
'project': self.project,
'branch': self.branch,
'language': self.language,
'dependencies': self.dependencies,
'lines': self.lines,
'lineno': self.lineno,
'cursorpos': self.cursorpos,
'user_agent': self.user_agent,
}
def items(self):
return self.dict().items()
def get_id(self):
return u('{h.time}-{h.type}-{h.project}-{h.branch}-{h.entity}-{h.is_write}').format(
h=self,
)
def _excluded_by_pattern(self):
return should_exclude(self.entity, self.args.include, self.args.exclude)
def __repr__(self):
return self.json()
def __bool__(self):
return not self.skip
def __nonzero__(self):
return self.__bool__()
def __getitem__(self, key):
return self.dict()[key]

View file

@ -11,387 +11,67 @@
from __future__ import print_function from __future__ import print_function
import base64
import logging import logging
import os import os
import re
import sys import sys
import traceback import traceback
import socket
pwd = os.path.dirname(os.path.abspath(__file__)) pwd = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(pwd)) sys.path.insert(0, os.path.dirname(pwd))
sys.path.insert(0, os.path.join(pwd, 'packages')) sys.path.insert(0, os.path.join(pwd, 'packages'))
from .__about__ import __version__ from .__about__ import __version__
from .arguments import parseArguments from .api import send_heartbeats
from .compat import u, is_py3 from .arguments import parse_arguments
from .compat import u, json
from .constants import ( from .constants import (
API_ERROR,
AUTH_ERROR,
SUCCESS, SUCCESS,
UNKNOWN_ERROR, UNKNOWN_ERROR,
MALFORMED_HEARTBEAT_ERROR,
) )
from .logger import setup_logging from .logger import setup_logging
log = logging.getLogger('WakaTime') log = logging.getLogger('WakaTime')
try: from .heartbeat import Heartbeat
from .packages import requests
except ImportError:
log.traceback(logging.ERROR)
print(traceback.format_exc())
log.error('Please upgrade Python to the latest version.')
print('Please upgrade Python to the latest version.')
sys.exit(UNKNOWN_ERROR)
from .offlinequeue import Queue from .offlinequeue import Queue
from .packages.requests.exceptions import RequestException
from .project import get_project_info
from .session_cache import SessionCache
from .stats import get_file_stats
from .utils import get_user_agent, should_exclude, format_file_path
try:
from .packages import simplejson as json # pragma: nocover
except (ImportError, SyntaxError): # pragma: nocover
import json
from .packages import tzlocal
def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None,
entity=None, timestamp=None, is_write=None, plugin=None,
offline=None, entity_type='file', hidefilenames=None,
proxy=None, nosslverify=None, api_url=None, timeout=None,
use_ntlm_proxy=False, **kwargs):
"""Sends heartbeat as POST request to WakaTime api server.
Returns `SUCCESS` when heartbeat was sent, otherwise returns an
error code constant.
"""
if not api_url:
api_url = 'https://api.wakatime.com/api/v1/heartbeats'
if not timeout:
timeout = 60
log.debug('Sending heartbeat to api at %s' % api_url)
data = {
'time': timestamp,
'entity': entity,
'type': entity_type,
}
if stats.get('lines'):
data['lines'] = stats['lines']
if stats.get('language'):
data['language'] = stats['language']
if stats.get('dependencies'):
data['dependencies'] = stats['dependencies']
if stats.get('lineno'):
data['lineno'] = stats['lineno']
if stats.get('cursorpos'):
data['cursorpos'] = stats['cursorpos']
if is_write:
data['is_write'] = is_write
if project:
data['project'] = project
if branch:
data['branch'] = branch
if hidefilenames and entity is not None and entity_type == 'file':
for pattern in hidefilenames:
try:
compiled = re.compile(pattern, re.IGNORECASE)
if compiled.search(entity):
extension = u(os.path.splitext(data['entity'])[1])
data['entity'] = u('HIDDEN{0}').format(extension)
# also delete any sensitive info when hiding file names
sensitive = ['dependencies', 'lines', 'lineno', 'cursorpos', 'branch']
for sensitiveKey in sensitive:
if sensitiveKey in data:
del data[sensitiveKey]
break
except re.error as ex:
log.warning(u('Regex error ({msg}) for include pattern: {pattern}').format(
msg=u(ex),
pattern=u(pattern),
))
log.debug(data)
# setup api request
request_body = json.dumps(data)
api_key = u(base64.b64encode(str.encode(key) if is_py3 else key))
auth = u('Basic {api_key}').format(api_key=api_key)
headers = {
'User-Agent': get_user_agent(plugin),
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': auth,
}
if hostname:
headers['X-Machine-Name'] = u(hostname).encode('utf-8')
# add Olson timezone to request
try:
tz = tzlocal.get_localzone()
except:
tz = None
if tz:
headers['TimeZone'] = u(tz.zone).encode('utf-8')
session_cache = SessionCache()
session = session_cache.get()
should_try_ntlm = False
proxies = {}
if proxy:
if use_ntlm_proxy:
from .packages.requests_ntlm import HttpNtlmAuth
username = proxy.rsplit(':', 1)
password = ''
if len(username) == 2:
password = username[1]
username = username[0]
session.auth = HttpNtlmAuth(username, password, session)
else:
should_try_ntlm = '\\' in proxy
proxies['https'] = proxy
# send request to api
response = None
try:
response = session.post(api_url, data=request_body, headers=headers,
proxies=proxies, timeout=timeout,
verify=not nosslverify)
except RequestException:
if should_try_ntlm:
return send_heartbeat(
project=project,
entity=entity,
timestamp=timestamp,
branch=branch,
hostname=hostname,
stats=stats,
key=key,
is_write=is_write,
plugin=plugin,
offline=offline,
hidefilenames=hidefilenames,
entity_type=entity_type,
proxy=proxy,
api_url=api_url,
timeout=timeout,
use_ntlm_proxy=True,
)
else:
exception_data = {
sys.exc_info()[0].__name__: u(sys.exc_info()[1]),
}
if log.isEnabledFor(logging.DEBUG):
exception_data['traceback'] = traceback.format_exc()
if offline:
queue = Queue()
queue.push(data, json.dumps(stats), plugin)
if log.isEnabledFor(logging.DEBUG):
log.warn(exception_data)
else:
log.error(exception_data)
except: # delete cached session when requests raises unknown exception
if should_try_ntlm:
return send_heartbeat(
project=project,
entity=entity,
timestamp=timestamp,
branch=branch,
hostname=hostname,
stats=stats,
key=key,
is_write=is_write,
plugin=plugin,
offline=offline,
hidefilenames=hidefilenames,
entity_type=entity_type,
proxy=proxy,
api_url=api_url,
timeout=timeout,
use_ntlm_proxy=True,
)
else:
exception_data = {
sys.exc_info()[0].__name__: u(sys.exc_info()[1]),
'traceback': traceback.format_exc(),
}
if offline:
queue = Queue()
queue.push(data, json.dumps(stats), plugin)
log.warn(exception_data)
else:
code = response.status_code if response is not None else None
content = response.text if response is not None else None
if code == requests.codes.created or code == requests.codes.accepted:
log.debug({
'response_code': code,
})
session_cache.save(session)
return SUCCESS
if should_try_ntlm:
return send_heartbeat(
project=project,
entity=entity,
timestamp=timestamp,
branch=branch,
hostname=hostname,
stats=stats,
key=key,
is_write=is_write,
plugin=plugin,
offline=offline,
hidefilenames=hidefilenames,
entity_type=entity_type,
proxy=proxy,
api_url=api_url,
timeout=timeout,
use_ntlm_proxy=True,
)
else:
if offline:
if code != 400:
queue = Queue()
queue.push(data, json.dumps(stats), plugin)
if code == 401:
log.error({
'response_code': code,
'response_content': content,
})
session_cache.delete()
return AUTH_ERROR
elif log.isEnabledFor(logging.DEBUG):
log.warn({
'response_code': code,
'response_content': content,
})
else:
log.error({
'response_code': code,
'response_content': content,
})
else:
log.error({
'response_code': code,
'response_content': content,
})
session_cache.delete()
return API_ERROR
def sync_offline_heartbeats(args, hostname):
"""Sends all heartbeats which were cached in the offline Queue."""
queue = Queue()
while True:
heartbeat = queue.pop()
if heartbeat is None:
break
status = send_heartbeat(
project=heartbeat['project'],
entity=heartbeat['entity'],
timestamp=heartbeat['time'],
branch=heartbeat['branch'],
hostname=hostname,
stats=json.loads(heartbeat['stats']),
key=args.key,
is_write=heartbeat['is_write'],
plugin=heartbeat['plugin'],
offline=args.offline,
hidefilenames=args.hidefilenames,
entity_type=heartbeat['type'],
proxy=args.proxy,
api_url=args.api_url,
timeout=args.timeout,
)
if status != SUCCESS:
if status == AUTH_ERROR:
return AUTH_ERROR
break
return SUCCESS
def process_heartbeat(args, configs, hostname, heartbeat):
exclude = should_exclude(heartbeat['entity'], args.include, args.exclude)
if exclude is not False:
log.debug(u('Skipping because matches exclude pattern: {pattern}').format(
pattern=u(exclude),
))
return SUCCESS
if heartbeat.get('entity_type') not in ['file', 'domain', 'app']:
heartbeat['entity_type'] = 'file'
if heartbeat['entity_type'] == 'file':
heartbeat['entity'] = format_file_path(heartbeat['entity'])
if heartbeat['entity_type'] != 'file' or os.path.isfile(heartbeat['entity']):
stats = get_file_stats(heartbeat['entity'],
entity_type=heartbeat['entity_type'],
lineno=heartbeat.get('lineno'),
cursorpos=heartbeat.get('cursorpos'),
plugin=args.plugin,
language=heartbeat.get('language'))
project = heartbeat.get('project') or heartbeat.get('alternate_project')
branch = None
if heartbeat['entity_type'] == 'file':
project, branch = get_project_info(configs, heartbeat)
heartbeat['project'] = project
heartbeat['branch'] = branch
heartbeat['stats'] = stats
heartbeat['hostname'] = hostname
heartbeat['timeout'] = args.timeout
heartbeat['key'] = args.key
heartbeat['plugin'] = args.plugin
heartbeat['offline'] = args.offline
heartbeat['hidefilenames'] = args.hidefilenames
heartbeat['proxy'] = args.proxy
heartbeat['nosslverify'] = args.nosslverify
heartbeat['api_url'] = args.api_url
return send_heartbeat(**heartbeat)
else:
log.debug('File does not exist; ignoring this heartbeat.')
return SUCCESS
def execute(argv=None): def execute(argv=None):
if argv: if argv:
sys.argv = ['wakatime'] + argv sys.argv = ['wakatime'] + argv
args, configs = parseArguments() args, configs = parse_arguments()
setup_logging(args, __version__) setup_logging(args, __version__)
try: try:
heartbeats = []
hostname = args.hostname or socket.gethostname() hb = Heartbeat(vars(args), args, configs)
if hb:
heartbeat = vars(args) heartbeats.append(hb)
retval = process_heartbeat(args, configs, hostname, heartbeat) else:
log.debug(hb.skip)
if args.extra_heartbeats: if args.extra_heartbeats:
try: try:
for heartbeat in json.loads(sys.stdin.readline()): for extra_data in json.loads(sys.stdin.readline()):
retval = process_heartbeat(args, configs, hostname, heartbeat) hb = Heartbeat(extra_data, args, configs)
except json.JSONDecodeError: if hb:
retval = MALFORMED_HEARTBEAT_ERROR heartbeats.append(hb)
else:
log.debug(hb.skip)
except json.JSONDecodeError as ex:
log.warning(u('Malformed extra heartbeats json: {msg}').format(
msg=u(ex),
))
retval = send_heartbeats(heartbeats, args, configs)
if retval == SUCCESS: if retval == SUCCESS:
retval = sync_offline_heartbeats(args, hostname) queue = Queue(args, configs)
offline_heartbeats = queue.pop_many()
if len(offline_heartbeats) > 0:
retval = send_heartbeats(offline_heartbeats, args, configs)
return retval return retval

View file

@ -14,77 +14,68 @@ import logging
import os import os
from time import sleep from time import sleep
from .compat import json
from .heartbeat import Heartbeat
try: try:
import sqlite3 import sqlite3
HAS_SQL = True HAS_SQL = True
except ImportError: # pragma: nocover except ImportError: # pragma: nocover
HAS_SQL = False HAS_SQL = False
from .compat import u
log = logging.getLogger('WakaTime') log = logging.getLogger('WakaTime')
class Queue(object): class Queue(object):
db_file = '.wakatime.db' db_file = '.wakatime.db'
table_name = 'heartbeat_1' table_name = 'heartbeat_2'
def get_db_file(self): args = None
home = '~' configs = None
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME') def __init__(self, args, configs):
return os.path.join(os.path.expanduser(home), '.wakatime.db') self.args = args
self.configs = configs
def connect(self): def connect(self):
conn = sqlite3.connect(self.get_db_file(), isolation_level=None) conn = sqlite3.connect(self._get_db_file(), isolation_level=None)
c = conn.cursor() c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS {0} ( c.execute('''CREATE TABLE IF NOT EXISTS {0} (
entity text, id text,
type text, heartbeat text)
time real,
project text,
branch text,
is_write integer,
stats text,
misc text,
plugin text)
'''.format(self.table_name)) '''.format(self.table_name))
return (conn, c) return (conn, c)
def push(self, data, stats, plugin, misc=None): def push(self, heartbeat):
if not HAS_SQL: # pragma: nocover if not HAS_SQL:
return return
try: try:
conn, c = self.connect() conn, c = self.connect()
heartbeat = { data = {
'entity': u(data.get('entity')), 'id': heartbeat.get_id(),
'type': u(data.get('type')), 'heartbeat': heartbeat.json(),
'time': data.get('time'),
'project': u(data.get('project')),
'branch': u(data.get('branch')),
'is_write': 1 if data.get('is_write') else 0,
'stats': u(stats),
'misc': u(misc),
'plugin': u(plugin),
} }
c.execute('INSERT INTO {0} VALUES (:entity,:type,:time,:project,:branch,:is_write,:stats,:misc,:plugin)'.format(self.table_name), heartbeat) c.execute('INSERT INTO {0} VALUES (:id,:heartbeat)'.format(self.table_name), data)
conn.commit() conn.commit()
conn.close() conn.close()
except sqlite3.Error: except sqlite3.Error:
log.traceback() log.traceback()
def pop(self): def pop(self):
if not HAS_SQL: # pragma: nocover if not HAS_SQL:
return None return None
tries = 3 tries = 3
wait = 0.1 wait = 0.1
heartbeat = None
try: try:
conn, c = self.connect() conn, c = self.connect()
except sqlite3.Error: except sqlite3.Error:
log.traceback(logging.DEBUG) log.traceback(logging.DEBUG)
return None return None
heartbeat = None
loop = True loop = True
while loop and tries > -1: while loop and tries > -1:
try: try:
@ -92,40 +83,43 @@ class Queue(object):
c.execute('SELECT * FROM {0} LIMIT 1'.format(self.table_name)) c.execute('SELECT * FROM {0} LIMIT 1'.format(self.table_name))
row = c.fetchone() row = c.fetchone()
if row is not None: if row is not None:
values = [] id = row[0]
clauses = [] heartbeat = Heartbeat(json.loads(row[1]), self.args, self.configs, _clone=True)
index = 0 c.execute('DELETE FROM {0} WHERE id=?'.format(self.table_name), [id])
for row_name in ['entity', 'type', 'time', 'project', 'branch', 'is_write']:
if row[index] is not None:
clauses.append('{0}=?'.format(row_name))
values.append(row[index])
else: # pragma: nocover
clauses.append('{0} IS NULL'.format(row_name))
index += 1
if len(values) > 0:
c.execute('DELETE FROM {0} WHERE {1}'.format(self.table_name, ' AND '.join(clauses)), values)
else: # pragma: nocover
c.execute('DELETE FROM {0} WHERE {1}'.format(self.table_name, ' AND '.join(clauses)))
conn.commit() conn.commit()
if row is not None:
heartbeat = {
'entity': row[0],
'type': row[1],
'time': row[2],
'project': row[3],
'branch': row[4],
'is_write': True if row[5] is 1 else False,
'stats': row[6],
'misc': row[7],
'plugin': row[8],
}
loop = False loop = False
except sqlite3.Error: # pragma: nocover except sqlite3.Error:
log.traceback(logging.DEBUG) log.traceback(logging.DEBUG)
sleep(wait) sleep(wait)
tries -= 1 tries -= 1
try: try:
conn.close() conn.close()
except sqlite3.Error: # pragma: nocover except sqlite3.Error:
log.traceback(logging.DEBUG) log.traceback(logging.DEBUG)
return heartbeat return heartbeat
def push_many(self, heartbeats):
for heartbeat in heartbeats:
self.push(heartbeat)
def pop_many(self, limit=None):
if limit is None:
limit = 100
heartbeats = []
count = 0
while limit == 0 or count < limit:
heartbeat = self.pop()
if not heartbeat:
break
heartbeats.append(heartbeat)
count += 1
return heartbeats
def _get_db_file(self):
home = '~'
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME')
return os.path.join(os.path.expanduser(home), '.wakatime.db')

View file

@ -33,7 +33,7 @@ REV_CONTROL_PLUGINS = [
] ]
def get_project_info(configs, heartbeat): def get_project_info(configs, heartbeat, data):
"""Find the current project and branch. """Find the current project and branch.
First looks for a .wakatime-project file. Second, uses the --project arg. First looks for a .wakatime-project file. Second, uses the --project arg.
@ -43,21 +43,27 @@ def get_project_info(configs, heartbeat):
Returns a project, branch tuple. Returns a project, branch tuple.
""" """
project_name, branch_name = None, None project_name, branch_name = heartbeat.project, heartbeat.branch
if heartbeat.type != 'file':
project_name = project_name or heartbeat.args.project or heartbeat.args.alternate_project
return project_name, branch_name
if project_name is None or branch_name is None:
for plugin_cls in CONFIG_PLUGINS: for plugin_cls in CONFIG_PLUGINS:
plugin_name = plugin_cls.__name__.lower() plugin_name = plugin_cls.__name__.lower()
plugin_configs = get_configs_for_plugin(plugin_name, configs) plugin_configs = get_configs_for_plugin(plugin_name, configs)
project = plugin_cls(heartbeat['entity'], configs=plugin_configs) project = plugin_cls(heartbeat.entity, configs=plugin_configs)
if project.process(): if project.process():
project_name = project_name or project.name() project_name = project_name or project.name()
branch_name = project.branch() branch_name = project.branch()
break break
if project_name is None: if project_name is None:
project_name = heartbeat.get('project') project_name = data.get('project') or heartbeat.args.project
if project_name is None or branch_name is None: if project_name is None or branch_name is None:
@ -66,14 +72,14 @@ def get_project_info(configs, heartbeat):
plugin_name = plugin_cls.__name__.lower() plugin_name = plugin_cls.__name__.lower()
plugin_configs = get_configs_for_plugin(plugin_name, configs) plugin_configs = get_configs_for_plugin(plugin_name, configs)
project = plugin_cls(heartbeat['entity'], configs=plugin_configs) project = plugin_cls(heartbeat.entity, configs=plugin_configs)
if project.process(): if project.process():
project_name = project_name or project.name() project_name = project_name or project.name()
branch_name = branch_name or project.branch() branch_name = branch_name or project.branch()
break break
if project_name is None: if project_name is None:
project_name = heartbeat.get('alternate_project') project_name = data.get('alternate_project') or heartbeat.args.alternate_project
return project_name, branch_name return project_name, branch_name

View file

@ -33,14 +33,8 @@ class SessionCache(object):
db_file = '.wakatime.db' db_file = '.wakatime.db'
table_name = 'session' table_name = 'session'
def get_db_file(self):
home = '~'
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME')
return os.path.join(os.path.expanduser(home), '.wakatime.db')
def connect(self): def connect(self):
conn = sqlite3.connect(self.get_db_file(), isolation_level=None) conn = sqlite3.connect(self._get_db_file(), isolation_level=None)
c = conn.cursor() c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS {0} ( c.execute('''CREATE TABLE IF NOT EXISTS {0} (
value BLOB) value BLOB)
@ -110,3 +104,9 @@ class SessionCache(object):
conn.close() conn.close()
except: except:
log.traceback(logging.DEBUG) log.traceback(logging.DEBUG)
def _get_db_file(self):
home = '~'
if os.environ.get('WAKATIME_HOME'):
home = os.environ.get('WAKATIME_HOME')
return os.path.join(os.path.expanduser(home), '.wakatime.db')

View file

@ -14,6 +14,7 @@ import platform
import logging import logging
import os import os
import re import re
import socket
import sys import sys
from .__about__ import __version__ from .__about__ import __version__
@ -48,7 +49,7 @@ def should_exclude(entity, include, exclude):
return False return False
def get_user_agent(plugin): def get_user_agent(plugin=None):
ver = sys.version_info ver = sys.version_info
python_version = '%d.%d.%d.%s.%d' % (ver[0], ver[1], ver[2], ver[3], ver[4]) python_version = '%d.%d.%d.%s.%d' % (ver[0], ver[1], ver[2], ver[3], ver[4])
user_agent = u('wakatime/{ver} ({platform}) Python{py_ver}').format( user_agent = u('wakatime/{ver} ({platform}) Python{py_ver}').format(
@ -77,3 +78,7 @@ def format_file_path(filepath):
except: # pragma: nocover except: # pragma: nocover
pass pass
return filepath return filepath
def get_hostname(args):
return args.hostname or socket.gethostname()