upgrade wakatime cli to v4.1.1

This commit is contained in:
Alan Hamlett 2015-08-25 00:42:37 -07:00
parent b07b59e0c8
commit 7ea51d09ba
40 changed files with 622 additions and 448 deletions

View file

@ -1,7 +1,7 @@
__title__ = 'wakatime' __title__ = 'wakatime'
__description__ = 'Common interface to the WakaTime api.' __description__ = 'Common interface to the WakaTime api.'
__url__ = 'https://github.com/wakatime/wakatime' __url__ = 'https://github.com/wakatime/wakatime'
__version_info__ = ('4', '1', '0') __version_info__ = ('4', '1', '1')
__version__ = '.'.join(__version_info__) __version__ = '.'.join(__version_info__)
__author__ = 'Alan Hamlett' __author__ = 'Alan Hamlett'
__author_email__ = 'alan@wakatime.com' __author_email__ = 'alan@wakatime.com'

View file

@ -22,7 +22,7 @@ import traceback
import socket import socket
try: try:
import ConfigParser as configparser import ConfigParser as configparser
except ImportError: except ImportError: # pragma: nocover
import configparser import configparser
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -33,15 +33,18 @@ from .compat import u, open, is_py3
from .logger import setup_logging from .logger import setup_logging
from .offlinequeue import Queue from .offlinequeue import Queue
from .packages import argparse from .packages import argparse
from .packages import simplejson as json
from .packages.requests.exceptions import RequestException from .packages.requests.exceptions import RequestException
from .project import get_project_info from .project import get_project_info
from .session_cache import SessionCache from .session_cache import SessionCache
from .stats import get_file_stats from .stats import get_file_stats
try: try:
from .packages import tzlocal from .packages import simplejson as json # pragma: nocover
except: except (ImportError, SyntaxError):
from .packages import tzlocal3 as tzlocal import json # pragma: nocover
try:
from .packages import tzlocal # pragma: nocover
except: # pragma: nocover
from .packages import tzlocal3 as tzlocal # pragma: nocover
log = logging.getLogger('WakaTime') log = logging.getLogger('WakaTime')
@ -54,45 +57,6 @@ class FileAction(argparse.Action):
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)
def upgradeConfigFile(configFile):
"""For backwards-compatibility, upgrade the existing config file
to work with configparser and rename from .wakatime.conf to .wakatime.cfg.
"""
if os.path.isfile(configFile):
# if upgraded cfg file already exists, don't overwrite it
return
oldConfig = os.path.join(os.path.expanduser('~'), '.wakatime.conf')
try:
configs = {
'ignore': [],
}
with open(oldConfig, 'r', encoding='utf-8') as fh:
for line in fh.readlines():
line = line.split('=', 1)
if len(line) == 2 and line[0].strip() and line[1].strip():
if line[0].strip() == 'ignore':
configs['ignore'].append(line[1].strip())
else:
configs[line[0].strip()] = line[1].strip()
with open(configFile, 'w', encoding='utf-8') as fh:
fh.write("[settings]\n")
for name, value in configs.items():
if isinstance(value, list):
fh.write("%s=\n" % name)
for item in value:
fh.write(" %s\n" % item)
else:
fh.write("%s = %s\n" % (name, value))
os.remove(oldConfig)
except IOError:
pass
def parseConfigFile(configFile=None): def parseConfigFile(configFile=None):
"""Returns a configparser.SafeConfigParser instance with configs """Returns a configparser.SafeConfigParser instance with configs
read from the config file. Default location of the config file is read from the config file. Default location of the config file is
@ -102,8 +66,6 @@ def parseConfigFile(configFile=None):
if not configFile: if not configFile:
configFile = os.path.join(os.path.expanduser('~'), '.wakatime.cfg') configFile = os.path.join(os.path.expanduser('~'), '.wakatime.cfg')
upgradeConfigFile(configFile)
configs = configparser.SafeConfigParser() configs = configparser.SafeConfigParser()
try: try:
with open(configFile, 'r', encoding='utf-8') as fh: with open(configFile, 'r', encoding='utf-8') as fh:
@ -117,17 +79,12 @@ def parseConfigFile(configFile=None):
return configs return configs
def parseArguments(argv): def parseArguments():
"""Parse command line arguments and configs from ~/.wakatime.cfg. """Parse command line arguments and configs from ~/.wakatime.cfg.
Command line arguments take precedence over config file settings. Command line arguments take precedence over config file settings.
Returns instances of ArgumentParser and SafeConfigParser. Returns instances of ArgumentParser and SafeConfigParser.
""" """
try:
sys.argv
except AttributeError:
sys.argv = argv
# define supported command line arguments # define supported command line arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Common interface for the WakaTime api.') description='Common interface for the WakaTime api.')
@ -189,7 +146,7 @@ def parseArguments(argv):
parser.add_argument('--version', action='version', version=__version__) parser.add_argument('--version', action='version', version=__version__)
# parse command line arguments # parse command line arguments
args = parser.parse_args(args=argv[1:]) args = parser.parse_args()
# use current unix epoch timestamp by default # use current unix epoch timestamp by default
if not args.timestamp: if not args.timestamp:
@ -267,7 +224,7 @@ def should_exclude(fileName, include, exclude):
msg=u(ex), msg=u(ex),
pattern=u(pattern), pattern=u(pattern),
)) ))
except TypeError: except TypeError: # pragma: nocover
pass pass
try: try:
for pattern in exclude: for pattern in exclude:
@ -280,7 +237,7 @@ def should_exclude(fileName, include, exclude):
msg=u(ex), msg=u(ex),
pattern=u(pattern), pattern=u(pattern),
)) ))
except TypeError: except TypeError: # pragma: nocover
pass pass
return False return False
@ -320,11 +277,8 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None,
'type': 'file', 'type': 'file',
} }
if hidefilenames and targetFile is not None and not notfile: if hidefilenames and targetFile is not None and not notfile:
data['entity'] = data['entity'].rsplit('/', 1)[-1].rsplit('\\', 1)[-1] extension = u(os.path.splitext(data['entity'])[1])
if len(data['entity'].strip('.').split('.', 1)) > 1: data['entity'] = u('HIDDEN{0}').format(extension)
data['entity'] = u('HIDDEN.{ext}').format(ext=u(data['entity'].strip('.').rsplit('.', 1)[-1]))
else:
data['entity'] = u('HIDDEN')
if stats.get('lines'): if stats.get('lines'):
data['lines'] = stats['lines'] data['lines'] = stats['lines']
if stats.get('language'): if stats.get('language'):
@ -425,11 +379,10 @@ def send_heartbeat(project=None, branch=None, hostname=None, stats={}, key=None,
return False return False
def main(argv=None): def main(argv):
if not argv: sys.argv = ['wakatime'] + argv
argv = sys.argv
args, configs = parseArguments(argv) args, configs = parseArguments()
if configs is None: if configs is None:
return 103 # config file parsing error return 103 # config file parsing error

View file

@ -22,7 +22,7 @@ sys.path.insert(0, package_folder)
# import local wakatime package # import local wakatime package
try: try:
import wakatime import wakatime
except TypeError: except (TypeError, ImportError):
# on Windows, non-ASCII characters in import path can be fixed using # on Windows, non-ASCII characters in import path can be fixed using
# the script path from sys.argv[0]. # the script path from sys.argv[0].
# More info at https://github.com/wakatime/wakatime/issues/32 # More info at https://github.com/wakatime/wakatime/issues/32
@ -32,4 +32,4 @@ except TypeError:
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(wakatime.main(sys.argv)) sys.exit(wakatime.main(sys.argv[1:]))

View file

@ -17,9 +17,11 @@ is_py2 = (sys.version_info[0] == 2)
is_py3 = (sys.version_info[0] == 3) is_py3 = (sys.version_info[0] == 3)
if is_py2: if is_py2: # pragma: nocover
def u(text): def u(text):
if text is None:
return None
try: try:
return text.decode('utf-8') return text.decode('utf-8')
except: except:
@ -31,18 +33,21 @@ if is_py2:
basestring = basestring basestring = basestring
elif is_py3: elif is_py3: # pragma: nocover
def u(text): def u(text):
if text is None:
return None
if isinstance(text, bytes): if isinstance(text, bytes):
return text.decode('utf-8') return text.decode('utf-8')
return str(text) return str(text)
open = open open = open
basestring = (str, bytes) basestring = (str, bytes)
try: try:
from importlib import import_module from importlib import import_module
except ImportError: except ImportError: # pragma: nocover
def _resolve_name(name, package, level): def _resolve_name(name, package, level):
"""Return the absolute name of the module to be imported.""" """Return the absolute name of the module to be imported."""
if not hasattr(package, 'rindex'): if not hasattr(package, 'rindex'):

View file

@ -10,6 +10,7 @@
""" """
import logging import logging
import sys
import traceback import traceback
from ..compat import u, open, import_module from ..compat import u, open, import_module
@ -53,8 +54,16 @@ class TokenParser(object):
def _extract_tokens(self): def _extract_tokens(self):
if self.lexer: if self.lexer:
try:
with open(self.source_file, 'r', encoding='utf-8') as fh: with open(self.source_file, 'r', encoding='utf-8') as fh:
return self.lexer.get_tokens_unprocessed(fh.read(512000)) return self.lexer.get_tokens_unprocessed(fh.read(512000))
except:
pass
try:
with open(self.source_file, 'r', encoding=sys.getfilesystemencoding()) as fh:
return self.lexer.get_tokens_unprocessed(fh.read(512000))
except:
pass
return [] return []
def _save_dependency(self, dep, truncate=False, separator=None, def _save_dependency(self, dep, truncate=False, separator=None,
@ -83,7 +92,7 @@ class DependencyParser(object):
self.lexer = lexer self.lexer = lexer
if self.lexer: if self.lexer:
module_name = self.lexer.__module__.split('.')[-1] module_name = self.lexer.__module__.rsplit('.', 1)[-1]
class_name = self.lexer.__class__.__name__.replace('Lexer', 'Parser', 1) class_name = self.lexer.__class__.__name__.replace('Lexer', 'Parser', 1)
else: else:
module_name = 'unknown' module_name = 'unknown'

View file

@ -13,12 +13,15 @@ import logging
import os import os
import sys import sys
from .packages import simplejson as json
from .compat import u from .compat import u
try: try:
from collections import OrderedDict from collections import OrderedDict # pragma: nocover
except ImportError: except ImportError:
from .packages.ordereddict import OrderedDict from .packages.ordereddict import OrderedDict # pragma: nocover
try:
from .packages import simplejson as json # pragma: nocover
except (ImportError, SyntaxError):
import json # pragma: nocover
class CustomEncoder(json.JSONEncoder): class CustomEncoder(json.JSONEncoder):

View file

@ -21,6 +21,8 @@ try:
except ImportError: except ImportError:
HAS_SQL = False HAS_SQL = False
from .compat import u
log = logging.getLogger('WakaTime') log = logging.getLogger('WakaTime')
@ -50,16 +52,16 @@ class Queue(object):
try: try:
conn, c = self.connect() conn, c = self.connect()
heartbeat = { heartbeat = {
'file': data.get('entity'), 'file': u(data.get('entity')),
'time': data.get('time'), 'time': data.get('time'),
'project': data.get('project'), 'project': u(data.get('project')),
'branch': data.get('branch'), 'branch': u(data.get('branch')),
'is_write': 1 if data.get('is_write') else 0, 'is_write': 1 if data.get('is_write') else 0,
'stats': stats, 'stats': u(stats),
'misc': misc, 'misc': u(misc),
'plugin': plugin, 'plugin': u(plugin),
} }
c.execute('INSERT INTO heartbeat VALUES (:file,:time,:project,:branch,:is_write,:stats,:misc,:plugin)', heartbeat) c.execute(u('INSERT INTO heartbeat VALUES (:file,:time,:project,:branch,:is_write,:stats,:misc,:plugin)'), heartbeat)
conn.commit() conn.commit()
conn.close() conn.close()
except sqlite3.Error: except sqlite3.Error:
@ -90,14 +92,14 @@ class Queue(object):
for row_name in ['file', 'time', 'project', 'branch', 'is_write']: for row_name in ['file', 'time', 'project', 'branch', 'is_write']:
if row[index] is not None: if row[index] is not None:
clauses.append('{0}=?'.format(row_name)) clauses.append('{0}=?'.format(row_name))
values.append(row[index]) values.append(u(row[index]))
else: else:
clauses.append('{0} IS NULL'.format(row_name)) clauses.append('{0} IS NULL'.format(row_name))
index += 1 index += 1
if len(values) > 0: if len(values) > 0:
c.execute('DELETE FROM heartbeat WHERE {0}'.format(' AND '.join(clauses)), values) c.execute(u('DELETE FROM heartbeat WHERE {0}').format(u(' AND ').join(clauses)), values)
else: else:
c.execute('DELETE FROM heartbeat WHERE {0}'.format(' AND '.join(clauses))) c.execute(u('DELETE FROM heartbeat WHERE {0}').format(u(' AND ').join(clauses)))
conn.commit() conn.commit()
if row is not None: if row is not None:
heartbeat = { heartbeat = {

View file

@ -61,7 +61,12 @@ considered public as object names -- the API of the formatter objects is
still considered an implementation detail.) still considered an implementation detail.)
""" """
__version__ = '1.2.1' __version__ = '1.3.0' # we use our own version number independant of the
# one in stdlib and we release this on pypi.
__external_lib__ = True # to make sure the tests really test THIS lib,
# not the builtin one in Python stdlib
__all__ = [ __all__ = [
'ArgumentParser', 'ArgumentParser',
'ArgumentError', 'ArgumentError',
@ -1045,9 +1050,13 @@ class _SubParsersAction(Action):
class _ChoicesPseudoAction(Action): class _ChoicesPseudoAction(Action):
def __init__(self, name, help): def __init__(self, name, aliases, help):
metavar = dest = name
if aliases:
metavar += ' (%s)' % ', '.join(aliases)
sup = super(_SubParsersAction._ChoicesPseudoAction, self) sup = super(_SubParsersAction._ChoicesPseudoAction, self)
sup.__init__(option_strings=[], dest=name, help=help) sup.__init__(option_strings=[], dest=dest, help=help,
metavar=metavar)
def __init__(self, def __init__(self,
option_strings, option_strings,
@ -1075,15 +1084,22 @@ class _SubParsersAction(Action):
if kwargs.get('prog') is None: if kwargs.get('prog') is None:
kwargs['prog'] = '%s %s' % (self._prog_prefix, name) kwargs['prog'] = '%s %s' % (self._prog_prefix, name)
aliases = kwargs.pop('aliases', ())
# create a pseudo-action to hold the choice help # create a pseudo-action to hold the choice help
if 'help' in kwargs: if 'help' in kwargs:
help = kwargs.pop('help') help = kwargs.pop('help')
choice_action = self._ChoicesPseudoAction(name, help) choice_action = self._ChoicesPseudoAction(name, aliases, help)
self._choices_actions.append(choice_action) self._choices_actions.append(choice_action)
# create the parser and add it to the map # create the parser and add it to the map
parser = self._parser_class(**kwargs) parser = self._parser_class(**kwargs)
self._name_parser_map[name] = parser self._name_parser_map[name] = parser
# make parser available under aliases also
for alias in aliases:
self._name_parser_map[alias] = parser
return parser return parser
def _get_subactions(self): def _get_subactions(self):

View file

@ -6,7 +6,7 @@
# / # /
""" """
requests HTTP library Requests HTTP library
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
Requests is an HTTP library, written in Python, for human beings. Basic GET Requests is an HTTP library, written in Python, for human beings. Basic GET
@ -42,8 +42,8 @@ is at <http://python-requests.org>.
""" """
__title__ = 'requests' __title__ = 'requests'
__version__ = '2.6.0' __version__ = '2.7.0'
__build__ = 0x020503 __build__ = 0x020700
__author__ = 'Kenneth Reitz' __author__ = 'Kenneth Reitz'
__license__ = 'Apache 2.0' __license__ = 'Apache 2.0'
__copyright__ = 'Copyright 2015 Kenneth Reitz' __copyright__ = 'Copyright 2015 Kenneth Reitz'

View file

@ -35,6 +35,7 @@ from .auth import _basic_auth_str
DEFAULT_POOLBLOCK = False DEFAULT_POOLBLOCK = False
DEFAULT_POOLSIZE = 10 DEFAULT_POOLSIZE = 10
DEFAULT_RETRIES = 0 DEFAULT_RETRIES = 0
DEFAULT_POOL_TIMEOUT = None
class BaseAdapter(object): class BaseAdapter(object):
@ -375,7 +376,7 @@ class HTTPAdapter(BaseAdapter):
if hasattr(conn, 'proxy_pool'): if hasattr(conn, 'proxy_pool'):
conn = conn.proxy_pool conn = conn.proxy_pool
low_conn = conn._get_conn(timeout=timeout) low_conn = conn._get_conn(timeout=DEFAULT_POOL_TIMEOUT)
try: try:
low_conn.putrequest(request.method, low_conn.putrequest(request.method,
@ -407,9 +408,6 @@ class HTTPAdapter(BaseAdapter):
# Then, reraise so that we can handle the actual exception. # Then, reraise so that we can handle the actual exception.
low_conn.close() low_conn.close()
raise raise
else:
# All is well, return the connection to the pool.
conn._put_conn(low_conn)
except (ProtocolError, socket.error) as err: except (ProtocolError, socket.error) as err:
raise ConnectionError(err, request=request) raise ConnectionError(err, request=request)

View file

@ -55,17 +55,18 @@ def request(method, url, **kwargs):
return response return response
def get(url, **kwargs): def get(url, params=None, **kwargs):
"""Sends a GET request. """Sends a GET request.
:param url: URL for the new :class:`Request` object. :param url: URL for the new :class:`Request` object.
:param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`.
:param \*\*kwargs: Optional arguments that ``request`` takes. :param \*\*kwargs: Optional arguments that ``request`` takes.
:return: :class:`Response <Response>` object :return: :class:`Response <Response>` object
:rtype: requests.Response :rtype: requests.Response
""" """
kwargs.setdefault('allow_redirects', True) kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs) return request('get', url, params=params, **kwargs)
def options(url, **kwargs): def options(url, **kwargs):

View file

@ -103,7 +103,8 @@ class HTTPDigestAuth(AuthBase):
# XXX not implemented yet # XXX not implemented yet
entdig = None entdig = None
p_parsed = urlparse(url) p_parsed = urlparse(url)
path = p_parsed.path #: path is request-uri defined in RFC 2616 which should not be empty
path = p_parsed.path or "/"
if p_parsed.query: if p_parsed.query:
path += '?' + p_parsed.query path += '?' + p_parsed.query
@ -178,7 +179,7 @@ class HTTPDigestAuth(AuthBase):
# Consume content and release the original connection # Consume content and release the original connection
# to allow our new request to reuse the same one. # to allow our new request to reuse the same one.
r.content r.content
r.raw.release_conn() r.close()
prep = r.request.copy() prep = r.request.copy()
extract_cookies_to_jar(prep._cookies, r.request, r.raw) extract_cookies_to_jar(prep._cookies, r.request, r.raw)
prep.prepare_cookies(prep._cookies) prep.prepare_cookies(prep._cookies)

View file

@ -11,6 +11,7 @@ If you are packaging Requests, e.g., for a Linux distribution or a managed
environment, you can change the definition of where() to return a separately environment, you can change the definition of where() to return a separately
packaged CA bundle. packaged CA bundle.
""" """
import sys
import os.path import os.path
try: try:
@ -19,7 +20,9 @@ except ImportError:
def where(): def where():
"""Return the preferred certificate bundle.""" """Return the preferred certificate bundle."""
# vendored bundle inside Requests # vendored bundle inside Requests
return os.path.join(os.path.dirname(__file__), 'cacert.pem') is_py3 = (sys.version_info[0] == 3)
cacert = os.path.join(os.path.dirname(__file__), 'cacert.pem')
return cacert.encode('utf-8') if is_py3 else cacert
if __name__ == '__main__': if __name__ == '__main__':
print(where()) print(where())

View file

@ -6,6 +6,7 @@ Compatibility code to be able to use `cookielib.CookieJar` with requests.
requests.utils imports from here, so be careful with imports. requests.utils imports from here, so be careful with imports.
""" """
import copy
import time import time
import collections import collections
from .compat import cookielib, urlparse, urlunparse, Morsel from .compat import cookielib, urlparse, urlunparse, Morsel
@ -302,7 +303,7 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping):
"""Updates this jar with cookies from another CookieJar or dict-like""" """Updates this jar with cookies from another CookieJar or dict-like"""
if isinstance(other, cookielib.CookieJar): if isinstance(other, cookielib.CookieJar):
for cookie in other: for cookie in other:
self.set_cookie(cookie) self.set_cookie(copy.copy(cookie))
else: else:
super(RequestsCookieJar, self).update(other) super(RequestsCookieJar, self).update(other)
@ -359,6 +360,21 @@ class RequestsCookieJar(cookielib.CookieJar, collections.MutableMapping):
return new_cj return new_cj
def _copy_cookie_jar(jar):
if jar is None:
return None
if hasattr(jar, 'copy'):
# We're dealing with an instane of RequestsCookieJar
return jar.copy()
# We're dealing with a generic CookieJar instance
new_jar = copy.copy(jar)
new_jar.clear()
for cookie in jar:
new_jar.set_cookie(copy.copy(cookie))
return new_jar
def create_cookie(name, value, **kwargs): def create_cookie(name, value, **kwargs):
"""Make a cookie from underspecified parameters. """Make a cookie from underspecified parameters.
@ -399,11 +415,14 @@ def morsel_to_cookie(morsel):
expires = None expires = None
if morsel['max-age']: if morsel['max-age']:
expires = time.time() + morsel['max-age'] try:
expires = int(time.time() + int(morsel['max-age']))
except ValueError:
raise TypeError('max-age: %s must be integer' % morsel['max-age'])
elif morsel['expires']: elif morsel['expires']:
time_template = '%a, %d-%b-%Y %H:%M:%S GMT' time_template = '%a, %d-%b-%Y %H:%M:%S GMT'
expires = time.mktime( expires = int(time.mktime(
time.strptime(morsel['expires'], time_template)) - time.timezone time.strptime(morsel['expires'], time_template)) - time.timezone)
return create_cookie( return create_cookie(
comment=morsel['comment'], comment=morsel['comment'],
comment_url=bool(morsel['comment']), comment_url=bool(morsel['comment']),

View file

@ -15,7 +15,7 @@ from .hooks import default_hooks
from .structures import CaseInsensitiveDict from .structures import CaseInsensitiveDict
from .auth import HTTPBasicAuth from .auth import HTTPBasicAuth
from .cookies import cookiejar_from_dict, get_cookie_header from .cookies import cookiejar_from_dict, get_cookie_header, _copy_cookie_jar
from .packages.urllib3.fields import RequestField from .packages.urllib3.fields import RequestField
from .packages.urllib3.filepost import encode_multipart_formdata from .packages.urllib3.filepost import encode_multipart_formdata
from .packages.urllib3.util import parse_url from .packages.urllib3.util import parse_url
@ -30,7 +30,8 @@ from .utils import (
iter_slices, guess_json_utf, super_len, to_native_string) iter_slices, guess_json_utf, super_len, to_native_string)
from .compat import ( from .compat import (
cookielib, urlunparse, urlsplit, urlencode, str, bytes, StringIO, cookielib, urlunparse, urlsplit, urlencode, str, bytes, StringIO,
is_py2, chardet, json, builtin_str, basestring) is_py2, chardet, builtin_str, basestring)
from .compat import json as complexjson
from .status_codes import codes from .status_codes import codes
#: The set of HTTP status codes that indicate an automatically #: The set of HTTP status codes that indicate an automatically
@ -42,12 +43,11 @@ REDIRECT_STATI = (
codes.temporary_redirect, # 307 codes.temporary_redirect, # 307
codes.permanent_redirect, # 308 codes.permanent_redirect, # 308
) )
DEFAULT_REDIRECT_LIMIT = 30 DEFAULT_REDIRECT_LIMIT = 30
CONTENT_CHUNK_SIZE = 10 * 1024 CONTENT_CHUNK_SIZE = 10 * 1024
ITER_CHUNK_SIZE = 512 ITER_CHUNK_SIZE = 512
json_dumps = json.dumps
class RequestEncodingMixin(object): class RequestEncodingMixin(object):
@property @property
@ -149,8 +149,7 @@ class RequestEncodingMixin(object):
else: else:
fdata = fp.read() fdata = fp.read()
rf = RequestField(name=k, data=fdata, rf = RequestField(name=k, data=fdata, filename=fn, headers=fh)
filename=fn, headers=fh)
rf.make_multipart(content_type=ft) rf.make_multipart(content_type=ft)
new_fields.append(rf) new_fields.append(rf)
@ -207,17 +206,8 @@ class Request(RequestHooksMixin):
<PreparedRequest [GET]> <PreparedRequest [GET]>
""" """
def __init__(self, def __init__(self, method=None, url=None, headers=None, files=None,
method=None, data=None, params=None, auth=None, cookies=None, hooks=None, json=None):
url=None,
headers=None,
files=None,
data=None,
params=None,
auth=None,
cookies=None,
hooks=None,
json=None):
# Default empty dicts for dict params. # Default empty dicts for dict params.
data = [] if data is None else data data = [] if data is None else data
@ -296,8 +286,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.hooks = default_hooks() self.hooks = default_hooks()
def prepare(self, method=None, url=None, headers=None, files=None, def prepare(self, method=None, url=None, headers=None, files=None,
data=None, params=None, auth=None, cookies=None, hooks=None, data=None, params=None, auth=None, cookies=None, hooks=None, json=None):
json=None):
"""Prepares the entire request with the given parameters.""" """Prepares the entire request with the given parameters."""
self.prepare_method(method) self.prepare_method(method)
@ -306,6 +295,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.prepare_cookies(cookies) self.prepare_cookies(cookies)
self.prepare_body(data, files, json) self.prepare_body(data, files, json)
self.prepare_auth(auth, url) self.prepare_auth(auth, url)
# Note that prepare_auth must be last to enable authentication schemes # Note that prepare_auth must be last to enable authentication schemes
# such as OAuth to work on a fully prepared request. # such as OAuth to work on a fully prepared request.
@ -320,7 +310,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
p.method = self.method p.method = self.method
p.url = self.url p.url = self.url
p.headers = self.headers.copy() if self.headers is not None else None p.headers = self.headers.copy() if self.headers is not None else None
p._cookies = self._cookies.copy() if self._cookies is not None else None p._cookies = _copy_cookie_jar(self._cookies)
p.body = self.body p.body = self.body
p.hooks = self.hooks p.hooks = self.hooks
return p return p
@ -357,8 +347,10 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
raise InvalidURL(*e.args) raise InvalidURL(*e.args)
if not scheme: if not scheme:
raise MissingSchema("Invalid URL {0!r}: No schema supplied. " error = ("Invalid URL {0!r}: No schema supplied. Perhaps you meant http://{0}?")
"Perhaps you meant http://{0}?".format(url)) error = error.format(to_native_string(url, 'utf8'))
raise MissingSchema(error)
if not host: if not host:
raise InvalidURL("Invalid URL %r: No host supplied" % url) raise InvalidURL("Invalid URL %r: No host supplied" % url)
@ -424,7 +416,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
if json is not None: if json is not None:
content_type = 'application/json' content_type = 'application/json'
body = json_dumps(json) body = complexjson.dumps(json)
is_stream = all([ is_stream = all([
hasattr(data, '__iter__'), hasattr(data, '__iter__'),
@ -501,7 +493,15 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
self.prepare_content_length(self.body) self.prepare_content_length(self.body)
def prepare_cookies(self, cookies): def prepare_cookies(self, cookies):
"""Prepares the given HTTP cookie data.""" """Prepares the given HTTP cookie data.
This function eventually generates a ``Cookie`` header from the
given cookies using cookielib. Due to cookielib's design, the header
will not be regenerated if it already exists, meaning this function
can only be called once for the life of the
:class:`PreparedRequest <PreparedRequest>` object. Any subsequent calls
to ``prepare_cookies`` will have no actual effect, unless the "Cookie"
header is removed beforehand."""
if isinstance(cookies, cookielib.CookieJar): if isinstance(cookies, cookielib.CookieJar):
self._cookies = cookies self._cookies = cookies
@ -514,6 +514,10 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
def prepare_hooks(self, hooks): def prepare_hooks(self, hooks):
"""Prepares the given hooks.""" """Prepares the given hooks."""
# hooks can be passed as None to the prepare method and to this
# method. To prevent iterating over None, simply use an empty list
# if hooks is False-y
hooks = hooks or []
for event in hooks: for event in hooks:
self.register_hook(event, hooks[event]) self.register_hook(event, hooks[event])
@ -524,16 +528,8 @@ class Response(object):
""" """
__attrs__ = [ __attrs__ = [
'_content', '_content', 'status_code', 'headers', 'url', 'history',
'status_code', 'encoding', 'reason', 'cookies', 'elapsed', 'request'
'headers',
'url',
'history',
'encoding',
'reason',
'cookies',
'elapsed',
'request',
] ]
def __init__(self): def __init__(self):
@ -653,9 +649,10 @@ class Response(object):
If decode_unicode is True, content will be decoded using the best If decode_unicode is True, content will be decoded using the best
available encoding based on the response. available encoding based on the response.
""" """
def generate(): def generate():
try:
# Special case for urllib3. # Special case for urllib3.
if hasattr(self.raw, 'stream'):
try: try:
for chunk in self.raw.stream(chunk_size, decode_content=True): for chunk in self.raw.stream(chunk_size, decode_content=True):
yield chunk yield chunk
@ -665,7 +662,7 @@ class Response(object):
raise ContentDecodingError(e) raise ContentDecodingError(e)
except ReadTimeoutError as e: except ReadTimeoutError as e:
raise ConnectionError(e) raise ConnectionError(e)
except AttributeError: else:
# Standard file-like object. # Standard file-like object.
while True: while True:
chunk = self.raw.read(chunk_size) chunk = self.raw.read(chunk_size)
@ -796,14 +793,16 @@ class Response(object):
encoding = guess_json_utf(self.content) encoding = guess_json_utf(self.content)
if encoding is not None: if encoding is not None:
try: try:
return json.loads(self.content.decode(encoding), **kwargs) return complexjson.loads(
self.content.decode(encoding), **kwargs
)
except UnicodeDecodeError: except UnicodeDecodeError:
# Wrong UTF codec detected; usually because it's not UTF-8 # Wrong UTF codec detected; usually because it's not UTF-8
# but some other 8-bit codec. This is an RFC violation, # but some other 8-bit codec. This is an RFC violation,
# and the server didn't bother to tell us what codec *was* # and the server didn't bother to tell us what codec *was*
# used. # used.
pass pass
return json.loads(self.text, **kwargs) return complexjson.loads(self.text, **kwargs)
@property @property
def links(self): def links(self):
@ -829,10 +828,10 @@ class Response(object):
http_error_msg = '' http_error_msg = ''
if 400 <= self.status_code < 500: if 400 <= self.status_code < 500:
http_error_msg = '%s Client Error: %s' % (self.status_code, self.reason) http_error_msg = '%s Client Error: %s for url: %s' % (self.status_code, self.reason, self.url)
elif 500 <= self.status_code < 600: elif 500 <= self.status_code < 600:
http_error_msg = '%s Server Error: %s' % (self.status_code, self.reason) http_error_msg = '%s Server Error: %s for url: %s' % (self.status_code, self.reason, self.url)
if http_error_msg: if http_error_msg:
raise HTTPError(http_error_msg, response=self) raise HTTPError(http_error_msg, response=self)
@ -843,4 +842,7 @@ class Response(object):
*Note: Should not normally need to be called explicitly.* *Note: Should not normally need to be called explicitly.*
""" """
if not self._content_consumed:
return self.raw.close()
return self.raw.release_conn() return self.raw.release_conn()

View file

@ -1,107 +1,3 @@
"""
Copyright (c) Donald Stufft, pip, and individual contributors
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import from __future__ import absolute_import
import sys from . import urllib3
class VendorAlias(object):
def __init__(self, package_names):
self._package_names = package_names
self._vendor_name = __name__
self._vendor_pkg = self._vendor_name + "."
self._vendor_pkgs = [
self._vendor_pkg + name for name in self._package_names
]
def find_module(self, fullname, path=None):
if fullname.startswith(self._vendor_pkg):
return self
def load_module(self, name):
# Ensure that this only works for the vendored name
if not name.startswith(self._vendor_pkg):
raise ImportError(
"Cannot import %s, must be a subpackage of '%s'." % (
name, self._vendor_name,
)
)
if not (name == self._vendor_name or
any(name.startswith(pkg) for pkg in self._vendor_pkgs)):
raise ImportError(
"Cannot import %s, must be one of %s." % (
name, self._vendor_pkgs
)
)
# Check to see if we already have this item in sys.modules, if we do
# then simply return that.
if name in sys.modules:
return sys.modules[name]
# Check to see if we can import the vendor name
try:
# We do this dance here because we want to try and import this
# module without hitting a recursion error because of a bunch of
# VendorAlias instances on sys.meta_path
real_meta_path = sys.meta_path[:]
try:
sys.meta_path = [
m for m in sys.meta_path
if not isinstance(m, VendorAlias)
]
__import__(name)
module = sys.modules[name]
finally:
# Re-add any additions to sys.meta_path that were made while
# during the import we just did, otherwise things like
# requests.packages.urllib3.poolmanager will fail.
for m in sys.meta_path:
if m not in real_meta_path:
real_meta_path.append(m)
# Restore sys.meta_path with any new items.
sys.meta_path = real_meta_path
except ImportError:
# We can't import the vendor name, so we'll try to import the
# "real" name.
real_name = name[len(self._vendor_pkg):]
try:
__import__(real_name)
module = sys.modules[real_name]
except ImportError:
raise ImportError("No module named '%s'" % (name,))
# If we've gotten here we've found the module we're looking for, either
# as part of our vendored package, or as the real name, so we'll add
# it to sys.modules as the vendored name so that we don't have to do
# the lookup again.
sys.modules[name] = module
# Finally, return the loaded module
return module
sys.meta_path.append(VendorAlias(["urllib3", "chardet"]))

View file

@ -4,7 +4,7 @@ urllib3 - Thread-safe connection pooling and re-using.
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)' __author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT' __license__ = 'MIT'
__version__ = '1.10.2' __version__ = '1.10.4'
from .connectionpool import ( from .connectionpool import (
@ -55,9 +55,12 @@ def add_stderr_logger(level=logging.DEBUG):
del NullHandler del NullHandler
# Set security warning to always go off by default.
import warnings import warnings
warnings.simplefilter('always', exceptions.SecurityWarning) # SecurityWarning's always go off by default.
warnings.simplefilter('always', exceptions.SecurityWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter('default', exceptions.InsecurePlatformWarning,
append=True)
def disable_warnings(category=exceptions.HTTPWarning): def disable_warnings(category=exceptions.HTTPWarning):
""" """

View file

@ -227,20 +227,20 @@ class HTTPHeaderDict(dict):
# Need to convert the tuple to list for further extension # Need to convert the tuple to list for further extension
_dict_setitem(self, key_lower, [vals[0], vals[1], val]) _dict_setitem(self, key_lower, [vals[0], vals[1], val])
def extend(*args, **kwargs): def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object. """Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__ with self.add instead of self.__setitem__
""" """
if len(args) > 2: if len(args) > 1:
raise TypeError("update() takes at most 2 positional " raise TypeError("extend() takes at most 1 positional "
"arguments ({} given)".format(len(args))) "arguments ({} given)".format(len(args)))
elif not args: other = args[0] if len(args) >= 1 else ()
raise TypeError("update() takes at least 1 argument (0 given)")
self = args[0]
other = args[1] if len(args) >= 2 else ()
if isinstance(other, Mapping): if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
elif isinstance(other, Mapping):
for key in other: for key in other:
self.add(key, other[key]) self.add(key, other[key])
elif hasattr(other, "keys"): elif hasattr(other, "keys"):
@ -304,17 +304,20 @@ class HTTPHeaderDict(dict):
return list(self.iteritems()) return list(self.iteritems())
@classmethod @classmethod
def from_httplib(cls, message, duplicates=('set-cookie',)): # Python 2 def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object.""" """Read headers from a Python 2 httplib message object."""
ret = cls(message.items()) # python2.7 does not expose a proper API for exporting multiheaders
# ret now contains only the last header line for each duplicate. # efficiently. This function re-reads raw lines from the message
# Importing with all duplicates would be nice, but this would # object and extracts the multiheaders properly.
# mean to repeat most of the raw parsing already done, when the headers = []
# message object was created. Extracting only the headers of interest
# separately, the cookies, should be faster and requires less for line in message.headers:
# extra code. if line.startswith((' ', '\t')):
for key in duplicates: key, value = headers[-1]
ret.discard(key) headers[-1] = (key, value + '\r\n' + line.rstrip())
for val in message.getheaders(key): continue
ret.add(key, val)
return ret key, value = line.split(':', 1)
headers.append((key, value.strip()))
return cls(headers)

View file

@ -260,3 +260,5 @@ if ssl:
# Make a copy for testing. # Make a copy for testing.
UnverifiedHTTPSConnection = HTTPSConnection UnverifiedHTTPSConnection = HTTPSConnection
HTTPSConnection = VerifiedHTTPSConnection HTTPSConnection = VerifiedHTTPSConnection
else:
HTTPSConnection = DummyConnection

View file

@ -735,7 +735,6 @@ class HTTPSConnectionPool(HTTPConnectionPool):
% (self.num_connections, self.host)) % (self.num_connections, self.host))
if not self.ConnectionCls or self.ConnectionCls is DummyConnection: if not self.ConnectionCls or self.ConnectionCls is DummyConnection:
# Platform-specific: Python without ssl
raise SSLError("Can't connect to HTTPS URL because the SSL " raise SSLError("Can't connect to HTTPS URL because the SSL "
"module is not available.") "module is not available.")

View file

@ -38,8 +38,6 @@ Module Variables
---------------- ----------------
:var DEFAULT_SSL_CIPHER_LIST: The list of supported SSL/TLS cipher suites. :var DEFAULT_SSL_CIPHER_LIST: The list of supported SSL/TLS cipher suites.
Default: ``ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:
ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:!aNULL:!MD5:!DSS``
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication .. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) .. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
@ -85,22 +83,7 @@ _openssl_verify = {
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
} }
# A secure default. DEFAULT_SSL_CIPHER_LIST = util.ssl_.DEFAULT_CIPHERS
# Sources for more information on TLS ciphers:
#
# - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM over any AES-CBC for better performance and security,
# - use 3DES as fallback which is secure but slow,
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
DEFAULT_SSL_CIPHER_LIST = "ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:" + \
"ECDH+AES128:DH+AES:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+3DES:" + \
"!aNULL:!MD5:!DSS"
orig_util_HAS_SNI = util.HAS_SNI orig_util_HAS_SNI = util.HAS_SNI
@ -299,7 +282,9 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
try: try:
cnx.do_handshake() cnx.do_handshake()
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError:
select.select([sock], [], []) rd, _, _ = select.select([sock], [], [], sock.gettimeout())
if not rd:
raise timeout('select timed out')
continue continue
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad handshake', e) raise ssl.SSLError('bad handshake', e)

View file

@ -162,3 +162,8 @@ class SystemTimeWarning(SecurityWarning):
class InsecurePlatformWarning(SecurityWarning): class InsecurePlatformWarning(SecurityWarning):
"Warned when certain SSL configuration is not available on a platform." "Warned when certain SSL configuration is not available on a platform."
pass pass
class ResponseNotChunked(ProtocolError, ValueError):
"Response needs to be chunked in order to read it as chunks."
pass

View file

@ -1,9 +1,15 @@
try:
import http.client as httplib
except ImportError:
import httplib
import zlib import zlib
import io import io
from socket import timeout as SocketTimeout from socket import timeout as SocketTimeout
from ._collections import HTTPHeaderDict from ._collections import HTTPHeaderDict
from .exceptions import ProtocolError, DecodeError, ReadTimeoutError from .exceptions import (
ProtocolError, DecodeError, ReadTimeoutError, ResponseNotChunked
)
from .packages.six import string_types as basestring, binary_type, PY3 from .packages.six import string_types as basestring, binary_type, PY3
from .connection import HTTPException, BaseSSLError from .connection import HTTPException, BaseSSLError
from .util.response import is_fp_closed from .util.response import is_fp_closed
@ -117,7 +123,17 @@ class HTTPResponse(io.IOBase):
if hasattr(body, 'read'): if hasattr(body, 'read'):
self._fp = body self._fp = body
if preload_content and not self._body: # Are we using the chunked-style of transfer encoding?
self.chunked = False
self.chunk_left = None
tr_enc = self.headers.get('transfer-encoding', '').lower()
# Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings:
self.chunked = True
# We certainly don't want to preload content when the response is chunked.
if not self.chunked and preload_content and not self._body:
self._body = self.read(decode_content=decode_content) self._body = self.read(decode_content=decode_content)
def get_redirect_location(self): def get_redirect_location(self):
@ -157,6 +173,35 @@ class HTTPResponse(io.IOBase):
""" """
return self._fp_bytes_read return self._fp_bytes_read
def _init_decoder(self):
"""
Set-up the _decoder attribute if necessar.
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
if self._decoder is None and content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
def _decode(self, data, decode_content, flush_decoder):
"""
Decode the data passed in and potentially flush the decoder.
"""
try:
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
content_encoding = self.headers.get('content-encoding', '').lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e)
if flush_decoder and decode_content and self._decoder:
buf = self._decoder.decompress(binary_type())
data += buf + self._decoder.flush()
return data
def read(self, amt=None, decode_content=None, cache_content=False): def read(self, amt=None, decode_content=None, cache_content=False):
""" """
Similar to :meth:`httplib.HTTPResponse.read`, but with two additional Similar to :meth:`httplib.HTTPResponse.read`, but with two additional
@ -178,12 +223,7 @@ class HTTPResponse(io.IOBase):
after having ``.read()`` the file object. (Overridden if ``amt`` is after having ``.read()`` the file object. (Overridden if ``amt`` is
set.) set.)
""" """
# Note: content-encoding value should be case-insensitive, per RFC 7230 self._init_decoder()
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
if self._decoder is None:
if content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
if decode_content is None: if decode_content is None:
decode_content = self.decode_content decode_content = self.decode_content
@ -232,17 +272,7 @@ class HTTPResponse(io.IOBase):
self._fp_bytes_read += len(data) self._fp_bytes_read += len(data)
try: data = self._decode(data, decode_content, flush_decoder)
if decode_content and self._decoder:
data = self._decoder.decompress(data)
except (IOError, zlib.error) as e:
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e)
if flush_decoder and decode_content and self._decoder:
buf = self._decoder.decompress(binary_type())
data += buf + self._decoder.flush()
if cache_content: if cache_content:
self._body = data self._body = data
@ -269,6 +299,10 @@ class HTTPResponse(io.IOBase):
If True, will attempt to decode the body based on the If True, will attempt to decode the body based on the
'content-encoding' header. 'content-encoding' header.
""" """
if self.chunked:
for line in self.read_chunked(amt, decode_content=decode_content):
yield line
else:
while not is_fp_closed(self._fp): while not is_fp_closed(self._fp):
data = self.read(amt=amt, decode_content=decode_content) data = self.read(amt=amt, decode_content=decode_content)
@ -351,3 +385,82 @@ class HTTPResponse(io.IOBase):
else: else:
b[:len(temp)] = temp b[:len(temp)] = temp
return len(temp) return len(temp)
def _update_chunk_length(self):
# First, we'll figure out length of a chunk and then
# we'll try to read it from socket.
if self.chunk_left is not None:
return
line = self._fp.fp.readline()
line = line.split(b';', 1)[0]
try:
self.chunk_left = int(line, 16)
except ValueError:
# Invalid chunked protocol response, abort.
self.close()
raise httplib.IncompleteRead(line)
def _handle_chunk(self, amt):
returned_chunk = None
if amt is None:
chunk = self._fp._safe_read(self.chunk_left)
returned_chunk = chunk
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
elif amt < self.chunk_left:
value = self._fp._safe_read(amt)
self.chunk_left = self.chunk_left - amt
returned_chunk = value
elif amt == self.chunk_left:
value = self._fp._safe_read(amt)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
returned_chunk = value
else: # amt > self.chunk_left
returned_chunk = self._fp._safe_read(self.chunk_left)
self._fp._safe_read(2) # Toss the CRLF at the end of the chunk.
self.chunk_left = None
return returned_chunk
def read_chunked(self, amt=None, decode_content=None):
"""
Similar to :meth:`HTTPResponse.read`, but with an additional
parameter: ``decode_content``.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
self._init_decoder()
# FIXME: Rewrite this method and make it a class with a better structured logic.
if not self.chunked:
raise ResponseNotChunked("Response is not chunked. "
"Header 'transfer-encoding: chunked' is missing.")
if self._original_response and self._original_response._method.upper() == 'HEAD':
# Don't bother reading the body of a HEAD request.
# FIXME: Can we do this somehow without accessing private httplib _method?
self._original_response.close()
return
while True:
self._update_chunk_length()
if self.chunk_left == 0:
break
chunk = self._handle_chunk(amt)
yield self._decode(chunk, decode_content=decode_content,
flush_decoder=True)
# Chunk content ends with \r\n: discard it.
while True:
line = self._fp.fp.readline()
if not line:
# Some sites may not end with '\r\n'.
break
if line == b'\r\n':
break
# We read everything; close the "file".
if self._original_response:
self._original_response.close()
self.release_conn()

View file

@ -9,10 +9,10 @@ HAS_SNI = False
create_default_context = None create_default_context = None
import errno import errno
import ssl
import warnings import warnings
try: # Test for SSL features try: # Test for SSL features
import ssl
from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23 from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23
from ssl import HAS_SNI # Has SNI? from ssl import HAS_SNI # Has SNI?
except ImportError: except ImportError:
@ -25,14 +25,24 @@ except ImportError:
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000 OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
OP_NO_COMPRESSION = 0x20000 OP_NO_COMPRESSION = 0x20000
try: # A secure default.
from ssl import _DEFAULT_CIPHERS # Sources for more information on TLS ciphers:
except ImportError: #
_DEFAULT_CIPHERS = ( # - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM over any AES-CBC for better performance and security,
# - use 3DES as fallback which is secure but slow,
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
DEFAULT_CIPHERS = (
'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:' 'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:'
'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:' 'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:'
'!eNULL:!MD5' '!eNULL:!MD5'
) )
try: try:
from ssl import SSLContext # Modern SSL? from ssl import SSLContext # Modern SSL?
@ -40,7 +50,8 @@ except ImportError:
import sys import sys
class SSLContext(object): # Platform-specific: Python 2 & 3.1 class SSLContext(object): # Platform-specific: Python 2 & 3.1
supports_set_ciphers = sys.version_info >= (2, 7) supports_set_ciphers = ((2, 7) <= sys.version_info < (3,) or
(3, 2) <= sys.version_info)
def __init__(self, protocol_version): def __init__(self, protocol_version):
self.protocol = protocol_version self.protocol = protocol_version
@ -167,7 +178,7 @@ def resolve_ssl_version(candidate):
return candidate return candidate
def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED, def create_urllib3_context(ssl_version=None, cert_reqs=None,
options=None, ciphers=None): options=None, ciphers=None):
"""All arguments have the same meaning as ``ssl_wrap_socket``. """All arguments have the same meaning as ``ssl_wrap_socket``.
@ -204,6 +215,9 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
""" """
context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23) context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23)
# Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
if options is None: if options is None:
options = 0 options = 0
# SSLv2 is easily broken and is considered harmful and dangerous # SSLv2 is easily broken and is considered harmful and dangerous
@ -217,7 +231,7 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
context.options |= options context.options |= options
if getattr(context, 'supports_set_ciphers', True): # Platform-specific: Python 2.6 if getattr(context, 'supports_set_ciphers', True): # Platform-specific: Python 2.6
context.set_ciphers(ciphers or _DEFAULT_CIPHERS) context.set_ciphers(ciphers or DEFAULT_CIPHERS)
context.verify_mode = cert_reqs context.verify_mode = cert_reqs
if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2 if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2

View file

@ -15,6 +15,8 @@ class Url(namedtuple('Url', url_attrs)):
def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None, def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None,
query=None, fragment=None): query=None, fragment=None):
if path and not path.startswith('/'):
path = '/' + path
return super(Url, cls).__new__(cls, scheme, auth, host, port, path, return super(Url, cls).__new__(cls, scheme, auth, host, port, path,
query, fragment) query, fragment)

View file

@ -90,7 +90,7 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict):
class SessionRedirectMixin(object): class SessionRedirectMixin(object):
def resolve_redirects(self, resp, req, stream=False, timeout=None, def resolve_redirects(self, resp, req, stream=False, timeout=None,
verify=True, cert=None, proxies=None): verify=True, cert=None, proxies=None, **adapter_kwargs):
"""Receives a Response. Returns a generator of Responses.""" """Receives a Response. Returns a generator of Responses."""
i = 0 i = 0
@ -193,6 +193,7 @@ class SessionRedirectMixin(object):
cert=cert, cert=cert,
proxies=proxies, proxies=proxies,
allow_redirects=False, allow_redirects=False,
**adapter_kwargs
) )
extract_cookies_to_jar(self.cookies, prepared_request, resp.raw) extract_cookies_to_jar(self.cookies, prepared_request, resp.raw)
@ -560,10 +561,6 @@ class Session(SessionRedirectMixin):
# Set up variables needed for resolve_redirects and dispatching of hooks # Set up variables needed for resolve_redirects and dispatching of hooks
allow_redirects = kwargs.pop('allow_redirects', True) allow_redirects = kwargs.pop('allow_redirects', True)
stream = kwargs.get('stream') stream = kwargs.get('stream')
timeout = kwargs.get('timeout')
verify = kwargs.get('verify')
cert = kwargs.get('cert')
proxies = kwargs.get('proxies')
hooks = request.hooks hooks = request.hooks
# Get the appropriate adapter to use # Get the appropriate adapter to use
@ -591,12 +588,7 @@ class Session(SessionRedirectMixin):
extract_cookies_to_jar(self.cookies, request, r.raw) extract_cookies_to_jar(self.cookies, request, r.raw)
# Redirect resolving generator. # Redirect resolving generator.
gen = self.resolve_redirects(r, request, gen = self.resolve_redirects(r, request, **kwargs)
stream=stream,
timeout=timeout,
verify=verify,
cert=cert,
proxies=proxies)
# Resolve redirects if allowed. # Resolve redirects if allowed.
history = [resp for resp in gen] if allow_redirects else [] history = [resp for resp in gen] if allow_redirects else []

View file

@ -67,7 +67,7 @@ def super_len(o):
return len(o.getvalue()) return len(o.getvalue())
def get_netrc_auth(url): def get_netrc_auth(url, raise_errors=False):
"""Returns the Requests tuple auth for a given url from netrc.""" """Returns the Requests tuple auth for a given url from netrc."""
try: try:
@ -105,8 +105,9 @@ def get_netrc_auth(url):
return (_netrc[login_i], _netrc[2]) return (_netrc[login_i], _netrc[2])
except (NetrcParseError, IOError): except (NetrcParseError, IOError):
# If there was a parsing error or a permissions issue reading the file, # If there was a parsing error or a permissions issue reading the file,
# we'll just skip netrc auth # we'll just skip netrc auth unless explicitly asked to raise errors.
pass if raise_errors:
raise
# AppEngine hackiness. # AppEngine hackiness.
except (ImportError, AttributeError): except (ImportError, AttributeError):

View file

@ -5,9 +5,8 @@ interchange format.
:mod:`simplejson` exposes an API familiar to users of the standard library :mod:`simplejson` exposes an API familiar to users of the standard library
:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained :mod:`marshal` and :mod:`pickle` modules. It is the externally maintained
version of the :mod:`json` library contained in Python 2.6, but maintains version of the :mod:`json` library contained in Python 2.6, but maintains
compatibility with Python 2.4 and Python 2.5 and (currently) has compatibility back to Python 2.5 and (currently) has significant performance
significant performance advantages, even without using the optional C advantages, even without using the optional C extension for speedups.
extension for speedups.
Encoding basic Python object hierarchies:: Encoding basic Python object hierarchies::
@ -98,7 +97,7 @@ Using simplejson.tool from the shell to validate and pretty-print::
Expecting property name: line 1 column 3 (char 2) Expecting property name: line 1 column 3 (char 2)
""" """
from __future__ import absolute_import from __future__ import absolute_import
__version__ = '3.6.5' __version__ = '3.8.0'
__all__ = [ __all__ = [
'dump', 'dumps', 'load', 'loads', 'dump', 'dumps', 'load', 'loads',
'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder',
@ -140,6 +139,7 @@ _default_encoder = JSONEncoder(
use_decimal=True, use_decimal=True,
namedtuple_as_object=True, namedtuple_as_object=True,
tuple_as_array=True, tuple_as_array=True,
iterable_as_array=False,
bigint_as_string=False, bigint_as_string=False,
item_sort_key=None, item_sort_key=None,
for_json=False, for_json=False,
@ -152,7 +152,8 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
encoding='utf-8', default=None, use_decimal=True, encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True, namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None, bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw): for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` as a JSON formatted stream to ``fp`` (a """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a
``.write()``-supporting file-like object). ``.write()``-supporting file-like object).
@ -204,6 +205,10 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
If *tuple_as_array* is true (default: ``True``), If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays. :class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise. Note that this is still a rounding that happens in Javascript otherwise. Note that this is still a
@ -242,7 +247,7 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
check_circular and allow_nan and check_circular and allow_nan and
cls is None and indent is None and separators is None and cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None and not ignore_nan and int_as_string_bitcount is None
@ -258,6 +263,7 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
default=default, use_decimal=use_decimal, default=default, use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object, namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array, tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string, bigint_as_string=bigint_as_string,
sort_keys=sort_keys, sort_keys=sort_keys,
item_sort_key=item_sort_key, item_sort_key=item_sort_key,
@ -276,7 +282,8 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
encoding='utf-8', default=None, use_decimal=True, encoding='utf-8', default=None, use_decimal=True,
namedtuple_as_object=True, tuple_as_array=True, namedtuple_as_object=True, tuple_as_array=True,
bigint_as_string=False, sort_keys=False, item_sort_key=None, bigint_as_string=False, sort_keys=False, item_sort_key=None,
for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw): for_json=False, ignore_nan=False, int_as_string_bitcount=None,
iterable_as_array=False, **kw):
"""Serialize ``obj`` to a JSON formatted ``str``. """Serialize ``obj`` to a JSON formatted ``str``.
If ``skipkeys`` is false then ``dict`` keys that are not basic types If ``skipkeys`` is false then ``dict`` keys that are not basic types
@ -324,6 +331,10 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
If *tuple_as_array* is true (default: ``True``), If *tuple_as_array* is true (default: ``True``),
:class:`tuple` (and subclasses) will be encoded as JSON arrays. :class:`tuple` (and subclasses) will be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If *bigint_as_string* is true (not the default), ints 2**53 and higher If *bigint_as_string* is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise. rounding that happens in Javascript otherwise.
@ -356,12 +367,11 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
""" """
# cached encoder # cached encoder
if ( if (not skipkeys and ensure_ascii and
not skipkeys and ensure_ascii and
check_circular and allow_nan and check_circular and allow_nan and
cls is None and indent is None and separators is None and cls is None and indent is None and separators is None and
encoding == 'utf-8' and default is None and use_decimal encoding == 'utf-8' and default is None and use_decimal
and namedtuple_as_object and tuple_as_array and namedtuple_as_object and tuple_as_array and not iterable_as_array
and not bigint_as_string and not sort_keys and not bigint_as_string and not sort_keys
and not item_sort_key and not for_json and not item_sort_key and not for_json
and not ignore_nan and int_as_string_bitcount is None and not ignore_nan and int_as_string_bitcount is None
@ -377,6 +387,7 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
use_decimal=use_decimal, use_decimal=use_decimal,
namedtuple_as_object=namedtuple_as_object, namedtuple_as_object=namedtuple_as_object,
tuple_as_array=tuple_as_array, tuple_as_array=tuple_as_array,
iterable_as_array=iterable_as_array,
bigint_as_string=bigint_as_string, bigint_as_string=bigint_as_string,
sort_keys=sort_keys, sort_keys=sort_keys,
item_sort_key=item_sort_key, item_sort_key=item_sort_key,

View file

@ -10,6 +10,7 @@
#define PyString_AS_STRING PyBytes_AS_STRING #define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_FromStringAndSize PyBytes_FromStringAndSize #define PyString_FromStringAndSize PyBytes_FromStringAndSize
#define PyInt_Check(obj) 0 #define PyInt_Check(obj) 0
#define PyInt_CheckExact(obj) 0
#define JSON_UNICHR Py_UCS4 #define JSON_UNICHR Py_UCS4
#define JSON_InternFromString PyUnicode_InternFromString #define JSON_InternFromString PyUnicode_InternFromString
#define JSON_Intern_GET_SIZE PyUnicode_GET_SIZE #define JSON_Intern_GET_SIZE PyUnicode_GET_SIZE
@ -168,6 +169,7 @@ typedef struct _PyEncoderObject {
int use_decimal; int use_decimal;
int namedtuple_as_object; int namedtuple_as_object;
int tuple_as_array; int tuple_as_array;
int iterable_as_array;
PyObject *max_long_size; PyObject *max_long_size;
PyObject *min_long_size; PyObject *min_long_size;
PyObject *item_sort_key; PyObject *item_sort_key;
@ -660,8 +662,21 @@ encoder_stringify_key(PyEncoderObject *s, PyObject *key)
return _encoded_const(key); return _encoded_const(key);
} }
else if (PyInt_Check(key) || PyLong_Check(key)) { else if (PyInt_Check(key) || PyLong_Check(key)) {
if (!(PyInt_CheckExact(key) || PyLong_CheckExact(key))) {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, key, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Str(tmp);
Py_DECREF(tmp);
return res;
}
else {
return PyObject_Str(key); return PyObject_Str(key);
} }
}
else if (s->use_decimal && PyObject_TypeCheck(key, (PyTypeObject *)s->Decimal)) { else if (s->use_decimal && PyObject_TypeCheck(key, (PyTypeObject *)s->Decimal)) {
return PyObject_Str(key); return PyObject_Str(key);
} }
@ -2567,7 +2582,6 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
static int static int
encoder_init(PyObject *self, PyObject *args, PyObject *kwds) encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
{ {
/* initialize Encoder object */
static char *kwlist[] = { static char *kwlist[] = {
"markers", "markers",
"default", "default",
@ -2582,30 +2596,32 @@ encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
"use_decimal", "use_decimal",
"namedtuple_as_object", "namedtuple_as_object",
"tuple_as_array", "tuple_as_array",
"iterable_as_array"
"int_as_string_bitcount", "int_as_string_bitcount",
"item_sort_key", "item_sort_key",
"encoding", "encoding",
"for_json", "for_json",
"ignore_nan", "ignore_nan",
"Decimal", "Decimal",
"iterable_as_array",
NULL}; NULL};
PyEncoderObject *s; PyEncoderObject *s;
PyObject *markers, *defaultfn, *encoder, *indent, *key_separator; PyObject *markers, *defaultfn, *encoder, *indent, *key_separator;
PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo; PyObject *item_separator, *sort_keys, *skipkeys, *allow_nan, *key_memo;
PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array; PyObject *use_decimal, *namedtuple_as_object, *tuple_as_array, *iterable_as_array;
PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json; PyObject *int_as_string_bitcount, *item_sort_key, *encoding, *for_json;
PyObject *ignore_nan, *Decimal; PyObject *ignore_nan, *Decimal;
assert(PyEncoder_Check(self)); assert(PyEncoder_Check(self));
s = (PyEncoderObject *)self; s = (PyEncoderObject *)self;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOO:make_encoder", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOOOOOOOOOOOOOOOO:make_encoder", kwlist,
&markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator, &markers, &defaultfn, &encoder, &indent, &key_separator, &item_separator,
&sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal, &sort_keys, &skipkeys, &allow_nan, &key_memo, &use_decimal,
&namedtuple_as_object, &tuple_as_array, &namedtuple_as_object, &tuple_as_array,
&int_as_string_bitcount, &item_sort_key, &encoding, &for_json, &int_as_string_bitcount, &item_sort_key, &encoding, &for_json,
&ignore_nan, &Decimal)) &ignore_nan, &Decimal, &iterable_as_array))
return -1; return -1;
Py_INCREF(markers); Py_INCREF(markers);
@ -2635,9 +2651,10 @@ encoder_init(PyObject *self, PyObject *args, PyObject *kwds)
s->use_decimal = PyObject_IsTrue(use_decimal); s->use_decimal = PyObject_IsTrue(use_decimal);
s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object); s->namedtuple_as_object = PyObject_IsTrue(namedtuple_as_object);
s->tuple_as_array = PyObject_IsTrue(tuple_as_array); s->tuple_as_array = PyObject_IsTrue(tuple_as_array);
s->iterable_as_array = PyObject_IsTrue(iterable_as_array);
if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) { if (PyInt_Check(int_as_string_bitcount) || PyLong_Check(int_as_string_bitcount)) {
static const unsigned int long_long_bitsize = SIZEOF_LONG_LONG * 8; static const unsigned int long_long_bitsize = SIZEOF_LONG_LONG * 8;
int int_as_string_bitcount_val = PyLong_AsLong(int_as_string_bitcount); int int_as_string_bitcount_val = (int)PyLong_AsLong(int_as_string_bitcount);
if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < long_long_bitsize) { if (int_as_string_bitcount_val > 0 && int_as_string_bitcount_val < long_long_bitsize) {
s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << int_as_string_bitcount_val); s->max_long_size = PyLong_FromUnsignedLongLong(1ULL << int_as_string_bitcount_val);
s->min_long_size = PyLong_FromLongLong(-1LL << int_as_string_bitcount_val); s->min_long_size = PyLong_FromLongLong(-1LL << int_as_string_bitcount_val);
@ -2800,7 +2817,20 @@ encoder_encode_float(PyEncoderObject *s, PyObject *obj)
} }
} }
/* Use a better float format here? */ /* Use a better float format here? */
if (PyFloat_CheckExact(obj)) {
return PyObject_Repr(obj); return PyObject_Repr(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *res;
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyFloat_Type, obj, NULL);
if (tmp == NULL) {
return NULL;
}
res = PyObject_Repr(tmp);
Py_DECREF(tmp);
return res;
}
} }
static PyObject * static PyObject *
@ -2840,7 +2870,21 @@ encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ss
rv = _steal_accumulate(rval, encoded); rv = _steal_accumulate(rval, encoded);
} }
else if (PyInt_Check(obj) || PyLong_Check(obj)) { else if (PyInt_Check(obj) || PyLong_Check(obj)) {
PyObject *encoded = PyObject_Str(obj); PyObject *encoded;
if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj)) {
encoded = PyObject_Str(obj);
}
else {
/* See #118, do not trust custom str/repr */
PyObject *tmp = PyObject_CallFunctionObjArgs((PyObject *)&PyLong_Type, obj, NULL);
if (tmp == NULL) {
encoded = NULL;
}
else {
encoded = PyObject_Str(tmp);
Py_DECREF(tmp);
}
}
if (encoded != NULL) { if (encoded != NULL) {
encoded = maybe_quote_bigint(s, encoded, obj); encoded = maybe_quote_bigint(s, encoded, obj);
if (encoded == NULL) if (encoded == NULL)
@ -2895,6 +2939,16 @@ encoder_listencode_obj(PyEncoderObject *s, JSON_Accu *rval, PyObject *obj, Py_ss
else { else {
PyObject *ident = NULL; PyObject *ident = NULL;
PyObject *newobj; PyObject *newobj;
if (s->iterable_as_array) {
newobj = PyObject_GetIter(obj);
if (newobj == NULL)
PyErr_Clear();
else {
rv = encoder_listencode_list(s, rval, newobj, indent_level);
Py_DECREF(newobj);
break;
}
}
if (s->markers != Py_None) { if (s->markers != Py_None) {
int has_key; int has_key;
ident = PyLong_FromVoidPtr(obj); ident = PyLong_FromVoidPtr(obj);

View file

@ -3,7 +3,8 @@
from __future__ import absolute_import from __future__ import absolute_import
import re import re
from operator import itemgetter from operator import itemgetter
from decimal import Decimal # Do not import Decimal directly to avoid reload issues
import decimal
from .compat import u, unichr, binary_type, string_types, integer_types, PY3 from .compat import u, unichr, binary_type, string_types, integer_types, PY3
def _import_speedups(): def _import_speedups():
try: try:
@ -123,7 +124,7 @@ class JSONEncoder(object):
use_decimal=True, namedtuple_as_object=True, use_decimal=True, namedtuple_as_object=True,
tuple_as_array=True, bigint_as_string=False, tuple_as_array=True, bigint_as_string=False,
item_sort_key=None, for_json=False, ignore_nan=False, item_sort_key=None, for_json=False, ignore_nan=False,
int_as_string_bitcount=None): int_as_string_bitcount=None, iterable_as_array=False):
"""Constructor for JSONEncoder, with sensible defaults. """Constructor for JSONEncoder, with sensible defaults.
If skipkeys is false, then it is a TypeError to attempt If skipkeys is false, then it is a TypeError to attempt
@ -178,6 +179,10 @@ class JSONEncoder(object):
If tuple_as_array is true (the default), tuple (and subclasses) will If tuple_as_array is true (the default), tuple (and subclasses) will
be encoded as JSON arrays. be encoded as JSON arrays.
If *iterable_as_array* is true (default: ``False``),
any object not in the above table that implements ``__iter__()``
will be encoded as a JSON array.
If bigint_as_string is true (not the default), ints 2**53 and higher If bigint_as_string is true (not the default), ints 2**53 and higher
or lower than -2**53 will be encoded as strings. This is to avoid the or lower than -2**53 will be encoded as strings. This is to avoid the
rounding that happens in Javascript otherwise. rounding that happens in Javascript otherwise.
@ -209,6 +214,7 @@ class JSONEncoder(object):
self.use_decimal = use_decimal self.use_decimal = use_decimal
self.namedtuple_as_object = namedtuple_as_object self.namedtuple_as_object = namedtuple_as_object
self.tuple_as_array = tuple_as_array self.tuple_as_array = tuple_as_array
self.iterable_as_array = iterable_as_array
self.bigint_as_string = bigint_as_string self.bigint_as_string = bigint_as_string
self.item_sort_key = item_sort_key self.item_sort_key = item_sort_key
self.for_json = for_json self.for_json = for_json
@ -311,6 +317,9 @@ class JSONEncoder(object):
elif o == _neginf: elif o == _neginf:
text = '-Infinity' text = '-Infinity'
else: else:
if type(o) != float:
# See #118, do not trust custom str/repr
o = float(o)
return _repr(o) return _repr(o)
if ignore_nan: if ignore_nan:
@ -334,7 +343,7 @@ class JSONEncoder(object):
self.namedtuple_as_object, self.tuple_as_array, self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount, int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json, self.item_sort_key, self.encoding, self.for_json,
self.ignore_nan, Decimal) self.ignore_nan, decimal.Decimal, self.iterable_as_array)
else: else:
_iterencode = _make_iterencode( _iterencode = _make_iterencode(
markers, self.default, _encoder, self.indent, floatstr, markers, self.default, _encoder, self.indent, floatstr,
@ -343,7 +352,7 @@ class JSONEncoder(object):
self.namedtuple_as_object, self.tuple_as_array, self.namedtuple_as_object, self.tuple_as_array,
int_as_string_bitcount, int_as_string_bitcount,
self.item_sort_key, self.encoding, self.for_json, self.item_sort_key, self.encoding, self.for_json,
Decimal=Decimal) self.iterable_as_array, Decimal=decimal.Decimal)
try: try:
return _iterencode(o, 0) return _iterencode(o, 0)
finally: finally:
@ -382,11 +391,12 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
_use_decimal, _namedtuple_as_object, _tuple_as_array, _use_decimal, _namedtuple_as_object, _tuple_as_array,
_int_as_string_bitcount, _item_sort_key, _int_as_string_bitcount, _item_sort_key,
_encoding,_for_json, _encoding,_for_json,
_iterable_as_array,
## HACK: hand-optimized bytecode; turn globals into locals ## HACK: hand-optimized bytecode; turn globals into locals
_PY3=PY3, _PY3=PY3,
ValueError=ValueError, ValueError=ValueError,
string_types=string_types, string_types=string_types,
Decimal=Decimal, Decimal=None,
dict=dict, dict=dict,
float=float, float=float,
id=id, id=id,
@ -395,7 +405,10 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
list=list, list=list,
str=str, str=str,
tuple=tuple, tuple=tuple,
iter=iter,
): ):
if _use_decimal and Decimal is None:
Decimal = decimal.Decimal
if _item_sort_key and not callable(_item_sort_key): if _item_sort_key and not callable(_item_sort_key):
raise TypeError("item_sort_key must be None or callable") raise TypeError("item_sort_key must be None or callable")
elif _sort_keys and not _item_sort_key: elif _sort_keys and not _item_sort_key:
@ -412,6 +425,9 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
or or
_int_as_string_bitcount < 1 _int_as_string_bitcount < 1
) )
if type(value) not in integer_types:
# See #118, do not trust custom str/repr
value = int(value)
if ( if (
skip_quoting or skip_quoting or
(-1 << _int_as_string_bitcount) (-1 << _int_as_string_bitcount)
@ -501,6 +517,9 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
elif key is None: elif key is None:
key = 'null' key = 'null'
elif isinstance(key, integer_types): elif isinstance(key, integer_types):
if type(key) not in integer_types:
# See #118, do not trust custom str/repr
key = int(key)
key = str(key) key = str(key)
elif _use_decimal and isinstance(key, Decimal): elif _use_decimal and isinstance(key, Decimal):
key = str(key) key = str(key)
@ -634,6 +653,16 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
elif _use_decimal and isinstance(o, Decimal): elif _use_decimal and isinstance(o, Decimal):
yield str(o) yield str(o)
else: else:
while _iterable_as_array:
# Markers are not checked here because it is valid for
# an iterable to return self.
try:
o = iter(o)
except TypeError:
break
for chunk in _iterencode_list(o, _current_indent_level):
yield chunk
return
if markers is not None: if markers is not None:
markerid = id(o) markerid = id(o)
if markerid in markers: if markerid in markers:

View file

@ -62,6 +62,7 @@ def all_tests_suite():
'simplejson.tests.test_namedtuple', 'simplejson.tests.test_namedtuple',
'simplejson.tests.test_tool', 'simplejson.tests.test_tool',
'simplejson.tests.test_for_json', 'simplejson.tests.test_for_json',
'simplejson.tests.test_subclass',
])) ]))
suite = get_suite() suite = get_suite()
import simplejson import simplejson

View file

@ -0,0 +1,31 @@
import unittest
from StringIO import StringIO
import simplejson as json
def iter_dumps(obj, **kw):
return ''.join(json.JSONEncoder(**kw).iterencode(obj))
def sio_dump(obj, **kw):
sio = StringIO()
json.dumps(obj, **kw)
return sio.getvalue()
class TestIterable(unittest.TestCase):
def test_iterable(self):
l = [1, 2, 3]
for dumps in (json.dumps, iter_dumps, sio_dump):
expect = dumps(l)
default_expect = dumps(sum(l))
# Default is False
self.assertRaises(TypeError, dumps, iter(l))
self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False)
self.assertEqual(expect, dumps(iter(l), iterable_as_array=True))
# Ensure that the "default" gets called
self.assertEqual(default_expect, dumps(iter(l), default=sum))
self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum))
# Ensure that the "default" does not get called
self.assertEqual(
default_expect,
dumps(iter(l), iterable_as_array=True, default=sum))

View file

@ -0,0 +1,37 @@
from unittest import TestCase
import simplejson as json
from decimal import Decimal
class AlternateInt(int):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
class AlternateFloat(float):
def __repr__(self):
return 'invalid json'
__str__ = __repr__
# class AlternateDecimal(Decimal):
# def __repr__(self):
# return 'invalid json'
class TestSubclass(TestCase):
def test_int(self):
self.assertEqual(json.dumps(AlternateInt(1)), '1')
self.assertEqual(json.dumps(AlternateInt(-1)), '-1')
self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1})
def test_float(self):
self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0')
self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0')
self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1})
# NOTE: Decimal subclasses are not supported as-is
# def test_decimal(self):
# self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0')
# self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0')

View file

@ -45,7 +45,3 @@ class TestTuples(unittest.TestCase):
self.assertEqual( self.assertEqual(
json.dumps(repr(t)), json.dumps(repr(t)),
sio.getvalue()) sio.getvalue())
class TestNamedTuple(unittest.TestCase):
def test_namedtuple_dump(self):
pass

View file

@ -11,6 +11,7 @@
import logging import logging
import os import os
import sys
from .base import BaseProject from .base import BaseProject
from ..compat import u, open from ..compat import u, open
@ -38,8 +39,14 @@ class Git(BaseProject):
try: try:
with open(head, 'r', encoding='utf-8') as fh: with open(head, 'r', encoding='utf-8') as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1]) return u(fh.readline().strip().rsplit('/', 1)[-1])
except UnicodeDecodeError:
try:
with open(head, 'r', encoding=sys.getfilesystemencoding()) as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except:
log.exception("Exception:")
except IOError: except IOError:
pass log.exception("Exception:")
return None return None
def _project_base(self): def _project_base(self):

View file

@ -11,6 +11,7 @@
import logging import logging
import os import os
import sys
from .base import BaseProject from .base import BaseProject
from ..compat import u, open from ..compat import u, open
@ -36,8 +37,14 @@ class Mercurial(BaseProject):
try: try:
with open(branch_file, 'r', encoding='utf-8') as fh: with open(branch_file, 'r', encoding='utf-8') as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1]) return u(fh.readline().strip().rsplit('/', 1)[-1])
except UnicodeDecodeError:
try:
with open(branch_file, 'r', encoding=sys.getfilesystemencoding()) as fh:
return u(fh.readline().strip().rsplit('/', 1)[-1])
except:
log.exception("Exception:")
except IOError: except IOError:
pass log.exception("Exception:")
return u('default') return u('default')
def _find_hg_config_dir(self, path): def _find_hg_config_dir(self, path):

View file

@ -46,8 +46,8 @@ class Subversion(BaseProject):
'/usr/local/bin/svn', '/usr/local/bin/svn',
] ]
for location in locations: for location in locations:
with open(os.devnull, 'wb') as DEVNULL:
try: try:
with open(os.devnull, 'wb') as DEVNULL:
Popen([location, '--version'], stdout=DEVNULL, stderr=DEVNULL) Popen([location, '--version'], stdout=DEVNULL, stderr=DEVNULL)
self.binary_location = location self.binary_location = location
return location return location

View file

@ -13,6 +13,7 @@
import logging import logging
import os import os
import sys
from .base import BaseProject from .base import BaseProject
from ..compat import u, open from ..compat import u, open
@ -34,6 +35,13 @@ class WakaTimeProjectFile(BaseProject):
with open(self.config, 'r', encoding='utf-8') as fh: with open(self.config, 'r', encoding='utf-8') as fh:
self._project_name = u(fh.readline().strip()) self._project_name = u(fh.readline().strip())
self._project_branch = u(fh.readline().strip()) self._project_branch = u(fh.readline().strip())
except UnicodeDecodeError:
try:
with open(self.config, 'r', encoding=sys.getfilesystemencoding()) as fh:
self._project_name = u(fh.readline().strip())
self._project_branch = u(fh.readline().strip())
except:
log.exception("Exception:")
except IOError: except IOError:
log.exception("Exception:") log.exception("Exception:")

View file

@ -52,7 +52,7 @@ class SessionCache(object):
conn, c = self.connect() conn, c = self.connect()
c.execute('DELETE FROM session') c.execute('DELETE FROM session')
values = { values = {
'value': pickle.dumps(session), 'value': pickle.dumps(session, protocol=2),
} }
c.execute('INSERT INTO session VALUES (:value)', values) c.execute('INSERT INTO session VALUES (:value)', values)
conn.commit() conn.commit()

View file

@ -28,57 +28,20 @@ from pygments.util import ClassNotFound
log = logging.getLogger('WakaTime') log = logging.getLogger('WakaTime')
# extensions taking priority over lexer
EXTENSIONS = {
'j2': 'HTML',
'markdown': 'Markdown',
'md': 'Markdown',
'mdown': 'Markdown',
'twig': 'Twig',
}
# lexers to human readable languages
TRANSLATIONS = {
'CSS+Genshi Text': 'CSS',
'CSS+Lasso': 'CSS',
'HTML+Django/Jinja': 'HTML',
'HTML+Lasso': 'HTML',
'JavaScript+Genshi Text': 'JavaScript',
'JavaScript+Lasso': 'JavaScript',
'Perl6': 'Perl',
'RHTML': 'HTML',
}
# extensions for when no lexer is found
AUXILIARY_EXTENSIONS = {
'vb': 'VB.net',
}
def guess_language(file_name): def guess_language(file_name):
"""Guess lexer and language for a file. """Guess lexer and language for a file.
Returns (language, lexer) tuple where language is a unicode string. Returns (language, lexer) tuple where language is a unicode string.
""" """
language = get_language_from_extension(file_name)
if language:
return language, None
lexer = smart_guess_lexer(file_name) lexer = smart_guess_lexer(file_name)
if lexer:
language = None
# guess language from file extension
if file_name:
language = get_language_from_extension(file_name, EXTENSIONS)
# get language from lexer if we didn't have a hard-coded extension rule
if language is None and lexer:
language = u(lexer.name) language = u(lexer.name)
if language is None:
language = get_language_from_extension(file_name, AUXILIARY_EXTENSIONS)
if language is not None:
language = translate_language(language)
return language, lexer return language, lexer
@ -93,14 +56,14 @@ def smart_guess_lexer(file_name):
text = get_file_contents(file_name) text = get_file_contents(file_name)
lexer_1, accuracy_1 = guess_lexer_using_filename(file_name, text) lexer1, accuracy1 = guess_lexer_using_filename(file_name, text)
lexer_2, accuracy_2 = guess_lexer_using_modeline(text) lexer2, accuracy2 = guess_lexer_using_modeline(text)
if lexer_1: if lexer1:
lexer = lexer_1 lexer = lexer1
if (lexer_2 and accuracy_2 and if (lexer2 and accuracy2 and
(not accuracy_1 or accuracy_2 > accuracy_1)): (not accuracy1 or accuracy2 > accuracy1)):
lexer = lexer_2 lexer = lexer2
return lexer return lexer
@ -156,36 +119,35 @@ def guess_lexer_using_modeline(text):
return lexer, accuracy return lexer, accuracy
def get_language_from_extension(file_name, extension_map): def get_language_from_extension(file_name):
"""Returns a matching language for the given file_name using extension_map. """Returns a matching language for the given file extension.
""" """
extension = file_name.rsplit('.', 1)[-1] if len(file_name.rsplit('.', 1)) > 1 else None extension = os.path.splitext(file_name)[1].lower()
if extension == '.h':
if extension: directory = os.path.dirname(file_name)
if extension in extension_map: available_files = os.listdir(directory)
return extension_map[extension] available_extensions = zip(*map(os.path.splitext, available_files))[1]
if extension.lower() in extension_map: available_extensions = [ext.lower() for ext in available_extensions]
return extension_map[extension.lower()] if '.cpp' in available_extensions:
return 'C++'
if '.c' in available_extensions:
return 'C'
return None return None
def translate_language(language):
"""Turns Pygments lexer class name string into human-readable language.
"""
if language in TRANSLATIONS:
language = TRANSLATIONS[language]
return language
def number_lines_in_file(file_name): def number_lines_in_file(file_name):
lines = 0 lines = 0
try: try:
with open(file_name, 'r', encoding='utf-8') as fh: with open(file_name, 'r', encoding='utf-8') as fh:
for line in fh: for line in fh:
lines += 1 lines += 1
except:
try:
with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh:
for line in fh:
lines += 1
except: except:
return None return None
return lines return lines
@ -223,5 +185,9 @@ def get_file_contents(file_name):
with open(file_name, 'r', encoding='utf-8') as fh: with open(file_name, 'r', encoding='utf-8') as fh:
text = fh.read(512000) text = fh.read(512000)
except: except:
pass try:
with open(file_name, 'r', encoding=sys.getfilesystemencoding()) as fh:
text = fh.read(512000)
except:
log.exception("Exception:")
return text return text