#ifndef LINEAR_HH
#define LINEAR_HH

#include <assert.h>
#include <stdio.h>

// generic linear algebra classes
// Copyright (C) 2002  Bruce J. Bell
//
// I have converted back to hardcoded "scalar" and "accumulator" classes
// They are still generic WRT dimension...
//
// -- intended for homogeneous operations
// --  class accumulator implements all incremental multiply-accumulate
//     operations  (instead of scalar itself)
// -- conversion from accumulator to scalar uses provided template argument
//    "SHIFT" to divide accumulated value by 2**SHIFT
// -- the above means this class  is appropriate for homogeneous
//    operations, or for general vector operations, depending on whether
//    your use of the SHIFT parameter is consistent.
// -- the idea is to explicitly instantiate the operations,
//    using the SHIFT parameter to explicitly control range operations


#include "arithmetic.hh"
typedef scalar::accumulator accumulator;

// declare templates first
template<int N> class vect;
template<int ROWS, int COLS> class tens2;


template<int N>
class vect {
  // generic "real" vector template
  // (concrete class, to be passed by reference)
  //
  // (not optimized, but convenient for first approximation)
  //
  // notation:  in general,
  // i should be the temporary index variable over val (0 <= i < N)
  // when using template index size K, corresponding index is k (0 <= k < K)

private: // small concrete class, template friends problem
  scalar val[N];  // coefficient values

public: // generic ctors, dtor, set methods
  vect(void) {}  // default ctor
  // use default copy ctor
  // use default dtor
  // use default assignment operator

public: // access methods
  // for "safe" and foreign access to elements
  //
  // for const element access, use getval() instead of getref()
  scalar getval(int i) const {
    assert(i>=0 && i<N);
    return val[i];
  }
  scalar operator[](int i) const { return getval(i); }
  scalar& getref(int i) {
    assert(i>=0 && i<N);
    return val[i];
  }
  scalar& operator[](int i) { return getref(i); }
  void setval(int i, scalar e) {
    assert(i>=0 && i<N);
    val[i]= e;
  }

  bool iszero(void) const {
    for(int i=0; i<N; i++)  if( ! getval(i).iszero() )  return false;
    return true;
  }

public: // basic vector operations (update/replace *this)
  void zero(void) {
    for(int i=0; i<N; i++)  getref(i).zero();
  }
  void setneg(const vect<N>& x) {
    for(int i=0; i<N; i++)  getref(i).setneg(x.getval(i));
  }
  void negate(void) {
    for(int i=0; i<N; i++)  getref(i).negate();
  }

  // straight scaling automatically divides by 2**(scalar.bits-1)
  void scale(scalar s) {
    for(int i=0; i<N; i++) {
      accumulator a(getval(i));
      a.mul(s);
      setval(i, a.convert(2,0) ); // convert product back to element
    }
  }
  void setscale(scalar s, const vect<N> x) {
    for(int i=0; i<N; i++) {
      accumulator a(x.getval(i));
      a.mul(s);
      setval(i, a.convert(2,0)); // convert product back to element
    }
  }

  // straight addition automatically divides by 2
  // (outside bulk operations, try adding vectors in binary trees)
  void add(const vect<N>& x) {
    for(int i=0; i<N; i++) {
      accumulator a(getval(i));
      a.add(x.getval(i));
      setval(i, a.convert(1,1)); // convert sum back to element
    }
  }
  void sub(const vect<N>& x) {
    for(int i=0; i<N; i++) {
      accumulator a(getval(i));
      a.sub(x.getval(i));
      setval(i, a.convert(1,1)); // convert sum back to element
    }
  }
  // no per-vector incremental mac -- use bulk operations instead...

public: // other operations that result in vectors (update/replace *this)
  template<int K,int SHIFTMORE>
  void xform_col(const tens2<N,K>& T, const vect<K>& x) {
    // linear transform of column vector:  T_ij x_j = ( T|x> )_i
    // x is a column vector (of R^K)
    // T is a N-row x K-column matrix
    // result, a column vector (of T^N), is stored in *this
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )
    assert(&x != this);

    for(int i=0; i<N; i++) {
      accumulator a;
      for(int j=0; j<K; j++)
	a.mac(T.getval(i,j), x.getval(j));
      setval(i, a.convert(2,SHIFTMORE));
    }
    // (no incremental vector mac -- use above instead...)
  }

  template<int K, int SHIFTMORE>
  void xform_row(const vect<K>& y, const tens2<K,N>& T) {
    // linear transform of row vector:  T_ij y_i = ( <y|T )_j
    // y is a row vector (of R^K)
    // T is a K-row x N-column matrix
    // result, a row vector (of R^N), is stored in *this
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )
    assert(&y != this);

    for(int j=0; j<N; j++) {
      accumulator a;
      for(int i=0; i<K; i++)
	a.mac(y.getval(i),T.getval(i,j));
      setval(j, a.convert(2,SHIFTMORE));
    }
  }

public:  // operations that result in scalars (or accumulators...)
  template<int SHIFTMORE>
  accumulator inner(const vect<N>& x) const { // inner product
    // SHIFTMORE should usually be be least_integer_gt( log2(N) )
    accumulator a;
    for(int i=0; i<N; i++)  a.mac(getval(i),x.getval(i));
    return a;
  }
  accumulator norm_2(void) const { // norm squared
    return inner(*this);
  }

public:  // standard debug operations
  // use C stdio, not C++ streams
  void dump(FILE *f) const {  // output object state in human-readable format
    // use dump() for debugging and for cheezy character output
    fprintf(f,"vect[ ");
    for(int i=0; i<N; i++) {
      getval(i).dump(f);
      fprintf(f," ");
    }
    fprintf(f,"]");
  }
  bool ok(void) const {  // silently check & return consistency of object state
    // use ok() to determine if object is properly constructed
    bool flag=true;
    for(int i=0; i<N; i++)  flag= (getval(i).ok() && flag);
    return flag;
  }
  bool analyze(FILE *f) const {  // verbosely check & return consistency flag
    // use analyze() in asserts for debugging purposes
    // analyze() may be more thorough in its consistency check than ok()
    bool flag=true;
    
    fprintf(f,"vect[%d]:\n", N);
    for(int i=0; i<N; i++) {
      fprintf(f,"  %d: ",i);
      flag= (getval(i).analyze(f) && flag);
    }
    return flag;
  }

}; // end template<N> class vect


// generic "real" matrix (second-rank tensor) template
// (not optimized, but convenient for first approximation)
template<int ROWS, int COLS> class tens2 {
  // notational conventions:
  //  - Einstein notation with no vector/covector distinction.
  //    Because each index position may refer to a different
  //    vector space, there is no inherent correspondence
  //    between indexed spaces.  Thus, all indices may as well
  //    be represented as subscripts.
  //  - matrix operations, often expressed in "bra-ket" notation for emphasis:
  //    column vectors can be indicated as |x>
  //    row vectors can be indicated as <y|
  //    matrix multiplying column vector -- Tx can be written as T|x>
  //    matrix multiplying row vector -- yT can be written as <y|T
  //    complete bilinear function of x,y -- yTx can be written as <y|T|x>
  //
  // first index (i) binds ROWS-dimensional vector space (row vectors)
  //   [indexes rows (=== elements of the component columns), of the tensor]
  // second index (j) binds COLS-dimensional vector space (column vectors)
  //   [indexes component columns, of the tensor]
  // i.e., T_ij x_j = ( Tx )_i = ( T|x> )_i
  //   (x is a column vector;
  //    i indexes row of T and of result;
  //    j indexes column of T and row of x)
  // and   T_ij y_i = ( yT )_j = ( <y|T )_j
  //   (y is a row vector;
  //    i indexes row of T and column of y;
  //    j indexes column of T and of result)
  //
  // in general, temporary index variables i and j are used consistently
  // with above (T_ij) notation:
  // 0 <= i < ROWS  (i indexes T row  [and element of component column])
  // 0 <= j < COLS  (j indexes T column [and column element of tensor])
  //
  // function arguments are *always* given in (ROWS,COLS) order
  // so explicit indexing functions are compatible with above notation
  // (e.g., in T.get(i,j):  i is row index of T;  j is column index of T)
  //
  // note also: due to matrix-multiplication conventions, the vector spaces
  // associated with the tensor have the length of the matrix components with
  // differing orientation.
  //
  // i.e., compatible column vect's have COLS elements (to match matrix rows)
  // while the matrix columns have ROWS elements (to match row vect's)

private: // small concrete class, template friends problem...
  vect<ROWS> col[COLS];  // elements stored as column vectors

public: // generic ctors, dtor, set methods
  tens2(void) {}
  // use default copy ctor
  // use default dtor
  // use default assignment operator

public: // access methods
  // intended for "safe" and foreign access to elements
  //
  // note that indices are supplied in (row,column) format
  // to make dimension ordering consistent with subscript notation
  //
  // for const element access, use getval() instead of getref()
  //
  // (for now, there is no column access operator[];
  // one should be added iff it is necessary...)

  scalar getval(int i, int j) const {
    assert(i>=0 && i<ROWS);
    assert(j>=0 && j<COLS);
    return col[j][i];
  }
  vect<ROWS>& colref(int j) {
    assert(j>=0 && j<COLS);
    return col[j];
  }
  const vect<ROWS>& colref(int j) const {
    assert(j>=0 && j<COLS);
    return col[j];
  }
  scalar& getref(int i, int j) {
    assert(i>=0 && i<ROWS);
    assert(j>=0 && j<COLS);
    return col[j].val[i];
  }
  void setcol(int j, const vect<ROWS>& x) {
    assert(j<=0 && j<COLS);
    col[j]= x;
  }
  void setval(int i, int j, scalar e) {
    assert(i>=0 && i<ROWS);
    assert(j>=0 && j<COLS);
    col[j][i]= e;
  }

  bool iszero(void) const {
    for(int j=0; j<COLS; j++) {
      if( !col[j].iszero() ) return false;
    }
    return true;
  }

public: // vector-like matrix operations (replace/update *this)
  void zero(void) {
    for(int j=0; j<COLS; j++)  colref(j).zero();
  }
  void setneg(const tens2<ROWS,COLS>& x) {
    for(int j=0; j<COLS; j++)  colref(j).setneg(x.colref(j));
  }
  void negate(void) {
    for(int j=0; j<COLS; j++)  colref(j).negate();
  }
  void scale(scalar s) {
    for(int j=0; j<COLS; j++)  colref(j).scale(s);
  }
  void setscale(scalar s, const tens2<ROWS,COLS>& x) {
    for(int j=0; j<COLS; j++)  colref(j).setscale(s, x.colref(j));
  }
  void add(const tens2<ROWS,COLS>& x) {
    for(int j=0; j<COLS; j++)  colref(j).add(x.colref(j));
  }
  void sub(const tens2<ROWS,COLS>& x) {
    for(int j=0; j<COLS; j++)  colref(j).sub(x.colref(j));
  }
  // no per-matrix incremental mac -- use bulk operations instead...

public: // operations that result in second-rank tensors (replace/update *this)
  void set_transpose(const tens2<COLS,ROWS>& A) {
    // set *this to transpose of A
    assert(&A != this);
    for(int j=0; j<COLS; j++)
      for(int i=0; i<ROWS; i++)
	setval(j,i, A.getval(i,j));
  }
  void transpose(void) {
    // transpose *this in place
    assert(ROWS==COLS); // must be symmetric
    for(int j=0; j<COLS; j++)
      for(int i=0; i<j; i++) {
	scalar tmp= getval(j,i);
	setval(j,i, getval(i,j));
	setval(i,j, tmp);
      }
  }

  void outer(const vect<ROWS>& x, const vect<COLS>& y) {
    // *this= outer product of two vectors:  T_ij = y_i x_j = |x><y|
    // x is a column vector (of R^ROWS)
    // y is a row vector (of R^COLS)
    //
    // note that,  unlike matrix multiplication, the vector lengths
    // are the same as the corresponding matrix components
    //
    // this means argument order is in some sense reversed from
    // the usual order -- i.e., outer(row vector x, column vector y).
    // this is necessary to maintain tensor index order convention, which
    // should be consistently in (ROWS,COLS) order
    for(int j=0; j<COLS; j++)
      colref(j).setscale(y.getval(j), x);
  }
  // no per-matrix incremental mac_outer -- use bulk operations instead...

  template<int K,int SHIFTMORE>
  void prod_matrix(const tens2<ROWS,K>& A, const tens2<K,COLS>& B) {
    assert(&A != this);
    assert(&B != this);
    // *this= tensor product: T_ij = A_ik B_kj
    //
    // <y|T|x> = T_ij x_j y_i = A_ik B_kj x_j y_i = (y_i A_ik) (B_kj x_j)
    // = (<y|A)_k (B|x>)_k = <y|AB|x>
    // so, in conventional matrix notation, T = AB
    //
    // sum over columns of A and rows of B
    // temporary index variable k is sum index (0 <= k < K)
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )

    for(int j=0; j<COLS; j++)
      for(int i=0; i<ROWS; i++) {
	accumulator a;  // initializes to 0
	for(int k=0; k<K; k++)
	  a.mac(A.getval(j,k), B.getval(k,i));
	setval(i,j, a.convert(2,SHIFTMORE));
      }
  }

  template<int K,int SHIFTMORE>
  void prod_transB(const tens2<ROWS,K>& A, const tens2<COLS,K>& B) {
    assert(&A != this);
    assert(&B != this);
    // *this= tensor product: T_ij = A_ik B_jk
    //
    // <y|T|x> = T_ij x_j y_i = A_ik B_jk x_j y_i = (A_ik y_i) (B_jk x_j)
    // = (<y|A) (B^t|x>) = <y|A B^t|x>
    // so, in conventional matrix notation, T= A B^t
    // (transpose B before matrix multiply)
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )

    for(int j=0; j<COLS; j++)
      for(int i=0; i<ROWS; i++) {
	accumulator a;  // initializes to 0
	for(int k=0; k<K; k++)
	  a.mac(A.getval(i,k), B.getval(j,k));
	setval(i,j, a.convert(2,SHIFTMORE));
      }
  }

  template<int K,int SHIFTMORE>
  void prod_transA(const tens2<K,ROWS>& A, const tens2<K,COLS>& B) {
    assert(&A != this);
    assert(&B != this);
    // *this= tensor product: T_ij = A_ki B_kj
    //
    // <y|T|x> = T_ij x_j y_i = A_ki B_kj x_j y_i = (A_ki y_i) (B_kj x_j)
    // = (<y|A^t) (B|x>) = <y|A^t B|x>
    // so, in conventional matrix notation, T= A^t B
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )

    for(int j=0; j<COLS; j++)
      for(int i=0; i<ROWS; i++) {
	accumulator a;  // initializes to 0
	for(int k=0; k<K; k++)
	  a.mac(A.getval(k,i), B.getval(k,j));
	setval(i,j, a.convert(2,SHIFTMORE));
      }
  }

  template <int K,int SHIFTMORE>
  void prod_transAB(const tens2<K,ROWS> &A, const tens2<COLS,K> &B) {
    assert(&A != this);
    assert(&B != this);
    // tensor product: T_ij = A_ki B_jk
    // result is stored in *this
    //
    // <y|T|x> = T_ij x_j y_i = A_ki B_jk x_j y_i = (A_ki y_i) (B_jk x_j)
    // = (<y|A^t) (B^t|x>) = <y|A^t B^t|x>
    // so, in conventional matrix notation, T= A^t B^t
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(K) )

    for(int j=0; j<COLS; j++)
      for(int i=0; i<ROWS; i++) {
	accumulator a;
	for(int k=0; k<K; k++)
	  a.mac(A.getval(k,i), B.getval(j,k));
	setval(i,j, a.convert(2,SHIFTMORE));
      }
  }

public: // operations that result in scalars
  template<int SHIFTMORE>
  scalar bilinear(const vect<ROWS>& y, const vect<COLS>& x) const {
    // bilinear function of two vectors:  T_ij x_j y_i = <y|T|x>
    // x is a column vector (of R^COLS)
    // y is a row vector (of R^ROWS)
    // to make the dimension ordering consistent, y is the first argument.
    //
    // SHIFTMORE should usually be be least_integer_gt( log2(COLS*ROWS) )
    //
    // (this operation may demand unusual dynamic range from accumulator...)

    accumulator a;
    for(int j=0; j<COLS; j++)
      a.mac(x.getval(j), y.inner(colref(j)));
    return a.convert(3,SHIFTMORE);
  }

public:  // standard debug operations
  // use C stdio, not C++ streams
  void dump(FILE *f) const {  // output object state in human-readable format
    // use dump() for debugging and for cheezy character output
    fprintf(f,"\ntens2: [");
    for(int j=0; j<COLS; j++) {
      fprintf(f,"   col=%-3d ", j);
    }
    fprintf(f," ]\n");
    for(int i=0; i<ROWS; i++) {
      fprintf(f,"        ");
      for(int j=0; j<COLS; j++) {
	//colref(j).dump(f);
	//fprintf(f," ");
	fprintf(f," %6d    ", getval(i,j).toint());
      }
      fprintf(f,"\n");
    }
  }
  bool ok(void) const {  // silently check & return consistency of object state
    // use ok() to determine if object is properly constructed
    for(int j=0; j<COLS; j++)
      if(!col[j].ok()) return false;
    return true;
  }
  bool analyze(FILE *f) const {  // verbosely check & return consistency flag
    // use analyze() in asserts for debugging purposes
    // analyze() may be more thorough in its consistency check than ok()
    bool flag=true;

    fprintf(f,"tens2[%d,%d]:\n",ROWS,COLS);
    for(int j=0; j<COLS; j++) {
      fprintf(f," %d: ",j);
      flag= (colref(j).analyze(f) && flag);
    }
    return flag;
  }

}; // end template<COLS,ROWS> class tens2




#endif // LINEAR_HH
