function [gmm, L] = GMM_VI(x,K,ploton)
% Copyright 2017: Steven Van Kuyk
% This program comes WITHOUT ANY WARRANTY.
%
% Train a Gaussian mixture model using variational inference
% See 'Pattern Recognition and Machine Learning', Bishop
%
% Inputs
% x: training data where each column is an observation
% K: number of mixtures
% ploton: enable/disable plotting
%
% Outputs:
% gmm: model structure
% gmm.mu: DxK matrix where each column represents a mean
% gmm.sigma: DxDxK matrix containing covariance matrices
% gmm.w: vector containing the weights of each mixture
% L: evidence lower bound of the model (at each iteration)
display('training GMM...')

maxIter = 2e3;
tol = 1e-10; % convergence tolerence
[D,N] = size(x);

% hyper parameters for generating GMMs (section 10.2, Bishop)
alpha0 = 0.1*(1/K); % Dirichlet prior (for weights)
W0 = 1*eye(D); % Wishart prior (for precision matrix i.e., inverse covariance)
v0 = D+1;  % Wishart degrees of freedom
m0 = mean(x,2); % Gaussian prior (for mean)
beta0 = 0.1; % Gaussian prior (for precision of the mean)

% random initialization
z = ceil(K*rand(1,N)); % labels
r = zeros(K,N); % p(z|x), responsibilities
for k=1:K
    r(k,k==z) = 1;
end

% pre-allocation
S = zeros(D,D,k);
Winv = zeros(D,D,k);
Equad = zeros(K,N);
ElogGamma = zeros(K,1);

% apply VI
L=-inf;
for iter=1:maxIter;
    % Maximization-like step
    Nk = max(sum(r,2),eps); % 10.51
    xbar = bsxfun(@times,x*r',1./Nk'); % 10.52
    for k=1:K
        x0 = bsxfun(@minus,x,xbar(:,k));
        x0 = bsxfun(@times,x0,sqrt(r(k,:)));
        S(:,:,k) = (1/Nk(k))*(x0*x0'); % 10.53
    end
    
    alphaK = alpha0+Nk; % 10.58
    betaK = beta0+Nk; % 10.60
    m = bsxfun(@times,repmat(beta0*m0,1,K) + x*r',1./betaK'); % 10.61
    
    for k=1:K
        Winv(:,:,k) = inv(W0) + Nk(k)*S(:,:,k) + (beta0*Nk(k)/(beta0+Nk(k))) * (xbar(:,k)-m0)*(xbar(:,k)-m0)'; % 10.62
    end
    vk = v0+Nk; % 10.63
    
    % compute lower bound
    if iter>1
        traceSW=zeros(size(Nk));
        quadXWX=zeros(size(Nk));
        quadMWM=zeros(size(Nk));
        traceWW = zeros(size(Nk));
        Hq = zeros(size(Nk));
        for k=1:K
            traceSW(k) = sum(diag( S(:,:,k)/Winv(:,:,k) ));
            quadXWX(k) = (xbar(:,k)-m(:,k))'*(Winv(:,:,k)\(xbar(:,k)-m(:,k)));
            quadMWM(k) = (m(:,k)-m0)'*(Winv(:,:,k)\(m(:,k)-m0));
            traceWW(k) = sum(diag( Winv(:,:,k)\inv(W0) ));
            
            logB = vk(k)*sum(log(diag(chol(Winv(:,:,k))))) - 0.5*vk(k)*D*log(2) - (D*(D-1)/4)*log(pi) - sum(gammaln( 0.5*(vk(k)+1-(1:D)) )); % log of B.79
            Hq(k) = -logB - 0.5*(vk(k)-D-1)*ElogGamma(k) + 0.5*vk(k)*D; % B.82
        end
        logB0 = -v0*sum(log(diag(chol(W0)))) - 0.5*v0*D*log(2) - (D*(D-1)/4)*log(pi) - sum(gammaln( 0.5*(v0+1-(1:D)) )); % log of B.79
        
        term1 = 0.5*sum( Nk.*(ElogGamma - D./betaK - vk.*traceSW - -vk.*quadXWX -D*log(2*pi)  ) ); % 10.71
        term2 = sum(r'*Elogw); % 10.72
        term3 = gammaln(K*alpha0)-K*gammaln(alpha0)+(alpha0-1)*sum(Elogw); % 10.73
        term4 = 0.5*sum( D*log(beta0/(2*pi)) + ElogGamma -D*beta0./betaK - beta0*vk.*quadMWM ) + K*logB0 + 0.5*(v0-D-1)*sum(ElogGamma) - 0.5*sum(vk.*traceWW); %10.74
        term5 = sum(sum(r.*logr)); % 10.75
        term6 = (alphaK'-1)*Elogw + gammaln(sum(alphaK)) - sum(gammaln(alphaK)); % 10.76
        term7 = sum( 0.5*ElogGamma + 0.5*D*log(betaK/(2*pi)) - D/2 - Hq ); % 10.77
        L(end+1) = term1(end) + term2(end) + term3(end) + term4(end) - term5(end) - term6(end) - term7(end); % 10.70
        
        if L(end)<L(end-1)
            display( sprintf('warning: lower bound decreased by %d',(L(end)-L(end-1))/N) )
        end
        if abs( L(end)-L(end-1) ) < tol*abs(L(end)) && iter>40
            display(sprintf('converged after %d iterations',iter))
            break
        end
    end
    
    % Expectation-like step
    for k=1:K
        x0=x-repmat(m(:,k),1,N);
        Equad(k,:) = (D/betaK(k)) + vk(k)*dot(x0,Winv(:,:,k)\x0); % 10.64
        ElogGamma(k) = D*log(2) + (-2*sum(log(diag(chol(Winv(:,:,k)))))) + sum( psi(0,(vk(k)+1-(1:D))/2) );% 10.65
    end
    Elogw = psi(0,alphaK)-psi(0,sum(alphaK));% 10.66
    
    logRho = bsxfun(@plus,Elogw + 0.5*ElogGamma - 0.5*D*log(2*pi), -0.5*Equad); % 10.46
    %r = bsxfun( @times,exp(logRho),1./sum(exp(logRho)) ); % 10.49 (unstable version)
    %logr = bsxfun(@minus,logRho, log(sum(exp(logRho))) ); % (might overflow version)
    mx = max(logRho);
    logsumexp = mx + log(sum(exp( bsxfun(@minus,logRho,mx) ))); % 10.49
    logr = bsxfun(@minus,logRho,logsumexp); % 10.49
    r=exp( logr ); % 10.49
    
    % predictive distribution
    mu = m; % 10.81
    w = alphaK/sum(alphaK); % 10.81
    for k=1:K
        scale = (1+betaK(k))/(betaK(k)*(vk(k)+1-D));
        sigma(:,:,k) = scale*Winv(:,:,k); % 10.82
    end
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%% plots
    if mod(iter,10)==0 && ploton
        figure(3); clf; hold on; set(gcf,'color','w');
        [~,z] = max(r); % most likely cluster
        
        for k=1:K
            if sum(r(k,:),2)<1
                continue % do not display clusters with no responsibility
            end
            ii = (z==k);
            
            % covariance
            scale = 3; % 3 stds
            [v, d] = eig( sigma(1:2,1:2,k) );
            [d, order] = sort(diag(d), 'descend');
            d = diag(d);
            v = v(:, order);
            t = linspace(0,2*pi,100);
            c = [cos(t) ; sin(t)];        % unit circle
            VV = scale*v*sqrt(d);         % scale eigenvectors
            c = bsxfun(@plus, VV*c, m(1:2,k)); % project circle back to orig space
            fill(c(1,:),c(2,:),'k','EdgeColor','k','FaceAlpha',0.8*sqrt(w(k)))
        end
        
        for k=1:K
            if sum(r(k,:),2)<1
                continue % do not display clusters with no responsibility
            end
            ii = (z==k);
            % data
            plot(x(1,ii),x(2,ii),'.')
            % mean
            plot(m(1,k),m(2,k),'xk')
        end
        
        hold off
        title('Variational Inference','interpreter','latex')
        xlabel('dimension 1','interpreter','latex')
        ylabel('dimension 2','interpreter','latex')
        set(gca,'fontsize',8,'TickLabelInterpreter', 'latex')
        drawnow;
        %pause(0.01)
    end
    
end

if abs( L(end)-L(end-1) ) >= tol*abs(L(end))
    display(sprintf('warning: did not converge after %d iterations',iter))
end

if ploton
    figure(4);
    plot(L/N)
    xlabel('iteration')
    ylabel('evidence lower bound')
    title('VI')
end

L = L(2:end)/N;
gmm.mu = mu;
gmm.sigma = sigma;
gmm.w = w(:);
