GMM的EM算法实现

GMM的EM算法实现

大家好,又见面了,我是全栈君,祝每个程序员都可以多学几门语言。

聚类算法K-Means, K-Medoids, GMM, Spectral clustering,Ncut一文中我们给出了GMM算法的基本模型与似然函数,在EM算法原理中对EM算法的实现与收敛性证明进行了具体说明。本文主要针对怎样用EM算法在混合高斯模型下进行聚类进行代码上的分析说明。


1. GMM模型:

每一个 GMM 由 K 个 Gaussian 分布组成,每一个 Gaussian 称为一个“Component”,这些 Component 线性加成在一起就组成了 GMM 的概率密度函数:

<span>GMM的EM算法实现</span>

依据上面的式子,如果我们要从 GMM 的分布中随机地取一个点的话,实际上能够分为两步:首先随机地在这 K个Gaussian Component 之中选一个,每一个 Component 被选中的概率实际上就是它的系数 pi(k) ,选中了 Component 之后,再单独地考虑从这个 Component 的分布中选取一个点就能够了──这里已经回到了普通的 Gaussian 分布,转化为了已知的问题。

那么怎样用 GMM 来做 clustering 呢?事实上非常easy,如今我们有了数据,假定它们是由 GMM 生成出来的,那么我们仅仅要依据数据推出 GMM 的概率分布来就能够了,然后 GMM 的 K 个 Component 实际上就相应了 K 个 cluster 了。依据数据来推算概率密度通常被称作 density estimation ,特别地,当我们在已知(或假定)了概率密度函数的形式,而要预计当中的參数的过程被称作“參数预计”。


2. 參数与似然函数:

如今如果我们有 N 个数据点,并如果它们服从某个分布(记作 p(x) ),如今要确定里面的一些參数的值,比如,在 GMM 中,我们就须要确定 影响因子pi(k)、各类均值pMiu(k) 和 各类协方差pSigma(k) 这些參数。 我们的想法是,找到这样一组參数,它所确定的概率分布生成这些给定的数据点的概率最大,而这个概率实际上就等于 <span>GMM的EM算法实现</span> ,我们把这个乘积称作似然函数 (Likelihood Function)。通常单个点的概率都非常小,很多非常小的数字相乘起来在计算机里非常容易造成浮点数下溢,因此我们一般会对其取对数,把乘积变成加和 \sum_{i=1}^N \log p(x_i),得到 log-likelihood function 。接下来我们仅仅要将这个函数最大化(通常的做法是求导并令导数等于零,然后解方程),亦即找到这样一组參数值,它让似然函数取得最大值,我们就觉得这是最合适的參数,这样就完毕了參数预计的过程。

以下让我们来看一看 GMM 的 log-likelihood function :

<span>GMM的EM算法实现</span>

因为在对数函数里面又有加和,我们没法直接用求导解方程的办法直接求得最大值。为了解决问题,我们採取之前从 GMM 中随机选点的办法:分成两步,实际上也就相似于K-means 的两步。



3. 算法流程:

1.  预计数据由每一个 Component 生成的概率(并非每一个 Component 被选中的概率):对于每一个数据 x_i 来说,它由第 k 个 Component 生成的概率为

<span>GMM的EM算法实现</span>


当中N(xi | μk,Σk)就是后验概率<span>GMM的EM算法实现</span>


2. 通过极大似然预计能够通过求到令參数=0得到參数pMiu,pSigma的值。具体请见这篇文章第三部分。

<span>GMM的EM算法实现</span>

当中 N_k = \sum_{i=1}^N \gamma(i, k) ,而且 \pi_k 也顺理成章地能够预计为 N_k/N 。


3. 反复迭代前面两步,直到似然函数的值收敛为止。



4. matlab实现GMM聚类代码与解释:


说明:fea为训练样本数据,gnd为样本标号。算法中的思想和上面写的一模一样,在最后的推断accuracy方面,因为聚类和分类不同,仅仅是得到一些 cluster ,而并不知道这些 cluster 应该被打上什么标签,或者说。因为我们的目的是衡量聚类算法的 performance ,因此直接假定这一步能实现最优的相应关系,将每一个 cluster 相应到一类上去。一种办法是枚举全部可能的情况并选出最优解,另外,对于这种问题,我们还能够用 Hungarian algorithm 来求解。具体的Hungarian代码我放在了资源里,调用方法已经写在以下函数中了。


注意:资源里我放的是Kmeans的代码,大家下载的时候仅仅要用bestMap.m等几个文件就好~


1. gmm.m,最核心的函数,进行模型与參数确定。

function varargout = gmm(X, K_or_centroids)
% ============================================================
% Expectation-Maximization iteration implementation of
% Gaussian Mixture Model.
%
% PX = GMM(X, K_OR_CENTROIDS)
% [PX MODEL] = GMM(X, K_OR_CENTROIDS)
%
%  - X: N-by-D data matrix.
%  - K_OR_CENTROIDS: either K indicating the number of
%       components or a K-by-D matrix indicating the
%       choosing of the initial K centroids.
%
%  - PX: N-by-K matrix indicating the probability of each
%       component generating each point.
%  - MODEL: a structure containing the parameters for a GMM:
%       MODEL.Miu: a K-by-D matrix.
%       MODEL.Sigma: a D-by-D-by-K matrix.
%       MODEL.Pi: a 1-by-K vector.
% ============================================================
% @SourceCode Author: Pluskid (http://blog.pluskid.org)
% @Appended by : Sophia_qing (http://blog.csdn.net/abcjennifer)
    

%% Generate Initial Centroids
    threshold = 1e-15;
    [N, D] = size(X);
 
    if isscalar(K_or_centroids) %if K_or_centroid is a 1*1 number
        K = K_or_centroids;
        Rn_index = randperm(N); %random index N samples
        centroids = X(Rn_index(1:K), :); %generate K random centroid
    else % K_or_centroid is a initial K centroid
        K = size(K_or_centroids, 1); 
        centroids = K_or_centroids;
    end
 
    %% initial values
    [pMiu pPi pSigma] = init_params();
 
    Lprev = -inf; %上一次聚类的误差
    
    %% EM Algorithm
    while true
        %% Estimation Step
        Px = calc_prob();
 
        % new value for pGamma(N*k), pGamma(i,k) = Xi由第k个Gaussian生成的概率
        % 或者说xi中有pGamma(i,k)是由第k个Gaussian生成的
        pGamma = Px .* repmat(pPi, N, 1); %分子 = pi(k) * N(xi | pMiu(k), pSigma(k))
        pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K); %分母 = pi(j) * N(xi | pMiu(j), pSigma(j))对全部j求和
 
        %% Maximization Step - through Maximize likelihood Estimation
        
        Nk = sum(pGamma, 1); %Nk(1*k) = 第k个高斯生成每一个样本的概率的和,全部Nk的总和为N。
        
        % update pMiu
        pMiu = diag(1./Nk) * pGamma' * X; %update pMiu through MLE(通过令导数 = 0得到)
        pPi = Nk/N;
        
        % update k个 pSigma
        for kk = 1:K 
            Xshift = X-repmat(pMiu(kk, :), N, 1);
            pSigma(:, :, kk) = (Xshift' * ...
                (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);
        end
 
        % check for convergence
        L = sum(log(Px*pPi'));
        if L-Lprev < threshold
            break;
        end
        Lprev = L;
    end
 
    if nargout == 1
        varargout = {Px};
    else
        model = [];
        model.Miu = pMiu;
        model.Sigma = pSigma;
        model.Pi = pPi;
        varargout = {Px, model};
    end
 
    %% Function Definition
    
    function [pMiu pPi pSigma] = init_params()
        pMiu = centroids; %k*D, 即k类的中心点
        pPi = zeros(1, K); %k类GMM所占权重(influence factor)
        pSigma = zeros(D, D, K); %k类GMM的协方差矩阵,每一个是D*D的
 
        % 距离矩阵,计算N*K的矩阵(x-pMiu)^2 = x^2+pMiu^2-2*x*Miu
        distmat = repmat(sum(X.*X, 2), 1, K) + ... %x^2, N*1的矩阵replicateK列
            repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...%pMiu^2,1*K的矩阵replicateN行
            2*X*pMiu';
        [~, labels] = min(distmat, [], 2);%Return the minimum from each row
 
        for k=1:K
            Xk = X(labels == k, :);
            pPi(k) = size(Xk, 1)/N;
            pSigma(:, :, k) = cov(Xk);
        end
    end
 
    function Px = calc_prob() 
        %Gaussian posterior probability 
        %N(x|pMiu,pSigma) = 1/((2pi)^(D/2))*(1/(abs(sigma))^0.5)*exp(-1/2*(x-pMiu)'pSigma^(-1)*(x-pMiu))
        Px = zeros(N, K);
        for k = 1:K
            Xshift = X-repmat(pMiu(k, :), N, 1); %X-pMiu
            inv_pSigma = inv(pSigma(:, :, k));
            tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);
            coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));
            Px(:, k) = coef * exp(-0.5*tmp);
        end
    end
end

2. gmm_accuracy.m调用gmm.m,计算准确率:

function [ Accuracy ] = gmm_accuracy( Data_fea, gnd_label, K )
%Calculate the accuracy Clustered by GMM model

px = gmm(Data_fea,K);
[~, cls_ind] = max(px,[],1); %cls_ind = cluster label 
Accuracy = cal_accuracy(cls_ind, gnd_label);

    function [acc] = cal_accuracy(gnd,estimate_label)
        res = bestMap(gnd,estimate_label);
        acc = length(find(gnd == res))/length(gnd);
    end

end

3. 主函数调用

gmm_acc = gmm_accuracy(fea,gnd,N_classes);








写了本文进行总结后自己非常受益,也希望大家能够好好YM下上面pluskid的gmm.m,不光是算法,当中的矩阵处理代码也写的非常简洁,非常值得学习。

另外看了两份东西非常受益,一个是pluskid大牛的漫谈 Clustering (3): Gaussian Mixture Model》,一个是JerryLead的EM算法具体解释,大家有兴趣也能够看一下,写的非常好。



关于Machine Learning很多其它的学习资料与相关讨论将继续更新,敬请关注本博客和新浪微博Sophia_qing




版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/118973.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • sqlserver中类型decimal

    sqlserver中类型decimaldecimal(18,0)18是定点精度,0是小数位数。decimal(a,b)a指定指定小数点左边和右边可以存储的十进制数字的最大个数,最大精度38。b指定小数点右边可以存储的十进制数字的最大个数。小数位数必须是从0到a之间的值。默认小数位数是0。

    2022年7月20日
    73
  • 卷积神经网络(3)—-经典网络

    卷积神经网络(3)—-经典网络

    2021年6月17日
    150
  • kafka删除topic数据[通俗易懂]

    kafka删除topic数据[通俗易懂]kafka删除topic数据一、概述生产环境中,有一个topic的数据量非常大。这些数据不是非常重要,需要定期清理。要求:默认保持24小时,某些topic需要保留2小时或者6小时二、清除方式主要有3个:1.基于时间2.基于日志大小3.基于日志起始偏移量详情,请参考链接:https://blog.csdn.net/u013256816/article/details/80418297接下来,主要介绍基于时间的清除!kafka版本为:2.11-..

    2022年10月17日
    4
  • pyd文件介绍

    pyd文件介绍pyd 一般是 python 外的其他语言如 C C 编写的 python 扩展模块 即 python 的一个动态链接库 与 dll 文件相当 在 linux 系统中一般为 so 文件 也有的时候 为了对 python 文件进行加密 会把 python 模块编译成 pyd 文件 供其他人使用 拿到一个 pyd 文件 在没有文档说明的情况下 可以试试查看模块内的一些函数和类的用法 首先 importXXX pyd 的文件名 然后直接 print dir XXX print help XXX 其中 dir 列出了属性和方法 help

    2025年8月12日
    2
  • java全局变量和静态变量的区别_java局部静态变量

    java全局变量和静态变量的区别_java局部静态变量全局变量和局部变量的简介:1、全局变量:描述对象有什么(在类中定义),类中所有方法都可以使用。2、局部变量:临时保存数据(在类的方法中定义),只能在当前方法中使用。全局变量和局部变量的区别:1、作用域不同全局变量=整个类中可用局部变量=该方法内可用2、初始值不同全局变量=有初始值局部变量=无初始值3、声明上不同全局变量=在一个类中不允许声明同名变量局部变量=在一个方…

    2022年8月21日
    4
  • redis常见命令使用

    redis常见命令使用

    2021年10月10日
    46

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号