function [gmm, llh] = GMM_EM(x,K,ploton)
% Copyright 2017: Steven Van Kuyk
% This program comes WITHOUT ANY WARRANTY.
%
% Train a Gaussian mixture model using the EM algorithm
% See 'Pattern Recognition and Machine Learning', Bishop
%
% Inputs
% x: training data where each column is an observation
% K: number of mixtures
% ploton: enable/disable plotting
%
% Outputs:
% gmm: model structure
% gmm.mu: DxK matrix where each column represents a mean
% gmm.sigma: DxDxK matrix containing covariance matrices
% gmm.w: vector containing the weights of each mixture
% llh: log-likelihood of the model (at each iteration)
display('training GMM...')

% initialize
maxIter = 2e3;
tol = 1e-10; % convergence tolerence
reg = 1e-4; % regularization to prevent singularities
[D,N] = size(x);

mu = x(:,randi(N,1,K));
sigma = repmat(eye(D),1,1,K);
w = ones(K,1)/K;

% k-means++ initialization (spreads out initial centroids)
if true
    mu(:,1)=x(:,randi(N));
    for k=2:K
        d = sum( abs(x-repmat(mu(:,k-1),1,N)).^2 )';
        [~,~,i] = histcounts(rand,[0 ; cumsum(d/sum(d))]);
        mu(:,k) = x(:,i);
    end
end

% apply k-means (faster + more robust initialization)
if true
    [mu, sigma, w] = Kmeans(x,mu,N,K,D,ploton);
end

% apply EM
llh = -inf;
Gamma = zeros(K,N); % p(z|x), posterior, responsibilities
for iter=1:maxIter
    % expectation
    for k=1:K
        sqrtSig = chol(sigma(:,:,k));
        %-0.5*log(det(sigma(:,:,k)))
        c = -0.5*D*log(2*pi) - sum(log(diag(sqrtSig)));
        quad = -0.5*dot(bsxfun(@minus,x,mu(:,k)) , sigma(:,:,k)\bsxfun(@minus,x,mu(:,k)));
        mvn = exp(c+quad); % multivariate gaussian
        
        num(k,:) = w(k)*mvn; % numerator in (9.23)
    end
    Gamma = exp( log(num) - log(repmat(sum(num),K,1)) ); % responsibilities (9.23)
    
    % log likelihood
    if iter>1
        llh(end+1) = sum( log( sum(num) ) );% log-likelihood (9.28)
        if (abs(llh(end)-llh(end-1))<tol*abs(llh(end)))
            display(sprintf('converged after %d iterations',iter))
            break
        end
    end
    
    % maximization
    Nk = max(sum(Gamma,2),eps); % (9.27)
    
    mu = (Gamma*x')./(repmat(Nk,1,D)); % update mean (9.24)
    mu=mu';
    for k=1:K % update covariance (9.25)
        x0 = (x-repmat(mu(:,k),1,N));
        sigma(:,:,k) = (1/Nk(k))*(repmat(Gamma(k,:),D,1).*x0)*x0';
        sigma(:,:,k) = sigma(:,:,k)+reg*eye(D);
    end
    w = Nk'/N; % update weights (9.26)
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%% plots
    if mod(iter,10)==0 && ploton
        figure(1); clf; hold on; set(gcf,'color','w');
        [~,z] = max(Gamma); % most likely cluster
        
        for k=1:K
            if Nk<1e-10
                continue % do not display clusters with no responsibility
            end
            ii = (z==k);
            
            % covariance
            scale = 3; % 3 stds
            [v, d] = eig( sigma(1:2,1:2,k) );
            [d, order] = sort(diag(d), 'descend');
            d = diag(d);
            v = v(:, order);
            t = linspace(0,2*pi,100);
            c = [cos(t) ; sin(t)];        % unit circle
            VV = scale*v*sqrt(d);         % scale eigenvectors
            c = bsxfun(@plus, VV*c, mu(1:2,k)); % project circle back to orig space
            fill(c(1,:),c(2,:),'k','EdgeColor','k','FaceAlpha',0.8*sqrt(w(k)))
        end
        
        for k=1:K
            ii = (z==k);
            % data
            plot(x(1,ii),x(2,ii),'.')
            % mean
            plot(mu(1,k),mu(2,k),'xk')
        end
        
        hold off
        title('Expectation Maximization','interpreter','latex')
        xlabel('dimension 1','interpreter','latex')
        ylabel('dimension 2','interpreter','latex')
        set(gca,'fontsize',8,'TickLabelInterpreter', 'latex')
        drawnow;
        %pause(0.01)
    end
    
end

if ploton
    figure(2)
    plot(llh/N)
    xlabel('iteration')
    ylabel('log-likelihood')
    title('EM')
end

if abs( llh(end)-llh(end-1) ) >= tol*abs(llh(end))
    display('warning: did not converge')
    iter
end

llh = llh(2:end)/N;
gmm.mu = mu;
gmm.sigma = sigma;
gmm.w = w(:);

%%
function [mu, sigma, w] = Kmeans(x,mu,N,K,D,ploton)
labels_old = zeros(1,N);
labels_new = ones(1,N);
dx2 = sum(x.^2); % squared magnitude of x

while any(labels_old~=labels_new)
    if ploton
        figure(1)
        clf; hold on
    end
    labels_old = labels_new;
    d2 = bsxfun(@plus,sum(mu.^2)',dx2)-2*(mu'*x); % squared difference between x and cluster means (https://statinfer.wordpress.com/2011/11/14/efficient-matlab-i-pairwise-distances/)
    [~,labels_new] = min(d2);
    for k=1:K
        ii = k==labels_new;
        mu(:,k)=mean(x(:,ii),2);
        
        if ploton
            plot(x(1,ii),x(2,ii),'.',mu(1,k),mu(2,k),'kx')
        end
    end
    
    if ploton
        title('Expectation Maximization','interpreter','latex')
        xlabel('dimension 1','interpreter','latex')
        ylabel('dimension 2','interpreter','latex')
        set(gca,'fontsize',8,'TickLabelInterpreter', 'latex')
        pause(0.05)
        hold off
    end
end

% compute initial weights and covariance based on k-means
for k=1:K
    ii=k==labels_new;
    x0 = x(:,ii)-repmat(mu(:,k),1,sum(ii));
    sigma(:,:,k)=(1/sum(ii))*x0*x0';
    sigma(:,:,k) = sigma(:,:,k)+1e-10*eye(D); % covariance
    
    w(k)=sum(ii)/N; % weights
end

% empty clusters
for k=1:K
    if sum(k==labels_new)==0
        w(k)=0;
        mu(:,k) = zeros(D,1);
        sigma(:,:,k) = eye(D);
    end
end