001 /**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.activemq.transport.nio;
019
020 import org.apache.activemq.command.Command;
021 import org.apache.activemq.openwire.OpenWireFormat;
022 import org.apache.activemq.thread.DefaultThreadPools;
023 import org.apache.activemq.util.IOExceptionSupport;
024 import org.apache.activemq.util.ServiceStopper;
025 import org.apache.activemq.wireformat.WireFormat;
026
027 import javax.net.SocketFactory;
028 import javax.net.ssl.*;
029 import java.io.DataInputStream;
030 import java.io.DataOutputStream;
031 import java.io.EOFException;
032 import java.io.IOException;
033 import java.net.Socket;
034 import java.net.URI;
035 import java.net.UnknownHostException;
036 import java.nio.ByteBuffer;
037
038 public class NIOSSLTransport extends NIOTransport {
039
040 protected boolean needClientAuth;
041 protected boolean wantClientAuth;
042 protected String[] enabledCipherSuites;
043
044 protected SSLContext sslContext;
045 protected SSLEngine sslEngine;
046 protected SSLSession sslSession;
047
048
049 protected boolean handshakeInProgress = false;
050 protected SSLEngineResult.Status status = null;
051 protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
052
053 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
054 super(wireFormat, socketFactory, remoteLocation, localLocation);
055 }
056
057 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
058 super(wireFormat, socket);
059 }
060
061 public void setSslContext(SSLContext sslContext) {
062 this.sslContext = sslContext;
063 }
064
065 @Override
066 protected void initializeStreams() throws IOException {
067 try {
068 channel = socket.getChannel();
069 channel.configureBlocking(false);
070
071 if (sslContext == null) {
072 sslContext = SSLContext.getDefault();
073 }
074
075 // initialize engine
076 sslEngine = sslContext.createSSLEngine();
077 sslEngine.setUseClientMode(false);
078 if (enabledCipherSuites != null) {
079 sslEngine.setEnabledCipherSuites(enabledCipherSuites);
080 }
081 sslEngine.setNeedClientAuth(needClientAuth);
082 sslEngine.setWantClientAuth(wantClientAuth);
083
084 sslSession = sslEngine.getSession();
085
086 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
087 inputBuffer.clear();
088 currentBuffer = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
089
090 NIOOutputStream outputStream = new NIOOutputStream(channel);
091 outputStream.setEngine(sslEngine);
092 this.dataOut = new DataOutputStream(outputStream);
093 this.buffOut = outputStream;
094 sslEngine.beginHandshake();
095 handshakeStatus = sslEngine.getHandshakeStatus();
096 doHandshake();
097
098 } catch (Exception e) {
099 throw new IOException(e);
100 }
101
102 }
103
104 protected void finishHandshake() throws Exception {
105 if (handshakeInProgress) {
106 handshakeInProgress = false;
107 nextFrameSize = -1;
108
109 // listen for events telling us when the socket is readable.
110 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
111 public void onSelect(SelectorSelection selection) {
112 serviceRead();
113 }
114
115 public void onError(SelectorSelection selection, Throwable error) {
116 if (error instanceof IOException) {
117 onException((IOException) error);
118 } else {
119 onException(IOExceptionSupport.create(error));
120 }
121 }
122 });
123 }
124 }
125
126 protected void serviceRead() {
127 try {
128 if (handshakeInProgress) {
129 doHandshake();
130 }
131
132 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
133 plain.position(plain.limit());
134
135 while(true) {
136 if (!plain.hasRemaining()) {
137
138 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
139 plain.clear();
140 } else {
141 plain.compact();
142 }
143 int readCount = secureRead(plain);
144
145
146 if (readCount == 0)
147 break;
148
149 // channel is closed, cleanup
150 if (readCount== -1) {
151 onException(new EOFException());
152 selection.close();
153 break;
154 }
155 }
156
157 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
158 processCommand(plain);
159 }
160
161 }
162 } catch (IOException e) {
163 onException(e);
164 } catch (Throwable e) {
165 onException(IOExceptionSupport.create(e));
166 }
167 }
168
169 protected void processCommand(ByteBuffer plain) throws Exception {
170 nextFrameSize = plain.getInt();
171 if (wireFormat instanceof OpenWireFormat) {
172 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
173 if (nextFrameSize > maxFrameSize) {
174 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
175 }
176 }
177 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
178 currentBuffer.putInt(nextFrameSize);
179 if (currentBuffer.hasRemaining()) {
180 if (currentBuffer.remaining() >= plain.remaining()) {
181 currentBuffer.put(plain);
182 } else {
183 byte[] fill = new byte[currentBuffer.remaining()];
184 plain.get(fill);
185 currentBuffer.put(fill);
186 }
187 }
188
189 if (currentBuffer.hasRemaining()) {
190 return;
191 } else {
192 currentBuffer.flip();
193 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
194 doConsume((Command) command);
195 nextFrameSize = -1;
196 }
197 }
198
199 protected int secureRead(ByteBuffer plain) throws Exception {
200
201 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
202 int bytesRead = channel.read(inputBuffer);
203
204 if (bytesRead == -1) {
205 sslEngine.closeInbound();
206 if (inputBuffer.position() == 0 ||
207 status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
208 return -1;
209 }
210 }
211 }
212
213 plain.clear();
214
215 inputBuffer.flip();
216 SSLEngineResult res;
217 do {
218 res = sslEngine.unwrap(inputBuffer, plain);
219 } while (res.getStatus() == SSLEngineResult.Status.OK &&
220 res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
221 res.bytesProduced() == 0);
222
223 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
224 finishHandshake();
225 }
226
227 status = res.getStatus();
228 handshakeStatus = res.getHandshakeStatus();
229
230
231 //TODO deal with BUFFER_OVERFLOW
232
233 if (status == SSLEngineResult.Status.CLOSED) {
234 sslEngine.closeInbound();
235 return -1;
236 }
237
238 inputBuffer.compact();
239 plain.flip();
240
241 return plain.remaining();
242 }
243
244 protected void doHandshake() throws Exception {
245 handshakeInProgress = true;
246 while (true) {
247 switch (sslEngine.getHandshakeStatus()) {
248 case NEED_UNWRAP:
249 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
250 break;
251 case NEED_TASK:
252 Runnable task;
253 while ((task = sslEngine.getDelegatedTask()) != null) {
254 DefaultThreadPools.getDefaultTaskRunnerFactory().execute(task);
255 }
256 break;
257 case NEED_WRAP:
258 ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
259 break;
260 case FINISHED:
261 case NOT_HANDSHAKING:
262 finishHandshake();
263 return;
264 }
265 }
266 }
267
268 @Override
269 protected void doStop(ServiceStopper stopper) throws Exception {
270 if (channel != null) {
271 channel.close();
272 channel = null;
273 }
274 super.doStop(stopper);
275 }
276
277 public boolean isNeedClientAuth() {
278 return needClientAuth;
279 }
280
281 public void setNeedClientAuth(boolean needClientAuth) {
282 this.needClientAuth = needClientAuth;
283 }
284
285 public boolean isWantClientAuth() {
286 return wantClientAuth;
287 }
288
289 public void setWantClientAuth(boolean wantClientAuth) {
290 this.wantClientAuth = wantClientAuth;
291 }
292
293 public String[] getEnabledCipherSuites() {
294 return enabledCipherSuites;
295 }
296
297 public void setEnabledCipherSuites(String[] enabledCipherSuites) {
298 this.enabledCipherSuites = enabledCipherSuites;
299 }
300 }