1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package org.jboss.netty.handler.codec.http.websocketx;
55
56 import org.jboss.netty.buffer.ChannelBuffer;
57 import org.jboss.netty.buffer.ChannelBuffers;
58 import org.jboss.netty.channel.Channel;
59 import org.jboss.netty.channel.ChannelFutureListener;
60 import org.jboss.netty.channel.ChannelHandlerContext;
61 import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
62 import org.jboss.netty.handler.codec.frame.TooLongFrameException;
63 import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
64 import org.jboss.netty.logging.InternalLogger;
65 import org.jboss.netty.logging.InternalLoggerFactory;
66
67
68
69
70
71 public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State> {
72
73 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
74
75 private static final byte OPCODE_CONT = 0x0;
76 private static final byte OPCODE_TEXT = 0x1;
77 private static final byte OPCODE_BINARY = 0x2;
78 private static final byte OPCODE_CLOSE = 0x8;
79 private static final byte OPCODE_PING = 0x9;
80 private static final byte OPCODE_PONG = 0xA;
81
82 private UTF8Output fragmentedFramesText;
83 private int fragmentedFramesCount;
84
85 private final long maxFramePayloadLength;
86 private boolean frameFinalFlag;
87 private int frameRsv;
88 private int frameOpcode;
89 private long framePayloadLength;
90 private ChannelBuffer framePayload;
91 private int framePayloadBytesRead;
92 private ChannelBuffer maskingKey;
93
94 private final boolean allowExtensions;
95 private final boolean maskedPayload;
96 private boolean receivedClosingHandshake;
97
98 public enum State {
99 FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
100 }
101
102
103
104
105
106
107
108
109
110
111 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) {
112 this(maskedPayload, allowExtensions, Long.MAX_VALUE);
113 }
114
115
116
117
118
119
120
121
122
123
124
125
126
127 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
128 super(State.FRAME_START);
129 this.maskedPayload = maskedPayload;
130 this.allowExtensions = allowExtensions;
131 this.maxFramePayloadLength = maxFramePayloadLength;
132 }
133
134 @Override
135 protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state)
136 throws Exception {
137
138
139 if (receivedClosingHandshake) {
140 buffer.skipBytes(actualReadableBytes());
141 return null;
142 }
143
144 switch (state) {
145 case FRAME_START:
146 framePayloadBytesRead = 0;
147 framePayloadLength = -1;
148 framePayload = null;
149
150
151 byte b = buffer.readByte();
152 frameFinalFlag = (b & 0x80) != 0;
153 frameRsv = (b & 0x70) >> 4;
154 frameOpcode = b & 0x0F;
155
156 if (logger.isDebugEnabled()) {
157 logger.debug("Decoding WebSocket Frame opCode=" + frameOpcode);
158 }
159
160
161 b = buffer.readByte();
162 boolean frameMasked = (b & 0x80) != 0;
163 int framePayloadLen1 = b & 0x7F;
164
165 if (frameRsv != 0 && !allowExtensions) {
166 protocolViolation(channel, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
167 return null;
168 }
169
170 if (maskedPayload && !frameMasked) {
171 protocolViolation(channel, "unmasked client to server frame");
172 return null;
173 }
174 if (frameOpcode > 7) {
175
176
177 if (!frameFinalFlag) {
178 protocolViolation(channel, "fragmented control frame");
179 return null;
180 }
181
182
183 if (framePayloadLen1 > 125) {
184 protocolViolation(channel, "control frame with payload length > 125 octets");
185 return null;
186 }
187
188
189 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
190 protocolViolation(channel, "control frame using reserved opcode " + frameOpcode);
191 return null;
192 }
193
194
195
196
197 if (frameOpcode == 8 && framePayloadLen1 == 1) {
198 protocolViolation(channel, "received close control frame with payload len 1");
199 return null;
200 }
201 } else {
202
203 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
204 protocolViolation(channel, "data frame using reserved opcode " + frameOpcode);
205 return null;
206 }
207
208
209 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
210 protocolViolation(channel, "received continuation data frame outside fragmented message");
211 return null;
212 }
213
214
215 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
216 protocolViolation(channel, "received non-continuation data frame while inside fragmented message");
217 return null;
218 }
219 }
220
221
222 if (framePayloadLen1 == 126) {
223 framePayloadLength = buffer.readUnsignedShort();
224 if (framePayloadLength < 126) {
225 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
226 return null;
227 }
228 } else if (framePayloadLen1 == 127) {
229 framePayloadLength = buffer.readLong();
230
231
232
233 if (framePayloadLength < 65536) {
234 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
235 return null;
236 }
237 } else {
238 framePayloadLength = framePayloadLen1;
239 }
240
241 if (framePayloadLength > maxFramePayloadLength) {
242 protocolViolation(channel, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
243 return null;
244 }
245 if (logger.isDebugEnabled()) {
246 logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
247 }
248
249 checkpoint(State.MASKING_KEY);
250 case MASKING_KEY:
251 if (maskedPayload) {
252 maskingKey = buffer.readBytes(4);
253 }
254 checkpoint(State.PAYLOAD);
255 case PAYLOAD:
256
257
258 int rbytes = actualReadableBytes();
259 ChannelBuffer payloadBuffer = null;
260
261 long willHaveReadByteCount = framePayloadBytesRead + rbytes;
262
263
264
265 if (willHaveReadByteCount == framePayloadLength) {
266
267 payloadBuffer = buffer.readBytes(rbytes);
268 } else if (willHaveReadByteCount < framePayloadLength) {
269
270
271 payloadBuffer = buffer.readBytes(rbytes);
272 if (framePayload == null) {
273 framePayload = channel.getConfig().getBufferFactory().getBuffer(toFrameLength(framePayloadLength));
274 }
275 framePayload.writeBytes(payloadBuffer);
276 framePayloadBytesRead += rbytes;
277
278
279 return null;
280 } else if (willHaveReadByteCount > framePayloadLength) {
281
282
283 payloadBuffer = buffer.readBytes(toFrameLength(framePayloadLength - framePayloadBytesRead));
284 }
285
286
287
288 checkpoint(State.FRAME_START);
289
290
291 if (framePayload == null) {
292 framePayload = payloadBuffer;
293 } else {
294 framePayload.writeBytes(payloadBuffer);
295 }
296
297
298 if (maskedPayload) {
299 unmask(framePayload);
300 }
301
302
303
304 if (frameOpcode == OPCODE_PING) {
305 return new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
306 }
307 if (frameOpcode == OPCODE_PONG) {
308 return new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
309 }
310 if (frameOpcode == OPCODE_CLOSE) {
311 checkCloseFrameBody(channel, framePayload);
312 receivedClosingHandshake = true;
313 return new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
314 }
315
316
317
318 String aggregatedText = null;
319 if (frameFinalFlag) {
320
321
322 if (frameOpcode != OPCODE_PING) {
323 fragmentedFramesCount = 0;
324
325
326 if (frameOpcode == OPCODE_TEXT || fragmentedFramesText != null) {
327
328 checkUTF8String(channel, framePayload.array());
329
330
331
332 aggregatedText = fragmentedFramesText.toString();
333
334 fragmentedFramesText = null;
335 }
336 }
337 } else {
338
339
340 if (fragmentedFramesCount == 0) {
341
342 fragmentedFramesText = null;
343 if (frameOpcode == OPCODE_TEXT) {
344 checkUTF8String(channel, framePayload.array());
345 }
346 } else {
347
348 if (fragmentedFramesText != null) {
349 checkUTF8String(channel, framePayload.array());
350 }
351 }
352
353
354 fragmentedFramesCount++;
355 }
356
357
358 if (frameOpcode == OPCODE_TEXT) {
359 return new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
360 } else if (frameOpcode == OPCODE_BINARY) {
361 return new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
362 } else if (frameOpcode == OPCODE_CONT) {
363 return new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload, aggregatedText);
364 } else {
365 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode);
366 }
367 case CORRUPT:
368
369
370 buffer.readByte();
371 return null;
372 default:
373 throw new Error("Shouldn't reach here.");
374 }
375 }
376
377 private void unmask(ChannelBuffer frame) {
378 byte[] bytes = frame.array();
379 for (int i = 0; i < bytes.length; i++) {
380 frame.setByte(i, frame.getByte(i) ^ maskingKey.getByte(i % 4));
381 }
382 }
383
384 private void protocolViolation(Channel channel, String reason) throws CorruptedFrameException {
385 checkpoint(State.CORRUPT);
386 if (channel.isConnected()) {
387 channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
388 }
389 throw new CorruptedFrameException(reason);
390 }
391
392 private static int toFrameLength(long l) throws TooLongFrameException {
393 if (l > Integer.MAX_VALUE) {
394 throw new TooLongFrameException("Length:" + l);
395 } else {
396 return (int) l;
397 }
398 }
399
400 private void checkUTF8String(Channel channel, byte[] bytes) throws CorruptedFrameException {
401 try {
402
403
404
405
406
407
408
409 if (fragmentedFramesText == null) {
410 fragmentedFramesText = new UTF8Output(bytes);
411 } else {
412 fragmentedFramesText.write(bytes);
413 }
414 } catch (UTF8Exception ex) {
415 protocolViolation(channel, "invalid UTF-8 bytes");
416 }
417 }
418
419 protected void checkCloseFrameBody(Channel channel, ChannelBuffer buffer) throws CorruptedFrameException {
420 if (buffer == null || buffer.capacity() == 0) {
421 return;
422 }
423 if (buffer.capacity() == 1) {
424 protocolViolation(channel, "Invalid close frame body");
425 }
426
427
428 int idx = buffer.readerIndex();
429 buffer.readerIndex(0);
430
431
432 int statusCode = buffer.readShort();
433 if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
434 || statusCode >= 1012 && statusCode <= 2999) {
435 protocolViolation(channel, "Invalid close frame status code: " + statusCode);
436 }
437
438
439 if (buffer.readableBytes() > 0) {
440 byte[] b = new byte[buffer.readableBytes()];
441 buffer.readBytes(b);
442 try {
443 new UTF8Output(b);
444 } catch (UTF8Exception ex) {
445 protocolViolation(channel, "Invalid close frame reason text. Invalid UTF-8 bytes");
446 }
447 }
448
449
450 buffer.readerIndex(idx);
451 }
452 }