/***************************************************************************
 *
 * This file is covered by a dual licence. You can choose whether you
 * want to use it according to the terms of the GNU GPL version 2, or
 * under the terms of Zorp Professional Firewall System EULA located
 * on the Zorp installation CD.
 *
 * $Id: connect.c,v 1.52 2004/10/05 14:06:37 chaoron Exp $
 *
 * Author  : Bazsi
 * Auditor :
 * Last audited version:
 * Notes:
 *
 ***************************************************************************/

#include <zorp/connect.h>
#include <zorp/io.h>
#include <zorp/log.h>
#include <zorp/socketsource.h>
#include <zorp/socket.h>
#include <zorp/error.h>

#include <sys/types.h>
#ifdef HAVE_UNISTD_H
  #include <unistd.h>
#endif
#include <fcntl.h>

#ifdef G_OS_WIN32
#  include <winsock2.h>
#include <io.h>
#define close _close 
#else
#  include <netinet/tcp.h>
#  include <netinet/in.h>
#  include <sys/poll.h>
#endif

/**
 * ZIORealConnect:
 * 
 * The ZIOConnect interface can be used to connect a socket to a given
 * destination, using the given local address. The connection establishment
 * is started asynchronously, and a callback is invoked when upon completition.
 *
 **/
typedef struct _ZIORealConnect 
{
  ZIOConnect super;
  ZSockAddr *remote;
  
  /* we use a reference to our GSource, as using source_id would cause a race */
  GSource *watch;
  gint timeout;
  ZConnectFunc callback;
  gpointer user_data;
  GDestroyNotify destroy_notify;
  gint refcnt;
  GStaticRecMutex lock;
  GMainContext *context;
  gboolean blocking;
  guint32 sock_flags;
  gchar *session_id;
} ZIORealConnect;

/** 
 * z_io_connect_connected:
 * @timed_out: specifies whether the operation timed out
 * @data: user data passed by socket source, assumed to point to ZIOConnect instance
 *
 * Private callback function, registered to a #ZSocketSource to be called
 * when the socket becomes writeable, e.g. when the connection is
 * established.
 *
 * Returns: always returns FALSE to indicate that polling the socket should end
 *
 **/
static gboolean 
z_io_connect_connected(gboolean timed_out, gpointer data)
{
  ZIORealConnect *self = (ZIORealConnect *) data;
  int error_num = 0;
  socklen_t errorlen = sizeof(error_num);
  ZConnectFunc callback;
  GError *err = NULL;
  gint fd;
  
  z_enter();
  
  if (!self->callback)
    {
      /* we have been called */
      z_leave();
      return FALSE;
    }
  fd = self->super.fd;
  if (timed_out)
    {
      error_num = ETIMEDOUT;
    }
  else if (getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)(&error_num), &errorlen) == -1)
    {
      /*LOG
        This message indicates that getsockopt(SOL_SOCKET, SO_ERROR)
        failed for the given fd. This system call should never fail,
        so if you see this please report it to the Zorp QA.
       */
      z_log(self->session_id, CORE_ERROR, 0, "getsockopt(SOL_SOCKET, SO_ERROR) failed for connecting socket, ignoring; fd='%d', error='%s'", self->super.fd, g_strerror(errno));
    }
  if (error_num)
    {
      char buf[MAX_SOCKADDR_STRING];
      
      /*LOG
        This message indicates that the connection to the remote end
	 failed for the given reason. It is likely that the remote end
	 is unreachable.
       */
      z_log(self->session_id, CORE_ERROR, 1, "Connection to remote end failed; remote='%s', error='%s'", z_sockaddr_format(self->remote, buf, sizeof(buf)), g_strerror(error_num));
      
      /* self->poll.fd is closed when we are freed */
      fd = -1;
    }
  else
    {
#ifdef G_OS_WIN32
      WSAEventSelect(fd, 0, 0);
#endif
      z_fd_set_nonblock(fd, 0);
      z_fd_set_keepalive(fd, 1);
      
      /* don't close our fd when freed */
      self->super.fd = -1;
    }
  
  g_static_rec_mutex_lock(&self->lock);
  
  if (self->watch || self->blocking)
    {
      if (error_num)
        g_set_error(&err, 0, error_num, g_strerror(error_num));
      
      callback = self->callback;
      self->callback = NULL;
      callback(fd, err, self->user_data);
      
      g_clear_error(&err);
    }
  else
    {
      /*LOG
        This message reports that the connection was cancelled, and
        no further action is taken.
       */
      z_log(self->session_id, CORE_DEBUG, 6, "Connection cancelled, not calling callback; fd='%d'", fd);
      close(fd);
    }
  g_static_rec_mutex_unlock(&self->lock);

  z_leave();
  
  /* don't poll again, and destroy associated source */
  return FALSE;
}

/**
 * z_io_connect_source_destroy_cb:
 * @self: ZIOConnect instance
 *
 * This function is registered as the destroy notify function of @self when
 * the associated #ZSocketSource source is destroyed. It calls our
 * destroy_notify callback, and unrefs self.
 *
 **/
static void
z_io_connect_source_destroy_cb(ZIORealConnect *self)
{
  if (self->destroy_notify)
    self->destroy_notify(self->user_data);
  z_io_connect_unref((ZIOConnect *) self);
}

/**
 * z_io_connect_start_internal:
 * @s: ZIOConnect instance
 *
 * This function is used by the different z_io_connect_start_*() functions,
 * it contains the common things to do when a connection is initiated.
 *
 * Returns: the local address where we are bound
 **/
static ZSockAddr *
z_io_connect_start_internal(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  ZSockAddr *local = NULL;
  gchar buf1[MAX_SOCKADDR_STRING], buf2[MAX_SOCKADDR_STRING];

  z_enter();
  /*LOG
    This message reports that a new connection is initiated
    from/to the given addresses.
   */
  z_log(self->session_id, CORE_DEBUG, 7, "Initiating connection; from='%s', to='%s'",
        self->super.local ? z_sockaddr_format(self->super.local, buf1, sizeof(buf1)) : "NULL",
        z_sockaddr_format(self->remote, buf2, sizeof(buf2)));

  if (z_connect(self->super.fd, self->remote, self->sock_flags) != G_IO_STATUS_NORMAL && !z_errno_is(EINPROGRESS))
    {
      
      /*LOG
        This message indicates that the connection to the remote end
        failed for the given reason. It is likely that the remote end
        is unreachable.
       */
      z_log(self->session_id, CORE_ERROR, 2, "Connection to remote end failed; local='%s', remote='%s', error='%s'", 
            self->super.local ? z_sockaddr_format(self->super.local, buf1, sizeof(buf1)) : "NULL", 
            z_sockaddr_format(self->remote, buf2, sizeof(buf2)), g_strerror(errno));
      
      z_leave();
      return NULL;
    }

  if (z_getsockname(self->super.fd, &local, self->sock_flags) == G_IO_STATUS_NORMAL)
    {
      ZSockAddr *l;
      
      /* it contained the bind address, we now have an exact value */
      l = self->super.local;
      self->super.local = NULL;
      z_sockaddr_unref(l);
      self->super.local = local;
      z_sockaddr_ref(local);
    }
  return local;
}

/**
 * z_io_connect_start:
 * @s: ZIOConnect instance
 *
 * Start initiating the connection.
 *
 * Returns: the local address we were bound to
 **/
ZSockAddr *
z_io_connect_start(ZIOConnect *s)
{  
  ZIORealConnect *self = (ZIORealConnect *) s;
  ZSockAddr *local;
  
  z_enter();
  
  if (self->watch)
    {
      /*LOG
        This message indicates that the connection was started twice.
        Please report this error to the Balabit QA Team (devel@balabit.com).
       */
      z_log(self->session_id, CORE_ERROR, 3, "Internal error, z_io_connect_start was called twice;");
      z_leave();
      return NULL;
    }

  local = z_io_connect_start_internal(s);
  if (local)
    {
      z_io_connect_ref(s);
      self->watch = z_socket_source_new(self->super.fd, Z_SOCKEVENT_CONNECT, self->timeout);
      
      g_source_set_callback(self->watch, (GSourceFunc) z_io_connect_connected, self, (GDestroyNotify) z_io_connect_source_destroy_cb);
      if (g_source_attach(self->watch, self->context) == 0)
        {
	  /*LOG
	    This message indicates that the connection can not be initiated. It is 
	    likely that some resource is not available.
	   */
          z_log(self->session_id, CORE_ERROR, 3, "Error attaching source to context; fd='%d', context='%p'", self->super.fd, self->context);
          g_source_unref(self->watch);
          self->watch = NULL;
          z_io_connect_unref(s);
          z_sockaddr_unref(local);
          local = NULL;
        }
    }
  
  z_leave();
  return local;
}

/**
 * z_io_connect_start_block:
 * @s: ZIOConnect instance
 *
 * Initiate the connection and block while it either succeeds or fails.
 * Instead of returning the results the user callback is called.
 *
 * Returns: the local address we were bound to
 **/
 
#ifndef G_OS_WIN32

ZSockAddr *
z_io_connect_start_block(ZIOConnect *s)
{  
  ZIORealConnect *self = (ZIORealConnect *) s;
  ZSockAddr *local;
  int res;
  time_t timeout_target, timeout_left;

  z_enter();
  local = z_io_connect_start_internal(s);
  if (local)
    {
      struct pollfd pfd;
      
      z_io_connect_ref(s);

      pfd.fd = self->super.fd;
      pfd.events = POLLOUT;
      pfd.revents = 0;
      timeout_target = time(NULL) + self->timeout;
      do
        {
          timeout_left = timeout_target - time(NULL);
          res = poll((struct pollfd *) &pfd, 1, timeout_left < 0 ? 0 : timeout_left * 1000);
          if (res == 1)
            {
              break;
            }
        }
      while (res == -1 && errno == EINTR);
      self->blocking = 1;
      if (res >= 0)
        {
          z_io_connect_connected(res == 0, s);
        }
      z_io_connect_source_destroy_cb(self);
    }
  z_leave();
  return local;
}

#else

ZSockAddr *
z_io_connect_start_block(ZIOConnect *s)
{  
  ZIORealConnect *self = (ZIORealConnect *) s;
  ZSockAddr *local;
  int res;

  z_enter();
  local = z_io_connect_start_internal(s);
  if (local)
    {
      fd_set rf;
      TIMEVAL to;
 
      to.tv_sec = 0;
      to.tv_usec = self->timeout * 1000;
      FD_ZERO(&rf);
      FD_SET(self->super.fd, &rf);

      do
        {
          res = select(0,&rf,&rf,&rf,&to); //   ((struct pollfd *) &pfd, 1, self->timeout);
          if (res == 1)
            {
              break;
            }
        }
      while (res == -1 && errno == EINTR);
      self->blocking = 1;
      if (res >= 0)
        {
          z_io_connect_connected(res == 0, s);
        }
    }
  z_leave();
  return local;
}

#endif

/**
 * z_io_connect_start_in_context:
 * @s: ZIOConnect instance
 * @context: GMainContext to use for polling
 *
 * Same as z_io_connect_start() but using the context specified in context.
 *
 * Returns: the local address we were bound to
 **/
ZSockAddr *
z_io_connect_start_in_context(ZIOConnect *s, GMainContext *context)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  ZSockAddr *res;

  z_enter();
  g_main_context_ref(context);
  self->context = context;
  res = z_io_connect_start(s);
  z_leave();  
  return res;
}

/**
 * z_io_connect_cancel:
 * @s: ZIOConnect instance
 *
 * Cancel connection after _start was called. It is guaranteed that no
 * user callbacks will be called after z_io_connect_cancel() returns.
 **/
void
z_io_connect_cancel(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;

  z_enter();
  
  g_static_rec_mutex_lock(&self->lock);
  if(self->watch)
    {
      /* Must unlock self->lock before call g_source_destroy,
       * because in another thread we may be hold context lock
       * (inside the glib) and wait for this lock. (For example if
       * the client stop the download in exactly the same time, when
       * the connection failed.
       */
      GSource *watch = self->watch;
      self->watch = NULL;
      g_static_rec_mutex_unlock(&self->lock);

      g_source_destroy(watch);
      g_source_unref(watch);
    }
  else
    {
      g_static_rec_mutex_unlock(&self->lock);
    }
  z_leave();
}

/**
 * z_io_connect_set_timeout:
 * @s: ZIOConnect instance
 * @timeout: timeout in milliseconds
 *
 * Set connection timeout. The connection establishment may not exceed the
 * time specified in timeout.
 * 
 **/
void
z_io_connect_set_timeout(ZIOConnect *s,
			 gint timeout)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  self->timeout = timeout;
}

/**
 * z_io_connect_set_destroy_notify:
 * @s: ZIOConnect instance
 * @notify: destroy notify callback
 *
 * Set the destroy notify callback to be called when this ZIOConnect instance is freed.
 * 
 **/
void 
z_io_connect_set_destroy_notify(ZIOConnect *s, GDestroyNotify notify)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  self->destroy_notify = notify;
}
		 
/**
 * z_io_connect_new:  
 * @session_id: session id used for logging
 * @local: local address to bind to.
 * @remote: remote address to connect to.
 * @sock_flags: 
 * @tos: Type of Service flag.
 * @callback: function to call when the connection is established.
 * @user_data: opaque pointer to pass to callback.
 *
 * This function creates a new ZIOConnect instance.
 *
 * Returns: The allocated instance.
 **/
ZIOConnect *
z_io_connect_new(const gchar *session_id,
                 ZSockAddr *local, 
                 ZSockAddr *remote,
                 guint32 sock_flags,
                 gint tos,
		 ZConnectFunc callback,
		 gpointer user_data)
{
  ZIORealConnect *self = g_new0(ZIORealConnect, 1);
  gchar buf[MAX_SOCKADDR_STRING];

  z_enter();
  self->refcnt = 1;
  self->super.local = z_sockaddr_ref(local);
  self->remote = z_sockaddr_ref(remote);
  self->session_id = session_id ? g_strdup(session_id) : NULL;
  self->callback = callback;
  self->user_data = user_data;
  self->timeout = -1;
  self->sock_flags = sock_flags;
  self->super.fd = socket(remote->sa.sa_family, SOCK_STREAM, 0);
  if (self->super.fd == -1)
    {
      
      /*LOG
        This message indicates that Zorp failed to create a socket for
        establishing a connection with the indicated remote endpoint.
       */
      z_log(self->session_id, CORE_ERROR, 1, "Creating socket for connecting failed; family='%d', type='SOCK_STREAM', remote='%s', error='%s'", remote->sa.sa_family, z_sockaddr_format(self->remote, buf, sizeof(buf)), g_strerror(errno));
      z_io_connect_unref((ZIOConnect *) self);
      z_leave();
      return NULL;
    }
  if (local && z_bind(self->super.fd, local, self->sock_flags) != G_IO_STATUS_NORMAL)
    {
      z_log(self->session_id, CORE_ERROR, 1, "Error binding socket; local='%s', error='%s'", z_sockaddr_format(local, buf, sizeof(buf)), g_strerror(errno));
      z_io_connect_unref((ZIOConnect *) self);
      z_leave();
      return NULL;
    }
  z_fd_set_our_tos(self->super.fd, tos);
  if (!z_fd_set_nonblock(self->super.fd, TRUE))
    {
      /* z_fd_set_nonblock sends log message for failure */
      
      z_io_connect_unref((ZIOConnect *) self);
      z_leave();
      return NULL;
    }
  z_leave();
  return (ZIOConnect *) self;
}

/**
 * z_io_connect_free:
 * @s: ZIOConnect instance to free
 *
 * This function is called by z_io_connect_unref() when the reference count
 * of @s goes down to zero. It frees the contents of s.
 **/
static void 
z_io_connect_free(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  
  z_enter();
  self->callback = NULL;
  if (self->super.fd != -1)
    close(self->super.fd);
  if (self->watch)
    {
      /* self->watch might still be present when the destruction of this
       * object is done by the FALSE return value of our callback.
       * 1) connected returns FALSE
       * 2) GSource calls our destroy notify, which drops the reference 
       *    held by the source
       * 3) when our ref_cnt goes to 0, this function is called, but 
       *    self->watch might still be present 
       *
       * Otherwise the circular reference is broken by _cancel or right in 
       * _start when the error occurs.
       */
      g_source_destroy(self->watch);
      g_source_unref(self->watch);
      self->watch = NULL;
    } 
  z_sockaddr_unref(self->super.local);
  z_sockaddr_unref(self->remote);
  if (self->context)
    g_main_context_unref(self->context);
  
  g_free(self->session_id);
  g_free(self);
  z_leave();
}

/**
 * z_io_connect_ref:
 * @s: ZIOConnect instance
 *
 * Increment the reference count of s.
 **/
void
z_io_connect_ref(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  
  z_enter();
  g_static_rec_mutex_lock(&self->lock);
  g_assert(self->refcnt);
  self->refcnt++;
  g_static_rec_mutex_unlock(&self->lock);
  z_leave();
}

/**
 * z_io_connect_unref:
 * @s: ZIOConnect instance
 *
 * Decrement the reference count of @s and free it if the count goes down
 * to zero.
 **/
void
z_io_connect_unref(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  
  z_enter();
  g_assert(self->refcnt);
  g_static_rec_mutex_lock(&self->lock);
  if (--self->refcnt == 0)
    {
      g_static_rec_mutex_unlock(&self->lock);
      z_io_connect_free(s);
      z_leave();
      return;
    }
  g_static_rec_mutex_unlock(&self->lock);
  z_leave();
}

/**
 * z_io_connect_get_session:
 * @s: ZIOConnect instance
 * 
 * Returns: session_id specified at creation
 **/
G_CONST_RETURN gchar *
z_io_connect_get_session(ZIOConnect *s)
{
  ZIORealConnect *self = (ZIORealConnect *) s;
  
  z_enter();
  z_leave();
  return self->session_id;
}

