// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "StdDataSet.h"

namespace Torch {

StdDataSet::StdDataSet()
{
  addBOption("normalize inputs", &norm_inputs, false, "normalize the inputs by mean/stdv");
  addBOption("normalize targets", &norm_targets, false, "normalize the targets by mean/stdv");

  all_inputs = NULL;
  all_targets = NULL;
  mean_i = NULL;
  mean_o = NULL;
  stdv_i = NULL;
  stdv_o = NULL;
}

void StdDataSet::init()
{
  DataSet::init();
  normalize();
  
  addToList(&inputs, n_inputs, NULL);
}

void StdDataSet::normalizeUsingDataSet(StdDataSet *data_norm)
{
  if(!is_already_initialized)
    error("StdDataSet: please, normalize *after* initialization");

  if(data_norm->norm_inputs)
  {
    if(n_inputs == data_norm->n_inputs)
    {
      for(int i = 0; i < n_real_examples; i++)
      {
        for(int d = 0; d < n_inputs; d++)
          all_inputs[i][d] = (all_inputs[i][d]-data_norm->mean_i[d])/data_norm->stdv_i[d];
      }
    }
    else
      warning("StdDataSet: the normalization machine has not the good input size");
  }

  if(data_norm->norm_targets)
  {
    if(n_targets == data_norm->n_targets)
    {
      for(int i = 0; i < n_real_examples; i++)
      {
        for(int d = 0; d < n_targets; d++)
          all_targets[i][d] = (all_targets[i][d]-data_norm->mean_o[d])/data_norm->stdv_o[d];
      }
    }
    else
      warning("StdDataSet: the normalization machine has not the good target size");
  }
}


void StdDataSet::normalize()
{
  if(norm_inputs)
  {
    if(!mean_i)
    {
      mean_i = (real *)xalloc(n_inputs*sizeof(real));
      stdv_i = (real *)xalloc(n_inputs*sizeof(real));
    }

    MSTDVNormalize(all_inputs, mean_i, stdv_i, n_real_examples, n_inputs);

    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = (all_inputs[i][d]-mean_i[d])/stdv_i[d];
    }
  }

  if(norm_targets)
  {
    if(!mean_o)
    {
      mean_o = (real *)xalloc(n_targets*sizeof(real));
      stdv_o = (real *)xalloc(n_targets*sizeof(real));
    }

    MSTDVNormalize(all_targets, mean_o, stdv_o, n_real_examples, n_targets);

    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_targets; d++)
        all_targets[i][d] = (all_targets[i][d]-mean_o[d])/stdv_o[d];
    }
  }
}

void StdDataSet::setInputs(real **all_inputs_, int n_inputs_, int n_examples_)
{
  if(is_already_initialized)
    error("StdDataSet: please, call setInputs *before* initialization");

  all_inputs = all_inputs_;
  n_inputs = n_inputs_;
  n_real_examples = n_examples_;
}

void StdDataSet::setTargets(real **all_targets_, int n_targets_)
{
  if(is_already_initialized)
    error("StdDataSet: please, call setTargets *before* initialization");

  all_targets = all_targets_;
  n_targets = n_targets_;
}

void StdDataSet::setRealExample(int t)
{
  current_example = t;

  if(n_targets > 0)
    targets = all_targets[t];

  if(n_inputs > 0)
    inputs->ptr = all_inputs[t];
}

real StdDataSet::realRealDotProduct(int i, int j)
{
  real z = 0;
  real *xx = all_inputs[i];
  real *yy = all_inputs[j];
  for(int i = 0; i < n_inputs; i++)
    z += *xx++ * *yy++;

  return(z);
}

real StdDataSet::realDotProduct(int i, List *y)
{
  real z = 0;
  real *xx = all_inputs[i];
  real *yy = (real *)y->ptr;
  for(int i = 0; i < n_inputs; i++)
    z += *xx++ * *yy++;

  return(z);
}

real StdDataSet::dotProduct(List *x, List *y)
{
  real z = 0;
  real *xx = (real *)x->ptr;
  real *yy = (real *)y->ptr;
  for(int i = 0; i < n_inputs; i++)
    z += *xx++ * *yy++;

  return(z);
}

void StdDataSet::loadFILE(FILE *file)
{
  if(norm_inputs)
  {
    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = all_inputs[i][d]*stdv_i[d]+mean_i[d];
    }

    xfread(mean_i, sizeof(real), n_inputs, file);
    xfread(stdv_i, sizeof(real), n_inputs, file);

    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_inputs; d++)
        all_inputs[i][d] = (all_inputs[i][d]-mean_i[d])/stdv_i[d];
    }
  }

  if(norm_targets)
  {
    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_targets; d++)
        all_targets[i][d] = all_targets[i][d]*stdv_o[d]+mean_o[d];
    }

    xfread(mean_o, sizeof(real), n_targets, file);
    xfread(stdv_o, sizeof(real), n_targets, file);

    for(int i = 0; i < n_real_examples; i++)
    {
      for(int d = 0; d < n_targets; d++)
        all_targets[i][d] = (all_targets[i][d]-mean_o[d])/stdv_o[d];
    }
  }
}

void StdDataSet::saveFILE(FILE *file)
{
  if(norm_inputs)
  {
    xfwrite(mean_i, sizeof(real), n_inputs, file);
    xfwrite(stdv_i, sizeof(real), n_inputs, file);
  }

  if(norm_targets)
  {
    xfwrite(mean_o, sizeof(real), n_targets, file);
    xfwrite(stdv_o, sizeof(real), n_targets, file);
  }
}

StdDataSet::~StdDataSet()
{
  freeList(&inputs);
  free(mean_i);
  free(mean_o);
  free(stdv_i);
  free(stdv_o);
}

}

