001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032 033import javax.net.SocketFactory; 034import javax.net.ssl.SSLContext; 035import javax.net.ssl.SSLEngine; 036import javax.net.ssl.SSLEngineResult; 037import javax.net.ssl.SSLEngineResult.HandshakeStatus; 038import javax.net.ssl.SSLParameters; 039import javax.net.ssl.SSLPeerUnverifiedException; 040import javax.net.ssl.SSLSession; 041 042import org.apache.activemq.command.ConnectionInfo; 043import org.apache.activemq.openwire.OpenWireFormat; 044import org.apache.activemq.thread.TaskRunnerFactory; 045import org.apache.activemq.util.IOExceptionSupport; 046import org.apache.activemq.util.ServiceStopper; 047import org.apache.activemq.wireformat.WireFormat; 048import org.slf4j.Logger; 049import org.slf4j.LoggerFactory; 050 051public class NIOSSLTransport extends NIOTransport { 052 053 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 054 055 protected boolean needClientAuth; 056 protected boolean wantClientAuth; 057 protected String[] enabledCipherSuites; 058 protected String[] enabledProtocols; 059 protected boolean verifyHostName = false; 060 061 protected SSLContext sslContext; 062 protected SSLEngine sslEngine; 063 protected SSLSession sslSession; 064 065 protected volatile boolean handshakeInProgress = false; 066 protected SSLEngineResult.Status status = null; 067 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 068 protected TaskRunnerFactory taskRunnerFactory; 069 070 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 071 super(wireFormat, socketFactory, remoteLocation, localLocation); 072 } 073 074 public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer, 075 ByteBuffer inputBuffer) throws IOException { 076 super(wireFormat, socket, initBuffer); 077 this.sslEngine = engine; 078 if (engine != null) { 079 this.sslSession = engine.getSession(); 080 } 081 this.inputBuffer = inputBuffer; 082 } 083 084 public void setSslContext(SSLContext sslContext) { 085 this.sslContext = sslContext; 086 } 087 088 volatile boolean hasSslEngine = false; 089 090 @Override 091 protected void initializeStreams() throws IOException { 092 if (sslEngine != null) { 093 hasSslEngine = true; 094 } 095 NIOOutputStream outputStream = null; 096 try { 097 channel = socket.getChannel(); 098 channel.configureBlocking(false); 099 100 if (sslContext == null) { 101 sslContext = SSLContext.getDefault(); 102 } 103 104 String remoteHost = null; 105 int remotePort = -1; 106 107 try { 108 URI remoteAddress = new URI(this.getRemoteAddress()); 109 remoteHost = remoteAddress.getHost(); 110 remotePort = remoteAddress.getPort(); 111 } catch (Exception e) { 112 } 113 114 // initialize engine, the initial sslSession we get will need to be 115 // updated once the ssl handshake process is completed. 116 if (!hasSslEngine) { 117 if (remoteHost != null && remotePort != -1) { 118 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 119 } else { 120 sslEngine = sslContext.createSSLEngine(); 121 } 122 123 if (verifyHostName) { 124 SSLParameters sslParams = new SSLParameters(); 125 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 126 sslEngine.setSSLParameters(sslParams); 127 } 128 129 sslEngine.setUseClientMode(false); 130 if (enabledCipherSuites != null) { 131 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 132 } 133 134 if (enabledProtocols != null) { 135 sslEngine.setEnabledProtocols(enabledProtocols); 136 } 137 138 if (wantClientAuth) { 139 sslEngine.setWantClientAuth(wantClientAuth); 140 } 141 142 if (needClientAuth) { 143 sslEngine.setNeedClientAuth(needClientAuth); 144 } 145 146 sslSession = sslEngine.getSession(); 147 148 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 149 inputBuffer.clear(); 150 } 151 152 outputStream = new NIOOutputStream(channel); 153 outputStream.setEngine(sslEngine); 154 this.dataOut = new DataOutputStream(outputStream); 155 this.buffOut = outputStream; 156 157 //If the sslEngine was not passed in, then handshake 158 if (!hasSslEngine) { 159 sslEngine.beginHandshake(); 160 } 161 handshakeStatus = sslEngine.getHandshakeStatus(); 162 if (!hasSslEngine) { 163 doHandshake(); 164 } 165 166 // if (hasSslEngine) { 167 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 168 @Override 169 public void onSelect(SelectorSelection selection) { 170 serviceRead(); 171 } 172 173 @Override 174 public void onError(SelectorSelection selection, Throwable error) { 175 if (error instanceof IOException) { 176 onException((IOException) error); 177 } else { 178 onException(IOExceptionSupport.create(error)); 179 } 180 } 181 }); 182 doInit(); 183 184 } catch (Exception e) { 185 try { 186 if(outputStream != null) { 187 outputStream.close(); 188 } 189 super.closeStreams(); 190 } catch (Exception ex) {} 191 throw new IOException(e); 192 } 193 } 194 195 protected void doInit() throws Exception { 196 197 } 198 199 protected void doOpenWireInit() throws Exception { 200 //Do this later to let wire format negotiation happen 201 if (initBuffer != null && this.wireFormat instanceof OpenWireFormat) { 202 initBuffer.buffer.flip(); 203 if (initBuffer.buffer.hasRemaining()) { 204 nextFrameSize = -1; 205 receiveCounter += initBuffer.readSize; 206 processCommand(initBuffer.buffer); 207 processCommand(initBuffer.buffer); 208 initBuffer.buffer.clear(); 209 } 210 } 211 } 212 213 protected void finishHandshake() throws Exception { 214 if (handshakeInProgress) { 215 handshakeInProgress = false; 216 nextFrameSize = -1; 217 218 // Once handshake completes we need to ask for the now real sslSession 219 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 220 // cipher suite. 221 sslSession = sslEngine.getSession(); 222 223 // listen for events telling us when the socket is readable. 224 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 225 @Override 226 public void onSelect(SelectorSelection selection) { 227 serviceRead(); 228 } 229 230 @Override 231 public void onError(SelectorSelection selection, Throwable error) { 232 if (error instanceof IOException) { 233 onException((IOException) error); 234 } else { 235 onException(IOExceptionSupport.create(error)); 236 } 237 } 238 }); 239 } 240 } 241 242 @Override 243 public void serviceRead() { 244 try { 245 if (handshakeInProgress) { 246 doHandshake(); 247 } 248 249 doOpenWireInit(); 250 251 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 252 plain.position(plain.limit()); 253 254 while (true) { 255 if (!plain.hasRemaining()) { 256 257 int readCount = secureRead(plain); 258 259 if (readCount == 0) { 260 break; 261 } 262 263 // channel is closed, cleanup 264 if (readCount == -1) { 265 onException(new EOFException()); 266 selection.close(); 267 break; 268 } 269 270 receiveCounter += readCount; 271 } 272 273 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 274 processCommand(plain); 275 } 276 } 277 } catch (IOException e) { 278 onException(e); 279 } catch (Throwable e) { 280 onException(IOExceptionSupport.create(e)); 281 } 282 } 283 284 protected void processCommand(ByteBuffer plain) throws Exception { 285 286 // Are we waiting for the next Command or are we building on the current one 287 if (nextFrameSize == -1) { 288 289 // We can get small packets that don't give us enough for the frame size 290 // so allocate enough for the initial size value and 291 if (plain.remaining() < Integer.SIZE) { 292 if (currentBuffer == null) { 293 currentBuffer = ByteBuffer.allocate(4); 294 } 295 296 // Go until we fill the integer sized current buffer. 297 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 298 currentBuffer.put(plain.get()); 299 } 300 301 // Didn't we get enough yet to figure out next frame size. 302 if (currentBuffer.hasRemaining()) { 303 return; 304 } else { 305 currentBuffer.flip(); 306 nextFrameSize = currentBuffer.getInt(); 307 } 308 309 } else { 310 311 // Either we are completing a previous read of the next frame size or its 312 // fully contained in plain already. 313 if (currentBuffer != null) { 314 315 // Finish the frame size integer read and get from the current buffer. 316 while (currentBuffer.hasRemaining()) { 317 currentBuffer.put(plain.get()); 318 } 319 320 currentBuffer.flip(); 321 nextFrameSize = currentBuffer.getInt(); 322 323 } else { 324 nextFrameSize = plain.getInt(); 325 } 326 } 327 328 if (wireFormat instanceof OpenWireFormat) { 329 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 330 if (nextFrameSize > maxFrameSize) { 331 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 332 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 333 } 334 } 335 336 // now we got the data, lets reallocate and store the size for the marshaler. 337 // if there's more data in plain, then the next call will start processing it. 338 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 339 currentBuffer.putInt(nextFrameSize); 340 341 } else { 342 // If its all in one read then we can just take it all, otherwise take only 343 // the current frame size and the next iteration starts a new command. 344 if (currentBuffer != null) { 345 if (currentBuffer.remaining() >= plain.remaining()) { 346 currentBuffer.put(plain); 347 } else { 348 byte[] fill = new byte[currentBuffer.remaining()]; 349 plain.get(fill); 350 currentBuffer.put(fill); 351 } 352 353 // Either we have enough data for a new command or we have to wait for some more. 354 if (currentBuffer.hasRemaining()) { 355 return; 356 } else { 357 currentBuffer.flip(); 358 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 359 doConsume(command); 360 nextFrameSize = -1; 361 currentBuffer = null; 362 } 363 } 364 } 365 } 366 367 protected int secureRead(ByteBuffer plain) throws Exception { 368 369 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 370 int bytesRead = channel.read(inputBuffer); 371 372 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 373 return 0; 374 } 375 376 if (bytesRead == -1) { 377 sslEngine.closeInbound(); 378 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 379 return -1; 380 } 381 } 382 } 383 384 plain.clear(); 385 386 inputBuffer.flip(); 387 SSLEngineResult res; 388 do { 389 res = sslEngine.unwrap(inputBuffer, plain); 390 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 391 && res.bytesProduced() == 0); 392 393 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 394 finishHandshake(); 395 } 396 397 status = res.getStatus(); 398 handshakeStatus = res.getHandshakeStatus(); 399 400 // TODO deal with BUFFER_OVERFLOW 401 402 if (status == SSLEngineResult.Status.CLOSED) { 403 sslEngine.closeInbound(); 404 return -1; 405 } 406 407 inputBuffer.compact(); 408 plain.flip(); 409 410 return plain.remaining(); 411 } 412 413 protected void doHandshake() throws Exception { 414 handshakeInProgress = true; 415 Selector selector = null; 416 SelectionKey key = null; 417 boolean readable = true; 418 try { 419 while (true) { 420 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 421 switch (handshakeStatus) { 422 case NEED_UNWRAP: 423 if (readable) { 424 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 425 } 426 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 427 long now = System.currentTimeMillis(); 428 if (selector == null) { 429 selector = Selector.open(); 430 key = channel.register(selector, SelectionKey.OP_READ); 431 } else { 432 key.interestOps(SelectionKey.OP_READ); 433 } 434 int keyCount = selector.select(this.getSoTimeout()); 435 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 436 throw new SocketTimeoutException("Timeout during handshake"); 437 } 438 readable = key.isReadable(); 439 } 440 break; 441 case NEED_TASK: 442 Runnable task; 443 while ((task = sslEngine.getDelegatedTask()) != null) { 444 task.run(); 445 } 446 break; 447 case NEED_WRAP: 448 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 449 break; 450 case FINISHED: 451 case NOT_HANDSHAKING: 452 finishHandshake(); 453 return; 454 } 455 } 456 } finally { 457 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 458 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 459 } 460 } 461 462 @Override 463 protected void doStart() throws Exception { 464 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 465 // no need to init as we can delay that until demand (eg in doHandshake) 466 super.doStart(); 467 } 468 469 @Override 470 protected void doStop(ServiceStopper stopper) throws Exception { 471 if (taskRunnerFactory != null) { 472 taskRunnerFactory.shutdownNow(); 473 taskRunnerFactory = null; 474 } 475 if (channel != null) { 476 channel.close(); 477 channel = null; 478 } 479 super.doStop(stopper); 480 } 481 482 /** 483 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 484 * 485 * @param command 486 * The Command coming in. 487 */ 488 @Override 489 public void doConsume(Object command) { 490 if (command instanceof ConnectionInfo) { 491 ConnectionInfo connectionInfo = (ConnectionInfo) command; 492 connectionInfo.setTransportContext(getPeerCertificates()); 493 } 494 super.doConsume(command); 495 } 496 497 /** 498 * @return peer certificate chain associated with the ssl socket 499 */ 500 public X509Certificate[] getPeerCertificates() { 501 502 X509Certificate[] clientCertChain = null; 503 try { 504 if (sslEngine.getSession() != null) { 505 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 506 } 507 } catch (SSLPeerUnverifiedException e) { 508 if (LOG.isTraceEnabled()) { 509 LOG.trace("Failed to get peer certificates.", e); 510 } 511 } 512 513 return clientCertChain; 514 } 515 516 public boolean isNeedClientAuth() { 517 return needClientAuth; 518 } 519 520 public void setNeedClientAuth(boolean needClientAuth) { 521 this.needClientAuth = needClientAuth; 522 } 523 524 public boolean isWantClientAuth() { 525 return wantClientAuth; 526 } 527 528 public void setWantClientAuth(boolean wantClientAuth) { 529 this.wantClientAuth = wantClientAuth; 530 } 531 532 public String[] getEnabledCipherSuites() { 533 return enabledCipherSuites; 534 } 535 536 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 537 this.enabledCipherSuites = enabledCipherSuites; 538 } 539 540 public String[] getEnabledProtocols() { 541 return enabledProtocols; 542 } 543 544 public void setEnabledProtocols(String[] enabledProtocols) { 545 this.enabledProtocols = enabledProtocols; 546 } 547 548 public boolean isVerifyHostName() { 549 return verifyHostName; 550 } 551 552 public void setVerifyHostName(boolean verifyHostName) { 553 this.verifyHostName = verifyHostName; 554 } 555}