/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

/*
   This module contains the following operators:

      Trend      trend           Trend
*/

#include <cdi.h>

#include "field.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_options.h"
#include "cdo_task.h"
#include "field_trend.h"
#include "cdo_omp.h"
#include "datetime.h"
#include "pmlist.h"
#include "param_conversion.h"
#include "progress.h"
#include "field_functions.h"
#include "arithmetic.h"

static void
trendGetParameter(bool &tstepIsEqual)
{
  auto pargc = cdo_operator_argc();
  if (pargc)
    {
      const auto &pargv = cdo_get_oper_argv();

      KVList kvlist;
      kvlist.name = cdo_module_name();
      if (kvlist.parse_arguments(pargv) != 0) cdo_abort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdo_abort("Too many values for parameter key >%s<!", key);
          if (kv.nvalues < 1) cdo_abort("Missing value for parameter key >%s<!", key);
          const auto &value = kv.values[0];

          // clang-format off
          if (key == "equal") tstepIsEqual = parameter_to_bool(value);
          else cdo_abort("Invalid parameter key >%s<!", key);
          // clang-format on
        }
    }
}

class Trend : public Process
{
public:
  using Process::Process;
  inline static CdoModule module = {
    .name = "Trend",
    .operators = { { "trend", TrendHelp } },
    .aliases = {},
    .mode = EXPOSED,     // Module mode: 0:intern 1:extern
    .number = CDI_REAL,  // Allowed number type
    .constraints = { 1, 2, OnlyFirst },
  };
  inline static RegisterEntry<Trend> registration = RegisterEntry<Trend>(module);

  static const int numWork = 5;

  CdoStreamID streamID1;
  CdoStreamID streamID2;
  CdoStreamID streamID3;

  int taxisID1;
  int taxisID2;

  int maxRecords;

  bool tstepIsEqual = true;

  VarList varList1;
  std::vector<RecordInfo> recordList;
  size_t gridSizeMax{ 0 };

public:
  void
  init() override
  {
    trendGetParameter(tstepIsEqual);

    streamID1 = cdo_open_read(0);

    auto vlistID1 = cdo_stream_inq_vlist(streamID1);
    auto vlistID2 = vlistDuplicate(vlistID1);

    vlist_unpack(vlistID2);

    vlistDefNtsteps(vlistID2, 1);

    taxisID1 = vlistInqTaxis(vlistID1);
    taxisID2 = taxisDuplicate(taxisID1);
    vlistDefTaxis(vlistID2, taxisID2);

    varList1 = VarList(vlistID1);

    maxRecords = varList1.numRecords();
    recordList = std::vector<RecordInfo>(maxRecords);

    auto numVars = varList1.numVars();
    for (int varID = 0; varID < numVars; ++varID) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT64);

    streamID2 = cdo_open_write(1);
    streamID3 = cdo_open_write(2);

    cdo_def_vlist(streamID2, vlistID2);
    cdo_def_vlist(streamID3, vlistID2);

    gridSizeMax = vlistGridsizeMax(vlistID1);
  }

  void
  write_output(const FieldVector3D &work)
  {
    Field field2, field3;
    field2.resize(gridSizeMax);
    field3.resize(gridSizeMax);

    cdo_def_timestep(streamID2, 0);
    cdo_def_timestep(streamID3, 0);

    for (int recID = 0; recID < maxRecords; ++recID)
      {
        auto [varID, levelID] = recordList[recID].get();

        const auto &var = varList1.vars[varID];
        field2.size = var.gridsize;
        field2.missval = var.missval;
        field3.size = var.gridsize;
        field3.missval = var.missval;

        calc_trend_param(work, field2, field3, varID, levelID);

        cdo_def_record(streamID2, varID, levelID);
        cdo_write_record(streamID2, field2.vec_d.data(), field_num_miss(field2));

        cdo_def_record(streamID3, varID, levelID);
        cdo_write_record(streamID3, field3.vec_d.data(), field_num_miss(field3));
      }
  }

  void
  run_sync()
  {
    auto calendar = taxisInqCalendar(taxisID1);
    CheckTimeIncr checkTimeIncr;
    JulianDate julianDate0;
    CdiDateTime vDateTime{};
    double deltat1 = 0.0;
    auto numSteps = varList1.numSteps();
    cdo::Progress progress;
    Field field1;

    FieldVector3D work(numWork);
    for (auto &w : work) field2D_init(w, varList1, FIELD_VEC, 0);

    int tsID = 0;
    while (true)
      {
        auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        vDateTime = taxisInqVdatetime(taxisID1);

        if (tstepIsEqual) check_time_increment(tsID, calendar, vDateTime, checkTimeIncr);
        auto zj = tstepIsEqual ? (double) tsID : delta_time_step_0(tsID, calendar, vDateTime, julianDate0, deltat1);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            auto fstatus = (tsID + (recID + 1.0) / nrecs) / numSteps;
            if (numSteps > 0) progress.update(fstatus);

            auto [varID, levelID] = cdo_inq_record(streamID1);
            recordList[recID].set(varID, levelID);
            field1.init(varList1.vars[varID]);
            cdo_read_record(streamID1, field1);

            calc_trend_sum(work, field1, zj, varID, levelID);
          }

        tsID++;
      }

    taxisDefVdatetime(taxisID2, vDateTime);
    write_output(work);
  }

  static void
  records_calc_trend_sum(FieldVector3D &work, const FieldVector2D &fields2D, const std::vector<RecordInfo> &recordList,
                         double zj) noexcept
  {
    for (const auto &record : recordList)
      {
        auto [varID, levelID] = record.get();
        calc_trend_sum(work, fields2D[varID][levelID], zj, varID, levelID);
      }
  }

  void
  run_async()
  {
    auto calendar = taxisInqCalendar(taxisID1);
    CheckTimeIncr checkTimeIncr;
    JulianDate julianDate0;
    CdiDateTime vDateTime{};
    double deltat1 = 0.0;
    auto numSteps = varList1.numSteps();
    cdo::Progress progress;

    FieldVector3D work(numWork);
    for (auto &w : work) field2D_init(w, varList1, FIELD_VEC, 0);

    FieldVector3D fields3D(2);
    field2D_init(fields3D[0], varList1, FIELD_VEC | FIELD_NAT);
    field2D_init(fields3D[1], varList1, FIELD_VEC | FIELD_NAT);

    bool useTask = true;
    auto task = useTask ? std::make_unique<cdo::Task>() : nullptr;

    int tsID = 0;
    while (true)
      {
        auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        vDateTime = taxisInqVdatetime(taxisID1);

        if (tstepIsEqual) check_time_increment(tsID, calendar, vDateTime, checkTimeIncr);
        auto zj = tstepIsEqual ? (double) tsID : delta_time_step_0(tsID, calendar, vDateTime, julianDate0, deltat1);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            auto fstatus = (tsID + (recID + 1.0) / nrecs) / numSteps;
            if (numSteps > 0) progress.update(fstatus);

            auto [varID, levelID] = cdo_inq_record(streamID1);
            recordList[recID].set(varID, levelID);
            cdo_read_record(streamID1, fields3D[tsID % 2][varID][levelID]);
          }

        if (useTask && tsID > 0) task->wait();

        std::function<void()> records_calc_trend_sum_func
            = std::bind(records_calc_trend_sum, std::ref(work), std::ref(fields3D[tsID % 2]), std::cref(recordList), zj);

        if (useTask) { task->doAsync(records_calc_trend_sum_func); }
        else { records_calc_trend_sum_func(); }

        tsID++;
      }

    if (useTask) task->wait();

    taxisDefVdatetime(taxisID2, vDateTime);
    write_output(work);
  }

  void
  run() override
  {
    auto runAsync = (Options::CDO_Parallel_Read > 0);
    if (runAsync)
      run_async();
    else
      run_sync();
  }

  void
  close() override
  {
    cdo_stream_close(streamID3);
    cdo_stream_close(streamID2);
    cdo_stream_close(streamID1);
  }
};
