diff options
author | Kim van der Riet <kpvdr@apache.org> | 2013-02-28 16:14:30 +0000 |
---|---|---|
committer | Kim van der Riet <kpvdr@apache.org> | 2013-02-28 16:14:30 +0000 |
commit | 9c73ef7a5ac10acd6a50d5d52bd721fc2faa5919 (patch) | |
tree | 2a890e1df09e5b896a9b4168a7b22648f559a1f2 /cpp/src/qpid/sys/ssl/SslSocket.cpp | |
parent | 172d9b2a16cfb817bbe632d050acba7e31401cd2 (diff) | |
download | qpid-python-asyncstore.tar.gz |
Update from trunk r1375509 through r1450773asyncstore
git-svn-id: https://svn.apache.org/repos/asf/qpid/branches/asyncstore@1451244 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'cpp/src/qpid/sys/ssl/SslSocket.cpp')
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.cpp | 182 |
1 files changed, 71 insertions, 111 deletions
diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp index 30234bb686..a328e49c13 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.cpp +++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -20,6 +20,7 @@ */ #include "qpid/sys/ssl/SslSocket.h" +#include "qpid/sys/SocketAddress.h" #include "qpid/sys/ssl/check.h" #include "qpid/sys/ssl/util.h" #include "qpid/Exception.h" @@ -52,28 +53,6 @@ namespace sys { namespace ssl { namespace { -std::string getService(int fd, bool local) -{ - ::sockaddr_storage name; // big enough for any socket address - ::socklen_t namelen = sizeof(name); - - int result = -1; - if (local) { - result = ::getsockname(fd, (::sockaddr*)&name, &namelen); - } else { - result = ::getpeername(fd, (::sockaddr*)&name, &namelen); - } - - QPID_POSIX_CHECK(result); - - char servName[NI_MAXSERV]; - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, 0, 0, - servName, sizeof(servName), - NI_NUMERICHOST | NI_NUMERICSERV) != 0) - throw QPID_POSIX_ERROR(rc); - return servName; -} - const std::string DOMAIN_SEPARATOR("@"); const std::string DC_SEPARATOR("."); const std::string DC("DC"); @@ -101,14 +80,18 @@ std::string getDomainFromSubject(std::string subject) } return domain; } - } -SslSocket::SslSocket() : socket(0), prototype(0) +SslSocket::SslSocket(const std::string& certName, bool clientAuth) : + nssSocket(0), certname(certName), prototype(0) { - impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); - if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); - socket = SSL_ImportFD(0, PR_ImportTCPSocket(impl->fd)); + //configure prototype socket: + prototype = SSL_ImportFD(0, PR_NewTCPSocket()); + + if (clientAuth) { + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); + NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); + } } /** @@ -116,25 +99,44 @@ SslSocket::SslSocket() : socket(0), prototype(0) * returned from accept. Because we use posix accept rather than * PR_Accept, we have to reset the handshake. */ -SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0) +SslSocket::SslSocket(int fd, PRFileDesc* model) : BSDSocket(fd), nssSocket(0), prototype(0) { - socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); - NSS_CHECK(SSL_ResetHandshake(socket, true)); + nssSocket = SSL_ImportFD(model, PR_ImportTCPSocket(fd)); + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_TRUE)); } void SslSocket::setNonblocking() const { + if (!nssSocket) { + BSDSocket::setNonblocking(); + return; + } PRSocketOptionData option; option.option = PR_SockOpt_Nonblocking; option.value.non_blocking = true; - PR_SetSocketOption(socket, &option); + PR_SetSocketOption(nssSocket, &option); } -void SslSocket::connect(const std::string& host, const std::string& port) const +void SslSocket::setTcpNoDelay() const { - std::stringstream namestream; - namestream << host << ":" << port; - connectname = namestream.str(); + if (!nssSocket) { + BSDSocket::setTcpNoDelay(); + return; + } + PRSocketOptionData option; + option.option = PR_SockOpt_NoDelay; + option.value.no_delay = true; + PR_SetSocketOption(nssSocket, &option); +} + +void SslSocket::connect(const SocketAddress& addr) const +{ + BSDSocket::connect(addr); +} + +void SslSocket::finishConnect(const SocketAddress& addr) const +{ + nssSocket = SSL_ImportFD(0, PR_ImportTCPSocket(fd)); void* arg; // Use the connection's cert-name if it has one; else use global cert-name @@ -145,75 +147,48 @@ void SslSocket::connect(const std::string& host, const std::string& port) const } else { arg = const_cast<char*>(SslOptions::global.certName.c_str()); } - NSS_CHECK(SSL_GetClientAuthDataHook(socket, NSS_GetClientAuthData, arg)); - NSS_CHECK(SSL_SetURL(socket, host.data())); - - char hostBuffer[PR_NETDB_BUF_SIZE]; - PRHostEnt hostEntry; - PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry)); - PRNetAddr address; - int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address); - if (value < 0) { - throw Exception(QPID_MSG("Error getting address for host: " << ErrorString())); - } else if (value == 0) { - throw Exception(QPID_MSG("Could not resolve address for host.")); - } - PR_CHECK(PR_Connect(socket, &address, PR_INTERVAL_NO_TIMEOUT)); - NSS_CHECK(SSL_ForceHandshake(socket)); + NSS_CHECK(SSL_GetClientAuthDataHook(nssSocket, NSS_GetClientAuthData, arg)); + + url = addr.getHost(); + NSS_CHECK(SSL_SetURL(nssSocket, url.data())); + + NSS_CHECK(SSL_ResetHandshake(nssSocket, PR_FALSE)); + NSS_CHECK(SSL_ForceHandshake(nssSocket)); } void SslSocket::close() const { - if (impl->fd > 0) { - PR_Close(socket); - impl->fd = -1; + if (!nssSocket) { + BSDSocket::close(); + return; + } + if (fd > 0) { + PR_Close(nssSocket); + fd = -1; } } -int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, bool clientAuth) const +int SslSocket::listen(const SocketAddress& sa, int backlog) const { - //configure prototype socket: - prototype = SSL_ImportFD(0, PR_NewTCPSocket()); - if (clientAuth) { - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - NSS_CHECK(SSL_OptionSet(prototype, SSL_REQUIRE_CERTIFICATE, PR_TRUE)); - } - //get certificate and key (is this the correct way?) - CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(certName.c_str()), 0); - if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << certName << "'")); + std::string cName( (certname == "") ? "localhost.localdomain" : certname); + CERTCertificate *cert = PK11_FindCertFromNickname(const_cast<char*>(cName.c_str()), 0); + if (!cert) throw Exception(QPID_MSG("Failed to load certificate '" << cName << "'")); SECKEYPrivateKey *key = PK11_FindKeyByAnyCert(cert, 0); if (!key) throw Exception(QPID_MSG("Failed to retrieve private key from certificate")); NSS_CHECK(SSL_ConfigSecureServer(prototype, cert, key, NSS_FindCertKEAType(cert))); SECKEY_DestroyPrivateKey(key); CERT_DestroyCertificate(cert); - //bind and listen - const int& socket = impl->fd; - int yes=1; - QPID_POSIX_CHECK(setsockopt(socket,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(yes))); - struct sockaddr_in name; - name.sin_family = AF_INET; - name.sin_port = htons(port); - name.sin_addr.s_addr = 0; - if (::bind(socket, (struct sockaddr*)&name, sizeof(name)) < 0) - throw Exception(QPID_MSG("Can't bind to port " << port << ": " << strError(errno))); - if (::listen(socket, backlog) < 0) - throw Exception(QPID_MSG("Can't listen on port " << port << ": " << strError(errno))); - - socklen_t namelen = sizeof(name); - if (::getsockname(socket, (struct sockaddr*)&name, &namelen) < 0) - throw QPID_POSIX_ERROR(errno); - - return ntohs(name.sin_port); + return BSDSocket::listen(sa, backlog); } -SslSocket* SslSocket::accept() const +Socket* SslSocket::accept() const { QPID_LOG(trace, "Accepting SSL connection."); - int afd = ::accept(impl->fd, 0, 0); + int afd = ::accept(fd, 0, 0); if ( afd >= 0) { - return new SslSocket(new IOHandlePrivate(afd), prototype); + return new SslSocket(afd, prototype); } else if (errno == EAGAIN) { return 0; } else { @@ -297,17 +272,22 @@ static bool isSslStream(int afd) { return isSSL2Handshake || isSSL3Handshake; } +SslMuxSocket::SslMuxSocket(const std::string& certName, bool clientAuth) : + SslSocket(certName, clientAuth) +{ +} + Socket* SslMuxSocket::accept() const { - int afd = ::accept(impl->fd, 0, 0); + int afd = ::accept(fd, 0, 0); if (afd >= 0) { QPID_LOG(trace, "Accepting connection with optional SSL wrapper."); if (isSslStream(afd)) { QPID_LOG(trace, "Accepted SSL connection."); - return new SslSocket(new IOHandlePrivate(afd), prototype); + return new SslSocket(afd, prototype); } else { QPID_LOG(trace, "Accepted Plaintext connection."); - return new Socket(new IOHandlePrivate(afd)); + return new BSDSocket(afd); } } else if (errno == EAGAIN) { return 0; @@ -318,32 +298,12 @@ Socket* SslMuxSocket::accept() const int SslSocket::read(void *buf, size_t count) const { - return PR_Read(socket, buf, count); + return PR_Read(nssSocket, buf, count); } int SslSocket::write(const void *buf, size_t count) const { - return PR_Write(socket, buf, count); -} - -uint16_t SslSocket::getLocalPort() const -{ - return std::atoi(getService(impl->fd, true).c_str()); -} - -uint16_t SslSocket::getRemotePort() const -{ - return atoi(getService(impl->fd, true).c_str()); -} - -void SslSocket::setTcpNoDelay(bool nodelay) const -{ - if (nodelay) { - PRSocketOptionData option; - option.option = PR_SockOpt_NoDelay; - option.value.no_delay = true; - PR_SetSocketOption(socket, &option); - } + return PR_Write(nssSocket, buf, count); } void SslSocket::setCertName(const std::string& name) @@ -359,7 +319,7 @@ int SslSocket::getKeyLen() const int keySize = 0; SECStatus rc; - rc = SSL_SecurityStatus( socket, + rc = SSL_SecurityStatus( nssSocket, &enabled, NULL, NULL, @@ -374,7 +334,7 @@ int SslSocket::getKeyLen() const std::string SslSocket::getClientAuthId() const { std::string authId; - CERTCertificate* cert = SSL_PeerCertificate(socket); + CERTCertificate* cert = SSL_PeerCertificate(nssSocket); if (cert) { authId = CERT_GetCommonName(&(cert->subject)); /* |