001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataOutputStream; 021import java.io.EOFException; 022import java.io.IOException; 023import java.net.Socket; 024import java.net.URI; 025import java.net.UnknownHostException; 026import java.nio.ByteBuffer; 027import java.util.concurrent.atomic.AtomicInteger; 028 029import javax.net.SocketFactory; 030import javax.net.ssl.SSLContext; 031import javax.net.ssl.SSLEngine; 032import javax.net.ssl.SSLEngineResult; 033import javax.net.ssl.SSLParameters; 034 035import org.apache.activemq.thread.TaskRunnerFactory; 036import org.apache.activemq.util.IOExceptionSupport; 037import org.apache.activemq.util.ServiceStopper; 038import org.apache.activemq.wireformat.WireFormat; 039 040/** 041 * This transport initializes the SSLEngine and reads the first command before 042 * handing off to the detected transport. 043 * 044 */ 045public class AutoInitNioSSLTransport extends NIOSSLTransport { 046 047 public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 048 super(wireFormat, socketFactory, remoteLocation, localLocation); 049 } 050 051 public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 052 super(wireFormat, socket, null, null, null); 053 } 054 055 @Override 056 public void setSslContext(SSLContext sslContext) { 057 this.sslContext = sslContext; 058 } 059 060 public ByteBuffer getInputBuffer() { 061 return this.inputBuffer; 062 } 063 064 @Override 065 protected void initializeStreams() throws IOException { 066 NIOOutputStream outputStream = null; 067 try { 068 channel = socket.getChannel(); 069 channel.configureBlocking(false); 070 071 if (sslContext == null) { 072 sslContext = SSLContext.getDefault(); 073 } 074 075 String remoteHost = null; 076 int remotePort = -1; 077 078 try { 079 URI remoteAddress = new URI(this.getRemoteAddress()); 080 remoteHost = remoteAddress.getHost(); 081 remotePort = remoteAddress.getPort(); 082 } catch (Exception e) { 083 } 084 085 // initialize engine, the initial sslSession we get will need to be 086 // updated once the ssl handshake process is completed. 087 if (remoteHost != null && remotePort != -1) { 088 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 089 } else { 090 sslEngine = sslContext.createSSLEngine(); 091 } 092 093 if (verifyHostName) { 094 SSLParameters sslParams = new SSLParameters(); 095 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 096 sslEngine.setSSLParameters(sslParams); 097 } 098 099 sslEngine.setUseClientMode(false); 100 if (enabledCipherSuites != null) { 101 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 102 } 103 104 if (enabledProtocols != null) { 105 sslEngine.setEnabledProtocols(enabledProtocols); 106 } 107 108 if (wantClientAuth) { 109 sslEngine.setWantClientAuth(wantClientAuth); 110 } 111 112 if (needClientAuth) { 113 sslEngine.setNeedClientAuth(needClientAuth); 114 } 115 116 sslSession = sslEngine.getSession(); 117 118 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 119 inputBuffer.clear(); 120 121 outputStream = new NIOOutputStream(channel); 122 outputStream.setEngine(sslEngine); 123 this.dataOut = new DataOutputStream(outputStream); 124 this.buffOut = outputStream; 125 sslEngine.beginHandshake(); 126 handshakeStatus = sslEngine.getHandshakeStatus(); 127 doHandshake(); 128 // detectReadyState(); 129 } catch (Exception e) { 130 try { 131 if(outputStream != null) { 132 outputStream.close(); 133 } 134 super.closeStreams(); 135 } catch (Exception ex) {} 136 throw new IOException(e); 137 } 138 } 139 140 @Override 141 protected void doOpenWireInit() throws Exception { 142 143 } 144 145 146 @Override 147 protected void finishHandshake() throws Exception { 148 if (handshakeInProgress) { 149 handshakeInProgress = false; 150 nextFrameSize = -1; 151 152 // Once handshake completes we need to ask for the now real sslSession 153 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 154 // cipher suite. 155 sslSession = sslEngine.getSession(); 156 157 } 158 } 159 160 public SSLEngine getSslSession() { 161 return this.sslEngine; 162 } 163 164 private volatile byte[] readData; 165 166 private final AtomicInteger readSize = new AtomicInteger(); 167 168 public byte[] getReadData() { 169 return readData != null ? readData : new byte[0]; 170 } 171 172 public AtomicInteger getReadSize() { 173 return readSize; 174 } 175 176 @Override 177 public void serviceRead() { 178 try { 179 if (handshakeInProgress) { 180 doHandshake(); 181 } 182 183 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 184 plain.position(plain.limit()); 185 186 while (true) { 187 if (!plain.hasRemaining()) { 188 int readCount = secureRead(plain); 189 190 // channel is closed, cleanup 191 if (readCount == -1) { 192 onException(new EOFException()); 193 break; 194 } 195 196 receiveCounter += readCount; 197 readSize.addAndGet(readCount); 198 } 199 200 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 201 processCommand(plain); 202 //we have received enough bytes to detect the protocol 203 if (receiveCounter >= 8) { 204 break; 205 } 206 } 207 } 208 } catch (IOException e) { 209 onException(e); 210 } catch (Throwable e) { 211 onException(IOExceptionSupport.create(e)); 212 } 213 } 214 215 @Override 216 protected void processCommand(ByteBuffer plain) throws Exception { 217 ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter); 218 if (readData != null) { 219 newBuffer.put(readData); 220 } 221 newBuffer.put(plain); 222 newBuffer.flip(); 223 readData = newBuffer.array(); 224 } 225 226 227 @Override 228 public void doStart() throws Exception { 229 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 230 // no need to init as we can delay that until demand (eg in doHandshake) 231 connect(); 232 } 233 234 235 @Override 236 protected void doStop(ServiceStopper stopper) throws Exception { 237 if (taskRunnerFactory != null) { 238 taskRunnerFactory.shutdownNow(); 239 taskRunnerFactory = null; 240 } 241 } 242 243 244}