// recursive_ops.cc: implement recursive operations for images
//   - this should be more cache-friendly than a simple dot-product sum
//   - may do pairwise accumulation, for better accuracy /w inexact arithmetic

#include "recursive_ops.hh"

template<class T>
void recursiveA_ops::my_trans_copy(const plane_ro<T> &A, const plane<T> &B) {
  // set B to transpose of A
  assert(A.ok() && B.ok());
  const uint Imax= A.get_nrows();
  const uint Jmax= A.get_ncols();
  assert(B.get_ncols()==Imax && B.get_nrows==Jmax);
  assert(Imax>0 && Jmax>0); // additional invariant

  if(Imax==1 && Jmax==1) { // actually copy data
    B.get_ref(0,0)= A.get_val(0,0);
    return;
  }

  plane_ro<T> A1,A2;
  plane<T> B1,B2;
  if(Imax > Jmax) { // split Imax
    assert(Imax>1);
    uint Inew= Imax>>1;
    A.split_byrows_ro(Inew,A1,A2);
    B.split_bycols(Inew,B1,B2);
    my_trans_copy(A1,B1);
    my_trans_copy(A2,B2);
  } else { // Jmax >= Imax; split Jmax
    assert(Jmax>1);
    uint Jnew= Jmax>>1;
    A.split_bycols_ro(Jnew,A1,A2);
    B.split_byrows(Jnew,B1,B2);
    my_trans_copy(A1,B1);
    my_trans_copy(A2,B2);
  }
} // end implementation:  recursiveA_ops::my_trans_copy()

template<class T>
void recursiveA_ops::my_trans_self(const plane<T> &M) {
  // transpose M in place
  assert(M.ok());
  const uint Size= M.get_nrows();
  assert(M.get_ncols()==Size);
  assert(Size > 0); // additional invariant

  if(Size==1) return; // on diagonal; need not do anything

  assert(Size>1);
  plane<T> M11,M12,M21,M22;
  uint Half= Size>>1;
  M.split_quad(Half,Half,M11,M12,M21,M22);
  my_trans_self(M11);
  my_trans_swap(M12,M21);
  my_trans_self(M22);
} // end implementation:  recursiveA_ops::my_trans_self()

template<class T>
void recursiveA_ops::my_trans_swap(const plane<T> &A, const plane<T> &B) {
  // swap A and B, transposing data [used by trans_self()]
  assert(A.ok() && B.ok());
  const uint Imax= A.get_nrows();
  const uint Jmax= A.get_ncols();
  assert(B.get_ncols()==Imax && B.get_nrows==Jmax);
  assert(Imax>0 && Jmax>0); // additional invariant

  if(Imax==1 && Jmax==1) { // actually swap data
    T tmp= A.get_val(0,0);
    A.get_ref(0,0)= B.get_val(0,0);
    B.get_ref(0,0)= tmp;
    return;
  }

  plane<T> A1,A2,B1,B2;
  if(Imax > Jmax) { // split Imax
    assert(Imax>1);
    uint Inew= Imax>>1;
    A.split_byrows(Inew,A1,A2);
    B.split_bycols(Inew,B1,B2);
    my_trans_swap(A1,B1);
    my_trans_swap(A2,B2);
  } else {
    assert(Jmax>1);
    uint Jnew= Jmax>>1;
    A.split_bycols(Jnew,A1,A2);
    B.split_byrows(Jnew,B1,B2);
    my_trans_swap(A1,B1);
    my_trans_swap(A2,B2);
  }
} // end implementation:  recursiveA_ops::my_trans_swap()


template<class T>
void recursiveA_ops::my_xform_accum_row(const line_ro<T> &in,
					const plane_ro<T> &M,
					const line<T> &out) {
  // multiply row vector (in) by row-major matrix (M)
  //   -> add to row vector (out)
  // (accumulates sum iteratively)
  assert(in.ok() && M.ok() && out.ok());
  const uint Imax= M.get_nrows();
  const uint Jmax= M.get_ncols();
  assert(in.get_size()==Imax && out.get_size()==Jmax);
  assert(Imax>0 && Jmax>0); // additional invariant

  if(Imax==1 && Jmax==1) { // accumulate sum
    out[0]+= in[0] * M.get_val(0,0);
    return;
  }

  if(Imax > Jmax) { // split Imax
    assert(Imax>1);

    plane_ro<T> M1,M2;
    line_ro<T> in1,in2;

    uint Inew= Imax>>1;
    in.split_ro(Inew,in1,in2);
    M.split_byrows_ro(Inew,M1,M2);
    my_xform_accum_row(in1,M1,out);
    my_xform_accum_row(in2,M2,out);
  } else { // Jmax >= Imax; split Jmax
    assert(Jmax>1);

    plane_ro<T> M1,M2;
    line<T> out1,out2;

    uint Jnew= Jmax>>1;
    M.split_bycols_ro(Jnew,M1,M2);
    out.split(Jnew,out1,out2);
    my_xform_accum_row(in,M1,out1);
    my_xform_accum_row(in,M2,out2);
  }
} // end implementation:  recursiveA_ops::my_xform_accum_row()

template<class T>
void recursiveA_ops::my_xform_accum_col(const plane_ro<T> &M,
					const line_ro<T> &in,
					const line<T> &out) {
  // multiply column vector (in) by row-major matrix (M)
  //   -> add to row vector (out)
  assert(M.ok() && in.ok() && out.ok());
  const uint Imax= M.get_nrows();
  const uint Jmax= M.get_ncols();
  assert(in.get_size()==Jmax && out.get_size()==Imax);
  assert(Imax>0 && Jmax>0); // additional invariant

  if(Imax==1 && Jmax==1) { // accumulate sum
    out[0]+= M.get_val(0,0) * in[0];
    return;
  }

  if(Imax > Jmax) { // split Imax
    assert(Imax>1);

    plane_ro<T> M1,M2;
    line<T> out1,out2;

    uint Inew= Imax>>1;
    M.split_byrows_ro(Inew,M1,M2);
    out.split(Inew,out1,out2);
    my_xform_accum_col(in,M1,L1);
    my_xform_accum_col(in,M2,L2);
  } else { // Jmax >= Imax; split Jmax
    assert(Jmax>0);

    plane_ro<T> M1,M2;
    line_ro<T> in1,in2;

    uint Jnew= Jmax>>1;
    in.split_ro(Jnew,in1,in2);
    M.split_bycols_ro(Jnew,M1,M2);
    my_xform_accum_col(in1,M1,out);
    my_xform_accum_col(in2,M2,out);
  }  
} // end implementation:  recursiveA_ops::my_xform_accum_col()

template<class T>
void recursiveA_ops::my_mmult_accum_AB(const plane_ro<T> &A,
				       const plane_ro<T> &B,
				       const plane<T> &C) {
  // multiply 2 row-major matrices (A,B) -> add to row-major matrix (C)
  assert(A.ok() && B.ok() && C.ok());
  const uint Imax= A.get_nrows();
  const uint Jmax= B.get_ncols();
  const uint Kmax= A.get_ncols();
  assert(C.get_nrows()==Imax &&
	 C.get_ncols()==Jmax &&
	 B.get_nrows()==Kmax );
  assert(Imax>0 && Jmax>0 && Kmax>0); // additional invariant

  if(Imax==1 && Jmax==1 && Kmax==1) { // accumulate sum
    C.get_ref(0,0)+= A.get_val(0,0) * B.get_val(0,0);
    return;
  }

  if(Kmax > Imax && Kmax > Jmax) { // split Kmax
    plane_ro<T> A1,A2,B1,B2;
    assert(Kmax>1);
    uint Knew= Kmax>>1;
    A.split_byrows_ro(Knew,A1,A2);
    B.split_byrows_ro(Knew,B1,B2);
    my_mmult_accum_AB(A1,B1,C);
    my_mmult_accum_AB(A2,B2,C); // accumulate to same array...
  } else if(Imax > Jmax) { // split Imax
    // [note: Imax >= Kmax]
    plane_ro<T> A1,A2;
    plane<T> C1,C2;
    assert(Imax>1);
    uint Inew= Imax>>1;
    A.split_byrows_ro(Inew,A1,A2);
    C.split_byrows(Inew,C1,C2);
    my_mmult_accum_AB(A1,B,C1);
    my_mmult_accum_AB(A2,B,C2);
  } else { // split Jmax
    // [note: Jmax >= Imax,Kmax]
    plane_ro<T> B1,B2;
    plane<T> C1,C2;
    assert(Jmax>1);
    uint Jnew= Jmax>>1;
    B.split_bycols_ro(Jnew,B1,B2);
    C.split_bycols(Jnew,C1,C2);
    my_mmult_accum_AB(A,B1,C1);
    my_mmult_accum_AB(A,B2,C2);
  }
} // end implementation:  recursiveA_ops::my_mmult_accum_AB()

// end file recursive_ops.cc
