#ifndef IMAGE_OPS_HH
#define IMAGE_OPS_HH

#include "image.hh"


// abstract base class: matrix_ops<T>
// (uses line<T> as vector, and plane<T> as matrix)

template<class T> class image_ops {
public: // ctors, etc.
  virtual ~image_ops(void) {}

public: // generic operations
  virtual void fill_line(const line<T> &M, T val)=0;
  virtual void copy_line(const line_ro<T> &A, const line<T> &B)=0;
  virtual void add_line(const line_ro<T> &A, const line<T> &B)=0;
  virtual void sub_line(const line_ro<T> &A, const line<T> &B)=0;
  virtual void mul_line(const line_ro<T> &A, const line<T> &B)=0;
  virtual void scale_line(T &s, line<T> &M)=0;

  virtual void fill_plane(T val)=0;
  virtual void copy_plane(const plane_ro<T> &A, const plane<T> &B)=0;
  virtual void add_plane(const plane_ro<T> &A, const plane<T> &B)=0;
  virtual void sub_plane(const plane_ro<T> &A, const plane<T> &B)=0;
  virtual void mul_plane(const plane_ro<T> &A, const plane<T> &B)=0
  virtual void scale_plane(const T &s, const plane<T> &M)=0;
  virtual void mac_plane(const T &s, const plane_ro<T> &A,
			 const plane<T> &B)=0;
  
  virtual void trans_copy(const plane_ro<T> &A, const plane<T> &B)=0;
  virtual void trans_self(const plane<T> &M)=0;

  virtual void xform_row(const line_ro<T> &in, const plane_ro<T> &M,
			 const line<T> &out)=0;
  virtual void xform_col(const const plane_ro<T> &M, const line_ro<T> &in,
			 const line<T> &out)=0;

  virtual void mmult_AB(const plane_ro<T> &A, const plane_ro<T> &B,
			const plane<T> &C)=0;
  virtual void mmult_AtB(const plane_ro<T> &A, const plane_ro<T> &B,
			 const plane<T> &C)=0;
  virtual void mmult_ABt(const plane_ro<T> &A, const plane_ro<T> &B,
			 const plane<T> &C)=0;
  virtual void mmult_AtBt(const plane_ro<T> &A, const plane_ro<T> &B,
			  const plane<T> &C)=0;

}; // end pure-virtual template class image_ops<T>



// general notation conventions:
// - index argument order is always (rows,cols)
// - for vector/matrix->vector ops, matrix indices are (i,j)
// - for matrix->matrix ops, output indices are (i,j)
// - for matrix/matrix->matrix ops, output indices are (i,j); sum is over (k)

template<class T> class simple_ops: public image_ops<T> {
public: // generic operations
  virtual void fill_line(const line<T> &M, T val) {
    assert(M.ok());
    const Size= M.get_size();

    for(uint n=0; n<Size; n++) {
      M.set_val(n,T);
    }
  }
  virtual void copy_line(const line_ro<T> &A, const line<T> &B) {
    assert(A.ok() && B.ok());
    const Size= A.get_size();
    assert(B.get_size()==Size);

    for(uint n=0; n<Size; n++) {
      B.get_ref(n)= A.get_val(n);
    }
  }
  virtual void add_line(const line_ro<T> &A, const line<T> &B) {
    assert(A.ok() && B.ok());
    const Size= A.get_size();
    assert(B.get_size()==Size);

    for(uint n=0; n<Size; n++) {
      B.get_ref(n)+= A.get_val(n);
    }
  }
  virtual void sub_line(const line_ro<T> &A, const line<T> &B) {
    assert(A.ok() && B.ok());
    const Size= A.get_size();
    assert(B.get_size()==Size);

    for(uint n=0; n<Size; n++) {
      B.get_ref(n)-= A.get_val(n);
    }
  }
  virtual void mul_line(const line_ro<T> &A, const line<T> &B) {
    assert(A.ok() && B.ok());
    const Size= A.get_size();
    assert(B.get_size()==Size);

    // multiply  elementwise
    for(uint n=0; n<Size; n++) {
      B.get_ref(n)*= A.get_val(n);
    }
  }
  virtual void scale_line(const T &s, const line<T> &X) {
    assert(X.ok());
    const Size= X.get_size();
    assert(B.get_size()==Size);

    // multiply by a fixed value
    for(uint n=0; n<Size; n++) {
      X.get_ref(n)*= s;
    }
  }
  virtual void mac_line(const T &s, const line_ro<T> &A, const line<T> &B) {
    assert(A.ok() && B.ok());
    const Size= A.get_size();
    assert(B.get_size()==Size);

    // multiply & accumulate
    for(uint n=0; n<Size; n++) {
      B.get_ref(n)+= s * A.get_val(n);
    }
  }

  virtual void fill_plane(const plane<T> &M, T val) {
    assert(M.ok());
    const uint Imax= M.get_nrows();
    const uint Jmax= M.get_ncols();

    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	M.set_val(i,j,T);
      }
    }
  }
  virtual void copy_plane(const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_nrows()==Imax && B.get_ncols()==Jmax);

    // scan A in row-major order; copy to B elementwise
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)= A.get_val(i,j);
      }
    }
  }
  virtual void add_plane(const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_nrows()==Imax && B.get_ncols()==Jmax);

    // scan A in row-major order; add to B elementwise
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)+= A.get_val(i,j);
      }
    }
  }
  virtual void sub_plane(const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_nrows()==Imax && B.get_ncols()==Jmax);

    // scan A in row-major order; subtract from B elementwise
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)-= A.get_val(i,j);
      }
    }
  }
  virtual void mul_plane(const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_nrows()==Imax && B.get_ncols()==Jmax);
    
    // scan A in row-major order; multiply B elementwise
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)*= A.get_val(i,j);
      }
    }
  }
  virtual void scale_plane(const T& s, const plane<T> &M) {
    assert(M.ok());
    const uint Imax= M.get_nrows();
    const uint Jmax= M.get_ncols();

    // scan M in row-major order; multiply each element by s
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	M.get_ref(i,j)*= s;
      }
    }
  }
  virtual void mac_plane(const T &s, const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_nrows()==Imax && B.get_ncols()==Jmax);

    // scan A in row-major order; multiply by s and add to B elementwise
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)+= s * A.get_val(i,j);
      }
    }
  }

  virtual void trans_copy(const plane_ro<T> &A, const plane<T> &B) {
    assert(A.ok() && B.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= A.get_ncols();
    assert(B.get_ncols()==Imax && B.get_nrows==Jmax);

    // scan A in column-major order; copy to B in row-major order
    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	B.get_ref(i,j)= A.get_val(j,i);
      }
    }
  } // end function trans_copy()

  virtual void trans_self(const plane<T> &M) {
    assert(M.ok());
    const uint Size= M.get_nrows();
    assert(M.get_ncols()==Size);

    // swap upper half with lower half (leave diagonal alone)
    for(uint i=1; i<Size; i++) {
      for(uint j=0; j<i; j++) {
	T tmp= M.get_val(i,j);
	M.get_ref(i,j)= M.get_val(j,i);
	M.get_ref(j,i)= tmp;
      }
    }
  } // end function trans_self()

  virtual void xform_row(const line_ro<T> &in, const plane_ro<T> &M,
			 const line<T> &out) {
    assert(in.ok() && M.ok() && out.ok());
    const uint Imax= M.get_nrows();
    const uint Jmax= M.get_ncols();
    assert(in.get_size()==Imax && out.get_size()==Jmax);

    for(uint j=0; j<Jmax; j++) { // access matrix (M) in column-major order
      T sum= 0;
      for(uint i=0; i<Imax; i++) { // scan vector (in) repeatedly
	sum+= in.get_val(i) * M.get_val(i,j); // accumulate sum iteratively
      }
      out[j]= sum; // generate vector (out) elements sequentially
    }
  } // end function xform_row()

  virtual void xform_col(const plane_ro<T> &M, const line_ro<T> &in,
			 const line<T> &out) {
    assert(M.ok() && in.ok() && out.ok());
    const uint Imax= M.get_nrows();
    const uint Jmax= M.get_ncols();
    assert(in.get_size()==Jmax && out.get_size()==Imax);

    for(uint i=0; i<Imax; i++) { // scan matrix (M) in row-major order
      T sum=0;
      for(uint j=0; j<Jmax; j++) { // scan vector (in) repeatedly
	sum+= M.get_val(i,j) * in.get_val(j); // accumulate sum iteratively
      }
      out[i]= sum; // generate vector (out) elements sequentially
    }
  } // end function xform_col()

  virtual void mmult_AB(const plane_ro<T> &A, const plane_ro<T> &B,
			const plane<T> &C) {
    // conventional matrix multiply: (A . B) -> C
    // tensor notation: C_ij = A_ik B_kj
    //
    // C is generated in row-major order, with iterative sums;
    //   each row of A scanned repeatedly;
    //   multiplied by B, scanned in column-major order
    assert(A.ok() && B.ok() && C.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= B.get_ncols();
    const uint Kmax= A.get_ncols();
    assert(C.get_nrows()==Imax &&
	   C.get_ncols()==Jmax &&
	   B.get_nrows()==Kmax );

    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) { // scan row i of matrix (A) repeatedly
	T sum=0;
	for(uint k=0; k<Kmax; k++) { // scan matrix (B) in column-major order
	  sum+= A.get_val(i,k) * B.get_val(k,j); // accumulate sum iteratively
	}
	C.get_ref(i,j)= sum; // generate matrix (C) in row-major order
      }
    }
  } // end function mmult_AB()

  virtual void mmult_AtB(const plane_ro<T> &A, const plane_ro<T> &B,
			 const plane<T> &C) {
    // multiply (A^t . B) -> C
    // tensor notation: C_ij = A_ki B_kj
    //
    // C is generated in row-major order, with iterative sums;
    //   each *column* of A is scanned repeatedly;
    //   multiplied by B, scanned in column-major order
    // (this should be worst-case for cache behavior...)
    assert(A.ok() && B.ok() && C.ok());
    const uint Imax= A.get_ncols();
    const uint Jmax= B.get_ncols();
    const uint Kmax= A.get_nrows();
    assert(C.get_nrows()==Imax &&
	   C.get_ncols()==Jmax &&
	   B.get_nrows()==Kmax );

    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) { // scan column i of matrix (A) repeatedly
	T sum=0;
	for(uint k=0; k<Kmax; k++) { // scan matrix (B) in column-major order
	  sum+= A.get_val(k,i) * B.get_val(k,j); // accumulate sum iteratively
	}
	C.get_ref(i,j)= sum; // generate matrix (C) in row-major order
      }
    }
  } // end function mmult_AtB()

  virtual void mmult_ABt(const plane_ro<T> &A, const plane_ro<T> &B,
			 const plane<T> &C) {
    // multiply (A . B^t) -> C
    // tensor notation: C_ij = A_ik B_jk
    //
    // C is generated in row-major order, with iterative sums;
    //   each row of A is scanned repeatedly;
    //   multiplied by B, scanned in row-major order
    // (this should be best-case for cache behavior...)
    assert(A.ok() && B.ok() && C.ok());
    const uint Imax= A.get_nrows();
    const uint Jmax= B.get_nrows();
    const uint Kmax= A.get_ncols();
    assert(C.get_nrows()==Imax &&
	   C.get_ncols()==Jmax &&
	   B.get_ncols()==Kmax );

    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	T sum=0;
	for(uint k=0; k<Kmax; k++) {
	  sum+= A.get_val(i,k) * B.get_val(j,k); // accumulate sum iteratively
	}
	C.get_ref(i,j)= sum; // generate matrix (C) in row-major order
      }
    }
  } // end function mmult_ABt

  virtual void mmult_AtBt(const plane_ro<T> &A, const plane_ro<T> &B,
			  const plane<T> &C) {
    // multiply (A^t . B^t) -> C
    // tensor notation: C_ij = A_ki B_jk
    //
    // C is generated in row-major order, with iterative sums;
    //   each *column* of A is scanned repeatedly;
    //   multiplied by B, scanned in row-major order
    assert(A.ok() && B.ok() && C.ok());
    const uint Imax= A.get_ncols();
    const uint Jmax= B.get_nrows();
    const uint Kmax= A.get_nrows();
    assert(C.get_nrows()==Imax &&
	   C.get_ncols()==Jmax &&
	   B.get_ncols()==Kmax );

    for(uint i=0; i<Imax; i++) {
      for(uint j=0; j<Jmax; j++) {
	T sum=0;
	for(uint k=0; k<Kmax; k++) {
	  sum+= A.get_val(k,i) * B.get_val(j,k); // accumulate sum iteratively
	}
	C.get_ref(i,j)= sum; // generate matrix (C) in row-major order
      }
    }
  } // end function mmult_AtBt()

}; // end template class simple_ops<T>



#endif // IMAGE_OPS_HH
