% function [beta,U,W,mx,my,r2,r2cross,Yt] = pls_v1(X,Y,k,Xt)
%
% Partial Least Squares (for only one output)
%
% Inputs: (pass [] to obtain default value)
%
% X     - input data
% Y     - output data
% k     - number of projections to be used
% Xt    - test data (can be empty)
%
% Outputs:
%
% beta    - regression coefficients
% U       - loading vectors
% W       - input regression vectors
% mx      - mean of inputs
% my      - mean of outputs
% Yt      - prediction for Xt
% r2      - coefficient of determination
% r2cross - r2 computed with PRESS residual
%
% Partial least squares is evaluated according to:
%
%    yp = SUM_k( beta(k) * s(k) ) + my,
%    where s(k) = U(:,k)' * d(k-1)
%          d(k) = d(k-1) - W(:,k)*s(k)
%          d(0) = (x-mx)
%
% 	Stefan Schaal, May 2005

function [beta,U,W,mx,my,r2,r2cross,Yt] = pls_v1(X,Y,k,Xt)


[n,dx] = size(X);
[n,dy] = size(Y);

if dy > 1,
  error('lwpls allows only one output');
end;

% initialize appropriately
if (~exist('k') | isempty(k)),
  k = dx;
end;

% calculate and subtract mean values
mx = mean(X)';
my = mean(Y)';

Xs = X;
Ys = Y;
X  = X-ones(n,1)*mx';
Y  = Y-ones(n,1)*my';

% perform partial least squares
D  =  X;
E  =  Y;

for i=1:k,

  % compute variables of current projection
  U(:,i)    = D'*E;
  U(:,i)    = U(:,i)/norm(U(:,i));
  S(:,i)    = D*U(:,i);
  invss(i,1)= 1/(S(:,i)'*S(:,i));
  W(:,i)    = (invss(i)*S(:,i)'*D)';
  beta(i,1) = invss(i)*S(:,i)'*E;

  % prepare next iteration
  D         = D  - S(:,i)*W(:,i)';
  E         = E  - S(:,i)*beta(i,1);

end;

if nargin > 3,
  % prediction for the test data
  D  = Xt-ones(length(Xt(:,1)),1)*mx';
  Yt = zeros(length(Xt(:,1)),1);
  for i=1:k,
    s  = D*U(:,i);
    Yt = Yt + s*beta(i);
    D  = D - s*W(:,i)';
  end;
  Yt = Yt + my;
end

% compute r2

% prediction for the training data
D  = Xs-ones(n,1)*mx';
Yp = zeros(length(D(:,1)),1);
for i=1:k,
  s  = D*U(:,i);
  Yp = Yp + s*beta(i);
  D  = D - s*W(:,i)';
end;
Yp = Yp + my;

res  = Ys-Yp;
MSE  = mean(res.^2);
varY = var(Ys,1);
r2   = (varY-MSE)./varY;
%r2   = mean((Ys - Yp)).^2./varY;

% compute leave one out cross validation error
res_cross = res./(1-sum(S.^2.*(ones(n,1)*invss'),2));
MSE_cross = mean(res_cross.^2);
r2cross   = (varY-MSE_cross)./varY;
%r2cross   = MSE_cross./varY;

