# -*- coding: utf-8 -*-

# Author: Alejandro J. Cura <alecu@canonical.com>
#
# Copyright 2011 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Tests for the tcpactivation module."""

# twisted uses a different coding convention
# pylint: disable=C0103,W0232

# this test module access a few protected members (starting with _)
# pylint: disable=W0212

from twisted.internet import defer, protocol, reactor, task
from twisted.trial.unittest import TestCase

from ubuntu_sso.utils import tcpactivation
from ubuntu_sso.utils.tcpactivation import (
    ActivationClient,
    ActivationConfig,
    ActivationDetector,
    ActivationInstance,
    ActivationTimeoutError,
    AlreadyStartedError,
    NullProtocol,
    PortDetectFactory,
)

SAMPLE_SERVICE = "test_service_name"
SAMPLE_CMDLINE = ["python", __file__, "-server"]
SAMPLE_PORT = 55555


class FakeServerProtocol(protocol.Protocol):
    """A test protocol."""

    def dataReceived(self, data):
        """Echo the data received."""
        self.transport.write(data)


class FakeServerFactory(protocol.Factory):
    """A factory for the test server."""

    protocol = FakeServerProtocol


class FakeTransport(object):
    """A fake transport."""

    connectionLost = False

    def loseConnection(self):
        """Remember that the connection was dropped."""
        self.connectionLost = True


class AsyncSleepTestCase(TestCase):
    """Tests for the async_sleep function."""

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test instance."""
        yield super(AsyncSleepTestCase, self).setUp()
        self.test_timeout = 5.0
        self.clock = task.Clock()
        self.patch(tcpactivation, "reactor", self.clock)
        self.d = tcpactivation.async_sleep(self.test_timeout)

    def test_async_sleep_not_fired_immediately(self):
        """The async_sleep deferred is not fired immediately."""
        self.assertFalse(self.d.called, "Must not be fired immediately.")

    def test_async_sleep_not_fired_in_a_bit(self):
        """The async_sleep deferred is not fired before the right time."""
        self.clock.advance(self.test_timeout / 2)
        self.assertFalse(self.d.called, "Must not be fired yet.")

    def test_async_sleep_fired_at_the_right_time(self):
        """The async_sleep deferred is fired at the right time."""
        self.clock.advance(self.test_timeout)
        self.assertTrue(self.d.called, "Must be fired by now.")


class NullProtocolTestCase(TestCase):
    """A test for the NullProtocol class."""

    def test_drops_connection(self):
        """The protocol drops the connection."""
        np = NullProtocol()
        np.transport = FakeTransport()
        np.connectionMade()
        self.assertTrue(np.transport.connectionLost,
                        "the connection must be dropped.")


class PortDetectFactoryTestCase(TestCase):
    """Tests for the PortDetectFactory."""

    timeout = 2

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test instance."""
        yield super(PortDetectFactoryTestCase, self).setUp()
        self.factory = PortDetectFactory()

    @defer.inlineCallbacks
    def test_is_listening(self):
        """Test that the deferred returns True when something is listening."""
        addr = (tcpactivation.LOCALHOST, SAMPLE_PORT)
        self.factory.buildProtocol(addr)
        is_listening = yield self.factory.is_listening()
        self.assertTrue(is_listening)

    @defer.inlineCallbacks
    def test_connection_lost(self):
        """Test that the deferred returns False when the connection is lost."""
        self.factory.clientConnectionLost(None, "test reason")
        is_listening = yield self.factory.is_listening()
        self.assertFalse(is_listening)

    @defer.inlineCallbacks
    def test_connection_failed(self):
        """Test that the deferred returns False when the connection fails."""
        self.factory.clientConnectionFailed(None, "test reason")
        is_listening = yield self.factory.is_listening()
        self.assertFalse(is_listening)

    @defer.inlineCallbacks
    def test_connection_failed_then_lost(self):
        """It's not an error if two events happen."""
        self.factory.clientConnectionFailed(None, "test reason")
        self.factory.clientConnectionLost(None, "test reason")
        is_listening = yield self.factory.is_listening()
        self.assertFalse(is_listening)

    @defer.inlineCallbacks
    def test_connection_works_then_lost(self):
        """It's not an error if two events happen."""
        addr = (tcpactivation.LOCALHOST, SAMPLE_PORT)
        self.factory.buildProtocol(addr)
        d = self.factory.is_listening()
        self.factory.clientConnectionLost(None, "test reason")
        is_listening = yield d
        self.assertTrue(is_listening)


class ActivationConfigTestCase(TestCase):
    """Tests for the ActivationConfig class."""

    def test_initialization(self):
        """Test the constructor."""
        config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE, SAMPLE_PORT)
        self.assertEqual(config.service_name, SAMPLE_SERVICE)
        self.assertEqual(config.command_line, SAMPLE_CMDLINE)
        self.assertEqual(config.port, SAMPLE_PORT)


class ActivationDetectorTestCase(TestCase):
    """Tests for the ActivationDetector class."""

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test instance."""
        yield super(ActivationDetectorTestCase, self).setUp()
        self.config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE,
                                       SAMPLE_PORT)

    def test_initialization(self):
        """Test the constructor."""
        self.config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE,
                                       SAMPLE_PORT)
        ai = ActivationDetector(self.config)
        self.assertEqual(ai.config, self.config)

    @defer.inlineCallbacks
    def test_is_not_already_running(self):
        """Test the is_already_running method returns False."""
        ad = ActivationDetector(self.config)
        result = yield ad.is_already_running()
        self.assertFalse(result, "It should not be already running.")

    @defer.inlineCallbacks
    def test_is_already_running(self):
        """The is_already_running method returns True if already started."""
        f = FakeServerFactory()
        # pylint: disable=E1101
        listener = reactor.listenTCP(SAMPLE_PORT, f,
                                     interface=tcpactivation.LOCALHOST)
        self.addCleanup(listener.stopListening)
        ad = ActivationDetector(self.config)
        result = yield ad.is_already_running()
        self.assertTrue(result, "It should be already running.")


class ActivationClientTestCase(TestCase):
    """Tests for the ActivationClient class."""

    timeout = 2

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test instance."""
        yield super(ActivationClientTestCase, self).setUp()
        self.config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE,
                                       SAMPLE_PORT)

    def test_initialization(self):
        """Test the constructor."""
        ac = ActivationClient(self.config)
        self.assertEqual(ac.config, self.config)

    @defer.inlineCallbacks
    def test_do_get_active_port_running(self):
        """Test the _do_get_active_port method when the server is running."""
        ac = ActivationClient(self.config)
        self.patch(ac, "is_already_running", lambda: defer.succeed(True))
        result = yield ac._do_get_active_port()
        self.assertEqual(result, SAMPLE_PORT)

    @defer.inlineCallbacks
    def test_do_get_active_port_not_running(self):
        """Test _do_get_active_port method when the server is not running."""
        server_spawned = []
        ac = ActivationClient(self.config)
        self.patch(ac, "_spawn_server",
                                lambda *args: server_spawned.append(args))
        self.patch(ac, "is_already_running", lambda: defer.succeed(False))
        self.patch(ac, "_wait_server_active", lambda: defer.succeed(None))
        result = yield ac._do_get_active_port()
        self.assertEqual(result, SAMPLE_PORT)
        self.assertEqual(len(server_spawned), 1)

    def test_get_active_port_waits_classwide(self):
        """Test the get_active_port method locks classwide."""
        d = defer.Deferred()
        ac1 = ActivationClient(self.config)
        ac2 = ActivationClient(self.config)
        self.patch(ac1, "_do_get_active_port", lambda: d)
        self.patch(ac2, "_do_get_active_port", lambda: defer.succeed(None))
        ac1.get_active_port()
        d2 = ac2.get_active_port()
        self.assertFalse(d2.called, "The second must wait for the first.")
        d.callback(SAMPLE_PORT)
        self.assertTrue(d2.called, "The second can fire after the first.")

    def test_wait_server_active(self):
        """Test the _wait_server_active method."""
        ac = ActivationClient(self.config)
        clock = task.Clock()
        self.patch(tcpactivation, "reactor", clock)
        self.patch(ac, "is_already_running", lambda: defer.succeed(False))

        d = ac._wait_server_active()

        self.assertFalse(d.called, "The deferred should not be fired yet.")
        clock.advance(tcpactivation.DELAY_BETWEEN_CHECKS)
        self.assertFalse(d.called, "The deferred should not be fired yet.")
        self.patch(ac, "is_already_running", lambda: defer.succeed(True))
        clock.advance(tcpactivation.DELAY_BETWEEN_CHECKS)
        self.assertTrue(d.called, "The deferred should be fired by now.")

    def test_wait_server_timeouts(self):
        """If the server takes too long to start then timeout."""
        ac = ActivationClient(self.config)
        clock = task.Clock()
        self.patch(tcpactivation, "reactor", clock)
        self.patch(ac, "is_already_running", lambda: defer.succeed(False))
        d = ac._wait_server_active()
        clock.pump([tcpactivation.DELAY_BETWEEN_CHECKS] *
                      tcpactivation.NUMBER_OF_CHECKS)
        return self.assertFailure(d, ActivationTimeoutError)

    def test_spawn_server(self):
        """Test the _spawn_server method."""
        popen_calls = []
        ac = ActivationClient(self.config)
        self.patch(tcpactivation.subprocess, "Popen",
                   lambda *args, **kwargs: popen_calls.append((args, kwargs)))
        ac._spawn_server()
        self.assertEqual(len(popen_calls), 1)


class ActivationInstanceTestCase(TestCase):
    """Tests for the ActivationServer class."""

    timeout = 2

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize this test instance."""
        yield super(ActivationInstanceTestCase, self).setUp()
        self.config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE,
                                       SAMPLE_PORT)

    def test_initialization(self):
        """Test the constructor."""
        ai = ActivationInstance(self.config)
        self.assertEqual(ai.config, self.config)

    @defer.inlineCallbacks
    def test_get_port(self):
        """Test the get_port method."""
        ai = ActivationInstance(self.config)
        port = yield ai.get_port()
        self.assertEqual(port, SAMPLE_PORT)

    @defer.inlineCallbacks
    def test_get_port_fails_if_service_already_started(self):
        """The get_port method fails if service already started."""
        ai1 = ActivationInstance(self.config)
        port1 = yield ai1.get_port()
        f = FakeServerFactory()
        # pylint: disable=E1101
        listener = reactor.listenTCP(port1, f,
                                     interface=tcpactivation.LOCALHOST)
        self.addCleanup(listener.stopListening)
        ai2 = ActivationInstance(self.config)
        yield self.assertFailure(ai2.get_port(), AlreadyStartedError)


def server_test(config):
    """An IRL test of the server."""

    def got_port(port_number):
        """The port number was found."""
        print "got server port:", port_number

        # start listening
        f = FakeServerFactory()
        # pylint: disable=E1101
        reactor.listenTCP(port_number, f)

        # try to get the port again
        get_port()

    def already_started(failure):
        """This instance was already started."""
        print "already started!"
        # pylint: disable=E1101
        reactor.callLater(3, reactor.stop)

    def get_port():
        """Try to get the port number."""
        get_port_d = ai.get_port()
        get_port_d.addCallback(got_port)
        get_port_d.addErrback(already_started)

    print "starting the server."
    ai = ActivationInstance(config)
    get_port()
    # pylint: disable=E1101
    reactor.run()


def client_test(config):
    """An IRL test of the client."""
    print "starting the client."
    ac = ActivationClient(config)
    d = ac.get_active_port()

    def got_port(port_number):
        """The port number was found."""
        print "client got server port:", port_number
        reactor.stop()

    d.addCallback(got_port)
    # pylint: disable=E1101
    reactor.run()


def irl_test():
    """Do an IRL test of the client and the server."""
    import sys
    config = ActivationConfig(SAMPLE_SERVICE, SAMPLE_CMDLINE, SAMPLE_PORT)
    if "-server" in sys.argv[1:]:
        server_test(config)
    else:
        client_test(config)

if __name__ == "__main__":
    irl_test()
