diff options
| author | Robert Godfrey <rgodfrey@apache.org> | 2013-04-07 20:57:23 +0000 |
|---|---|---|
| committer | Robert Godfrey <rgodfrey@apache.org> | 2013-04-07 20:57:23 +0000 |
| commit | bffa6ec58c3ca61282eedd3882d175d544d428a8 (patch) | |
| tree | 608081b8548cda00e115ecf416ed16a231b74a02 /java | |
| parent | 1a45a3423c2c17968ff0018e5d2489853a6018f3 (diff) | |
| download | qpid-python-bffa6ec58c3ca61282eedd3882d175d544d428a8.tar.gz | |
QPID-4726: [Java Broker] AMQP 1.0 : Improve SASL support
git-svn-id: https://svn.apache.org/repos/asf/qpid/trunk/qpid@1465459 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'java')
5 files changed, 159 insertions, 134 deletions
diff --git a/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java b/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java index cdb2007b4a..0ef286e89e 100644 --- a/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java +++ b/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java @@ -21,7 +21,8 @@ package org.apache.qpid.amqp_1_0.transport; -import java.util.List; +import java.util.HashSet; +import java.util.Set; import org.apache.qpid.amqp_1_0.codec.DescribedTypeConstructorRegistry; import org.apache.qpid.amqp_1_0.codec.ValueWriter; import org.apache.qpid.amqp_1_0.framing.AMQFrame; @@ -59,12 +60,15 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { private static final short CONNECTION_CONTROL_CHANNEL = (short) 0; private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new byte[0]); + private static final Symbol SASL_PLAIN = Symbol.valueOf("PLAIN"); + private static final Symbol SASL_ANONYMOUS = Symbol.valueOf("ANONYMOUS"); + private static final Symbol SASL_EXTERNAL = Symbol.valueOf("EXTERNAL"); private final Container _container; private Principal _user; private static final short DEFAULT_CHANNEL_MAX = 255; - private static final int DEFAULT_MAX_FRAME = Integer.getInteger("amqp.max_frame_size",1<<15); + private static final int DEFAULT_MAX_FRAME = Integer.getInteger("amqp.max_frame_size", 1 << 15); private ConnectionState _state = ConnectionState.UNOPENED; @@ -75,20 +79,20 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour private SocketAddress _remoteAddress; // positioned by the *outgoing* channel - private SessionEndpoint[] _sendingSessions = new SessionEndpoint[DEFAULT_CHANNEL_MAX+1]; + private SessionEndpoint[] _sendingSessions = new SessionEndpoint[DEFAULT_CHANNEL_MAX + 1]; // positioned by the *incoming* channel - private SessionEndpoint[] _receivingSessions = new SessionEndpoint[DEFAULT_CHANNEL_MAX+1]; + private SessionEndpoint[] _receivingSessions = new SessionEndpoint[DEFAULT_CHANNEL_MAX + 1]; private boolean _closedForInput; private boolean _closedForOutput; private long _idleTimeout; private AMQPDescribedTypeRegistry _describedTypeRegistry = AMQPDescribedTypeRegistry.newInstance() - .registerTransportLayer() - .registerMessagingLayer() - .registerTransactionLayer() - .registerSecurityLayer(); + .registerTransportLayer() + .registerMessagingLayer() + .registerTransactionLayer() + .registerSecurityLayer(); private FrameOutputHandler<FrameBody> _frameOutputHandler; @@ -135,11 +139,11 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized void open() { - if(_requiresSASLClient) + if (_requiresSASLClient) { synchronized (getLock()) { - while(!_saslComplete) + while (!_saslComplete) { try { @@ -151,12 +155,12 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour } } } - if(!_authenticated) + if (!_authenticated) { throw new RuntimeException("Could not connect - authentication error"); } } - if(_state == ConnectionState.UNOPENED) + if (_state == ConnectionState.UNOPENED) { sendOpen(DEFAULT_CHANNEL_MAX, DEFAULT_MAX_FRAME); _state = ConnectionState.AWAITING_OPEN; @@ -172,8 +176,8 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { // todo assert connection state SessionEndpoint endpoint = new SessionEndpoint(this); - short channel = getFirstFreeChannel(); - if(channel != -1) + short channel = getFirstFreeChannel(); + if (channel != -1) { _sendingSessions[channel] = endpoint; endpoint.setSendingChannel(channel); @@ -244,8 +248,6 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour } - - private void closeSender() { setClosedForOutput(true); @@ -255,9 +257,9 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour short getFirstFreeChannel() { - for(int i = 0; i<_sendingSessions.length;i++) + for (int i = 0; i < _sendingSessions.length; i++) { - if(_sendingSessions[i]==null) + if (_sendingSessions[i] == null) { return (short) i; } @@ -276,22 +278,25 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { _channelMax = open.getChannelMax() == null ? DEFAULT_CHANNEL_MAX - : open.getChannelMax().shortValue() < DEFAULT_CHANNEL_MAX - ? DEFAULT_CHANNEL_MAX - : open.getChannelMax().shortValue(); + : open.getChannelMax().shortValue() < DEFAULT_CHANNEL_MAX + ? DEFAULT_CHANNEL_MAX + : open.getChannelMax().shortValue(); - UnsignedInteger remoteDesiredMaxFrameSize = open.getMaxFrameSize() == null ? UnsignedInteger.valueOf(DEFAULT_MAX_FRAME) : open.getMaxFrameSize(); + UnsignedInteger remoteDesiredMaxFrameSize = + open.getMaxFrameSize() == null ? UnsignedInteger.valueOf(DEFAULT_MAX_FRAME) : open.getMaxFrameSize(); - _maxFrameSize = (remoteDesiredMaxFrameSize.compareTo(_desiredMaxFrameSize) < 0 ? remoteDesiredMaxFrameSize : _desiredMaxFrameSize).intValue(); + _maxFrameSize = (remoteDesiredMaxFrameSize.compareTo(_desiredMaxFrameSize) < 0 + ? remoteDesiredMaxFrameSize + : _desiredMaxFrameSize).intValue(); _remoteContainerId = open.getContainerId(); - if(open.getIdleTimeOut() != null) + if (open.getIdleTimeOut() != null) { _idleTimeout = open.getIdleTimeOut().longValue(); } - switch(_state) + switch (_state) { case UNOPENED: sendOpen(_channelMax, _maxFrameSize); @@ -313,7 +318,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { setClosedForInput(true); _connectionEventListener.closeReceived(); - switch(_state) + switch (_state) { case UNOPENED: case AWAITING_OPEN: @@ -341,7 +346,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { Close close = new Close(); close.setError(error); - switch(_state) + switch (_state) { case UNOPENED: _state = ConnectionState.CLOSED; @@ -359,17 +364,17 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour } } - public synchronized void inputClosed() + public synchronized void inputClosed() { - if(!_closedForInput) + if (!_closedForInput) { _closedForInput = true; - for(int i = 0; i < _receivingSessions.length; i++) + for (int i = 0; i < _receivingSessions.length; i++) { - if(_receivingSessions[i] != null) + if (_receivingSessions[i] != null) { _receivingSessions[i].end(); - _receivingSessions[i]=null; + _receivingSessions[i] = null; } } @@ -395,8 +400,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour short myChannelId; - - if(begin.getRemoteChannel() != null) + if (begin.getRemoteChannel() != null) { myChannelId = begin.getRemoteChannel().shortValue(); SessionEndpoint endpoint; @@ -404,7 +408,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { endpoint = _sendingSessions[myChannelId]; } - catch(IndexOutOfBoundsException e) + catch (IndexOutOfBoundsException e) { final Error error = new Error(); error.setCondition(ConnectionError.FRAMING_ERROR); @@ -414,9 +418,9 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour connectionError(error); return; } - if(endpoint != null) + if (endpoint != null) { - if(_receivingSessions[channel] == null) + if (_receivingSessions[channel] == null) { _receivingSessions[channel] = endpoint; endpoint.setReceivingChannel(channel); @@ -446,16 +450,16 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { myChannelId = getFirstFreeChannel(); - if(myChannelId == -1) + if (myChannelId == -1) { // close any half open channel myChannelId = getFirstFreeChannel(); } - if(_receivingSessions[channel] == null) + if (_receivingSessions[channel] == null) { - SessionEndpoint endpoint = new SessionEndpoint(this,begin); + SessionEndpoint endpoint = new SessionEndpoint(this, begin); _receivingSessions[channel] = endpoint; _sendingSessions[myChannelId] = endpoint; @@ -483,15 +487,13 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour } - } - public synchronized void receiveEnd(short channel, End end) { SessionEndpoint endpoint = _receivingSessions[channel]; - if(endpoint != null) + if (endpoint != null) { _receivingSessions[channel] = null; @@ -551,18 +553,18 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized int send(short channel, FrameBody body, ByteBuffer payload) { - if(!_closedForOutput) + if (!_closedForOutput) { ValueWriter<FrameBody> writer = _describedTypeRegistry.getValueWriter(body); int size = writer.writeToBuffer(EMPTY_BYTE_BUFFER); ByteBuffer payloadDup = payload == null ? null : payload.duplicate(); int payloadSent = getMaxFrameSize() - (size + 9); - if(payloadSent < (payload == null ? 0 : payload.remaining())) + if (payloadSent < (payload == null ? 0 : payload.remaining())) { - if(body instanceof Transfer) + if (body instanceof Transfer) { - ((Transfer)body).setMore(Boolean.TRUE); + ((Transfer) body).setMore(Boolean.TRUE); } writer = _describedTypeRegistry.getValueWriter(body); @@ -571,9 +573,9 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour try { - payloadDup.limit(payloadDup.position()+payloadSent); + payloadDup.limit(payloadDup.position() + payloadSent); } - catch(NullPointerException npe) + catch (NullPointerException npe) { throw npe; } @@ -592,7 +594,6 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour } - public void invalidHeaderReceived() { // TODO @@ -606,7 +607,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized void protocolHeaderReceived(final byte major, final byte minorVersion, final byte revision) { - if(_requiresSASLServer && _state != ConnectionState.UNOPENED) + if (_requiresSASLServer && _state != ConnectionState.UNOPENED) { // TODO - bad stuff } @@ -618,7 +619,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized void handleError(final Error error) { - if(!closedForOutput()) + if (!closedForOutput()) { Close close = new Close(); close.setError(error); @@ -631,17 +632,17 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized void receive(final short channel, final Object frame) { - if(_logger.isLoggable(Level.FINE)) + if (_logger.isLoggable(Level.FINE)) { - _logger.fine("RECV["+ _remoteAddress + "|"+channel+"] : " + frame); + _logger.fine("RECV[" + _remoteAddress + "|" + channel + "] : " + frame); } - if(frame instanceof FrameBody) + if (frame instanceof FrameBody) { - ((FrameBody)frame).invoke(channel, this); + ((FrameBody) frame).invoke(channel, this); } - else if(frame instanceof SaslFrameBody) + else if (frame instanceof SaslFrameBody) { - ((SaslFrameBody)frame).invoke(this); + ((SaslFrameBody) frame).invoke(this); } } @@ -674,7 +675,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public synchronized void close() { - switch(_state) + switch (_state) { case AWAITING_OPEN: case OPEN: @@ -737,10 +738,11 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour { _saslComplete = true; _authenticated = true; + _user = _saslServerProvider.getAuthenticatedPrincipal(_saslServer); getLock().notifyAll(); } - if(_onSaslCompleteTask != null) + if (_onSaslCompleteTask != null) { _onSaslCompleteTask.run(); } @@ -766,7 +768,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour _authenticated = false; getLock().notifyAll(); } - if(_onSaslCompleteTask != null) + if (_onSaslCompleteTask != null) { _onSaslCompleteTask.run(); } @@ -776,19 +778,32 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public void receiveSaslMechanisms(final SaslMechanisms saslMechanisms) { - if(Arrays.asList(saslMechanisms.getSaslServerMechanisms()).contains(Symbol.valueOf("PLAIN"))) + SaslInit init = new SaslInit(); + init.setHostname(_remoteHostname); + + Set<Symbol> mechanisms = new HashSet<Symbol>(Arrays.asList(saslMechanisms.getSaslServerMechanisms())); + if (mechanisms.contains(SASL_PLAIN) && _password != null) { - SaslInit init = new SaslInit(); - init.setMechanism(Symbol.valueOf("PLAIN")); - init.setHostname(_remoteHostname); + + init.setMechanism(SASL_PLAIN); + byte[] usernameBytes = _user.getName().getBytes(Charset.forName("UTF-8")); byte[] passwordBytes = _password.getBytes(Charset.forName("UTF-8")); - byte[] initResponse = new byte[usernameBytes.length+passwordBytes.length+2]; - System.arraycopy(usernameBytes,0,initResponse,1,usernameBytes.length); - System.arraycopy(passwordBytes,0,initResponse,usernameBytes.length+2,passwordBytes.length); + byte[] initResponse = new byte[usernameBytes.length + passwordBytes.length + 2]; + System.arraycopy(usernameBytes, 0, initResponse, 1, usernameBytes.length); + System.arraycopy(passwordBytes, 0, initResponse, usernameBytes.length + 2, passwordBytes.length); init.setInitialResponse(new Binary(initResponse)); - _saslFrameOutput.send(new SASLFrame(init),null); + + } + else if (mechanisms.contains(SASL_ANONYMOUS)) + { + init.setMechanism(SASL_ANONYMOUS); + } + else if (mechanisms.contains(SASL_EXTERNAL)) + { + init.setMechanism(SASL_EXTERNAL); } + _saslFrameOutput.send(new SASLFrame(init), null); } public void receiveSaslChallenge(final SaslChallenge saslChallenge) @@ -798,65 +813,66 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour public void receiveSaslResponse(final SaslResponse saslResponse) { - final Binary responseBinary = saslResponse.getResponse(); - byte[] response = responseBinary == null ? new byte[0] : responseBinary.getArray(); + final Binary responseBinary = saslResponse.getResponse(); + byte[] response = responseBinary == null ? new byte[0] : responseBinary.getArray(); - try - { - - // Process response from the client - byte[] challenge = _saslServer.evaluateResponse(response != null ? response : new byte[0]); + try + { - if (_saslServer.isComplete()) - { - SaslOutcome outcome = new SaslOutcome(); - - outcome.setCode(SaslCode.OK); - _saslFrameOutput.send(new SASLFrame(outcome),null); - synchronized (getLock()) - { - _saslComplete = true; - _authenticated = true; - getLock().notifyAll(); - } - if(_onSaslCompleteTask != null) - { - _onSaslCompleteTask.run(); - } + // Process response from the client + byte[] challenge = _saslServer.evaluateResponse(response != null ? response : new byte[0]); - } - else - { - SaslChallenge challengeBody = new SaslChallenge(); - challengeBody.setChallenge(new Binary(challenge)); - _saslFrameOutput.send(new SASLFrame(challengeBody), null); + if (_saslServer.isComplete()) + { + SaslOutcome outcome = new SaslOutcome(); - } + outcome.setCode(SaslCode.OK); + _saslFrameOutput.send(new SASLFrame(outcome), null); + synchronized (getLock()) + { + _saslComplete = true; + _authenticated = true; + _user = _saslServerProvider.getAuthenticatedPrincipal(_saslServer); + getLock().notifyAll(); } - catch (SaslException e) + if (_onSaslCompleteTask != null) { - SaslOutcome outcome = new SaslOutcome(); + _onSaslCompleteTask.run(); + } - outcome.setCode(SaslCode.AUTH); - _saslFrameOutput.send(new SASLFrame(outcome),null); - synchronized (getLock()) - { - _saslComplete = true; - _authenticated = false; - getLock().notifyAll(); - } - if(_onSaslCompleteTask != null) - { - _onSaslCompleteTask.run(); - } + } + else + { + SaslChallenge challengeBody = new SaslChallenge(); + challengeBody.setChallenge(new Binary(challenge)); + _saslFrameOutput.send(new SASLFrame(challengeBody), null); - } + } } + catch (SaslException e) + { + SaslOutcome outcome = new SaslOutcome(); + + outcome.setCode(SaslCode.AUTH); + _saslFrameOutput.send(new SASLFrame(outcome), null); + synchronized (getLock()) + { + _saslComplete = true; + _authenticated = false; + getLock().notifyAll(); + } + if (_onSaslCompleteTask != null) + { + _onSaslCompleteTask.run(); + } + + } + } public void receiveSaslOutcome(final SaslOutcome saslOutcome) { - if(saslOutcome.getCode() == SaslCode.OK) + if (saslOutcome.getCode() == SaslCode.OK) { _saslFrameOutput.close(); synchronized (getLock()) @@ -865,7 +881,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour _authenticated = true; getLock().notifyAll(); } - if(_onSaslCompleteTask != null) + if (_onSaslCompleteTask != null) { _onSaslCompleteTask.run(); } @@ -904,22 +920,13 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour return _authenticated; } - public void initiateSASL() + public void initiateSASL(String[] mechanismNames) { SaslMechanisms mechanisms = new SaslMechanisms(); - final Enumeration<SaslServerFactory> saslServerFactories = Sasl.getSaslServerFactories(); - - SaslServerFactory f; ArrayList<Symbol> mechanismsList = new ArrayList<Symbol>(); - while(saslServerFactories.hasMoreElements()) + for (String name : mechanismNames) { - f = saslServerFactories.nextElement(); - final String[] mechanismNames = f.getMechanismNames(null); - for(String name : mechanismNames) - { - mechanismsList.add(Symbol.valueOf(name)); - } - + mechanismsList.add(Symbol.valueOf(name)); } mechanisms.setSaslServerMechanisms(mechanismsList.toArray(new Symbol[mechanismsList.size()])); _saslFrameOutput.send(new SASLFrame(mechanisms), null); diff --git a/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/SaslServerProvider.java b/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/SaslServerProvider.java index 1b08488673..abc92e8acf 100644 --- a/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/SaslServerProvider.java +++ b/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/SaslServerProvider.java @@ -20,10 +20,12 @@ package org.apache.qpid.amqp_1_0.transport; +import java.security.Principal; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; public interface SaslServerProvider { SaslServer getSaslServer(String mechanism, String fqdn) throws SaslException; + Principal getAuthenticatedPrincipal(SaslServer server); } diff --git a/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0.java b/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0.java index f6b8e1e5c9..ed9cd324b4 100755 --- a/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0.java +++ b/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0.java @@ -22,6 +22,7 @@ package org.apache.qpid.server.protocol; import java.net.SocketAddress; import java.nio.ByteBuffer; +import java.security.Principal; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -41,6 +42,7 @@ import org.apache.qpid.protocol.ServerProtocolEngine; import org.apache.qpid.server.model.Broker; import org.apache.qpid.server.protocol.v1_0.Connection_1_0; import org.apache.qpid.server.security.SubjectCreator; +import org.apache.qpid.server.security.auth.UsernamePrincipal; import org.apache.qpid.server.virtualhost.VirtualHost; import org.apache.qpid.transport.Sender; import org.apache.qpid.transport.network.NetworkConnection; @@ -170,6 +172,12 @@ public class ProtocolEngine_1_0_0 implements ServerProtocolEngine, FrameOutputHa { return subjectCreator.createSaslServer(mechanism, fqdn, null); } + + @Override + public Principal getAuthenticatedPrincipal(SaslServer server) + { + return new UsernamePrincipal(server.getAuthorizationID()); + } }; } diff --git a/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0_SASL.java b/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0_SASL.java index 3b02ef2e5b..124eb779d5 100644 --- a/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0_SASL.java +++ b/java/broker/src/main/java/org/apache/qpid/server/protocol/ProtocolEngine_1_0_0_SASL.java @@ -23,6 +23,7 @@ package org.apache.qpid.server.protocol; import java.io.PrintWriter; import java.net.SocketAddress; import java.nio.ByteBuffer; +import java.security.Principal; import java.util.logging.Level; import java.util.logging.Logger; import javax.security.sasl.SaslException; @@ -42,6 +43,7 @@ import org.apache.qpid.protocol.ServerProtocolEngine; import org.apache.qpid.server.model.Broker; import org.apache.qpid.server.protocol.v1_0.Connection_1_0; import org.apache.qpid.server.security.SubjectCreator; +import org.apache.qpid.server.security.auth.UsernamePrincipal; import org.apache.qpid.server.virtualhost.VirtualHost; import org.apache.qpid.transport.Sender; import org.apache.qpid.transport.network.NetworkConnection; @@ -162,7 +164,8 @@ public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOut Container container = new Container(_broker.getId().toString()); VirtualHost virtualHost = _broker.getVirtualHostRegistry().getVirtualHost((String)_broker.getAttribute(Broker.DEFAULT_VIRTUAL_HOST)); - _conn = new ConnectionEndpoint(container, asSaslServerProvider(_broker.getSubjectCreator(getLocalAddress()))); + SubjectCreator subjectCreator = _broker.getSubjectCreator(getLocalAddress()); + _conn = new ConnectionEndpoint(container, asSaslServerProvider(subjectCreator)); _conn.setRemoteAddress(getRemoteAddress()); _conn.setConnectionEventListener(new Connection_1_0(virtualHost, _conn, _connectionId)); _conn.setFrameOutputHandler(this); @@ -189,7 +192,7 @@ public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOut _sender.send(HEADER.duplicate()); _sender.flush(); - _conn.initiateSASL(); + _conn.initiateSASL(subjectCreator.getMechanisms().split(" ")); } @@ -201,7 +204,13 @@ public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOut @Override public SaslServer getSaslServer(String mechanism, String fqdn) throws SaslException { - return subjectCreator.createSaslServer(mechanism, fqdn, null); + return subjectCreator.createSaslServer(mechanism, fqdn, _network.getPeerPrincipal()); + } + + @Override + public Principal getAuthenticatedPrincipal(SaslServer server) + { + return new UsernamePrincipal(server.getAuthorizationID()); } }; } @@ -230,7 +239,7 @@ public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOut Binary bin = new Binary(data); RAW_LOGGER.fine("RECV[" + getRemoteAddress() + "] : " + bin.toString()); } - _readBytes += msg.remaining(); + _readBytes += msg.remaining(); switch(_state) { case A: @@ -392,7 +401,6 @@ public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOut RAW_LOGGER.fine("SEND[" + getRemoteAddress() + "] : " + bin.toString()); } - _sender.send(dup); _sender.flush(); diff --git a/java/broker/src/main/java/org/apache/qpid/server/security/auth/sasl/external/ExternalSaslServer.java b/java/broker/src/main/java/org/apache/qpid/server/security/auth/sasl/external/ExternalSaslServer.java index 509442b14b..475f74180e 100644 --- a/java/broker/src/main/java/org/apache/qpid/server/security/auth/sasl/external/ExternalSaslServer.java +++ b/java/broker/src/main/java/org/apache/qpid/server/security/auth/sasl/external/ExternalSaslServer.java @@ -61,7 +61,7 @@ public class ExternalSaslServer implements SaslServer public String getAuthorizationID() { - return null; + return getAuthenticatedPrincipal().getName(); } public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException |
