!-----------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations         !
!   Copyright (C) 2000 - 2015  CP2K developers group                          !
!-----------------------------------------------------------------------------!

MODULE qs_fb_filter_matrix_methods

  USE atomic_kind_types,               ONLY: atomic_kind_type,&
                                             get_atomic_kind
  USE cp_dbcsr_interface,              ONLY: &
       cp_dbcsr_create, cp_dbcsr_finalize, cp_dbcsr_get_info, &
       cp_dbcsr_get_stored_coordinates, cp_dbcsr_init, cp_dbcsr_put_block, &
       cp_dbcsr_row_block_sizes, cp_dbcsr_type, dbcsr_distribution_obj, &
       dbcsr_type_no_symmetry
  USE cp_para_types,                   ONLY: cp_para_env_type
  USE fermi_utils,                     ONLY: Fermi,&
                                             FermiFixed
  USE kinds,                           ONLY: default_string_length,&
                                             dp,&
                                             int_8
  USE message_passing,                 ONLY: mp_alltoall
  USE particle_types,                  ONLY: particle_type
  USE qs_fb_atomic_halo_types,         ONLY: fb_atomic_halo_create,&
                                             fb_atomic_halo_get,&
                                             fb_atomic_halo_list_get,&
                                             fb_atomic_halo_list_obj,&
                                             fb_atomic_halo_nullify,&
                                             fb_atomic_halo_obj,&
                                             fb_atomic_halo_release,&
                                             fb_atomic_halo_set
  USE qs_fb_atomic_matrix_methods,     ONLY: fb_atmatrix_calc_size,&
                                             fb_atmatrix_construct,&
                                             fb_atmatrix_construct_2,&
                                             fb_atmatrix_generate_com_pairs_2
  USE qs_fb_com_tasks_types,           ONLY: &
       TASK_COST, TASK_DEST, TASK_N_RECORDS, TASK_PAIR, TASK_SRC, &
       fb_com_atom_pairs_calc_buffer_sizes, fb_com_atom_pairs_create, &
       fb_com_atom_pairs_decode, fb_com_atom_pairs_distribute_blks, &
       fb_com_atom_pairs_gather_blks, fb_com_atom_pairs_get, &
       fb_com_atom_pairs_has_data, fb_com_atom_pairs_init, &
       fb_com_atom_pairs_nullify, fb_com_atom_pairs_obj, &
       fb_com_atom_pairs_release, fb_com_tasks_build_atom_pairs, &
       fb_com_tasks_create, fb_com_tasks_encode_pair, fb_com_tasks_nullify, &
       fb_com_tasks_obj, fb_com_tasks_release, fb_com_tasks_set, &
       fb_com_tasks_transpose_dest_src
  USE qs_fb_matrix_data_types,         ONLY: fb_matrix_data_add,&
                                             fb_matrix_data_create,&
                                             fb_matrix_data_has_data,&
                                             fb_matrix_data_nullify,&
                                             fb_matrix_data_obj,&
                                             fb_matrix_data_release
  USE qs_fb_trial_fns_types,           ONLY: fb_trial_fns_get,&
                                             fb_trial_fns_obj
  USE string_utilities,                ONLY: compress,&
                                             uppercase
#include "./base/base_uses.f90"

  IMPLICIT NONE

  PRIVATE

  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_fb_filter_matrix_methods'

  PUBLIC :: fb_fltrmat_build,&
            fb_fltrmat_build_2


CONTAINS


! *****************************************************************************
!> \brief Build the filter matrix, with MPI communications happening at each
!>        step. Less efficient on communication, but more efficient on
!>        memory usage (compared to fb_fltrmat_build_2)
!> \param H_mat : DBCSR system KS matrix
!> \param S_mat : DBCSR system overlap matrix
!> \param atomic_halos : list of all local atomic halos, each halo gives
!>                       one atomic matrix and contributes to one blk
!>                       col to the filter matrix
!> \param trial_fns : the trial functions to be used to shrink the
!>                     size of the new "filtered" basis
!> \param para_env : cp2k parallel environment
!> \param particle_set : set of all particles in the system
!> \param fermi_level : the fermi level used for defining the filter
!>                      function, which is a Fermi-Dirac distribution
!>                      function
!> \param filter_temp : the filter temperature used for defining the
!>                      filter function
!> \param name        : name given to the filter matrix
!> \param filter_mat  : DBCSR format filter matrix
!> \param tolerance   : anything less than tolerance is treated as zero
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_build(H_mat, &
                              S_mat, &
                              atomic_halos, &
                              trial_fns, &
                              para_env, &
                              particle_set, &
                              fermi_level, &
                              filter_temp, &
                              name, &
                              filter_mat, &
                              tolerance)
    TYPE(cp_dbcsr_type), POINTER             :: H_mat, S_mat
    TYPE(fb_atomic_halo_list_obj), &
      INTENT(IN)                             :: atomic_halos
    TYPE(fb_trial_fns_obj), INTENT(IN)       :: trial_fns
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(particle_type), DIMENSION(:), &
      POINTER                                :: particle_set
    REAL(KIND=dp), INTENT(IN)                :: fermi_level, filter_temp
    CHARACTER(LEN=*), INTENT(IN)             :: name
    TYPE(cp_dbcsr_type), POINTER             :: filter_mat
    REAL(KIND=dp), INTENT(IN)                :: tolerance

    CHARACTER(LEN=*), PARAMETER :: routineN = 'fb_fltrmat_build', &
      routineP = moduleN//':'//routineN

    CHARACTER(LEN=32)                        :: symmetry_string
    CHARACTER(LEN=default_string_length)     :: name_string
    INTEGER                                  :: handle, iblkcol, ihalo, &
                                                ikind, max_nhalos, &
                                                nblkcols_total, nhalos
    INTEGER, DIMENSION(:), POINTER           :: col_blk_size, &
                                                dummy_halo_atoms, ntfns, &
                                                row_blk_size
    LOGICAL                                  :: send_data_only
    TYPE(atomic_kind_type), POINTER          :: atomic_kind
    TYPE(dbcsr_distribution_obj)             :: dbcsr_dist
    TYPE(fb_atomic_halo_obj)                 :: dummy_atomic_halo
    TYPE(fb_atomic_halo_obj), DIMENSION(:), &
      POINTER                                :: halos

    CALL timeset(routineN, handle)


    NULLIFY(halos, atomic_kind, ntfns, dummy_halo_atoms, row_blk_size, col_blk_size)
    CALL fb_atomic_halo_nullify(dummy_atomic_halo)

    ! filter_mat must be of a dissassociated status (i.e. brand new)
    CPASSERT(.NOT.ASSOCIATED(filter_mat))

    ! get trial function information
    CALL fb_trial_fns_get(trial_fns=trial_fns, &
                          nfunctions=ntfns)

    ! calculate the row_blk_size and col_blk_size arrays for
    ! constructing the filter matrix in DBCSR format
    ! row_blk_size for the filter matrix is the same as H or S
    CALL cp_dbcsr_get_info(H_mat, &
                           nblkcols_total=nblkcols_total, &
                           row_blk_size=row_blk_size, &
                           distribution=dbcsr_dist)
    ALLOCATE(col_blk_size(nblkcols_total))
    col_blk_size = 0
    DO iblkcol = 1, nblkcols_total
       atomic_kind => particle_set(iblkcol)%atomic_kind
       CALL get_atomic_kind(atomic_kind=atomic_kind, &
                            kind_number=ikind)
       col_blk_size(iblkcol) = ntfns(ikind)
    END DO
    ! DO NOT deallocate cbs if gift=.TRUE. as col_blk_sizes will only point to cbs
    name_string = name
    CALL compress(name_string)
    CALL uppercase(name_string)
    ! the filter matrix is non-square and is always non-symmetric
    symmetry_string = dbcsr_type_no_symmetry
    ! create empty filter matrix
    ALLOCATE(filter_mat)
    CALL cp_dbcsr_init(filter_mat)
    CALL cp_dbcsr_create(matrix=filter_mat, &
                         name=name_string, &
                         dist=dbcsr_dist, &
                         matrix_type=symmetry_string, &
                         row_blk_size=row_blk_size, &
                         col_blk_size=col_blk_size, &
                         nze=0)
    DEALLOCATE(col_blk_size)

    CALL fb_atomic_halo_list_get(atomic_halos=atomic_halos, &
                                 nhalos=nhalos, &
                                 max_nhalos=max_nhalos, &
                                 halos=halos)

    ! create dummy empty atomic halo
    CALL fb_atomic_halo_create(dummy_atomic_halo)
    ALLOCATE(dummy_halo_atoms(0))
    CALL fb_atomic_halo_set(atomic_halo=dummy_atomic_halo, &
                            owner_atom=0, &
                            owner_id_in_halo=0, &
                            natoms=0, &
                            halo_atoms=dummy_halo_atoms, &
                            nelectrons=0, &
                            sorted=.TRUE.)

    send_data_only = .FALSE.

    DO ihalo = 1, max_nhalos
       IF (ihalo > nhalos) THEN
          send_data_only = .TRUE.
       END IF
       ! construct the filter matrix block by block
       IF (send_data_only) THEN
          CALL fb_fltrmat_add_blkcol(H_mat, &
                                     S_mat, &
                                     dummy_atomic_halo, &
                                     trial_fns, &
                                     para_env, &
                                     particle_set, &
                                     fermi_level, &
                                     filter_temp, &
                                     filter_mat, &
                                     tolerance)
       ELSE
          CALL fb_fltrmat_add_blkcol(H_mat, &
                                     S_mat, &
                                     halos(ihalo), &
                                     trial_fns, &
                                     para_env, &
                                     particle_set, &
                                     fermi_level, &
                                     filter_temp, &
                                     filter_mat, &
                                     tolerance)
       END IF ! send_data_only
    END DO

    ! finalise the filter matrix
    CALL cp_dbcsr_finalize(filter_mat)

    ! cleanup
    CALL fb_atomic_halo_release(dummy_atomic_halo)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_build


! *****************************************************************************
!> \brief Build the filter matrix, with MPI communications grouped together.
!>        More effcient on communication, less efficient on memory (compared
!>        to fb_fltrmat_build)
!> \param H_mat : DBCSR system KS matrix
!> \param S_mat : DBCSR system overlap matrix
!> \param atomic_halos : list of all local atomic halos, each halo gives
!>                       one atomic matrix and contributes to one blk
!>                       col to the filter matrix
!> \param trial_fns : the trial functions to be used to shrink the
!>                     size of the new "filtered" basis
!> \param para_env : cp2k parallel environment
!> \param particle_set : set of all particles in the system
!> \param fermi_level : the fermi level used for defining the filter
!>                      function, which is a Fermi-Dirac distribution
!>                      function
!> \param filter_temp : the filter temperature used for defining the
!>                      filter function
!> \param name        : name given to the filter matrix
!> \param filter_mat  : DBCSR format filter matrix
!> \param tolerance   : anything less than tolerance is treated as zero
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_build_2(H_mat, &
                                S_mat, &
                                atomic_halos, &
                                trial_fns, &
                                para_env, &
                                particle_set, &
                                fermi_level, &
                                filter_temp, &
                                name, &
                                filter_mat, &
                                tolerance)
    TYPE(cp_dbcsr_type), POINTER             :: H_mat, S_mat
    TYPE(fb_atomic_halo_list_obj), &
      INTENT(IN)                             :: atomic_halos
    TYPE(fb_trial_fns_obj), INTENT(IN)       :: trial_fns
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(particle_type), DIMENSION(:), &
      POINTER                                :: particle_set
    REAL(KIND=dp), INTENT(IN)                :: fermi_level, filter_temp
    CHARACTER(LEN=*), INTENT(IN)             :: name
    TYPE(cp_dbcsr_type), POINTER             :: filter_mat
    REAL(KIND=dp), INTENT(IN)                :: tolerance

    CHARACTER(LEN=*), PARAMETER :: routineN = 'fb_fltrmat_build_2', &
      routineP = moduleN//':'//routineN

    CHARACTER(LEN=default_string_length)     :: name_string
    INTEGER :: handle, iblkcol, ihalo, ikind, natoms_global, natoms_in_halo, &
      nblkcols_total, nblks_recv, nhalos, nmax
    INTEGER, DIMENSION(:), POINTER           :: col_blk_size, ntfns, &
                                                row_blk_size
    LOGICAL                                  :: check_ok
    TYPE(atomic_kind_type), POINTER          :: atomic_kind
    TYPE(dbcsr_distribution_obj)             :: dbcsr_dist
    TYPE(fb_atomic_halo_obj), DIMENSION(:), &
      POINTER                                :: halos
    TYPE(fb_com_atom_pairs_obj) :: atmatrix_blks_recv, atmatrix_blks_send, &
      filter_mat_blks_recv, filter_mat_blks_send
    TYPE(fb_matrix_data_obj)                 :: filter_mat_data, H_mat_data, &
                                                S_mat_data

    CALL timeset(routineN, handle)


    NULLIFY(halos, atomic_kind, row_blk_size, col_blk_size, ntfns)

    ! filter_mat must be of a dissassociated status (i.e. brand new)
    check_ok = .NOT. ASSOCIATED(filter_mat)
    CPASSERT(check_ok)

    ! get total number of atoms
    natoms_global = SIZE(particle_set)

    ! get trial function information
    CALL fb_trial_fns_get(trial_fns=trial_fns, &
                          nfunctions=ntfns)

    ! calculate the row_blk_size and col_blk_size arrays for
    ! constructing the filter matrix in DBCSR format
    ! row_blk_size for the filter matrix is the same as H or S
    CALL cp_dbcsr_get_info(H_mat, &
                           nblkcols_total=nblkcols_total, &
                           row_blk_size=row_blk_size, &
                           distribution=dbcsr_dist)
    ALLOCATE(col_blk_size(nblkcols_total))
    col_blk_size = 0
    DO iblkcol = 1, nblkcols_total
       atomic_kind => particle_set(iblkcol)%atomic_kind
       CALL get_atomic_kind(atomic_kind=atomic_kind, &
                            kind_number=ikind)
       col_blk_size(iblkcol) = ntfns(ikind)
    END DO
    ! DO NOT deallocate cbs if gift=.TRUE. as col_blk_sizes will only point to cbs
    name_string = name
    CALL compress(name_string)
    CALL uppercase(name_string)
    ! create empty filter matrix (it is always non-symmetric as it is non-square)
    ALLOCATE(filter_mat)
    CALL cp_dbcsr_init(filter_mat)
    CALL cp_dbcsr_create(matrix=filter_mat, &
                         name=name_string, &
                         dist=dbcsr_dist, &
                         matrix_type=dbcsr_type_no_symmetry, &
                         row_blk_size=row_blk_size, &
                         col_blk_size=col_blk_size, &
                         nze=0)
    DEALLOCATE(col_blk_size)

    ! get all the blocks required for constructing atomic matrics, and
    ! store it in a fb_matrix_data object
    CALL fb_matrix_data_nullify(H_mat_data)
    CALL fb_matrix_data_nullify(S_mat_data)
    CALL fb_com_atom_pairs_nullify(atmatrix_blks_send)
    CALL fb_com_atom_pairs_nullify(atmatrix_blks_recv)
    CALL fb_com_atom_pairs_create(atmatrix_blks_send)
    CALL fb_com_atom_pairs_create(atmatrix_blks_recv)
    ! H matrix
    CALL fb_atmatrix_generate_com_pairs_2(H_mat, &
                                          atomic_halos, &
                                          para_env, &
                                          atmatrix_blks_send, &
                                          atmatrix_blks_recv)
    CALL fb_com_atom_pairs_get(atom_pairs=atmatrix_blks_recv, &
                               npairs=nblks_recv)
    CALL fb_matrix_data_create(H_mat_data, &
                               nblks_recv, &
                               natoms_global)
    CALL fb_com_atom_pairs_gather_blks(H_mat, &
                                       atmatrix_blks_send, &
                                       atmatrix_blks_recv, &
                                       para_env, &
                                       H_mat_data)
    ! S matrix
    CALL fb_atmatrix_generate_com_pairs_2(S_mat, &
                                          atomic_halos, &
                                          para_env, &
                                          atmatrix_blks_send, &
                                          atmatrix_blks_recv)
    CALL fb_com_atom_pairs_get(atom_pairs=atmatrix_blks_recv, &
                               npairs=nblks_recv)
    CALL fb_matrix_data_create(S_mat_data, &
                               nblks_recv, &
                               natoms_global)
    CALL fb_com_atom_pairs_gather_blks(S_mat, &
                                       atmatrix_blks_send, &
                                       atmatrix_blks_recv, &
                                       para_env, &
                                       S_mat_data)
    ! cleanup
    CALL fb_com_atom_pairs_release(atmatrix_blks_send)
    CALL fb_com_atom_pairs_release(atmatrix_blks_recv)

    ! make filter matrix blocks one by one and store in an
    ! matrix_data_obj
    CALL fb_matrix_data_nullify(filter_mat_data)
    CALL fb_atomic_halo_list_get(atomic_halos=atomic_halos, &
                                 nhalos=nhalos, &
                                 halos=halos)
    nmax = 0
    DO ihalo = 1, nhalos
       CALL fb_atomic_halo_get(atomic_halo=halos(ihalo), &
                               natoms=natoms_in_halo)
       nmax = nmax + natoms_in_halo
    END DO
    CALL fb_matrix_data_create(filter_mat_data, &
                               nmax, &
                               natoms_global)
    DO ihalo = 1, nhalos
       CALL fb_fltrmat_add_blkcol_2(H_mat, &
                                    S_mat, &
                                    H_mat_data, &
                                    S_mat_data, &
                                    halos(ihalo), &
                                    trial_fns, &
                                    particle_set, &
                                    fermi_level, &
                                    filter_temp, &
                                    filter_mat_data, &
                                    tolerance)
    END DO
    ! clean up
    CALL fb_matrix_data_release(H_mat_data)
    CALL fb_matrix_data_release(S_mat_data)

    ! distribute the relevant blocks from the matrix_data_obj to DBCSR
    ! filter matrix
    CALL fb_com_atom_pairs_nullify(filter_mat_blks_send)
    CALL fb_com_atom_pairs_nullify(filter_mat_blks_recv)
    CALL fb_com_atom_pairs_create(filter_mat_blks_send)
    CALL fb_com_atom_pairs_create(filter_mat_blks_recv)
    CALL fb_fltrmat_generate_com_pairs_2(filter_mat, &
                                         atomic_halos, &
                                         para_env, &
                                         filter_mat_blks_send, &
                                         filter_mat_blks_recv)
    CALL fb_com_atom_pairs_distribute_blks(filter_mat_data, &
                                           filter_mat_blks_send, &
                                           filter_mat_blks_recv, &
                                           para_env, &
                                           filter_mat)
    ! cleanup
    CALL fb_com_atom_pairs_release(filter_mat_blks_send)
    CALL fb_com_atom_pairs_release(filter_mat_blks_recv)
    CALL fb_matrix_data_release(filter_mat_data)

    ! finalise matrix
    CALL cp_dbcsr_finalize(filter_mat)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_build_2


! *****************************************************************************
!> \brief Add a computed blocks in one column to the filter matrix. This
!>        version is used by fb_fltrmat_build, for the case where MPI
!>        communications are done at each step
!>        It does not finalise the filter matrix
!> \param H_mat : DBCSR system KS matrix
!> \param S_mat : DBCSR system overlap matrix
!> \param atomic_halo :  the halo that contributes to the blk
!>                       col of the filter matrix
!> \param trial_fns ...
!> \param para_env : cp2k parallel environment
!> \param particle_set : set of all particles in the system
!> \param fermi_level : the fermi level used for defining the filter
!>                      function, which is a Fermi-Dirac distribution
!>                      function
!> \param filter_temp : the filter temperature used for defining the
!>                      filter function
!> \param filter_mat  : DBCSR format filter matrix
!> \param tolerance   : anything smaller than tolerance is treated as zero
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_add_blkcol(H_mat, &
                                   S_mat, &
                                   atomic_halo, &
                                   trial_fns, &
                                   para_env, &
                                   particle_set, &
                                   fermi_level, &
                                   filter_temp, &
                                   filter_mat, &
                                   tolerance)
    TYPE(cp_dbcsr_type), POINTER             :: H_mat, S_mat
    TYPE(fb_atomic_halo_obj), INTENT(IN)     :: atomic_halo
    TYPE(fb_trial_fns_obj), INTENT(IN)       :: trial_fns
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(particle_type), DIMENSION(:), &
      POINTER                                :: particle_set
    REAL(KIND=dp), INTENT(IN)                :: fermi_level, filter_temp
    TYPE(cp_dbcsr_type), POINTER             :: filter_mat
    REAL(KIND=dp), INTENT(IN)                :: tolerance

    CHARACTER(LEN=*), PARAMETER :: routineN = 'fb_fltrmat_add_blkcol', &
      routineP = moduleN//':'//routineN

    INTEGER :: handle, handle_mpi, iatom_global, iatom_in_halo, ind, ipair, &
      ipe, itrial, jatom_global, jatom_in_halo, jkind, natoms_global, &
      natoms_in_halo, ncols_atmatrix, ncols_blk, nrows_atmatrix, nrows_blk, &
      numprocs, pe, recv_encode, send_encode, stat
    INTEGER(KIND=int_8), DIMENSION(:), &
      POINTER                                :: pairs_recv, pairs_send
    INTEGER, ALLOCATABLE, DIMENSION(:) :: atomic_H_blk_col_start, &
      atomic_H_blk_row_start, atomic_S_blk_col_start, atomic_S_blk_row_start, &
      col_block_size_data, ind_in_halo, recv_disps, recv_pair_count, &
      recv_pair_disps, recv_sizes, send_disps, send_pair_count, &
      send_pair_disps, send_sizes
    INTEGER, DIMENSION(:), POINTER           :: halo_atoms, ntfns, &
                                                row_block_size_data
    INTEGER, DIMENSION(:, :), POINTER        :: tfns
    REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_buf, send_buf
    REAL(KIND=dp), ALLOCATABLE, &
      DIMENSION(:, :)                        :: atomic_filter_mat, atomic_H, &
                                                atomic_S
    TYPE(atomic_kind_type), POINTER          :: atomic_kind
    TYPE(fb_com_atom_pairs_obj)              :: com_pairs_recv, com_pairs_send

    CALL timeset(routineN, handle)


    NULLIFY(atomic_kind, halo_atoms, ntfns, pairs_send, pairs_recv, &
            row_block_size_data, tfns)
    CALL fb_com_atom_pairs_nullify(com_pairs_send)
    CALL fb_com_atom_pairs_nullify(com_pairs_recv)

    ! ----------------------------------------------------------------------
    ! Get communication buffers ready
    ! ----------------------------------------------------------------------

    ! generate send and recv atom pairs
    CALL fb_com_atom_pairs_create(com_pairs_send)
    CALL fb_com_atom_pairs_create(com_pairs_recv)
    CALL fb_fltrmat_generate_com_pairs(filter_mat, &
                                       atomic_halo, &
                                       para_env, &
                                       com_pairs_send, &
                                       com_pairs_recv)
    CALL fb_com_atom_pairs_get(atom_pairs=com_pairs_send, &
                               natoms_encode=send_encode, &
                               pairs=pairs_send)
    CALL fb_com_atom_pairs_get(atom_pairs=com_pairs_recv, &
                               natoms_encode=recv_encode, &
                               pairs=pairs_recv)

    ! get para_env info
    numprocs = para_env%num_pe
    ! me = para_env%mepos + 1   ! my process id, starting counting from 1

    ! obtain trail function information
    CALL fb_trial_fns_get(trial_fns=trial_fns, &
                          nfunctions=ntfns, &
                          functions=tfns)

    ! obtain row and col block size data for filter matrix
    row_block_size_data => cp_dbcsr_row_block_sizes(H_mat)
    natoms_global = SIZE(particle_set)
    ALLOCATE(col_block_size_data(natoms_global))
    DO jatom_global = 1, natoms_global
       atomic_kind => particle_set(jatom_global)%atomic_kind
       CALL get_atomic_kind(atomic_kind=atomic_kind, kind_number=jkind)
       col_block_size_data(jatom_global) = ntfns(jkind)
    END DO

    ! allocate temporary arrays for send
    ALLOCATE(send_sizes(numprocs))
    ALLOCATE(send_disps(numprocs))
    ALLOCATE(send_pair_count(numprocs))
    ALLOCATE(send_pair_disps(numprocs))
    ! setup send buffer sizes
    CALL fb_com_atom_pairs_calc_buffer_sizes(com_pairs_send, &
                                             numprocs, &
                                             row_block_size_data, &
                                             col_block_size_data, &
                                             send_sizes, &
                                             send_disps, &
                                             send_pair_count, &
                                             send_pair_disps)
    ! allocate send buffer
    ALLOCATE(send_buf(SUM(send_sizes)))

    ! allocate temporary array for recv
    ALLOCATE(recv_sizes(numprocs))
    ALLOCATE(recv_disps(numprocs))
    ALLOCATE(recv_pair_count(numprocs))
    ALLOCATE(recv_pair_disps(numprocs))
    ! setup recv buffer sizes
    CALL fb_com_atom_pairs_calc_buffer_sizes(com_pairs_recv, &
                                             numprocs, &
                                             row_block_size_data, &
                                             col_block_size_data, &
                                             recv_sizes, &
                                             recv_disps, &
                                             recv_pair_count, &
                                             recv_pair_disps)
    ! allocate recv buffer
    ALLOCATE(recv_buf(SUM(recv_sizes)))

    ! ----------------------------------------------------------------------
    ! Construct atomic filter matrix for this atomic_halo
    ! ----------------------------------------------------------------------

    CALL fb_atomic_halo_get(atomic_halo=atomic_halo, &
                            natoms=natoms_in_halo, &
                            halo_atoms=halo_atoms)

    ! construct atomic matrix for H for atomic_halo
    ALLOCATE(atomic_H_blk_row_start(natoms_in_halo + 1), &
             atomic_H_blk_col_start(natoms_in_halo + 1), &
             STAT=stat)
    CPASSERT(stat==0)
    CALL fb_atmatrix_calc_size(H_mat, &
                               atomic_halo, &
                               nrows_atmatrix, &
                               ncols_atmatrix, &
                               atomic_H_blk_row_start, &
                               atomic_H_blk_col_start)

    ALLOCATE(atomic_H(nrows_atmatrix,ncols_atmatrix))
    CALL fb_atmatrix_construct(H_mat, &
                               atomic_halo, &
                               para_env, &
                               atomic_H, &
                               atomic_H_blk_row_start, &
                               atomic_H_blk_col_start)

    ! construct atomic matrix for S for atomic_halo
    ALLOCATE(atomic_S_blk_row_start(natoms_in_halo + 1), &
             atomic_S_blk_col_start(natoms_in_halo + 1), &
             STAT=stat)
    CPASSERT(stat==0)
    CALL fb_atmatrix_calc_size(S_mat, &
                               atomic_halo, &
                               nrows_atmatrix, &
                               ncols_atmatrix, &
                               atomic_S_blk_row_start, &
                               atomic_S_blk_col_start)
    ALLOCATE(atomic_S(nrows_atmatrix,ncols_atmatrix))
    CALL fb_atmatrix_construct(S_mat, &
                               atomic_halo, &
                               para_env, &
                               atomic_S, &
                               atomic_S_blk_row_start, &
                               atomic_S_blk_col_start)

    ! construct the atomic filter matrix
    ALLOCATE(atomic_filter_mat(nrows_atmatrix,ncols_atmatrix))
    ! calculate atomic filter matrix only if it is non-zero sized
    IF (nrows_atmatrix > 0 .AND. ncols_atmatrix > 0) THEN
       CALL fb_fltrmat_build_atomic_fltrmat(atomic_H, &
                                            atomic_S, &
                                            fermi_level, &
                                            filter_temp, &
                                            atomic_filter_mat, &
                                            tolerance)
    END IF

    ! ----------------------------------------------------------------------
    ! Construct filter matrix blocks and add to the correct locations
    ! in send_buffer
    ! ----------------------------------------------------------------------

    ! preconstruct iatom_global to iatom_in_halo map
    ALLOCATE(ind_in_halo(natoms_global))
    ind_in_halo = 0
    DO iatom_in_halo = 1, natoms_in_halo
       iatom_global = halo_atoms(iatom_in_halo)
       ind_in_halo(iatom_global) = iatom_in_halo
    END DO

    ! initialise send buffer
    IF (SIZE(send_buf) > 0) send_buf = 0.0_dp
    ! assign values
    DO ipe = 1, numprocs
       send_sizes(ipe) = 0
       DO ipair = 1, send_pair_count(ipe)
          CALL fb_com_atom_pairs_decode(pairs_send(send_pair_disps(ipe) + ipair), &
                                        pe, iatom_global, jatom_global, &
                                        send_encode)
          iatom_in_halo = ind_in_halo(iatom_global)
          CPASSERT(iatom_in_halo>0)
          jatom_in_halo = ind_in_halo(jatom_global)
          CPASSERT(jatom_in_halo>0)
          atomic_kind => particle_set(jatom_global)%atomic_kind
          CALL get_atomic_kind(atomic_kind=atomic_kind, &
                               kind_number=jkind)
          nrows_blk = row_block_size_data(iatom_global)
          ncols_blk = ntfns(jkind)

          ! do it column-wise one trial function at a time
          DO itrial = 1, ntfns(jkind)
             ind = send_disps(ipe) + send_sizes(ipe) + (itrial-1) * nrows_blk
             CALL dgemv("N",                                           &
                         nrows_blk,                                    &
                         ncols_atmatrix,                               &
                         1.0_dp,                                       &
                         atomic_filter_mat(                            &
                           atomic_H_blk_row_start(iatom_in_halo) :     &
                           atomic_H_blk_row_start(iatom_in_halo+1)-1,  &
                           1 : ncols_atmatrix                          &
                         ),                                            &
                         nrows_blk,                                    &
                         atomic_S(                                     &
                           1 : nrows_atmatrix,                         &
                           atomic_S_blk_col_start(jatom_in_halo) +     &
                           tfns(itrial,jkind) - 1                      &
                         ),                                            &
                         1,                                            &
                         0.0_dp,                                       &
                         send_buf(ind + 1 : ind + nrows_blk),          &
                         1)
          END DO ! itrial
          send_sizes(ipe) = send_sizes(ipe) + nrows_blk * ncols_blk
       END DO ! ipair
    END DO  ! ipe

    DEALLOCATE(atomic_H)
    DEALLOCATE(atomic_H_blk_row_start)
    DEALLOCATE(atomic_S)
    DEALLOCATE(atomic_S_blk_row_start)
    DEALLOCATE(atomic_filter_mat)
    DEALLOCATE(ind_in_halo)

    ! ----------------------------------------------------------------------
    ! Do communication
    ! ----------------------------------------------------------------------

    CALL timeset("fb_fltrmat_add_blkcol_mpi", handle_mpi)

    CALL mp_alltoall(send_buf, send_sizes, send_disps, &
                     recv_buf, recv_sizes, recv_disps, &
                     para_env%group)

    CALL timestop(handle_mpi)

    DEALLOCATE(send_buf)
    DEALLOCATE(send_sizes)
    DEALLOCATE(send_disps)
    DEALLOCATE(send_pair_count)
    DEALLOCATE(send_pair_disps)

    ! ----------------------------------------------------------------------
    ! Unpack the recv buffer and add the blocks to correct parts of
    ! the DBCSR filter matrix
    ! ----------------------------------------------------------------------

    DO ipe = 1, numprocs
       recv_sizes(ipe) = 0
       DO ipair = 1, recv_pair_count(ipe)
          CALL fb_com_atom_pairs_decode(pairs_recv(recv_pair_disps(ipe) + ipair), &
                                        pe, iatom_global, jatom_global, &
                                        recv_encode)
          nrows_blk = row_block_size_data(iatom_global)
          ncols_blk = col_block_size_data(jatom_global)
          ind = recv_disps(ipe) + recv_sizes(ipe)
          CALL cp_dbcsr_put_block(filter_mat, &
                                  iatom_global, jatom_global, &
                                  recv_buf((ind+1) : (ind+nrows_blk*ncols_blk)))
          recv_sizes(ipe) = recv_sizes(ipe) + nrows_blk * ncols_blk
       END DO ! ipair
    END DO ! ipe

    ! cleanup rest of the temporary arrays
    DEALLOCATE(recv_buf)
    DEALLOCATE(recv_sizes)
    DEALLOCATE(recv_pair_count)
    DEALLOCATE(recv_pair_disps)

    CALL fb_com_atom_pairs_release(com_pairs_send)
    CALL fb_com_atom_pairs_release(com_pairs_recv)

    ! cannot finalise the matrix until all blocks has been added

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_add_blkcol


! *****************************************************************************
!> \brief Computed blocks in one filter matrix column. This version is used by
!>        fb_fltrmat_build_2, where MPI communication is done collectively
!> \param H_mat : DBCSR system KS matrix
!> \param S_mat : DBCSR system overlap matrix
!> \param H_mat_data  :  local storage of the relevant H_mat matrix blocks
!> \param S_mat_data  :  local storage of the relevant S_mat matrix blocks
!> \param atomic_halo :  the halo that contributes to the blk
!>                       col of the filter matrix
!> \param trial_fns   :  trial functions data
!> \param particle_set : set of all particles in the system
!> \param fermi_level : the fermi level used for defining the filter
!>                      function, which is a Fermi-Dirac distribution
!>                      function
!> \param filter_temp : the filter temperature used for defining the
!>                      filter function
!> \param filter_mat_data : local storage for the the computed filter matrix
!>                          blocks
!> \param tolerance : anything less than this is regarded as zero
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_add_blkcol_2(H_mat, &
                                     S_mat, &
                                     H_mat_data, &
                                     S_mat_data, &
                                     atomic_halo, &
                                     trial_fns, &
                                     particle_set, &
                                     fermi_level, &
                                     filter_temp, &
                                     filter_mat_data, &
                                     tolerance)
    TYPE(cp_dbcsr_type), POINTER             :: H_mat, S_mat
    TYPE(fb_matrix_data_obj), INTENT(IN)     :: H_mat_data, S_mat_data
    TYPE(fb_atomic_halo_obj), INTENT(IN)     :: atomic_halo
    TYPE(fb_trial_fns_obj), INTENT(IN)       :: trial_fns
    TYPE(particle_type), DIMENSION(:), &
      POINTER                                :: particle_set
    REAL(KIND=dp), INTENT(IN)                :: fermi_level, filter_temp
    TYPE(fb_matrix_data_obj), INTENT(INOUT)  :: filter_mat_data
    REAL(KIND=dp), INTENT(IN)                :: tolerance

    CHARACTER(LEN=*), PARAMETER :: routineN = 'fb_fltrmat_add_blkcol_2', &
      routineP = moduleN//':'//routineN

    INTEGER :: handle, iatom_global, iatom_in_halo, itrial, jatom_global, &
      jatom_in_halo, jkind, natoms_global, natoms_in_halo, ncols_atmatrix, &
      ncols_blk, ncols_blk_max, nrows_atmatrix, nrows_blk, nrows_blk_max, stat
    INTEGER, ALLOCATABLE, DIMENSION(:) :: atomic_H_blk_col_start, &
      atomic_H_blk_row_start, atomic_S_blk_col_start, atomic_S_blk_row_start, &
      col_block_size_data
    INTEGER, DIMENSION(:), POINTER           :: halo_atoms, ntfns, &
                                                row_block_size_data
    INTEGER, DIMENSION(:, :), POINTER        :: tfns
    LOGICAL                                  :: check_ok
    REAL(KIND=dp), ALLOCATABLE, &
      DIMENSION(:, :)                        :: atomic_filter_mat, atomic_H, &
                                                atomic_S, mat_blk
    TYPE(atomic_kind_type), POINTER          :: atomic_kind

    CALL timeset(routineN, handle)

    NULLIFY(atomic_kind, halo_atoms, ntfns, row_block_size_data, tfns)

    check_ok = fb_matrix_data_has_data(H_mat_data)
    CPASSERT(check_ok)
    check_ok = fb_matrix_data_has_data(S_mat_data)
    CPASSERT(check_ok)

    ! obtain trial function information
    CALL fb_trial_fns_get(trial_fns=trial_fns, &
                          nfunctions=ntfns, &
                          functions=tfns)

    ! obtain row and col block size data for filter matrix
    row_block_size_data => cp_dbcsr_row_block_sizes(H_mat)
    natoms_global = SIZE(particle_set)
    ALLOCATE(col_block_size_data(natoms_global))
    DO jatom_global = 1, natoms_global
       atomic_kind => particle_set(jatom_global)%atomic_kind
       CALL get_atomic_kind(atomic_kind=atomic_kind, kind_number=jkind)
       col_block_size_data(jatom_global) = ntfns(jkind)
    END DO

    ! ----------------------------------------------------------------------
    ! Construct atomic filter matrix for this atomic_halo
    ! ----------------------------------------------------------------------

    CALL fb_atomic_halo_get(atomic_halo=atomic_halo, &
                            natoms=natoms_in_halo, &
                            halo_atoms=halo_atoms)

    ! construct atomic matrix for H for atomic_halo
    ALLOCATE(atomic_H_blk_row_start(natoms_in_halo + 1), &
             atomic_H_blk_col_start(natoms_in_halo + 1), &
             STAT=stat)
    CPASSERT(stat==0)
    CALL fb_atmatrix_calc_size(H_mat, &
                               atomic_halo, &
                               nrows_atmatrix, &
                               ncols_atmatrix, &
                               atomic_H_blk_row_start, &
                               atomic_H_blk_col_start)
    ALLOCATE(atomic_H(nrows_atmatrix,ncols_atmatrix))
    CALL fb_atmatrix_construct_2(H_mat_data, &
                                 atomic_halo, &
                                 atomic_H, &
                                 atomic_H_blk_row_start, &
                                 atomic_H_blk_col_start)

    ! construct atomic matrix for S for atomic_halo
    ALLOCATE(atomic_S_blk_row_start(natoms_in_halo + 1), &
             atomic_S_blk_col_start(natoms_in_halo + 1), &
             STAT=stat)
    CPASSERT(stat==0)
    CALL fb_atmatrix_calc_size(S_mat, &
                               atomic_halo, &
                               nrows_atmatrix, &
                               ncols_atmatrix, &
                               atomic_S_blk_row_start, &
                               atomic_S_blk_col_start)
    ALLOCATE(atomic_S(nrows_atmatrix,ncols_atmatrix))
    CALL fb_atmatrix_construct_2(S_mat_data, &
                                 atomic_halo, &
                                 atomic_S, &
                                 atomic_S_blk_row_start, &
                                 atomic_S_blk_col_start)

    ! construct the atomic filter matrix
    ALLOCATE(atomic_filter_mat(nrows_atmatrix,ncols_atmatrix))
    ! calculate atomic filter matrix only if it is non-zero sized
    IF (nrows_atmatrix > 0 .AND. ncols_atmatrix > 0) THEN
       CALL fb_fltrmat_build_atomic_fltrmat(atomic_H, &
                                            atomic_S, &
                                            fermi_level, &
                                            filter_temp, &
                                            atomic_filter_mat, &
                                            tolerance)
    END IF

    ! ----------------------------------------------------------------------
    ! Construct filter matrix block and add to filter_mat_data
    ! ----------------------------------------------------------------------

    CALL fb_atomic_halo_get(atomic_halo=atomic_halo, &
                            owner_atom=jatom_global, &
                            owner_id_in_halo=jatom_in_halo)
    nrows_blk_max = MAXVAL(row_block_size_data)
    ncols_blk_max = MAXVAL(ntfns)
    ALLOCATE(mat_blk(nrows_blk_max,ncols_blk_max))
    mat_blk(:,:) = 0.0_dp
    DO iatom_in_halo = 1, natoms_in_halo
       iatom_global = halo_atoms(iatom_in_halo)
       atomic_kind => particle_set(jatom_global)%atomic_kind
       CALL get_atomic_kind(atomic_kind=atomic_kind, &
                            kind_number=jkind)
       nrows_blk = row_block_size_data(iatom_global)
       ncols_blk = ntfns(jkind)

       ! ALLOCATE(mat_blk(nrows_blk,ncols_blk) STAT=stat)
       ! CPPostcondition(stat==0, cp_failure_level, routineP,failure)

       ! do it column-wise one trial function at a time
       DO itrial = 1, ntfns(jkind)
          CALL dgemv("N",                                          &
                     nrows_blk,                                    &
                     ncols_atmatrix,                               &
                     1.0_dp,                                       &
                     atomic_filter_mat(                            &
                       atomic_H_blk_row_start(iatom_in_halo) :     &
                       atomic_H_blk_row_start(iatom_in_halo+1)-1,  &
                       1 : ncols_atmatrix                          &
                     ),                                            &
                     nrows_blk,                                    &
                     atomic_S(                                     &
                       1 : nrows_atmatrix,                         &
                       atomic_S_blk_col_start(jatom_in_halo) +     &
                       tfns(itrial,jkind) - 1                      &
                     ),                                            &
                     1,                                            &
                     0.0_dp,                                       &
                     mat_blk(                                      &
                       1 : nrows_blk,                              &
                       itrial),                                    &
                     1)
       END DO ! itrial
       CALL fb_matrix_data_add(filter_mat_data, &
                               iatom_global, &
                               jatom_global, &
                               mat_blk(1:nrows_blk, 1:ncols_blk))
       ! DEALLOCATE(mat_blk, STAT=stat)
       ! CPPostcondition(stat==0, cp_failure_level, routineP,failure)
    END DO ! iatom_in_halo
    DEALLOCATE(mat_blk)

    ! clean up
    DEALLOCATE(atomic_H)
    DEALLOCATE(atomic_H_blk_row_start)
    DEALLOCATE(atomic_S)
    DEALLOCATE(atomic_S_blk_row_start)
    DEALLOCATE(atomic_filter_mat)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_add_blkcol_2


! *****************************************************************************
!> \brief generate the list of blocks (atom pairs) to be sent and received
!>        in order to construct the filter matrix for each atomic halo.
!>        This version is for use with fb_fltrmat_build, where MPI
!>        communications are done at each step
!> \param filter_mat : DBCSR formated filter matrix
!> \param atomic_halo :  the halo that contributes to a blk
!>                       col of the filter matrix
!> \param para_env : cp2k parallel environment
!> \param atom_pairs_send : list of blocks to be sent
!> \param atom_pairs_recv : list of blocks to be received
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_generate_com_pairs(filter_mat, &
                                           atomic_halo, &
                                           para_env, &
                                           atom_pairs_send, &
                                           atom_pairs_recv)
    TYPE(cp_dbcsr_type), POINTER             :: filter_mat
    TYPE(fb_atomic_halo_obj), INTENT(IN)     :: atomic_halo
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(fb_com_atom_pairs_obj), &
      INTENT(INOUT)                          :: atom_pairs_send, &
                                                atom_pairs_recv

    CHARACTER(LEN=*), PARAMETER :: &
      routineN = 'fb_fltrmat_generate_com_pairs', &
      routineP = moduleN//':'//routineN

    INTEGER :: dest, handle, iatom_global, iatom_in_halo, iatom_stored, &
      itask, jatom_global, jatom_stored, natoms_in_halo, nblkrows_total, &
      ntasks_send, src
    INTEGER(KIND=int_8), DIMENSION(:, :), &
      POINTER                                :: tasks_send
    INTEGER, DIMENSION(:), POINTER           :: halo_atoms
    LOGICAL                                  :: transpose
    TYPE(fb_com_tasks_obj)                   :: com_tasks_recv, com_tasks_send

    CALL timeset(routineN, handle)


    NULLIFY(tasks_send)
    CALL fb_com_tasks_nullify(com_tasks_send)
    CALL fb_com_tasks_nullify(com_tasks_recv)

    ! initialise atom_pairs_send and atom_pairs_recv
    IF (fb_com_atom_pairs_has_data(atom_pairs_send)) THEN
       CALL fb_com_atom_pairs_init(atom_pairs_send)
    ELSE
       CALL fb_com_atom_pairs_create(atom_pairs_send)
    END IF
    IF (fb_com_atom_pairs_has_data(atom_pairs_recv)) THEN
       CALL fb_com_atom_pairs_init(atom_pairs_recv)
    ELSE
       CALL fb_com_atom_pairs_create(atom_pairs_recv)
    END IF

    ! source is always the local processor
    src = para_env%mepos

    ! The total number of filter matrix blocks each processor is going
    ! to construct equals to the total number of halo atoms in all of
    ! the atomic halos local to the processor. The number of send
    ! tasks will not exceed this. We do one halo (col) at a time, and
    ! each call of this subroutine will only work on one filter matrix
    ! col corresponding to atomic_halo.

    ! The col atom block index for each filter matrix block are the
    ! owner atom of each halo. The row atom block index for each
    ! filter matrix block corresponding to each col are the halo atoms
    ! of the corresponding halos. Filter matrix is non-symmetric: it
    ! is non-square, because the blocks themselves are non-sqaure

    CALL fb_atomic_halo_get(atomic_halo=atomic_halo, &
                            owner_atom=jatom_global, &
                            natoms=natoms_in_halo, &
                            halo_atoms=halo_atoms)
    ntasks_send = natoms_in_halo

    ! allocate send tasks
    ALLOCATE(tasks_send(TASK_N_RECORDS,ntasks_send))

    ! Get the total number of atoms, this can be obtained from the
    ! total number of block rows in the DBCSR filter matrix.  We
    ! assumes that before calling this subroutine, the filter_mat has
    ! already been created and initialised: i.e. using
    ! cp_dbcsr_create_new. Even if the matrix is at the moment empty,
    ! the attribute nblkrows_total is already assigned from the dbcsr
    ! distribution data
    CALL cp_dbcsr_get_info(filter_mat, &
                           nblkrows_total=nblkrows_total)

    ! construct send tasks
    itask = 1
    DO iatom_in_halo = 1, natoms_in_halo
       iatom_global = halo_atoms(iatom_in_halo)
       iatom_stored = iatom_global
       jatom_stored = jatom_global
       transpose = .FALSE.
       ! find where the constructed block of filter matrix belongs to
       CALL cp_dbcsr_get_stored_coordinates(filter_mat, &
                                            iatom_stored, &
                                            jatom_stored, &
                                            processor=dest)
       ! create the send tasks
       tasks_send(TASK_DEST,itask) = dest
       tasks_send(TASK_SRC,itask) = src
       CALL fb_com_tasks_encode_pair(tasks_send(TASK_PAIR,itask), &
                                     iatom_global, jatom_global, &
                                     nblkrows_total)
       ! calculation of cost not implemented at the moment
       tasks_send(TASK_COST,itask) = 0
       itask = itask + 1
    END DO ! iatom_in_halo

    CALL fb_com_tasks_create(com_tasks_recv)
    CALL fb_com_tasks_create(com_tasks_send)

    CALL fb_com_tasks_set(com_tasks=com_tasks_send, &
                          task_dim=TASK_N_RECORDS, &
                          ntasks=ntasks_send, &
                          nencode=nblkrows_total, &
                          tasks=tasks_send)

    ! generate the recv task list (tasks_recv) from the send task list
    CALL fb_com_tasks_transpose_dest_src(com_tasks_recv, "<", com_tasks_send, &
                                         para_env)

    ! task lists are now complete, now construct the atom_pairs_send
    ! and atom_pairs_recv from the tasks lists
    CALL fb_com_tasks_build_atom_pairs(com_tasks=com_tasks_send, &
                                       atom_pairs=atom_pairs_send, &
                                       natoms_encode=nblkrows_total, &
                                       send_or_recv="send", &
                                       symmetric=.FALSE.)
    CALL fb_com_tasks_build_atom_pairs(com_tasks=com_tasks_recv, &
                                       atom_pairs=atom_pairs_recv, &
                                       natoms_encode=nblkrows_total, &
                                       send_or_recv="recv", &
                                       symmetric=.FALSE.)

    ! cleanup
    CALL fb_com_tasks_release(com_tasks_recv)
    CALL fb_com_tasks_release(com_tasks_send)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_generate_com_pairs


! *****************************************************************************
!> \brief generate the list of blocks (atom pairs) to be sent and received
!>        in order to construct the filter matrix for each atomic halo.
!>        This vesion is for use with fb_fltrmat_build_2, where MPI
!>        communications are done collectively.
!> \param filter_mat  : DBCSR formated filter matrix
!> \param atomic_halos : set of all local atomic halos contributing to the
!>                       filter matrix
!> \param para_env : cp2k parallel environment
!> \param atom_pairs_send : list of blocks to be sent
!> \param atom_pairs_recv : list of blocks to be received
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_generate_com_pairs_2(filter_mat, &
                                             atomic_halos, &
                                             para_env, &
                                             atom_pairs_send, &
                                             atom_pairs_recv)
    TYPE(cp_dbcsr_type), POINTER             :: filter_mat
    TYPE(fb_atomic_halo_list_obj), &
      INTENT(IN)                             :: atomic_halos
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(fb_com_atom_pairs_obj), &
      INTENT(INOUT)                          :: atom_pairs_send, &
                                                atom_pairs_recv

    CHARACTER(LEN=*), PARAMETER :: &
      routineN = 'fb_fltrmat_generate_com_pairs_2', &
      routineP = moduleN//':'//routineN

    INTEGER :: dest, handle, iatom_global, iatom_in_halo, iatom_stored, &
      ihalo, itask, jatom_global, jatom_stored, natoms_in_halo, &
      nblkrows_total, nhalos, ntasks_send, src
    INTEGER(KIND=int_8), DIMENSION(:, :), &
      POINTER                                :: tasks_send
    INTEGER, DIMENSION(:), POINTER           :: halo_atoms
    LOGICAL                                  :: transpose
    TYPE(fb_atomic_halo_obj), DIMENSION(:), &
      POINTER                                :: halos
    TYPE(fb_com_tasks_obj)                   :: com_tasks_recv, com_tasks_send

    CALL timeset(routineN, handle)


    NULLIFY(tasks_send)
    CALL fb_com_tasks_nullify(com_tasks_send)
    CALL fb_com_tasks_nullify(com_tasks_recv)

    ! initialise atom_pairs_send and atom_pairs_recv
    IF (fb_com_atom_pairs_has_data(atom_pairs_send)) THEN
       CALL fb_com_atom_pairs_init(atom_pairs_send)
    ELSE
       CALL fb_com_atom_pairs_create(atom_pairs_send)
    END IF
    IF (fb_com_atom_pairs_has_data(atom_pairs_recv)) THEN
       CALL fb_com_atom_pairs_init(atom_pairs_recv)
    ELSE
       CALL fb_com_atom_pairs_create(atom_pairs_recv)
    END IF

    ! source is always the local processor
    src = para_env%mepos

    ! The col atom block index for each filter matrix block are the
    ! owner atom of each halo. The row atom block index for each
    ! filter matrix block corresponding to each col are the halo atoms
    ! of the corresponding halos. Filter matrix is non-symmetric: it
    ! is non-square, because the blocks themselves are non-sqaure

    CALL fb_atomic_halo_list_get(atomic_halos=atomic_halos, &
                                 nhalos=nhalos, &
                                 halos=halos)

    ! estimate the maximum number of blocks (i.e. atom paris) to send
    ntasks_send = 0
    DO ihalo = 1, nhalos
       CALL fb_atomic_halo_get(atomic_halo=halos(ihalo), &
                               natoms=natoms_in_halo)
       ntasks_send = ntasks_send + natoms_in_halo
    END DO ! ihalo

    ! allocate send tasks
    ALLOCATE(tasks_send(TASK_N_RECORDS,ntasks_send))

    ! Get the total number of atoms. This can be obtained from the
    ! total number of block rows in the DBCSR filter matrix.  We
    ! assumes that before calling this subroutine, the filter_mat has
    ! already been created and initialised: i.e. using
    ! cp_dbcsr_create_new. Even if the matrix is at the moment empty,
    ! the attribute nblkrows_total is already assigned from the dbcsr
    ! distribution data
    CALL cp_dbcsr_get_info(filter_mat, &
                           nblkrows_total=nblkrows_total)

    ! construct send tasks
    itask = 1
    DO ihalo = 1, nhalos
       CALL fb_atomic_halo_get(atomic_halo=halos(ihalo), &
                               owner_atom=jatom_global, &
                               natoms=natoms_in_halo, &
                               halo_atoms=halo_atoms)
       DO iatom_in_halo = 1, natoms_in_halo
          iatom_global = halo_atoms(iatom_in_halo)
          iatom_stored = iatom_global
          jatom_stored = jatom_global
          transpose = .FALSE.
          ! find where the constructed block of filter matrix belongs to
          CALL cp_dbcsr_get_stored_coordinates(filter_mat, &
                                               iatom_stored, &
                                               jatom_stored, &
                                               processor=dest)
          ! create the send tasks
          tasks_send(TASK_DEST,itask) = dest
          tasks_send(TASK_SRC,itask) = src
          CALL fb_com_tasks_encode_pair(tasks_send(TASK_PAIR,itask), &
                                        iatom_global, jatom_global, &
                                        nblkrows_total)
          ! calculation of cost not implemented at the moment
          tasks_send(TASK_COST,itask) = 0
          itask = itask + 1
       END DO ! iatom_in_halo
    END DO ! ihalo

    CALL fb_com_tasks_create(com_tasks_send)
    CALL fb_com_tasks_set(com_tasks=com_tasks_send, &
                          task_dim=TASK_N_RECORDS, &
                          ntasks=ntasks_send, &
                          nencode=nblkrows_total, &
                          tasks=tasks_send)

    ! generate the recv task list (tasks_recv) from the send task list
    CALL fb_com_tasks_create(com_tasks_recv)
    CALL fb_com_tasks_transpose_dest_src(com_tasks_recv, "<", com_tasks_send, &
                                         para_env)

    ! task lists are now complete, now construct the atom_pairs_send
    ! and atom_pairs_recv from the tasks lists
    CALL fb_com_tasks_build_atom_pairs(com_tasks=com_tasks_send, &
                                       atom_pairs=atom_pairs_send, &
                                       natoms_encode=nblkrows_total, &
                                       send_or_recv="send", &
                                       symmetric=.FALSE.)
    CALL fb_com_tasks_build_atom_pairs(com_tasks=com_tasks_recv, &
                                       atom_pairs=atom_pairs_recv, &
                                       natoms_encode=nblkrows_total, &
                                       send_or_recv="recv", &
                                       symmetric=.FALSE.)

    ! cleanup
    CALL fb_com_tasks_release(com_tasks_recv)
    CALL fb_com_tasks_release(com_tasks_send)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_generate_com_pairs_2


! *****************************************************************************
!> \brief Build the atomic filter matrix for each atomic halo
!> \param atomic_H : atomic KS matrix
!> \param atomic_S : atomic overlap matrix
!> \param fermi_level : fermi level used to construct the Fermi-Dirac
!>                      filter function
!> \param filter_temp : temperature used to construct the Fermi-Dirac
!>                      filter function
!> \param atomic_filter_mat : the atomic filter matrix
!> \param tolerance : anything smaller than tolerance is treated as zero
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_build_atomic_fltrmat(atomic_H, &
                                             atomic_S, &
                                             fermi_level, &
                                             filter_temp, &
                                             atomic_filter_mat, &
                                             tolerance)
    REAL(KIND=dp), DIMENSION(:, :), &
      INTENT(IN)                             :: atomic_H, atomic_S
    REAL(KIND=dp), INTENT(IN)                :: fermi_level, filter_temp
    REAL(KIND=dp), DIMENSION(:, :), &
      INTENT(OUT)                            :: atomic_filter_mat
    REAL(KIND=dp), INTENT(IN)                :: tolerance

    CHARACTER(LEN=*), PARAMETER :: &
      routineN = 'fb_fltrmat_build_atomic_fltrmat', &
      routineP = moduleN//':'//routineN

    INTEGER                                  :: handle, handle_dgemm, &
                                                handle_dsygv, ii, info, jj, &
                                                mat_dim, work_array_size
    LOGICAL                                  :: check_ok
    REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: eigenvalues, filter_function, &
                                                work
    REAL(KIND=dp), ALLOCATABLE, &
      DIMENSION(:, :)                        :: atomic_S_copy, eigenvectors, &
                                                filtered_eigenvectors

    CALL timeset(routineN, handle)


    ! This subroutine assumes atomic_filter_mat is not zero size, in
    ! other words, it really has to be constructed, instead of just
    ! being a dummy

    check_ok = SIZE(atomic_filter_mat, 1) > 0 .AND. &
               SIZE(atomic_filter_mat, 2) > 0
    CPASSERT(check_ok)

    ! initialise
    atomic_filter_mat = 0.0_dp
    mat_dim = SIZE(atomic_H, 1)

    ! diagonalise using LAPACK
    ALLOCATE(eigenvalues(mat_dim))
    ! get optimal work array size
    ALLOCATE(work(1))
    ! dsygv will overwrite part of atomic_H and atomic_S, thus need to copy them
    ALLOCATE(atomic_S_copy(SIZE(atomic_S,1),SIZE(atomic_S,2)))
    atomic_S_copy(:,:) = atomic_S(:,:)
    ALLOCATE(eigenvectors(SIZE(atomic_H,1),SIZE(atomic_H,2)))
    eigenvectors(:,:) = atomic_H(:,:)

    CALL timeset("fb_atomic_filter_dsygv", handle_dsygv)

    info = 0
    CALL dsygv(1,             &
               'V',           &
               'U',           &
               mat_dim,       &
               eigenvectors,  &
               mat_dim,       &
               atomic_S_copy, &
               mat_dim,       &
               eigenvalues,   &
               work,          &
               -1,            &
               info)
    work_array_size = NINT(work(1))
    ! now allocate work array
    DEALLOCATE(work)
    ALLOCATE(work(work_array_size))
    work = 0.0_dp
    ! do calculation
    atomic_S_copy(:,:) = atomic_S(:,:)
    eigenvectors(:,:) = atomic_H(:,:)
    info = 0
    CALL dsygv(1,               &
               'V',             &
               'U',             &
               mat_dim,         &
               eigenvectors,    &
               mat_dim,         &
               atomic_S_copy,   &
               mat_dim,         &
               eigenvalues,     &
               work,            &
               work_array_size, &
               info)
    ! check if diagonalisation is successful
    IF (info .NE. 0) THEN
       WRITE (*,*) "DSYGV ERROR MESSAGE: ", info
       CPABORT("DSYGV failed")
    END IF

    CALL timestop(handle_dsygv)

    DEALLOCATE(work)
    DEALLOCATE(atomic_S_copy)

    ! first get the filter function
    ALLOCATE(filter_function(mat_dim))
    filter_function = 0.0_dp
    CALL fb_fltrmat_fermi_dirac_mu(filter_function, &
                                   eigenvalues, &
                                   filter_temp, &
                                   fermi_level)
    DEALLOCATE(eigenvalues)

    ! atomic_H has the eigenvectors, construct the version of it
    ! filtered through the filter function
    ALLOCATE(filtered_eigenvectors(mat_dim,mat_dim))
    DO jj = 1, mat_dim
       DO ii = 1, mat_dim
          filtered_eigenvectors(ii,jj) = &
               filter_function(jj) * eigenvectors(ii,jj)
       END DO ! ii
    END DO ! jj

    DEALLOCATE(filter_function)

    CALL timeset("fb_atomic_filter_dgemm", handle_dgemm)

    ! construct atomic filter matrix
    CALL dgemm("N",                   &
               "T",                   &
               mat_dim,               &
               mat_dim,               &
               mat_dim,               &
               1.0_dp,                &
               filtered_eigenvectors, &
               mat_dim,               &
               eigenvectors,          &
               mat_dim,               &
               0.0_dp,                &
               atomic_filter_mat,     &
               mat_dim)

    CALL timestop(handle_dgemm)

    ! remove small negative terms due to numerical error, the filter
    ! matrix must not be negative definite
    DO jj = 1, SIZE(atomic_filter_mat,2)
       DO ii = 1, SIZE(atomic_filter_mat,1)
          IF (ABS(atomic_filter_mat(ii,jj)) < tolerance) THEN
             atomic_filter_mat(ii,jj) = 0.0_dp
          END IF
       END DO
    END DO

    DEALLOCATE(filtered_eigenvectors)
    DEALLOCATE(eigenvectors)

    CALL timestop(handle)

  END SUBROUTINE fb_fltrmat_build_atomic_fltrmat


! *****************************************************************************
!> \brief get values of Fermi-Dirac distribution based on a given fermi
!>        level at a given set of energy eigenvalues
!> \param f : the Fermi-Dirac distribution function values
!> \param eigenvals : set of energy eigenvalues
!> \param T : temperature
!> \param mu : the fermi level
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_fermi_dirac_mu(f, eigenvals, T, mu)
    REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: f
    REAL(KIND=dp), DIMENSION(:), INTENT(IN)  :: eigenvals
    REAL(KIND=dp), INTENT(IN)                :: T, mu

    CHARACTER(len=*), PARAMETER :: routineN = 'fb_fltrmat_fermi_dirac_mu', &
      routineP = moduleN//':'//routineN

    REAL(KIND=dp)                            :: kTS, ne

! we want fermi function max at 1, so maxocc = 1 here

    CALL Fermi(f, ne, kTS, eigenvals, mu, T, 1.0_dp)
  END SUBROUTINE fb_fltrmat_fermi_dirac_mu


! *****************************************************************************
!> \brief get values of Fermi-Dirac distribution based on a given electron
!>        number at a given set of energy eigenvales
!> \param f : the Fermi-Dirac distribution function values
!> \param eigenvals : set of energy eigenvalues
!> \param T : temperature
!> \param ne : number of electrons
!> \param maxocc : maximum occupancy per orbital
!> \author Lianheng Tong (LT) lianheng.tong@kcl.ac.uk
! *****************************************************************************
  SUBROUTINE fb_fltrmat_fermi_dirac_ne(f, eigenvals, T, ne, maxocc)
    REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: f
    REAL(KIND=dp), DIMENSION(:), INTENT(IN)  :: eigenvals
    REAL(KIND=dp), INTENT(IN)                :: T, ne, maxocc

    CHARACTER(len=*), PARAMETER :: routineN = 'fb_fltrmat_fermi_dirac_ne', &
      routineP = moduleN//':'//routineN

    REAL(KIND=dp)                            :: kTS, mu

! mu is the calculated fermi level
! kTS is the calculated entropic contribution to the energy i.e. -TS
! kTS = kT*[f ln f + (1-f) ln (1-f)]

    CALL FermiFixed(f, mu, kTS, eigenvals, ne, T, maxocc)
  END SUBROUTINE fb_fltrmat_fermi_dirac_ne


END MODULE qs_fb_filter_matrix_methods
