summaryrefslogtreecommitdiff
path: root/passlib/utils
diff options
context:
space:
mode:
authorEli Collins <elic@assurancetechnologies.com>2011-06-17 00:36:46 -0400
committerEli Collins <elic@assurancetechnologies.com>2011-06-17 00:36:46 -0400
commite23ee714f2606fdb24e071bf481c76442e0a1aec (patch)
tree61a72a6baf375a6b5ff08feeef4b709fd7eeca31 /passlib/utils
parent3f7c43556e647df2b2993fc4cf47a87933e276cc (diff)
downloadpasslib-e23ee714f2606fdb24e071bf481c76442e0a1aec.tar.gz
rest of utils now py3 compat
* kdfs, md4, and utils proper * updated UTs * added to_native_str helper * added some UTs for new to_bytes / to_unicode etc methods
Diffstat (limited to 'passlib/utils')
-rw-r--r--passlib/utils/__init__.py97
-rw-r--r--passlib/utils/md4.py12
-rw-r--r--passlib/utils/pbkdf2.py56
3 files changed, 105 insertions, 60 deletions
diff --git a/passlib/utils/__init__.py b/passlib/utils/__init__.py
index bebf7a8..8064aa0 100644
--- a/passlib/utils/__init__.py
+++ b/passlib/utils/__init__.py
@@ -38,6 +38,7 @@ __all__ = [
#bytes<->unicode
'to_bytes',
'to_unicode',
+ 'to_native_str',
'is_same_codec',
#byte manipulation
@@ -290,6 +291,38 @@ def to_unicode(source, source_encoding="utf-8", errname="value"):
raise TypeError("%s must be unicode or %s-encoded bytes, not %s" %
(errname, source_encoding, type(source)))
+def to_native_str(source, encoding="utf-8", errname="value"):
+ """take in unicode or bytes, return native string
+
+ python 2: encodes unicode using specified encoding, leaves bytes alone.
+ python 3: decodes bytes using specified encoding, leaves unicode alone.
+
+ :raises TypeError: if source is not unicode or bytes.
+
+ :arg source: source bytes/unicode to process
+ :arg encoding: encoding to use when encoding unicode / decoding bytes
+ :param errname: optional name of variable/noun to reference when raising errors
+
+ :returns: :class:`str` instance
+ """
+ assert encoding
+ if isinstance(source, bytes):
+ # Py2k #
+ return source
+ # Py3k #
+ #return source.decode(encoding)
+ # end Py3k #
+
+ elif isinstance(source, unicode):
+ # Py2k #
+ return source.encode(encoding)
+ # Py3k #
+ #return source
+ # end Py3k #
+
+ else:
+ raise TypeError("%s must be unicode or bytes, not %s" % (errname, type(source)))
+
#--------------------------------------------------
#support utils
#--------------------------------------------------
@@ -321,6 +354,7 @@ BEMPTY = b('')
#helpers for joining / extracting elements
bjoin = BEMPTY.join
+ujoin = u''.join
def belem_join(elems):
"""takes series of bytes elements, returns bytes.
@@ -509,7 +543,7 @@ def genseed(value=None):
#if value is rng, extract a bunch of bits from it's state
if hasattr(value, "getrandbits"):
value = value.getrandbits(256)
- text = "%s %s %s %.15f %s" % (
+ text = u"%s %s %s %.15f %s" % (
value,
#if user specified a seed value (eg current rng state), mix it in
@@ -523,11 +557,11 @@ def genseed(value=None):
time.time(),
#the current time, to whatever precision os uses
- os.urandom(16) if has_urandom else 0,
+ os.urandom(16).decode("latin-1") if has_urandom else 0,
#if urandom available, might as well mix some bytes in.
)
#hash it all up and return it as int
- return long(sha256(text).hexdigest(), 16)
+ return long(sha256(text.encode("utf-8")).hexdigest(), 16)
if has_urandom:
rng = random.SystemRandom()
@@ -548,18 +582,28 @@ def getrandbytes(rng, count):
##if meth:
## return meth(count)
- #XXX: break into chunks for large number of bits?
if not count:
- return ''
- value = rng.getrandbits(count<<3)
- buf = StringIO()
- for i in xrange(count):
- buf.write(chr(value & 0xff))
- value //= 0xff
- return buf.getvalue()
+ return BEMPTY
+ def helper():
+ #XXX: break into chunks for large number of bits?
+ value = rng.getrandbits(count<<3)
+ i = 0
+ while i < count:
+ # Py2k #
+ yield chr(value & 0xff)
+ # Py3k #
+ #yield value & 0xff
+ # end Py3k #
+ value >>= 3
+ i += 1
+ # Py2k #
+ return bjoin(helper())
+ # Py3k #
+ #return bytes(helper())
+ # end Py3k #
def getrandstr(rng, charset, count):
- """return character string containg *count* number of chars, whose elements are drawn from specified charset, using specified rng"""
+ """return string containing *count* number of chars/bytes, whose elements are drawn from specified charset, using specified rng"""
#check alphabet & count
if count < 0:
raise ValueError("count must be >= 0")
@@ -570,16 +614,25 @@ def getrandstr(rng, charset, count):
return charset * count
#get random value, and write out to buffer
- #XXX: break into chunks for large number of letters?
- value = rng.randrange(0, letters**count)
- buf = StringIO()
- for i in xrange(count):
- buf.write(charset[value % letters])
- value //= letters
- assert value == 0
- return buf.getvalue()
-
-def generate_password(size=10, charset=u'2346789ABCDEFGHJKMNPQRTUVWXYZabcdefghjkmnpqrstuvwxyz'):
+ def helper():
+ #XXX: break into chunks for large number of letters?
+ value = rng.randrange(0, letters**count)
+ i = 0
+ while i < count:
+ yield charset[value % letters]
+ value //= letters
+ i += 1
+
+ if isinstance(charset, unicode):
+ return ujoin(helper())
+ else:
+ # Py2k #
+ return bjoin(helper())
+ # Py3k #
+ #return bytes(helper())
+ # end Py3k #
+
+def generate_password(size=10, charset='2346789ABCDEFGHJKMNPQRTUVWXYZabcdefghjkmnpqrstuvwxyz'):
"""generate random password using given length & chars
:param size:
diff --git a/passlib/utils/md4.py b/passlib/utils/md4.py
index d81afe5..40a48e4 100644
--- a/passlib/utils/md4.py
+++ b/passlib/utils/md4.py
@@ -15,6 +15,7 @@ from binascii import hexlify
import struct
from warnings import warn
#site
+from passlib.utils import b, bytes, to_native_str
#local
__all__ = [ "md4" ]
#=========================================================================
@@ -71,7 +72,7 @@ class md4(object):
def __init__(self, content=None):
self._count = 0
self._state = [0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476]
- self._buf = ''
+ self._buf = b('')
if content:
self.update(content)
@@ -173,6 +174,8 @@ class md4(object):
orig[i] = (orig[i]+state[i]) & MASK_32
def update(self, content):
+ if not isinstance(content, bytes):
+ raise TypeError("expected bytes")
buf = self._buf
if buf:
content = buf + content
@@ -205,7 +208,7 @@ class md4(object):
# then last 8 bytes = msg length in bits
buf = self._buf
msglen = self._count*512 + len(buf)*8
- block = buf + '\x80' + '\x00' * ((119-len(buf)) % 64) + \
+ block = buf + b('\x80') + b('\x00') * ((119-len(buf)) % 64) + \
struct.pack("<2I", msglen & MASK_32, (msglen>>32) & MASK_32)
if len(block) == 128:
self._process(block[:64])
@@ -220,7 +223,8 @@ class md4(object):
return out
def hexdigest(self):
- return hexlify(self.digest())
+ return to_native_str(hexlify(self.digest()), "latin-1")
+
#=========================================================================
#eoc
#=========================================================================
@@ -252,7 +256,7 @@ if _has_native_md4():
#overwrite md4 class w/ hashlib wrapper
def md4(content=None):
"wrapper for hashlib.new('md4')"
- return hashlib.new('md4', content or '')
+ return hashlib.new('md4', content or b(''))
else:
del hashlib
diff --git a/passlib/utils/pbkdf2.py b/passlib/utils/pbkdf2.py
index adb071f..7bc6935 100644
--- a/passlib/utils/pbkdf2.py
+++ b/passlib/utils/pbkdf2.py
@@ -8,7 +8,6 @@ maybe rename to "kdf" since it's getting more key derivation functions added.
#=================================================================================
#core
from binascii import unhexlify
-from cStringIO import StringIO
import hashlib
import hmac
import logging; log = logging.getLogger(__name__)
@@ -21,7 +20,7 @@ try:
except ImportError:
_EVP = None
#pkg
-from passlib.utils import xor_bytes
+from passlib.utils import xor_bytes, to_bytes, native_str, b, bytes
#local
__all__ = [
"hmac_sha1",
@@ -30,6 +29,12 @@ __all__ = [
"pbkdf2",
]
+# Py2k #
+from cStringIO import StringIO as BytesIO
+# Py3k #
+#from io import BytesIO
+# end Py3k #
+
#=================================================================================
#quick hmac_sha1 implementation used various places
#=================================================================================
@@ -40,12 +45,12 @@ def hmac_sha1(key, msg):
if _EVP:
#default *should* be sha1, which saves us a wrapper function, but might as well check.
try:
- result = _EVP.hmac('x','y')
+ result = _EVP.hmac(b('x'),b('y'))
except ValueError: #pragma: no cover
#this is probably not a good sign if it happens.
warn("PassLib: M2Crypt.EVP.hmac() unexpected threw value error during passlib startup test")
else:
- if result == ',\x1cb\xe0H\xa5\x82M\xfb>\xd6\x98\xef\x8e\xf9oQ\x85\xa3i':
+ if result == b(',\x1cb\xe0H\xa5\x82M\xfb>\xd6\x98\xef\x8e\xf9oQ\x85\xa3i'):
hmac_sha1 = _EVP.hmac
#=================================================================================
@@ -56,7 +61,7 @@ def _get_hmac_prf(digest):
#check if m2crypto is present and supports requested digest
if _EVP:
try:
- result = _EVP.hmac('x', 'y', digest)
+ result = _EVP.hmac(b('x'), b('y'), digest)
except ValueError:
pass
else:
@@ -132,14 +137,14 @@ def get_prf(name):
global _prf_cache
if name in _prf_cache:
return _prf_cache[name]
- if isinstance(name, str):
+ if isinstance(name, native_str):
if name.startswith("hmac-") or name.startswith("hmac_"):
retval = _get_hmac_prf(name[5:])
else:
raise ValueError("unknown prf algorithm: %r" % (name,))
elif callable(name):
#assume it's a callable, use it directly
- digest_size = len(name('x','y'))
+ digest_size = len(name(b('x'),b('y')))
retval = (name, digest_size)
else:
raise TypeError("prf must be string or callable")
@@ -175,19 +180,11 @@ def pbkdf1(secret, salt, rounds, keylen, hash="sha1", encoding="utf8"):
than the digest size of the specified hash.
"""
- #prepare secret
- if isinstance(secret, unicode):
- secret = secret.encode(encoding)
- elif not isinstance(secret, str):
- raise TypeError("secret must be str or unicode")
-
- #prepare salt
- if isinstance(salt, unicode):
- salt = salt.encode(encoding)
- elif not isinstance(salt, str):
- raise TypeError("salt must be str or unicode")
+ #prepare secret & salt
+ secret = to_bytes(secret, encoding, errname="secret")
+ salt = to_bytes(salt, encoding, errname="salt")
- #preprare rounds
+ #prepare rounds
if not isinstance(rounds, (int, long)):
raise TypeError("rounds must be an integer")
if rounds < 1:
@@ -198,7 +195,7 @@ def pbkdf1(secret, salt, rounds, keylen, hash="sha1", encoding="utf8"):
raise ValueError("keylen must be at least 0")
#resolve hash
- if isinstance(hash, str):
+ if isinstance(hash, native_str):
#check for builtin hash
hf = getattr(hashlib, hash, None)
if hf is None:
@@ -244,20 +241,11 @@ def pbkdf2(secret, salt, rounds, keylen, prf="hmac-sha1", encoding="utf8"):
:returns:
raw bytes of generated key
"""
+ #prepare secret & salt
+ secret = to_bytes(secret, encoding, errname="secret")
+ salt = to_bytes(salt, encoding, errname="salt")
- #prepare secret
- if isinstance(secret, unicode):
- secret = secret.encode(encoding)
- elif not isinstance(secret, str):
- raise TypeError("secret must be str or unicode")
-
- #prepare salt
- if isinstance(salt, unicode):
- salt = salt.encode(encoding)
- elif not isinstance(salt, str):
- raise TypeError("salt must be str or unicode")
-
- #preprare rounds
+ #prepare rounds
if not isinstance(rounds, (int, long)):
raise TypeError("rounds must be an integer")
if rounds < 1:
@@ -284,7 +272,7 @@ def pbkdf2(secret, salt, rounds, keylen, prf="hmac-sha1", encoding="utf8"):
raise ValueError("key length to long")
#build up key from blocks
- out = StringIO()
+ out = BytesIO()
write = out.write
for i in xrange(1,bcount+1):
block = tmp = encode_block(secret, salt + pack(">L", i))