001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033import javax.net.ssl.SSLParameters;
034
035import org.apache.activemq.thread.TaskRunnerFactory;
036import org.apache.activemq.util.IOExceptionSupport;
037import org.apache.activemq.util.ServiceStopper;
038import org.apache.activemq.wireformat.WireFormat;
039
040/**
041 * This transport initializes the SSLEngine and reads the first command before
042 * handing off to the detected transport.
043 *
044 */
045public class AutoInitNioSSLTransport extends NIOSSLTransport {
046
047    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
048        super(wireFormat, socketFactory, remoteLocation, localLocation);
049    }
050
051    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
052        super(wireFormat, socket, null, null, null);
053    }
054
055    @Override
056    public void setSslContext(SSLContext sslContext) {
057        this.sslContext = sslContext;
058    }
059
060    public ByteBuffer getInputBuffer() {
061        return this.inputBuffer;
062    }
063
064    @Override
065    protected void initializeStreams() throws IOException {
066        NIOOutputStream outputStream = null;
067        try {
068            channel = socket.getChannel();
069            channel.configureBlocking(false);
070
071            if (sslContext == null) {
072                sslContext = SSLContext.getDefault();
073            }
074
075            String remoteHost = null;
076            int remotePort = -1;
077
078            try {
079                URI remoteAddress = new URI(this.getRemoteAddress());
080                remoteHost = remoteAddress.getHost();
081                remotePort = remoteAddress.getPort();
082            } catch (Exception e) {
083            }
084
085            // initialize engine, the initial sslSession we get will need to be
086            // updated once the ssl handshake process is completed.
087            if (remoteHost != null && remotePort != -1) {
088                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
089            } else {
090                sslEngine = sslContext.createSSLEngine();
091            }
092
093            if (verifyHostName) {
094                SSLParameters sslParams = new SSLParameters();
095                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
096                sslEngine.setSSLParameters(sslParams);
097            }
098
099            sslEngine.setUseClientMode(false);
100            if (enabledCipherSuites != null) {
101                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
102            }
103
104            if (enabledProtocols != null) {
105                sslEngine.setEnabledProtocols(enabledProtocols);
106            }
107
108            if (wantClientAuth) {
109                sslEngine.setWantClientAuth(wantClientAuth);
110            }
111
112            if (needClientAuth) {
113                sslEngine.setNeedClientAuth(needClientAuth);
114            }
115
116            sslSession = sslEngine.getSession();
117
118            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
119            inputBuffer.clear();
120
121            outputStream = new NIOOutputStream(channel);
122            outputStream.setEngine(sslEngine);
123            this.dataOut = new DataOutputStream(outputStream);
124            this.buffOut = outputStream;
125            sslEngine.beginHandshake();
126            handshakeStatus = sslEngine.getHandshakeStatus();
127            doHandshake();
128           // detectReadyState();
129        } catch (Exception e) {
130            try {
131                if(outputStream != null) {
132                    outputStream.close();
133                }
134                super.closeStreams();
135            } catch (Exception ex) {}
136            throw new IOException(e);
137        }
138    }
139
140    @Override
141    protected void doOpenWireInit() throws Exception {
142
143    }
144
145
146    @Override
147    protected void finishHandshake() throws Exception {
148        if (handshakeInProgress) {
149            handshakeInProgress = false;
150            nextFrameSize = -1;
151
152            // Once handshake completes we need to ask for the now real sslSession
153            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
154            // cipher suite.
155            sslSession = sslEngine.getSession();
156
157        }
158    }
159
160    public SSLEngine getSslSession() {
161        return this.sslEngine;
162    }
163
164    private volatile byte[] readData;
165
166    private final AtomicInteger readSize = new AtomicInteger();
167
168    public byte[] getReadData() {
169        return readData != null ? readData : new byte[0];
170    }
171
172    public AtomicInteger getReadSize() {
173        return readSize;
174    }
175
176    @Override
177    public void serviceRead() {
178        try {
179            if (handshakeInProgress) {
180                doHandshake();
181            }
182
183            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
184            plain.position(plain.limit());
185
186            while (true) {
187                if (!plain.hasRemaining()) {
188                    int readCount = secureRead(plain);
189
190                    // channel is closed, cleanup
191                    if (readCount == -1) {
192                        onException(new EOFException());
193                        break;
194                    }
195
196                    receiveCounter += readCount;
197                    readSize.addAndGet(readCount);
198                }
199
200                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
201                    processCommand(plain);
202                    //we have received enough bytes to detect the protocol
203                    if (receiveCounter >= 8) {
204                        break;
205                    }
206                }
207            }
208        } catch (IOException e) {
209            onException(e);
210        } catch (Throwable e) {
211            onException(IOExceptionSupport.create(e));
212        }
213    }
214
215    @Override
216    protected void processCommand(ByteBuffer plain) throws Exception {
217        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
218        if (readData != null) {
219            newBuffer.put(readData);
220        }
221        newBuffer.put(plain);
222        newBuffer.flip();
223        readData = newBuffer.array();
224    }
225
226
227    @Override
228    public void doStart() throws Exception {
229        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
230        // no need to init as we can delay that until demand (eg in doHandshake)
231        connect();
232    }
233
234
235    @Override
236    protected void doStop(ServiceStopper stopper) throws Exception {
237        if (taskRunnerFactory != null) {
238            taskRunnerFactory.shutdownNow();
239            taskRunnerFactory = null;
240        }
241    }
242
243
244}