// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "ConnectedMachine.h"

namespace Torch {

ConnectedMachine::ConnectedMachine()
{
  n_layers = 0;
  n_machines_on_layer = NULL;
  machines = NULL;
  links = NULL;
  alpha_links = NULL;

  current_layer = -1;
  current_machine = -1;

  alpha_buff = NULL;
  alpha_buff_size = 0;
  current_alpha_offset = 0;

  addLayer();
}

void ConnectedMachine::addFCL(GradientMachine *machine)
{
  if(n_machines_on_layer[current_layer])
    addLayer();
  
  addMachine(machine);
  
  if(n_layers > 1)
  {
    for(int i = 0; i < n_machines_on_layer[current_layer-1]; i++)
      connectOn(machines[current_layer-1][i]);
  }
}

void ConnectedMachine::addLayer()
{
  n_outputs = 0;

  if(n_layers > 0)
  {
    if(n_machines_on_layer[n_layers-1] == 0)
      error("ConnectedMachine: one layer without any machine !?!");
  }

  machines = (GradientMachine ***)xrealloc((void *)machines, (n_layers+1)*sizeof(GradientMachine **));
  links = (List ***)xrealloc((void *)links, (n_layers+1)*sizeof(List **));
  alpha_links = (List ***)xrealloc((void *)alpha_links, (n_layers+1)*sizeof(List **));
  n_machines_on_layer = (int *)xrealloc((void *)n_machines_on_layer, (n_layers+1)*sizeof(int));

  links[n_layers] = NULL;
  alpha_links[n_layers] = NULL;
  machines[n_layers] = NULL;
  n_machines_on_layer[n_layers] = 0;
  
  current_layer = n_layers;
  current_machine = -1;

  n_layers++;
}

void ConnectedMachine::addMachine(GradientMachine *machine)
{
  machines[current_layer] = (GradientMachine **)xrealloc((void *)(machines[current_layer]),
                                                    (n_machines_on_layer[current_layer] + 1)*sizeof(GradientMachine *));
  machines[current_layer][n_machines_on_layer[current_layer]] = machine;

  //---

  links[current_layer] = (List **)xrealloc((void *)(links[current_layer]),
                                                    (n_machines_on_layer[current_layer] + 1)*sizeof(List *));
  links[current_layer][n_machines_on_layer[current_layer]] = NULL;

  //--- 

  alpha_links[current_layer] = (List **)xrealloc((void *)(alpha_links[current_layer]),
                                                    (n_machines_on_layer[current_layer] + 1)*sizeof(List *));
  alpha_links[current_layer][n_machines_on_layer[current_layer]] = NULL;

  //---

  current_machine = n_machines_on_layer[current_layer];
  n_machines_on_layer[current_layer]++;

  if(current_layer == 0)
  {
    if(n_machines_on_layer[0] > 1)
    {
      if(machine->n_inputs != n_inputs)
        error("ConnectedMachine: trying to connect machine of different input size at the first layer");
    }
    else
      n_inputs = machine->n_inputs;
  }

  n_outputs += machine->n_outputs;

  if(machine->n_inputs > alpha_buff_size)
    alpha_buff_size = machine->n_inputs;

  current_alpha_offset = 0;
}

void ConnectedMachine::connectOn(GradientMachine *machine)
{
  if(current_machine < 0)
    error("ConnectedMachine: no machine to connect");

  bool flag = true;
  int l, m = -666;
  for(l = 0; (l < current_layer) && flag; l++)
  {
    for(m = 0; m < n_machines_on_layer[l]; m++)
    {
      if(machines[l][m] == machine)
      {
        flag = false;
        break;
      }
    }
  }

  l--;

  if(flag)
    error("ConnectedMachine: cannot connect your machine");

  addToList(&links[current_layer][current_machine], machine->outputs);
  addToList(&alpha_links[l][m], 1, machines[current_layer][current_machine]->beta+current_alpha_offset);
  current_alpha_offset += machine->n_outputs;

//  printf("[%d %d on %d %d] machine %d outputs. = machine mere: %d outputs. machine fils: %d inputs\n", l, m, current_layer, current_machine, machine->n_outputs, machines[l][m]->n_outputs, machines[current_layer][current_machine]->n_inputs);
}

void ConnectedMachine::checkInternalLinks()
{
  for(int l = 1; l < n_layers; l++)
  {
    for(int m = 0; m < n_machines_on_layer[l]; m++)
    {
      List *xliens = links[l][m];
      int xn_liens = 0;
      while(xliens)
      {
        xn_liens += xliens->n;
        xliens = xliens->next;
      }
      if(machines[l][m]->n_inputs != xn_liens)
        error("ConnectedMachine: incorrect number of inputs for machine [%d %d]", l, m);
    }
  }
}

int ConnectedMachine::numberOfParams()
{
  int n = 0;
  for(int l = 0; l < n_layers; l++)
  {
    for(int m = 0; m < n_machines_on_layer[l]; m++)
      n += machines[l][m]->numberOfParams();
  }

  return(n);
}

void ConnectedMachine::forward(List *inputs)
{
  for(int m = 0; m < n_machines_on_layer[0]; m++)
    machines[0][m]->forward(inputs);

  for(int l = 1; l < n_layers; l++)
  {
    for(int m = 0; m < n_machines_on_layer[l]; m++)
      machines[l][m]->forward(links[l][m]);
  }
}

void ConnectedMachine::backward(List *inputs, real *alpha)
{
  if(n_layers > 1)
  {
    for(int m = 0; m < n_machines_on_layer[n_layers-1]; m++)
    {
      machines[n_layers-1][m]->backward(links[n_layers-1][m], alpha);
      alpha += machines[n_layers-1][m]->n_outputs;
    }
  }
  else
  {
    for(int m = 0; m < n_machines_on_layer[0]; m++)
    {
      machines[n_layers-1][m]->backward(inputs, alpha);
      alpha += machines[0][m]->n_outputs;
    }
  }

  for(int l = n_layers-2; l >= 0; l--)
  {
    for(int m = 0; m < n_machines_on_layer[l]; m++)
    {
      for(int i = 0; i < machines[l][m]->n_outputs; i++)
        alpha_buff[i] = 0;

      List *liens = alpha_links[l][m];
      while(liens)
      {
        real *alpha_ = (real *)liens->ptr;
        for(int i = 0; i < machines[l][m]->n_outputs; i++)
          alpha_buff[i] += *alpha_++;
        liens = liens->next;
      }

      if(l == 0)
        machines[0][m]->backward(inputs, alpha_buff);
      else
        machines[l][m]->backward(links[l][m], alpha_buff);
    }
  }

  real *ptr_beta = beta;
  for(int i = 0; i < n_inputs; i++)
    *ptr_beta++ = 0;

  for(int k = 0; k < n_machines_on_layer[0]; k++)
  {
    real *x = machines[0][k]->beta;
    ptr_beta = beta;
    for(int i = 0; i < n_inputs; i++)
      *ptr_beta++ += *x++;
  }
}

void ConnectedMachine::allocateMemory()
{
  checkInternalLinks();

  n_params = numberOfParams();
  for(int l = 0; l < n_layers; l++)
  {
    for(int m = 0; m < n_machines_on_layer[l]; m++)
    {
      addToList(&params, machines[l][m]->params);
      addToList(&der_params, machines[l][m]->der_params);
    }
  }

  beta = (real *)xalloc(sizeof(real)*n_inputs);

  for(int m = 0; m < n_machines_on_layer[n_layers-1]; m++)
    addToList(&outputs, machines[n_layers-1][m]->outputs);

  alpha_buff = (real *)xalloc(sizeof(real)*alpha_buff_size);
}


void ConnectedMachine::freeMemory()
{
  if(is_free)
    return;

  freeList(&params);
  freeList(&der_params);
  free(beta);
  freeList(&outputs);
  free(alpha_buff);

  is_free = true;
}

void ConnectedMachine::reset()
{
  for(int i = 0; i < n_layers; i++)
  {
    for(int m = 0; m < n_machines_on_layer[i]; m++)
      machines[i][m]->reset();
  }
}

void ConnectedMachine::iterInitialize()
{
  for(int i = 0; i < n_layers; i++)
  {
    for(int m = 0; m < n_machines_on_layer[i]; m++)
      machines[i][m]->iterInitialize();
  }  
}

void ConnectedMachine::loadFILE(FILE *file)
{
  for(int i = 0; i < n_layers; i++)
  {
    for(int m = 0; m < n_machines_on_layer[i]; m++)
      machines[i][m]->loadFILE(file);
  }
}

void ConnectedMachine::saveFILE(FILE *file)
{
  for(int i = 0; i < n_layers; i++)
  {
    for(int m = 0; m < n_machines_on_layer[i]; m++)
      machines[i][m]->saveFILE(file);
  }
}

ConnectedMachine::~ConnectedMachine()
{
  for(int l = 0; l < n_layers; l++)
  {
    free(machines[l]);
    
    for(int m = 0; m < n_machines_on_layer[l]; m++)
    {
      freeList(&links[l][m]);
      freeList(&alpha_links[l][m]);
    }

    free(links[l]);
    free(alpha_links[l]);
  }

  free(links);
  free(alpha_links);
  free(n_machines_on_layer);
  free(machines);
  freeMemory();
}

}

