diff --git a/test/test_basic.py b/test/test_basic.py index e39619f..99c12d8 100644 --- a/test/test_basic.py +++ b/test/test_basic.py @@ -1,363 +1,363 @@ # Copyright (C) 2008-2009 AG Projects. See LICENSE for details import pprint import sys import unittest from application import log from copy import copy from eventlib import api, proc from eventlib.coros import event from gnutls.crypto import X509PrivateKey, X509Certificate from gnutls.interfaces.twisted import X509Credentials from twisted.internet import reactor; del reactor # need to import this to let eventlib know to use a twisted based reactor from twisted.internet.error import ConnectionDone, ConnectionClosed from twisted.names.srvconnect import SRVConnector from msrplib.connect import DirectConnector, DirectAcceptor, RelayConnection, MSRPRelaySettings, ConnectBase, MSRPServer from msrplib.protocol import ContentTypeHeader, SuccessReportHeader, FailureReportHeader, URI from msrplib.session import GreenMSRPSession, MSRPSessionError, LocalResponse from msrplib.trafficlog import Logger from msrplib.transport import MSRPTransport X509Credentials.verify_peer=False class NoisySRVConnector(SRVConnector): def pickServer(self): host, port = SRVConnector.pickServer(self) print('Resolved _%s._%s.%s --> %s:%s' % (self.service, self.protocol, self.domain, host, port)) return host, port ConnectBase.SRVConnectorClass = NoisySRVConnector class TimeoutEvent(event): timeout = 10 def wait(self): with api.timeout(self.timeout): return event.wait(self) def _connect_msrp(local_event, remote_event, msrp, local_uri): full_local_path = msrp.prepare(local_uri) try: local_event.send(full_local_path) full_remote_path = remote_event.wait() result = msrp.complete(full_remote_path) assert isinstance(result, MSRPTransport), repr(result) return result finally: msrp.cleanup() class GreenMSRPSession_ZeroTimeout(GreenMSRPSession): RESPONSE_TIMEOUT = 0 class InjectedError(Exception): pass class TestBase(unittest.TestCase): PER_TEST_TIMEOUT = 30 client_relay = None client_logger = Logger(prefix='C ') server_relay = None server_logger = Logger(prefix='S ') debug = True use_tls = False server_credentials = None def get_client_uri(self): return URI(use_tls=self.use_tls) def get_server_uri(self): return URI(port=0, use_tls=self.use_tls, credentials=self.server_credentials) def get_connector(self): if self.client_relay is not None: return RelayConnection(self.client_relay, 'active', logger=self.client_logger) else: return DirectConnector(logger=self.client_logger) def get_acceptor(self): if self.server_relay is not None: return RelayConnection(self.server_relay, 'passive', logger=self.client_logger) else: return DirectAcceptor(logger=self.server_logger) def setup_two_endpoints(self): server_path = TimeoutEvent() client_path = TimeoutEvent() client = proc.spawn_link_exception(_connect_msrp, client_path, server_path, self.get_connector(), self.get_client_uri()) server = proc.spawn_link_exception(_connect_msrp, server_path, client_path, self.get_acceptor(), self.get_server_uri()) return client, server def setUp(self): print('\n%s.%s' % (self.__class__.__name__, self._testMethodName)) self.timer = api.exc_after(self.PER_TEST_TIMEOUT, api.TimeoutError('per test timeout expired')) def tearDown(self): self.timer.cancel() del self.timer def assertHeaderEqual(self, header, chunk1, chunk2): self.assertEqual(chunk1.headers[header].decoded, chunk2.headers[header].decoded) def assertSameData(self, chunk1, chunk2): try: self.assertHeaderEqual('Content-Type', chunk1, chunk2) self.assertEqual(chunk1.data, chunk2.data) self.assertEqual(chunk1.contflag, chunk2.contflag) except Exception: print('Error while comparing %r and %r' % (chunk1, chunk2)) raise def make_hello(self, msrptransport, success_report=None, failure_report=None): - chunk = msrptransport.make_send_request(data='hello') + chunk = msrptransport.make_send_request(data=b'hello') chunk.add_header(ContentTypeHeader('text/plain')) # because MSRPTransport does not send the responses, the relay must not either if success_report is not None: chunk.add_header(SuccessReportHeader(success_report)) if failure_report is not None: chunk.add_header(FailureReportHeader(failure_report)) return chunk def _test_write_chunk(self, sender, receiver): chunk = self.make_hello(sender, failure_report='no') sender.write(chunk.encode()) chunk_received = receiver.read_chunk() self.assertSameData(chunk, chunk_received) class TLSMixin(object): use_tls = True cert = X509Certificate(open('valid.crt').read()) key = X509PrivateKey(open('valid.key').read()) server_credentials = X509Credentials(cert, key) class MSRPTransportTest(TestBase): def test_write_chunk(self): client, server = proc.waitall(self.setup_two_endpoints()) self._test_write_chunk(client, server) self._test_write_chunk(server, client) #self.assertNoIncoming(0.1, client, server) client.loseConnection() server.loseConnection() def test_close_connection__read(self): client, server = proc.waitall(self.setup_two_endpoints()) client.loseConnection() self.assertRaises(ConnectionDone, server.read_chunk) self.assertRaises(ConnectionDone, server.write, self.make_hello(server).encode()) self.assertRaises(ConnectionDone, client.read_chunk) self.assertRaises(ConnectionDone, client.write, self.make_hello(client).encode()) # add test for chunking class MSRPTransportTest_TLS(TLSMixin, MSRPTransportTest): pass class MSRPSessionTest(TestBase): def _test_deliver_chunk(self, sender, receiver, chunk=None): if chunk is None: chunk = self.make_hello(sender.msrp) response = sender.deliver_chunk(chunk) assert response.code == 200, response chunk_received = receiver.receive_chunk() self.assertSameData(chunk, chunk_received) def test_deliver_chunk(self): client, server = proc.waitall(self.setup_two_endpoints()) client, server = GreenMSRPSession(client), GreenMSRPSession(server) self._test_deliver_chunk(client, server) self._test_deliver_chunk(server, client) #self.assertNoIncoming(0.1, client, server) client.shutdown() server.shutdown() def assertRaisesCode(self, exception, code, func, *args, **kwargs): try: func(*args, **kwargs) except exception as ex: self.assertEqual(ex.code, code) else: raise AssertionError('%r didnt raise %s' % (func, exception)) def test_deliver_chunk_success_report(self): client, server = proc.waitall(self.setup_two_endpoints()) client, server = GreenMSRPSession(client), GreenMSRPSession(server) chunk = self.make_hello(client.msrp, success_report='yes') self._test_deliver_chunk(client, server, chunk) report = client.receive_chunk() self.assertEqual(report.method, 'REPORT') self.assertEqual(report.message_id, chunk.message_id) self.assertEqual(report.byte_range, chunk.byte_range) client.shutdown() server.shutdown() def test_send_chunk_response_localtimeout(self): client, server = proc.waitall(self.setup_two_endpoints()) client, server = GreenMSRPSession_ZeroTimeout(client), GreenMSRPSession(server) x = self.make_hello(client.msrp) self.assertRaisesCode(LocalResponse, 408, client.deliver_chunk, x) y = server.receive_chunk() self.assertSameData(x, y) #self.assertNoIncoming(0.1, client, server) server.shutdown() client.shutdown() def test_close_connection__receive(self): client, server = proc.waitall(self.setup_two_endpoints()) assert isinstance(client, MSRPTransport), repr(client) client, server = GreenMSRPSession(client), GreenMSRPSession(server) client.shutdown() self.assertRaises(ConnectionDone, server.receive_chunk) self.assertRaises(MSRPSessionError, server.send_chunk, self.make_hello(server.msrp)) self.assertRaises(ConnectionDone, client.receive_chunk) self.assertRaises(MSRPSessionError, client.send_chunk, self.make_hello(client.msrp)) def test_reader_failed__receive(self): # if reader fails with an exception, receive_chunk raises that exception # send_chunk raises an error and the other party gets closed connection client, server = proc.waitall(self.setup_two_endpoints()) client, server = GreenMSRPSession(client), GreenMSRPSession(server) client.reader_job.kill(InjectedError("Killing client's reader_job")) self.assertRaises(InjectedError, client.receive_chunk) self.assertRaises(MSRPSessionError, client.send_chunk, self.make_hello(client.msrp)) self.assertRaises(ConnectionClosed, server.receive_chunk) self.assertRaises(MSRPSessionError, server.send_chunk, self.make_hello(server.msrp)) def test_reader_failed__send(self): client, server = proc.waitall(self.setup_two_endpoints()) client, server = GreenMSRPSession(client), GreenMSRPSession(server) client.reader_job.kill(InjectedError("Killing client's reader_job")) api.sleep(0.1) self.assertRaises(MSRPSessionError, client.send_chunk, self.make_hello(client.msrp)) self.assertRaises(InjectedError, client.receive_chunk) api.sleep(0.1) self.assertRaises(MSRPSessionError, server.send_chunk, self.make_hello(server.msrp)) self.assertRaises(ConnectionClosed, server.receive_chunk) class MSRPSessionTest_TLS(TLSMixin, MSRPSessionTest): pass class ServerTest(TestBase): server = None @classmethod def get_server(cls): if cls.server is None: cls.server = MSRPServer(logger=cls.server_logger) return cls.server def get_server_uri(self): return URI(port=28550, use_tls=self.use_tls, credentials=self.server_credentials) def test_2_servers_same_port(self): server = self.get_server() server_uri_1 = server.prepare(self.get_server_uri()) server_uri_2 = server.prepare(self.get_server_uri()) suri:gnutls.interfaces.twisted.X509Credentials = self.get_server_uri() assert len(server.ports)==1, server.ports assert len(list(server.ports.values())[0])==1, server.ports connector = self.get_connector() client1_full_local_path = connector.prepare() server_transport_event = TimeoutEvent() proc.spawn(server.complete, client1_full_local_path).link(server_transport_event) client1_transport = connector.complete(server_uri_1) server_transport = server_transport_event.wait() self._test_write_chunk(client1_transport, server_transport) self._test_write_chunk(server_transport, client1_transport) client1_transport.loseConnection() server_transport.loseConnection() client2_full_local_path = connector.prepare() server_transport_event = TimeoutEvent() proc.spawn(server.complete, client2_full_local_path).link(server_transport_event) client2_transport = connector.complete(server_uri_2) server_transport = server_transport_event.wait() self._test_write_chunk(client2_transport, server_transport) self._test_write_chunk(server_transport, client2_transport) client2_transport.loseConnection() server_transport.loseConnection() class ServerTest_TLS(TLSMixin, ServerTest): pass from optparse import OptionParser parser = OptionParser() parser.add_option('--domain') parser.add_option('--username') parser.add_option('--password') parser.add_option('--host') parser.add_option('--port', default=2855) parser.add_option('--log-client', action='store_true', default=False) parser.add_option('--log-server', action='store_true', default=False) parser.add_option('--debug', action='store_true', default=False) options, _args = parser.parse_args() log.Formatter.prefix_format = '' if options.debug: log.level.current = log.level.DEBUG TestBase.client_logger.log_traffic = options.log_client TestBase.server_logger.log_traffic = options.log_server relays = [] # SRV: if options.domain is not None: relays.append(MSRPRelaySettings(options.domain, options.username, options.password)) # explicit host: if options.host is not None: assert options.domain is not None relays.append(MSRPRelaySettings(options.domain, options.username, options.password, options.host, options.port)) configs = [] for relay in relays: configs.append({'server_relay': relay, 'client_relay': None}) configs.append({'server_relay': relay, 'client_relay': relay}) def get_config_name(config): result = [] for name, relay in list(config.items()): if relay is not None: x = name print(name, relay.host) if relay.host is None: x += '_srv' result.append(x) return '_'.join(result) def make_tests_for_other_configurations(TestClass): klass = TestClass.__name__ for config in configs: config_name = get_config_name(config) klass_name = klass + '_' + config_name while klass_name in globals(): klass_name += '_x' new_class = type(klass_name, (TestClass, ), copy(config)) print(klass_name) globals()[klass_name] = new_class if relays: print('Relays: ') pprint.pprint(relays) print() if configs: print('Configs: ') pprint.pprint(configs) print() make_tests_for_other_configurations(MSRPTransportTest) make_tests_for_other_configurations(MSRPSessionTest) if __name__=='__main__': test = unittest.defaultTestLoader.loadTestsFromModule(sys.modules['__main__']) testRunner = unittest.TextTestRunner().run(test)