diff options
Diffstat (limited to 'Lib/base64.py')
| -rwxr-xr-x | Lib/base64.py | 50 | 
1 files changed, 26 insertions, 24 deletions
| diff --git a/Lib/base64.py b/Lib/base64.py index 895d813f7e..4042f004fd 100755 --- a/Lib/base64.py +++ b/Lib/base64.py @@ -29,14 +29,16 @@ __all__ = [  bytes_types = (bytes, bytearray)  # Types acceptable as binary data - -def _translate(s, altchars): -    if not isinstance(s, bytes_types): -        raise TypeError("expected bytes, not %s" % s.__class__.__name__) -    translation = bytearray(range(256)) -    for k, v in altchars.items(): -        translation[ord(k)] = v[0] -    return s.translate(translation) +def _bytes_from_decode_data(s): +    if isinstance(s, str): +        try: +            return s.encode('ascii') +        except UnicodeEncodeError: +            raise ValueError('string argument should contain only ASCII characters') +    elif isinstance(s, bytes_types): +        return s +    else: +        raise TypeError("argument should be bytes or ASCII string, not %s" % s.__class__.__name__) @@ -61,7 +63,7 @@ def b64encode(s, altchars=None):              raise TypeError("expected bytes, not %s"                              % altchars.__class__.__name__)          assert len(altchars) == 2, repr(altchars) -        return _translate(encoded, {'+': altchars[0:1], '/': altchars[1:2]}) +        return encoded.translate(bytes.maketrans(b'+/', altchars))      return encoded @@ -79,14 +81,11 @@ def b64decode(s, altchars=None, validate=False):      discarded prior to the padding check.  If validate is True,      non-base64-alphabet characters in the input result in a binascii.Error.      """ -    if not isinstance(s, bytes_types): -        raise TypeError("expected bytes, not %s" % s.__class__.__name__) +    s = _bytes_from_decode_data(s)      if altchars is not None: -        if not isinstance(altchars, bytes_types): -            raise TypeError("expected bytes, not %s" -                            % altchars.__class__.__name__) +        altchars = _bytes_from_decode_data(altchars)          assert len(altchars) == 2, repr(altchars) -        s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'}) +        s = s.translate(bytes.maketrans(altchars, b'+/'))      if validate and not re.match(b'^[A-Za-z0-9+/]*={0,2}$', s):          raise binascii.Error('Non-base64 digit found')      return binascii.a2b_base64(s) @@ -109,6 +108,10 @@ def standard_b64decode(s):      """      return b64decode(s) + +_urlsafe_encode_translation = bytes.maketrans(b'+/', b'-_') +_urlsafe_decode_translation = bytes.maketrans(b'-_', b'+/') +  def urlsafe_b64encode(s):      """Encode a byte string using a url-safe Base64 alphabet. @@ -116,7 +119,7 @@ def urlsafe_b64encode(s):      returned.  The alphabet uses '-' instead of '+' and '_' instead of      '/'.      """ -    return b64encode(s, b'-_') +    return b64encode(s).translate(_urlsafe_encode_translation)  def urlsafe_b64decode(s):      """Decode a byte string encoded with the standard Base64 alphabet. @@ -128,7 +131,9 @@ def urlsafe_b64decode(s):      The alphabet uses '-' instead of '+' and '_' instead of '/'.      """ -    return b64decode(s, b'-_') +    s = _bytes_from_decode_data(s) +    s = s.translate(_urlsafe_decode_translation) +    return b64decode(s) @@ -211,8 +216,7 @@ def b32decode(s, casefold=False, map01=None):      the input is incorrectly padded or if there are non-alphabet      characters present in the input.      """ -    if not isinstance(s, bytes_types): -        raise TypeError("expected bytes, not %s" % s.__class__.__name__) +    s = _bytes_from_decode_data(s)      quanta, leftover = divmod(len(s), 8)      if leftover:          raise binascii.Error('Incorrect padding') @@ -220,10 +224,9 @@ def b32decode(s, casefold=False, map01=None):      # False, or the character to map the digit 1 (one) to.  It should be      # either L (el) or I (eye).      if map01 is not None: -        if not isinstance(map01, bytes_types): -            raise TypeError("expected bytes, not %s" % map01.__class__.__name__) +        map01 = _bytes_from_decode_data(map01)          assert len(map01) == 1, repr(map01) -        s = _translate(s, {b'0': b'O', b'1': map01}) +        s = s.translate(bytes.maketrans(b'01', b'O' + map01))      if casefold:          s = s.upper()      # Strip off pad characters from the right.  We need to count the pad @@ -292,8 +295,7 @@ def b16decode(s, casefold=False):      s were incorrectly padded or if there are non-alphabet characters      present in the string.      """ -    if not isinstance(s, bytes_types): -        raise TypeError("expected bytes, not %s" % s.__class__.__name__) +    s = _bytes_from_decode_data(s)      if casefold:          s = s.upper()      if re.search(b'[^0-9A-F]', s): | 
