#include "mex.h"
#include "mexutil.h"

/*
 *  z = repmultiply(y,x)
 *
 * Multiply an input vector with an input matrix 
 *
 * Written by Masa-aki Sato 2007/07/01
 */

/* 
   error correlation
   dYX(n) = dY(n,:) * X(m,:)';  (m:fix)
   Z   = sum_t( Y(t) * X(t))
 */
void dyx_corr(double *x, double *y, double *z, int t)
{
	double *yt, *xt;
	int j;
	
	yt = y;
	xt = x;
	*z = 0;
	
  	for (j=0; j<t; j++) {
      	*z = *z + *xt * *yt;
      	xt++;
      	yt++;
    }
}

/* 
    % error update
    (m :fix)
    dY(n,t) = dY(n,t) - dW(n) * X(m,t);
 */
void dy_err(double *x, double w, double *y, int t)
{
	double *yt, *xt;
	int j;
	
	yt = y;
	xt = x;
	
	for (j=0; j<t; j++) {
      	*yt = *yt - *xt * w;
      	xt++;
      	yt++;
    }
}

/* The gateway routine */
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
	double *w,*Wout,*dy,*x,*xx,*a, dw, *dyx;
	int    m,M,n,N,t,T,x_cnt,y_cnt,w_cnt;
	mxArray *dyx_ptr;
	
	/*  Check for proper number of arguments. */
	if(nrhs!=5) 
	  mexErrMsgTxt("5 inputs required.");
	if(nlhs!=1) 
	  mexErrMsgTxt("One output required.");
	
	/*  Create a pointer to the inputs*/
	/*  W = sequential_weight_update(W,dY,X,XX,A)  */
	w  = mxGetPr(prhs[0]);
	dy = mxGetPr(prhs[1]);
	x  = mxGetPr(prhs[2]);
	xx = mxGetPr(prhs[3]);
	a  = mxGetPr(prhs[4]);
	
	/*  Get the dimensions of the matrix W */
	M = mxGetM(prhs[0]);
	N = mxGetN(prhs[0]);

	
	/*  Get the dimensions of the matrix input dy */
	T = mxGetM(prhs[1]);
	n = mxGetN(prhs[1]);
	t = mxGetM(prhs[2]);
	m = mxGetN(prhs[2]);

	if(T!=t) 
	  mexErrMsgTxt("T does not march.");
	if(N!=n) 
	  mexErrMsgTxt("N does not march.");
	if(M!=m) 
	  mexErrMsgTxt("M does not march.");
	
	/*  Set the output pointer to the output matrix. 
		plhs[0] = mxCreateDoubleMatrix(my,ny, mxREAL);
		Create uninitialized matrix for speed up
	*/
	plhs[0] = mxCreateDoubleMatrix(M,N,mxREAL);
	
	/*  Create a C pointer to a copy of the output matrix. */
	Wout = mxGetPr(plhs[0]);

	dyx_ptr = mxCreateDoubleMatrix(1,1,mxREAL);
	dyx = mxGetPr(dyx_ptr);
	
/* 
	W  : M x N : weight 
	dy : T x N : residual error , dY = Y - W * X
	x  : T x M : input
	xx : M x 1 : input variance,  XX = sum(X.^2,2) = diag(X * X')
	a  : M x 1 : weight precision parameter alpha
*/
	y_cnt = 0;
	w_cnt = 0;
	
	for (n=0; n<N; n++) {
		x_cnt = 0;
    	
		for (m=0; m<M; m++) {
			/* error correlation
			   (m, n: fix)
			   dYX = dY(n,:) * X(m,:)';  
			*/
			dyx_corr(x + x_cnt, dy + y_cnt, dyx, T);
			
			/*  Weight update
			    dW(n)  = (dYX(n) - W(n,m)*A(m))./ (XX(m) + A(m));
			*/
			dw = (*dyx - w[w_cnt] * a[m])/(xx[m] + a[m]);
			Wout[w_cnt] = w[w_cnt] + dw;
			
/*			printf("W= [%e , %e , %e] (%d) \n", Wout[w_cnt],w[w_cnt],dw,m); */
			
			/*  error update
			    dY(n,t) = dY(n,t) - dW(n) * X(m,t);
			*/
 			dy_err(x + x_cnt, dw, dy + y_cnt, T);
 			
			x_cnt += T;
			w_cnt++;

		}
		
		y_cnt += T;

	}
}
