Bug 710345: Upgrade pywebsocket to v606 (RFC 6455). r=mcmanus

This commit is contained in:
Jason Duell 2011-12-20 00:20:00 -08:00
parent 6dfec8a582
commit 2a58edc446
14 changed files with 464 additions and 175 deletions

View File

@ -65,6 +65,12 @@ Installation:
PythonOption mod_pywebsocket.allow_draft75 On
If you want to allow handlers whose canonical path is not under the root
directory (i.e. symbolic link is in root directory but its target is not),
configure as follows:
PythonOption mod_pywebsocket.allow_handlers_outside_root_dir On
Example snippet of httpd.conf:
(mod_pywebsocket is in /websock_lib, WebSocket handlers are in
/websock_handlers, port is 80 for ws, 443 for wss.)

View File

@ -111,14 +111,9 @@ class StreamBase(object):
bytes = self._request.connection.read(length)
if not bytes:
# MOZILLA: Patrick McManus found we needed this for Python 2.5 to
# work. Not sure which tests he meant: I found that
# content/base/test/test_websocket* all worked fine with 2.5 with
# the original Google code. JDuell
#raise ConnectionTerminatedException(
# 'Receiving %d byte failed. Peer (%r) closed connection' %
# (length, (self._request.connection.remote_addr,)))
raise ConnectionTerminatedException('connection terminated: read failed')
raise ConnectionTerminatedException(
'Receiving %d byte failed. Peer (%r) closed connection' %
(length, (self._request.connection.remote_addr,)))
return bytes
def _write(self, bytes):

View File

@ -298,9 +298,11 @@ class Stream(StreamBase):
'Mask bit on the received frame did\'nt match masking '
'configuration for received frames')
# The spec doesn't disallow putting a value in 0x0-0xFFFF into the
# 8-octet extended payload length field (or 0x0-0xFD in 2-octet field).
# So, we don't check the range of extended_payload_length.
# The Hybi-13 and later specs disallow putting a value in 0x0-0xFFFF
# into the 8-octet extended payload length field (or 0x0-0xFD in
# 2-octet field).
valid_length_encoding = True
length_encoding_bytes = 1
if payload_length == 127:
extended_payload_length = self.receive_bytes(8)
payload_length = struct.unpack(
@ -308,10 +310,23 @@ class Stream(StreamBase):
if payload_length > 0x7FFFFFFFFFFFFFFF:
raise InvalidFrameException(
'Extended payload length >= 2^63')
if self._request.ws_version >= 13 and payload_length < 0x10000:
valid_length_encoding = False
length_encoding_bytes = 8
elif payload_length == 126:
extended_payload_length = self.receive_bytes(2)
payload_length = struct.unpack(
'!H', extended_payload_length)[0]
if self._request.ws_version >= 13 and payload_length < 126:
valid_length_encoding = False
length_encoding_bytes = 2
if not valid_length_encoding:
self._logger.warning(
'Payload length is not encoded using the minimal number of '
'bytes (%d is encoded using %d bytes)',
payload_length,
length_encoding_bytes)
if mask == 1:
masking_nonce = self.receive_bytes(4)

View File

@ -41,9 +41,16 @@ VERSION_HYBI07 = 7
VERSION_HYBI08 = 8
VERSION_HYBI09 = 8
VERSION_HYBI10 = 8
VERSION_HYBI11 = 8
VERSION_HYBI12 = 8
VERSION_HYBI13 = 13
VERSION_HYBI14 = 13
VERSION_HYBI15 = 13
VERSION_HYBI16 = 13
VERSION_HYBI17 = 13
# Constants indicating WebSocket protocol latest version.
VERSION_HYBI_LATEST = VERSION_HYBI10
VERSION_HYBI_LATEST = VERSION_HYBI13
# Port numbers
DEFAULT_WEB_SOCKET_PORT = 80
@ -95,10 +102,17 @@ STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED = 1003
STATUS_TOO_LARGE = 1004
STATUS_CODE_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSE = 1006
STATUS_INVALID_UTF8 = 1007
STATUS_INVALID_FRAME_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_MANDATORY_EXT = 1010
# HTTP status codes
HTTP_STATUS_BAD_REQUEST = 400
HTTP_STATUS_FORBIDDEN = 403
HTTP_STATUS_NOT_FOUND = 404
def is_control_opcode(opcode):

View File

@ -54,13 +54,14 @@ _PASSIVE_CLOSING_HANDSHAKE_HANDLER_NAME = (
class DispatchException(Exception):
"""Exception in dispatching WebSocket request."""
def __init__(self, name, status=404):
def __init__(self, name, status=common.HTTP_STATUS_NOT_FOUND):
super(DispatchException, self).__init__(name)
self.status = status
def _default_passive_closing_handshake_handler(request):
"""Default web_socket_passive_closing_handshake handler."""
return common.STATUS_NORMAL, ''
@ -76,15 +77,21 @@ def _normalize_path(path):
"""
path = path.replace('\\', os.path.sep)
# MOZILLA: do not normalize away symlinks in mochitest
#path = os.path.realpath(path)
path = os.path.realpath(path)
path = path.replace('\\', '/')
return path
def _create_path_to_resource_converter(base_dir):
"""Returns a function that converts the path of a WebSocket handler source
file to a resource string by removing the path to the base directory from
its head, removing _SOURCE_SUFFIX from its tail, and replacing path
separators in it with '/'.
Args:
base_dir: the path to the base directory.
"""
base_dir = _normalize_path(base_dir)
base_len = len(base_dir)
@ -93,7 +100,9 @@ def _create_path_to_resource_converter(base_dir):
def converter(path):
if not path.endswith(_SOURCE_SUFFIX):
return None
path = _normalize_path(path)
# _normalize_path must not be used because resolving symlink breaks
# following path check.
path = path.replace('\\', '/')
if not path.startswith(base_dir):
return None
return path[base_len:-suffix_len]
@ -169,7 +178,9 @@ class Dispatcher(object):
This class maintains a map from resource name to handlers.
"""
def __init__(self, root_dir, scan_dir=None):
def __init__(
self, root_dir, scan_dir=None,
allow_handlers_outside_root_dir=True):
"""Construct an instance.
Args:
@ -181,6 +192,8 @@ class Dispatcher(object):
root_dir is used as scan_dir. scan_dir can be useful
in saving scan time when root_dir contains many
subdirectories.
allow_handlers_outside_root_dir: Scans handler files even if their
canonical path is not under root_dir.
"""
self._logger = util.get_class_logger(self)
@ -193,7 +206,8 @@ class Dispatcher(object):
os.path.realpath(root_dir)):
raise DispatchException('scan_dir:%s must be a directory under '
'root_dir:%s.' % (scan_dir, root_dir))
self._source_handler_files_in_dir(root_dir, scan_dir)
self._source_handler_files_in_dir(
root_dir, scan_dir, allow_handlers_outside_root_dir)
def add_resource_path_alias(self,
alias_resource_path, existing_resource_path):
@ -247,7 +261,7 @@ class Dispatcher(object):
_DO_EXTRA_HANDSHAKE_HANDLER_NAME,
request.ws_resource),
e)
raise handshake.HandshakeException(e, 403)
raise handshake.HandshakeException(e, common.HTTP_STATUS_FORBIDDEN)
def transfer_data(self, request):
"""Let a handler transfer_data with a WebSocket client.
@ -288,8 +302,9 @@ class Dispatcher(object):
self._logger.debug('%s', e)
request.ws_stream.close_connection(common.STATUS_UNSUPPORTED)
except stream.InvalidUTF8Exception, e:
self._logger_debug('%s', e)
request.ws_stream.close_connection(common.STATUS_INVALID_UTF8)
self._logger.debug('%s', e)
request.ws_stream.close_connection(
common.STATUS_INVALID_FRAME_PAYLOAD)
except msgutil.ConnectionTerminatedException, e:
self._logger.debug('%s', e)
except Exception, e:
@ -322,22 +337,44 @@ class Dispatcher(object):
handler_suite = self._handler_suite_map.get(resource)
if handler_suite and fragment:
raise DispatchException('Fragment identifiers MUST NOT be used on '
'WebSocket URIs', 400);
'WebSocket URIs',
common.HTTP_STATUS_BAD_REQUEST)
return handler_suite
def _source_handler_files_in_dir(self, root_dir, scan_dir):
def _source_handler_files_in_dir(
self, root_dir, scan_dir, allow_handlers_outside_root_dir):
"""Source all the handler source files in the scan_dir directory.
The resource path is determined relative to root_dir.
"""
# We build a map from resource to handler code assuming that there's
# only one path from root_dir to scan_dir and it can be obtained by
# comparing realpath of them.
# Here we cannot use abspath. See
# https://bugs.webkit.org/show_bug.cgi?id=31603
convert = _create_path_to_resource_converter(root_dir)
for path in _enumerate_handler_file_paths(scan_dir):
scan_realpath = os.path.realpath(scan_dir)
root_realpath = os.path.realpath(root_dir)
for path in _enumerate_handler_file_paths(scan_realpath):
if (not allow_handlers_outside_root_dir and
(not os.path.realpath(path).startswith(root_realpath))):
self._logger.debug(
'Canonical path of %s is not under root directory' %
path)
continue
try:
handler_suite = _source_handler_file(open(path).read())
except DispatchException, e:
self._source_warnings.append('%s: %s' % (path, e))
continue
resource = convert(path)
if resource is None:
self._logger.debug(
'Path to resource conversion on %s failed' % path)
else:
self._handler_suite_map[convert(path)] = handler_suite

View File

@ -36,6 +36,7 @@ _available_processors = {}
class ExtensionProcessorInterface(object):
def get_extension_response(self):
return None
@ -131,7 +132,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
return response
def setup_stream_options(self, stream_options):
class _OutgoingFilter(object):
def __init__(self, parent):
self._parent = parent
@ -139,6 +142,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
self._parent._outgoing_filter(frame)
class _IncomingFilter(object):
def __init__(self, parent):
self._parent = parent

View File

@ -36,12 +36,15 @@ successfully established.
import logging
from mod_pywebsocket import common
from mod_pywebsocket.handshake import draft75
from mod_pywebsocket.handshake import hybi00
from mod_pywebsocket.handshake import hybi
# Export AbortedByUserException and HandshakeException symbol from this module.
# Export AbortedByUserException, HandshakeException, and VersionException
# symbol from this module.
from mod_pywebsocket.handshake._base import AbortedByUserException
from mod_pywebsocket.handshake._base import HandshakeException
from mod_pywebsocket.handshake._base import VersionException
_LOGGER = logging.getLogger(__name__)
@ -62,7 +65,7 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
handshake.
"""
_LOGGER.debug('Opening handshake resource: %r', request.uri)
_LOGGER.debug('Client\'s opening handshake resource: %r', request.uri)
# To print mimetools.Message as escaped one-line string, we converts
# headers_in to dict object. Without conversion, if we use %r, it just
# prints the type and address, and if we use %s, it prints the original
@ -76,7 +79,7 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
# header values. While MpTable_Type doesn't have such __str__ but just
# __repr__ which formats itself as well as dictionary object.
_LOGGER.debug(
'Opening handshake request headers: %r', dict(request.headers_in))
'Client\'s opening handshake headers: %r', dict(request.headers_in))
handshakers = []
handshakers.append(
@ -88,21 +91,26 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
('IETF Hixie 75', draft75.Handshaker(request, dispatcher, strict)))
for name, handshaker in handshakers:
_LOGGER.info('Trying %s protocol', name)
_LOGGER.debug('Trying %s protocol', name)
try:
handshaker.do_handshake()
_LOGGER.info('Established (%s protocol)', name)
return
except HandshakeException, e:
_LOGGER.info(
_LOGGER.debug(
'Failed to complete opening handshake as %s protocol: %r',
name, e)
if e.status:
raise e
except AbortedByUserException, e:
raise
except VersionException, e:
raise
# TODO(toyoshim): Add a test to cover the case all handshakers fail.
raise HandshakeException(
'Failed to complete opening handshake for all available protocols')
'Failed to complete opening handshake for all available protocols',
status=common.HTTP_STATUS_BAD_REQUEST)
# vi:sts=4 sw=4 et

View File

@ -61,6 +61,22 @@ class HandshakeException(Exception):
self.status = status
class VersionException(Exception):
"""This exception will be raised when a version of client request does not
match with version the server supports.
"""
def __init__(self, name, supported_versions=''):
"""Construct an instance.
Args:
supported_version: a str object to show supported hybi versions.
(e.g. '8, 13')
"""
super(VersionException, self).__init__(name)
self.supported_versions = supported_versions
def get_default_port(is_secure):
if is_secure:
return common.DEFAULT_WEB_SOCKET_SECURE_PORT
@ -200,7 +216,7 @@ def parse_token_list(data):
return token_list
def _parse_extension_param(state, definition):
def _parse_extension_param(state, definition, allow_quoted_string):
param_name = http_header_util.consume_token(state)
if param_name is None:
@ -214,7 +230,11 @@ def _parse_extension_param(state, definition):
http_header_util.consume_lwses(state)
if allow_quoted_string:
# TODO(toyoshim): Add code to validate that parsed param_value is token
param_value = http_header_util.consume_token_or_quoted_string(state)
else:
param_value = http_header_util.consume_token(state)
if param_value is None:
raise HandshakeException(
'No valid parameter value found on the right-hand side of '
@ -223,7 +243,7 @@ def _parse_extension_param(state, definition):
definition.add_parameter(param_name, param_value)
def _parse_extension(state):
def _parse_extension(state, allow_quoted_string):
extension_token = http_header_util.consume_token(state)
if extension_token is None:
return None
@ -239,7 +259,7 @@ def _parse_extension(state):
http_header_util.consume_lwses(state)
try:
_parse_extension_param(state, extension)
_parse_extension_param(state, extension, allow_quoted_string)
except HandshakeException, e:
raise HandshakeException(
'Failed to parse Sec-WebSocket-Extensions header: '
@ -249,7 +269,7 @@ def _parse_extension(state):
return extension
def parse_extensions(data):
def parse_extensions(data, allow_quoted_string=False):
"""Parses Sec-WebSocket-Extensions header value returns a list of
common.ExtensionParameter objects.
@ -260,7 +280,7 @@ def parse_extensions(data):
extension_list = []
while True:
extension = _parse_extension(state)
extension = _parse_extension(state, allow_quoted_string)
if extension is not None:
extension_list.append(extension)

View File

@ -53,6 +53,7 @@ from mod_pywebsocket.handshake._base import parse_extensions
from mod_pywebsocket.handshake._base import parse_token_list
from mod_pywebsocket.handshake._base import validate_mandatory_header
from mod_pywebsocket.handshake._base import validate_subprotocol
from mod_pywebsocket.handshake._base import VersionException
from mod_pywebsocket.stream import Stream
from mod_pywebsocket.stream import StreamOptions
from mod_pywebsocket import util
@ -60,6 +61,16 @@ from mod_pywebsocket import util
_BASE64_REGEX = re.compile('^[+/0-9A-Za-z]*=*$')
# Defining aliases for values used frequently.
_VERSION_HYBI08 = common.VERSION_HYBI08
_VERSION_HYBI08_STRING = str(_VERSION_HYBI08)
_VERSION_LATEST = common.VERSION_HYBI_LATEST
_VERSION_LATEST_STRING = str(_VERSION_LATEST)
_SUPPORTED_VERSIONS = [
_VERSION_LATEST,
_VERSION_HYBI08,
]
def compute_accept(key):
"""Computes value for the Sec-WebSocket-Accept header from value of the
@ -130,7 +141,7 @@ class Handshaker(object):
unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
self._check_version()
self._request.ws_version = self._check_version()
# This handshake must be based on latest hybi. We are responsible to
# fallback to HTTP on handshake failure as latest hybi handshake
@ -151,7 +162,6 @@ class Handshaker(object):
util.hexify(accept_binary))
self._logger.debug('IETF HyBi protocol')
self._request.ws_version = common.VERSION_HYBI_LATEST
# Setup extension processors.
@ -212,29 +222,42 @@ class Handshaker(object):
'request any subprotocol')
self._send_handshake(accept)
self._logger.debug('Sent opening handshake response')
except HandshakeException, e:
if not e.status:
# Fallback to 400 bad request by default.
e.status = 400
e.status = common.HTTP_STATUS_BAD_REQUEST
raise e
def _get_origin(self):
origin = self._request.headers_in.get(
common.SEC_WEBSOCKET_ORIGIN_HEADER)
if self._request.ws_version is _VERSION_HYBI08:
origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER
else:
origin_header = common.ORIGIN_HEADER
origin = self._request.headers_in.get(origin_header)
if origin is None:
self._logger.debug('Client request does not have origin header')
self._request.ws_origin = origin
def _check_version(self):
unused_value = validate_mandatory_header(
self._request, common.SEC_WEBSOCKET_VERSION_HEADER,
str(common.VERSION_HYBI_LATEST), fail_status=426)
version = get_mandatory_header(self._request,
common.SEC_WEBSOCKET_VERSION_HEADER)
if version == _VERSION_HYBI08_STRING:
return _VERSION_HYBI08
if version == _VERSION_LATEST_STRING:
return _VERSION_LATEST
if version.find(',') >= 0:
raise HandshakeException(
'Multiple versions (%r) are not allowed for header %s' %
(version, common.SEC_WEBSOCKET_VERSION_HEADER),
status=common.HTTP_STATUS_BAD_REQUEST)
raise VersionException(
'Unsupported version %r for header %s' %
(version, common.SEC_WEBSOCKET_VERSION_HEADER),
supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
def _set_protocol(self):
self._request.ws_protocol = None
# MOZILLA
self._request.sts = None
# /MOZILLA
protocol_header = self._request.headers_in.get(
common.SEC_WEBSOCKET_PROTOCOL_HEADER)
@ -255,8 +278,12 @@ class Handshaker(object):
self._request.ws_requested_extensions = None
return
if self._request.ws_version is common.VERSION_HYBI08:
allow_quoted_string=False
else:
allow_quoted_string=True
self._request.ws_requested_extensions = parse_extensions(
extensions_header)
extensions_header, allow_quoted_string=allow_quoted_string)
self._logger.debug(
'Extensions requested: %r',
@ -264,6 +291,11 @@ class Handshaker(object):
self._request.ws_requested_extensions))
def _validate_key(self, key):
if key.find(',') >= 0:
raise HandshakeException('Request has multiple %s header lines or '
'contains illegal character \',\': %r' %
(common.SEC_WEBSOCKET_KEY_HEADER, key))
# Validate
key_is_valid = False
try:
@ -319,16 +351,12 @@ class Handshaker(object):
response.append(format_header(
common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
format_extensions(self._request.ws_extensions)))
# MOZILLA: Add HSTS header if requested to
if self._request.sts is not None:
response.append(format_header("Strict-Transport-Security",
self._request.sts))
# /MOZILLA
response.append('\r\n')
raw_response = ''.join(response)
self._logger.debug('Opening handshake response: %r', raw_response)
self._request.connection.write(raw_response)
self._logger.debug('Sent server\'s opening handshake: %r',
raw_response)
# vi:sts=4 sw=4 et

View File

@ -107,8 +107,6 @@ class Handshaker(object):
self._send_handshake()
self._logger.debug('Sent opening handshake response')
def _set_resource(self):
self._request.ws_resource = self._request.uri
@ -138,7 +136,8 @@ class Handshaker(object):
draft = self._request.headers_in.get(common.SEC_WEBSOCKET_DRAFT_HEADER)
if draft is not None and draft != '0':
raise HandshakeException('Illegal value for %s: %s' %
(common.SEC_WEBSOCKET_DRAFT_HEADER, draft))
(common.SEC_WEBSOCKET_DRAFT_HEADER,
draft))
self._logger.debug('IETF HyBi 00 protocol')
self._request.ws_version = common.VERSION_HYBI00
@ -229,8 +228,9 @@ class Handshaker(object):
response.append(self._request.ws_challenge_md5)
raw_response = ''.join(response)
self._logger.debug('Opening handshake response: %r', raw_response)
self._request.connection.write(raw_response)
self._logger.debug('Sent server\'s opening handshake: %r',
raw_response)
# vi:sts=4 sw=4 et

View File

@ -39,6 +39,7 @@ import logging
from mod_python import apache
from mod_pywebsocket import common
from mod_pywebsocket import dispatch
from mod_pywebsocket import handshake
from mod_pywebsocket import util
@ -52,9 +53,21 @@ _PYOPT_HANDLER_ROOT = 'mod_pywebsocket.handler_root'
# The default is the root directory.
_PYOPT_HANDLER_SCAN = 'mod_pywebsocket.handler_scan'
# PythonOption to allow handlers whose canonical path is
# not under the root directory. It's disallowed by default.
# Set this option with value of 'yes' to allow.
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT = (
'mod_pywebsocket.allow_handlers_outside_root_dir')
# Map from values to their meanings. 'Yes' and 'No' are allowed just for
# compatibility.
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION = {
'off': False, 'no': False, 'on': True, 'yes': True}
# PythonOption to specify to allow draft75 handshake.
# The default is None (Off)
_PYOPT_ALLOW_DRAFT75 = 'mod_pywebsocket.allow_draft75'
# Map from values to their meanings.
_PYOPT_ALLOW_DRAFT75_DEFINITION = {'off': False, 'on': True}
class ApacheLogHandler(logging.Handler):
@ -70,15 +83,20 @@ class ApacheLogHandler(logging.Handler):
def __init__(self, request=None):
logging.Handler.__init__(self)
self.log_error = apache.log_error
self._log_error = apache.log_error
if request is not None:
self.log_error = request.log_error
self._log_error = request.log_error
# Time and level will be printed by Apache.
self._formatter = logging.Formatter('%(name)s: %(message)s')
def emit(self, record):
apache_level = apache.APLOG_DEBUG
if record.levelno in ApacheLogHandler._LEVELS:
apache_level = ApacheLogHandler._LEVELS[record.levelno]
msg = self._formatter.format(record)
# "server" parameter must be passed to have "level" parameter work.
# If only "level" parameter is passed, nothing shows up on Apache's
# log. However, at this point, we cannot get the server object of the
@ -99,28 +117,57 @@ class ApacheLogHandler(logging.Handler):
# methods call request.log_error indirectly. When request is
# _StandaloneRequest, the methods call Python's logging facility which
# we create in standalone.py.
self.log_error(record.getMessage(), apache_level, apache.main_server)
self._log_error(msg, apache_level, apache.main_server)
_LOGGER = logging.getLogger('mod_pywebsocket')
def _configure_logging():
logger = logging.getLogger()
# Logs are filtered by Apache based on LogLevel directive in Apache
# configuration file. We must just pass logs for all levels to
# ApacheLogHandler.
_LOGGER.setLevel(logging.DEBUG)
_LOGGER.addHandler(ApacheLogHandler())
logger.setLevel(logging.DEBUG)
logger.addHandler(ApacheLogHandler())
_configure_logging()
_LOGGER = logging.getLogger(__name__)
def _parse_option(name, value, definition):
if value is None:
return False
meaning = definition.get(value.lower())
if meaning is None:
raise Exception('Invalid value for PythonOption %s: %r' %
(name, value))
return meaning
def _create_dispatcher():
_HANDLER_ROOT = apache.main_server.get_options().get(
_PYOPT_HANDLER_ROOT, None)
if not _HANDLER_ROOT:
_LOGGER.info('Initializing Dispatcher')
options = apache.main_server.get_options()
handler_root = options.get(_PYOPT_HANDLER_ROOT, None)
if not handler_root:
raise Exception('PythonOption %s is not defined' % _PYOPT_HANDLER_ROOT,
apache.APLOG_ERR)
_HANDLER_SCAN = apache.main_server.get_options().get(
_PYOPT_HANDLER_SCAN, _HANDLER_ROOT)
dispatcher = dispatch.Dispatcher(_HANDLER_ROOT, _HANDLER_SCAN)
handler_scan = options.get(_PYOPT_HANDLER_SCAN, handler_root)
allow_handlers_outside_root = _parse_option(
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT,
options.get(_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT),
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION)
dispatcher = dispatch.Dispatcher(
handler_root, handler_scan, allow_handlers_outside_root)
for warning in dispatcher.source_warnings():
apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING)
return dispatcher
@ -140,33 +187,54 @@ def headerparserhandler(request):
handshake_is_done = False
try:
allowDraft75 = apache.main_server.get_options().get(
_PYOPT_ALLOW_DRAFT75, None)
handshake.do_handshake(
request, _dispatcher, allowDraft75=allowDraft75)
handshake_is_done = True
request.log_error(
'mod_pywebsocket: resource: %r' % request.ws_resource,
apache.APLOG_DEBUG)
request._dispatcher = _dispatcher
_dispatcher.transfer_data(request)
# Fallback to default http handler for request paths for which
# we don't have request handlers.
if not _dispatcher.get_handler_suite(request.uri):
request.log_error('No handler for resource: %r' % request.uri,
apache.APLOG_INFO)
request.log_error('Fallback to Apache', apache.APLOG_INFO)
return apache.DECLINED
except dispatch.DispatchException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_WARNING)
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
if not handshake_is_done:
return e.status
except handshake.AbortedByUserException, e:
try:
allow_draft75 = _parse_option(
_PYOPT_ALLOW_DRAFT75,
apache.main_server.get_options().get(_PYOPT_ALLOW_DRAFT75),
_PYOPT_ALLOW_DRAFT75_DEFINITION)
try:
handshake.do_handshake(
request, _dispatcher, allowDraft75=allow_draft75)
except handshake.VersionException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
return apache.HTTP_BAD_REQUEST
except handshake.HandshakeException, e:
# Handshake for ws/wss failed.
# The request handling fallback into http/https.
# Send http response with error status.
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
return e.status
handshake_is_done = True
request._dispatcher = _dispatcher
_dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
except Exception, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_WARNING)
# DispatchException can also be thrown if something is wrong in
# pywebsocket code. It's caught here, then.
request.log_error('mod_pywebsocket: %s\n%s' %
(e, util.get_stack_trace()),
apache.APLOG_ERR)
# Unknown exceptions before handshake mean Apache must handle its
# request with another handler.
if not handshake_is_done:
return apache.DECLINE
return apache.DECLINED
# Set assbackwards to suppress response header generation by Apache.
request.assbackwards = 1
return apache.DONE # Return DONE such that no other handlers are invoked.

View File

@ -52,6 +52,7 @@ def _is_ctl(c):
class ParsingState(object):
def __init__(self, data):
self.data = data
self.head = 0
@ -209,7 +210,7 @@ def quote_if_necessary(s):
result.append(c)
if quote:
return '"' + ''.join(result) + '"';
return '"' + ''.join(result) + '"'
else:
return ''.join(result)
@ -251,4 +252,12 @@ def parse_uri(uri):
return parsed.hostname, port, path
try:
urlparse.uses_netloc.index('ws')
except ValueError, e:
# urlparse in Python2.5.1 doesn't have 'ws' and 'wss' entries.
urlparse.uses_netloc.append('ws')
urlparse.uses_netloc.append('wss')
# vi:sts=4 sw=4 et

View File

@ -218,6 +218,7 @@ class DeflateRequest(object):
class _Deflater(object):
def __init__(self, window_bits):
self._logger = get_class_logger(self)
@ -233,6 +234,7 @@ class _Deflater(object):
class _Inflater(object):
def __init__(self):
self._logger = get_class_logger(self)
@ -390,6 +392,10 @@ class DeflateConnection(object):
self._deflater = _Deflater(zlib.MAX_WBITS)
self._inflater = _Inflater()
def get_remote_addr(self):
return self._connection.remote_addr
remote_addr = property(get_remote_addr)
def put_bytes(self, bytes):
self.write(bytes)

View File

@ -65,6 +65,7 @@ import BaseHTTPServer
import CGIHTTPServer
import SimpleHTTPServer
import SocketServer
import httplib
import logging
import logging.handlers
import optparse
@ -74,6 +75,8 @@ import select
import socket
import sys
import threading
import time
_HAS_OPEN_SSL = False
try:
@ -99,13 +102,6 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128
_MAX_MEMORIZED_LINES = 1024
def _print_warnings_if_any(dispatcher):
warnings = dispatcher.source_warnings()
if warnings:
for warning in warnings:
logging.warning('mod_pywebsocket: %s' % warning)
class _StandaloneConnection(object):
"""Mimic mod_python mp_conn."""
@ -165,6 +161,7 @@ class _StandaloneRequest(object):
self._request_handler = request_handler
self.connection = _StandaloneConnection(request_handler)
self._use_tls = use_tls
self.headers_in = request_handler.headers
def get_uri(self):
"""Getter to mimic request.uri."""
@ -178,12 +175,6 @@ class _StandaloneRequest(object):
return self._request_handler.command
method = property(get_method)
def get_headers_in(self):
"""Getter to mimic request.headers_in."""
return self._request_handler.headers
headers_in = property(get_headers_in)
def is_https(self):
"""Mimic request.is_https()."""
@ -216,6 +207,8 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
if necessary.
"""
self._logger = util.get_class_logger(self)
self.request_queue_size = options.request_queue_size
self.__ws_is_shut_down = threading.Event()
self.__ws_serving = False
@ -235,8 +228,16 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
self.server_name, self.server_port = self.server_address
self._sockets = []
if not self.server_name:
# On platforms that doesn't support IPv6, the first bind fails.
# On platforms that supports IPv6
# - If it binds both IPv4 and IPv6 on call with AF_INET6, the
# first bind succeeds and the second fails (we'll see 'Address
# already in use' error).
# - If it binds only IPv6 on call with AF_INET6, both call are
# expected to succeed to listen both protocol.
addrinfo_array = [
(self.address_family, self.socket_type, '', '', '')]
(socket.AF_INET6, socket.SOCK_STREAM, '', '', ''),
(socket.AF_INET, socket.SOCK_STREAM, '', '', '')]
else:
addrinfo_array = socket.getaddrinfo(self.server_name,
self.server_port,
@ -244,12 +245,12 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
socket.SOCK_STREAM,
socket.IPPROTO_TCP)
for addrinfo in addrinfo_array:
logging.info('Create socket on: %r', addrinfo)
self._logger.info('Create socket on: %r', addrinfo)
family, socktype, proto, canonname, sockaddr = addrinfo
try:
socket_ = socket.socket(family, socktype)
except Exception, e:
logging.info('Skip by failure: %r', e)
self._logger.info('Skip by failure: %r', e)
continue
if self.websocket_server_options.use_tls:
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
@ -265,11 +266,22 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
sockets bind.
"""
for socket_, addrinfo in self._sockets:
logging.info('Bind on: %r', addrinfo)
failed_sockets = []
for socketinfo in self._sockets:
socket_, addrinfo = socketinfo
self._logger.info('Bind on: %r', addrinfo)
if self.allow_reuse_address:
socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
socket_.bind(self.server_address)
except Exception, e:
self._logger.info('Skip by failure: %r', e)
socket_.close()
failed_sockets.append(socketinfo)
for socketinfo in failed_sockets:
self._sockets.remove(socketinfo)
def server_activate(self):
"""Override SocketServer.TCPServer.server_activate to enable multiple
@ -280,11 +292,11 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
for socketinfo in self._sockets:
socket_, addrinfo = socketinfo
logging.info('Listen on: %r', addrinfo)
self._logger.info('Listen on: %r', addrinfo)
try:
socket_.listen(self.request_queue_size)
except Exception, e:
logging.info('Skip by failure: %r', e)
self._logger.info('Skip by failure: %r', e)
socket_.close()
failed_sockets.append(socketinfo)
@ -298,23 +310,23 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
for socketinfo in self._sockets:
socket_, addrinfo = socketinfo
logging.info('Close on: %r', addrinfo)
self._logger.info('Close on: %r', addrinfo)
socket_.close()
def fileno(self):
"""Override SocketServer.TCPServer.fileno."""
logging.critical('Not supported: fileno')
self._logger.critical('Not supported: fileno')
return self._sockets[0][0].fileno()
def handle_error(self, rquest, client_address):
"""Override SocketServer.handle_error."""
logging.error(
('Exception in processing request from: %r' % (client_address,)) +
'\n' + util.get_stack_trace())
# Note: client_address is a tuple. To match it against %r, we need the
# trailing comma.
self._logger.error(
'Exception in processing request from: %r\n%s',
client_address,
util.get_stack_trace())
# Note: client_address is a tuple.
def serve_forever(self, poll_interval=0.5):
"""Override SocketServer.BaseServer.serve_forever."""
@ -325,8 +337,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
if hasattr(self, '_handle_request_noblock'):
handle_request = self._handle_request_noblock
else:
logging.warning('mod_pywebsocket: fallback to blocking request '
'handler')
self._logger.warning('Fallback to blocking request handler')
try:
while self.__ws_serving:
r, w, e = select.select(
@ -349,6 +360,9 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
"""CGIHTTPRequestHandler specialized for WebSocket."""
# Use httplib.HTTPMessage instead of mimetools.Message.
MessageClass = httplib.HTTPMessage
def setup(self):
"""Override SocketServer.StreamRequestHandler.setup to wrap rfile
with MemorizingFile.
@ -370,6 +384,8 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
max_memorized_lines=_MAX_MEMORIZED_LINES)
def __init__(self, request, client_address, server):
self._logger = util.get_class_logger(self)
self._options = server.websocket_server_options
# Overrides CGIHTTPServerRequestHandler.cgi_directories.
@ -378,10 +394,6 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
if self._options.is_executable_method is not None:
self.is_executable = self._options.is_executable_method
self._request = _StandaloneRequest(self, self._options.use_tls)
_print_warnings_if_any(self._options.dispatcher)
# This actually calls BaseRequestHandler.__init__.
CGIHTTPServer.CGIHTTPRequestHandler.__init__(
self, request, client_address, server)
@ -406,70 +418,77 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
return False
host, port, resource = http_header_util.parse_uri(self.path)
if resource is None:
logging.info('mod_pywebsocket: invalid uri %r' % self.path)
self._logger.info('Invalid URI: %r', self.path)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True
server_options = self.server.websocket_server_options
if host is not None:
validation_host = server_options.validation_host
if validation_host is not None and host != validation_host:
logging.info('mod_pywebsocket: invalid host %r '
'(expected: %r)' % (host, validation_host))
self._logger.info('Invalid host: %r (expected: %r)',
host,
validation_host)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True
if port is not None:
validation_port = server_options.validation_port
if validation_port is not None and port != validation_port:
logging.info('mod_pywebsocket: invalid port %r '
'(expected: %r)' % (port, validation_port))
self._logger.info('Invalid port: %r (expected: %r)',
port,
validation_port)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True
self.path = resource
request = _StandaloneRequest(self, self._options.use_tls)
try:
# Fallback to default http handler for request paths for which
# we don't have request handlers.
if not self._options.dispatcher.get_handler_suite(self.path):
logging.info('No handlers for request: %s' % self.path)
self._logger.info('No handler for resource: %r',
self.path)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True
except dispatch.DispatchException, e:
self._logger.info('%s', e)
self.send_error(e.status)
return False
# If any Exceptions without except clause setup (including
# DispatchException) is raised below this point, it will be caught
# and logged by WebSocketServer.
try:
try:
handshake.do_handshake(
self._request,
request,
self._options.dispatcher,
allowDraft75=self._options.allow_draft75,
strict=self._options.strict)
except handshake.AbortedByUserException, e:
logging.info('mod_pywebsocket: %s' % e)
except handshake.VersionException, e:
self._logger.info('%s', e)
self.send_response(common.HTTP_STATUS_BAD_REQUEST)
self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
self.end_headers()
return False
try:
self._request._dispatcher = self._options.dispatcher
self._options.dispatcher.transfer_data(self._request)
except dispatch.DispatchException, e:
logging.warning('mod_pywebsocket: %s' % e)
return False
except handshake.AbortedByUserException, e:
logging.info('mod_pywebsocket: %s' % e)
except Exception, e:
# Catch exception in transfer_data.
# In this case, handshake has been successful, so just log
# the exception and return False.
logging.info('mod_pywebsocket: %s' % e)
logging.info(
'mod_pywebsocket: %s' % util.get_stack_trace())
except dispatch.DispatchException, e:
logging.warning('mod_pywebsocket: %s' % e)
self.send_error(e.status)
except handshake.HandshakeException, e:
# Handshake for ws(s) failed. Assume http(s).
logging.info('mod_pywebsocket: %s' % e)
# Handshake for ws(s) failed.
self._logger.info('%s', e)
self.send_error(e.status)
except Exception, e:
logging.warning('mod_pywebsocket: %s' % e)
logging.warning('mod_pywebsocket: %s' % util.get_stack_trace())
return False
request._dispatcher = self._options.dispatcher
self._options.dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e:
self._logger.info('%s', e)
return False
def log_request(self, code='-', size='-'):
"""Override BaseHTTPServer.log_request."""
logging.info('"%s" %s %s',
self._logger.info('"%s" %s %s',
self.requestline, str(code), str(size))
def log_error(self, *args):
@ -477,8 +496,9 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
# Despite the name, this method is for warnings than for errors.
# For example, HTTP status code is logged by this method.
logging.warning('%s - %s' %
(self.address_string(), (args[0] % args[1:])))
self._logger.warning('%s - %s',
self.address_string(),
args[0] % args[1:])
def is_cgi(self):
"""Test whether self.path corresponds to a CGI script.
@ -544,8 +564,9 @@ def _alias_handlers(dispatcher, websock_handlers_map_file):
fp.close()
def _main():
def _build_option_parser():
parser = optparse.OptionParser()
parser.add_option('-H', '--server-host', '--server_host',
dest='server_host',
default='',
@ -576,6 +597,13 @@ def _main():
default=None,
help=('WebSocket handlers scan directory. '
'Must be a directory under websock_handlers.'))
parser.add_option('--allow-handlers-outside-root-dir',
'--allow_handlers_outside_root_dir',
dest='allow_handlers_outside_root_dir',
action='store_true',
default=False,
help=('Scans WebSocket handlers even if their canonical '
'path is not under websock_handlers.'))
parser.add_option('-d', '--document-root', '--document_root',
dest='document_root', default='.',
help='Document root directory.')
@ -599,6 +627,15 @@ def _main():
choices=['debug', 'info', 'warning', 'warn', 'error',
'critical'],
help='Log level.')
parser.add_option('--thread-monitor-interval-in-sec',
'--thread_monitor_interval_in_sec',
dest='thread_monitor_interval_in_sec',
type='int', default=-1,
help=('If positive integer is specified, run a thread '
'monitor to show the status of server threads '
'periodically in the specified inteval in '
'second. If non-positive integer is specified, '
'disable the thread monitor.'))
parser.add_option('--log-max', '--log_max', dest='log_max', type='int',
default=_DEFAULT_LOG_MAX_BYTES,
help='Log maximum bytes')
@ -613,7 +650,39 @@ def _main():
parser.add_option('-q', '--queue', dest='request_queue_size', type='int',
default=_DEFAULT_REQUEST_QUEUE_SIZE,
help='request queue size')
options = parser.parse_args()[0]
return parser
class ThreadMonitor(threading.Thread):
daemon = True
def __init__(self, interval_in_sec):
threading.Thread.__init__(self, name='ThreadMonitor')
self._logger = util.get_class_logger(self)
self._interval_in_sec = interval_in_sec
def run(self):
while True:
thread_name_list = []
for thread in threading.enumerate():
thread_name_list.append(thread.name)
self._logger.info(
"%d active threads: %s",
threading.active_count(),
', '.join(thread_name_list))
time.sleep(self._interval_in_sec)
def _main(args=None):
parser = _build_option_parser()
options, args = parser.parse_args(args=args)
if args:
logging.critical('Unrecognized positional arguments: %r', args)
sys.exit(1)
os.chdir(options.document_root)
@ -653,14 +722,24 @@ def _main():
options.scan_dir = options.websock_handlers
try:
if options.thread_monitor_interval_in_sec > 0:
# Run a thread monitor to show the status of server threads for
# debugging.
ThreadMonitor(options.thread_monitor_interval_in_sec).start()
# Share a Dispatcher among request handlers to save time for
# instantiation. Dispatcher can be shared because it is thread-safe.
options.dispatcher = dispatch.Dispatcher(options.websock_handlers,
options.scan_dir)
options.dispatcher = dispatch.Dispatcher(
options.websock_handlers,
options.scan_dir,
options.allow_handlers_outside_root_dir)
if options.websock_handlers_map_file:
_alias_handlers(options.dispatcher,
options.websock_handlers_map_file)
_print_warnings_if_any(options.dispatcher)
warnings = options.dispatcher.source_warnings()
if warnings:
for warning in warnings:
logging.warning('mod_pywebsocket: %s' % warning)
server = WebSocketServer(options)
server.serve_forever()