aboutsummaryrefslogtreecommitdiffstats
path: root/lib/python2.7/site-packages/Twisted-12.2.0-py2.7-linux-x86_64.egg/twisted/internet/test/test_newtls.py
blob: a196cb519efc38d854e563314b5a8ec4e452c4e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for L{twisted.internet._newtls}.
"""

from twisted.trial import unittest
from twisted.internet.test.reactormixins import ReactorBuilder, runProtocolsWithReactor
from twisted.internet.test.reactormixins import ConnectableProtocol
from twisted.internet.test.test_tls import SSLCreator, TLSMixin
from twisted.internet.test.test_tls import StartTLSClientCreator
from twisted.internet.test.test_tls import ContextGeneratingMixin
from twisted.internet.test.test_tcp import TCPCreator
try:
    from twisted.protocols import tls
    from twisted.internet import _newtls
except ImportError:
    _newtls = None


class BypassTLSTests(unittest.TestCase):
    """
    Tests for the L{_newtls._BypassTLS} class.
    """

    if not _newtls:
        skip = "Couldn't import _newtls, perhaps pyOpenSSL is old or missing"

    def test_loseConnectionPassThrough(self):
        """
        C{_BypassTLS.loseConnection} calls C{loseConnection} on the base
        class, while preserving any default argument in the base class'
        C{loseConnection} implementation.
        """
        default = object()
        result = []

        class FakeTransport(object):
            def loseConnection(self, _connDone=default):
                result.append(_connDone)

        bypass = _newtls._BypassTLS(FakeTransport, FakeTransport())

        # The default from FakeTransport is used:
        bypass.loseConnection()
        self.assertEqual(result, [default])

        # And we can pass our own:
        notDefault = object()
        bypass.loseConnection(notDefault)
        self.assertEqual(result, [default, notDefault])



class FakeProducer(object):
    """
    A producer that does nothing.
    """

    def pauseProducing(self):
        pass


    def resumeProducing(self):
        pass


    def stopProducing(self):
        pass



class ProducerProtocol(ConnectableProtocol):
    """
    Register a producer, unregister it, and verify the producer hooks up to
    innards of C{TLSMemoryBIOProtocol}.
    """

    def __init__(self, producer, result):
        self.producer = producer
        self.result = result


    def connectionMade(self):
        if not isinstance(self.transport.protocol,
                          tls.TLSMemoryBIOProtocol):
            # Either the test or the code have a bug...
            raise RuntimeError("TLSMemoryBIOProtocol not hooked up.")

        self.transport.registerProducer(self.producer, True)
        # The producer was registered with the TLSMemoryBIOProtocol:
        self.result.append(self.transport.protocol._producer._producer)

        self.transport.unregisterProducer()
        # The producer was unregistered from the TLSMemoryBIOProtocol:
        self.result.append(self.transport.protocol._producer)
        self.transport.loseConnection()



class ProducerTestsMixin(ReactorBuilder, TLSMixin, ContextGeneratingMixin):
    """
    Test the new TLS code integrates C{TLSMemoryBIOProtocol} correctly.
    """

    if not _newtls:
        skip = "Could not import twisted.internet._newtls"

    def test_producerSSLFromStart(self):
        """
        C{registerProducer} and C{unregisterProducer} on TLS transports
        created as SSL from the get go are passed to the
        C{TLSMemoryBIOProtocol}, not the underlying transport directly.
        """
        result = []
        producer = FakeProducer()

        runProtocolsWithReactor(self, ConnectableProtocol(),
                                ProducerProtocol(producer, result),
                                SSLCreator())
        self.assertEqual(result, [producer, None])


    def test_producerAfterStartTLS(self):
        """
        C{registerProducer} and C{unregisterProducer} on TLS transports
        created by C{startTLS} are passed to the C{TLSMemoryBIOProtocol}, not
        the underlying transport directly.
        """
        result = []
        producer = FakeProducer()

        runProtocolsWithReactor(self, ConnectableProtocol(),
                                ProducerProtocol(producer, result),
                                StartTLSClientCreator())
        self.assertEqual(result, [producer, None])


    def startTLSAfterRegisterProducer(self, streaming):
        """
        When a producer is registered, and then startTLS is called,
        the producer is re-registered with the C{TLSMemoryBIOProtocol}.
        """
        clientContext = self.getClientContext()
        serverContext = self.getServerContext()
        result = []
        producer = FakeProducer()

        class RegisterTLSProtocol(ConnectableProtocol):
            def connectionMade(self):
                self.transport.registerProducer(producer, streaming)
                self.transport.startTLS(serverContext)
                # Store TLSMemoryBIOProtocol and underlying transport producer
                # status:
                if streaming:
                    # _ProducerMembrane -> producer:
                    result.append(self.transport.protocol._producer._producer)
                    result.append(self.transport.producer._producer)
                else:
                    # _ProducerMembrane -> _PullToPush -> producer:
                    result.append(
                        self.transport.protocol._producer._producer._producer)
                    result.append(self.transport.producer._producer._producer)
                self.transport.unregisterProducer()
                self.transport.loseConnection()

        class StartTLSProtocol(ConnectableProtocol):
            def connectionMade(self):
                self.transport.startTLS(clientContext)

        runProtocolsWithReactor(self, RegisterTLSProtocol(),
                                StartTLSProtocol(), TCPCreator())
        self.assertEqual(result, [producer, producer])


    def test_startTLSAfterRegisterProducerStreaming(self):
        """
        When a streaming producer is registered, and then startTLS is called,
        the producer is re-registered with the C{TLSMemoryBIOProtocol}.
        """
        self.startTLSAfterRegisterProducer(True)


    def test_startTLSAfterRegisterProducerNonStreaming(self):
        """
        When a non-streaming producer is registered, and then startTLS is
        called, the producer is re-registered with the
        C{TLSMemoryBIOProtocol}.
        """
        self.startTLSAfterRegisterProducer(False)


globals().update(ProducerTestsMixin.makeTestCaseClasses())