diff options
author | Kim van der Riet <kpvdr@apache.org> | 2013-02-28 16:14:30 +0000 |
---|---|---|
committer | Kim van der Riet <kpvdr@apache.org> | 2013-02-28 16:14:30 +0000 |
commit | 9c73ef7a5ac10acd6a50d5d52bd721fc2faa5919 (patch) | |
tree | 2a890e1df09e5b896a9b4168a7b22648f559a1f2 /cpp/src/qpid/sys | |
parent | 172d9b2a16cfb817bbe632d050acba7e31401cd2 (diff) | |
download | qpid-python-asyncstore.tar.gz |
Update from trunk r1375509 through r1450773asyncstore
git-svn-id: https://svn.apache.org/repos/asf/qpid/branches/asyncstore@1451244 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'cpp/src/qpid/sys')
58 files changed, 1189 insertions, 2026 deletions
diff --git a/cpp/src/qpid/sys/AggregateOutput.cpp b/cpp/src/qpid/sys/AggregateOutput.cpp index fc95f46fb9..ebc5689ce5 100644 --- a/cpp/src/qpid/sys/AggregateOutput.cpp +++ b/cpp/src/qpid/sys/AggregateOutput.cpp @@ -32,8 +32,6 @@ void AggregateOutput::abort() { control.abort(); } void AggregateOutput::activateOutput() { control.activateOutput(); } -void AggregateOutput::giveReadCredit(int32_t credit) { control.giveReadCredit(credit); } - namespace { // Clear the busy flag and notify waiting threads in destructor. struct ScopedBusy { @@ -51,6 +49,7 @@ bool AggregateOutput::doOutput() { while (!tasks.empty()) { OutputTask* t=tasks.front(); tasks.pop_front(); + taskSet.erase(t); bool didOutput; { // Allow concurrent call to addOutputTask. @@ -59,7 +58,9 @@ bool AggregateOutput::doOutput() { didOutput = t->doOutput(); } if (didOutput) { - tasks.push_back(t); + if (taskSet.insert(t).second) { + tasks.push_back(t); + } return true; } } @@ -68,12 +69,15 @@ bool AggregateOutput::doOutput() { void AggregateOutput::addOutputTask(OutputTask* task) { Mutex::ScopedLock l(lock); - tasks.push_back(task); + if (taskSet.insert(task).second) { + tasks.push_back(task); + } } void AggregateOutput::removeOutputTask(OutputTask* task) { Mutex::ScopedLock l(lock); while (busy) lock.wait(); + taskSet.erase(task); tasks.erase(std::remove(tasks.begin(), tasks.end(), task), tasks.end()); } @@ -81,6 +85,7 @@ void AggregateOutput::removeAll() { Mutex::ScopedLock l(lock); while (busy) lock.wait(); + taskSet.clear(); tasks.clear(); } diff --git a/cpp/src/qpid/sys/AggregateOutput.h b/cpp/src/qpid/sys/AggregateOutput.h index d7c0ff29e3..e9dbd5a4cc 100644 --- a/cpp/src/qpid/sys/AggregateOutput.h +++ b/cpp/src/qpid/sys/AggregateOutput.h @@ -28,6 +28,7 @@ #include <algorithm> #include <deque> +#include <set> namespace qpid { namespace sys { @@ -44,9 +45,11 @@ namespace sys { class QPID_COMMON_CLASS_EXTERN AggregateOutput : public OutputTask, public OutputControl { typedef std::deque<OutputTask*> TaskList; + typedef std::set<OutputTask*> TaskSet; Monitor lock; TaskList tasks; + TaskSet taskSet; bool busy; OutputControl& control; @@ -56,7 +59,6 @@ class QPID_COMMON_CLASS_EXTERN AggregateOutput : public OutputTask, public Outpu // These may be called concurrently with any function. QPID_COMMON_EXTERN void abort(); QPID_COMMON_EXTERN void activateOutput(); - QPID_COMMON_EXTERN void giveReadCredit(int32_t); QPID_COMMON_EXTERN void addOutputTask(OutputTask* t); // These functions must not be called concurrently with each other. diff --git a/cpp/src/qpid/sys/AsynchIO.h b/cpp/src/qpid/sys/AsynchIO.h index b2eaaac9de..679665f8ad 100644 --- a/cpp/src/qpid/sys/AsynchIO.h +++ b/cpp/src/qpid/sys/AsynchIO.h @@ -21,9 +21,11 @@ * */ -#include "qpid/sys/IntegerTypes.h" #include "qpid/CommonImportExport.h" +#include "qpid/sys/IntegerTypes.h" +#include "qpid/sys/SecuritySettings.h" + #include <string.h> #include <boost/function.hpp> @@ -56,6 +58,7 @@ class AsynchConnector { public: typedef boost::function1<void, const Socket&> ConnectedCallback; typedef boost::function3<void, const Socket&, int, const std::string&> FailedCallback; + typedef boost::function1<void, AsynchConnector&> RequestCallback; // Call create() to allocate a new AsynchConnector object with the // specified poller, addressing, and callbacks. @@ -70,6 +73,7 @@ public: FailedCallback failCb); virtual void start(boost::shared_ptr<Poller> poller) = 0; virtual void stop() {}; + virtual void requestCallback(RequestCallback) = 0; protected: AsynchConnector() {} virtual ~AsynchConnector() {} @@ -155,11 +159,11 @@ public: virtual void notifyPendingWrite() = 0; virtual void queueWriteClose() = 0; virtual bool writeQueueEmpty() = 0; - virtual void startReading() = 0; - virtual void stopReading() = 0; virtual void requestCallback(RequestCallback) = 0; virtual BufferBase* getQueuedBuffer() = 0; + virtual SecuritySettings getSecuritySettings() = 0; + protected: // Derived class manages lifetime; must be constructed using the // static create() method. Deletes not allowed from outside. diff --git a/cpp/src/qpid/sys/AsynchIOHandler.cpp b/cpp/src/qpid/sys/AsynchIOHandler.cpp index 2e117a3fb7..cf08b482e6 100644 --- a/cpp/src/qpid/sys/AsynchIOHandler.cpp +++ b/cpp/src/qpid/sys/AsynchIOHandler.cpp @@ -51,15 +51,15 @@ struct ProtocolTimeoutTask : public sys::TimerTask { } }; -AsynchIOHandler::AsynchIOHandler(const std::string& id, ConnectionCodec::Factory* f) : +AsynchIOHandler::AsynchIOHandler(const std::string& id, ConnectionCodec::Factory* f, bool isClient0, bool nodict0) : identifier(id), aio(0), factory(f), codec(0), reads(0), readError(false), - isClient(false), - readCredit(InfiniteCredit) + isClient(isClient0), + nodict(nodict0) {} AsynchIOHandler::~AsynchIOHandler() { @@ -97,25 +97,20 @@ void AsynchIOHandler::abort() { if (!readError) { aio->requestCallback(boost::bind(&AsynchIOHandler::eof, this, _1)); } + aio->queueWriteClose(); } void AsynchIOHandler::activateOutput() { aio->notifyPendingWrite(); } -// Input side -void AsynchIOHandler::giveReadCredit(int32_t credit) { - // Check whether we started in the don't about credit state - if (readCredit.boolCompareAndSwap(InfiniteCredit, credit)) - return; - // TODO In theory should be able to use an atomic operation before taking the lock - // but in practice there seems to be an unexplained race in that case - ScopedLock<Mutex> l(creditLock); - if (readCredit.fetchAndAdd(credit) != 0) - return; - assert(readCredit.get() >= 0); - if (readCredit.get() != 0) - aio->startReading(); +namespace { + SecuritySettings getSecuritySettings(AsynchIO* aio, bool nodict) + { + SecuritySettings settings = aio->getSecuritySettings(); + settings.nodict = nodict; + return settings; + } } void AsynchIOHandler::readbuff(AsynchIO& , AsynchIO::BufferBase* buff) { @@ -123,26 +118,6 @@ void AsynchIOHandler::readbuff(AsynchIO& , AsynchIO::BufferBase* buff) { return; } - // Check here for read credit - if (readCredit.get() != InfiniteCredit) { - if (readCredit.get() == 0) { - // FIXME aconway 2009-10-01: Workaround to avoid "false wakeups". - // readbuff is sometimes called with no credit. - // This should be fixed somewhere else to avoid such calls. - aio->unread(buff); - return; - } - // TODO In theory should be able to use an atomic operation before taking the lock - // but in practice there seems to be an unexplained race in that case - ScopedLock<Mutex> l(creditLock); - if (--readCredit == 0) { - assert(readCredit.get() >= 0); - if (readCredit.get() == 0) { - aio->stopReading(); - } - } - } - ++reads; size_t decoded = 0; if (codec) { // Already initiated @@ -168,13 +143,16 @@ void AsynchIOHandler::readbuff(AsynchIO& , AsynchIO::BufferBase* buff) { QPID_LOG(debug, "RECV [" << identifier << "]: INIT(" << protocolInit << ")"); try { - codec = factory->create(protocolInit.getVersion(), *this, identifier, SecuritySettings()); + codec = factory->create(protocolInit.getVersion(), *this, identifier, getSecuritySettings(aio, nodict)); if (!codec) { //TODO: may still want to revise this... //send valid version header & close connection. write(framing::ProtocolInitiation(framing::highestProtocolVersion)); readError = true; aio->queueWriteClose(); + } else { + //read any further data that may already have been sent + decoded += codec->decode(buff->bytes+buff->dataStart+in.getPosition(), buff->dataCount-in.getPosition()); } } catch (const std::exception& e) { QPID_LOG(error, e.what()); @@ -223,7 +201,7 @@ void AsynchIOHandler::nobuffs(AsynchIO&) { void AsynchIOHandler::idle(AsynchIO&){ if (isClient && codec == 0) { - codec = factory->create(*this, identifier, SecuritySettings()); + codec = factory->create(*this, identifier, getSecuritySettings(aio, nodict)); write(framing::ProtocolInitiation(codec->getVersion())); // We've just sent the protocol negotiation so we can cancel the timeout for that // This is not ideal, because we've not received anything yet, but heartbeats will diff --git a/cpp/src/qpid/sys/AsynchIOHandler.h b/cpp/src/qpid/sys/AsynchIOHandler.h index fd0bc140e5..d93e24fd4c 100644 --- a/cpp/src/qpid/sys/AsynchIOHandler.h +++ b/cpp/src/qpid/sys/AsynchIOHandler.h @@ -51,24 +51,19 @@ class AsynchIOHandler : public OutputControl { uint32_t reads; bool readError; bool isClient; - AtomicValue<int32_t> readCredit; - static const int32_t InfiniteCredit = -1; - Mutex creditLock; + bool nodict; boost::intrusive_ptr<sys::TimerTask> timeoutTimerTask; void write(const framing::ProtocolInitiation&); public: - QPID_COMMON_EXTERN AsynchIOHandler(const std::string& id, qpid::sys::ConnectionCodec::Factory* f ); + QPID_COMMON_EXTERN AsynchIOHandler(const std::string& id, qpid::sys::ConnectionCodec::Factory* f, bool isClient, bool nodict); QPID_COMMON_EXTERN ~AsynchIOHandler(); QPID_COMMON_EXTERN void init(AsynchIO* a, Timer& timer, uint32_t maxTime); - QPID_COMMON_INLINE_EXTERN void setClient() { isClient = true; } - // Output side QPID_COMMON_EXTERN void abort(); QPID_COMMON_EXTERN void activateOutput(); - QPID_COMMON_EXTERN void giveReadCredit(int32_t credit); // Input side QPID_COMMON_EXTERN void readbuff(AsynchIO& aio, AsynchIOBufferBase* buff); diff --git a/cpp/src/qpid/sys/ClusterSafe.cpp b/cpp/src/qpid/sys/ClusterSafe.cpp deleted file mode 100644 index dd37615145..0000000000 --- a/cpp/src/qpid/sys/ClusterSafe.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include "ClusterSafe.h" -#include "qpid/log/Statement.h" -#include "qpid/sys/Thread.h" -#include <stdlib.h> - -namespace qpid { -namespace sys { - -namespace { -bool inCluster = false; -QPID_TSS bool inContext = false; -} - -bool isClusterSafe() { return !inCluster || inContext; } - -void assertClusterSafe() { - if (!isClusterSafe()) { - QPID_LOG(critical, "Modified cluster state outside of cluster context"); - ::abort(); - } -} - -ClusterSafeScope::ClusterSafeScope() { - save = inContext; - inContext = true; -} - -ClusterSafeScope::~ClusterSafeScope() { - assert(inContext); - inContext = save; -} - -ClusterUnsafeScope::ClusterUnsafeScope() { - save = inContext; - inContext = false; -} - -ClusterUnsafeScope::~ClusterUnsafeScope() { - assert(!inContext); - inContext = save; -} - -void enableClusterSafe() { inCluster = true; } - -}} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/ClusterSafe.h b/cpp/src/qpid/sys/ClusterSafe.h deleted file mode 100644 index 27e4eb46a5..0000000000 --- a/cpp/src/qpid/sys/ClusterSafe.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef QPID_SYS_CLUSTERSAFE_H -#define QPID_SYS_CLUSTERSAFE_H - -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include "qpid/CommonImportExport.h" - -namespace qpid { -namespace sys { - -/** - * Assertion to add to code that modifies clustered state. - * - * In a non-clustered broker this is a no-op. - * - * In a clustered broker, checks that it is being called - * in a context where it is safe to modify clustered state. - * If not it aborts the process as this is a serious bug. - * - * This function is in the common library rather than the cluster - * library because it is called by code in the broker library. - */ -QPID_COMMON_EXTERN void assertClusterSafe(); - -/** - * In a non-clustered broker, returns true. - * - * In a clustered broker returns true if we are in a context where it - * is safe to modify cluster state. - * - * This function is in the common library rather than the cluster - * library because it is called by code in the broker library. - */ -QPID_COMMON_EXTERN bool isClusterSafe(); - -/** - * Mark a scope as cluster safe. Sets isClusterSafe in constructor and resets - * to previous value in destructor. - */ -class ClusterSafeScope { - public: - ClusterSafeScope(); - ~ClusterSafeScope(); - private: - bool save; -}; - -/** - * Mark a scope as cluster unsafe. Clears isClusterSafe in constructor and resets - * to previous value in destructor. - */ -class ClusterUnsafeScope { - public: - QPID_COMMON_EXTERN ClusterUnsafeScope(); - QPID_COMMON_EXTERN ~ClusterUnsafeScope(); - private: - bool save; -}; - -/** - * Enable cluster-safe assertions. By default they are no-ops. - * Called by cluster code. - */ -void enableClusterSafe(); - -}} // namespace qpid::sys - -#endif /*!QPID_SYS_CLUSTERSAFE_H*/ diff --git a/cpp/src/qpid/sys/Codec.h b/cpp/src/qpid/sys/Codec.h index ace721fbcc..e398403e47 100644 --- a/cpp/src/qpid/sys/Codec.h +++ b/cpp/src/qpid/sys/Codec.h @@ -42,7 +42,7 @@ class Codec /** Encode into buffer, return number of bytes encoded */ - virtual std::size_t encode(const char* buffer, std::size_t size) = 0; + virtual std::size_t encode(char* buffer, std::size_t size) = 0; /** Return true if we have data to encode */ virtual bool canEncode() = 0; diff --git a/cpp/src/qpid/sys/ConnectionOutputHandlerPtr.h b/cpp/src/qpid/sys/ConnectionOutputHandlerPtr.h index 95a08d15ae..53d56ad716 100644 --- a/cpp/src/qpid/sys/ConnectionOutputHandlerPtr.h +++ b/cpp/src/qpid/sys/ConnectionOutputHandlerPtr.h @@ -45,7 +45,6 @@ class ConnectionOutputHandlerPtr : public ConnectionOutputHandler size_t getBuffered() const { return next->getBuffered(); } void abort() { next->abort(); } void activateOutput() { next->activateOutput(); } - void giveReadCredit(int32_t credit) { next->giveReadCredit(credit); } void send(framing::AMQFrame& f) { next->send(f); } private: diff --git a/cpp/src/qpid/sys/FileSysDir.h b/cpp/src/qpid/sys/FileSysDir.h index ffe7823f0a..7432fe39c9 100755 --- a/cpp/src/qpid/sys/FileSysDir.h +++ b/cpp/src/qpid/sys/FileSysDir.h @@ -54,6 +54,15 @@ class FileSysDir void mkdir(void); + typedef void Callback(const std::string&); + + /** + * Call the Callback function for every regular file in the directory + * + * @param cb Callback function that receives the full path to the file + */ + void forEachFile(Callback cb) const; + std::string getPath () { return dirPath; } }; diff --git a/cpp/src/qpid/sys/OutputControl.h b/cpp/src/qpid/sys/OutputControl.h index eae99beb0f..0d801e9d16 100644 --- a/cpp/src/qpid/sys/OutputControl.h +++ b/cpp/src/qpid/sys/OutputControl.h @@ -1,3 +1,6 @@ +#ifndef QPID_SYS_OUTPUT_CONTROL_H +#define QPID_SYS_OUTPUT_CONTROL_H + /* * * Licensed to the Apache Software Foundation (ASF) under one @@ -21,9 +24,6 @@ #include "qpid/sys/IntegerTypes.h" -#ifndef _OutputControl_ -#define _OutputControl_ - namespace qpid { namespace sys { @@ -33,11 +33,10 @@ namespace sys { virtual ~OutputControl() {} virtual void abort() = 0; virtual void activateOutput() = 0; - virtual void giveReadCredit(int32_t credit) = 0; }; } } -#endif +#endif /*!QPID_SYS_OUTPUT_CONTROL_H*/ diff --git a/cpp/src/qpid/sys/ProtocolFactory.h b/cpp/src/qpid/sys/ProtocolFactory.h index 4d198a92da..236398c111 100644 --- a/cpp/src/qpid/sys/ProtocolFactory.h +++ b/cpp/src/qpid/sys/ProtocolFactory.h @@ -42,10 +42,10 @@ class ProtocolFactory : public qpid::SharedObject<ProtocolFactory> virtual void accept(boost::shared_ptr<Poller>, ConnectionCodec::Factory*) = 0; virtual void connect( boost::shared_ptr<Poller>, + const std::string& name, const std::string& host, const std::string& port, ConnectionCodec::Factory* codec, ConnectFailedCallback failed) = 0; - virtual bool supports(const std::string& /*capability*/) { return false; } }; inline ProtocolFactory::~ProtocolFactory() {} diff --git a/cpp/src/qpid/sys/RdmaIOPlugin.cpp b/cpp/src/qpid/sys/RdmaIOPlugin.cpp index b491d28d0a..51cc0ed109 100644 --- a/cpp/src/qpid/sys/RdmaIOPlugin.cpp +++ b/cpp/src/qpid/sys/RdmaIOPlugin.cpp @@ -23,6 +23,7 @@ #include "qpid/Plugin.h" #include "qpid/broker/Broker.h" +#include "qpid/broker/NameGenerator.h" #include "qpid/framing/AMQP_HighestVersion.h" #include "qpid/log/Statement.h" #include "qpid/sys/rdma/RdmaIO.h" @@ -67,7 +68,6 @@ class RdmaIOHandler : public OutputControl { void close(); void abort(); void activateOutput(); - void giveReadCredit(int32_t credit); void initProtocolOut(); // Input side @@ -83,7 +83,7 @@ class RdmaIOHandler : public OutputControl { }; RdmaIOHandler::RdmaIOHandler(Rdma::Connection::intrusive_ptr c, qpid::sys::ConnectionCodec::Factory* f) : - identifier(c->getFullName()), + identifier(broker::QPID_NAME_PREFIX+c->getFullName()), factory(f), codec(0), readError(false), @@ -199,10 +199,6 @@ void RdmaIOHandler::full(Rdma::AsynchIO&) { QPID_LOG(debug, "Rdma: buffer full [" << identifier << "]"); } -// TODO: Dummy implementation of read throttling -void RdmaIOHandler::giveReadCredit(int32_t) { -} - // The logic here is subtly different from TCP as RDMA is message oriented // so we define that an RDMA message is a frame - in this case there is no putting back // of any message remainder - there shouldn't be any. And what we read here can't be @@ -250,7 +246,7 @@ class RdmaIOProtocolFactory : public ProtocolFactory { public: RdmaIOProtocolFactory(int16_t port, int backlog); void accept(Poller::shared_ptr, ConnectionCodec::Factory*); - void connect(Poller::shared_ptr, const string& host, const std::string& port, ConnectionCodec::Factory*, ConnectFailedCallback); + void connect(Poller::shared_ptr, const std::string& name, const string& host, const std::string& port, ConnectionCodec::Factory*, ConnectFailedCallback); uint16_t getPort() const; @@ -371,6 +367,7 @@ void RdmaIOProtocolFactory::connected(Poller::shared_ptr poller, Rdma::Connectio void RdmaIOProtocolFactory::connect( Poller::shared_ptr poller, + const std::string& /*name*/, const std::string& host, const std::string& port, ConnectionCodec::Factory* f, ConnectFailedCallback failed) diff --git a/cpp/src/qpid/sys/SecurityLayer.h b/cpp/src/qpid/sys/SecurityLayer.h index 52bc40e352..317ada16de 100644 --- a/cpp/src/qpid/sys/SecurityLayer.h +++ b/cpp/src/qpid/sys/SecurityLayer.h @@ -33,8 +33,12 @@ namespace sys { class SecurityLayer : public Codec { public: + SecurityLayer(int ssf_) : ssf(ssf_) {} + int getSsf() const { return ssf; } virtual void init(Codec*) = 0; virtual ~SecurityLayer() {} + private: + int ssf; }; }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/SecuritySettings.h b/cpp/src/qpid/sys/SecuritySettings.h index bfcd08fd0f..d595cad660 100644 --- a/cpp/src/qpid/sys/SecuritySettings.h +++ b/cpp/src/qpid/sys/SecuritySettings.h @@ -21,6 +21,8 @@ * under the License. * */ +#include <string> + namespace qpid { namespace sys { diff --git a/cpp/src/qpid/sys/Socket.h b/cpp/src/qpid/sys/Socket.h index defec4879c..38183bd5fd 100644 --- a/cpp/src/qpid/sys/Socket.h +++ b/cpp/src/qpid/sys/Socket.h @@ -22,7 +22,6 @@ * */ -#include "qpid/sys/IOHandle.h" #include "qpid/sys/IntegerTypes.h" #include "qpid/CommonImportExport.h" #include <string> @@ -31,45 +30,43 @@ namespace qpid { namespace sys { class Duration; +class IOHandle; class SocketAddress; -class QPID_COMMON_CLASS_EXTERN Socket : public IOHandle +class Socket { public: - /** Create a socket wrapper for descriptor. */ - QPID_COMMON_EXTERN Socket(); + virtual ~Socket() {}; - /** Create a new Socket which is the same address family as this one */ - QPID_COMMON_EXTERN Socket* createSameTypeSocket() const; + virtual operator const IOHandle&() const = 0; /** Set socket non blocking */ - void setNonblocking() const; + virtual void setNonblocking() const = 0; - QPID_COMMON_EXTERN void setTcpNoDelay() const; + virtual void setTcpNoDelay() const = 0; - QPID_COMMON_EXTERN void connect(const std::string& host, const std::string& port) const; - QPID_COMMON_EXTERN void connect(const SocketAddress&) const; + virtual void connect(const SocketAddress&) const = 0; + virtual void finishConnect(const SocketAddress&) const = 0; - QPID_COMMON_EXTERN void close() const; + virtual void close() const = 0; /** Bind to a port and start listening. *@param port 0 means choose an available port. *@param backlog maximum number of pending connections. *@return The bound port. */ - QPID_COMMON_EXTERN int listen(const std::string& host = "", const std::string& port = "0", int backlog = 10) const; - QPID_COMMON_EXTERN int listen(const SocketAddress&, int backlog = 10) const; + virtual int listen(const SocketAddress&, int backlog = 10) const = 0; /** * Returns an address (host and port) for the remote end of the * socket */ - QPID_COMMON_EXTERN std::string getPeerAddress() const; + virtual std::string getPeerAddress() const = 0; /** * Returns an address (host and port) for the local end of the * socket */ - QPID_COMMON_EXTERN std::string getLocalAddress() const; + virtual std::string getLocalAddress() const = 0; /** * Returns the full address of the connection: local and remote host and port. @@ -80,31 +77,24 @@ public: * Returns the error code stored in the socket. This may be used * to determine the result of a non-blocking connect. */ - QPID_COMMON_EXTERN int getError() const; + virtual int getError() const = 0; /** Accept a connection from a socket that is already listening * and has an incoming connection */ - QPID_COMMON_EXTERN Socket* accept() const; + virtual Socket* accept() const = 0; - // TODO The following are raw operations, maybe they need better wrapping? - QPID_COMMON_EXTERN int read(void *buf, size_t count) const; - QPID_COMMON_EXTERN int write(const void *buf, size_t count) const; + virtual int read(void *buf, size_t count) const = 0; + virtual int write(const void *buf, size_t count) const = 0; -private: - /** Create socket */ - void createSocket(const SocketAddress&) const; - -public: - /** Construct socket with existing handle */ - Socket(IOHandlePrivate*); - -protected: - mutable std::string localname; - mutable std::string peername; - mutable bool nonblocking; - mutable bool nodelay; + /* Transport security related: */ + virtual int getKeyLen() const = 0; + virtual std::string getClientAuthId() const = 0; }; +/** Make the default socket for whatever platform we are executing on + */ +QPID_COMMON_EXTERN Socket* createSocket(); + }} #endif /*!_sys_Socket_h*/ diff --git a/cpp/src/qpid/sys/SocketAddress.h b/cpp/src/qpid/sys/SocketAddress.h index dcca109d94..a4da5cca79 100644 --- a/cpp/src/qpid/sys/SocketAddress.h +++ b/cpp/src/qpid/sys/SocketAddress.h @@ -44,11 +44,12 @@ public: QPID_COMMON_EXTERN bool nextAddress(); QPID_COMMON_EXTERN std::string asString(bool numeric=true) const; + QPID_COMMON_EXTERN std::string getHost() const; QPID_COMMON_EXTERN void setAddrInfoPort(uint16_t port); QPID_COMMON_EXTERN static std::string asString(::sockaddr const * const addr, size_t addrlen); QPID_COMMON_EXTERN static uint16_t getPort(::sockaddr const * const addr); - + private: std::string host; diff --git a/cpp/src/qpid/sys/SslPlugin.cpp b/cpp/src/qpid/sys/SslPlugin.cpp index 069e97758e..a40da24eb8 100644 --- a/cpp/src/qpid/sys/SslPlugin.cpp +++ b/cpp/src/qpid/sys/SslPlugin.cpp @@ -22,19 +22,19 @@ #include "qpid/sys/ProtocolFactory.h" #include "qpid/Plugin.h" -#include "qpid/sys/ssl/check.h" -#include "qpid/sys/ssl/util.h" -#include "qpid/sys/ssl/SslHandler.h" +#include "qpid/broker/Broker.h" +#include "qpid/broker/NameGenerator.h" +#include "qpid/log/Statement.h" #include "qpid/sys/AsynchIOHandler.h" #include "qpid/sys/AsynchIO.h" -#include "qpid/sys/ssl/SslIo.h" +#include "qpid/sys/ssl/util.h" #include "qpid/sys/ssl/SslSocket.h" -#include "qpid/broker/Broker.h" -#include "qpid/log/Statement.h" +#include "qpid/sys/SocketAddress.h" +#include "qpid/sys/SystemInfo.h" +#include "qpid/sys/Poller.h" #include <boost/bind.hpp> -#include <memory> - +#include <boost/ptr_container/ptr_vector.hpp> namespace qpid { namespace sys { @@ -64,38 +64,32 @@ struct SslServerOptions : ssl::SslOptions } }; -template <class T> -class SslProtocolFactoryTmpl : public ProtocolFactory { - private: - - typedef SslAcceptorTmpl<T> SslAcceptor; - +class SslProtocolFactory : public ProtocolFactory { + boost::ptr_vector<Socket> listeners; + boost::ptr_vector<AsynchAcceptor> acceptors; Timer& brokerTimer; uint32_t maxNegotiateTime; + uint16_t listeningPort; const bool tcpNoDelay; - T listener; - const uint16_t listeningPort; - std::auto_ptr<SslAcceptor> acceptor; bool nodict; public: - SslProtocolFactoryTmpl(const SslServerOptions&, int backlog, bool nodelay, Timer& timer, uint32_t maxTime); + SslProtocolFactory(const qpid::broker::Broker::Options& opts, const SslServerOptions& options, + Timer& timer); void accept(Poller::shared_ptr, ConnectionCodec::Factory*); - void connect(Poller::shared_ptr, const std::string& host, const std::string& port, + void connect(Poller::shared_ptr, const std::string& name, const std::string& host, const std::string& port, ConnectionCodec::Factory*, - boost::function2<void, int, std::string> failed); + ConnectFailedCallback); uint16_t getPort() const; - bool supports(const std::string& capability); private: - void established(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*, - bool isClient); + void establishedIncoming(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*); + void establishedOutgoing(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*, const std::string&); + void establishedCommon(AsynchIOHandler*, Poller::shared_ptr , const Socket&); + void connectFailed(const Socket&, int, const std::string&, ConnectFailedCallback); }; -typedef SslProtocolFactoryTmpl<SslSocket> SslProtocolFactory; -typedef SslProtocolFactoryTmpl<SslMuxSocket> SslMuxProtocolFactory; - // Static instance to initialise plugin static struct SslPlugin : public Plugin { @@ -124,7 +118,7 @@ static struct SslPlugin : public Plugin { } } } - + void initialize(Target& target) { QPID_LOG(trace, "Initialising SSL plugin"); broker::Broker* broker = dynamic_cast<broker::Broker*>(&target); @@ -139,19 +133,16 @@ static struct SslPlugin : public Plugin { const broker::Broker::Options& opts = broker->getOptions(); - ProtocolFactory::shared_ptr protocol(options.multiplex ? - static_cast<ProtocolFactory*>(new SslMuxProtocolFactory(options, - opts.connectionBacklog, - opts.tcpNoDelay, - broker->getTimer(), opts.maxNegotiateTime)) : - static_cast<ProtocolFactory*>(new SslProtocolFactory(options, - opts.connectionBacklog, - opts.tcpNoDelay, - broker->getTimer(), opts.maxNegotiateTime))); - QPID_LOG(notice, "Listening for " << - (options.multiplex ? "SSL or TCP" : "SSL") << - " connections on TCP port " << - protocol->getPort()); + ProtocolFactory::shared_ptr protocol( + static_cast<ProtocolFactory*>(new SslProtocolFactory(opts, options, broker->getTimer()))); + + if (protocol->getPort()!=0 ) { + QPID_LOG(notice, "Listening for " << + (options.multiplex ? "SSL or TCP" : "SSL") << + " connections on TCP/TCP6 port " << + protocol->getPort()); + } + broker->registerProtocolFactory("ssl", protocol); } catch (const std::exception& e) { QPID_LOG(error, "Failed to initialise SSL plugin: " << e.what()); @@ -161,99 +152,133 @@ static struct SslPlugin : public Plugin { } } sslPlugin; -template <class T> -SslProtocolFactoryTmpl<T>::SslProtocolFactoryTmpl(const SslServerOptions& options, int backlog, bool nodelay, Timer& timer, uint32_t maxTime) : +namespace { + // Expand list of Interfaces and addresses to a list of addresses + std::vector<std::string> expandInterfaces(const std::vector<std::string>& interfaces) { + std::vector<std::string> addresses; + // If there are no specific interfaces listed use a single "" to listen on every interface + if (interfaces.empty()) { + addresses.push_back(""); + return addresses; + } + for (unsigned i = 0; i < interfaces.size(); ++i) { + const std::string& interface = interfaces[i]; + if (!(SystemInfo::getInterfaceAddresses(interface, addresses))) { + // We don't have an interface of that name - + // Check for IPv6 ('[' ']') brackets and remove them + // then pass to be looked up directly + if (interface[0]=='[' && interface[interface.size()-1]==']') { + addresses.push_back(interface.substr(1, interface.size()-2)); + } else { + addresses.push_back(interface); + } + } + } + return addresses; + } +} + +SslProtocolFactory::SslProtocolFactory(const qpid::broker::Broker::Options& opts, const SslServerOptions& options, + Timer& timer) : brokerTimer(timer), - maxNegotiateTime(maxTime), - tcpNoDelay(nodelay), listeningPort(listener.listen(options.port, backlog, options.certName, options.clientAuth)), + maxNegotiateTime(opts.maxNegotiateTime), + tcpNoDelay(opts.tcpNoDelay), nodict(options.nodict) -{} - -void SslEstablished(Poller::shared_ptr poller, const qpid::sys::SslSocket& s, - ConnectionCodec::Factory* f, bool isClient, - Timer& timer, uint32_t maxTime, bool tcpNoDelay, bool nodict) { - qpid::sys::ssl::SslHandler* async = new qpid::sys::ssl::SslHandler(s.getFullAddress(), f, nodict); - - if (tcpNoDelay) { - s.setTcpNoDelay(tcpNoDelay); - QPID_LOG(info, "Set TCP_NODELAY on connection to " << s.getPeerAddress()); +{ + std::vector<std::string> addresses = expandInterfaces(opts.listenInterfaces); + if (addresses.empty()) { + // We specified some interfaces, but couldn't find addresses for them + QPID_LOG(warning, "SSL: No specified network interfaces found: Not Listening"); + listeningPort = 0; } - if (isClient) { - async->setClient(); + for (unsigned i = 0; i<addresses.size(); ++i) { + QPID_LOG(debug, "Using interface: " << addresses[i]); + SocketAddress sa(addresses[i], boost::lexical_cast<std::string>(options.port)); + + // We must have at least one resolved address + QPID_LOG(info, "Listening to: " << sa.asString()) + Socket* s = options.multiplex ? + new SslMuxSocket(options.certName, options.clientAuth) : + new SslSocket(options.certName, options.clientAuth); + uint16_t lport = s->listen(sa, opts.connectionBacklog); + QPID_LOG(debug, "Listened to: " << lport); + listeners.push_back(s); + + listeningPort = lport; + + // Try any other resolved addresses + while (sa.nextAddress()) { + // Hack to ensure that all listening connections are on the same port + sa.setAddrInfoPort(listeningPort); + QPID_LOG(info, "Listening to: " << sa.asString()) + Socket* s = options.multiplex ? + new SslMuxSocket(options.certName, options.clientAuth) : + new SslSocket(options.certName, options.clientAuth); + uint16_t lport = s->listen(sa, opts.connectionBacklog); + QPID_LOG(debug, "Listened to: " << lport); + listeners.push_back(s); + } } - - qpid::sys::ssl::SslIO* aio = new qpid::sys::ssl::SslIO(s, - boost::bind(&qpid::sys::ssl::SslHandler::readbuff, async, _1, _2), - boost::bind(&qpid::sys::ssl::SslHandler::eof, async, _1), - boost::bind(&qpid::sys::ssl::SslHandler::disconnect, async, _1), - boost::bind(&qpid::sys::ssl::SslHandler::closedSocket, async, _1, _2), - boost::bind(&qpid::sys::ssl::SslHandler::nobuffs, async, _1), - boost::bind(&qpid::sys::ssl::SslHandler::idle, async, _1)); - - async->init(aio,timer, maxTime); - aio->start(poller); -} - -template <> -void SslProtocolFactory::established(Poller::shared_ptr poller, const Socket& s, - ConnectionCodec::Factory* f, bool isClient) { - const SslSocket *sslSock = dynamic_cast<const SslSocket*>(&s); - - SslEstablished(poller, *sslSock, f, isClient, brokerTimer, maxNegotiateTime, tcpNoDelay, nodict); } -template <class T> -uint16_t SslProtocolFactoryTmpl<T>::getPort() const { - return listeningPort; // Immutable no need for lock. +void SslProtocolFactory::establishedIncoming(Poller::shared_ptr poller, const Socket& s, + ConnectionCodec::Factory* f) { + AsynchIOHandler* async = new AsynchIOHandler(broker::QPID_NAME_PREFIX+s.getFullAddress(), f, false, false); + establishedCommon(async, poller, s); } -template <class T> -void SslProtocolFactoryTmpl<T>::accept(Poller::shared_ptr poller, - ConnectionCodec::Factory* fact) { - acceptor.reset( - new SslAcceptor(listener, - boost::bind(&SslProtocolFactoryTmpl<T>::established, - this, poller, _1, fact, false))); - acceptor->start(poller); +void SslProtocolFactory::establishedOutgoing(Poller::shared_ptr poller, const Socket& s, + ConnectionCodec::Factory* f, const std::string& name) { + AsynchIOHandler* async = new AsynchIOHandler(name, f, true, false); + establishedCommon(async, poller, s); } -template <> -void SslMuxProtocolFactory::established(Poller::shared_ptr poller, const Socket& s, - ConnectionCodec::Factory* f, bool isClient) { - const SslSocket *sslSock = dynamic_cast<const SslSocket*>(&s); - - if (sslSock) { - SslEstablished(poller, *sslSock, f, isClient, brokerTimer, maxNegotiateTime, tcpNoDelay, nodict); - return; - } - - AsynchIOHandler* async = new AsynchIOHandler(s.getFullAddress(), f); - +void SslProtocolFactory::establishedCommon(AsynchIOHandler* async, Poller::shared_ptr poller, const Socket& s) { if (tcpNoDelay) { s.setTcpNoDelay(); QPID_LOG(info, "Set TCP_NODELAY on connection to " << s.getPeerAddress()); } - if (isClient) { - async->setClient(); - } - AsynchIO* aio = AsynchIO::create - (s, - boost::bind(&AsynchIOHandler::readbuff, async, _1, _2), - boost::bind(&AsynchIOHandler::eof, async, _1), - boost::bind(&AsynchIOHandler::disconnect, async, _1), - boost::bind(&AsynchIOHandler::closedSocket, async, _1, _2), - boost::bind(&AsynchIOHandler::nobuffs, async, _1), - boost::bind(&AsynchIOHandler::idle, async, _1)); + AsynchIO* aio = AsynchIO::create( + s, + boost::bind(&AsynchIOHandler::readbuff, async, _1, _2), + boost::bind(&AsynchIOHandler::eof, async, _1), + boost::bind(&AsynchIOHandler::disconnect, async, _1), + boost::bind(&AsynchIOHandler::closedSocket, async, _1, _2), + boost::bind(&AsynchIOHandler::nobuffs, async, _1), + boost::bind(&AsynchIOHandler::idle, async, _1)); async->init(aio, brokerTimer, maxNegotiateTime); aio->start(poller); } -template <class T> -void SslProtocolFactoryTmpl<T>::connect( +uint16_t SslProtocolFactory::getPort() const { + return listeningPort; // Immutable no need for lock. +} + +void SslProtocolFactory::accept(Poller::shared_ptr poller, + ConnectionCodec::Factory* fact) { + for (unsigned i = 0; i<listeners.size(); ++i) { + acceptors.push_back( + AsynchAcceptor::create(listeners[i], + boost::bind(&SslProtocolFactory::establishedIncoming, this, poller, _1, fact))); + acceptors[i].start(poller); + } +} + +void SslProtocolFactory::connectFailed( + const Socket& s, int ec, const std::string& emsg, + ConnectFailedCallback failedCb) +{ + failedCb(ec, emsg); + s.close(); + delete &s; +} + +void SslProtocolFactory::connect( Poller::shared_ptr poller, + const std::string& name, const std::string& host, const std::string& port, ConnectionCodec::Factory* fact, ConnectFailedCallback failed) @@ -264,31 +289,23 @@ void SslProtocolFactoryTmpl<T>::connect( // shutdown. The allocated SslConnector frees itself when it // is no longer needed. - qpid::sys::ssl::SslSocket* socket = new qpid::sys::ssl::SslSocket(); - new SslConnector(*socket, poller, host, port, - boost::bind(&SslProtocolFactoryTmpl<T>::established, this, poller, _1, fact, true), - failed); -} - -namespace -{ -const std::string SSL = "ssl"; -} - -template <> -bool SslProtocolFactory::supports(const std::string& capability) -{ - std::string s = capability; - transform(s.begin(), s.end(), s.begin(), tolower); - return s == SSL; -} - -template <> -bool SslMuxProtocolFactory::supports(const std::string& capability) -{ - std::string s = capability; - transform(s.begin(), s.end(), s.begin(), tolower); - return s == SSL || s == "tcp"; + Socket* socket = new qpid::sys::ssl::SslSocket(); + try { + AsynchConnector* c = AsynchConnector::create( + *socket, + host, + port, + boost::bind(&SslProtocolFactory::establishedOutgoing, + this, poller, _1, fact, name), + boost::bind(&SslProtocolFactory::connectFailed, + this, _1, _2, _3, failed)); + c->start(poller); + } catch (std::exception&) { + // TODO: Design question - should we do the error callback and also throw? + int errCode = socket->getError(); + connectFailed(*socket, errCode, strError(errCode), failed); + throw; + } } }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/TCPIOPlugin.cpp b/cpp/src/qpid/sys/TCPIOPlugin.cpp index ed7cc3748d..1ef8708cd0 100644 --- a/cpp/src/qpid/sys/TCPIOPlugin.cpp +++ b/cpp/src/qpid/sys/TCPIOPlugin.cpp @@ -20,15 +20,17 @@ */ #include "qpid/sys/ProtocolFactory.h" -#include "qpid/sys/AsynchIOHandler.h" -#include "qpid/sys/AsynchIO.h" #include "qpid/Plugin.h" +#include "qpid/broker/Broker.h" +#include "qpid/broker/NameGenerator.h" +#include "qpid/log/Statement.h" +#include "qpid/sys/AsynchIOHandler.h" +#include "qpid/sys/AsynchIO.h" #include "qpid/sys/Socket.h" #include "qpid/sys/SocketAddress.h" +#include "qpid/sys/SystemInfo.h" #include "qpid/sys/Poller.h" -#include "qpid/broker/Broker.h" -#include "qpid/log/Statement.h" #include <boost/bind.hpp> #include <boost/ptr_container/ptr_vector.hpp> @@ -47,20 +49,19 @@ class AsynchIOProtocolFactory : public ProtocolFactory { const bool tcpNoDelay; public: - AsynchIOProtocolFactory(const std::string& host, const std::string& port, - int backlog, bool nodelay, - Timer& timer, uint32_t maxTime, - bool shouldListen); + AsynchIOProtocolFactory(const qpid::broker::Broker::Options& opts, Timer& timer, bool shouldListen); void accept(Poller::shared_ptr, ConnectionCodec::Factory*); - void connect(Poller::shared_ptr, const std::string& host, const std::string& port, + void connect(Poller::shared_ptr, const std::string& name, + const std::string& host, const std::string& port, ConnectionCodec::Factory*, ConnectFailedCallback); uint16_t getPort() const; private: - void established(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*, - bool isClient); + void establishedIncoming(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*); + void establishedOutgoing(Poller::shared_ptr, const Socket&, ConnectionCodec::Factory*, const std::string&); + void establishedCommon(AsynchIOHandler*, Poller::shared_ptr , const Socket&); void connectFailed(const Socket&, int, const std::string&, ConnectFailedCallback); }; @@ -93,14 +94,9 @@ static class TCPIOPlugin : public Plugin { bool shouldListen = !sslMultiplexEnabled(); ProtocolFactory::shared_ptr protocolt( - new AsynchIOProtocolFactory( - "", boost::lexical_cast<std::string>(opts.port), - opts.connectionBacklog, - opts.tcpNoDelay, - broker->getTimer(), opts.maxNegotiateTime, - shouldListen)); - - if (shouldListen) { + new AsynchIOProtocolFactory(opts, broker->getTimer(),shouldListen)); + + if (shouldListen && protocolt->getPort()!=0 ) { QPID_LOG(notice, "Listening on TCP/TCP6 port " << protocolt->getPort()); } @@ -109,54 +105,93 @@ static class TCPIOPlugin : public Plugin { } } tcpPlugin; -AsynchIOProtocolFactory::AsynchIOProtocolFactory(const std::string& host, const std::string& port, - int backlog, bool nodelay, - Timer& timer, uint32_t maxTime, - bool shouldListen) : +namespace { + // Expand list of Interfaces and addresses to a list of addresses + std::vector<std::string> expandInterfaces(const std::vector<std::string>& interfaces) { + std::vector<std::string> addresses; + // If there are no specific interfaces listed use a single "" to listen on every interface + if (interfaces.empty()) { + addresses.push_back(""); + return addresses; + } + for (unsigned i = 0; i < interfaces.size(); ++i) { + const std::string& interface = interfaces[i]; + if (!(SystemInfo::getInterfaceAddresses(interface, addresses))) { + // We don't have an interface of that name - + // Check for IPv6 ('[' ']') brackets and remove them + // then pass to be looked up directly + if (interface[0]=='[' && interface[interface.size()-1]==']') { + addresses.push_back(interface.substr(1, interface.size()-2)); + } else { + addresses.push_back(interface); + } + } + } + return addresses; + } +} + +AsynchIOProtocolFactory::AsynchIOProtocolFactory(const qpid::broker::Broker::Options& opts, Timer& timer, bool shouldListen) : brokerTimer(timer), - maxNegotiateTime(maxTime), - tcpNoDelay(nodelay) + maxNegotiateTime(opts.maxNegotiateTime), + tcpNoDelay(opts.tcpNoDelay) { if (!shouldListen) { - listeningPort = boost::lexical_cast<uint16_t>(port); + listeningPort = boost::lexical_cast<uint16_t>(opts.port); return; } - SocketAddress sa(host, port); - - // We must have at least one resolved address - QPID_LOG(info, "Listening to: " << sa.asString()) - Socket* s = new Socket; - uint16_t lport = s->listen(sa, backlog); - QPID_LOG(debug, "Listened to: " << lport); - listeners.push_back(s); + std::vector<std::string> addresses = expandInterfaces(opts.listenInterfaces); + if (addresses.empty()) { + // We specified some interfaces, but couldn't find addresses for them + QPID_LOG(warning, "TCP/TCP6: No specified network interfaces found: Not Listening"); + listeningPort = 0; + } - listeningPort = lport; + for (unsigned i = 0; i<addresses.size(); ++i) { + QPID_LOG(debug, "Using interface: " << addresses[i]); + SocketAddress sa(addresses[i], boost::lexical_cast<std::string>(opts.port)); - // Try any other resolved addresses - while (sa.nextAddress()) { - // Hack to ensure that all listening connections are on the same port - sa.setAddrInfoPort(listeningPort); + // We must have at least one resolved address QPID_LOG(info, "Listening to: " << sa.asString()) - Socket* s = new Socket; - uint16_t lport = s->listen(sa, backlog); + Socket* s = createSocket(); + uint16_t lport = s->listen(sa, opts.connectionBacklog); QPID_LOG(debug, "Listened to: " << lport); listeners.push_back(s); + + listeningPort = lport; + + // Try any other resolved addresses + while (sa.nextAddress()) { + // Hack to ensure that all listening connections are on the same port + sa.setAddrInfoPort(listeningPort); + QPID_LOG(info, "Listening to: " << sa.asString()) + Socket* s = createSocket(); + uint16_t lport = s->listen(sa, opts.connectionBacklog); + QPID_LOG(debug, "Listened to: " << lport); + listeners.push_back(s); + } } +} +void AsynchIOProtocolFactory::establishedIncoming(Poller::shared_ptr poller, const Socket& s, + ConnectionCodec::Factory* f) { + AsynchIOHandler* async = new AsynchIOHandler(broker::QPID_NAME_PREFIX+s.getFullAddress(), f, false, false); + establishedCommon(async, poller, s); } -void AsynchIOProtocolFactory::established(Poller::shared_ptr poller, const Socket& s, - ConnectionCodec::Factory* f, bool isClient) { - AsynchIOHandler* async = new AsynchIOHandler(s.getFullAddress(), f); +void AsynchIOProtocolFactory::establishedOutgoing(Poller::shared_ptr poller, const Socket& s, + ConnectionCodec::Factory* f, const std::string& name) { + AsynchIOHandler* async = new AsynchIOHandler(name, f, true, false); + establishedCommon(async, poller, s); +} +void AsynchIOProtocolFactory::establishedCommon(AsynchIOHandler* async, Poller::shared_ptr poller, const Socket& s) { if (tcpNoDelay) { s.setTcpNoDelay(); QPID_LOG(info, "Set TCP_NODELAY on connection to " << s.getPeerAddress()); } - if (isClient) - async->setClient(); AsynchIO* aio = AsynchIO::create (s, boost::bind(&AsynchIOHandler::readbuff, async, _1, _2), @@ -179,7 +214,7 @@ void AsynchIOProtocolFactory::accept(Poller::shared_ptr poller, for (unsigned i = 0; i<listeners.size(); ++i) { acceptors.push_back( AsynchAcceptor::create(listeners[i], - boost::bind(&AsynchIOProtocolFactory::established, this, poller, _1, fact, false))); + boost::bind(&AsynchIOProtocolFactory::establishedIncoming, this, poller, _1, fact))); acceptors[i].start(poller); } } @@ -195,6 +230,7 @@ void AsynchIOProtocolFactory::connectFailed( void AsynchIOProtocolFactory::connect( Poller::shared_ptr poller, + const std::string& name, const std::string& host, const std::string& port, ConnectionCodec::Factory* fact, ConnectFailedCallback failed) @@ -204,14 +240,14 @@ void AsynchIOProtocolFactory::connect( // upon connection failure or by the AsynchIO upon connection // shutdown. The allocated AsynchConnector frees itself when it // is no longer needed. - Socket* socket = new Socket(); + Socket* socket = createSocket(); try { AsynchConnector* c = AsynchConnector::create( *socket, host, port, - boost::bind(&AsynchIOProtocolFactory::established, - this, poller, _1, fact, true), + boost::bind(&AsynchIOProtocolFactory::establishedOutgoing, + this, poller, _1, fact, name), boost::bind(&AsynchIOProtocolFactory::connectFailed, this, _1, _2, _3, failed)); c->start(poller); diff --git a/cpp/src/qpid/sys/Timer.cpp b/cpp/src/qpid/sys/Timer.cpp index 973c6bd8b7..f8eef2c9ec 100644 --- a/cpp/src/qpid/sys/Timer.cpp +++ b/cpp/src/qpid/sys/Timer.cpp @@ -96,18 +96,13 @@ void TimerTask::cancel() { state = CANCELLED; } -void TimerTask::setFired() { - // Set nextFireTime to just before now, making readyToFire() true. - nextFireTime = AbsTime(sys::now(), Duration(-1)); -} - - +// TODO AStitcher 21/08/09 The threshholds for emitting warnings are a little arbitrary Timer::Timer() : active(false), late(50 * TIME_MSEC), overran(2 * TIME_MSEC), lateCancel(500 * TIME_MSEC), - warn(5 * TIME_SEC) + warn(60 * TIME_SEC) { start(); } @@ -133,7 +128,6 @@ public: } }; -// TODO AStitcher 21/08/09 The threshholds for emitting warnings are a little arbitrary void Timer::run() { Monitor::ScopedLock l(monitor); @@ -151,10 +145,6 @@ void Timer::run() { TimerTaskCallbackScope s(*t); if (s) { - { - Monitor::ScopedUnlock u(monitor); - drop(t); - } if (delay > lateCancel) { QPID_LOG(debug, t->name << " cancelled timer woken up " << delay / TIME_MSEC << "ms late"); @@ -171,8 +161,8 @@ void Timer::run() if (!tasks.empty()) { overrun = Duration(tasks.top()->nextFireTime, end); } - bool warningsEnabled; - QPID_LOG_TEST(warning, warningsEnabled); + bool warningsEnabled; // TimerWarning enabled + QPID_LOG_TEST(debug, warningsEnabled); // TimerWarning emitted at debug level if (warningsEnabled) { if (overrun > overran) { if (delay > overran) // if delay is significant to an overrun. @@ -235,9 +225,6 @@ void Timer::fire(boost::intrusive_ptr<TimerTask> t) { } } -// Provided for subclasses: called when a task is droped. -void Timer::drop(boost::intrusive_ptr<TimerTask>) {} - bool operator<(const intrusive_ptr<TimerTask>& a, const intrusive_ptr<TimerTask>& b) { diff --git a/cpp/src/qpid/sys/Timer.h b/cpp/src/qpid/sys/Timer.h index 5731b8d977..5045009609 100644 --- a/cpp/src/qpid/sys/Timer.h +++ b/cpp/src/qpid/sys/Timer.h @@ -67,10 +67,6 @@ class TimerTask : public RefCounted { std::string getName() const { return name; } - // Move the nextFireTime so readyToFire is true. - // Used by the cluster, where tasks are fired on cluster events, not on local time. - QPID_COMMON_EXTERN void setFired(); - protected: // Must be overridden with callback virtual void fire() = 0; @@ -99,7 +95,7 @@ class Timer : private Runnable { protected: QPID_COMMON_EXTERN virtual void fire(boost::intrusive_ptr<TimerTask> task); - QPID_COMMON_EXTERN virtual void drop(boost::intrusive_ptr<TimerTask> task); + // Allow derived classes to change the late/overran thresholds. Duration late; Duration overran; diff --git a/cpp/src/qpid/sys/TimerWarnings.cpp b/cpp/src/qpid/sys/TimerWarnings.cpp index 85e26da54a..00fb0d9db6 100644 --- a/cpp/src/qpid/sys/TimerWarnings.cpp +++ b/cpp/src/qpid/sys/TimerWarnings.cpp @@ -56,18 +56,18 @@ void TimerWarnings::log() { std::string task = i->first; TaskStats& stats = i->second; if (stats.lateDelay.count) - QPID_LOG(info, task << " task late " + QPID_LOG(debug, task << " task late " << stats.lateDelay.count << " times by " << stats.lateDelay.average()/TIME_MSEC << "ms on average."); if (stats.overranOverrun.count) - QPID_LOG(info, task << " task overran " + QPID_LOG(debug, task << " task overran " << stats.overranOverrun.count << " times by " << stats.overranOverrun.average()/TIME_MSEC << "ms (taking " << stats.overranTime.average() << "ns) on average."); if (stats.lateAndOverranOverrun.count) - QPID_LOG(info, task << " task late and overran " + QPID_LOG(debug, task << " task late and overran " << stats.lateAndOverranOverrun.count << " times: late " << stats.lateAndOverranDelay.average()/TIME_MSEC << "ms, overran " << stats.lateAndOverranOverrun.average()/TIME_MSEC << "ms (taking " diff --git a/cpp/src/qpid/sys/alloca.h b/cpp/src/qpid/sys/alloca.h deleted file mode 100644 index b3f59b7c3f..0000000000 --- a/cpp/src/qpid/sys/alloca.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef QPID_SYS_ALLOCA_H -#define QPID_SYS_ALLOCA_H - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#if (defined(_WINDOWS) || defined (WIN32)) -# include <malloc.h> - -# if defined(_MSC_VER) -# ifdef alloc -# undef alloc -# endif -# define alloc _alloc -# ifdef alloca -# undef alloca -# endif -# define alloca _alloca -# endif -#endif -#if !defined _WINDOWS && !defined WIN32 -# include <alloca.h> -#endif - -#endif /*!QPID_SYS_ALLOCA_H*/ diff --git a/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.cpp b/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.cpp index 29b91f3e7a..79d9d08a59 100644 --- a/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.cpp +++ b/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.cpp @@ -29,8 +29,8 @@ namespace qpid { namespace sys { namespace cyrus { -CyrusSecurityLayer::CyrusSecurityLayer(sasl_conn_t* c, uint16_t maxFrameSize) : - conn(c), decrypted(0), decryptedSize(0), encrypted(0), encryptedSize(0), codec(0), maxInputSize(0), +CyrusSecurityLayer::CyrusSecurityLayer(sasl_conn_t* c, uint16_t maxFrameSize, int ssf) : + SecurityLayer(ssf), conn(c), decrypted(0), decryptedSize(0), encrypted(0), encryptedSize(0), codec(0), maxInputSize(0), decodeBuffer(maxFrameSize), encodeBuffer(maxFrameSize), encoded(0) { const void* value(0); @@ -68,7 +68,7 @@ size_t CyrusSecurityLayer::decode(const char* input, size_t size) return size; } -size_t CyrusSecurityLayer::encode(const char* buffer, size_t size) +size_t CyrusSecurityLayer::encode(char* buffer, size_t size) { size_t processed = 0;//records how many bytes have been written to buffer do { @@ -92,12 +92,12 @@ size_t CyrusSecurityLayer::encode(const char* buffer, size_t size) //can't fit all encrypted data in the buffer we've //been given, copy in what we can and hold on to the //rest until the next call - ::memcpy(const_cast<char*>(buffer + processed), encrypted, remaining); + ::memcpy(buffer + processed, encrypted, remaining); processed += remaining; encrypted += remaining; encryptedSize -= remaining; } else { - ::memcpy(const_cast<char*>(buffer + processed), encrypted, encryptedSize); + ::memcpy(buffer + processed, encrypted, encryptedSize); processed += encryptedSize; encrypted = 0; encryptedSize = 0; diff --git a/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.h b/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.h index 1645cf1a58..ae86ba5569 100644 --- a/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.h +++ b/cpp/src/qpid/sys/cyrus/CyrusSecurityLayer.h @@ -37,9 +37,9 @@ namespace cyrus { class CyrusSecurityLayer : public qpid::sys::SecurityLayer { public: - CyrusSecurityLayer(sasl_conn_t*, uint16_t maxFrameSize); + CyrusSecurityLayer(sasl_conn_t*, uint16_t maxFrameSize, int ssf); size_t decode(const char* buffer, size_t size); - size_t encode(const char* buffer, size_t size); + size_t encode(char* buffer, size_t size); bool canEncode(); void init(qpid::sys::Codec*); private: diff --git a/cpp/src/qpid/sys/epoll/EpollPoller.cpp b/cpp/src/qpid/sys/epoll/EpollPoller.cpp index c23403c66d..6fdf99637f 100644 --- a/cpp/src/qpid/sys/epoll/EpollPoller.cpp +++ b/cpp/src/qpid/sys/epoll/EpollPoller.cpp @@ -20,7 +20,6 @@ */ #include "qpid/sys/Poller.h" -#include "qpid/sys/IOHandle.h" #include "qpid/sys/Mutex.h" #include "qpid/sys/AtomicCount.h" #include "qpid/sys/DeletionManager.h" @@ -64,12 +63,12 @@ class PollerHandlePrivate { }; ::__uint32_t events; - const IOHandlePrivate* ioHandle; + const IOHandle* ioHandle; PollerHandle* pollerHandle; FDStat stat; Mutex lock; - PollerHandlePrivate(const IOHandlePrivate* h, PollerHandle* p) : + PollerHandlePrivate(const IOHandle* h, PollerHandle* p) : events(0), ioHandle(h), pollerHandle(p), @@ -77,7 +76,7 @@ class PollerHandlePrivate { } int fd() const { - return toFd(ioHandle); + return ioHandle->fd; } bool isActive() const { @@ -138,7 +137,7 @@ class PollerHandlePrivate { }; PollerHandle::PollerHandle(const IOHandle& h) : - impl(new PollerHandlePrivate(h.impl, this)) + impl(new PollerHandlePrivate(&h, this)) {} PollerHandle::~PollerHandle() { @@ -385,6 +384,7 @@ void PollerPrivate::resetMode(PollerHandlePrivate& eh) { int rc = ::epoll_ctl(epollFd, EPOLL_CTL_MOD, eh.fd(), &epe); // If something has closed the fd in the meantime try adding it back if (rc ==-1 && errno == ENOENT) { + eh.setIdle(); // Reset our handle as if starting from scratch rc = ::epoll_ctl(epollFd, EPOLL_CTL_ADD, eh.fd(), &epe); } QPID_POSIX_CHECK(rc); diff --git a/cpp/src/qpid/sys/posix/AsynchIO.cpp b/cpp/src/qpid/sys/posix/AsynchIO.cpp index 31355627cd..353a55f50c 100644 --- a/cpp/src/qpid/sys/posix/AsynchIO.cpp +++ b/cpp/src/qpid/sys/posix/AsynchIO.cpp @@ -143,6 +143,7 @@ class AsynchConnector : public qpid::sys::AsynchConnector, private: void connComplete(DispatchHandle& handle); + void requestedCall(RequestCallback rCb); private: ConnectedCallback connCallback; @@ -158,6 +159,7 @@ public: FailedCallback failCb); void start(Poller::shared_ptr poller); void stop(); + void requestCallback(RequestCallback rCb); }; AsynchConnector::AsynchConnector(const Socket& s, @@ -191,11 +193,30 @@ void AsynchConnector::stop() stopWatch(); } +void AsynchConnector::requestCallback(RequestCallback callback) { + // TODO creating a function object every time isn't all that + // efficient - if this becomes heavily used do something better (what?) + assert(callback); + DispatchHandle::call(boost::bind(&AsynchConnector::requestedCall, this, callback)); +} + +void AsynchConnector::requestedCall(RequestCallback callback) { + assert(callback); + callback(*this); +} + void AsynchConnector::connComplete(DispatchHandle& h) { int errCode = socket.getError(); if (errCode == 0) { h.stopWatch(); + try { + socket.finishConnect(sa); + } catch (const std::exception& e) { + failCallback(socket, 0, e.what()); + DispatchHandle::doDelete(); + return; + } connCallback(socket); } else { // Retry while we cause an immediate exception @@ -247,10 +268,9 @@ public: virtual void notifyPendingWrite(); virtual void queueWriteClose(); virtual bool writeQueueEmpty(); - virtual void startReading(); - virtual void stopReading(); virtual void requestCallback(RequestCallback); virtual BufferBase* getQueuedBuffer(); + virtual SecuritySettings getSecuritySettings(); private: ~AsynchIO(); @@ -282,13 +302,6 @@ private: * thread processing this handle. */ volatile bool writePending; - /** - * This records whether we've been reading is flow controlled: - * it's safe as a simple boolean as the only way to be stopped - * is in calls only allowed in the callback context, the only calls - * checking it are also in calls only allowed in callback context. - */ - volatile bool readingStopped; }; AsynchIO::AsynchIO(const Socket& s, @@ -307,8 +320,7 @@ AsynchIO::AsynchIO(const Socket& s, idleCallback(iCb), socket(s), queuedClose(false), - writePending(false), - readingStopped(false) { + writePending(false) { s.setNonblocking(); } @@ -344,7 +356,7 @@ void AsynchIO::queueReadBuffer(BufferBase* buff) { bool queueWasEmpty = bufferQueue.empty(); bufferQueue.push_back(buff); - if (queueWasEmpty && !readingStopped) + if (queueWasEmpty) DispatchHandle::rewatchRead(); } @@ -354,7 +366,7 @@ void AsynchIO::unread(BufferBase* buff) { bool queueWasEmpty = bufferQueue.empty(); bufferQueue.push_front(buff); - if (queueWasEmpty && !readingStopped) + if (queueWasEmpty) DispatchHandle::rewatchRead(); } @@ -386,17 +398,6 @@ bool AsynchIO::writeQueueEmpty() { return writeQueue.empty(); } -// This can happen outside the callback context -void AsynchIO::startReading() { - readingStopped = false; - DispatchHandle::rewatchRead(); -} - -void AsynchIO::stopReading() { - readingStopped = true; - DispatchHandle::unwatchRead(); -} - void AsynchIO::requestCallback(RequestCallback callback) { // TODO creating a function object every time isn't all that // efficient - if this becomes heavily used do something better (what?) @@ -429,11 +430,6 @@ AsynchIO::BufferBase* AsynchIO::getQueuedBuffer() { * to put it in and reading is not stopped by flow control. */ void AsynchIO::readable(DispatchHandle& h) { - if (readingStopped) { - // We have been flow controlled. - QPID_PROBE1(asynchio_read_flowcontrolled, &h); - return; - } AbsTime readStartTime = AbsTime::now(); size_t total = 0; int readCalls = 0; @@ -455,12 +451,6 @@ void AsynchIO::readable(DispatchHandle& h) { total += rc; readCallback(*this, buff); - if (readingStopped) { - // We have been flow controlled. - QPID_PROBE4(asynchio_read_finished_flowcontrolled, &h, duration, total, readCalls); - break; - } - if (rc != readCount) { // If we didn't fill the read buffer then time to stop reading QPID_PROBE4(asynchio_read_finished_done, &h, duration, total, readCalls); @@ -626,6 +616,13 @@ void AsynchIO::close(DispatchHandle& h) { } } +SecuritySettings AsynchIO::getSecuritySettings() { + SecuritySettings settings; + settings.ssf = socket.getKeyLen(); + settings.authid = socket.getClientAuthId(); + return settings; +} + } // namespace posix AsynchAcceptor* AsynchAcceptor::create(const Socket& s, diff --git a/cpp/src/qpid/sys/posix/Socket.cpp b/cpp/src/qpid/sys/posix/BSDSocket.cpp index 77ae1af60c..7c31b13ae9 100644 --- a/cpp/src/qpid/sys/posix/Socket.cpp +++ b/cpp/src/qpid/sys/posix/BSDSocket.cpp @@ -19,7 +19,7 @@ * */ -#include "qpid/sys/Socket.h" +#include "qpid/sys/posix/BSDSocket.h" #include "qpid/sys/SocketAddress.h" #include "qpid/sys/posix/check.h" @@ -67,25 +67,41 @@ uint16_t getLocalPort(int fd) } } -Socket::Socket() : - IOHandle(new IOHandlePrivate), +BSDSocket::BSDSocket() : + fd(-1), + handle(new IOHandle), nonblocking(false), nodelay(false) {} -Socket::Socket(IOHandlePrivate* h) : - IOHandle(h), +Socket* createSocket() +{ + return new BSDSocket; +} + +BSDSocket::BSDSocket(int fd0) : + fd(fd0), + handle(new IOHandle(fd)), nonblocking(false), nodelay(false) {} -void Socket::createSocket(const SocketAddress& sa) const +BSDSocket::~BSDSocket() +{} + +BSDSocket::operator const IOHandle&() const +{ + return *handle; +} + +void BSDSocket::createSocket(const SocketAddress& sa) const { - int& socket = impl->fd; - if (socket != -1) Socket::close(); + int& socket = fd; + if (socket != -1) BSDSocket::close(); int s = ::socket(getAddrInfo(sa).ai_family, getAddrInfo(sa).ai_socktype, 0); if (s < 0) throw QPID_POSIX_ERROR(errno); socket = s; + *handle = IOHandle(s); try { if (nonblocking) setNonblocking(); @@ -98,50 +114,31 @@ void Socket::createSocket(const SocketAddress& sa) const } catch (std::exception&) { ::close(s); socket = -1; + *handle = IOHandle(); throw; } } -Socket* Socket::createSameTypeSocket() const { - int& socket = impl->fd; - // Socket currently has no actual socket attached - if (socket == -1) - return new Socket; - - ::sockaddr_storage sa; - ::socklen_t salen = sizeof(sa); - QPID_POSIX_CHECK(::getsockname(socket, (::sockaddr*)&sa, &salen)); - int s = ::socket(sa.ss_family, SOCK_STREAM, 0); // Currently only work with SOCK_STREAM - if (s < 0) throw QPID_POSIX_ERROR(errno); - return new Socket(new IOHandlePrivate(s)); -} - -void Socket::setNonblocking() const { - int& socket = impl->fd; +void BSDSocket::setNonblocking() const { + int& socket = fd; nonblocking = true; if (socket != -1) { QPID_POSIX_CHECK(::fcntl(socket, F_SETFL, O_NONBLOCK)); } } -void Socket::setTcpNoDelay() const +void BSDSocket::setTcpNoDelay() const { - int& socket = impl->fd; + int& socket = fd; nodelay = true; if (socket != -1) { int flag = 1; - int result = ::setsockopt(impl->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag)); + int result = ::setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag)); QPID_POSIX_CHECK(result); } } -void Socket::connect(const std::string& host, const std::string& port) const -{ - SocketAddress sa(host, port); - connect(sa); -} - -void Socket::connect(const SocketAddress& addr) const +void BSDSocket::connect(const SocketAddress& addr) const { // The display name for an outbound connection needs to be the name that was specified // for the address rather than a resolved IP address as we don't know which of @@ -154,7 +151,7 @@ void Socket::connect(const SocketAddress& addr) const createSocket(addr); - const int& socket = impl->fd; + const int& socket = fd; // TODO the correct thing to do here is loop on failure until you've used all the returned addresses if ((::connect(socket, getAddrInfo(addr).ai_addr, getAddrInfo(addr).ai_addrlen) < 0) && (errno != EINPROGRESS)) { @@ -165,11 +162,6 @@ void Socket::connect(const SocketAddress& addr) const // remote port (which is unoccupied) as the port to bind the local // end of the socket, resulting in a "circular" connection. // - // This seems like something the OS should prevent but I have - // confirmed that sporadic hangs in - // cluster_tests.LongTests.test_failover on RHEL5 are caused by - // such a circular connection. - // // Raise an error if we see such a connection, since we know there is // no listener on the peer address. // @@ -179,26 +171,25 @@ void Socket::connect(const SocketAddress& addr) const } } +void BSDSocket::finishConnect(const SocketAddress&) const +{ +} + void -Socket::close() const +BSDSocket::close() const { - int& socket = impl->fd; + int& socket = fd; if (socket == -1) return; if (::close(socket) < 0) throw QPID_POSIX_ERROR(errno); socket = -1; + *handle = IOHandle(); } -int Socket::listen(const std::string& host, const std::string& port, int backlog) const -{ - SocketAddress sa(host, port); - return listen(sa, backlog); -} - -int Socket::listen(const SocketAddress& sa, int backlog) const +int BSDSocket::listen(const SocketAddress& sa, int backlog) const { createSocket(sa); - const int& socket = impl->fd; + const int& socket = fd; int yes=1; QPID_POSIX_CHECK(::setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes))); @@ -210,11 +201,11 @@ int Socket::listen(const SocketAddress& sa, int backlog) const return getLocalPort(socket); } -Socket* Socket::accept() const +Socket* BSDSocket::accept() const { - int afd = ::accept(impl->fd, 0, 0); + int afd = ::accept(fd, 0, 0); if ( afd >= 0) { - Socket* s = new Socket(new IOHandlePrivate(afd)); + BSDSocket* s = new BSDSocket(afd); s->localname = localname; return s; } @@ -223,41 +214,51 @@ Socket* Socket::accept() const else throw QPID_POSIX_ERROR(errno); } -int Socket::read(void *buf, size_t count) const +int BSDSocket::read(void *buf, size_t count) const { - return ::read(impl->fd, buf, count); + return ::read(fd, buf, count); } -int Socket::write(const void *buf, size_t count) const +int BSDSocket::write(const void *buf, size_t count) const { - return ::write(impl->fd, buf, count); + return ::write(fd, buf, count); } -std::string Socket::getPeerAddress() const +std::string BSDSocket::getPeerAddress() const { if (peername.empty()) { - peername = getName(impl->fd, false); + peername = getName(fd, false); } return peername; } -std::string Socket::getLocalAddress() const +std::string BSDSocket::getLocalAddress() const { if (localname.empty()) { - localname = getName(impl->fd, true); + localname = getName(fd, true); } return localname; } -int Socket::getError() const +int BSDSocket::getError() const { int result; socklen_t rSize = sizeof (result); - if (::getsockopt(impl->fd, SOL_SOCKET, SO_ERROR, &result, &rSize) < 0) + if (::getsockopt(fd, SOL_SOCKET, SO_ERROR, &result, &rSize) < 0) throw QPID_POSIX_ERROR(errno); return result; } +int BSDSocket::getKeyLen() const +{ + return 0; +} + +std::string BSDSocket::getClientAuthId() const +{ + return std::string(); +} + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/posix/BSDSocket.h b/cpp/src/qpid/sys/posix/BSDSocket.h new file mode 100644 index 0000000000..862d36c1b9 --- /dev/null +++ b/cpp/src/qpid/sys/posix/BSDSocket.h @@ -0,0 +1,113 @@ +#ifndef QPID_SYS_BSDSOCKET_H +#define QPID_SYS_BSDSOCKET_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/sys/Socket.h" +#include "qpid/sys/IntegerTypes.h" +#include "qpid/CommonImportExport.h" +#include <string> + +#include <boost/scoped_ptr.hpp> + +namespace qpid { +namespace sys { + +class Duration; +class IOHandle; +class SocketAddress; + +namespace ssl { +class SslMuxSocket; +} + +class QPID_COMMON_CLASS_EXTERN BSDSocket : public Socket +{ +public: + /** Create a socket wrapper for descriptor. */ + QPID_COMMON_EXTERN BSDSocket(); + QPID_COMMON_EXTERN ~BSDSocket(); + + QPID_COMMON_EXTERN operator const IOHandle&() const; + + /** Set socket non blocking */ + QPID_COMMON_EXTERN virtual void setNonblocking() const; + + QPID_COMMON_EXTERN virtual void setTcpNoDelay() const; + + QPID_COMMON_EXTERN virtual void connect(const SocketAddress&) const; + QPID_COMMON_EXTERN virtual void finishConnect(const SocketAddress&) const; + + QPID_COMMON_EXTERN virtual void close() const; + + /** Bind to a port and start listening. + *@return The bound port number + */ + QPID_COMMON_EXTERN virtual int listen(const SocketAddress&, int backlog = 10) const; + + /** + * Returns an address (host and port) for the remote end of the + * socket + */ + QPID_COMMON_EXTERN std::string getPeerAddress() const; + /** + * Returns an address (host and port) for the local end of the + * socket + */ + QPID_COMMON_EXTERN std::string getLocalAddress() const; + + /** + * Returns the error code stored in the socket. This may be used + * to determine the result of a non-blocking connect. + */ + QPID_COMMON_EXTERN int getError() const; + + /** Accept a connection from a socket that is already listening + * and has an incoming connection + */ + QPID_COMMON_EXTERN virtual Socket* accept() const; + + // TODO The following are raw operations, maybe they need better wrapping? + QPID_COMMON_EXTERN virtual int read(void *buf, size_t count) const; + QPID_COMMON_EXTERN virtual int write(const void *buf, size_t count) const; + + QPID_COMMON_EXTERN int getKeyLen() const; + QPID_COMMON_EXTERN std::string getClientAuthId() const; + +protected: + /** Create socket */ + void createSocket(const SocketAddress&) const; + + mutable int fd; + mutable boost::scoped_ptr<IOHandle> handle; + mutable std::string localname; + mutable std::string peername; + mutable bool nonblocking; + mutable bool nodelay; + + /** Construct socket with existing handle */ + BSDSocket(int fd); + friend class qpid::sys::ssl::SslMuxSocket; // Needed for this constructor +}; + +}} +#endif /*!QPID_SYS_BSDSOCKET_H*/ diff --git a/cpp/src/qpid/sys/posix/FileSysDir.cpp b/cpp/src/qpid/sys/posix/FileSysDir.cpp index 22dc487e74..cec580164d 100755 --- a/cpp/src/qpid/sys/posix/FileSysDir.cpp +++ b/cpp/src/qpid/sys/posix/FileSysDir.cpp @@ -18,6 +18,7 @@ #include "qpid/sys/FileSysDir.h" #include "qpid/sys/StrError.h" +#include "qpid/log/Statement.h" #include "qpid/Exception.h" #include <sys/types.h> @@ -25,6 +26,8 @@ #include <fcntl.h> #include <cerrno> #include <unistd.h> +#include <dirent.h> +#include <stdlib.h> namespace qpid { namespace sys { @@ -51,4 +54,27 @@ void FileSysDir::mkdir(void) throw Exception ("Can't create directory: " + dirPath); } +void FileSysDir::forEachFile(Callback cb) const { + + ::dirent** namelist; + + int n = scandir(dirPath.c_str(), &namelist, 0, alphasort); + if (n == -1) throw Exception (strError(errno) + ": Can't scan directory: " + dirPath); + + for (int i = 0; i<n; ++i) { + std::string fullpath = dirPath + "/" + namelist[i]->d_name; + // Filter out non files/stat problems etc. + struct ::stat s; + // Can't throw here without leaking memory, so just do nothing with + // entries for which stat() fails. + if (!::stat(fullpath.c_str(), &s)) { + if (S_ISREG(s.st_mode)) { + cb(fullpath); + } + } + ::free(namelist[i]); + } + ::free(namelist); +} + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/posix/IOHandle.cpp b/cpp/src/qpid/sys/posix/IOHandle.cpp index 9c049ee1de..d3f502a63c 100644 --- a/cpp/src/qpid/sys/posix/IOHandle.cpp +++ b/cpp/src/qpid/sys/posix/IOHandle.cpp @@ -19,26 +19,11 @@ * */ -#include "qpid/sys/IOHandle.h" - #include "qpid/sys/posix/PrivatePosix.h" namespace qpid { namespace sys { -int toFd(const IOHandlePrivate* h) -{ - return h->fd; -} - NullIOHandle DummyIOHandle; -IOHandle::IOHandle(IOHandlePrivate* h) : - impl(h) -{} - -IOHandle::~IOHandle() { - delete impl; -} - }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/posix/PollableCondition.cpp b/cpp/src/qpid/sys/posix/PollableCondition.cpp index abff8a5be8..aa129faf20 100644 --- a/cpp/src/qpid/sys/posix/PollableCondition.cpp +++ b/cpp/src/qpid/sys/posix/PollableCondition.cpp @@ -21,7 +21,6 @@ #include "qpid/sys/PollableCondition.h" #include "qpid/sys/DispatchHandle.h" -#include "qpid/sys/IOHandle.h" #include "qpid/sys/posix/PrivatePosix.h" #include "qpid/Exception.h" @@ -58,14 +57,14 @@ PollableConditionPrivate::PollableConditionPrivate( const sys::PollableCondition::Callback& cb, sys::PollableCondition& parent, const boost::shared_ptr<sys::Poller>& poller -) : IOHandle(new sys::IOHandlePrivate), cb(cb), parent(parent) +) : cb(cb), parent(parent) { int fds[2]; if (::pipe(fds) == -1) throw ErrnoException(QPID_MSG("Can't create PollableCondition")); - impl->fd = fds[0]; + fd = fds[0]; writeFd = fds[1]; - if (::fcntl(impl->fd, F_SETFL, O_NONBLOCK) == -1) + if (::fcntl(fd, F_SETFL, O_NONBLOCK) == -1) throw ErrnoException(QPID_MSG("Can't create PollableCondition")); if (::fcntl(writeFd, F_SETFL, O_NONBLOCK) == -1) throw ErrnoException(QPID_MSG("Can't create PollableCondition")); diff --git a/cpp/src/qpid/sys/posix/PosixPoller.cpp b/cpp/src/qpid/sys/posix/PosixPoller.cpp index eb0c3384d1..ae839b2e20 100644 --- a/cpp/src/qpid/sys/posix/PosixPoller.cpp +++ b/cpp/src/qpid/sys/posix/PosixPoller.cpp @@ -88,12 +88,12 @@ class PollerHandlePrivate { }; short events; - const IOHandlePrivate* ioHandle; + const IOHandle* ioHandle; PollerHandle* pollerHandle; FDStat stat; Mutex lock; - PollerHandlePrivate(const IOHandlePrivate* h, PollerHandle* p) : + PollerHandlePrivate(const IOHandle* h, PollerHandle* p) : events(0), ioHandle(h), pollerHandle(p), @@ -101,7 +101,7 @@ class PollerHandlePrivate { } int fd() const { - return toFd(ioHandle); + return ioHandle->fd; } bool isActive() const { @@ -162,7 +162,7 @@ class PollerHandlePrivate { }; PollerHandle::PollerHandle(const IOHandle& h) : - impl(new PollerHandlePrivate(h.impl, this)) + impl(new PollerHandlePrivate(&h, this)) {} PollerHandle::~PollerHandle() { diff --git a/cpp/src/qpid/sys/posix/SocketAddress.cpp b/cpp/src/qpid/sys/posix/SocketAddress.cpp index 344bd28669..cd23442226 100644 --- a/cpp/src/qpid/sys/posix/SocketAddress.cpp +++ b/cpp/src/qpid/sys/posix/SocketAddress.cpp @@ -102,6 +102,11 @@ std::string SocketAddress::asString(bool numeric) const return asString(ai.ai_addr, ai.ai_addrlen); } +std::string SocketAddress::getHost() const +{ + return host; +} + bool SocketAddress::nextAddress() { bool r = currentAddrInfo->ai_next != 0; if (r) diff --git a/cpp/src/qpid/sys/posix/SystemInfo.cpp b/cpp/src/qpid/sys/posix/SystemInfo.cpp index cfd2c64aee..ea7f521f2b 100755 --- a/cpp/src/qpid/sys/posix/SystemInfo.cpp +++ b/cpp/src/qpid/sys/posix/SystemInfo.cpp @@ -21,7 +21,6 @@ #include "qpid/log/Statement.h" #include "qpid/sys/SystemInfo.h" #include "qpid/sys/posix/check.h" -#include <set> #include <arpa/inet.h> #include <sys/ioctl.h> #include <sys/utsname.h> @@ -33,6 +32,7 @@ #include <iostream> #include <fstream> #include <sstream> +#include <map> #include <netdb.h> #include <string.h> @@ -77,84 +77,70 @@ inline bool isLoopback(const ::sockaddr* addr) { } } -void SystemInfo::getLocalIpAddresses (uint16_t port, - std::vector<Address> &addrList) { - ::ifaddrs* ifaddr = 0; - QPID_POSIX_CHECK(::getifaddrs(&ifaddr)); - for (::ifaddrs* ifap = ifaddr; ifap != 0; ifap = ifap->ifa_next) { - if (ifap->ifa_addr == 0) continue; - if (isLoopback(ifap->ifa_addr)) continue; - int family = ifap->ifa_addr->sa_family; - switch (family) { - case AF_INET6: { - // Ignore link local addresses as: - // * The scope id is illegal in URL syntax - // * Clients won't be able to use a link local address - // without adding their own (potentially different) scope id - sockaddr_in6* sa6 = (sockaddr_in6*)((void*)ifap->ifa_addr); - if (IN6_IS_ADDR_LINKLOCAL(&sa6->sin6_addr)) break; - // Fallthrough - } - case AF_INET: { - char dispName[NI_MAXHOST]; - int rc = ::getnameinfo( - ifap->ifa_addr, - (family == AF_INET) - ? sizeof(struct sockaddr_in) - : sizeof(struct sockaddr_in6), - dispName, sizeof(dispName), - 0, 0, NI_NUMERICHOST); - if (rc != 0) { - throw QPID_POSIX_ERROR(rc); - } - string addr(dispName); - addrList.push_back(Address(TCP, addr, port)); - break; - } - default: - continue; +namespace { + inline socklen_t sa_len(::sockaddr* sa) + { + switch (sa->sa_family) { + case AF_INET: + return sizeof(struct sockaddr_in); + case AF_INET6: + return sizeof(struct sockaddr_in6); + default: + return sizeof(struct sockaddr_storage); } } - ::freeifaddrs(ifaddr); - if (addrList.empty()) { - addrList.push_back(Address(TCP, LOOPBACK, port)); + inline bool isInetOrInet6(::sockaddr* sa) { + switch (sa->sa_family) { + case AF_INET: + case AF_INET6: + return true; + default: + return false; + } + } + typedef std::map<std::string, std::vector<std::string> > InterfaceInfo; + std::map<std::string, std::vector<std::string> > cachedInterfaces; + + void cacheInterfaceInfo() { + // Get interface info + ::ifaddrs* interfaceInfo; + QPID_POSIX_CHECK( ::getifaddrs(&interfaceInfo) ); + + char name[NI_MAXHOST]; + for (::ifaddrs* info = interfaceInfo; info != 0; info = info->ifa_next) { + + // Only use IPv4/IPv6 interfaces + if (!isInetOrInet6(info->ifa_addr)) continue; + + int rc=::getnameinfo(info->ifa_addr, sa_len(info->ifa_addr), + name, sizeof(name), 0, 0, + NI_NUMERICHOST); + if (rc >= 0) { + std::string address(name); + cachedInterfaces[info->ifa_name].push_back(address); + } else { + throw qpid::Exception(QPID_MSG(gai_strerror(rc))); + } + } + ::freeifaddrs(interfaceInfo); } } -namespace { -struct AddrInfo { - struct addrinfo* ptr; - AddrInfo(const std::string& host) : ptr(0) { - ::addrinfo hints; - ::memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; // Allow both IPv4 and IPv6 - if (::getaddrinfo(host.c_str(), NULL, &hints, &ptr) != 0) - ptr = 0; - } - ~AddrInfo() { if (ptr) ::freeaddrinfo(ptr); } -}; +bool SystemInfo::getInterfaceAddresses(const std::string& interface, std::vector<std::string>& addresses) { + if ( cachedInterfaces.empty() ) cacheInterfaceInfo(); + InterfaceInfo::iterator i = cachedInterfaces.find(interface); + if ( i==cachedInterfaces.end() ) return false; + std::copy(i->second.begin(), i->second.end(), std::back_inserter(addresses)); + return true; } -bool SystemInfo::isLocalHost(const std::string& host) { - std::vector<Address> myAddrs; - getLocalIpAddresses(0, myAddrs); - std::set<string> localHosts; - for (std::vector<Address>::const_iterator i = myAddrs.begin(); i != myAddrs.end(); ++i) - localHosts.insert(i->host); - // Resolve host - AddrInfo ai(host); - if (!ai.ptr) return false; - for (struct addrinfo *res = ai.ptr; res != NULL; res = res->ai_next) { - if (isLoopback(res->ai_addr)) return true; - // Get string form of IP addr - char addr[NI_MAXHOST] = ""; - int error = ::getnameinfo(res->ai_addr, res->ai_addrlen, addr, NI_MAXHOST, NULL, 0, - NI_NUMERICHOST | NI_NUMERICSERV); - if (error) return false; - if (localHosts.find(addr) != localHosts.end()) return true; +void SystemInfo::getInterfaceNames(std::vector<std::string>& names ) { + if ( cachedInterfaces.empty() ) cacheInterfaceInfo(); + + for (InterfaceInfo::const_iterator i = cachedInterfaces.begin(); i!=cachedInterfaces.end(); ++i) { + names.push_back(i->first); } - return false; } void SystemInfo::getSystemId (std::string &osName, @@ -205,4 +191,11 @@ string SystemInfo::getProcessName() return value; } +// Always true. Only Windows has exception cases. +bool SystemInfo::threadSafeShutdown() +{ + return true; +} + + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/rdma/rdma_wrap.cpp b/cpp/src/qpid/sys/rdma/rdma_wrap.cpp index efe454c5be..889ee9ff75 100644 --- a/cpp/src/qpid/sys/rdma/rdma_wrap.cpp +++ b/cpp/src/qpid/sys/rdma/rdma_wrap.cpp @@ -105,7 +105,7 @@ namespace Rdma { } QueuePair::QueuePair(boost::shared_ptr< ::rdma_cm_id > i) : - qpid::sys::IOHandle(new qpid::sys::IOHandlePrivate), + handle(new qpid::sys::IOHandle), pd(allocPd(i->verbs)), cchannel(mkCChannel(i->verbs)), scq(mkCq(i->verbs, DEFAULT_CQ_ENTRIES, 0, cchannel.get())), @@ -113,7 +113,7 @@ namespace Rdma { outstandingSendEvents(0), outstandingRecvEvents(0) { - impl->fd = cchannel->fd; + handle->fd = cchannel->fd; // Set cq context to this QueuePair object so we can find // ourselves again @@ -163,6 +163,11 @@ namespace Rdma { // The buffers vectors automatically deletes all the buffers we've allocated } + QueuePair::operator qpid::sys::IOHandle&() const + { + return *handle; + } + // Create buffers to use for writing void QueuePair::createSendBuffers(int sendBufferCount, int bufferSize, int reserved) { @@ -359,11 +364,11 @@ namespace Rdma { // Wrap the passed in rdma_cm_id with a Connection // this basically happens only on connection request Connection::Connection(::rdma_cm_id* i) : - qpid::sys::IOHandle(new qpid::sys::IOHandlePrivate), + handle(new qpid::sys::IOHandle), id(mkId(i)), context(0) { - impl->fd = id->channel->fd; + handle->fd = id->channel->fd; // Just overwrite the previous context as it will // have come from the listening connection @@ -372,12 +377,12 @@ namespace Rdma { } Connection::Connection() : - qpid::sys::IOHandle(new qpid::sys::IOHandlePrivate), + handle(new qpid::sys::IOHandle), channel(mkEChannel()), id(mkId(channel.get(), this, RDMA_PS_TCP)), context(0) { - impl->fd = channel->fd; + handle->fd = channel->fd; } Connection::~Connection() { @@ -385,6 +390,11 @@ namespace Rdma { id->context = 0; } + Connection::operator qpid::sys::IOHandle&() const + { + return *handle; + } + void Connection::ensureQueuePair() { assert(id.get()); diff --git a/cpp/src/qpid/sys/rdma/rdma_wrap.h b/cpp/src/qpid/sys/rdma/rdma_wrap.h index 8e3429027b..5f84793a5b 100644 --- a/cpp/src/qpid/sys/rdma/rdma_wrap.h +++ b/cpp/src/qpid/sys/rdma/rdma_wrap.h @@ -28,6 +28,7 @@ #include "qpid/sys/Mutex.h" #include <boost/shared_ptr.hpp> +#include <boost/scoped_ptr.hpp> #include <boost/intrusive_ptr.hpp> #include <boost/ptr_container/ptr_deque.hpp> @@ -116,9 +117,10 @@ namespace Rdma { // Wrapper for a queue pair - this has the functionality for // putting buffers on the receive queue and for sending buffers // to the other end of the connection. - class QueuePair : public qpid::sys::IOHandle, public qpid::RefCounted { + class QueuePair : public qpid::RefCounted { friend class Connection; + boost::scoped_ptr< qpid::sys::IOHandle > handle; boost::shared_ptr< ::ibv_pd > pd; boost::shared_ptr< ::ibv_mr > smr; boost::shared_ptr< ::ibv_mr > rmr; @@ -139,6 +141,8 @@ namespace Rdma { public: typedef boost::intrusive_ptr<QueuePair> intrusive_ptr; + operator qpid::sys::IOHandle&() const; + // Create a buffers to use for writing void createSendBuffers(int sendBufferCount, int dataSize, int headerSize); @@ -195,7 +199,8 @@ namespace Rdma { // registered buffers can't be shared between different connections // (this can only happen between connections on the same controller in any case, // so needs careful management if used) - class Connection : public qpid::sys::IOHandle, public qpid::RefCounted { + class Connection : public qpid::RefCounted { + boost::scoped_ptr< qpid::sys::IOHandle > handle; boost::shared_ptr< ::rdma_event_channel > channel; boost::shared_ptr< ::rdma_cm_id > id; QueuePair::intrusive_ptr qp; @@ -216,6 +221,8 @@ namespace Rdma { public: typedef boost::intrusive_ptr<Connection> intrusive_ptr; + operator qpid::sys::IOHandle&() const; + static intrusive_ptr make(); static intrusive_ptr find(::rdma_cm_id* i); diff --git a/cpp/src/qpid/sys/solaris/SystemInfo.cpp b/cpp/src/qpid/sys/solaris/SystemInfo.cpp index e5856f55e6..0e754e048b 100755 --- a/cpp/src/qpid/sys/solaris/SystemInfo.cpp +++ b/cpp/src/qpid/sys/solaris/SystemInfo.cpp @@ -60,31 +60,6 @@ bool SystemInfo::getLocalHostname(Address &address) { static const string LOCALHOST("127.0.0.1"); static const string TCP("tcp"); -void SystemInfo::getLocalIpAddresses(uint16_t port, - std::vector<Address> &addrList) { - int s = socket(PF_INET, SOCK_STREAM, 0); - for (int i=1;;i++) { - struct lifreq ifr; - ifr.lifr_index = i; - if (::ioctl(s, SIOCGIFADDR, &ifr) < 0) { - break; - } - struct sockaddr *sa = static_cast<struct sockaddr *>((void *) &ifr.lifr_addr); - if (sa->sa_family != AF_INET) { - // TODO: Url parsing currently can't cope with IPv6 addresses, defer for now - break; - } - struct sockaddr_in *sin = static_cast<struct sockaddr_in *>((void *)sa); - std::string addr(inet_ntoa(sin->sin_addr)); - if (addr != LOCALHOST) - addrList.push_back(Address(TCP, addr, port)); - } - if (addrList.empty()) { - addrList.push_back(Address(TCP, LOCALHOST, port)); - } - close (s); -} - void SystemInfo::getSystemId(std::string &osName, std::string &nodeName, std::string &release, @@ -126,4 +101,10 @@ string SystemInfo::getProcessName() return value; } +// Always true. Only Windows has exception cases. +bool SystemInfo::threadSafeShutdown() +{ + return true; +} + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/ssl/SslHandler.cpp b/cpp/src/qpid/sys/ssl/SslHandler.cpp deleted file mode 100644 index eeb8c26a76..0000000000 --- a/cpp/src/qpid/sys/ssl/SslHandler.cpp +++ /dev/null @@ -1,217 +0,0 @@ -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ -#include "qpid/sys/ssl/SslHandler.h" -#include "qpid/sys/ssl/SslIo.h" -#include "qpid/sys/ssl/SslSocket.h" -#include "qpid/sys/Timer.h" -#include "qpid/framing/AMQP_HighestVersion.h" -#include "qpid/framing/ProtocolInitiation.h" -#include "qpid/log/Statement.h" - -#include <boost/bind.hpp> - -namespace qpid { -namespace sys { -namespace ssl { - - -struct ProtocolTimeoutTask : public sys::TimerTask { - SslHandler& handler; - std::string id; - - ProtocolTimeoutTask(const std::string& i, const Duration& timeout, SslHandler& h) : - TimerTask(timeout, "ProtocolTimeout"), - handler(h), - id(i) - {} - - void fire() { - // If this fires it means that we didn't negotiate the connection in the timeout period - // Schedule closing the connection for the io thread - QPID_LOG(error, "Connection " << id << " No protocol received closing"); - handler.abort(); - } -}; - -SslHandler::SslHandler(std::string id, ConnectionCodec::Factory* f, bool _nodict) : - identifier(id), - aio(0), - factory(f), - codec(0), - readError(false), - isClient(false), - nodict(_nodict) -{} - -SslHandler::~SslHandler() { - if (codec) - codec->closed(); - if (timeoutTimerTask) - timeoutTimerTask->cancel(); - delete codec; -} - -void SslHandler::init(SslIO* a, Timer& timer, uint32_t maxTime) { - aio = a; - - // Start timer for this connection - timeoutTimerTask = new ProtocolTimeoutTask(identifier, maxTime*TIME_MSEC, *this); - timer.add(timeoutTimerTask); - - // Give connection some buffers to use - aio->createBuffers(); -} - -void SslHandler::write(const framing::ProtocolInitiation& data) -{ - QPID_LOG(debug, "SENT [" << identifier << "]: INIT(" << data << ")"); - SslIO::BufferBase* buff = aio->getQueuedBuffer(); - assert(buff); - framing::Buffer out(buff->bytes, buff->byteCount); - data.encode(out); - buff->dataCount = data.encodedSize(); - aio->queueWrite(buff); -} - -void SslHandler::abort() { - // Don't disconnect if we're already disconnecting - if (!readError) { - aio->requestCallback(boost::bind(&SslHandler::eof, this, _1)); - } -} -void SslHandler::activateOutput() { - aio->notifyPendingWrite(); -} - -void SslHandler::giveReadCredit(int32_t) { - // FIXME aconway 2008-12-05: not yet implemented. -} - -// Input side -void SslHandler::readbuff(SslIO& , SslIO::BufferBase* buff) { - if (readError) { - return; - } - size_t decoded = 0; - if (codec) { // Already initiated - try { - decoded = codec->decode(buff->bytes+buff->dataStart, buff->dataCount); - }catch(const std::exception& e){ - QPID_LOG(error, e.what()); - readError = true; - aio->queueWriteClose(); - } - }else{ - framing::Buffer in(buff->bytes+buff->dataStart, buff->dataCount); - framing::ProtocolInitiation protocolInit; - if (protocolInit.decode(in)) { - // We've just got the protocol negotiation so we can cancel the timeout for that - timeoutTimerTask->cancel(); - - decoded = in.getPosition(); - QPID_LOG(debug, "RECV [" << identifier << "]: INIT(" << protocolInit << ")"); - try { - codec = factory->create(protocolInit.getVersion(), *this, identifier, getSecuritySettings(aio)); - if (!codec) { - //TODO: may still want to revise this... - //send valid version header & close connection. - write(framing::ProtocolInitiation(framing::highestProtocolVersion)); - readError = true; - aio->queueWriteClose(); - } - } catch (const std::exception& e) { - QPID_LOG(error, e.what()); - readError = true; - aio->queueWriteClose(); - } - } - } - // TODO: unreading needs to go away, and when we can cope - // with multiple sub-buffers in the general buffer scheme, it will - if (decoded != size_t(buff->dataCount)) { - // Adjust buffer for used bytes and then "unread them" - buff->dataStart += decoded; - buff->dataCount -= decoded; - aio->unread(buff); - } else { - // Give whole buffer back to aio subsystem - aio->queueReadBuffer(buff); - } -} - -void SslHandler::eof(SslIO&) { - QPID_LOG(debug, "DISCONNECTED [" << identifier << "]"); - if (codec) codec->closed(); - aio->queueWriteClose(); -} - -void SslHandler::closedSocket(SslIO&, const SslSocket& s) { - // If we closed with data still to send log a warning - if (!aio->writeQueueEmpty()) { - QPID_LOG(warning, "CLOSING [" << identifier << "] unsent data (probably due to client disconnect)"); - } - delete &s; - aio->queueForDeletion(); - delete this; -} - -void SslHandler::disconnect(SslIO& a) { - // treat the same as eof - eof(a); -} - -// Notifications -void SslHandler::nobuffs(SslIO&) { -} - -void SslHandler::idle(SslIO&){ - if (isClient && codec == 0) { - codec = factory->create(*this, identifier, getSecuritySettings(aio)); - write(framing::ProtocolInitiation(codec->getVersion())); - // We've just sent the protocol negotiation so we can cancel the timeout for that - // This is not ideal, because we've not received anything yet, but heartbeats will - // be active soon - timeoutTimerTask->cancel(); - return; - } - if (codec == 0) return; - if (!codec->canEncode()) { - return; - } - SslIO::BufferBase* buff = aio->getQueuedBuffer(); - if (buff) { - size_t encoded=codec->encode(buff->bytes, buff->byteCount); - buff->dataCount = encoded; - aio->queueWrite(buff); - } - if (codec->isClosed()) - aio->queueWriteClose(); -} - -SecuritySettings SslHandler::getSecuritySettings(SslIO* aio) -{ - SecuritySettings settings = aio->getSecuritySettings(); - settings.nodict = nodict; - return settings; -} - - -}}} // namespace qpid::sys::ssl diff --git a/cpp/src/qpid/sys/ssl/SslHandler.h b/cpp/src/qpid/sys/ssl/SslHandler.h deleted file mode 100644 index 14814b0281..0000000000 --- a/cpp/src/qpid/sys/ssl/SslHandler.h +++ /dev/null @@ -1,85 +0,0 @@ -#ifndef QPID_SYS_SSL_SSLHANDLER_H -#define QPID_SYS_SSL_SSLHANDLER_H - -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include "qpid/sys/ConnectionCodec.h" -#include "qpid/sys/OutputControl.h" - -#include <boost/intrusive_ptr.hpp> - -namespace qpid { - -namespace framing { - class ProtocolInitiation; -} - -namespace sys { - -class Timer; -class TimerTask; - -namespace ssl { - -class SslIO; -struct SslIOBufferBase; -class SslSocket; - -class SslHandler : public OutputControl { - std::string identifier; - SslIO* aio; - ConnectionCodec::Factory* factory; - ConnectionCodec* codec; - bool readError; - bool isClient; - bool nodict; - boost::intrusive_ptr<sys::TimerTask> timeoutTimerTask; - - void write(const framing::ProtocolInitiation&); - qpid::sys::SecuritySettings getSecuritySettings(SslIO* aio); - - public: - SslHandler(std::string id, ConnectionCodec::Factory* f, bool nodict); - ~SslHandler(); - void init(SslIO* a, Timer& timer, uint32_t maxTime); - - void setClient() { isClient = true; } - - // Output side - void abort(); - void activateOutput(); - void giveReadCredit(int32_t); - - // Input side - void readbuff(SslIO& aio, SslIOBufferBase* buff); - void eof(SslIO& aio); - void disconnect(SslIO& aio); - - // Notifications - void nobuffs(SslIO& aio); - void idle(SslIO& aio); - void closedSocket(SslIO& aio, const SslSocket& s); -}; - -}}} // namespace qpid::sys::ssl - -#endif /*!QPID_SYS_SSL_SSLHANDLER_H*/ diff --git a/cpp/src/qpid/sys/ssl/SslIo.cpp b/cpp/src/qpid/sys/ssl/SslIo.cpp deleted file mode 100644 index bbfb703170..0000000000 --- a/cpp/src/qpid/sys/ssl/SslIo.cpp +++ /dev/null @@ -1,470 +0,0 @@ -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include "qpid/sys/ssl/SslIo.h" -#include "qpid/sys/ssl/SslSocket.h" -#include "qpid/sys/ssl/check.h" - -#include "qpid/sys/Time.h" -#include "qpid/sys/posix/check.h" -#include "qpid/log/Statement.h" - -// TODO The basic algorithm here is not really POSIX specific and with a bit more abstraction -// could (should) be promoted to be platform portable -#include <unistd.h> -#include <sys/socket.h> -#include <signal.h> -#include <errno.h> -#include <string.h> - -#include <boost/bind.hpp> - -namespace qpid { -namespace sys { -namespace ssl { - -namespace { - -/* - * Make *process* not generate SIGPIPE when writing to closed - * pipe/socket (necessary as default action is to terminate process) - */ -void ignoreSigpipe() { - ::signal(SIGPIPE, SIG_IGN); -} - -/* - * We keep per thread state to avoid locking overhead. The assumption is that - * on average all the connections are serviced by all the threads so the state - * recorded in each thread is about the same. If this turns out not to be the - * case we could rebalance the info occasionally. - */ -__thread int threadReadTotal = 0; -__thread int threadReadCount = 0; -__thread int threadWriteTotal = 0; -__thread int threadWriteCount = 0; -__thread int64_t threadMaxIoTimeNs = 2 * 1000000; // start at 2ms -} - -/* - * Asynch Acceptor - */ - -template <class T> -SslAcceptorTmpl<T>::SslAcceptorTmpl(const T& s, Callback callback) : - acceptedCallback(callback), - handle(s, boost::bind(&SslAcceptorTmpl<T>::readable, this, _1), 0, 0), - socket(s) { - - s.setNonblocking(); - ignoreSigpipe(); -} - -template <class T> -SslAcceptorTmpl<T>::~SslAcceptorTmpl() -{ - handle.stopWatch(); -} - -template <class T> -void SslAcceptorTmpl<T>::start(Poller::shared_ptr poller) { - handle.startWatch(poller); -} - -/* - * We keep on accepting as long as there is something to accept - */ -template <class T> -void SslAcceptorTmpl<T>::readable(DispatchHandle& h) { - Socket* s; - do { - errno = 0; - // TODO: Currently we ignore the peers address, perhaps we should - // log it or use it for connection acceptance. - try { - s = socket.accept(); - if (s) { - acceptedCallback(*s); - } else { - break; - } - } catch (const std::exception& e) { - QPID_LOG(error, "Could not accept socket: " << e.what()); - } - } while (true); - - h.rewatch(); -} - -// Explicitly instantiate the templates we need -template class SslAcceptorTmpl<SslSocket>; -template class SslAcceptorTmpl<SslMuxSocket>; - -/* - * Asynch Connector - */ - -SslConnector::SslConnector(const SslSocket& s, - Poller::shared_ptr poller, - std::string hostname, - std::string port, - ConnectedCallback connCb, - FailedCallback failCb) : - DispatchHandle(s, - 0, - boost::bind(&SslConnector::connComplete, this, _1), - boost::bind(&SslConnector::connComplete, this, _1)), - connCallback(connCb), - failCallback(failCb), - socket(s) -{ - //TODO: would be better for connect to be performed on a - //non-blocking socket, but that doesn't work at present so connect - //blocks until complete - try { - socket.connect(hostname, port); - socket.setNonblocking(); - startWatch(poller); - } catch(std::exception& e) { - failure(-1, std::string(e.what())); - } -} - -void SslConnector::connComplete(DispatchHandle& h) -{ - int errCode = socket.getError(); - - h.stopWatch(); - if (errCode == 0) { - connCallback(socket); - DispatchHandle::doDelete(); - } else { - // TODO: This need to be fixed as strerror isn't thread safe - failure(errCode, std::string(::strerror(errCode))); - } -} - -void SslConnector::failure(int errCode, std::string message) -{ - if (failCallback) - failCallback(errCode, message); - - socket.close(); - delete &socket; - - DispatchHandle::doDelete(); -} - -/* - * Asynch reader/writer - */ -SslIO::SslIO(const SslSocket& s, - ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb, - ClosedCallback cCb, BuffersEmptyCallback eCb, IdleCallback iCb) : - - DispatchHandle(s, - boost::bind(&SslIO::readable, this, _1), - boost::bind(&SslIO::writeable, this, _1), - boost::bind(&SslIO::disconnected, this, _1)), - readCallback(rCb), - eofCallback(eofCb), - disCallback(disCb), - closedCallback(cCb), - emptyCallback(eCb), - idleCallback(iCb), - socket(s), - queuedClose(false), - writePending(false) { - - s.setNonblocking(); -} - -SslIO::~SslIO() { -} - -void SslIO::queueForDeletion() { - DispatchHandle::doDelete(); -} - -void SslIO::start(Poller::shared_ptr poller) { - DispatchHandle::startWatch(poller); -} - -void SslIO::createBuffers(uint32_t size) { - // Allocate all the buffer memory at once - bufferMemory.reset(new char[size*BufferCount]); - - // Create the Buffer structs in a vector - // And push into the buffer queue - buffers.reserve(BufferCount); - for (uint32_t i = 0; i < BufferCount; i++) { - buffers.push_back(BufferBase(&bufferMemory[i*size], size)); - queueReadBuffer(&buffers[i]); - } -} - -void SslIO::queueReadBuffer(BufferBase* buff) { - assert(buff); - buff->dataStart = 0; - buff->dataCount = 0; - bufferQueue.push_back(buff); - DispatchHandle::rewatchRead(); -} - -void SslIO::unread(BufferBase* buff) { - assert(buff); - if (buff->dataStart != 0) { - memmove(buff->bytes, buff->bytes+buff->dataStart, buff->dataCount); - buff->dataStart = 0; - } - bufferQueue.push_front(buff); - DispatchHandle::rewatchRead(); -} - -void SslIO::queueWrite(BufferBase* buff) { - assert(buff); - // If we've already closed the socket then throw the write away - if (queuedClose) { - bufferQueue.push_front(buff); - return; - } else { - writeQueue.push_front(buff); - } - writePending = false; - DispatchHandle::rewatchWrite(); -} - -void SslIO::notifyPendingWrite() { - writePending = true; - DispatchHandle::rewatchWrite(); -} - -void SslIO::queueWriteClose() { - queuedClose = true; - DispatchHandle::rewatchWrite(); -} - -void SslIO::requestCallback(RequestCallback callback) { - // TODO creating a function object every time isn't all that - // efficient - if this becomes heavily used do something better (what?) - assert(callback); - DispatchHandle::call(boost::bind(&SslIO::requestedCall, this, callback)); -} - -void SslIO::requestedCall(RequestCallback callback) { - assert(callback); - callback(*this); -} - -/** Return a queued buffer if there are enough - * to spare - */ -SslIO::BufferBase* SslIO::getQueuedBuffer() { - // Always keep at least one buffer (it might have data that was "unread" in it) - if (bufferQueue.size()<=1) - return 0; - BufferBase* buff = bufferQueue.back(); - assert(buff); - buff->dataStart = 0; - buff->dataCount = 0; - bufferQueue.pop_back(); - return buff; -} - -/* - * We keep on reading as long as we have something to read and a buffer to put - * it in - */ -void SslIO::readable(DispatchHandle& h) { - AbsTime readStartTime = AbsTime::now(); - do { - // (Try to) get a buffer - if (!bufferQueue.empty()) { - // Read into buffer - BufferBase* buff = bufferQueue.front(); - assert(buff); - bufferQueue.pop_front(); - errno = 0; - int readCount = buff->byteCount-buff->dataCount; - int rc = socket.read(buff->bytes + buff->dataCount, readCount); - if (rc > 0) { - buff->dataCount += rc; - threadReadTotal += rc; - - readCallback(*this, buff); - if (rc != readCount) { - // If we didn't fill the read buffer then time to stop reading - break; - } - - // Stop reading if we've overrun our timeslot - if (Duration(readStartTime, AbsTime::now()) > threadMaxIoTimeNs) { - break; - } - - } else { - // Put buffer back (at front so it doesn't interfere with unread buffers) - bufferQueue.push_front(buff); - assert(buff); - - // Eof or other side has gone away - if (rc == 0 || errno == ECONNRESET) { - eofCallback(*this); - h.unwatchRead(); - break; - } else if (errno == EAGAIN) { - // We have just put a buffer back so we know - // we can carry on watching for reads - break; - } else { - // Report error then just treat as a socket disconnect - QPID_LOG(error, "Error reading socket: " << getErrorString(PR_GetError())); - eofCallback(*this); - h.unwatchRead(); - break; - } - } - } else { - // Something to read but no buffer - if (emptyCallback) { - emptyCallback(*this); - } - // If we still have no buffers we can't do anything more - if (bufferQueue.empty()) { - h.unwatchRead(); - break; - } - - } - } while (true); - - ++threadReadCount; - return; -} - -/* - * We carry on writing whilst we have data to write and we can write - */ -void SslIO::writeable(DispatchHandle& h) { - AbsTime writeStartTime = AbsTime::now(); - do { - // See if we've got something to write - if (!writeQueue.empty()) { - // Write buffer - BufferBase* buff = writeQueue.back(); - writeQueue.pop_back(); - errno = 0; - assert(buff->dataStart+buff->dataCount <= buff->byteCount); - int rc = socket.write(buff->bytes+buff->dataStart, buff->dataCount); - if (rc >= 0) { - threadWriteTotal += rc; - - // If we didn't write full buffer put rest back - if (rc != buff->dataCount) { - buff->dataStart += rc; - buff->dataCount -= rc; - writeQueue.push_back(buff); - break; - } - - // Recycle the buffer - queueReadBuffer(buff); - - // Stop writing if we've overrun our timeslot - if (Duration(writeStartTime, AbsTime::now()) > threadMaxIoTimeNs) { - break; - } - } else { - // Put buffer back - writeQueue.push_back(buff); - if (errno == ECONNRESET || errno == EPIPE) { - // Just stop watching for write here - we'll get a - // disconnect callback soon enough - h.unwatchWrite(); - break; - } else if (errno == EAGAIN) { - // We have just put a buffer back so we know - // we can carry on watching for writes - break; - } else { - QPID_LOG(error, "Error writing to socket: " << getErrorString(PR_GetError())); - h.unwatchWrite(); - break; - } - } - } else { - // If we're waiting to close the socket then can do it now as there is nothing to write - if (queuedClose) { - close(h); - break; - } - // Fd is writable, but nothing to write - if (idleCallback) { - writePending = false; - idleCallback(*this); - } - // If we still have no buffers to write we can't do anything more - if (writeQueue.empty() && !writePending && !queuedClose) { - h.unwatchWrite(); - // The following handles the case where writePending is - // set to true after the test above; in this case its - // possible that the unwatchWrite overwrites the - // desired rewatchWrite so we correct that here - if (writePending) - h.rewatchWrite(); - break; - } - } - } while (true); - - ++threadWriteCount; - return; -} - -void SslIO::disconnected(DispatchHandle& h) { - // If we've already queued close do it instead of disconnected callback - if (queuedClose) { - close(h); - } else if (disCallback) { - disCallback(*this); - h.unwatch(); - } -} - -/* - * Close the socket and callback to say we've done it - */ -void SslIO::close(DispatchHandle& h) { - h.stopWatch(); - socket.close(); - if (closedCallback) { - closedCallback(*this, socket); - } -} - -SecuritySettings SslIO::getSecuritySettings() { - SecuritySettings settings; - settings.ssf = socket.getKeyLen(); - settings.authid = socket.getClientAuthId(); - return settings; -} - -}}} diff --git a/cpp/src/qpid/sys/ssl/SslIo.h b/cpp/src/qpid/sys/ssl/SslIo.h deleted file mode 100644 index f3112bfa65..0000000000 --- a/cpp/src/qpid/sys/ssl/SslIo.h +++ /dev/null @@ -1,193 +0,0 @@ -#ifndef _sys_ssl_SslIO -#define _sys_ssl_SslIO -/* - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - */ - -#include "qpid/sys/DispatchHandle.h" -#include "qpid/sys/SecuritySettings.h" - -#include <boost/function.hpp> -#include <boost/shared_array.hpp> -#include <deque> - -namespace qpid { -namespace sys { - -class Socket; - -namespace ssl { - -class SslSocket; - -/* - * Asynchronous ssl acceptor: accepts connections then does a callback - * with the accepted fd - */ -template <class T> -class SslAcceptorTmpl { -public: - typedef boost::function1<void, const Socket&> Callback; - -private: - Callback acceptedCallback; - qpid::sys::DispatchHandle handle; - const T& socket; - -public: - SslAcceptorTmpl(const T& s, Callback callback); - ~SslAcceptorTmpl(); - void start(qpid::sys::Poller::shared_ptr poller); - -private: - void readable(qpid::sys::DispatchHandle& handle); -}; - -/* - * Asynchronous ssl connector: starts the process of initiating a - * connection and invokes a callback when completed or failed. - */ -class SslConnector : private qpid::sys::DispatchHandle { -public: - typedef boost::function1<void, const SslSocket&> ConnectedCallback; - typedef boost::function2<void, int, std::string> FailedCallback; - -private: - ConnectedCallback connCallback; - FailedCallback failCallback; - const SslSocket& socket; - -public: - SslConnector(const SslSocket& socket, - Poller::shared_ptr poller, - std::string hostname, - std::string port, - ConnectedCallback connCb, - FailedCallback failCb = 0); - -private: - void connComplete(DispatchHandle& handle); - void failure(int, std::string); -}; - -struct SslIOBufferBase { - char* bytes; - int32_t byteCount; - int32_t dataStart; - int32_t dataCount; - - SslIOBufferBase(char* const b, const int32_t s) : - bytes(b), - byteCount(s), - dataStart(0), - dataCount(0) - {} - - virtual ~SslIOBufferBase() - {} -}; - -/* - * Asychronous reader/writer: - * Reader accepts buffers to read into; reads into the provided buffers - * and then does a callback with the buffer and amount read. Optionally it can callback - * when there is something to read but no buffer to read it into. - * - * Writer accepts a buffer and queues it for writing; can also be given - * a callback for when writing is "idle" (ie fd is writable, but nothing to write) - * - * The class is implemented in terms of DispatchHandle to allow it to be deleted by deleting - * the contained DispatchHandle - */ -class SslIO : private qpid::sys::DispatchHandle { -public: - typedef SslIOBufferBase BufferBase; - - typedef boost::function2<void, SslIO&, BufferBase*> ReadCallback; - typedef boost::function1<void, SslIO&> EofCallback; - typedef boost::function1<void, SslIO&> DisconnectCallback; - typedef boost::function2<void, SslIO&, const SslSocket&> ClosedCallback; - typedef boost::function1<void, SslIO&> BuffersEmptyCallback; - typedef boost::function1<void, SslIO&> IdleCallback; - typedef boost::function1<void, SslIO&> RequestCallback; - - SslIO(const SslSocket& s, - ReadCallback rCb, EofCallback eofCb, DisconnectCallback disCb, - ClosedCallback cCb = 0, BuffersEmptyCallback eCb = 0, IdleCallback iCb = 0); -private: - ReadCallback readCallback; - EofCallback eofCallback; - DisconnectCallback disCallback; - ClosedCallback closedCallback; - BuffersEmptyCallback emptyCallback; - IdleCallback idleCallback; - const SslSocket& socket; - std::deque<BufferBase*> bufferQueue; - std::deque<BufferBase*> writeQueue; - std::vector<BufferBase> buffers; - boost::shared_array<char> bufferMemory; - bool queuedClose; - /** - * This flag is used to detect and handle concurrency between - * calls to notifyPendingWrite() (which can be made from any thread) and - * the execution of the writeable() method (which is always on the - * thread processing this handle. - */ - volatile bool writePending; - -public: - /* - * Size of IO buffers - this is the maximum possible frame size + 1 - */ - const static uint32_t MaxBufferSize = 65536; - - /* - * Number of IO buffers allocated - I think the code can only use 2 - - * 1 for reading and 1 for writing, allocate 4 for safety - */ - const static uint32_t BufferCount = 4; - - void queueForDeletion(); - - void start(qpid::sys::Poller::shared_ptr poller); - void createBuffers(uint32_t size = MaxBufferSize); - void queueReadBuffer(BufferBase* buff); - void unread(BufferBase* buff); - void queueWrite(BufferBase* buff); - void notifyPendingWrite(); - void queueWriteClose(); - bool writeQueueEmpty() { return writeQueue.empty(); } - void requestCallback(RequestCallback); - BufferBase* getQueuedBuffer(); - - qpid::sys::SecuritySettings getSecuritySettings(); - -private: - ~SslIO(); - void readable(qpid::sys::DispatchHandle& handle); - void writeable(qpid::sys::DispatchHandle& handle); - void disconnected(qpid::sys::DispatchHandle& handle); - void requestedCall(RequestCallback); - void close(qpid::sys::DispatchHandle& handle); -}; - -}}} - -#endif // _sys_ssl_SslIO diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp index 30234bb686..a328e49c13 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.cpp +++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -20,6 +20,7 @@ */ #include "qpid/sys/ssl/SslSocket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/sys/ssl/check.h" #include "qpid/sys/ssl/util.h" #include "qpid/Exception.h" @@ -52,28 +53,6 @@ namespace sys { namespace ssl { namespace { -std::string getService(int fd, bool local) -{ - ::sockaddr_storage name; // big enough for any socket address - ::socklen_t namelen = sizeof(name); - - int result = -1; - if (local) { - result = ::getsockname(fd, (::sockaddr*)&name, &namelen); - } else { - result = ::getpeername(fd, (::sockaddr*)&name, &namelen); - } - - QPID_POSIX_CHECK(result); - - char servName[NI_MAXSERV]; - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0, - servName, sizeof(servName), - NI_NUMERICHOST | NI_NUMERICSERV) != 0) - throw QPID_POSIX_ERROR(rc); - return servName; -} - const std::string DOMAIN_SEPARATOR("@"); const std::string DC_SEPARATOR("."); const std::string DC("DC"); @@ -101,14 +80,18 @@ std::string getDomainFromSubject(std::string subject) } return domain; } - } -SslSocket::SslSocket() : socket(0), prototype(0) +SslSocket::SslSocket(const std::string& certName, bool clientAuth) : + nssSocket(0), certname(certName), prototype(0) { - impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); - if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); - socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); + //configure prototype socket: + prototype = SSL_ImportFD(0, PR_NewTCPSocket()); + + if (clientAuth) { + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); + } } /** @@ -116,25 +99,44 @@ SslSocket::SslSocket() : socket(0), prototype(0) * returned from accept. Because we use posix accept rather than * PR_Accept, we have to reset the handshake. */ -SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0) +SslSocket::SslSocket(int fd, PRFileDesc* model) : BSDSocket(fd), nssSocket(0), prototype(0) { - socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); - NSS_CHECK(SSL_ResetHandshake(socket, true)); + nssSocket = SSL_ImportFD(model, PR_ImportTCPSocket(fd)); + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_TRUE)); } void SslSocket::setNonblocking() const { + if (!nssSocket) { + BSDSocket::setNonblocking(); + return; + } PRSocketOptionData option; option.option = PR_SockOpt_Nonblocking; option.value.non_blocking = true; - PR_SetSocketOption(socket, &option); + PR_SetSocketOption(nssSocket, &option); } -void SslSocket::connect(const std::string& host, const std::string& port) const +void SslSocket::setTcpNoDelay() const { - std::stringstream namestream; - namestream << host << ":" << port; - connectname = namestream.str(); + if (!nssSocket) { + BSDSocket::setTcpNoDelay(); + return; + } + PRSocketOptionData option; + option.option = PR_SockOpt_NoDelay; + option.value.no_delay = true; + PR_SetSocketOption(nssSocket, &option); +} + +void SslSocket::connect(const SocketAddress& addr) const +{ + BSDSocket::connect(addr); +} + +void SslSocket::finishConnect(const SocketAddress& addr) const +{ + nssSocket = SSL_ImportFD(0, PR_ImportTCPSocket(fd)); void* arg; // Use the connection's cert-name if it has one; else use global cert-name @@ -145,75 +147,48 @@ void SslSocket::connect(const std::string& host, const std::string& port) const } else { arg = const_cast<char*>(SslOptions::global.certName.c_str()); } - NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg)); - NSS_CHECK(SSL_SetURL(socket, host.data())); - - char hostBuffer[PR_NETDB_BUF_SIZE]; - PRHostEnt hostEntry; - PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry)); - PRNetAddr address; - int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address); - if (value < 0) { - throw Exception(QPID_MSG("Error getting address for host: " << ErrorString())); - } else if (value == 0) { - throw Exception(QPID_MSG("Could not resolve address for host.")); - } - PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT)); - NSS_CHECK(SSL_ForceHandshake(socket)); + NSS_CHECK(SSL_GetClientAuthDataHook(nssSocket, NSS_GetClientAuthData, arg)); + + url = addr.getHost(); + NSS_CHECK(SSL_SetURL(nssSocket, url.data())); + + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_FALSE)); + NSS_CHECK(SSL_ForceHandshake(nssSocket)); } void SslSocket::close() const { - if (impl->fd > 0) { - PR_Close(socket); - impl->fd = -1; + if (!nssSocket) { + BSDSocket::close(); + return; + } + if (fd > 0) { + PR_Close(nssSocket); + fd = -1; } } -int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const +int SslSocket::listen(const SocketAddress& sa, int backlog) const { - //configure prototype socket: - prototype = SSL_ImportFD(0, PR_NewTCPSocket()); - if (clientAuth) { - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); - } - //get certificate and key (is this the correct way?) - CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0); - if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'")); + std::string cName( (certname == "") ? "localhost.localdomain" : certname); + CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(cName.c_str()), 0); + if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << cName << "'")); SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0); if (!key) throw Exception(QPID_MSG("Failed to retrieve private key from certificate")); NSS_CHECK(SSL_ConfigSecureServer(prototype, cert, key, NSS_FindCertKEAType(cert))); SECKEY_DestroyPrivateKey(key); CERT_DestroyCertificate(cert); - //bind and listen - const int& socket = impl->fd; - int yes=1; - QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes))); - struct sockaddr_in name; - name.sin_family = AF_INET; - name.sin_port = htons(port); - name.sin_addr.s_addr = 0; - if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0) - throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno))); - if (::listen(socket, backlog) < 0) - throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno))); - - socklen_t namelen = sizeof(name); - if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0) - throw QPID_POSIX_ERROR(errno); - - return ntohs(name.sin_port); + return BSDSocket::listen(sa, backlog); } -SslSocket* SslSocket::accept() const +Socket* SslSocket::accept() const { QPID_LOG(trace, "Accepting SSL connection."); - int afd = ::accept(impl->fd, 0, 0); + int afd = ::accept(fd, 0, 0); if ( afd >= 0) { - return new SslSocket(new IOHandlePrivate(afd), prototype); + return new SslSocket(afd, prototype); } else if (errno == EAGAIN) { return 0; } else { @@ -297,17 +272,22 @@ static bool isSslStream(int afd) { return isSSL2Handshake || isSSL3Handshake; } +SslMuxSocket::SslMuxSocket(const std::string& certName, bool clientAuth) : + SslSocket(certName, clientAuth) +{ +} + Socket* SslMuxSocket::accept() const { - int afd = ::accept(impl->fd, 0, 0); + int afd = ::accept(fd, 0, 0); if (afd >= 0) { QPID_LOG(trace, "Accepting connection with optional SSL wrapper."); if (isSslStream(afd)) { QPID_LOG(trace, "Accepted SSL connection."); - return new SslSocket(new IOHandlePrivate(afd), prototype); + return new SslSocket(afd, prototype); } else { QPID_LOG(trace, "Accepted Plaintext connection."); - return new Socket(new IOHandlePrivate(afd)); + return new BSDSocket(afd); } } else if (errno == EAGAIN) { return 0; @@ -318,32 +298,12 @@ Socket* SslMuxSocket::accept() const int SslSocket::read(void *buf, size_t count) const { - return PR_Read(socket, buf, count); + return PR_Read(nssSocket, buf, count); } int SslSocket::write(const void *buf, size_t count) const { - return PR_Write(socket, buf, count); -} - -uint16_t SslSocket::getLocalPort() const -{ - return std::atoi(getService(impl->fd, true).c_str()); -} - -uint16_t SslSocket::getRemotePort() const -{ - return atoi(getService(impl->fd, true).c_str()); -} - -void SslSocket::setTcpNoDelay(bool nodelay) const -{ - if (nodelay) { - PRSocketOptionData option; - option.option = PR_SockOpt_NoDelay; - option.value.no_delay = true; - PR_SetSocketOption(socket, &option); - } + return PR_Write(nssSocket, buf, count); } void SslSocket::setCertName(const std::string& name) @@ -359,7 +319,7 @@ int SslSocket::getKeyLen() const int keySize = 0; SECStatus rc; - rc = SSL_SecurityStatus( socket, + rc = SSL_SecurityStatus( nssSocket, &enabled, NULL, NULL, @@ -374,7 +334,7 @@ int SslSocket::getKeyLen() const std::string SslSocket::getClientAuthId() const { std::string authId; - CERTCertificate* cert = SSL_PeerCertificate(socket); + CERTCertificate* cert = SSL_PeerCertificate(nssSocket); if (cert) { authId = CERT_GetCommonName(&(cert->subject)); /* diff --git a/cpp/src/qpid/sys/ssl/SslSocket.h b/cpp/src/qpid/sys/ssl/SslSocket.h index eabadcbe23..fc97059cfd 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.h +++ b/cpp/src/qpid/sys/ssl/SslSocket.h @@ -23,7 +23,7 @@ */ #include "qpid/sys/IOHandle.h" -#include "qpid/sys/Socket.h" +#include "qpid/sys/posix/BSDSocket.h" #include <nspr.h> #include <string> @@ -37,55 +37,54 @@ class Duration; namespace ssl { -class SslSocket : public qpid::sys::Socket +class SslSocket : public qpid::sys::BSDSocket { public: - /** Create a socket wrapper for descriptor. */ - SslSocket(); + /** Create a socket wrapper for descriptor. + *@param certName name of certificate to use to identify the socket + */ + SslSocket(const std::string& certName = "", bool clientAuth = false); /** Set socket non blocking */ void setNonblocking() const; /** Set tcp-nodelay */ - void setTcpNoDelay(bool nodelay) const; + void setTcpNoDelay() const; /** Set SSL cert-name. Allows the cert-name to be set per * connection, overriding global cert-name settings from * NSSInit().*/ void setCertName(const std::string& certName); - void connect(const std::string& host, const std::string& port) const; + void connect(const SocketAddress&) const; + void finishConnect(const SocketAddress&) const; void close() const; /** Bind to a port and start listening. *@param port 0 means choose an available port. *@param backlog maximum number of pending connections. - *@param certName name of certificate to use to identify the server *@return The bound port. */ - int listen(uint16_t port = 0, int backlog = 10, const std::string& certName = "localhost.localdomain", bool clientAuth = false) const; + int listen(const SocketAddress&, int backlog = 10) const; /** * Accept a connection from a socket that is already listening * and has an incoming connection */ - SslSocket* accept() const; + virtual Socket* accept() const; // TODO The following are raw operations, maybe they need better wrapping? int read(void *buf, size_t count) const; int write(const void *buf, size_t count) const; - uint16_t getLocalPort() const; - uint16_t getRemotePort() const; - int getKeyLen() const; std::string getClientAuthId() const; protected: - mutable std::string connectname; - mutable PRFileDesc* socket; + mutable PRFileDesc* nssSocket; std::string certname; + mutable std::string url; /** * 'model' socket, with configuration to use when importing @@ -94,13 +93,14 @@ protected: */ mutable PRFileDesc* prototype; - SslSocket(IOHandlePrivate* ioph, PRFileDesc* model); - friend class SslMuxSocket; + SslSocket(int fd, PRFileDesc* model); + friend class SslMuxSocket; // Needed for this constructor }; class SslMuxSocket : public SslSocket { public: + SslMuxSocket(const std::string& certName = "", bool clientAuth = false); Socket* accept() const; }; diff --git a/cpp/src/qpid/sys/ssl/util.cpp b/cpp/src/qpid/sys/ssl/util.cpp index 3078e894df..de5d638b09 100644 --- a/cpp/src/qpid/sys/ssl/util.cpp +++ b/cpp/src/qpid/sys/ssl/util.cpp @@ -31,8 +31,6 @@ #include <iostream> #include <fstream> -#include <boost/filesystem/operations.hpp> -#include <boost/filesystem/path.hpp> namespace qpid { namespace sys { @@ -82,15 +80,14 @@ SslOptions SslOptions::global; char* readPasswordFromFile(PK11SlotInfo*, PRBool retry, void*) { const std::string& passwordFile = SslOptions::global.certPasswordFile; - if (retry || passwordFile.empty() || !boost::filesystem::exists(passwordFile)) { - return 0; - } else { - std::ifstream file(passwordFile.c_str()); - std::string password; - file >> password; - return PL_strdup(password.c_str()); - } -} + if (retry || passwordFile.empty()) return 0; + std::ifstream file(passwordFile.c_str()); + if (!file) return 0; + + std::string password; + file >> password; + return PL_strdup(password.c_str()); +} void initNSS(const SslOptions& options, bool server) { diff --git a/cpp/src/qpid/sys/windows/AsynchIO.cpp b/cpp/src/qpid/sys/windows/AsynchIO.cpp index 355acbe0e6..b36ee9f941 100644 --- a/cpp/src/qpid/sys/windows/AsynchIO.cpp +++ b/cpp/src/qpid/sys/windows/AsynchIO.cpp @@ -24,6 +24,8 @@ #include "qpid/sys/AsynchIO.h" #include "qpid/sys/Mutex.h" #include "qpid/sys/Socket.h" +#include "qpid/sys/windows/WinSocket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/sys/Poller.h" #include "qpid/sys/Thread.h" #include "qpid/sys/Time.h" @@ -50,8 +52,8 @@ namespace { * The function pointers for AcceptEx and ConnectEx need to be looked up * at run time. */ -const LPFN_ACCEPTEX lookUpAcceptEx(const qpid::sys::Socket& s) { - SOCKET h = toSocketHandle(s); +const LPFN_ACCEPTEX lookUpAcceptEx(const qpid::sys::IOHandle& io) { + SOCKET h = io.fd; GUID guidAcceptEx = WSAID_ACCEPTEX; DWORD dwBytes = 0; LPFN_ACCEPTEX fnAcceptEx; @@ -93,12 +95,14 @@ private: AsynchAcceptor::Callback acceptedCallback; const Socket& socket; + const SOCKET wSocket; const LPFN_ACCEPTEX fnAcceptEx; }; AsynchAcceptor::AsynchAcceptor(const Socket& s, Callback callback) : acceptedCallback(callback), socket(s), + wSocket(IOHandle(s).fd), fnAcceptEx(lookUpAcceptEx(s)) { s.setNonblocking(); @@ -121,8 +125,8 @@ void AsynchAcceptor::restart(void) { this, socket); BOOL status; - status = fnAcceptEx(toSocketHandle(socket), - toSocketHandle(*result->newSocket), + status = fnAcceptEx(wSocket, + IOHandle(*result->newSocket).fd, result->addressBuffer, 0, AsynchAcceptResult::SOCKADDRMAXLEN, @@ -133,16 +137,30 @@ void AsynchAcceptor::restart(void) { } +Socket* createSameTypeSocket(const Socket& sock) { + SOCKET socket = IOHandle(sock).fd; + // Socket currently has no actual socket attached + if (socket == INVALID_SOCKET) + return new WinSocket; + + ::sockaddr_storage sa; + ::socklen_t salen = sizeof(sa); + QPID_WINSOCK_CHECK(::getsockname(socket, (::sockaddr*)&sa, &salen)); + SOCKET s = ::socket(sa.ss_family, SOCK_STREAM, 0); // Currently only work with SOCK_STREAM + if (s == INVALID_SOCKET) throw QPID_WINDOWS_ERROR(WSAGetLastError()); + return new WinSocket(s); +} + AsynchAcceptResult::AsynchAcceptResult(AsynchAcceptor::Callback cb, AsynchAcceptor *acceptor, - const Socket& listener) + const Socket& lsocket) : callback(cb), acceptor(acceptor), - listener(toSocketHandle(listener)), - newSocket(listener.createSameTypeSocket()) { + listener(IOHandle(lsocket).fd), + newSocket(createSameTypeSocket(lsocket)) { } void AsynchAcceptResult::success(size_t /*bytesTransferred*/) { - ::setsockopt (toSocketHandle(*newSocket), + ::setsockopt (IOHandle(*newSocket).fd, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, (char*)&listener, @@ -180,6 +198,7 @@ public: ConnectedCallback connCb, FailedCallback failCb = 0); void start(Poller::shared_ptr poller); + void requestCallback(RequestCallback rCb); }; AsynchConnector::AsynchConnector(const Socket& sock, @@ -195,7 +214,7 @@ AsynchConnector::AsynchConnector(const Socket& sock, void AsynchConnector::start(Poller::shared_ptr) { try { - socket.connect(hostname, port); + socket.connect(SocketAddress(hostname, port)); socket.setNonblocking(); connCallback(socket); } catch(std::exception& e) { @@ -205,6 +224,13 @@ void AsynchConnector::start(Poller::shared_ptr) } } +// This can never be called in the current windows code as connect +// is blocking and requestCallback only makes sense if connect is +// non-blocking with the results returned via a poller callback. +void AsynchConnector::requestCallback(RequestCallback rCb) +{ +} + } // namespace windows AsynchAcceptor* AsynchAcceptor::create(const Socket& s, @@ -260,8 +286,6 @@ public: virtual void notifyPendingWrite(); virtual void queueWriteClose(); virtual bool writeQueueEmpty(); - virtual void startReading(); - virtual void stopReading(); virtual void requestCallback(RequestCallback); /** @@ -272,6 +296,8 @@ public: */ virtual BufferBase* getQueuedBuffer(); + virtual SecuritySettings getSecuritySettings(void); + private: ReadCallback readCallback; EofCallback eofCallback; @@ -319,6 +345,12 @@ private: void close(void); /** + * startReading initiates reading, readComplete() is + * called when the read completes. + */ + void startReading(); + + /** * readComplete is called when a read request is complete. * * @param result Results of the operation. @@ -362,7 +394,7 @@ class CallbackHandle : public IOHandle { public: CallbackHandle(AsynchIoResult::Completer completeCb, AsynchIO::RequestCallback reqCb = 0) : - IOHandle(new IOHandlePrivate (INVALID_SOCKET, completeCb, reqCb)) + IOHandle(INVALID_SOCKET, completeCb, reqCb) {} }; @@ -515,7 +547,7 @@ void AsynchIO::startReading() { DWORD bytesReceived = 0, flags = 0; InterlockedIncrement(&opsInProgress); readInProgress = true; - int status = WSARecv(toSocketHandle(socket), + int status = WSARecv(IOHandle(socket).fd, const_cast<LPWSABUF>(result->getWSABUF()), 1, &bytesReceived, &flags, @@ -537,15 +569,6 @@ void AsynchIO::startReading() { return; } -// stopReading was added to prevent a race condition with read-credit on Linux. -// It may or may not be required on windows. -// -// AsynchIOHandler::readbuff() calls stopReading() inside the same -// critical section that protects startReading() in -// AsynchIOHandler::giveReadCredit(). -// -void AsynchIO::stopReading() {} - // Queue the specified callback for invocation from an I/O thread. void AsynchIO::requestCallback(RequestCallback callback) { // This method is generally called from a processing thread; transfer @@ -613,7 +636,7 @@ void AsynchIO::startWrite(AsynchIO::BufferBase* buff) { buff, buff->dataCount); DWORD bytesSent = 0; - int status = WSASend(toSocketHandle(socket), + int status = WSASend(IOHandle(socket).fd, const_cast<LPWSABUF>(result->getWSABUF()), 1, &bytesSent, 0, @@ -639,6 +662,13 @@ void AsynchIO::close(void) { notifyClosed(); } +SecuritySettings AsynchIO::getSecuritySettings() { + SecuritySettings settings; + settings.ssf = socket.getKeyLen(); + settings.authid = socket.getClientAuthId(); + return settings; +} + void AsynchIO::readComplete(AsynchReadResult *result) { int status = result->getStatus(); size_t bytes = result->getTransferred(); @@ -683,7 +713,8 @@ void AsynchIO::writeComplete(AsynchWriteResult *result) { else { // An error... if it's a connection close, ignore it - it will be // noticed and handled on a read completion any moment now. - // What to do with real error??? Save the Buffer? + // What to do with real error??? Save the Buffer? TBD. + queueReadBuffer(buff); // All done; back to the pool } } diff --git a/cpp/src/qpid/sys/windows/FileSysDir.cpp b/cpp/src/qpid/sys/windows/FileSysDir.cpp index 88f1637d48..e090747715 100644 --- a/cpp/src/qpid/sys/windows/FileSysDir.cpp +++ b/cpp/src/qpid/sys/windows/FileSysDir.cpp @@ -24,6 +24,9 @@ #include <sys/stat.h> #include <direct.h> #include <errno.h> +#include <windows.h> +#include <strsafe.h> + namespace qpid { namespace sys { @@ -50,4 +53,36 @@ void FileSysDir::mkdir(void) throw Exception ("Can't create directory: " + dirPath); } +void FileSysDir::forEachFile(Callback cb) const { + + WIN32_FIND_DATAA findFileData; + char szDir[MAX_PATH]; + size_t dirPathLength; + HANDLE hFind = INVALID_HANDLE_VALUE; + + // create dirPath+"\*" in szDir + StringCchLength (dirPath.c_str(), MAX_PATH, &dirPathLength); + + if (dirPathLength > (MAX_PATH - 3)) { + throw Exception ("Directory path is too long: " + dirPath); + } + + StringCchCopy(szDir, MAX_PATH, dirPath.c_str()); + StringCchCat(szDir, MAX_PATH, TEXT("\\*")); + + // Special work for first file + hFind = FindFirstFileA(szDir, &findFileData); + if (INVALID_HANDLE_VALUE == hFind) { + return; + } + + // process everything that isn't a directory + do { + if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { + std::string fileName(findFileData.cFileName); + cb(fileName); + } + } while (FindNextFile(hFind, &findFileData) != 0); +} + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/windows/IOHandle.cpp b/cpp/src/qpid/sys/windows/IOHandle.cpp index 250737cb99..19a1c44875 100755 --- a/cpp/src/qpid/sys/windows/IOHandle.cpp +++ b/cpp/src/qpid/sys/windows/IOHandle.cpp @@ -19,24 +19,11 @@ * */ -#include "qpid/sys/IOHandle.h" #include "qpid/sys/windows/IoHandlePrivate.h" #include <windows.h> namespace qpid { namespace sys { -SOCKET toFd(const IOHandlePrivate* h) -{ - return h->fd; -} - -IOHandle::IOHandle(IOHandlePrivate* h) : - impl(h) -{} - -IOHandle::~IOHandle() { - delete impl; -} }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/windows/IoHandlePrivate.h b/cpp/src/qpid/sys/windows/IoHandlePrivate.h index 5943db5cc7..4529ad93ec 100755 --- a/cpp/src/qpid/sys/windows/IoHandlePrivate.h +++ b/cpp/src/qpid/sys/windows/IoHandlePrivate.h @@ -38,15 +38,14 @@ namespace sys { // completer from an I/O thread. If the callback mechanism is used, there // can be a RequestCallback set - this carries the callback object through // from AsynchIO::requestCallback() through to the I/O completion processing. -class IOHandlePrivate { - friend QPID_COMMON_EXTERN SOCKET toSocketHandle(const Socket& s); - static IOHandlePrivate* getImpl(const IOHandle& h); - +class IOHandle { public: - IOHandlePrivate(SOCKET f = INVALID_SOCKET, - windows::AsynchIoResult::Completer cb = 0, - AsynchIO::RequestCallback reqCallback = 0) : - fd(f), event(cb), cbRequest(reqCallback) + IOHandle(SOCKET f = INVALID_SOCKET, + windows::AsynchIoResult::Completer cb = 0, + AsynchIO::RequestCallback reqCallback = 0) : + fd(f), + event(cb), + cbRequest(reqCallback) {} SOCKET fd; @@ -54,8 +53,6 @@ public: AsynchIO::RequestCallback cbRequest; }; -QPID_COMMON_EXTERN SOCKET toSocketHandle(const Socket& s); - }} #endif /* _sys_windows_IoHandlePrivate_h */ diff --git a/cpp/src/qpid/sys/windows/IocpPoller.cpp b/cpp/src/qpid/sys/windows/IocpPoller.cpp index c81cef87b0..ecb33c5517 100755 --- a/cpp/src/qpid/sys/windows/IocpPoller.cpp +++ b/cpp/src/qpid/sys/windows/IocpPoller.cpp @@ -22,7 +22,7 @@ #include "qpid/sys/Poller.h" #include "qpid/sys/Mutex.h" #include "qpid/sys/Dispatcher.h" - +#include "qpid/sys/IOHandle.h" #include "qpid/sys/windows/AsynchIoResult.h" #include "qpid/sys/windows/IoHandlePrivate.h" #include "qpid/sys/windows/check.h" @@ -55,7 +55,7 @@ class PollerHandlePrivate { }; PollerHandle::PollerHandle(const IOHandle& h) : - impl(new PollerHandlePrivate(toSocketHandle(static_cast<const Socket&>(h)), h.impl->event, h.impl->cbRequest)) + impl(new PollerHandlePrivate(h.fd, h.event, h.cbRequest)) {} PollerHandle::~PollerHandle() { diff --git a/cpp/src/qpid/sys/windows/PollableCondition.cpp b/cpp/src/qpid/sys/windows/PollableCondition.cpp index bb637be0a6..3e2a5fb36c 100644 --- a/cpp/src/qpid/sys/windows/PollableCondition.cpp +++ b/cpp/src/qpid/sys/windows/PollableCondition.cpp @@ -52,14 +52,14 @@ private: PollableCondition& parent; boost::shared_ptr<sys::Poller> poller; LONG isSet; + LONG isDispatching; }; PollableConditionPrivate::PollableConditionPrivate(const sys::PollableCondition::Callback& cb, sys::PollableCondition& parent, const boost::shared_ptr<sys::Poller>& poller) - : IOHandle(new sys::IOHandlePrivate(INVALID_SOCKET, - boost::bind(&PollableConditionPrivate::dispatch, this, _1))), - cb(cb), parent(parent), poller(poller), isSet(0) + : IOHandle(INVALID_SOCKET, boost::bind(&PollableConditionPrivate::dispatch, this, _1)), + cb(cb), parent(parent), poller(poller), isSet(0), isDispatching(0) { } @@ -78,7 +78,12 @@ void PollableConditionPrivate::poke() void PollableConditionPrivate::dispatch(windows::AsynchIoResult *result) { delete result; // Poller::monitorHandle() allocates this + // If isDispatching is already set, just return. Else, enter. + if (::InterlockedCompareExchange(&isDispatching, 1, 0) == 1) + return; cb(parent); + LONG oops = ::InterlockedDecrement(&isDispatching); // Result must be 0 + assert(!oops); if (isSet) poke(); } diff --git a/cpp/src/qpid/sys/windows/QpidDllMain.h b/cpp/src/qpid/sys/windows/QpidDllMain.h new file mode 100644 index 0000000000..74eaf0256a --- /dev/null +++ b/cpp/src/qpid/sys/windows/QpidDllMain.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +/* + * Include this file once in each DLL that relies on SystemInfo.h: + * threadSafeShutdown(). Note that Thread.cpp has a more elaborate + * DllMain, that also provides this functionality separately. + * + * Teardown is in the reverse order of the DLL dependencies used + * during the load phase. The calls to DllMain and the static + * destructors are from the same thread, so no locking is necessary + * and there is no downside to an invocation of DllMain by multiple + * Qpid DLLs. + */ + +#ifdef _DLL + +#include <qpid/ImportExport.h> +#include <windows.h> + +namespace qpid { +namespace sys { +namespace windows { + +QPID_IMPORT bool processExiting; +QPID_IMPORT bool libraryUnloading; + +}}} // namespace qpid::sys::SystemInfo + + +BOOL APIENTRY DllMain(HMODULE hm, DWORD reason, LPVOID reserved) { + switch (reason) { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + break; + + case DLL_PROCESS_DETACH: + // Remember how the process is terminating this DLL. + if (reserved != NULL) { + qpid::sys::windows::processExiting = true; + // Danger: all threading suspect, including indirect use of malloc or locks. + // Think twice before adding more functionality here. + return TRUE; + } + else { + qpid::sys::windows::libraryUnloading = true; + } + break; + } + return TRUE; +} + + +#endif diff --git a/cpp/src/qpid/sys/windows/SslAsynchIO.cpp b/cpp/src/qpid/sys/windows/SslAsynchIO.cpp index d263f00ab3..e48c799b29 100644 --- a/cpp/src/qpid/sys/windows/SslAsynchIO.cpp +++ b/cpp/src/qpid/sys/windows/SslAsynchIO.cpp @@ -209,18 +209,6 @@ bool SslAsynchIO::writeQueueEmpty() { return aio->writeQueueEmpty(); } -/* - * Initiate a read operation. AsynchIO::readComplete() will be - * called when the read is complete and data is available. - */ -void SslAsynchIO::startReading() { - aio->startReading(); -} - -void SslAsynchIO::stopReading() { - aio->stopReading(); -} - // Queue the specified callback for invocation from an I/O thread. void SslAsynchIO::requestCallback(RequestCallback callback) { aio->requestCallback(callback); @@ -241,11 +229,15 @@ AsynchIO::BufferBase* SslAsynchIO::getQueuedBuffer() { return sslBuff; } -unsigned int SslAsynchIO::getSslKeySize() { +SecuritySettings SslAsynchIO::getSecuritySettings() { SecPkgContext_KeyInfo info; memset(&info, 0, sizeof(info)); ::QueryContextAttributes(&ctxtHandle, SECPKG_ATTR_KEY_INFO, &info); - return info.KeySize; + + SecuritySettings settings; + settings.ssf = info.KeySize; + settings.authid = std::string(); + return settings; } void SslAsynchIO::negotiationDone() { diff --git a/cpp/src/qpid/sys/windows/SslAsynchIO.h b/cpp/src/qpid/sys/windows/SslAsynchIO.h index e9d9e8d629..2f6842b135 100644 --- a/cpp/src/qpid/sys/windows/SslAsynchIO.h +++ b/cpp/src/qpid/sys/windows/SslAsynchIO.h @@ -77,12 +77,9 @@ public: virtual void notifyPendingWrite(); virtual void queueWriteClose(); virtual bool writeQueueEmpty(); - virtual void startReading(); - virtual void stopReading(); virtual void requestCallback(RequestCallback); virtual BufferBase* getQueuedBuffer(); - - QPID_COMMON_EXTERN unsigned int getSslKeySize(); + virtual SecuritySettings getSecuritySettings(void); protected: CredHandle credHandle; diff --git a/cpp/src/qpid/sys/windows/SystemInfo.cpp b/cpp/src/qpid/sys/windows/SystemInfo.cpp index cef78dcc60..fb58d53b81 100755 --- a/cpp/src/qpid/sys/windows/SystemInfo.cpp +++ b/cpp/src/qpid/sys/windows/SystemInfo.cpp @@ -25,7 +25,8 @@ #include "qpid/sys/SystemInfo.h" #include "qpid/sys/IntegerTypes.h" -#include "qpid/Exception.h"
+#include "qpid/Exception.h" +#include "qpid/log/Statement.h" #include <assert.h> #include <winsock2.h> @@ -66,39 +67,10 @@ bool SystemInfo::getLocalHostname (Address &address) { static const std::string LOCALHOST("127.0.0.1"); static const std::string TCP("tcp"); -void SystemInfo::getLocalIpAddresses (uint16_t port, - std::vector<Address> &addrList) { - enum { MAX_URL_INTERFACES = 100 }; - - SOCKET s = socket (PF_INET, SOCK_STREAM, 0); - if (s != INVALID_SOCKET) { - INTERFACE_INFO interfaces[MAX_URL_INTERFACES]; - DWORD filledBytes = 0; - WSAIoctl (s, - SIO_GET_INTERFACE_LIST, - 0, - 0, - interfaces, - sizeof (interfaces), - &filledBytes, - 0, - 0); - unsigned int interfaceCount = filledBytes / sizeof (INTERFACE_INFO); - for (unsigned int i = 0; i < interfaceCount; ++i) { - if (interfaces[i].iiFlags & IFF_UP) { - std::string addr(inet_ntoa(interfaces[i].iiAddress.AddressIn.sin_addr)); - if (addr != LOCALHOST) - addrList.push_back(Address(TCP, addr, port)); - } - } - closesocket (s); - } -} - -bool SystemInfo::isLocalHost(const std::string& candidateHost) { - // FIXME aconway 2012-05-03: not implemented. - assert(0); - throw Exception("Not implemented: isLocalHost"); +// Null function which always fails to find an network interface name +bool SystemInfo::getInterfaceAddresses(const std::string&, std::vector<std::string>&) +{ + return false; } void SystemInfo::getSystemId (std::string &osName, @@ -208,4 +180,29 @@ std::string SystemInfo::getProcessName() return name; } + +#ifdef _DLL +namespace windows { +// set from one or more Qpid DLLs: i.e. in DllMain with DLL_PROCESS_DETACH +QPID_EXPORT bool processExiting = false; +QPID_EXPORT bool libraryUnloading = false; +} +#endif + +bool SystemInfo::threadSafeShutdown() +{ +#ifdef _DLL + if (!windows::processExiting && !windows::libraryUnloading) { + // called before exit() or FreeLibrary(), or by a DLL without + // a participating DllMain. + QPID_LOG(warning, "invalid query for shutdown state"); + throw qpid::Exception(QPID_MSG("Unable to determine shutdown state.")); + } + return !windows::processExiting; +#else + // Not a DLL: shutdown can only be by exit() or return from main(). + return false; +#endif +} + }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/windows/Thread.cpp b/cpp/src/qpid/sys/windows/Thread.cpp index 23b0033be4..b342c9da1d 100755 --- a/cpp/src/qpid/sys/windows/Thread.cpp +++ b/cpp/src/qpid/sys/windows/Thread.cpp @@ -27,6 +27,7 @@ #include "qpid/sys/Thread.h" #include "qpid/sys/Runnable.h" #include "qpid/sys/windows/check.h" +#include "qpid/sys/SystemInfo.h" #include <process.h> #include <windows.h> @@ -274,8 +275,17 @@ Thread Thread::current() { #ifdef _DLL +namespace qpid { +namespace sys { +namespace windows { + +extern bool processExiting; +extern bool libraryUnloading; + +}}} // namespace qpid::sys::SystemInfo + // DllMain: called possibly many times in a process lifetime if dll -// loaded and freed repeatedly . Be mindful of Windows loader lock +// loaded and freed repeatedly. Be mindful of Windows loader lock // and other DllMain restrictions. BOOL APIENTRY DllMain(HMODULE hm, DWORD reason, LPVOID reserved) { @@ -290,10 +300,12 @@ BOOL APIENTRY DllMain(HMODULE hm, DWORD reason, LPVOID reserved) { if (reserved != NULL) { // process exit(): threads are stopped arbitrarily and // possibly in an inconsistent state. Not even threadLock - // can be trusted. All static destructors have been - // called at this point and any resources this unit knows - // about will be released as part of process tear down by - // the OS. Accordingly, do nothing. + // can be trusted. All static destructors for this unit + // are pending and face the same unsafe environment. + // Any resources this unit knows about will be released as + // part of process tear down by the OS. Accordingly, skip + // any clean up tasks. + qpid::sys::windows::processExiting = true; return TRUE; } else { @@ -301,6 +313,7 @@ BOOL APIENTRY DllMain(HMODULE hm, DWORD reason, LPVOID reserved) { // encouraged to clean up to avoid leaks. Mostly we just // want any straggler threads to finish and notify // threadsDone as the last thing they do. + qpid::sys::windows::libraryUnloading = true; while (1) { { ScopedCriticalSection l(threadLock); diff --git a/cpp/src/qpid/sys/windows/Socket.cpp b/cpp/src/qpid/sys/windows/WinSocket.cpp index a4374260cc..b2d2d79c63 100644 --- a/cpp/src/qpid/sys/windows/Socket.cpp +++ b/cpp/src/qpid/sys/windows/WinSocket.cpp @@ -19,18 +19,12 @@ * */ -#include "qpid/sys/Socket.h" +#include "qpid/sys/windows/WinSocket.h" #include "qpid/sys/SocketAddress.h" #include "qpid/sys/windows/check.h" #include "qpid/sys/windows/IoHandlePrivate.h" - -// Ensure we get all of winsock2.h -#ifndef _WIN32_WINNT -#define _WIN32_WINNT 0x0501 -#endif - -#include <winsock2.h> +#include "qpid/sys/SystemInfo.h" namespace qpid { namespace sys { @@ -67,7 +61,8 @@ public: } ~WinSockSetup() { - WSACleanup(); + if (SystemInfo::threadSafeShutdown()) + WSACleanup(); } public: @@ -106,22 +101,32 @@ uint16_t getLocalPort(int fd) } } // namespace -Socket::Socket() : - IOHandle(new IOHandlePrivate), +WinSocket::WinSocket() : + handle(new IOHandle), nonblocking(false), nodelay(false) {} -Socket::Socket(IOHandlePrivate* h) : - IOHandle(h), +Socket* createSocket() +{ + return new WinSocket; +} + +WinSocket::WinSocket(SOCKET fd) : + handle(new IOHandle(fd)), nonblocking(false), nodelay(false) {} -void Socket::createSocket(const SocketAddress& sa) const +WinSocket::operator const IOHandle&() const { - SOCKET& socket = impl->fd; - if (socket != INVALID_SOCKET) Socket::close(); + return *handle; +} + +void WinSocket::createSocket(const SocketAddress& sa) const +{ + SOCKET& socket = handle->fd; + if (socket != INVALID_SOCKET) WinSocket::close(); SOCKET s = ::socket (getAddrInfo(sa).ai_family, getAddrInfo(sa).ai_socktype, @@ -139,39 +144,19 @@ void Socket::createSocket(const SocketAddress& sa) const } } -Socket* Socket::createSameTypeSocket() const { - SOCKET& socket = impl->fd; - // Socket currently has no actual socket attached - if (socket == INVALID_SOCKET) - return new Socket; - - ::sockaddr_storage sa; - ::socklen_t salen = sizeof(sa); - QPID_WINSOCK_CHECK(::getsockname(socket, (::sockaddr*)&sa, &salen)); - SOCKET s = ::socket(sa.ss_family, SOCK_STREAM, 0); // Currently only work with SOCK_STREAM - if (s == INVALID_SOCKET) throw QPID_WINDOWS_ERROR(WSAGetLastError()); - return new Socket(new IOHandlePrivate(s)); -} - -void Socket::setNonblocking() const { +void WinSocket::setNonblocking() const { u_long nonblock = 1; - QPID_WINSOCK_CHECK(ioctlsocket(impl->fd, FIONBIO, &nonblock)); -} - -void Socket::connect(const std::string& host, const std::string& port) const -{ - SocketAddress sa(host, port); - connect(sa); + QPID_WINSOCK_CHECK(ioctlsocket(handle->fd, FIONBIO, &nonblock)); } void -Socket::connect(const SocketAddress& addr) const +WinSocket::connect(const SocketAddress& addr) const { peername = addr.asString(false); createSocket(addr); - const SOCKET& socket = impl->fd; + const SOCKET& socket = handle->fd; int err; WSASetLastError(0); if ((::connect(socket, getAddrInfo(addr).ai_addr, getAddrInfo(addr).ai_addrlen) != 0) && @@ -180,44 +165,43 @@ Socket::connect(const SocketAddress& addr) const } void -Socket::close() const +WinSocket::finishConnect(const SocketAddress&) const { - SOCKET& socket = impl->fd; +} + +void +WinSocket::close() const +{ + SOCKET& socket = handle->fd; if (socket == INVALID_SOCKET) return; QPID_WINSOCK_CHECK(closesocket(socket)); socket = INVALID_SOCKET; } -int Socket::write(const void *buf, size_t count) const +int WinSocket::write(const void *buf, size_t count) const { - const SOCKET& socket = impl->fd; + const SOCKET& socket = handle->fd; int sent = ::send(socket, (const char *)buf, count, 0); if (sent == SOCKET_ERROR) return -1; return sent; } -int Socket::read(void *buf, size_t count) const +int WinSocket::read(void *buf, size_t count) const { - const SOCKET& socket = impl->fd; + const SOCKET& socket = handle->fd; int received = ::recv(socket, (char *)buf, count, 0); if (received == SOCKET_ERROR) return -1; return received; } -int Socket::listen(const std::string& host, const std::string& port, int backlog) const -{ - SocketAddress sa(host, port); - return listen(sa, backlog); -} - -int Socket::listen(const SocketAddress& addr, int backlog) const +int WinSocket::listen(const SocketAddress& addr, int backlog) const { createSocket(addr); - const SOCKET& socket = impl->fd; + const SOCKET& socket = handle->fd; BOOL yes=1; QPID_WINSOCK_CHECK(setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char *)&yes, sizeof(yes))); @@ -229,48 +213,48 @@ int Socket::listen(const SocketAddress& addr, int backlog) const return getLocalPort(socket); } -Socket* Socket::accept() const +Socket* WinSocket::accept() const { - SOCKET afd = ::accept(impl->fd, 0, 0); + SOCKET afd = ::accept(handle->fd, 0, 0); if (afd != INVALID_SOCKET) - return new Socket(new IOHandlePrivate(afd)); + return new WinSocket(afd); else if (WSAGetLastError() == EAGAIN) return 0; else throw QPID_WINDOWS_ERROR(WSAGetLastError()); } -std::string Socket::getPeerAddress() const +std::string WinSocket::getPeerAddress() const { if (peername.empty()) { - peername = getName(impl->fd, false); + peername = getName(handle->fd, false); } return peername; } -std::string Socket::getLocalAddress() const +std::string WinSocket::getLocalAddress() const { if (localname.empty()) { - localname = getName(impl->fd, true); + localname = getName(handle->fd, true); } return localname; } -int Socket::getError() const +int WinSocket::getError() const { int result; socklen_t rSize = sizeof (result); - QPID_WINSOCK_CHECK(::getsockopt(impl->fd, SOL_SOCKET, SO_ERROR, (char *)&result, &rSize)); + QPID_WINSOCK_CHECK(::getsockopt(handle->fd, SOL_SOCKET, SO_ERROR, (char *)&result, &rSize)); return result; } -void Socket::setTcpNoDelay() const +void WinSocket::setTcpNoDelay() const { - SOCKET& socket = impl->fd; + SOCKET& socket = handle->fd; nodelay = true; if (socket != INVALID_SOCKET) { int flag = 1; - int result = setsockopt(impl->fd, + int result = setsockopt(handle->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, @@ -279,14 +263,14 @@ void Socket::setTcpNoDelay() const } } -inline IOHandlePrivate* IOHandlePrivate::getImpl(const qpid::sys::IOHandle &h) +int WinSocket::getKeyLen() const { - return h.impl; + return 0; } -SOCKET toSocketHandle(const Socket& s) +std::string WinSocket::getClientAuthId() const { - return IOHandlePrivate::getImpl(s)->fd; + return std::string(); } }} // namespace qpid::sys diff --git a/cpp/src/qpid/sys/windows/WinSocket.h b/cpp/src/qpid/sys/windows/WinSocket.h new file mode 100644 index 0000000000..bee6a58e7a --- /dev/null +++ b/cpp/src/qpid/sys/windows/WinSocket.h @@ -0,0 +1,118 @@ +#ifndef QPID_SYS_WINDOWS_BSDSOCKET_H +#define QPID_SYS_WINDOWS_BSDSOCKET_H + +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "qpid/sys/Socket.h" +#include "qpid/sys/IntegerTypes.h" +#include "qpid/CommonImportExport.h" +#include <string> + +#include <boost/scoped_ptr.hpp> + +// Ensure we get all of winsock2.h +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0501 +#endif + +#include <winsock2.h> + +namespace qpid { +namespace sys { + +namespace windows { +Socket* createSameTypeSocket(const Socket&); +} + +class Duration; +class IOHandle; +class SocketAddress; + +class QPID_COMMON_CLASS_EXTERN WinSocket : public Socket +{ +public: + /** Create a socket wrapper for descriptor. */ + QPID_COMMON_EXTERN WinSocket(); + + QPID_COMMON_EXTERN operator const IOHandle&() const; + + /** Set socket non blocking */ + QPID_COMMON_EXTERN virtual void setNonblocking() const; + + QPID_COMMON_EXTERN virtual void setTcpNoDelay() const; + + QPID_COMMON_EXTERN virtual void connect(const SocketAddress&) const; + QPID_COMMON_EXTERN virtual void finishConnect(const SocketAddress&) const; + + QPID_COMMON_EXTERN virtual void close() const; + + /** Bind to a port and start listening. + *@return The bound port number + */ + QPID_COMMON_EXTERN virtual int listen(const SocketAddress&, int backlog = 10) const; + + /** + * Returns an address (host and port) for the remote end of the + * socket + */ + QPID_COMMON_EXTERN std::string getPeerAddress() const; + /** + * Returns an address (host and port) for the local end of the + * socket + */ + QPID_COMMON_EXTERN std::string getLocalAddress() const; + + /** + * Returns the error code stored in the socket. This may be used + * to determine the result of a non-blocking connect. + */ + QPID_COMMON_EXTERN int getError() const; + + /** Accept a connection from a socket that is already listening + * and has an incoming connection + */ + QPID_COMMON_EXTERN virtual Socket* accept() const; + + // TODO The following are raw operations, maybe they need better wrapping? + QPID_COMMON_EXTERN virtual int read(void *buf, size_t count) const; + QPID_COMMON_EXTERN virtual int write(const void *buf, size_t count) const; + + QPID_COMMON_EXTERN int getKeyLen() const; + QPID_COMMON_EXTERN std::string getClientAuthId() const; + +protected: + /** Create socket */ + void createSocket(const SocketAddress&) const; + + mutable boost::scoped_ptr<IOHandle> handle; + mutable std::string localname; + mutable std::string peername; + mutable bool nonblocking; + mutable bool nodelay; + + /** Construct socket with existing handle */ + friend Socket* qpid::sys::windows::createSameTypeSocket(const Socket&); + WinSocket(SOCKET fd); +}; + +}} +#endif /*!QPID_SYS_WINDOWS_BSDSOCKET_H*/ |