From 876d0b94c37f252b08c81656386100fad18a8a46 Mon Sep 17 00:00:00 2001 From: Alan Conway Date: Wed, 21 Feb 2007 19:25:45 +0000 Subject: Thread safety fixes for race conditions on incoming messages. * cpp/lib/client/MessageListener.h: const correctness. * cpp/tests/*: MessageListener const change. * cpp/lib/broker/Content.h: Removed out-of-date FIXME comments. * cpp/lib/client/ClientChannel.h/ .cpp(): - added locking for consumers map and other member access. - refactored implementations of Basic get, deliver, return: most logic now encapsulted in IncomingMessage class. - fix channel close problems. * cpp/lib/client/ClientMessage.h/.cpp: - const correctness & API convenience fixes. - getMethod/setMethod/getHeader: for new IncomingMessage * cpp/lib/client/Connection.h/.cpp: - Fixes to channel closure. * cpp/lib/client/IncomingMessage.h/.cpp: - Encapsulate *all* incoming message handling for client. - Moved handling of BasicGetOk to IncomingMessage to fix race. - Thread safety fixes. * cpp/lib/client/ResponseHandler.h/.cpp: - added getResponse for ClientChannel. * cpp/lib/common/Exception.h: - added missing throwSelf implementations. - added ShutdownException as general purpose shut-down indicator. - added EmptyException as general purpose "empty" indicator. * cpp/lib/common/sys/Condition|Monitor|Mutex.h|.cpp: - Condition variable abstraction extracted from Monitor for situations where a single lock is associated with multiple conditions. * cpp/tests/ClientChannelTest.cpp: - Test incoming message transfer, get, consume etc. git-svn-id: https://svn.apache.org/repos/asf/incubator/qpid/branches/qpid.0-9@510161 13f79535-47bb-0310-9956-ffa450edef68 --- cpp/lib/client/IncomingMessage.cpp | 152 ++++++++++++++++++++++++++++++------- 1 file changed, 124 insertions(+), 28 deletions(-) (limited to 'cpp/lib/client/IncomingMessage.cpp') diff --git a/cpp/lib/client/IncomingMessage.cpp b/cpp/lib/client/IncomingMessage.cpp index c1f6ca880f..07f94ceb64 100644 --- a/cpp/lib/client/IncomingMessage.cpp +++ b/cpp/lib/client/IncomingMessage.cpp @@ -19,58 +19,154 @@ * */ #include +#include "framing/AMQHeaderBody.h" +#include "framing/AMQContentBody.h" +#include "BasicGetOkBody.h" +#include "BasicReturnBody.h" +#include "BasicDeliverBody.h" #include #include -using namespace qpid::client; -using namespace qpid::framing; +namespace qpid { +namespace client { -IncomingMessage::IncomingMessage(BasicDeliverBody::shared_ptr intro) : delivered(intro){} -IncomingMessage::IncomingMessage(BasicReturnBody::shared_ptr intro): returned(intro){} -IncomingMessage::IncomingMessage(BasicGetOkBody::shared_ptr intro): response(intro){} +using namespace sys; +using namespace framing; -IncomingMessage::~IncomingMessage(){ +struct IncomingMessage::Guard: public Mutex::ScopedLock { + Guard(IncomingMessage* im) : Mutex::ScopedLock(im->lock) { + im->shutdownError.throwIf(); + } +}; + +IncomingMessage::IncomingMessage() { reset(); } + +void IncomingMessage::reset() { + state = &IncomingMessage::expectRequest; + endFn= &IncomingMessage::endRequest; + buildMessage = Message(); +} + +void IncomingMessage::startGet() { + Guard g(this); + if (state != &IncomingMessage::expectRequest) { + endGet(new QPID_ERROR(CLIENT_ERROR, "Message already in progress.")); + } + else { + state = &IncomingMessage::expectGetOk; + endFn = &IncomingMessage::endGet; + getError.reset(); + getState = GETTING; + } } -void IncomingMessage::setHeader(AMQHeaderBody::shared_ptr _header){ - this->header = _header; +bool IncomingMessage::waitGet(Message& msg) { + Guard g(this); + while (getState == GETTING && !shutdownError && !getError) + getReady.wait(lock); + shutdownError.throwIf(); + getError.throwIf(); + msg = getMessage; + return getState==GOT; } -void IncomingMessage::addContent(AMQContentBody::shared_ptr content){ - data.append(content->getData()); +Message IncomingMessage::waitDispatch() { + Guard g(this); + while(dispatchQueue.empty() && !shutdownError) + dispatchReady.wait(lock); + shutdownError.throwIf(); + + Message msg(dispatchQueue.front()); + dispatchQueue.pop(); + return msg; } -bool IncomingMessage::isComplete(){ - return header != 0 && header->getContentSize() == data.size(); +void IncomingMessage::add(BodyPtr body) { + Guard g(this); + shutdownError.throwIf(); + // Call the current state function. + (this->*state)(body); } -bool IncomingMessage::isReturn(){ - return returned; +void IncomingMessage::shutdown() { + Mutex::ScopedLock l(lock); + shutdownError.reset(new ShutdownException()); + getReady.notify(); + dispatchReady.notify(); } -bool IncomingMessage::isDelivery(){ - return delivered; +bool IncomingMessage::isShutdown() const { + Mutex::ScopedLock l(lock); + return shutdownError; } -bool IncomingMessage::isResponse(){ - return response; +// Common check for all the expect functions. Called in network thread. +template +boost::shared_ptr IncomingMessage::expectCheck(BodyPtr body) { + boost::shared_ptr ptr = boost::dynamic_pointer_cast(body); + if (!ptr) + throw QPID_ERROR(PROTOCOL_ERROR+504, "Unexpected frame type"); + return ptr; } -const string& IncomingMessage::getConsumerTag(){ - if(!isDelivery()) THROW_QPID_ERROR(CLIENT_ERROR, "Consumer tag only valid for delivery"); - return delivered->getConsumerTag(); +void IncomingMessage::expectGetOk(BodyPtr body) { + if (dynamic_cast(body.get())) + state = &IncomingMessage::expectHeader; + else if (dynamic_cast(body.get())) { + getState = EMPTY; + endGet(); + } + else + throw QPID_ERROR(PROTOCOL_ERROR+504, "Unexpected frame type"); } -u_int64_t IncomingMessage::getDeliveryTag(){ - if(!isDelivery()) THROW_QPID_ERROR(CLIENT_ERROR, "Delivery tag only valid for delivery"); - return delivered->getDeliveryTag(); +void IncomingMessage::expectHeader(BodyPtr body) { + AMQHeaderBody::shared_ptr header = expectCheck(body); + buildMessage.header = header; + state = &IncomingMessage::expectContent; + checkComplete(); } -AMQHeaderBody::shared_ptr& IncomingMessage::getHeader(){ - return header; +void IncomingMessage::expectContent(BodyPtr body) { + AMQContentBody::shared_ptr content = expectCheck(body); + buildMessage.setData(buildMessage.getData() + content->getData()); + checkComplete(); +} + +void IncomingMessage::checkComplete() { + size_t declaredSize = buildMessage.header->getContentSize(); + size_t currentSize = buildMessage.getData().size(); + if (declaredSize == currentSize) + (this->*endFn)(0); + else if (declaredSize < currentSize) + (this->*endFn)(new QPID_ERROR( + PROTOCOL_ERROR, "Message content exceeds declared size.")); +} + +void IncomingMessage::expectRequest(BodyPtr body) { + AMQMethodBody::shared_ptr method = expectCheck(body); + buildMessage.setMethod(method); + state = &IncomingMessage::expectHeader; +} + +void IncomingMessage::endGet(Exception* ex) { + getError.reset(ex); + if (getState == GETTING) { + getMessage = buildMessage; + getState = GOT; + } + reset(); + getReady.notify(); } -std::string IncomingMessage::getData() const { - return data; +void IncomingMessage::endRequest(Exception* ex) { + ExceptionHolder eh(ex); + if (!eh) { + dispatchQueue.push(buildMessage); + reset(); + dispatchReady.notify(); + } + eh.throwIf(); } +}} // namespace qpid::client -- cgit v1.2.1