% Calculation of objective and gradient for the mean/covariance matrix
% parameterization of the Multivariate Normal.
%
% function [obj,grad,rowObj] = multiNormal(v,X,varargin)
% v - vector of parameters [d*(k+1),1]
% X - data points (one per row) [n,d]
% obj - value of objective at v [scalar]
% grad - gradient at v [d,1]
% rowObj - objective per datum/row [n,1]
%
% Written by Jason Rennie, July 2006
% Last modified: Mon Aug 21 18:42:05 2006

% To test:
% checkgrad2(@multiNormal,randn(12,1),{rand(1000,3)})

function [obj,grad,rowObj] = multiNormal(v,X,varargin)
  fn = mfilename;
  if nargin < 2
    error('insufficient parameters')
  end
  % Parameters that can be set via varargin
  verbose = 1;
  % Process varargin
  paramgt;
  
  t0 = clock;
  [n,d] = size(X);
  if mod(length(v),d) ~= 0
    error('Sizes of X and v are incompatible. size(X)=[%d,%d] length(v)=%d',n,d,length(v));
  end
  k = length(v)/d - 1;
  mu = reshape(v(1:d),1,d);
  L = reshape(v(d+1:d*(k+1)),d,k);

  Xmu = X - repmat(mu,n,1);
  S = L*L';
  if rank(S) ~= d
    str = sprintf('Rank of L*L'' must be d=%d.  rank(L*L'')=%d\n',d,rank(S));
    error(str);
  end
  Sinv = inv(S);
  rowObj2 = ones(n,1).*log(det(S)) + sum((Xmu*Sinv).*Xmu,2);
  rowObj = rowObj2./2;
  obj = sum(rowObj);
  
  dmu = -2.*sum(Xmu*Sinv,1);
  dL = 2.*n.*Sinv*L - 2.*Sinv*Xmu'*Xmu*Sinv*L;
  grad = [dmu(:); dL(:)]./2;

  if verbose
    fprintf(1,'obj=%.2e grad''*grad=%.2e time=%.1f\n',obj,grad'*grad,etime(clock,t0));
  end

% ChangeLog
% 8/21/06 - It works!!!!  Key to generating data: randn(n,k)*Lambda' + mu; then
% covariance is Lambda*Lambda' and mean is mu
% 8/21/06 - fixed typo (was multiplying zero by log(det(S))!!!)
% 8/21/06 - checkgrad2(@multiNormal,rand(6,1),{rand(1000,2)}) fails
% 8/18/06 - convert to Multivariate Normal objective and gradient
% 8/17/06 - Use sum((Xmu*Sinv).*Xmu,2) to calculate row objective value
% 8/17/06 - Added rowObj return value; ran checkgrad2(@faML,rand(70,1),{rand(100,10)})
% 8/10/06 - Implemented and checked (checkgrad2) derivative for L
% 8/8/06 - Implemented and checked (checkgrad2) derivative for psi
% 8/8/06 - Implemented and checked (checkgrad2) derivative for mu
% 8/8/06 - Implemented objective
