#!/usr/bin/python

# Wrapper around cloud-localds and libvirt

# Copyright (C) 2012-3 Canonical Ltd.
# Author: Robie Basak <robie.basak@canonical.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Depends: distro-info, cloud-utils, python-libvirt, python-lxml
# also qemu-kvm (precise) or kvm (newer?)
# The import subcommand needs: qemu-utils (for qemu-img)

from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import codecs
import errno
import os
import shutil
import StringIO
import subprocess
import sys
import tempfile
import uuid
import yaml

import libvirt
from lxml import etree
from lxml.builder import E

from uvtool.libvirt import create_volume_from_fobj
import uvtool.libvirt.simplestreams

DEFAULT_TEMPLATE = '/usr/share/uvtool/libvirt/template.xml'
POOL_NAME = 'uvtool'


def create_default_user_data(fobj, args):
    """Write some sensible default cloud-init user-data to the given file
    object.

    """
    try:
        f = open(os.path.join(os.environ['HOME'], '.ssh', 'id_rsa.pub'), 'rb')
    except IOError as e:
        if e.errno != errno.ENOENT:
            raise
        print(
            "Warning: ~/.ssh/id_rsa.pub not found; instance will be started " +
                "with no ssh access by default.",
            file=sys.stderr,
        )
        ssh_authorized_keys = []
    else:
        with f:
            ssh_authorized_keys = [f.read().strip()]

    data = {
        b'hostname': args.hostname.encode('ascii'),
        b'manage_etc_hosts': b'localhost',
        b'packages': b'avahi-daemon',
    }

    if ssh_authorized_keys:
        data[b'ssh_authorized_keys'] = ssh_authorized_keys

    if args.password:
        data[b'password'] = args.password.encode('utf-8')
        data[b'chpasswd'] = {b'expire': False}
        data[b'ssh_pwauth'] = True

    fobj.write("#cloud-config\n")
    fobj.write(yaml.dump(data))


def create_ds_image(temp_dir, hostname, user_data_fobj):
    """Create a file called ds.img inside temp_dir that contains a useful
    cloud-init data source.

    Other temporary files created in temp_dir are currently metadata and
    userdata and can be safely deleted.

    """

    with codecs.open(
            os.path.join(temp_dir, 'metadata'), 'w', encoding='ascii') as f:
        f.write("instance-id: %s\n" % str(uuid.uuid1()))

    with open(os.path.join(temp_dir, 'userdata'), 'wb') as f:
        f.write(user_data_fobj.read())

    subprocess.check_call(
        ['cloud-localds', 'ds.img', 'userdata', 'metadata'], cwd=temp_dir)


def create_ds_volume(new_volume_name, hostname, user_data_fobj):
    """Create a new libvirt cloud-init datasource volume."""

    temp_dir = tempfile.mkdtemp(prefix='uvt-kvm-')
    try:
        create_ds_image(temp_dir, hostname, user_data_fobj)
        with open(os.path.join(temp_dir, 'ds.img'), 'rb') as f:
            return create_volume_from_fobj(
                new_volume_name, f, pool_name=POOL_NAME)
    finally:
        shutil.rmtree(temp_dir)


def create_cow_volume(backing_volume_name, new_volume_name, new_volume_size,
        conn=None):

    if conn is None:
        conn = libvirt.open('qemu:///system')

    pool = conn.storagePoolLookupByName(POOL_NAME)
    try:
        backing_vol = pool.storageVolLookupByName(backing_volume_name)
    except libvirt.libvirtError:
        raise RuntimeError("Cannot find volume %s" % backing_volume_name)

    return create_cow_volume_by_path(
        backing_volume_path=backing_vol.path(),
        new_volume_name=new_volume_name,
        new_volume_size=new_volume_size,
        conn=conn
    )

def create_cow_volume_by_path(backing_volume_path, new_volume_name,
        new_volume_size, conn=None):
    """Create a new libvirt qcow2 volume backed by an existing volume path."""

    if conn is None:
        conn = libvirt.open('qemu:///system')

    pool = conn.storagePoolLookupByName(POOL_NAME)

    new_vol = E.volume(
        E.name(new_volume_name),
        E.allocation('0'),
        E.capacity(str(new_volume_size), unit='G'),
        E.target(E.format(type='qcow2')),
        E.backingStore(
            E.path(backing_volume_path),
            E.format(type='qcow2'),
            )
        )
    return pool.createXML(etree.tostring(new_vol), 0)


def compose_domain_xml(name, volumes, cpu=1, memory=512, unsafe_caching=False,
        template_path=DEFAULT_TEMPLATE, log_console_output=False, bridge=None):
    tree = etree.parse(template_path)
    domain = tree.getroot()
    assert domain.tag == 'domain'

    etree.strip_elements(domain, 'name')
    etree.SubElement(domain, 'name').text = name

    etree.strip_elements(domain, 'vcpu')
    etree.SubElement(domain, 'vcpu').text = str(cpu)

    etree.strip_elements(domain, 'currentMemory')
    etree.SubElement(domain, 'currentMemory').text = str(memory * 1024)

    etree.strip_elements(domain, 'memory')
    etree.SubElement(domain, 'memory').text = str(memory * 1024)

    devices = domain.find('devices')

    etree.strip_elements(devices, 'disk')
    for disk_device, vol in zip(['vda', 'vdb'], volumes):
        disk_format_type = (
            etree.fromstring(vol.XMLDesc(0)).
            find('target').
            find('format').
            get('type')
            )
        if unsafe_caching:
            disk_driver = E.driver(
                name='qemu', type=disk_format_type, cache='unsafe')
        else:
            disk_driver = E.driver(name='qemu', type=disk_format_type)
        devices.append(
            E.disk(
                disk_driver,
                E.source(file=vol.path()),
                E.target(dev=disk_device),
                type='file',
                device='disk',
                )
            )

    if bridge:
        etree.strip_elements(devices, 'interface')
        devices.append(E.interface(E.source(bridge=bridge), type='bridge'))

    if log_console_output:
        print(
            "Warning: logging guest console output introduces a DoS " +
                "security problem on the host and should not be used in " +
                "production.",
            file=sys.stderr
        )
        etree.strip_elements(devices, 'serial')
        devices.append(E.serial(E.target(port='0'), type='stdio'))

    return etree.tostring(tree)


def get_base_image(filters):
    result = list(uvtool.libvirt.simplestreams.query(filters))
    if not result:
        raise RuntimeError(
            "No images found that match filters %s." % repr(filters))
    elif len(result) != 1:
        raise RuntimeError(
            "Multiple images found that match filters %s." % repr(filters))
    return result[0]


def create(hostname, filters, user_data_fobj, memory=512, cpu=1, disk=2,
        unsafe_caching=False, template_path=DEFAULT_TEMPLATE,
        log_console_output=False, bridge=None, backing_image_file=None):
    if backing_image_file is None:
        base_volume_name = get_base_image(filters)
    undo_volume_creation = []
    try:
        # cow image names must end in ".qcow" so that the current Apparmor
        # profile for /usr/lib/libvirt/virt-aa-helper is able to read them,
        # determine their backing volumes, and generate a dynamic libvirt
        # profile that permits reading the backing volume. Once our pool
        # directory is added to the virt-aa-helper profile, this requirement
        # can be dropped.

        if backing_image_file:
            main_vol = create_cow_volume_by_path(
                backing_image_file, "%s.qcow" % hostname, disk)
        else:
            main_vol = create_cow_volume(
                base_volume_name, "%s.qcow" % hostname, disk)
        undo_volume_creation.append(main_vol)

        ds_vol = create_ds_volume(
            "%s-ds.qcow" % hostname, hostname, user_data_fobj)
        undo_volume_creation.append(ds_vol)

        xml = compose_domain_xml(
            hostname, [main_vol, ds_vol],
            bridge=bridge,
            cpu=cpu,
            log_console_output=log_console_output,
            memory=memory,
            template_path=template_path,
            unsafe_caching=unsafe_caching,
        )
        conn = libvirt.open('qemu:///system')
        domain = conn.defineXML(xml)
        try:
            domain.create()
        except:
            domain.undefine()
            raise
    except:
        for vol in undo_volume_creation:
            vol.delete(0)
        raise


def delete_domain_volumes(conn, domain):
    """Delete all volumes associated with a domain.

    :param conn: libvirt connection object
    :param domain: libvirt domain object

    """
    domain_xml = etree.fromstring(domain.XMLDesc(0))
    assert domain_xml.tag == 'domain'
    for disk in domain_xml.find('devices').iter('disk'):
        disk_file = disk.find('source').get('file')
        vol = conn.storageVolLookupByKey(disk_file)
        vol.delete(0)


def destroy(hostname):
    conn = libvirt.open('qemu:///system')
    domain = conn.lookupByName(hostname)
    state = domain.state(0)[0]
    if state != libvirt.VIR_DOMAIN_SHUTOFF:
        domain.destroy()

    delete_domain_volumes(conn, domain)

    domain.undefine()


def get_lts_series():
    output = subprocess.check_output(['distro-info', '--lts'], close_fds=True)
    return output.strip()


def main_create_get_user_data_fobj(args):
    """Return a user-data fobj to use, based on command line arguments
    supplied.

    If no user-data was supplied on the command line, then create a temporary
    object that contains sensible default user-data.

    """
    if args.user_data:
        user_data_fobj = args.user_data
    else:
        user_data_fobj = StringIO.StringIO()
        create_default_user_data(user_data_fobj, args)
        user_data_fobj.seek(0)
    return user_data_fobj


def check_kvm_ok():
    try:
        process = subprocess.Popen(
            ['kvm-ok'], shell=False, stdout=subprocess.PIPE, close_fds=True)
    except OSError as e:
        if e.errno != errno.ENOENT:
            raise
        # Ignore if we can't find kvm-ok executable
        return True, None
    stdout, stderr = process.communicate()
    return (False, stdout) if process.returncode else (True, None)


def main_create(parser, args):
    if args.user_data and args.password:
        parser.error("--password cannot be used with --user-data.")
    if args.password:
        print(
            "Warning: using --password from the command line is " +
                "not secure and should be used for debugging only.",
            file=sys.stderr
        )

    kvm_ok, is_kvm_ok_output = check_kvm_ok()
    if not kvm_ok:
        print(
            "KVM not available. kvm-ok returned:", is_kvm_ok_output,
            sep="\n", end="", file=sys.stderr
        )
        return

    user_data_fobj = main_create_get_user_data_fobj(args)
    if args.backing_image_file:
        abs_image_backing_file = os.path.abspath(args.backing_image_file)
    else:
        abs_image_backing_file = None
    create(
        args.hostname, args.filters, user_data_fobj,
        backing_image_file=abs_image_backing_file,
        bridge=args.bridge,
        cpu=args.cpu,
        disk=args.disk,
        log_console_output=args.log_console_output,
        memory=args.memory,
        template_path=args.template,
        unsafe_caching=args.unsafe_caching,
    )


def main_destroy(parser, args):
    for h in args.hostname:
        destroy(h)


def main_import(parser, args):
    with open(args.filename, 'rb') as f:
        create_volume_from_fobj(args.image_name, f, image_type='qcow2')


def main_list(parser, args):
    # Hack for now. In time this should properly use the API and list
    # only instances created with this tool.
    subprocess.check_call('virsh -q list --all|awk \'{print $2}\'', shell=True)


class DeveloperOptionAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        namespace.unsafe_caching = True
        namespace.log_console_output = True


def main(args):
    print(
        "Warning: this CLI is experimental and may change.",
        file=sys.stderr
    )
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    create_subparser = subparsers.add_parser('create')
    create_subparser.set_defaults(func=main_create)
    create_subparser.add_argument(
        '--developer', '-d', nargs=0, action=DeveloperOptionAction)
    create_subparser.add_argument('--template', default=DEFAULT_TEMPLATE)
    create_subparser.add_argument('--memory', default=512, type=int)
    create_subparser.add_argument('--cpu', default=1, type=int)
    create_subparser.add_argument('--disk', default=8, type=int)
    create_subparser.add_argument('--bridge')
    create_subparser.add_argument('--unsafe-caching', action='store_true')
    create_subparser.add_argument(
        '--user-data', type=argparse.FileType('rb'))
    create_subparser.add_argument('--password')
    create_subparser.add_argument('--log-console-output', action='store_true')
    create_subparser.add_argument('--backing-image-file')
    create_subparser.add_argument('hostname')
    create_subparser.add_argument(
        'filters', nargs='*', metavar='filter',
        default=["release=%s" % get_lts_series()],
    )
    destroy_subparser = subparsers.add_parser('destroy')
    destroy_subparser.set_defaults(func=main_destroy)
    destroy_subparser.add_argument('hostname', nargs='+')
    import_subparser = subparsers.add_parser('import')
    import_subparser.set_defaults(func=main_import)
    import_subparser.add_argument('image_name')
    import_subparser.add_argument('filename')
    list_subparser = subparsers.add_parser('list')
    list_subparser.set_defaults(func=main_list)
    args = parser.parse_args(args)
    args.func(parser, args)


if __name__ == '__main__':
    main(sys.argv[1:])
