Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tls/_constructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from functools import partial

from construct import Array, Bytes, Pass, Struct, UBInt16, UBInt32, UBInt8
from construct import (Array, Bytes, Pass, Struct, Switch, UBInt16, UBInt32,
UBInt8)

from tls._common import enums

Expand Down Expand Up @@ -67,6 +68,22 @@
Array(lambda ctx: ctx.length, UBInt8("compression_methods"))
)

HostName = PrefixedBytes("hostname", UBInt16("length"))

ServerName = Struct(
"server_name",
EnumClass(UBInt8("name_type"), enums.NameType),
Switch(
"name",
lambda ctx: ctx.name_type,
{
enums.NameType.HOST_NAME: HostName
}
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps EnumSwitch would go better here

Copy link
Copy Markdown
Member

@markrwilliams markrwilliams Nov 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

EDIT I just realized that comment wasn't clear -- sorry! I meant to say that it's OK to make this use EnumSwitch in this PR, but I'm happy to see this addressed in a separate PR that closes #106.

)

ServerNameList = TLSPrefixedArray("server_name_list", ServerName)

SignatureAndHashAlgorithm = Struct(
"algorithms",
EnumClass(UBInt8("hash"), enums.HashAlgorithm),
Expand All @@ -86,6 +103,7 @@
type_enum=enums.ExtensionType,
value_field="data",
value_choices={
enums.ExtensionType.SERVER_NAME: Opaque(ServerNameList),
enums.ExtensionType.SIGNATURE_ALGORITHMS: Opaque(
SupportedSignatureAlgorithms
),
Expand Down
4 changes: 4 additions & 0 deletions tls/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@

class UnsupportedCipherException(Exception):
pass


class UnsupportedExtensionException(Exception):
pass
34 changes: 34 additions & 0 deletions tls/hello_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from tls._common import enums

from tls.exceptions import UnsupportedExtensionException


@attr.s
class ProtocolVersion(object):
Expand All @@ -31,6 +33,15 @@ class Random(object):
random_bytes = attr.ib()


@attr.s
class ServerName(object):
"""
An object representing a ServerName struct.
"""
name_type = attr.ib()
name = attr.ib()


@attr.s
class ClientHello(object):
"""
Expand All @@ -42,8 +53,20 @@ class ClientHello(object):
cipher_suites = attr.ib()
compression_methods = attr.ib()
extensions = attr.ib()
allowed_extensions = frozenset([
enums.ExtensionType.SERVER_NAME,
enums.ExtensionType.MAX_FRAGMENT_LENGTH,
enums.ExtensionType.CLIENT_CERTIFICATE_URL,
enums.ExtensionType.SIGNATURE_ALGORITHMS,
# XXX Incomplete list, needs to be populated as we implement more
# extensions.
])

def as_bytes(self):
if any(extension.type not in self.allowed_extensions
for extension in self.extensions):
raise UnsupportedExtensionException

return _constructs.ClientHello.build(
Container(
version=Container(major=self.client_version.major,
Expand Down Expand Up @@ -72,6 +95,9 @@ def from_bytes(cls, bytes):
:return: ClientHello object.
"""
construct = _constructs.ClientHello.parse(bytes)
if any(extension.type not in cls.allowed_extensions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an excellent catch! Should as_bytes also have this check, so that it's not possible to serialize a hello message that contains invalid extensions?

for extension in construct.extensions):
raise UnsupportedExtensionException
return ClientHello(
client_version=ProtocolVersion(
major=construct.version.major,
Expand Down Expand Up @@ -101,8 +127,13 @@ class ServerHello(object):
cipher_suite = attr.ib()
compression_method = attr.ib()
extensions = attr.ib()
allowed_extensions = frozenset([])

def as_bytes(self):
if any(extension.type not in self.allowed_extensions
for extension in self.extensions):
raise UnsupportedExtensionException

return _constructs.ServerHello.build(
Container(
version=Container(major=self.server_version.major,
Expand All @@ -128,6 +159,9 @@ def from_bytes(cls, bytes):
:return: ServerHello object.
"""
construct = _constructs.ServerHello.parse(bytes)
if any(extension.type not in cls.allowed_extensions
for extension in construct.extensions):
raise UnsupportedExtensionException
return ServerHello(
server_version=ProtocolVersion(
major=construct.version.major,
Expand Down
156 changes: 125 additions & 31 deletions tls/test/test_hello_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

import pytest

from tls import _constructs

from tls._common import enums

from tls.ciphersuites import CipherSuites

from tls.exceptions import UnsupportedExtensionException

from tls.hello_message import ClientHello, ServerHello


Expand All @@ -20,7 +24,7 @@ class TestClientHello(object):
Tests for the parsing of ClientHello messages.
"""

no_extensions_packet = (
common_client_hello_data = (
b'\x03\x00' # client_version
b'\x01\x02\x03\x04' # random.gmt_unix_time
b'0123456789012345678901234567' # random.random_bytes
Expand All @@ -30,6 +34,9 @@ class TestClientHello(object):
b'\x00\x6B' # cipher_suites
b'\x01' # compression_methods length
b'\x00' # compression_methods
)

no_extensions_packet = common_client_hello_data + (
b'\x00\x00' # extensions.length
b'' # extensions.extension_type
b'' # extensions.extensions
Expand All @@ -51,16 +58,7 @@ class TestClientHello(object):
b'\x02\x02' # SHA1, DSA
)

extensions_packet = (
b'\x03\x00' # client_version
b'\x01\x02\x03\x04' # random.gmt_unix_time
b'0123456789012345678901234567' # random.random_bytes
b'\x00' # session_id.length
b'' # session_id.session_id
b'\x00\x02' # cipher_suites length
b'\x00\x6B' # cipher_suites
b'\x01' # compression_methods length
b'\x00' # compression_methods
extensions_packet = common_client_hello_data + (
b'\x00\x1a' # extensions length
) + supported_signature_list_extension_data

Expand Down Expand Up @@ -93,6 +91,19 @@ class TestClientHello(object):
b'' # extensions.extensions.extension_data
)

server_name_extension_data = (
b'\x00\x00' # Extension Type: Server Name
b'\x00\x0e' # Length
b'\x00\x0c' # Server Name Indication Length
b'\x00' # Server Name Type: host_name
b'\x00\x09' # Length of hostname data
b'localhost'
)

client_hello_packet_with_server_name_ext = common_client_hello_data + (
b'\x00\x12'
) + server_name_extension_data

def test_resumption_no_extensions(self):
"""
:func:`parse_client_hello` returns an instance of
Expand Down Expand Up @@ -173,19 +184,75 @@ def test_as_bytes_client_hello_cipher_suites(self):
record.as_bytes()
assert exc_info.value.args == ('invalid object', 0)

def test_client_hello_with_server_name_extension(self):
"""
:py:func:`tls.hello_message.ClientHello` parses a packet with a
server_name extension
"""
record = ClientHello.from_bytes(
self.client_hello_packet_with_server_name_ext
)
assert len(record.extensions) == 1
assert record.extensions[0].type == enums.ExtensionType.SERVER_NAME
assert len(record.extensions[0].data) == 1
server_name_list = record.extensions[0].data
assert server_name_list[0].name_type == enums.NameType.HOST_NAME
assert server_name_list[0].name == b'localhost'

def test_hello_from_bytes_with_unsupported_extension(self):
"""
:py:func:`tls.hello_message.ClientHello` does not parse a packet
with an unsupported extension, and raises an error.
"""
server_certificate_type_extension_data = (
b'\x00\x14' # Extension Type: Server Certificate Type
b'\x00\x00' # Length
b'' # Data
)

client_hello_packet = self.common_client_hello_data + (
b'\x00\x04'
) + server_certificate_type_extension_data

with pytest.raises(UnsupportedExtensionException):
ClientHello.from_bytes(
client_hello_packet
)

def test_as_bytes_unsupported_extension(self):
"""
:func:`ClientHello.as_bytes` fails to serialize a message that
contains invalid extensions
"""
extensions_data = (
b'\x00\x04'
b'\x00\x14' # Extension Type: Server Certificate Type
b'\x00\x00' # Length
b'' # Data
)

record = ClientHello.from_bytes(self.no_extensions_packet)
extensions = _constructs.Extensions.parse(extensions_data)
record.extensions = extensions
with pytest.raises(UnsupportedExtensionException):
record.as_bytes()


class TestServerHello(object):
"""
Tests for the parsing of ServerHello messages.
"""
no_extensions_packet = (
common_server_hello_data = (
b'\x03\x00' # server_version
b'\x01\x02\x03\x04' # random.gmt_unix_time
b'0123456789012345678901234567' # random.random_bytes
b'\x20' # session_id.length
b'01234567890123456789012345678901' # session_id
b'\x00\x6B' # cipher_suite
b'\x00' # compression_method
)

no_extensions_packet = common_server_hello_data + (
b'\x00\x00' # extensions.length
b'' # extensions.extension_type
b'' # extensions.extensions
Expand All @@ -207,14 +274,7 @@ class TestServerHello(object):
b'\x02\x02' # SHA1, DSA
)

extensions_packet = (
b'\x03\x00' # server_version
b'\x01\x02\x03\x04' # random.gmt_unix_time
b'0123456789012345678901234567' # random.random_bytes
b'\x20' # session_id.length
b'01234567890123456789012345678901' # session_id
b'\x00\x6B' # cipher_suite
b'\x00' # compression_method
extensions_packet = common_server_hello_data + (
b'\x00\x1a' # extensions length
) + supported_signature_list_extension_data

Expand All @@ -236,14 +296,11 @@ def test_parse_server_hello(self):

def test_parse_server_hello_extensions(self):
"""
:func:`parse_server_hello` returns an instance of :class:`ServerHello`
with extensions, when the extension bytes are present in the input.
:func:`parse_server_hello` fails to parse when
SIGNATURE_ALGORITHMS extension bytes are present in the packet
"""
record = ServerHello.from_bytes(self.extensions_packet)
assert len(record.extensions) == 1
assert (record.extensions[0].type ==
enums.ExtensionType.SIGNATURE_ALGORITHMS)
assert len(record.extensions[0].data) == 10
with pytest.raises(UnsupportedExtensionException):
ServerHello.from_bytes(self.extensions_packet)

def test_as_bytes_no_extensions(self):
"""
Expand All @@ -252,9 +309,46 @@ def test_as_bytes_no_extensions(self):
record = ServerHello.from_bytes(self.no_extensions_packet)
assert record.as_bytes() == self.no_extensions_packet

def test_as_bytes_with_extensions(self):
def test_server_hello_fails_with_server_name_extension(self):
"""
:func:`ServerHello.as_bytes` returns the bytes it was created with
:py:func:`tls.hello_message.ServerHello` does not parse a packet
with a server_name extension, and raises an error.
"""
record = ServerHello.from_bytes(self.extensions_packet)
assert record.as_bytes() == self.extensions_packet
server_name_extension_data = (
b'\x00\x00' # Extension Type: Server Name
b'\x00\x0e' # Length
b'\x00\x0c' # Server Name Indication Length
b'\x00' # Server Name Type: host_name
b'\x00\x09' # Length of hostname data
b'localhost'
)

server_hello_packet = self.common_server_hello_data + (
b'\x00\x12'
) + server_name_extension_data

with pytest.raises(UnsupportedExtensionException):
ServerHello.from_bytes(
server_hello_packet
)

def test_as_bytes_unsupported_extension(self):
"""
:func:`ServerHello.as_bytes` fails to serialize a message that
contains invalid extensions
"""
extensions_data = (
b'\x00\x12'
b'\x00\x00' # Extension Type: Server Name
b'\x00\x0e' # Length
b'\x00\x0c' # Server Name Indication Length
b'\x00' # Server Name Type: host_name
b'\x00\x09' # Length of hostname data
b'localhost'
)

record = ServerHello.from_bytes(self.no_extensions_packet)
extensions = _constructs.Extensions.parse(extensions_data)
record.extensions = extensions
with pytest.raises(UnsupportedExtensionException):
record.as_bytes()
8 changes: 4 additions & 4 deletions tls/test/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,12 @@ class TestHandshakeStructParsing(object):
b'01234567890123456789012345678901' # session_id
b'\x00\x6B' # cipher_suite
b'\x00' # compression_method
b'\x00\x1a' # extensions length
) + supported_signature_list_extension_data
b'\x00\x00' # extensions.length
)

server_hello_handshake_packet = (
b'\x02' # msg_type
b'\x00\x00\x62' # body length
b'\x00\x00\x48' # body length
) + server_hello_packet

certificate_packet = (
Expand Down Expand Up @@ -398,7 +398,7 @@ def test_parse_server_hello_in_handshake(self):
record = Handshake.from_bytes(self.server_hello_handshake_packet)
assert isinstance(record, Handshake)
assert record.msg_type == enums.HandshakeType.SERVER_HELLO
assert record.length == 98
assert record.length == 72
assert isinstance(record.body, ServerHello)

def test_parse_certificate_request_in_handshake(self):
Expand Down