parent
da68d2957b
commit
21dc6257b8
@ -0,0 +1,694 @@
|
||||
'''
|
||||
The MIT License (MIT)
|
||||
Copyright (c) 2013 Dave P.
|
||||
'''
|
||||
import sys
|
||||
VER = sys.version_info[0]
|
||||
if VER >= 3:
|
||||
import socketserver
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from io import StringIO, BytesIO
|
||||
else:
|
||||
import SocketServer
|
||||
from BaseHTTPServer import BaseHTTPRequestHandler
|
||||
from StringIO import StringIO
|
||||
|
||||
import hashlib
|
||||
import base64
|
||||
import socket
|
||||
import struct
|
||||
import ssl
|
||||
import errno
|
||||
import codecs
|
||||
from collections import deque
|
||||
from select import select
|
||||
|
||||
__all__ = ['WebSocket',
|
||||
'SimpleWebSocketServer',
|
||||
'SimpleSSLWebSocketServer']
|
||||
|
||||
def _check_unicode(val):
|
||||
if VER >= 3:
|
||||
return isinstance(val, str)
|
||||
else:
|
||||
return isinstance(val, unicode)
|
||||
|
||||
class HTTPRequest(BaseHTTPRequestHandler):
|
||||
def __init__(self, request_text):
|
||||
if VER >= 3:
|
||||
self.rfile = BytesIO(request_text)
|
||||
else:
|
||||
self.rfile = StringIO(request_text)
|
||||
self.raw_requestline = self.rfile.readline()
|
||||
self.error_code = self.error_message = None
|
||||
self.parse_request()
|
||||
|
||||
_VALID_STATUS_CODES = [1000, 1001, 1002, 1003, 1007, 1008,
|
||||
1009, 1010, 1011, 3000, 3999, 4000, 4999]
|
||||
|
||||
HANDSHAKE_STR = (
|
||||
"HTTP/1.1 101 Switching Protocols\r\n"
|
||||
"Upgrade: WebSocket\r\n"
|
||||
"Connection: Upgrade\r\n"
|
||||
"Sec-WebSocket-Accept: %(acceptstr)s\r\n\r\n"
|
||||
)
|
||||
|
||||
GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
|
||||
STREAM = 0x0
|
||||
TEXT = 0x1
|
||||
BINARY = 0x2
|
||||
CLOSE = 0x8
|
||||
PING = 0x9
|
||||
PONG = 0xA
|
||||
|
||||
HEADERB1 = 1
|
||||
HEADERB2 = 3
|
||||
LENGTHSHORT = 4
|
||||
LENGTHLONG = 5
|
||||
MASK = 6
|
||||
PAYLOAD = 7
|
||||
|
||||
MAXHEADER = 65536
|
||||
MAXPAYLOAD = 33554432
|
||||
|
||||
class WebSocket(object):
|
||||
|
||||
def __init__(self, server, sock, address):
|
||||
self.server = server
|
||||
self.client = sock
|
||||
self.address = address
|
||||
|
||||
self.handshaked = False
|
||||
self.headerbuffer = bytearray()
|
||||
self.headertoread = 2048
|
||||
|
||||
self.fin = 0
|
||||
self.data = bytearray()
|
||||
self.opcode = 0
|
||||
self.hasmask = 0
|
||||
self.maskarray = None
|
||||
self.length = 0
|
||||
self.lengtharray = None
|
||||
self.index = 0
|
||||
self.request = None
|
||||
self.usingssl = False
|
||||
|
||||
self.frag_start = False
|
||||
self.frag_type = BINARY
|
||||
self.frag_buffer = None
|
||||
self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict')
|
||||
self.closed = False
|
||||
self.sendq = deque()
|
||||
|
||||
self.state = HEADERB1
|
||||
|
||||
# restrict the size of header and payload for security reasons
|
||||
self.maxheader = MAXHEADER
|
||||
self.maxpayload = MAXPAYLOAD
|
||||
|
||||
def handleMessage(self):
|
||||
"""
|
||||
Called when websocket frame is received.
|
||||
To access the frame data call self.data.
|
||||
|
||||
If the frame is Text then self.data is a unicode object.
|
||||
If the frame is Binary then self.data is a bytearray object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def handleConnected(self):
|
||||
"""
|
||||
Called when a websocket client connects to the server.
|
||||
"""
|
||||
pass
|
||||
|
||||
def handleClose(self):
|
||||
"""
|
||||
Called when a websocket server gets a Close frame from a client.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handlePacket(self):
|
||||
if self.opcode == CLOSE:
|
||||
pass
|
||||
elif self.opcode == STREAM:
|
||||
pass
|
||||
elif self.opcode == TEXT:
|
||||
pass
|
||||
elif self.opcode == BINARY:
|
||||
pass
|
||||
elif self.opcode == PONG or self.opcode == PING:
|
||||
if len(self.data) > 125:
|
||||
raise Exception('control frame length can not be > 125')
|
||||
else:
|
||||
# unknown or reserved opcode so just close
|
||||
raise Exception('unknown opcode')
|
||||
|
||||
if self.opcode == CLOSE:
|
||||
status = 1000
|
||||
reason = u''
|
||||
length = len(self.data)
|
||||
|
||||
if length == 0:
|
||||
pass
|
||||
elif length >= 2:
|
||||
status = struct.unpack_from('!H', self.data[:2])[0]
|
||||
reason = self.data[2:]
|
||||
|
||||
if status not in _VALID_STATUS_CODES:
|
||||
status = 1002
|
||||
|
||||
if len(reason) > 0:
|
||||
try:
|
||||
reason = reason.decode('utf8', errors='strict')
|
||||
except:
|
||||
status = 1002
|
||||
else:
|
||||
status = 1002
|
||||
|
||||
self.close(status, reason)
|
||||
return
|
||||
|
||||
elif self.fin == 0:
|
||||
if self.opcode != STREAM:
|
||||
if self.opcode == PING or self.opcode == PONG:
|
||||
raise Exception('control messages can not be fragmented')
|
||||
|
||||
self.frag_type = self.opcode
|
||||
self.frag_start = True
|
||||
self.frag_decoder.reset()
|
||||
|
||||
if self.frag_type == TEXT:
|
||||
self.frag_buffer = []
|
||||
utf_str = self.frag_decoder.decode(self.data, final = False)
|
||||
if utf_str:
|
||||
self.frag_buffer.append(utf_str)
|
||||
else:
|
||||
self.frag_buffer = bytearray()
|
||||
self.frag_buffer.extend(self.data)
|
||||
|
||||
else:
|
||||
if self.frag_start is False:
|
||||
raise Exception('fragmentation protocol error')
|
||||
|
||||
if self.frag_type == TEXT:
|
||||
utf_str = self.frag_decoder.decode(self.data, final = False)
|
||||
if utf_str:
|
||||
self.frag_buffer.append(utf_str)
|
||||
else:
|
||||
self.frag_buffer.extend(self.data)
|
||||
|
||||
else:
|
||||
if self.opcode == STREAM:
|
||||
if self.frag_start is False:
|
||||
raise Exception('fragmentation protocol error')
|
||||
|
||||
if self.frag_type == TEXT:
|
||||
utf_str = self.frag_decoder.decode(self.data, final = True)
|
||||
self.frag_buffer.append(utf_str)
|
||||
self.data = u''.join(self.frag_buffer)
|
||||
else:
|
||||
self.frag_buffer.extend(self.data)
|
||||
self.data = self.frag_buffer
|
||||
|
||||
self.handleMessage()
|
||||
|
||||
self.frag_decoder.reset()
|
||||
self.frag_type = BINARY
|
||||
self.frag_start = False
|
||||
self.frag_buffer = None
|
||||
|
||||
elif self.opcode == PING:
|
||||
self._sendMessage(False, PONG, self.data)
|
||||
|
||||
elif self.opcode == PONG:
|
||||
pass
|
||||
|
||||
else:
|
||||
if self.frag_start is True:
|
||||
raise Exception('fragmentation protocol error')
|
||||
|
||||
if self.opcode == TEXT:
|
||||
try:
|
||||
self.data = self.data.decode('utf8', errors='strict')
|
||||
except Exception as exp:
|
||||
raise Exception('invalid utf-8 payload')
|
||||
|
||||
self.handleMessage()
|
||||
|
||||
|
||||
def _handleData(self):
|
||||
# do the HTTP header and handshake
|
||||
if self.handshaked is False:
|
||||
|
||||
data = self.client.recv(self.headertoread)
|
||||
if not data:
|
||||
raise Exception('remote socket closed')
|
||||
|
||||
else:
|
||||
# accumulate
|
||||
self.headerbuffer.extend(data)
|
||||
|
||||
if len(self.headerbuffer) >= self.maxheader:
|
||||
raise Exception('header exceeded allowable size')
|
||||
|
||||
# indicates end of HTTP header
|
||||
if b'\r\n\r\n' in self.headerbuffer:
|
||||
self.request = HTTPRequest(self.headerbuffer)
|
||||
|
||||
# handshake rfc 6455
|
||||
try:
|
||||
key = self.request.headers['Sec-WebSocket-Key']
|
||||
k = key.encode('ascii') + GUID_STR.encode('ascii')
|
||||
k_s = base64.b64encode(hashlib.sha1(k).digest()).decode('ascii')
|
||||
hStr = HANDSHAKE_STR % {'acceptstr': k_s}
|
||||
self.sendq.append((BINARY, hStr.encode('ascii')))
|
||||
self.handshaked = True
|
||||
self.handleConnected()
|
||||
except Exception as e:
|
||||
raise Exception('handshake failed: %s', str(e))
|
||||
|
||||
# else do normal data
|
||||
else:
|
||||
data = self.client.recv(8192)
|
||||
if not data:
|
||||
raise Exception("remote socket closed")
|
||||
|
||||
if VER >= 3:
|
||||
for d in data:
|
||||
self._parseMessage(d)
|
||||
else:
|
||||
for d in data:
|
||||
self._parseMessage(ord(d))
|
||||
|
||||
def close(self, status = 1000, reason = u''):
|
||||
"""
|
||||
Send Close frame to the client. The underlying socket is only closed
|
||||
when the client acknowledges the Close frame.
|
||||
|
||||
status is the closing identifier.
|
||||
reason is the reason for the close.
|
||||
"""
|
||||
try:
|
||||
if self.closed is False:
|
||||
close_msg = bytearray()
|
||||
close_msg.extend(struct.pack("!H", status))
|
||||
if _check_unicode(reason):
|
||||
close_msg.extend(reason.encode('utf-8'))
|
||||
else:
|
||||
close_msg.extend(reason)
|
||||
|
||||
self._sendMessage(False, CLOSE, close_msg)
|
||||
|
||||
finally:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _sendBuffer(self, buff):
|
||||
size = len(buff)
|
||||
tosend = size
|
||||
already_sent = 0
|
||||
|
||||
while tosend > 0:
|
||||
try:
|
||||
# i should be able to send a bytearray
|
||||
sent = self.client.send(buff[already_sent:])
|
||||
if sent == 0:
|
||||
raise RuntimeError('socket connection broken')
|
||||
|
||||
already_sent += sent
|
||||
tosend -= sent
|
||||
|
||||
except socket.error as e:
|
||||
# if we have full buffers then wait for them to drain and try again
|
||||
if e.errno in [errno.EAGAIN, errno.EWOULDBLOCK]:
|
||||
return buff[already_sent:]
|
||||
else:
|
||||
raise e
|
||||
|
||||
return None
|
||||
|
||||
def sendFragmentStart(self, data):
|
||||
"""
|
||||
Send the start of a data fragment stream to a websocket client.
|
||||
Subsequent data should be sent using sendFragment().
|
||||
A fragment stream is completed when sendFragmentEnd() is called.
|
||||
|
||||
If data is a unicode object then the frame is sent as Text.
|
||||
If the data is a bytearray object then the frame is sent as Binary.
|
||||
"""
|
||||
opcode = BINARY
|
||||
if _check_unicode(data):
|
||||
opcode = TEXT
|
||||
self._sendMessage(True, opcode, data)
|
||||
|
||||
def sendFragment(self, data):
|
||||
"""
|
||||
see sendFragmentStart()
|
||||
|
||||
If data is a unicode object then the frame is sent as Text.
|
||||
If the data is a bytearray object then the frame is sent as Binary.
|
||||
"""
|
||||
self._sendMessage(True, STREAM, data)
|
||||
|
||||
def sendFragmentEnd(self, data):
|
||||
"""
|
||||
see sendFragmentEnd()
|
||||
|
||||
If data is a unicode object then the frame is sent as Text.
|
||||
If the data is a bytearray object then the frame is sent as Binary.
|
||||
"""
|
||||
self._sendMessage(False, STREAM, data)
|
||||
|
||||
def sendMessage(self, data):
|
||||
"""
|
||||
Send websocket data frame to the client.
|
||||
|
||||
If data is a unicode object then the frame is sent as Text.
|
||||
If the data is a bytearray object then the frame is sent as Binary.
|
||||
"""
|
||||
opcode = BINARY
|
||||
if _check_unicode(data):
|
||||
opcode = TEXT
|
||||
self._sendMessage(False, opcode, data)
|
||||
|
||||
|
||||
def _sendMessage(self, fin, opcode, data):
|
||||
|
||||
payload = bytearray()
|
||||
|
||||
b1 = 0
|
||||
b2 = 0
|
||||
if fin is False:
|
||||
b1 |= 0x80
|
||||
b1 |= opcode
|
||||
|
||||
if _check_unicode(data):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
length = len(data)
|
||||
payload.append(b1)
|
||||
|
||||
if length <= 125:
|
||||
b2 |= length
|
||||
payload.append(b2)
|
||||
|
||||
elif length >= 126 and length <= 65535:
|
||||
b2 |= 126
|
||||
payload.append(b2)
|
||||
payload.extend(struct.pack("!H", length))
|
||||
|
||||
else:
|
||||
b2 |= 127
|
||||
payload.append(b2)
|
||||
payload.extend(struct.pack("!Q", length))
|
||||
|
||||
if length > 0:
|
||||
payload.extend(data)
|
||||
|
||||
self.sendq.append((opcode, payload))
|
||||
|
||||
|
||||
def _parseMessage(self, byte):
|
||||
# read in the header
|
||||
if self.state == HEADERB1:
|
||||
|
||||
self.fin = byte & 0x80
|
||||
self.opcode = byte & 0x0F
|
||||
self.state = HEADERB2
|
||||
|
||||
self.index = 0
|
||||
self.length = 0
|
||||
self.lengtharray = bytearray()
|
||||
self.data = bytearray()
|
||||
|
||||
rsv = byte & 0x70
|
||||
if rsv != 0:
|
||||
raise Exception('RSV bit must be 0')
|
||||
|
||||
elif self.state == HEADERB2:
|
||||
mask = byte & 0x80
|
||||
length = byte & 0x7F
|
||||
|
||||
if self.opcode == PING and length > 125:
|
||||
raise Exception('ping packet is too large')
|
||||
|
||||
if mask == 128:
|
||||
self.hasmask = True
|
||||
else:
|
||||
self.hasmask = False
|
||||
|
||||
if length <= 125:
|
||||
self.length = length
|
||||
|
||||
# if we have a mask we must read it
|
||||
if self.hasmask is True:
|
||||
self.maskarray = bytearray()
|
||||
self.state = MASK
|
||||
else:
|
||||
# if there is no mask and no payload we are done
|
||||
if self.length <= 0:
|
||||
try:
|
||||
self._handlePacket()
|
||||
finally:
|
||||
self.state = self.HEADERB1
|
||||
self.data = bytearray()
|
||||
|
||||
# we have no mask and some payload
|
||||
else:
|
||||
#self.index = 0
|
||||
self.data = bytearray()
|
||||
self.state = PAYLOAD
|
||||
|
||||
elif length == 126:
|
||||
self.lengtharray = bytearray()
|
||||
self.state = LENGTHSHORT
|
||||
|
||||
elif length == 127:
|
||||
self.lengtharray = bytearray()
|
||||
self.state = LENGTHLONG
|
||||
|
||||
|
||||
elif self.state == LENGTHSHORT:
|
||||
self.lengtharray.append(byte)
|
||||
|
||||
if len(self.lengtharray) > 2:
|
||||
raise Exception('short length exceeded allowable size')
|
||||
|
||||
if len(self.lengtharray) == 2:
|
||||
self.length = struct.unpack_from('!H', self.lengtharray)[0]
|
||||
|
||||
if self.hasmask is True:
|
||||
self.maskarray = bytearray()
|
||||
self.state = MASK
|
||||
else:
|
||||
# if there is no mask and no payload we are done
|
||||
if self.length <= 0:
|
||||
try:
|
||||
self._handlePacket()
|
||||
finally:
|
||||
self.state = HEADERB1
|
||||
self.data = bytearray()
|
||||
|
||||
# we have no mask and some payload
|
||||
else:
|
||||
#self.index = 0
|
||||
self.data = bytearray()
|
||||
self.state = PAYLOAD
|
||||
|
||||
elif self.state == LENGTHLONG:
|
||||
|
||||
self.lengtharray.append(byte)
|
||||
|
||||
if len(self.lengtharray) > 8:
|
||||
raise Exception('long length exceeded allowable size')
|
||||
|
||||
if len(self.lengtharray) == 8:
|
||||
self.length = struct.unpack_from('!Q', self.lengtharray)[0]
|
||||
|
||||
if self.hasmask is True:
|
||||
self.maskarray = bytearray()
|
||||
self.state = MASK
|
||||
else:
|
||||
# if there is no mask and no payload we are done
|
||||
if self.length <= 0:
|
||||
try:
|
||||
self._handlePacket()
|
||||
finally:
|
||||
self.state = HEADERB1
|
||||
self.data = bytearray()
|
||||
|
||||
# we have no mask and some payload
|
||||
else:
|
||||
#self.index = 0
|
||||
self.data = bytearray()
|
||||
self.state = PAYLOAD
|
||||
|
||||
# MASK STATE
|
||||
elif self.state == MASK:
|
||||
self.maskarray.append(byte)
|
||||
|
||||
if len(self.maskarray) > 4:
|
||||
raise Exception('mask exceeded allowable size')
|
||||
|
||||
if len(self.maskarray) == 4:
|
||||
# if there is no mask and no payload we are done
|
||||
if self.length <= 0:
|
||||
try:
|
||||
self._handlePacket()
|
||||
finally:
|
||||
self.state = HEADERB1
|
||||
self.data = bytearray()
|
||||
|
||||
# we have no mask and some payload
|
||||
else:
|
||||
#self.index = 0
|
||||
self.data = bytearray()
|
||||
self.state = PAYLOAD
|
||||
|
||||
# PAYLOAD STATE
|
||||
elif self.state == PAYLOAD:
|
||||
if self.hasmask is True:
|
||||
self.data.append( byte ^ self.maskarray[self.index % 4] )
|
||||
else:
|
||||
self.data.append( byte )
|
||||
|
||||
# if length exceeds allowable size then we except and remove the connection
|
||||
if len(self.data) >= self.maxpayload:
|
||||
raise Exception('payload exceeded allowable size')
|
||||
|
||||
# check if we have processed length bytes; if so we are done
|
||||
if (self.index+1) == self.length:
|
||||
try:
|
||||
self._handlePacket()
|
||||
finally:
|
||||
#self.index = 0
|
||||
self.state = HEADERB1
|
||||
self.data = bytearray()
|
||||
else:
|
||||
self.index += 1
|
||||
|
||||
|
||||
class SimpleWebSocketServer(object):
|
||||
def __init__(self, host, port, websocketclass, selectInterval = 0.1):
|
||||
self.websocketclass = websocketclass
|
||||
self.serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.serversocket.bind((host, port))
|
||||
self.serversocket.listen(5)
|
||||
self.selectInterval = selectInterval
|
||||
self.connections = {}
|
||||
self.listeners = [self.serversocket]
|
||||
|
||||
def _decorateSocket(self, sock):
|
||||
return sock
|
||||
|
||||
def _constructWebSocket(self, sock, address):
|
||||
return self.websocketclass(self, sock, address)
|
||||
|
||||
def close(self):
|
||||
self.serversocket.close()
|
||||
|
||||
for desc, conn in self.connections.items():
|
||||
conn.close()
|
||||
conn.handleClose()
|
||||
|
||||
|
||||
def serveforever(self):
|
||||
while True:
|
||||
writers = []
|
||||
for fileno in self.listeners:
|
||||
if fileno == self.serversocket:
|
||||
continue
|
||||
client = self.connections[fileno]
|
||||
if client.sendq:
|
||||
writers.append(fileno)
|
||||
|
||||
if self.selectInterval:
|
||||
rList, wList, xList = select(self.listeners, writers, self.listeners, self.selectInterval)
|
||||
else:
|
||||
rList, wList, xList = select(self.listeners, writers, self.listeners)
|
||||
|
||||
for ready in wList:
|
||||
client = self.connections[ready]
|
||||
try:
|
||||
while client.sendq:
|
||||
opcode, payload = client.sendq.popleft()
|
||||
remaining = client._sendBuffer(payload)
|
||||
if remaining is not None:
|
||||
client.sendq.appendleft((opcode, remaining))
|
||||
break
|
||||
else:
|
||||
if opcode == CLOSE:
|
||||
raise Exception('received client close')
|
||||
|
||||
except Exception as n:
|
||||
client.client.close()
|
||||
client.handleClose()
|
||||
del self.connections[ready]
|
||||
self.listeners.remove(ready)
|
||||
|
||||
for ready in rList:
|
||||
if ready == self.serversocket:
|
||||
try:
|
||||
sock, address = self.serversocket.accept()
|
||||
newsock = self._decorateSocket(sock)
|
||||
newsock.setblocking(0)
|
||||
fileno = newsock.fileno()
|
||||
self.connections[fileno] = self._constructWebSocket(newsock, address)
|
||||
self.listeners.append(fileno)
|
||||
except Exception as n:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
else:
|
||||
if ready not in self.connections:
|
||||
continue
|
||||
client = self.connections[ready]
|
||||
try:
|
||||
client._handleData()
|
||||
except Exception as n:
|
||||
client.client.close()
|
||||
client.handleClose()
|
||||
del self.connections[ready]
|
||||
self.listeners.remove(ready)
|
||||
|
||||
for failed in xList:
|
||||
if failed == self.serversocket:
|
||||
self.close()
|
||||
raise Exception('server socket failed')
|
||||
else:
|
||||
if failed not in self.connections:
|
||||
continue
|
||||
client = self.connections[failed]
|
||||
client.client.close()
|
||||
client.handleClose()
|
||||
del self.connections[failed]
|
||||
self.listeners.remove(failed)
|
||||
|
||||
|
||||
class SimpleSSLWebSocketServer(SimpleWebSocketServer):
|
||||
|
||||
def __init__(self, host, port, websocketclass, certfile,
|
||||
keyfile, version = ssl.PROTOCOL_TLSv1, selectInterval = 0.1):
|
||||
|
||||
SimpleWebSocketServer.__init__(self, host, port,
|
||||
websocketclass, selectInterval)
|
||||
|
||||
self.context = ssl.SSLContext(version)
|
||||
self.context.load_cert_chain(certfile, keyfile)
|
||||
|
||||
def close(self):
|
||||
super(SimpleSSLWebSocketServer, self).close()
|
||||
|
||||
def _decorateSocket(self, sock):
|
||||
sslsock = self.context.wrap_socket(sock, server_side=True)
|
||||
return sslsock
|
||||
|
||||
def _constructWebSocket(self, sock, address):
|
||||
ws = self.websocketclass(self, sock, address)
|
||||
ws.usingssl = True
|
||||
return ws
|
||||
|
||||
def serveforever(self):
|
||||
super(SimpleSSLWebSocketServer, self).serveforever()
|
||||
Loading…
Reference in new issue