/*
 * N-Dimensional array template approximating the spec of
 * Norbert Nemec at
 * http://homepages.uni-regensburg.de/~nen10015/documents/D-multidimarray.html
 * see small type at bottom for "legal" stuff
 */

// generic column-oriented N-dimensional array of T
struct NDArray(T,int N) {
  T *data;
  uint[N] length;
  uint[N] stride;

  // construct array from data and dimensions x0, x1, ...
  static NDArray opCall(T* data, uint x0,...) {
    NDArray res;
    uint[] len = (&x0)[0..N]; // not very portable, use stdarg
    res.length[] = len[];
    res.stride[0] = 1;
    for (int k = 1; k < N; ++k) {
      res.stride[k] = res.length[k-1]*res.stride[k-1];
    }
    if (data is null) {
      data = new T[res.stride[N-1]*res.length[N-1]];
    }
    res.data = data;
    return res;
  }

  // Two sets of recursive templates for index expressions mixed with
  // slice expressions. The second-to-last expression is singled out
  // because the return types of the opIndex and opSlice calls are
  // different than the opIndex and opSlice return types during the
  // middle of the indexing expresion. As soon as a slice expression 
  // occurs the entire operation becomes a slice expression and the
  // end result will be an NDArray. Otherwise the end result is type T.
  // The main difference is that NDArray cannot be assigned to as in
  //  A[0..2][3..5] = 10

  // private struct for indexing expressions specialized for second-to-last
  // dimension.
  // If we made it this far with index expr then all the previous expr
  // were also index expr.
  private struct IndexExpr(int M:1) {
    NDArray *x;
    uint[N-1] prev_n;

    // make an IndexExpr
    static .NDArray!(T,N).IndexExpr!(1) make(NDArray *x,
				 uint[N-1] n) {
      .NDArray!(T,N).IndexExpr!(1) res;
      res.x = x;
      res.prev_n[] = n[];
      return res;
    }

    // A*[n] where * are opIndex calls
    T opIndex(uint n) {
      uint sum = 0;
      for (int k=0; k < N-1; ++k)
	sum += x.stride[k]*prev_n[k];
      return x.data[sum + x.stride[N-1]*n];
    }

    // A*[n] = val where * are opIndex calls
    void opIndex(uint n, T val) {
      uint sum = 0;
      for (int k=0; k < N-1; ++k) {
	sum += x.stride[k]*prev_n[k];
      }
      x.data[sum + x.stride[N-1]*n] = val;
    }

    // A*[i..j] where * are opIndex calls
    NDArray opSlice(uint i, uint j) {
      NDArray res;
      uint sum = 0;
      for (int k=0; k < N-1; ++k) {
	sum += x.stride[k]*prev_n[k];
      }
      res.length[] = 1;
      res.stride[] = x.stride[];
      res.length[N-1] = j-i;
      res.data = &x.data[sum + x.stride[N-1]*i];
      return res;
    }

    // A*[] where * are opIndex calls
    NDArray opSlice() {
      return opSlice(0,x.length[N-1]);
    }
  }

  // private struct for slicing expressions specialized for last dimension
  private struct SliceExpr(int M:1) {
    NDArray *x;
    uint[N-1] prev_i;
    uint[N-1] prev_j;
    uint[N-1] prev_step;

    // make a SliceExpr in the second-to-last dim
    static .NDArray!(T,N).SliceExpr!(1) make(NDArray *x,
					     uint[N-1] i,
					     uint[N-1] j,
					     uint[N-1] step) {
      .NDArray!(T,N).SliceExpr!(1) res;
      res.prev_i[] = i[];
      res.prev_j[] = j[];
      res.prev_step[] = step[];
      res.x = x;
      return res;
    }
    
    // A*[i..j].by(n)[*]
    .NDArray!(T,N).SliceExpr!(1) by(uint step) {
      .NDArray!(T,N).SliceExpr!(1) res;
      res.x = x;
      res.prev_i[] = prev_i[];
      res.prev_j[] = prev_j[];
      res.prev_step[] = prev_step[];
      res.prev_step[N-2] = step;
      return res;
    }

    // A*[i..j][i2..j2]
    NDArray opSlice(uint i, uint j) {
      NDArray res;
      for (int k=0; k < N-1; ++k) {
	res.length[k] = nd_nelems(prev_j[k]-prev_i[k],prev_step[k]);
	res.stride[k] = prev_step[k]*x.stride[k];
      }
      res.length[N-1] = j-i;
      res.stride[N-1] = x.stride[N-1];
      uint sum = 0;
      for (int k=0; k < N-1; ++k) {
	sum += res.stride[k]*prev_i[k];
      }
      res.data = &x.data[sum + x.stride[N-1]*i];
      return res;
    }

    // A*[i..j][]
    NDArray opSlice() {
      return opSlice(0,x.length[N-1]);
    }

    // A*[i..j][n]
    NDArray opIndex(uint i) {
      return opSlice(i,i+1);
    }
  }

  // private struct for sliceing expressions for reducing dimension by 1
  // with luck this can be inlined
  private struct SliceExpr(uint M) {
    NDArray *x;
    uint[N-1] prev_i;
    uint[N-1] prev_j;
    uint[N-1] prev_step;

    // make a SliceExpr not in the last dimension
    static .NDArray!(T,N).SliceExpr!(M) make(NDArray *x,
					     uint[N-1] i,
					     uint[N-1] j,
					     uint[N-1] step) {
      .NDArray!(T,N).SliceExpr!(M) res;
      res.prev_i[] = i[];
      res.prev_j[] = j[];
      res.prev_step[] = step[];
      res.x = x;
      return res;
    }

    // A*[i..j].by(n)*
    .NDArray!(T,N).SliceExpr!(M) by(uint step) {
      uint[N-1] astep;
      astep[] = prev_step[];
      astep[N-M-1] = step;
      return .NDArray!(T,N).SliceExpr!(M).make(x,prev_i,prev_j,astep);
    }

    // A*[i..j][i2..j2]*
    .NDArray!(T,N).SliceExpr!(M-1) opSlice(uint i, uint j) {
      prev_i[N-M] = i;
      prev_j[N-M] = j;
      return .NDArray!(T,N).SliceExpr!(M-1).make(x,prev_i,prev_j,prev_step);
    }

    // A*[i..j][]*
    .NDArray!(T,N).SliceExpr!(M-1) opSlice() {
      return opSlice(0,x.length[N-M]);
    }

    // A*[i]*
    .NDArray!(T,N).SliceExpr!(M-1) opIndex(uint i) {
      return opSlice(i,i+1);
    }
  }


  // private struct for indexing expressions for reducing dimension by 1
  // with luck this can be inlined
  private struct IndexExpr(uint M) {
    NDArray *x;
    uint[N-1] prev_n;

    // make an IndexExpr
    static .NDArray!(T,N).IndexExpr!(M) make(NDArray *x, 
					     uint[N-1] n) {
      .NDArray!(T,N).IndexExpr!(M) res;
      res.x = x;
      res.prev_n[] = n[];
      return res;
    }

    // A*[n]* 
    .NDArray!(T,N).IndexExpr!(M-1) opIndex(uint n) {
      uint[N-1] inds;
      inds[] = prev_n[];
      inds[N-M] = n;
      return .NDArray!(T,N).IndexExpr!(M-1).make(x,inds);
    }

    // A*[i..j]* 
    .NDArray!(T,N).SliceExpr!(M-1) opSlice(uint i,uint j) {
      uint[N-1] prev_i;
      uint[N-1] prev_j;
      uint[N-1] prev_step;
      for (int k=0; k<M-1; ++k) {
	prev_i[k] = prev_n[k];
	prev_j[k] = prev_n[k]+1;
      }
      prev_i[M-1] = i;
      prev_j[M-1] = j;
      prev_step[] = 1;
      return .NDArray!(T,N).SliceExpr!(M-1).make(x, prev_i, prev_j, prev_step);
    }

    // A*[]*
    .NDArray!(T,N).SliceExpr!(M-1) opSlice() {
      return opSlice(0,x.length[N-M]);
    }
  }

  // A[n]*
  .NDArray!(T,N).IndexExpr!(N-1) opIndex(uint n) {
    uint[N-1] inds;
    inds[0] = n;
    return .NDArray!(T,N).IndexExpr!(N-1).make(cast(NDArray*)&data, inds);
  }

  // A[i..j]*
  .NDArray!(T,N).SliceExpr!(N-1) opSlice(uint i,uint j) {
    uint[N-1] prev_i;
    uint[N-1] prev_j;
    uint[N-1] prev_step;
    prev_i[0] = i;
    prev_j[0] = j;
    prev_step[] = 1;
    return .NDArray!(T,N).SliceExpr!(N-1).make(cast(NDArray*)&data,
					       prev_i,prev_j,prev_step);
  }

  // A[]*
  .NDArray!(T,N).SliceExpr!(N-1) opSlice() {
    return opSlice(0,length[0]);
  }

  // A*.by(step)
  NDArray by(uint step) {
    NDArray res;
    res.length[] = length[];
    res.stride[] = stride[];
    res.length[N-1] = nd_nelems(length[N-1],step);
    res.stride[N-1] = step*stride[N-1];
    res.data = data;
    return res;
  }

  // number of elements in the array
  uint nelems() {
    uint total = 1;
    for (int k=0; k < N; k++) 
      total *= length[k];
    return total;
  }
}

////////////////////////////////////////////////////////////
// Two dimensional array is specialized for easier inlining

struct NDArray(T,int N:2) {
  T *data;
  uint[2] length;
  uint[2] stride;

  // construct a new matrix
  static NDArray opCall(uint x0, uint x1) {
    return opCall(new T[x0*x1],x0,x1);
  }
  
  // construct a new matrix
  static NDArray opCall(T* data, uint x0, uint x1) {
    NDArray res;
    res.length[0] = x0;
    res.length[1] = x1;
    res.stride[0] = 1;
    res.stride[1] = x0;
    res.data = data;
    return res;
  }

  // private intermediate structure for indexing expressions
  private struct IndexExpr {
    NDArray *x;
    uint m;

    // A[m][n]
    T opIndex(uint n) {
      // bounds-check on length
      return x.data[m*x.stride[0] + n*x.stride[1]];
    }

    // A[m][n] = val
    void opIndex(uint n,T val) {
      // bounds-check on length
      x.data[m*x.stride[0] + n*x.stride[1]] = val;
      //      x.data[m + n*x.stride[1]] = val;
    }

    // A[m][i..j]
    NDArray opSlice(uint i, uint j) {
      NDArray res;
      res.data = &x.data[m*x.stride[0] + i*x.stride[1]];
      res.length[0] = 1;
      res.length[1] = j-i;
      res.stride[0] = x.stride[0];
      res.stride[1] = x.stride[1];
      return res;
    }
  }

  // private intermediate structure for slicing expressions
  private struct SliceExpr {
    NDArray *x;
    uint i;
    uint j;
    uint step;

    // A[i..j][n]
    NDArray opIndex(uint n) {
      NDArray res;
      res.length[0] = nd_nelems(j-i,step);
      res.length[1] = 1;
      res.stride[0] = step*x.stride[0];
      res.stride[1] = x.stride[1];
      res.data = &x.data[i*x.stride[0] + n*x.stride[1]];
      return res;
    }

    // A[i..j][i2..j2]
    NDArray opSlice(uint i2, uint j2) {
      NDArray res;
      res.length[0] = nd_nelems(j-i,step);
      res.length[1] = j2-i2;
      res.stride[0] = step*x.stride[0];
      res.stride[1] = x.stride[1]*x.stride[0];
      res.data = &x.data[i*x.stride[0] + x.stride[1]*i2];
      return res;
    }

    // A[i..j].by(step)*
    SliceExpr by(uint step) {
      SliceExpr res;
      res.i = i;
      res.j = j;
      res.step = step;
      res.x = x;
      return res;
    }
  }

  // A[*][i2..j2].by(step)
  NDArray by(uint step) {
    NDArray res;
    res.length[0] = length[0];
    res.length[1] = nd_nelems(length[1],step);
    res.stride[0] = stride[0];
    res.stride[1] = step*stride[1];
    res.data = data;
    return res;
  }

  // A[m]*
  IndexExpr opIndex(int m) {
    IndexExpr res;
    res.x = cast(NDArray*)&data;
    res.m = m;
    return res;
  }

  // A[i..j]*
  SliceExpr opSlice(uint i, uint j) {
    SliceExpr res;
    res.x = cast(NDArray*)&data; // no "this" for struct, hmmm
    res.i = i;
    res.j = j;
    res.step = 1;
    return res;
  }

  // A[]*
  SliceExpr opSlice() {
    return opSlice(0,length[0]);
  }

  // number of elements in the array
  uint nelems() {
    return length[0]*length[1];
  }

  // subarray for a single column
  NDArray col(uint n) {
    NDArray res;
    res.length[0] = length[0];
    res.length[1] = 1;
    res.stride[0] = stride[0];
    res.stride[1] = 1;
    res.data = data+stride[1]*n;
    return res;
  }

  // subarray for a single row
  NDArray row(uint m) {
    NDArray res;
    res.length[0] = 1;
    res.length[1] = length[1];
    res.stride[0] = 1;
    res.stride[1] = stride[1];
    res.data = data+stride[0]*m;
    return res;
  }

  // reshape matrix, just for fun
  NDArray reshape(uint x0, uint x1) {
    NDArray res;
    res.length[0] = x0;
    res.length[1] = x1;
    res.data = data;
    return res;
  }

}

// make Matrix a nice alias for 2D array
template Matrix(T) {
  alias NDArray!(T,2) Matrix;
}

// private helper function
private uint nd_nelems(uint n, uint step) {
  if (n == 0)
    return 0;
  else
    return ((n-1)/step)+1;
}

/*
 * Written by Ben Hinkle. This file is meant as an illustration of
 * n-d arrays using index and slicing expression templates with
 * specializations for 2D matrices for better performance.
 * The code is put into the public domain and can be used for
 * any purpose. No warranty etc etc...
 * Send comments to bhinkle4@juno.com
 */


