diff options
| -rw-r--r-- | qpid/python/qpid/driver.py | 43 | ||||
| -rw-r--r-- | qpid/python/qpid/messaging.py | 13 | ||||
| -rw-r--r-- | qpid/python/qpid/sasl.py | 99 |
3 files changed, 142 insertions, 13 deletions
diff --git a/qpid/python/qpid/driver.py b/qpid/python/qpid/driver.py index 1e295dd42f..6ca555c9ff 100644 --- a/qpid/python/qpid/driver.py +++ b/qpid/python/qpid/driver.py @@ -17,7 +17,7 @@ # under the License. # -import address, compat, connection, socket, struct, sys, time +import address, compat, connection, sasl, socket, struct, sys, time from concurrency import synchronized from datatypes import RangedSet, Serial from exceptions import Timeout, VersionError @@ -171,6 +171,24 @@ class Driver: self._op_dec = OpDecoder() self._timeout = None + self._sasl = sasl.Client() + if self.connection.username: + self._sasl.setAttr("username", self.connection.username) + if self.connection.password: + self._sasl.setAttr("password", self.connection.password) + if self.connection.host: + self._sasl.setAttr("host", self.connection.host) + options = self.connection.options + if "service" in options: + self._sasl.setAttr("service", options["service"]) + if "min_ssf" in options: + self._sasl.setAttr("minssf", options["min_ssf"]) + if "max_ssf" in options: + self._sasl.setAttr("maxssf", options["max_ssf"]) + self._sasl.init() + self._sasl_encode = False + self._sasl_decode = False + for ssn in self.connection.sessions.values(): for m in ssn.acked + ssn.unacked + ssn.incoming: m._transfer_id = None @@ -210,6 +228,8 @@ class Driver: try: data = self._socket.recv(64*1024) if data: + if self._sasl_decode: + data = self._sasl.decode(data) rawlog.debug("READ[%s]: %r", self.log_id, data) else: rawlog.debug("ABORTED[%s]: %s", self.log_id, self._socket.getpeername()) @@ -287,7 +307,10 @@ class Driver: self._op_enc.write(op) self._seg_enc.write(*self._op_enc.read()) self._frame_enc.write(*self._seg_enc.read()) - self._buf += self._frame_enc.read() + bytes = self._frame_enc.read() + if self._sasl_encode: + bytes = self._sasl.encode(bytes) + self._buf += bytes def do_header(self, hdr): cli_major = 0; cli_minor = 10 @@ -297,11 +320,17 @@ class Driver: (cli_major, cli_minor, major, minor)) def do_connection_start(self, start): - # XXX: should we use some sort of callback for this? - r = "\0%s\0%s" % (self.connection.username, self.connection.password) - m = self.connection.mechanism + if self.connection.mechanisms: + mechs = [m for m in start.mechanisms if m in self.connection.mechanisms] + else: + mechs = start.mechanisms + mech, initial = self._sasl.start(" ".join(mechs)) self.write_op(ConnectionStartOk(client_properties=CLIENT_PROPERTIES, - mechanism=m, response=r)) + mechanism=mech, response=initial)) + + def do_connection_secure(self, secure): + resp = self._sasl.step(secure.challenge) + self.write_op(ConnectionSecureOk(response=resp)) def do_connection_tune(self, tune): # XXX: is heartbeat protocol specific? @@ -310,9 +339,11 @@ class Driver: self.write_op(ConnectionTuneOk(heartbeat=self.connection.heartbeat, channel_max=self.channel_max)) self.write_op(ConnectionOpen()) + self._sasl_encode = True def do_connection_open_ok(self, open_ok): self._connected = True + self._sasl_decode = True def connection_heartbeat(self, hrt): self.write_op(ConnectionHeartbeat()) diff --git a/qpid/python/qpid/messaging.py b/qpid/python/qpid/messaging.py index 9ec38ad45c..91d7bca703 100644 --- a/qpid/python/qpid/messaging.py +++ b/qpid/python/qpid/messaging.py @@ -75,8 +75,7 @@ class Connection: """ @static - def open(host, port=None, username="guest", password="guest", - mechanism="PLAIN", heartbeat=None, **options): + def open(host, port=None, username="guest", password="guest", **options): """ Creates an AMQP connection and connects it to the given host and port. @@ -87,12 +86,11 @@ class Connection: @rtype: Connection @return: a connected Connection """ - conn = Connection(host, port, username, password, mechanism, heartbeat, **options) + conn = Connection(host, port, username, password, **options) conn.connect() return conn - def __init__(self, host, port=None, username="guest", password="guest", - mechanism="PLAIN", heartbeat=None, **options): + def __init__(self, host, port=None, username="guest", password="guest", **options): """ Creates a connection. A newly created connection must be connected with the Connection.connect() method before it can be used. @@ -108,8 +106,9 @@ class Connection: self.port = default(port, AMQP_PORT) self.username = username self.password = password - self.mechanism = mechanism - self.heartbeat = heartbeat + self.mechanisms = options.get("mechanisms") + self.heartbeat = options.get("heartbeat") + self.options = options self.id = str(uuid4()) self.session_counter = 0 diff --git a/qpid/python/qpid/sasl.py b/qpid/python/qpid/sasl.py new file mode 100644 index 0000000000..6b00ddaa99 --- /dev/null +++ b/qpid/python/qpid/sasl.py @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import socket + +class SASLError(Exception): + pass + +class WrapperClient: + + def __init__(self): + self._cli = _Client() + + def setAttr(self, name, value): + status = self._cli.setAttr(str(name), str(value)) + if not status: + raise SASLError(self._cli.getError()) + + def init(self): + status = self._cli.init() + if not status: + raise SASLError(self._cli.getError()) + + def start(self, mechanisms): + status, mech, initial = self._cli.start(str(mechanisms)) + if status: + return mech, initial + else: + raise SASLError(self._cli.getError()) + + def step(self, challenge): + status, response = self._cli.step(challenge) + if status: + return response + else: + raise SASLError(self._cli.getError()) + + def encode(self, bytes): + status, result = self._cli.encode(bytes) + if status: + return result + else: + raise SASLError(self._cli.getError()) + + def decode(self, bytes): + status, result = self._cli.decode(bytes) + if status: + return result + else: + raise SASLError(self._cli.getError()) + +class PlainClient: + + def __init__(self): + self.attrs = {} + + def setAttr(self, name, value): + self.attrs[name] = value + + def init(self): + pass + + def start(self, mechanisms): + mechs = mechanisms.split() + if self.attrs.get("username") and self.attrs.get("password") and "PLAIN" in mechs: + return "PLAIN", "\0%s\0%s" % (self.attrs.get("username"), self.attrs.get("password")) + elif "ANONYMOUS" in mechs: + return "ANONYMOUS", "%s@%s" % (self.attrs.get("username"), socket.gethostname()) + + def step(self, challenge): + pass + + def encode(self, bytes): + return bytes + + def decode(self, bytes): + return bytes + +try: + from saslwrapper import Client as _Client + Client = WrapperClient +except ImportError: + Client = PlainClient |
