diff options
Diffstat (limited to 'memcache.py')
-rw-r--r-- | memcache.py | 147 |
1 files changed, 77 insertions, 70 deletions
diff --git a/memcache.py b/memcache.py index e0a2662..11da6c1 100644 --- a/memcache.py +++ b/memcache.py @@ -59,15 +59,11 @@ import zlib import six -if six.PY2: - # With Python 2, the faster C implementation has to be imported explicitly. - import cPickle as pickle -else: - import pickle +import pickle def cmemcache_hash(key): - return (((binascii.crc32(key) & 0xffffffff) >> 16) & 0x7fff) or 1 + return ((binascii.crc32(key) & 0xffffffff) >> 16) & 0x7fff serverHashFunction = cmemcache_hash @@ -130,7 +126,7 @@ class Client(threading.local): @group Integers: incr, decr @group Removal: delete, delete_multi @sort: __init__, set_servers, forget_dead_hosts, disconnect_all, - debuglog,\ set, set_multi, add, replace, get, get_multi, + debuglog, set, set_multi, add, replace, get, get_multi, incr, decr, delete, delete_multi """ _FLAG_PICKLE = 1 << 0 @@ -253,12 +249,11 @@ class Client(threading.local): return key def _encode_cmd(self, cmd, key, headers, noreply, *args): - cmd_bytes = cmd.encode('utf-8') if six.PY3 else cmd + cmd_bytes = cmd.encode('utf-8') fullcmd = [cmd_bytes, b' ', key] if headers: - if six.PY3: - headers = headers.encode('utf-8') + headers = headers.encode('utf-8') fullcmd.append(b' ') fullcmd.append(headers) @@ -326,11 +321,13 @@ class Client(threading.local): serverData = {} data.append((name, serverData)) readline = s.readline - while 1: + while True: line = readline() - if not line or line.decode('ascii').strip() == 'END': + if line: + line = line.decode('ascii') + if not line or line.strip() == 'END': break - stats = line.decode('ascii').split(' ', 2) + stats = line.split(' ', 2) serverData[stats[1]] = stats[2] return data @@ -350,8 +347,10 @@ class Client(threading.local): data.append((name, serverData)) s.send_cmd('stats slabs') readline = s.readline - while 1: + while True: line = readline() + if line: + line = line.decode('ascii') if not line or line.strip() == 'END': break item = line.split(' ', 2) @@ -366,6 +365,11 @@ class Client(threading.local): serverData[slab[0]][slab[1]] = item[2] return data + def quit_all(self) -> None: + '''Send a "quit" command to all servers and wait for the connection to close.''' + for s in self.servers: + s.quit() + def get_slabs(self): data = [] for s in self.servers: @@ -381,7 +385,7 @@ class Client(threading.local): data.append((name, serverData)) s.send_cmd('stats items') readline = s.readline - while 1: + while True: line = readline() if not line or line.strip() == 'END': break @@ -520,18 +524,36 @@ class Client(threading.local): rc = 0 return rc - def delete(self, key, time=None, noreply=False): + def delete(self, key, noreply=False): '''Deletes a key from the memcache. @return: Nonzero on success. - @param time: number of seconds any subsequent set / update commands - should fail. Defaults to None for no delay. @param noreply: optional parameter instructs the server to not send the reply. @rtype: int ''' - return self._deletetouch([b'DELETED', b'NOT_FOUND'], "delete", key, - time, noreply) + key = self._encode_key(key) + if self.do_check_key: + self.check_key(key) + server, key = self._get_server(key) + if not server: + return 0 + self._statlog('delete') + fullcmd = self._encode_cmd('delete', key, None, noreply) + + try: + server.send_cmd(fullcmd) + if noreply: + return 1 + line = server.readline() + if line and line.strip() in [b'DELETED', b'NOT_FOUND']: + return 1 + self.debuglog('delete expected DELETED or NOT_FOUND, got: %r' % (line,)) + except socket.error as msg: + if isinstance(msg, tuple): + msg = msg[1] + server.mark_dead(msg) + return 0 def touch(self, key, time=0, noreply=False): '''Updates the expiration time of a key in memcache. @@ -546,31 +568,23 @@ class Client(threading.local): reply. @rtype: int ''' - return self._deletetouch([b'TOUCHED'], "touch", key, time, noreply) - - def _deletetouch(self, expected, cmd, key, time=0, noreply=False): key = self._encode_key(key) if self.do_check_key: self.check_key(key) server, key = self._get_server(key) if not server: return 0 - self._statlog(cmd) - if time is not None: - headers = str(time) - else: - headers = None - fullcmd = self._encode_cmd(cmd, key, headers, noreply) + self._statlog('touch') + fullcmd = self._encode_cmd('touch', key, str(time), noreply) try: server.send_cmd(fullcmd) if noreply: return 1 line = server.readline() - if line and line.strip() in expected: + if line and line.strip() in [b'TOUCHED']: return 1 - self.debuglog('%s expected %s, got: %r' - % (cmd, b' or '.join(expected), line)) + self.debuglog('touch expected TOUCHED, got: %r' % (line,)) except socket.error as msg: if isinstance(msg, tuple): msg = msg[1] @@ -796,24 +810,18 @@ class Client(threading.local): key = self._encode_key(key) if not isinstance(key, six.binary_type): # set_multi supports int / long keys. - key = str(key) - if six.PY3: - key = key.encode('utf8') + key = str(key).encode('utf8') bytes_orig_key = key # Gotta pre-mangle key before hashing to a # server. Returns the mangled key. server, key = self._get_server( (serverhash, key_prefix + key)) - - orig_key = orig_key[1] else: key = self._encode_key(orig_key) if not isinstance(key, six.binary_type): # set_multi supports int / long keys. - key = str(key) - if six.PY3: - key = key.encode('utf8') + key = str(key).encode('utf8') bytes_orig_key = key server, key = self._get_server(key_prefix + key) @@ -972,16 +980,7 @@ class Client(threading.local): val = val.encode('utf-8') elif val_type == int: flags |= Client._FLAG_INTEGER - val = '%d' % val - if six.PY3: - val = val.encode('ascii') - # force no attempt to compress this silly string. - min_compress_len = 0 - elif six.PY2 and isinstance(val, long): # noqa: F821 - flags |= Client._FLAG_LONG - val = str(val) - if six.PY3: - val = val.encode('ascii') + val = ('%d' % val).encode('ascii') # force no attempt to compress this silly string. min_compress_len = 0 else: @@ -1008,8 +1007,7 @@ class Client(threading.local): val = comp_val # silently do not store if value length exceeds maximum - if (self.server_max_value_length != 0 and - len(val) > self.server_max_value_length): + if (self.server_max_value_length != 0 and len(val) > self.server_max_value_length): return 0 return (flags, len(val), val) @@ -1064,7 +1062,7 @@ class Client(threading.local): server.mark_dead(msg) return 0 - def _get(self, cmd, key): + def _get(self, cmd, key, default=None): key = self._encode_key(key) if self.do_check_key: self.check_key(key) @@ -1076,7 +1074,7 @@ class Client(threading.local): self._statlog(cmd) try: - cmd_bytes = cmd.encode('utf-8') if six.PY3 else cmd + cmd_bytes = cmd.encode('utf-8') fullcmd = b''.join((cmd_bytes, b' ', key)) server.send_cmd(fullcmd) rkey = flags = rlen = cas_id = None @@ -1093,7 +1091,7 @@ class Client(threading.local): ) if not rkey: - return None + return default try: value = self._recv_value(server, flags, rlen) finally: @@ -1118,12 +1116,12 @@ class Client(threading.local): server.mark_dead(msg) return None - def get(self, key): + def get(self, key, default=None): '''Retrieves a key from the memcache. @return: The value or None. ''' - return self._get('get', key) + return self._get('get', key, default) def gets(self, key): '''Retrieves a key from the memcache. Used in conjunction with 'cas'. @@ -1269,10 +1267,7 @@ class Client(threading.local): elif flags & Client._FLAG_INTEGER: val = int(buf) elif flags & Client._FLAG_LONG: - if six.PY3: - val = int(buf) - else: - val = long(buf) # noqa: F821 + val = int(buf) elif flags & Client._FLAG_PICKLE: try: file = BytesIO(buf) @@ -1305,8 +1300,8 @@ class Client(threading.local): key = key[1] if key is None: raise Client.MemcachedKeyNoneError("Key is None") - if key is '': - if key_extra_len is 0: + if key == '': + if key_extra_len == 0: raise Client.MemcachedKeyNoneError("Key is empty") # key is empty but there is some other component to key @@ -1315,8 +1310,7 @@ class Client(threading.local): if not isinstance(key, six.binary_type): raise Client.MemcachedKeyTypeError("Key must be a binary string") - if (self.server_max_key_length != 0 and - len(key) + key_extra_len > self.server_max_key_length): + if (self.server_max_key_length != 0 and len(key) + key_extra_len > self.server_max_key_length): raise Client.MemcachedKeyLengthError( "Key length is > %s" % self.server_max_key_length ) @@ -1468,11 +1462,8 @@ class _Host(object): def expect(self, text, raise_exception=False): line = self.readline(raise_exception) if self.debug and line != text: - if six.PY3: - text = text.decode('utf8') - log_line = line.decode('utf8', 'replace') - else: - log_line = line + text = text.decode('utf8') + log_line = line.decode('utf8', 'replace') self.debuglog("while expecting %r, got unexpected response %r" % (text, log_line)) return line @@ -1489,6 +1480,22 @@ class _Host(object): self.buffer = buf[rlen:] return buf[:rlen] + def quit(self) -> None: + '''Send a "quit" command to remote server and wait for connection to close.''' + if self.socket: + self.send_cmd('quit') + + # We can't close the local socket until the remote end processes the quit + # command and sends us a FIN packet. When that happens, socket.recv() + # will stop blocking and return an empty string. If we try to close the + # socket before then, the OS will think we're initiating the connection + # close and will put the socket into TIME_WAIT. + self.socket.recv(1) + + # At this point, socket should be in CLOSE_WAIT. Closing the socket should + # release the port back to the OS. + self.close_socket() + def flush(self): self.send_cmd('flush_all') self.expect(b'OK') |