diff options
| author | Gordon Sim <gsim@apache.org> | 2008-10-17 09:41:26 +0000 |
|---|---|---|
| committer | Gordon Sim <gsim@apache.org> | 2008-10-17 09:41:26 +0000 |
| commit | ed27e866fb3927257791591e00b9d9e90477e845 (patch) | |
| tree | 6550b389be9612f69337d449f315759679077843 /cpp/src/qpid | |
| parent | 5644e4fbfd777921b33874aed13c45d544c8a383 (diff) | |
| download | qpid-python-ed27e866fb3927257791591e00b9d9e90477e845.tar.gz | |
QPID-106: SSL support for c++ (broker and client), can be enabled/disabled explictly via --with-ssl/--without-ssl args to configure; by default will build the modules if dependencies are found. See SSL readme file for more details.
git-svn-id: https://svn.apache.org/repos/asf/incubator/qpid/trunk/qpid@705534 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'cpp/src/qpid')
| -rw-r--r-- | cpp/src/qpid/broker/Broker.cpp | 15 | ||||
| -rw-r--r-- | cpp/src/qpid/broker/Broker.h | 2 | ||||
| -rw-r--r-- | cpp/src/qpid/client/SslConnector.cpp | 389 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ProtocolFactory.h | 1 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/SslPlugin.cpp | 176 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslHandler.cpp | 177 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslHandler.h | 75 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.cpp | 433 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.h | 167 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.cpp | 279 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.h | 117 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/check.cpp | 70 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/check.h | 53 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/util.cpp | 119 | ||||
| -rw-r--r-- | cpp/src/qpid/sys/ssl/util.h | 50 |
15 files changed, 2116 insertions, 7 deletions
diff --git a/cpp/src/qpid/broker/Broker.cpp b/cpp/src/qpid/broker/Broker.cpp index 6a308ab64d..94c4449178 100644 --- a/cpp/src/qpid/broker/Broker.cpp +++ b/cpp/src/qpid/broker/Broker.cpp @@ -113,7 +113,7 @@ Broker::Options::Options(const std::string& name) : "Interval between attempts to purge any expired messages from queues") ("auth", optValue(auth, "yes|no"), "Enable authentication, if disabled all incoming connections will be trusted") ("realm", optValue(realm, "REALM"), "Use the given realm when performing authentication") - ("default-queue-limit", optValue(queueLimit, "BYTES"), "Default maximum size for queues (in bytes)") + ("default-queue-limit", optValue(queueLimit, "BYTES"), "Default maximum size for queues (in bytes)") ("tcp-nodelay", optValue(tcpNoDelay), "Set TCP_NODELAY on TCP connections"); } @@ -339,8 +339,6 @@ Manageable::status_t Broker::ManagementMethod (uint32_t methodId, QPID_LOG(error, "Transport '" << transport << "' not supported"); return Manageable::STATUS_NOT_IMPLEMENTED; } - QPID_LOG(info, "Connecting to " << hp.i_host << ":" << hp.i_port << " using '" << transport << "' as " << "'" << hp.i_username << "'"); - std::pair<Link::shared_ptr, bool> response = links.declare (hp.i_host, hp.i_port, transport, hp.i_durable, hp.i_authMechanism, hp.i_username, hp.i_password); @@ -372,9 +370,14 @@ boost::shared_ptr<ProtocolFactory> Broker::getProtocolFactory(const std::string& else return i->second; } -//TODO: should this allow choosing the port by transport name? -uint16_t Broker::getPort() const { - return getProtocolFactory()->getPort(); +uint16_t Broker::getPort(const std::string& name) const { + boost::shared_ptr<ProtocolFactory> factory + = getProtocolFactory(name.empty() ? TCP_TRANSPORT : name); + if (factory) { + return factory->getPort(); + } else { + throw Exception(QPID_MSG("No such transport: " << name)); + } } void Broker::registerProtocolFactory(const std::string& name, ProtocolFactory::shared_ptr protocolFactory) { diff --git a/cpp/src/qpid/broker/Broker.h b/cpp/src/qpid/broker/Broker.h index 089db69c6b..213bf63837 100644 --- a/cpp/src/qpid/broker/Broker.h +++ b/cpp/src/qpid/broker/Broker.h @@ -149,7 +149,7 @@ class Broker : public sys::Runnable, public Plugin::Target, * port, which will be different if the configured port is * 0. */ - virtual uint16_t getPort() const; + virtual uint16_t getPort(const std::string& name = TCP_TRANSPORT) const; /** * Run the broker. Implements Runnable::run() so the broker diff --git a/cpp/src/qpid/client/SslConnector.cpp b/cpp/src/qpid/client/SslConnector.cpp new file mode 100644 index 0000000000..8ae412ed09 --- /dev/null +++ b/cpp/src/qpid/client/SslConnector.cpp @@ -0,0 +1,389 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "Connector.h" + +#include "Bounds.h" +#include "ConnectionImpl.h" +#include "ConnectionSettings.h" +#include "qpid/Options.h" +#include "qpid/log/Statement.h" +#include "qpid/sys/Time.h" +#include "qpid/framing/AMQFrame.h" +#include "qpid/sys/ssl/util.h" +#include "qpid/sys/ssl/SslIo.h" +#include "qpid/sys/ssl/SslSocket.h" +#include "qpid/sys/Dispatcher.h" +#include "qpid/sys/Poller.h" +#include "qpid/Msg.h" + +#include <iostream> +#include <map> +#include <boost/bind.hpp> +#include <boost/format.hpp> + +namespace qpid { +namespace client { + +using namespace qpid::sys; +using namespace qpid::sys::ssl; +using namespace qpid::framing; +using boost::format; +using boost::str; + + +class SslConnector : public Connector, private sys::Runnable +{ + struct Buff; + + /** Batch up frames for writing to aio. */ + class Writer : public framing::FrameHandler { + typedef sys::ssl::SslIOBufferBase BufferBase; + typedef std::vector<framing::AMQFrame> Frames; + + const uint16_t maxFrameSize; + sys::Mutex lock; + sys::ssl::SslIO* aio; + BufferBase* buffer; + Frames frames; + size_t lastEof; // Position after last EOF in frames + framing::Buffer encode; + size_t framesEncoded; + std::string identifier; + Bounds* bounds; + + void writeOne(); + void newBuffer(); + + public: + + Writer(uint16_t maxFrameSize, Bounds*); + ~Writer(); + void init(std::string id, sys::ssl::SslIO*); + void handle(framing::AMQFrame&); + void write(sys::ssl::SslIO&); + }; + + const uint16_t maxFrameSize; + framing::ProtocolVersion version; + bool initiated; + + sys::Mutex closedLock; + bool closed; + bool joined; + + sys::ShutdownHandler* shutdownHandler; + framing::InputHandler* input; + framing::InitiationHandler* initialiser; + framing::OutputHandler* output; + + Writer writer; + + sys::Thread receiver; + + sys::ssl::SslSocket socket; + + sys::ssl::SslIO* aio; + boost::shared_ptr<sys::Poller> poller; + + ~SslConnector(); + + void run(); + void handleClosed(); + bool closeInternal(); + + void readbuff(qpid::sys::ssl::SslIO&, qpid::sys::ssl::SslIOBufferBase*); + void writebuff(qpid::sys::ssl::SslIO&); + void writeDataBlock(const framing::AMQDataBlock& data); + void eof(qpid::sys::ssl::SslIO&); + + std::string identifier; + + ConnectionImpl* impl; + + void connect(const std::string& host, int port); + void init(); + void close(); + void send(framing::AMQFrame& frame); + + void setInputHandler(framing::InputHandler* handler); + void setShutdownHandler(sys::ShutdownHandler* handler); + sys::ShutdownHandler* getShutdownHandler() const; + framing::OutputHandler* getOutputHandler(); + const std::string& getIdentifier() const; + +public: + SslConnector(framing::ProtocolVersion pVersion, + const ConnectionSettings&, + ConnectionImpl*); +}; + +// Static constructor which registers connector here +namespace { + Connector* create(framing::ProtocolVersion v, const ConnectionSettings& s, ConnectionImpl* c) { + return new SslConnector(v, s, c); + } + + struct StaticInit { + StaticInit() { + try { + SslOptions options; + options.parse (0, 0, CONF_FILE, true); + initNSS(options); + Connector::registerFactory("ssl", &create); + } catch (const std::exception& e) { + QPID_LOG(error, "Failed to initialise SSL connector: " << e.what()); + } + }; + + ~StaticInit() { shutdownNSS(); } + } init; +} + +SslConnector::SslConnector(ProtocolVersion ver, + const ConnectionSettings& settings, + ConnectionImpl* cimpl) + : maxFrameSize(settings.maxFrameSize), + version(ver), + initiated(false), + closed(true), + joined(true), + shutdownHandler(0), + writer(maxFrameSize, cimpl), + aio(0), + impl(cimpl) +{ + QPID_LOG(debug, "SslConnector created for " << version); + //TODO: how do we want to handle socket configuration with ssl? + //settings.configureSocket(socket); +} + +SslConnector::~SslConnector() { + close(); +} + +void SslConnector::connect(const std::string& host, int port){ + Mutex::ScopedLock l(closedLock); + assert(closed); + socket.connect(host, port); + identifier = str(format("[%1% %2%]") % socket.getLocalPort() % socket.getPeerAddress()); + closed = false; + poller = Poller::shared_ptr(new Poller); + aio = new SslIO(socket, + boost::bind(&SslConnector::readbuff, this, _1, _2), + boost::bind(&SslConnector::eof, this, _1), + boost::bind(&SslConnector::eof, this, _1), + 0, // closed + 0, // nobuffs + boost::bind(&SslConnector::writebuff, this, _1)); + writer.init(identifier, aio); +} + +void SslConnector::init(){ + Mutex::ScopedLock l(closedLock); + assert(joined); + ProtocolInitiation init(version); + writeDataBlock(init); + joined = false; + receiver = Thread(this); +} + +bool SslConnector::closeInternal() { + Mutex::ScopedLock l(closedLock); + bool ret = !closed; + if (!closed) { + closed = true; + poller->shutdown(); + } + if (!joined && receiver.id() != Thread::current().id()) { + joined = true; + Mutex::ScopedUnlock u(closedLock); + receiver.join(); + } + return ret; +} + +void SslConnector::close() { + closeInternal(); +} + +void SslConnector::setInputHandler(InputHandler* handler){ + input = handler; +} + +void SslConnector::setShutdownHandler(ShutdownHandler* handler){ + shutdownHandler = handler; +} + +OutputHandler* SslConnector::getOutputHandler() { + return this; +} + +sys::ShutdownHandler* SslConnector::getShutdownHandler() const { + return shutdownHandler; +} + +const std::string& SslConnector::getIdentifier() const { + return identifier; +} + +void SslConnector::send(AMQFrame& frame) { + writer.handle(frame); +} + +void SslConnector::handleClosed() { + if (closeInternal() && shutdownHandler) + shutdownHandler->shutdown(); +} + +struct SslConnector::Buff : public SslIO::BufferBase { + Buff(size_t size) : SslIO::BufferBase(new char[size], size) {} + ~Buff() { delete [] bytes;} +}; + +SslConnector::Writer::Writer(uint16_t s, Bounds* b) : maxFrameSize(s), aio(0), buffer(0), lastEof(0), bounds(b) +{ +} + +SslConnector::Writer::~Writer() { delete buffer; } + +void SslConnector::Writer::init(std::string id, sys::ssl::SslIO* a) { + Mutex::ScopedLock l(lock); + identifier = id; + aio = a; + newBuffer(); +} +void SslConnector::Writer::handle(framing::AMQFrame& frame) { + Mutex::ScopedLock l(lock); + frames.push_back(frame); + if (frame.getEof()) {//or if we already have a buffers worth + lastEof = frames.size(); + aio->notifyPendingWrite(); + } + QPID_LOG(trace, "SENT " << identifier << ": " << frame); +} + +void SslConnector::Writer::writeOne() { + assert(buffer); + framesEncoded = 0; + + buffer->dataStart = 0; + buffer->dataCount = encode.getPosition(); + aio->queueWrite(buffer); + newBuffer(); +} + +void SslConnector::Writer::newBuffer() { + buffer = aio->getQueuedBuffer(); + if (!buffer) buffer = new Buff(maxFrameSize); + encode = framing::Buffer(buffer->bytes, buffer->byteCount); + framesEncoded = 0; +} + +// Called in IO thread. +void SslConnector::Writer::write(sys::ssl::SslIO&) { + Mutex::ScopedLock l(lock); + assert(buffer); + size_t bytesWritten(0); + for (size_t i = 0; i < lastEof; ++i) { + AMQFrame& frame = frames[i]; + uint32_t size = frame.encodedSize(); + if (size > encode.available()) writeOne(); + assert(size <= encode.available()); + frame.encode(encode); + ++framesEncoded; + bytesWritten += size; + } + frames.erase(frames.begin(), frames.begin()+lastEof); + lastEof = 0; + if (bounds) bounds->reduce(bytesWritten); + if (encode.getPosition() > 0) writeOne(); +} + +void SslConnector::readbuff(SslIO& aio, SslIO::BufferBase* buff) { + framing::Buffer in(buff->bytes+buff->dataStart, buff->dataCount); + + if (!initiated) { + framing::ProtocolInitiation protocolInit; + if (protocolInit.decode(in)) { + //TODO: check the version is correct + QPID_LOG(debug, "RECV " << identifier << " INIT(" << protocolInit << ")"); + } + initiated = true; + } + AMQFrame frame; + while(frame.decode(in)){ + QPID_LOG(trace, "RECV " << identifier << ": " << frame); + input->received(frame); + } + // TODO: unreading needs to go away, and when we can cope + // with multiple sub-buffers in the general buffer scheme, it will + if (in.available() != 0) { + // Adjust buffer for used bytes and then "unread them" + buff->dataStart += buff->dataCount-in.available(); + buff->dataCount = in.available(); + aio.unread(buff); + } else { + // Give whole buffer back to aio subsystem + aio.queueReadBuffer(buff); + } +} + +void SslConnector::writebuff(SslIO& aio_) { + writer.write(aio_); +} + +void SslConnector::writeDataBlock(const AMQDataBlock& data) { + SslIO::BufferBase* buff = new Buff(maxFrameSize); + framing::Buffer out(buff->bytes, buff->byteCount); + data.encode(out); + buff->dataCount = data.encodedSize(); + aio->queueWrite(buff); +} + +void SslConnector::eof(SslIO&) { + handleClosed(); +} + +// TODO: astitcher 20070908 This version of the code can never time out, so the idle processing +// will never be called +void SslConnector::run(){ + // Keep the connection impl in memory until run() completes. + boost::shared_ptr<ConnectionImpl> protect = impl->shared_from_this(); + assert(protect); + try { + Dispatcher d(poller); + + for (int i = 0; i < 32; i++) { + aio->queueReadBuffer(new Buff(maxFrameSize)); + } + + aio->start(poller); + d.run(); + aio->queueForDeletion(); + socket.close(); + } catch (const std::exception& e) { + QPID_LOG(error, e.what()); + handleClosed(); + } +} + + +}} // namespace qpid::client diff --git a/cpp/src/qpid/sys/ProtocolFactory.h b/cpp/src/qpid/sys/ProtocolFactory.h index 2a77adc9c4..56ab404d82 100644 --- a/cpp/src/qpid/sys/ProtocolFactory.h +++ b/cpp/src/qpid/sys/ProtocolFactory.h @@ -46,6 +46,7 @@ class ProtocolFactory : public qpid::SharedObject<ProtocolFactory> const std::string& host, int16_t port, ConnectionCodec::Factory* codec, ConnectFailedCallback failed) = 0; + virtual bool supports(const std::string& /*capability*/) { return false; } }; inline ProtocolFactory::~ProtocolFactory() {} diff --git a/cpp/src/qpid/sys/SslPlugin.cpp b/cpp/src/qpid/sys/SslPlugin.cpp new file mode 100644 index 0000000000..f5111fff6a --- /dev/null +++ b/cpp/src/qpid/sys/SslPlugin.cpp @@ -0,0 +1,176 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "ProtocolFactory.h" + +#include "qpid/Plugin.h" +#include "qpid/sys/ssl/check.h" +#include "qpid/sys/ssl/util.h" +#include "qpid/sys/ssl/SslHandler.h" +#include "qpid/sys/ssl/SslIo.h" +#include "qpid/sys/ssl/SslSocket.h" +#include "qpid/broker/Broker.h" +#include "qpid/log/Statement.h" + +#include <boost/bind.hpp> +#include <memory> + + +namespace qpid { +namespace sys { + +struct SslServerOptions : ssl::SslOptions +{ + uint16_t port; + bool clientAuth; + + SslServerOptions() : port(5673), + clientAuth(false) + { + addOptions() + ("ssl-port", optValue(port, "PORT"), "Port on which to listen for SSL connections") + ("ssl-require-client-authentication", optValue(clientAuth), + "Forces clients to authenticate in order to establish an SSL connection"); + } +}; + +class SslProtocolFactory : public ProtocolFactory { + const bool tcpNoDelay; + qpid::sys::ssl::SslSocket listener; + const uint16_t listeningPort; + std::auto_ptr<qpid::sys::ssl::SslAcceptor> acceptor; + + public: + SslProtocolFactory(const SslServerOptions&, int backlog, bool nodelay); + void accept(Poller::shared_ptr, ConnectionCodec::Factory*); + void connect(Poller::shared_ptr, const std::string& host, int16_t port, + ConnectionCodec::Factory*, + boost::function2<void, int, std::string> failed); + + uint16_t getPort() const; + std::string getHost() const; + bool supports(const std::string& capability); + + private: + void established(Poller::shared_ptr, const qpid::sys::ssl::SslSocket&, ConnectionCodec::Factory*, + bool isClient); +}; + +// Static instance to initialise plugin +static struct SslPlugin : public Plugin { + SslServerOptions options; + + Options* getOptions() { return &options; } + + ~SslPlugin() { ssl::shutdownNSS(); } + + void earlyInitialize(Target&) { + } + + void initialize(Target& target) { + broker::Broker* broker = dynamic_cast<broker::Broker*>(&target); + // Only provide to a Broker + if (broker) { + ssl::initNSS(options, true); + + const broker::Broker::Options& opts = broker->getOptions(); + ProtocolFactory::shared_ptr protocol(new SslProtocolFactory(options, + opts.connectionBacklog, opts.tcpNoDelay)); + QPID_LOG(info, "Listening for SSL connections on TCP port " << protocol->getPort()); + broker->registerProtocolFactory("ssl", protocol); + } + } +} sslPlugin; + +SslProtocolFactory::SslProtocolFactory(const SslServerOptions& options, int backlog, bool nodelay) : + tcpNoDelay(nodelay), listeningPort(listener.listen(options.port, backlog, options.certName, options.clientAuth)) +{} + +void SslProtocolFactory::established(Poller::shared_ptr poller, const qpid::sys::ssl::SslSocket& s, + ConnectionCodec::Factory* f, bool isClient) { + qpid::sys::ssl::SslHandler* async = new qpid::sys::ssl::SslHandler(s.getPeerAddress(), f); + + if (tcpNoDelay) { + s.setTcpNoDelay(tcpNoDelay); + QPID_LOG(info, "Set TCP_NODELAY on connection to " << s.getPeerAddress()); + } + + if (isClient) + async->setClient(); + qpid::sys::ssl::SslIO* aio = new qpid::sys::ssl::SslIO(s, + boost::bind(&qpid::sys::ssl::SslHandler::readbuff, async, _1, _2), + boost::bind(&qpid::sys::ssl::SslHandler::eof, async, _1), + boost::bind(&qpid::sys::ssl::SslHandler::disconnect, async, _1), + boost::bind(&qpid::sys::ssl::SslHandler::closedSocket, async, _1, _2), + boost::bind(&qpid::sys::ssl::SslHandler::nobuffs, async, _1), + boost::bind(&qpid::sys::ssl::SslHandler::idle, async, _1)); + + async->init(aio, 4); + aio->start(poller); +} + +uint16_t SslProtocolFactory::getPort() const { + return listeningPort; // Immutable no need for lock. +} + +std::string SslProtocolFactory::getHost() const { + return listener.getSockname(); +} + +void SslProtocolFactory::accept(Poller::shared_ptr poller, + ConnectionCodec::Factory* fact) { + acceptor.reset( + new qpid::sys::ssl::SslAcceptor(listener, + boost::bind(&SslProtocolFactory::established, this, poller, _1, fact, false))); + acceptor->start(poller); +} + +void SslProtocolFactory::connect( + Poller::shared_ptr poller, + const std::string& host, int16_t port, + ConnectionCodec::Factory* fact, + ConnectFailedCallback failed) +{ + // Note that the following logic does not cause a memory leak. + // The allocated Socket is freed either by the SslConnector + // upon connection failure or by the SslIoHandle upon connection + // shutdown. The allocated SslConnector frees itself when it + // is no longer needed. + + qpid::sys::ssl::SslSocket* socket = new qpid::sys::ssl::SslSocket(); + new qpid::sys::ssl::SslConnector (*socket, poller, host, port, + boost::bind(&SslProtocolFactory::established, this, poller, _1, fact, true), + failed); +} + +namespace +{ +const std::string SSL = "ssl"; +} + +bool SslProtocolFactory::supports(const std::string& capability) +{ + std::string s = capability; + transform(s.begin(), s.end(), s.begin(), tolower); + return s == SSL; +} + +}} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/ssl/SslHandler.cpp b/cpp/src/qpid/sys/ssl/SslHandler.cpp new file mode 100644 index 0000000000..4177ca294c --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslHandler.cpp @@ -0,0 +1,177 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "SslHandler.h" + +#include "SslIo.h" +#include "SslSocket.h" +#include "qpid/framing/AMQP_HighestVersion.h" +#include "qpid/framing/ProtocolInitiation.h" +#include "qpid/log/Statement.h" + +namespace qpid { +namespace sys { +namespace ssl { + + +// Buffer definition +struct Buff : public SslIO::BufferBase { + Buff() : + SslIO::BufferBase(new char[65536], 65536) + {} + ~Buff() + { delete [] bytes;} +}; + +SslHandler::SslHandler(std::string id, ConnectionCodec::Factory* f) : + identifier(id), + aio(0), + factory(f), + codec(0), + readError(false), + isClient(false) +{} + +SslHandler::~SslHandler() { + if (codec) + codec->closed(); + delete codec; +} + +void SslHandler::init(SslIO* a, int numBuffs) { + aio = a; + + // Give connection some buffers to use + for (int i = 0; i < numBuffs; i++) { + aio->queueReadBuffer(new Buff); + } +} + +void SslHandler::write(const framing::ProtocolInitiation& data) +{ + QPID_LOG(debug, "SENT [" << identifier << "] INIT(" << data << ")"); + SslIO::BufferBase* buff = aio->getQueuedBuffer(); + if (!buff) + buff = new Buff; + framing::Buffer out(buff->bytes, buff->byteCount); + data.encode(out); + buff->dataCount = data.encodedSize(); + aio->queueWrite(buff); +} + +void SslHandler::activateOutput() { + aio->notifyPendingWrite(); +} + +// Input side +void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) { + if (readError) { + return; + } + size_t decoded = 0; + if (codec) { // Already initiated + try { + decoded = codec->decode(buff->bytes+buff->dataStart, buff->dataCount); + }catch(const std::exception& e){ + QPID_LOG(error, e.what()); + readError = true; + aio->queueWriteClose(); + } + }else{ + framing::Buffer in(buff->bytes+buff->dataStart, buff->dataCount); + framing::ProtocolInitiation protocolInit; + if (protocolInit.decode(in)) { + decoded = in.getPosition(); + QPID_LOG(debug, "RECV [" << identifier << "] INIT(" << protocolInit << ")"); + try { + codec = factory->create(protocolInit.getVersion(), *this, identifier); + if (!codec) { + //TODO: may still want to revise this... + //send valid version header & close connection. + write(framing::ProtocolInitiation(framing::highestProtocolVersion)); + readError = true; + aio->queueWriteClose(); + } + } catch (const std::exception& e) { + QPID_LOG(error, e.what()); + readError = true; + aio->queueWriteClose(); + } + } + } + // TODO: unreading needs to go away, and when we can cope + // with multiple sub-buffers in the general buffer scheme, it will + if (decoded != size_t(buff->dataCount)) { + // Adjust buffer for used bytes and then "unread them" + buff->dataStart += decoded; + buff->dataCount -= decoded; + aio->unread(buff); + } else { + // Give whole buffer back to aio subsystem + aio->queueReadBuffer(buff); + } +} + +void SslHandler::eof(SslIO&) { + QPID_LOG(debug, "DISCONNECTED [" << identifier << "]"); + if (codec) codec->closed(); + aio->queueWriteClose(); +} + +void SslHandler::closedSocket(SslIO&, const SslSocket& s) { + // If we closed with data still to send log a warning + if (!aio->writeQueueEmpty()) { + QPID_LOG(warning, "CLOSING [" << identifier << "] unsent data (probably due to client disconnect)"); + } + delete &s; + aio->queueForDeletion(); + delete this; +} + +void SslHandler::disconnect(SslIO& a) { + // treat the same as eof + eof(a); +} + +// Notifications +void SslHandler::nobuffs(SslIO&) { +} + +void SslHandler::idle(SslIO&){ + if (isClient && codec == 0) { + codec = factory->create(*this, identifier); + write(framing::ProtocolInitiation(codec->getVersion())); + return; + } + if (codec == 0) return; + if (codec->canEncode()) { + // Try and get a queued buffer if not then construct new one + SslIO::BufferBase* buff = aio->getQueuedBuffer(); + if (!buff) buff = new Buff; + size_t encoded=codec->encode(buff->bytes, buff->byteCount); + buff->dataCount = encoded; + aio->queueWrite(buff); + } + if (codec->isClosed()) + aio->queueWriteClose(); +} + + +}}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/SslHandler.h b/cpp/src/qpid/sys/ssl/SslHandler.h new file mode 100644 index 0000000000..cce5ecf09b --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslHandler.h @@ -0,0 +1,75 @@ +#ifndef QPID_SYS_SSL_SSLHANDLER_H +#define QPID_SYS_SSL_SSLHANDLER_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/sys/ConnectionCodec.h" +#include "qpid/sys/OutputControl.h" + +namespace qpid { + +namespace framing { + class ProtocolInitiation; +} + +namespace sys { +namespace ssl { + +class SslIO; +class SslIOBufferBase; +class SslSocket; + +class SslHandler : public OutputControl { + std::string identifier; + SslIO* aio; + ConnectionCodec::Factory* factory; + ConnectionCodec* codec; + bool readError; + bool isClient; + + void write(const framing::ProtocolInitiation&); + + public: + SslHandler(std::string id, ConnectionCodec::Factory* f); + ~SslHandler(); + void init(SslIO* a, int numBuffs); + + void setClient() { isClient = true; } + + // Output side + void close(); + void activateOutput(); + + // Input side + void readbuff(SslIO& aio, SslIOBufferBase* buff); + void eof(SslIO& aio); + void disconnect(SslIO& aio); + + // Notifications + void nobuffs(SslIO& aio); + void idle(SslIO& aio); + void closedSocket(SslIO& aio, const SslSocket& s); +}; + +}}} // namespace qpid::sys::ssl + +#endif /*!QPID_SYS_SSL_SSLHANDLER_H*/ diff --git a/cpp/src/qpid/sys/ssl/SslIo.cpp b/cpp/src/qpid/sys/ssl/SslIo.cpp new file mode 100644 index 0000000000..9be75af47d --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslIo.cpp @@ -0,0 +1,433 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "SslIo.h" +#include "SslSocket.h" + +#include "qpid/sys/Time.h" +#include "qpid/sys/posix/check.h" +#include "qpid/log/Statement.h" + +// TODO The basic algorithm here is not really POSIX specific and with a bit more abstraction +// could (should) be promoted to be platform portable +#include <unistd.h> +#include <sys/socket.h> +#include <signal.h> +#include <errno.h> +#include <string.h> + +#include <boost/bind.hpp> + +using namespace qpid::sys; +using namespace qpid::sys::ssl; + +namespace { + +/* + * Make *process* not generate SIGPIPE when writing to closed + * pipe/socket (necessary as default action is to terminate process) + */ +void ignoreSigpipe() { + ::signal(SIGPIPE, SIG_IGN); +} + +/* + * We keep per thread state to avoid locking overhead. The assumption is that + * on average all the connections are serviced by all the threads so the state + * recorded in each thread is about the same. If this turns out not to be the + * case we could rebalance the info occasionally. + */ +__thread int threadReadTotal = 0; +__thread int threadMaxRead = 0; +__thread int threadReadCount = 0; +__thread int threadWriteTotal = 0; +__thread int threadWriteCount = 0; +__thread int64_t threadMaxReadTimeNs = 2 * 1000000; // start at 2ms +} + +/* + * Asynch Acceptor + */ + +SslAcceptor::SslAcceptor(const SslSocket& s, Callback callback) : + acceptedCallback(callback), + handle(s, boost::bind(&SslAcceptor::readable, this, _1), 0, 0), + socket(s) { + + s.setNonblocking(); + ignoreSigpipe(); +} + +void SslAcceptor::start(Poller::shared_ptr poller) { + handle.startWatch(poller); +} + +/* + * We keep on accepting as long as there is something to accept + */ +void SslAcceptor::readable(DispatchHandle& h) { + SslSocket* s; + do { + errno = 0; + // TODO: Currently we ignore the peers address, perhaps we should + // log it or use it for connection acceptance. + try { + s = socket.accept(0, 0); + if (s) { + acceptedCallback(*s); + } else { + break; + } + } catch (const std::exception& e) { + QPID_LOG(error, "Could not accept socket: " << e.what()); + } + } while (true); + + h.rewatch(); +} + +/* + * Asynch Connector + */ + +SslConnector::SslConnector(const SslSocket& s, + Poller::shared_ptr poller, + std::string hostname, + uint16_t port, + ConnectedCallback connCb, + FailedCallback failCb) : + DispatchHandle(s, + 0, + boost::bind(&SslConnector::connComplete, this, _1), + boost::bind(&SslConnector::connComplete, this, _1)), + connCallback(connCb), + failCallback(failCb), + socket(s) +{ + //TODO: would be better for connect to be performed on a + //non-blocking socket, but that doesn't work at present so connect + //blocks until complete + try { + socket.connect(hostname, port); + socket.setNonblocking(); + startWatch(poller); + } catch(std::exception& e) { + failure(-1, std::string(e.what())); + } +} + +void SslConnector::connComplete(DispatchHandle& h) +{ + int errCode = socket.getError(); + + h.stopWatch(); + if (errCode == 0) { + connCallback(socket); + DispatchHandle::doDelete(); + } else { + // TODO: This need to be fixed as strerror isn't thread safe + failure(errCode, std::string(::strerror(errCode))); + } +} + +void SslConnector::failure(int errCode, std::string message) +{ + if (failCallback) + failCallback(errCode, message); + + socket.close(); + delete &socket; + + DispatchHandle::doDelete(); +} + +/* + * Asynch reader/writer + */ +SslIO::SslIO(const SslSocket& s, + ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb, + ClosedCallback cCb, BuffersEmptyCallback eCb, IdleCallback iCb) : + + DispatchHandle(s, + boost::bind(&SslIO::readable, this, _1), + boost::bind(&SslIO::writeable, this, _1), + boost::bind(&SslIO::disconnected, this, _1)), + readCallback(rCb), + eofCallback(eofCb), + disCallback(disCb), + closedCallback(cCb), + emptyCallback(eCb), + idleCallback(iCb), + socket(s), + queuedClose(false), + writePending(false) { + + s.setNonblocking(); +} + +struct deleter +{ + template <typename T> + void operator()(T *ptr){ delete ptr;} +}; + +SslIO::~SslIO() { + std::for_each( bufferQueue.begin(), bufferQueue.end(), deleter()); + std::for_each( writeQueue.begin(), writeQueue.end(), deleter()); +} + +void SslIO::queueForDeletion() { + DispatchHandle::doDelete(); +} + +void SslIO::start(Poller::shared_ptr poller) { + DispatchHandle::startWatch(poller); +} + +void SslIO::queueReadBuffer(BufferBase* buff) { + assert(buff); + buff->dataStart = 0; + buff->dataCount = 0; + bufferQueue.push_back(buff); + DispatchHandle::rewatchRead(); +} + +void SslIO::unread(BufferBase* buff) { + assert(buff); + if (buff->dataStart != 0) { + memmove(buff->bytes, buff->bytes+buff->dataStart, buff->dataCount); + buff->dataStart = 0; + } + bufferQueue.push_front(buff); + DispatchHandle::rewatchRead(); +} + +void SslIO::queueWrite(BufferBase* buff) { + assert(buff); + // If we've already closed the socket then throw the write away + if (queuedClose) { + bufferQueue.push_front(buff); + return; + } else { + writeQueue.push_front(buff); + } + writePending = false; + DispatchHandle::rewatchWrite(); +} + +void SslIO::notifyPendingWrite() { + writePending = true; + DispatchHandle::rewatchWrite(); +} + +void SslIO::queueWriteClose() { + queuedClose = true; + DispatchHandle::rewatchWrite(); +} + +/** Return a queued buffer if there are enough + * to spare + */ +SslIO::BufferBase* SslIO::getQueuedBuffer() { + // Always keep at least one buffer (it might have data that was "unread" in it) + if (bufferQueue.size()<=1) + return 0; + BufferBase* buff = bufferQueue.back(); + assert(buff); + buff->dataStart = 0; + buff->dataCount = 0; + bufferQueue.pop_back(); + return buff; +} + +/* + * We keep on reading as long as we have something to read and a buffer to put + * it in + */ +void SslIO::readable(DispatchHandle& h) { + int readTotal = 0; + AbsTime readStartTime = AbsTime::now(); + do { + // (Try to) get a buffer + if (!bufferQueue.empty()) { + // Read into buffer + BufferBase* buff = bufferQueue.front(); + assert(buff); + bufferQueue.pop_front(); + errno = 0; + int readCount = buff->byteCount-buff->dataCount; + int rc = socket.read(buff->bytes + buff->dataCount, readCount); + if (rc > 0) { + buff->dataCount += rc; + threadReadTotal += rc; + readTotal += rc; + + readCallback(*this, buff); + if (rc != readCount) { + // If we didn't fill the read buffer then time to stop reading + break; + } + + // Stop reading if we've overrun our timeslot + if (Duration(readStartTime, AbsTime::now()) > threadMaxReadTimeNs) { + break; + } + + } else { + // Put buffer back (at front so it doesn't interfere with unread buffers) + bufferQueue.push_front(buff); + assert(buff); + + // Eof or other side has gone away + if (rc == 0 || errno == ECONNRESET) { + eofCallback(*this); + h.unwatchRead(); + break; + } else if (errno == EAGAIN) { + // We have just put a buffer back so we know + // we can carry on watching for reads + break; + } else { + // Report error then just treat as a socket disconnect + QPID_LOG(error, "Error reading socket: " << qpid::sys::strError(rc) << "(" << rc << ")" ); + eofCallback(*this); + h.unwatchRead(); + break; + } + } + } else { + // Something to read but no buffer + if (emptyCallback) { + emptyCallback(*this); + } + // If we still have no buffers we can't do anything more + if (bufferQueue.empty()) { + h.unwatchRead(); + break; + } + + } + } while (true); + + ++threadReadCount; + threadMaxRead = std::max(threadMaxRead, readTotal); + return; +} + +/* + * We carry on writing whilst we have data to write and we can write + */ +void SslIO::writeable(DispatchHandle& h) { + int writeTotal = 0; + do { + // See if we've got something to write + if (!writeQueue.empty()) { + // Write buffer + BufferBase* buff = writeQueue.back(); + writeQueue.pop_back(); + errno = 0; + assert(buff->dataStart+buff->dataCount <= buff->byteCount); + int rc = socket.write(buff->bytes+buff->dataStart, buff->dataCount); + if (rc >= 0) { + threadWriteTotal += rc; + writeTotal += rc; + + // If we didn't write full buffer put rest back + if (rc != buff->dataCount) { + buff->dataStart += rc; + buff->dataCount -= rc; + writeQueue.push_back(buff); + break; + } + + // Recycle the buffer + queueReadBuffer(buff); + + // If we've already written more than the max for reading then stop + // (this is to stop writes dominating reads) + if (writeTotal > threadMaxRead) + break; + } else { + // Put buffer back + writeQueue.push_back(buff); + if (errno == ECONNRESET || errno == EPIPE) { + // Just stop watching for write here - we'll get a + // disconnect callback soon enough + h.unwatchWrite(); + break; + } else if (errno == EAGAIN) { + // We have just put a buffer back so we know + // we can carry on watching for writes + break; + } else { + QPID_POSIX_CHECK(rc); + } + } + } else { + // If we're waiting to close the socket then can do it now as there is nothing to write + if (queuedClose) { + close(h); + break; + } + // Fd is writable, but nothing to write + if (idleCallback) { + writePending = false; + idleCallback(*this); + } + // If we still have no buffers to write we can't do anything more + if (writeQueue.empty() && !writePending && !queuedClose) { + h.unwatchWrite(); + // The following handles the case where writePending is + // set to true after the test above; in this case its + // possible that the unwatchWrite overwrites the + // desired rewatchWrite so we correct that here + if (writePending) + h.rewatchWrite(); + break; + } + } + } while (true); + + ++threadWriteCount; + return; +} + +void SslIO::disconnected(DispatchHandle& h) { + // If we've already queued close do it instead of disconnected callback + if (queuedClose) { + close(h); + } else if (disCallback) { + disCallback(*this); + h.unwatch(); + } +} + +/* + * Close the socket and callback to say we've done it + */ +void SslIO::close(DispatchHandle& h) { + h.stopWatch(); + socket.close(); + if (closedCallback) { + closedCallback(*this, socket); + } +} + diff --git a/cpp/src/qpid/sys/ssl/SslIo.h b/cpp/src/qpid/sys/ssl/SslIo.h new file mode 100644 index 0000000000..6875a92dea --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslIo.h @@ -0,0 +1,167 @@ +#ifndef _sys_ssl_SslIO +#define _sys_ssl_SslIO +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/sys/Dispatcher.h" + +#include <boost/function.hpp> +#include <deque> + +namespace qpid { +namespace sys { +namespace ssl { + +class SslSocket; + +/* + * Asynchronous ssl acceptor: accepts connections then does a callback + * with the accepted fd + */ +class SslAcceptor { +public: + typedef boost::function1<void, const SslSocket&> Callback; + +private: + Callback acceptedCallback; + qpid::sys::DispatchHandle handle; + const SslSocket& socket; + +public: + SslAcceptor(const SslSocket& s, Callback callback); + void start(qpid::sys::Poller::shared_ptr poller); + +private: + void readable(qpid::sys::DispatchHandle& handle); +}; + +/* + * Asynchronous ssl connector: starts the process of initiating a + * connection and invokes a callback when completed or failed. + */ +class SslConnector : private qpid::sys::DispatchHandle { +public: + typedef boost::function1<void, const SslSocket&> ConnectedCallback; + typedef boost::function2<void, int, std::string> FailedCallback; + +private: + ConnectedCallback connCallback; + FailedCallback failCallback; + const SslSocket& socket; + +public: + SslConnector(const SslSocket& socket, + Poller::shared_ptr poller, + std::string hostname, + uint16_t port, + ConnectedCallback connCb, + FailedCallback failCb = 0); + +private: + void connComplete(DispatchHandle& handle); + void failure(int, std::string); +}; + +struct SslIOBufferBase { + char* const bytes; + const int32_t byteCount; + int32_t dataStart; + int32_t dataCount; + + SslIOBufferBase(char* const b, const int32_t s) : + bytes(b), + byteCount(s), + dataStart(0), + dataCount(0) + {} + + virtual ~SslIOBufferBase() + {} +}; + +/* + * Asychronous reader/writer: + * Reader accepts buffers to read into; reads into the provided buffers + * and then does a callback with the buffer and amount read. Optionally it can callback + * when there is something to read but no buffer to read it into. + * + * Writer accepts a buffer and queues it for writing; can also be given + * a callback for when writing is "idle" (ie fd is writable, but nothing to write) + * + * The class is implemented in terms of DispatchHandle to allow it to be deleted by deleting + * the contained DispatchHandle + */ +class SslIO : private qpid::sys::DispatchHandle { +public: + typedef SslIOBufferBase BufferBase; + + typedef boost::function2<void, SslIO&, BufferBase*> ReadCallback; + typedef boost::function1<void, SslIO&> EofCallback; + typedef boost::function1<void, SslIO&> DisconnectCallback; + typedef boost::function2<void, SslIO&, const SslSocket&> ClosedCallback; + typedef boost::function1<void, SslIO&> BuffersEmptyCallback; + typedef boost::function1<void, SslIO&> IdleCallback; + +private: + ReadCallback readCallback; + EofCallback eofCallback; + DisconnectCallback disCallback; + ClosedCallback closedCallback; + BuffersEmptyCallback emptyCallback; + IdleCallback idleCallback; + const SslSocket& socket; + std::deque<BufferBase*> bufferQueue; + std::deque<BufferBase*> writeQueue; + bool queuedClose; + /** + * This flag is used to detect and handle concurrency between + * calls to notifyPendingWrite() (which can be made from any thread) and + * the execution of the writeable() method (which is always on the + * thread processing this handle. + */ + volatile bool writePending; + +public: + SslIO(const SslSocket& s, + ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb, + ClosedCallback cCb = 0, BuffersEmptyCallback eCb = 0, IdleCallback iCb = 0); + void queueForDeletion(); + + void start(qpid::sys::Poller::shared_ptr poller); + void queueReadBuffer(BufferBase* buff); + void unread(BufferBase* buff); + void queueWrite(BufferBase* buff); + void notifyPendingWrite(); + void queueWriteClose(); + bool writeQueueEmpty() { return writeQueue.empty(); } + BufferBase* getQueuedBuffer(); + +private: + ~SslIO(); + void readable(qpid::sys::DispatchHandle& handle); + void writeable(qpid::sys::DispatchHandle& handle); + void disconnected(qpid::sys::DispatchHandle& handle); + void close(qpid::sys::DispatchHandle& handle); +}; + +}}} + +#endif // _sys_ssl_SslIO diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp new file mode 100644 index 0000000000..597fbe57db --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -0,0 +1,279 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "SslSocket.h" +#include "check.h" +#include "util.h" +#include "qpid/Exception.h" +#include "qpid/sys/posix/check.h" +#include "qpid/sys/posix/PrivatePosix.h" + +#include <fcntl.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/errno.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <cstdlib> +#include <string.h> +#include <iostream> + +#include <nspr4/private/pprio.h> +#include <nss3/nss.h> +#include <nss3/pk11pub.h> +#include <nss3/ssl.h> +#include <nss3/key.h> + +#include <boost/format.hpp> + +namespace qpid { +namespace sys { +namespace ssl { + +namespace { +std::string getName(int fd, bool local, bool includeService = false) +{ + ::sockaddr_storage name; // big enough for any socket address + ::socklen_t namelen = sizeof(name); + + int result = -1; + if (local) { + result = ::getsockname(fd, (::sockaddr*)&name, &namelen); + } else { + result = ::getpeername(fd, (::sockaddr*)&name, &namelen); + } + + QPID_POSIX_CHECK(result); + + char servName[NI_MAXSERV]; + char dispName[NI_MAXHOST]; + if (includeService) { + if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), + servName, sizeof(servName), + NI_NUMERICHOST | NI_NUMERICSERV) != 0) + throw QPID_POSIX_ERROR(rc); + return std::string(dispName) + ":" + std::string(servName); + + } else { + if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), 0, 0, NI_NUMERICHOST) != 0) + throw QPID_POSIX_ERROR(rc); + return dispName; + } +} + +std::string getService(int fd, bool local) +{ + ::sockaddr_storage name; // big enough for any socket address + ::socklen_t namelen = sizeof(name); + + int result = -1; + if (local) { + result = ::getsockname(fd, (::sockaddr*)&name, &namelen); + } else { + result = ::getpeername(fd, (::sockaddr*)&name, &namelen); + } + + QPID_POSIX_CHECK(result); + + char servName[NI_MAXSERV]; + if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0, + servName, sizeof(servName), + NI_NUMERICHOST | NI_NUMERICSERV) != 0) + throw QPID_POSIX_ERROR(rc); + return servName; +} + +} + +SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0) +{ + impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); + if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); + socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); +} + +/** + * This form of the constructor is used with the server-side sockets + * returned from accept. Because we use posix accept rather than + * PR_Accept, we have to reset the handshake. + */ +SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : IOHandle(ioph), socket(0), prototype(0) +{ + socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); + NSS_CHECK(SSL_ResetHandshake(socket, true)); + NSS_CHECK(SSL_ForceHandshake(socket)); +} + +void SslSocket::setNonblocking() const +{ + PRSocketOptionData option; + option.option = PR_SockOpt_Nonblocking; + option.value.non_blocking = true; + PR_SetSocketOption(socket, &option); +} + +void SslSocket::connect(const std::string& host, uint16_t port) const +{ + std::stringstream namestream; + namestream << host << ":" << port; + connectname = namestream.str(); + + void* arg = SslOptions::global.certName.empty() ? 0 : const_cast<char*>(SslOptions::global.certName.c_str()); + NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg)); + NSS_CHECK(SSL_SetURL(socket, host.data())); + + char hostBuffer[PR_NETDB_BUF_SIZE]; + PRHostEnt hostEntry; + PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry)); + PRNetAddr address; + int value = PR_EnumerateHostEnt(0, &hostEntry, port, &address); + if (value < 0) { + throw Exception(QPID_MSG("Error getting address for host: " << ErrorString())); + } else if (value == 0) { + throw Exception(QPID_MSG("Could not resolve address for host.")); + } + PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT)); + NSS_CHECK(SSL_ForceHandshake(socket)); +} + +void SslSocket::close() const +{ + if (impl->fd > 0) { + PR_Close(socket); + impl->fd = -1; + } +} + +int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const +{ + //configure prototype socket: + prototype = SSL_ImportFD(0, PR_NewTCPSocket()); + if (clientAuth) { + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); + } + + //get certificate and key (is this the correct way?) + CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0); + if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'")); + SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0); + if (!key) throw Exception(QPID_MSG("Failed to retrieve private key from certificate")); + NSS_CHECK(SSL_ConfigSecureServer(prototype, cert, key, NSS_FindCertKEAType(cert))); + SECKEY_DestroyPrivateKey(key); + CERT_DestroyCertificate(cert); + + //bind and listen + const int& socket = impl->fd; + int yes=1; + QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes))); + struct sockaddr_in name; + name.sin_family = AF_INET; + name.sin_port = htons(port); + name.sin_addr.s_addr = 0; + if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0) + throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno))); + if (::listen(socket, backlog) < 0) + throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno))); + + socklen_t namelen = sizeof(name); + if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0) + throw QPID_POSIX_ERROR(errno); + + return ntohs(name.sin_port); +} + +SslSocket* SslSocket::accept(struct sockaddr *addr, socklen_t *addrlen) const +{ + int afd = ::accept(impl->fd, addr, addrlen); + if ( afd >= 0) { + return new SslSocket(new IOHandlePrivate(afd), prototype); + } else if (errno == EAGAIN) { + return 0; + } else { + throw QPID_POSIX_ERROR(errno); + } +} + +int SslSocket::read(void *buf, size_t count) const +{ + return PR_Read(socket, buf, count); +} + +int SslSocket::write(const void *buf, size_t count) const +{ + return PR_Write(socket, buf, count); +} + +std::string SslSocket::getSockname() const +{ + return getName(impl->fd, true); +} + +std::string SslSocket::getPeername() const +{ + return getName(impl->fd, false); +} + +std::string SslSocket::getPeerAddress() const +{ + if (!connectname.empty()) + return connectname; + return getName(impl->fd, false, true); +} + +std::string SslSocket::getLocalAddress() const +{ + return getName(impl->fd, true, true); +} + +uint16_t SslSocket::getLocalPort() const +{ + return std::atoi(getService(impl->fd, true).c_str()); +} + +uint16_t SslSocket::getRemotePort() const +{ + return atoi(getService(impl->fd, true).c_str()); +} + +int SslSocket::getError() const +{ + int result; + socklen_t rSize = sizeof (result); + + if (::getsockopt(impl->fd, SOL_SOCKET, SO_ERROR, &result, &rSize) < 0) + throw QPID_POSIX_ERROR(errno); + + return result; +} + +void SslSocket::setTcpNoDelay(bool nodelay) const +{ + if (nodelay) { + PRSocketOptionData option; + option.option = PR_SockOpt_NoDelay; + option.value.no_delay = true; + PR_SetSocketOption(socket, &option); + } +} + +}}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/SslSocket.h b/cpp/src/qpid/sys/ssl/SslSocket.h new file mode 100644 index 0000000000..a82e9133e8 --- /dev/null +++ b/cpp/src/qpid/sys/ssl/SslSocket.h @@ -0,0 +1,117 @@ +#ifndef _sys_ssl_Socket_h +#define _sys_ssl_Socket_h + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/sys/IOHandle.h" +#include <nspr4/nspr.h> + +#include <string> + +struct sockaddr; + +namespace qpid { +namespace sys { + +class Duration; + +namespace ssl { + +class SslSocket : public qpid::sys::IOHandle +{ +public: + /** Create a socket wrapper for descriptor. */ + SslSocket(); + + /** Set socket non blocking */ + void setNonblocking() const; + + /** Set tcp-nodelay */ + void setTcpNoDelay(bool nodelay) const; + + void connect(const std::string& host, uint16_t port) const; + + void close() const; + + /** Bind to a port and start listening. + *@param port 0 means choose an available port. + *@param backlog maximum number of pending connections. + *@param certName name of certificate to use to identify the server + *@return The bound port. + */ + int listen(uint16_t port = 0, int backlog = 10, const std::string& certName = "localhost.localdomain", bool clientAuth = false) const; + + /** + * Accept a connection from a socket that is already listening + * and has an incoming connection + */ + SslSocket* accept(struct sockaddr *addr, socklen_t *addrlen) const; + + // TODO The following are raw operations, maybe they need better wrapping? + int read(void *buf, size_t count) const; + int write(const void *buf, size_t count) const; + + /** Returns the "socket name" ie the address bound to + * the near end of the socket + */ + std::string getSockname() const; + + /** Returns the "peer name" ie the address bound to + * the remote end of the socket + */ + std::string getPeername() const; + + /** + * Returns an address (host and port) for the remote end of the + * socket + */ + std::string getPeerAddress() const; + /** + * Returns an address (host and port) for the local end of the + * socket + */ + std::string getLocalAddress() const; + + uint16_t getLocalPort() const; + uint16_t getRemotePort() const; + + /** + * Returns the error code stored in the socket. This may be used + * to determine the result of a non-blocking connect. + */ + int getError() const; + +private: + mutable std::string connectname; + mutable PRFileDesc* socket; + /** + * 'model' socket, with configuration to use when importing + * accepted sockets for use as ssl sockets. Set on listen(), used + * in accept to pass through to newly created socket instances. + */ + mutable PRFileDesc* prototype; + + SslSocket(IOHandlePrivate* ioph, PRFileDesc* model); +}; + +}}} +#endif /*!_sys_ssl_Socket_h*/ diff --git a/cpp/src/qpid/sys/ssl/check.cpp b/cpp/src/qpid/sys/ssl/check.cpp new file mode 100644 index 0000000000..2f95ab71b8 --- /dev/null +++ b/cpp/src/qpid/sys/ssl/check.cpp @@ -0,0 +1,70 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "check.h" +#include <nss3/secerr.h> +#include <nss3/sslerr.h> +#include <boost/format.hpp> + +using boost::format; +using boost::str; + +namespace qpid { +namespace sys { +namespace ssl { + +const std::string SSL_ERROR_BAD_CERT_DOMAIN_STR = + "Unable to communicate securely with peer: requested domain name does not match the server's certificate."; +const std::string SSL_ERROR_BAD_CERT_ALERT_STR = "SSL peer cannot verify your certificate."; +const std::string SEC_ERROR_BAD_DATABASE_STR = "Security library: bad database."; +const std::string SSL_ERROR_NO_CERTIFICATE_STR = "Unable to find the certificate or key necessary for authentication."; + +ErrorString::ErrorString() : code(PR_GetError()), buffer(new char[PR_GetErrorTextLength()]), used(PR_GetErrorText(buffer)) {} + +ErrorString::~ErrorString() +{ + delete[] buffer; +} + +std::string ErrorString::getString() const +{ + std::string msg = std::string(buffer, used); + if (!used) { + //seems most of the NSPR/NSS errors don't have text set for + //them, add a few specific ones in here. (TODO: more complete + //list?): + switch (code) { + case SSL_ERROR_BAD_CERT_DOMAIN: msg = SSL_ERROR_BAD_CERT_DOMAIN_STR; break; + case SSL_ERROR_BAD_CERT_ALERT: msg = SSL_ERROR_BAD_CERT_ALERT_STR; break; + case SEC_ERROR_BAD_DATABASE: msg = SEC_ERROR_BAD_DATABASE_STR; break; + case SSL_ERROR_NO_CERTIFICATE: msg = SSL_ERROR_NO_CERTIFICATE_STR; break; + } + } + return str(format("%1% [%2%]") % msg % code); +} + +std::ostream& operator<<(std::ostream& out, const ErrorString& err) +{ + out << err.getString(); + return out; +} + + +}}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/check.h b/cpp/src/qpid/sys/ssl/check.h new file mode 100644 index 0000000000..6217a39429 --- /dev/null +++ b/cpp/src/qpid/sys/ssl/check.h @@ -0,0 +1,53 @@ +#ifndef QPID_SYS_SSL_CHECK_H +#define QPID_SYS_SSL_CHECK_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include <iostream> +#include <string> +#include <nspr4/nspr.h> +#include <nss3/nss.h> + +namespace qpid { +namespace sys { +namespace ssl { + +class ErrorString +{ + public: + ErrorString(); + ~ErrorString(); + std::string getString() const; + private: + const int code; + char* const buffer; + const size_t used; +}; + +std::ostream& operator<<(std::ostream& out, const ErrorString& err); + +}}} // namespace qpid::sys::ssl + + +#define NSS_CHECK(value) if (value != SECSuccess) { throw Exception(QPID_MSG("Failed: " << qpid::sys::ssl::ErrorString())); } +#define PR_CHECK(value) if (value != PR_SUCCESS) { throw Exception(QPID_MSG("Failed: " << qpid::sys::ssl::ErrorString())); } + +#endif /*!QPID_SYS_SSL_CHECK_H*/ diff --git a/cpp/src/qpid/sys/ssl/util.cpp b/cpp/src/qpid/sys/ssl/util.cpp new file mode 100644 index 0000000000..63855d49ac --- /dev/null +++ b/cpp/src/qpid/sys/ssl/util.cpp @@ -0,0 +1,119 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include "util.h" +#include "check.h" +#include "qpid/Exception.h" +#include "qpid/sys/SystemInfo.h" + +#include <unistd.h> +#include <nspr4/nspr.h> +#include <nss3/nss.h> +#include <nss3/pk11pub.h> +#include <nss3/ssl.h> + +#include <iostream> +#include <fstream> +#include <boost/filesystem/operations.hpp> +#include <boost/filesystem/path.hpp> + +namespace qpid { +namespace sys { +namespace ssl { + +std::string defaultCertName() +{ + TcpAddress address; + if (SystemInfo::getLocalHostname(address)) { + return address.host; + } else { + return "localhost"; + } +} + +SslOptions::SslOptions() : qpid::Options("SSL Settings"), + certDbPath(CERT_DB), + certName(defaultCertName()), + exportPolicy(false) +{ + addOptions() + ("ssl-use-export-policy", optValue(exportPolicy), "Use NSS export policy") + ("ssl-cert-password-file", optValue(certPasswordFile, "PATH"), "File containing password to use for accessing certificate database") + ("ssl-cert-db", optValue(certDbPath, "PATH"), "Path to directory containing certificate database") + ("ssl-cert-name", optValue(certName, "NAME"), "Name of the certificate to use"); +} + +SslOptions& SslOptions::operator=(const SslOptions& o) +{ + certDbPath = o.certDbPath; + certName = o.certName; + certPasswordFile = o.certPasswordFile; + exportPolicy = o.exportPolicy; + return *this; +} + +char* promptForPassword(PK11SlotInfo*, PRBool retry, void*) +{ + if (retry) return 0; + //TODO: something else? + return PL_strdup(getpass("Please enter the password for accessing the certificate database:")); +} + +SslOptions SslOptions::global; + +char* readPasswordFromFile(PK11SlotInfo*, PRBool retry, void*) +{ + const std::string& passwordFile = SslOptions::global.certPasswordFile; + if (retry || passwordFile.empty() || !boost::filesystem::exists(passwordFile)) { + return 0; + } else { + std::ifstream file(passwordFile.c_str()); + std::string password; + file >> password; + return PL_strdup(password.c_str()); + } +} + +void initNSS(const SslOptions& options, bool server) +{ + SslOptions::global = options; + if (options.certPasswordFile.empty()) { + PK11_SetPasswordFunc(promptForPassword); + } else { + PK11_SetPasswordFunc(readPasswordFromFile); + } + NSS_CHECK(NSS_Init(options.certDbPath.c_str())); + if (options.exportPolicy) { + NSS_CHECK(NSS_SetExportPolicy()); + } else { + NSS_CHECK(NSS_SetDomesticPolicy()); + } + if (server) { + //use defaults for all args, TODO: may want to make this configurable + SSL_ConfigServerSessionIDCache(0, 0, 0, 0); + } +} + +void shutdownNSS() +{ + NSS_Shutdown(); +} + +}}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/util.h b/cpp/src/qpid/sys/ssl/util.h new file mode 100644 index 0000000000..f34adab7be --- /dev/null +++ b/cpp/src/qpid/sys/ssl/util.h @@ -0,0 +1,50 @@ +#ifndef QPID_SYS_SSL_UTIL_H +#define QPID_SYS_SSL_UTIL_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/Options.h" +#include <string> + +namespace qpid { +namespace sys { +namespace ssl { + +struct SslOptions : qpid::Options +{ + static SslOptions global; + + std::string certDbPath; + std::string certName; + std::string certPasswordFile; + bool exportPolicy; + + SslOptions(); + SslOptions& operator=(const SslOptions&); +}; + +void initNSS(const SslOptions& options, bool server = false); +void shutdownNSS(); + +}}} // namespace qpid::sys::ssl + +#endif /*!QPID_SYS_SSL_UTIL_H*/ |
