Machine Learning Mathematical Derivations (13): EM Algorithm and GMM
Chen Kai BOSS

The EM (Expectation-Maximization) algorithm is a general framework for handling latent variable models — when data contains unobserved latent variables, direct likelihood maximization becomes difficult. EM iterates between "expectation" and "maximization" steps, guaranteeing monotonic likelihood increase until convergence. From parameter estimation in Gaussian mixture models to image segmentation and speech recognition, the EM algorithm demonstrates both theoretical elegance and practical value. This chapter systematically derives the mathematical principles, convergence theory, Gaussian mixture models, and their variants.

Figure 4
Figure 5

Latent Variables and Incomplete Data

Problem Background

Complete data: Observed dataExtra close brace or missing open brace\mathbf{X} = \{\mathbf{x}_1, \dots, \mathbf{x}_N}and latent variablesExtra close brace or missing open brace\mathbf{Z} = \{\mathbf{z}_1, \dots, \mathbf{z}_N}

Incomplete data: Onlyis observed, latent variablesare unknown

Objective: Maximize the incomplete data log-likelihood

Difficulty: The sum is inside the logarithm, making direct optimization challenging.

Example: Gaussian Mixture Model - Observed: Data points - Latent variables: Component membershipExtra close brace or missing open bracez_i \in \{1, \dots, K} - Likelihood:The summation inside the logarithm makes direct differentiation difficult.

Jensen's Inequality and Log-Likelihood Lower Bound

Jensen's Inequality: For a concave functionand random variable:

For (concave function):

Application to log-likelihood: Introduce an arbitrary distribution:

ELBO (Evidence Lower Bound):is a lower bound on the log-likelihood.

KL Divergence and Tight Bound Condition

Rewriting the ELBO:

where KL divergence:

Tight bound condition: When,, and the lower bound equals the log-likelihood.

EM Algorithm Derivation

Algorithm Framework

Initialization: Parameters

Iteration (iteration):

E-step (Expectation): Fix, chooseto make the bound tight:

Compute the Q-function:

M-step (Maximization): Fix, optimize parameters:

Termination: Convergence (change in log-likelihood or parameters below threshold)

Meaning of the Q-function

Interpretation: The Q-function is the expected complete data log-likelihood under the posterior distribution of latent variables.

Convergence Proof

Theorem: The EM algorithm guarantees monotonic increase of the log-likelihood:

Proof:whereis entropy. The first inequality follows from non-negativity of KL divergence, and the second from the definition of the M-step.

Corollary: The EM algorithm converges to a (local) maximum or saddle point of the log-likelihood.

Gaussian Mixture Model (GMM)

Model Definition

Generative process: 1. Choose component, whereis the prior probability of component$kk_i (_k, _k)$ Joint distribution:

Marginal distribution (observed data likelihood):

Parameters:Extra close brace or missing open brace\boldsymbol{\theta} = \{\boldsymbol{\pi}, \{\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k}_{k=1}^K} Constraints:,

EM Algorithm for GMM

E-step: Compute responsibilities:

Physical meaning: Posterior probability that samplebelongs to component(soft assignment).

M-step: Update parameters.

Q-function:

Update (using Lagrange multipliers for constraint):

whereis the effective sample count for component.

Update: Differentiate with respect to:

Weighted average: Samplesweighted by responsibilities.

Update:

Weighted covariance.

Geometric Interpretation of GMM

K-means vs GMM: - K-means: Hard assignment (Extra close brace or missing open bracez_i \in \{1, \dots, K}), spherical clusters - GMM: Soft assignment (), ellipsoidal clusters, probabilistic output

Advantages of GMM: - Allows overlapping clusters - Different shapes for different clusters (covariance) - Provides uncertainty estimates

Initialization Strategies

The EM algorithm is sensitive to initialization. Common strategies:

  1. K-means initialization: Use K-means results to initialize,,

  2. Random initialization: Randomly selectpoints from data as

  3. Multiple restarts: Run EM multiple times, select result with highest log-likelihood

GMM Applications

Density Estimation

Objective: Learn data distributionGMM provides a flexible density model: any continuous distribution can be approximated by sufficiently many Gaussian components (universal approximation).

Applications: Anomaly detection, sample generation, probabilistic prediction.

Cluster Analysis

Hard clustering: Soft clustering: Retainas membership degrees

BIC model selection: Choose number of components:

whereis the number of parameters:.

Image Segmentation

Problem: Cluster image pixels intoregions.

Features: RGB color, position coordinates, texture features.

GMM model: Each component corresponds to a region.

Segmentation result: Assign pixelto.

Complete Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import numpy as np
from scipy.stats import multivariate_normal

class GMM:
def __init__(self, n_components=3, max_iter=100, tol=1e-4, init='kmeans'):
self.n_components = n_components
self.max_iter = max_iter
self.tol = tol
self.init = init

self.weights = None # pi_k
self.means = None # mu_k
self.covariances = None # Sigma_k
self.converged = False

def _initialize(self, X):
N, d = X.shape
K = self.n_components

if self.init == 'kmeans':
# K-means initialization
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=K, random_state=0).fit(X)
self.means = kmeans.cluster_centers_
labels = kmeans.labels_

self.weights = np.array([np.sum(labels == k) / N for k in range(K)])
self.covariances = np.array([
np.cov(X[labels == k].T) + 1e-6 * np.eye(d) for k in range(K)
])
elif self.init == 'random':
# Random initialization
indices = np.random.choice(N, K, replace=False)
self.means = X[indices]
self.weights = np.ones(K) / K
self.covariances = np.array([np.cov(X.T) + 1e-6 * np.eye(d) for _ in range(K)])

def _e_step(self, X):
"""E-step: Compute responsibilities"""
N = X.shape[0]
K = self.n_components
gamma = np.zeros((N, K))

for k in range(K):
gamma[:, k] = self.weights[k] * multivariate_normal.pdf(
X, mean=self.means[k], cov=self.covariances[k]
)

# Normalize
gamma_sum = gamma.sum(axis=1, keepdims=True)
gamma /= gamma_sum + 1e-10

return gamma

def _m_step(self, X, gamma):
"""M-step: Update parameters"""
N, d = X.shape
K = self.n_components

# Effective sample counts
N_k = gamma.sum(axis=0)

# Update weights
self.weights = N_k / N

# Update means
self.means = (gamma.T @ X) / N_k[:, np.newaxis]

# Update covariances
for k in range(K):
diff = X - self.means[k]
self.covariances[k] = (gamma[:, k, np.newaxis] * diff).T @ diff / N_k[k]
self.covariances[k] += 1e-6 * np.eye(d) # Regularization

def _compute_log_likelihood(self, X):
"""Compute log-likelihood"""
N = X.shape[0]
K = self.n_components

log_likelihood = 0
for i in range(N):
prob = 0
for k in range(K):
prob += self.weights[k] * multivariate_normal.pdf(
X[i], mean=self.means[k], cov=self.covariances[k]
)
log_likelihood += np.log(prob + 1e-10)

return log_likelihood

def fit(self, X):
"""Train GMM"""
self._initialize(X)

log_likelihood_old = -np.inf

for iteration in range(self.max_iter):
# E-step
gamma = self._e_step(X)

# M-step
self._m_step(X, gamma)

# Check convergence
log_likelihood = self._compute_log_likelihood(X)

if abs(log_likelihood - log_likelihood_old) < self.tol:
self.converged = True
print(f"Converged at iteration {iteration + 1}")
break

log_likelihood_old = log_likelihood

return self

def predict(self, X):
"""Predict cluster labels (hard assignment)"""
gamma = self._e_step(X)
return np.argmax(gamma, axis=1)

def predict_proba(self, X):
"""Predict probabilities (soft assignment)"""
return self._e_step(X)

def score(self, X):
"""Compute log-likelihood"""
return self._compute_log_likelihood(X)

def sample(self, n_samples=100):
"""Generate samples"""
# Sample components
components = np.random.choice(
self.n_components, size=n_samples, p=self.weights
)

# Sample from each component
samples = np.zeros((n_samples, self.means.shape[1]))
for k in range(self.n_components):
mask = components == k
n_k = np.sum(mask)
if n_k > 0:
samples[mask] = np.random.multivariate_normal(
self.means[k], self.covariances[k], size=n_k
)

return samples, components

# Example usage
if __name__ == '__main__':
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

# Generate data
X, y_true = make_blobs(n_samples=500, centers=3, n_features=2,
cluster_std=0.6, random_state=42)

# Train GMM
gmm = GMM(n_components=3, max_iter=100)
gmm.fit(X)

# Predict
y_pred = gmm.predict(X)
gamma = gmm.predict_proba(X)

# Visualization
plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.scatter(X[:, 0], X[:, 1], c=y_true, cmap='viridis')
plt.title('True Labels')

plt.subplot(132)
plt.scatter(X[:, 0], X[:, 1], c=y_pred, cmap='viridis')
plt.scatter(gmm.means[:, 0], gmm.means[:, 1],
c='red', marker='x', s=200, linewidths=3)
plt.title('GMM Clustering')

plt.subplot(133)
plt.scatter(X[:, 0], X[:, 1], c=gamma[:, 0], cmap='coolwarm')
plt.title('Responsibility for Component 1')
plt.colorbar()

plt.tight_layout()
plt.savefig('gmm_clustering.png', dpi=150)
plt.show()

print(f"Log-likelihood: {gmm.score(X):.2f}")
print(f"Weights: {gmm.weights}")

# Generate new samples
X_new, _ = gmm.sample(n_samples=100)
plt.figure()
plt.scatter(X_new[:, 0], X_new[:, 1], alpha=0.6)
plt.scatter(gmm.means[:, 0], gmm.means[:, 1],
c='red', marker='x', s=200, linewidths=3)
plt.title('Samples Generated by GMM')
plt.savefig('gmm_samples.png', dpi=150)
plt.show()

Q&A

Q1: Why is the EM algorithm called Expectation-Maximization?

A: The E-step computes the Q-function (the expectation of complete data log-likelihood), and the M-step maximizes the Q-function. The name directly describes the algorithm steps.


Q2: Does the EM algorithm converge to the global optimum?

A: Not guaranteed. EM guarantees monotonic increase of log-likelihood but may converge to local optima or saddle points. Solutions: - Multiple random initializations - Better initialization (K-means) - Combine with other optimization algorithms (e.g., gradient descent)


Q3: How to choose the number of components K in GMM?

A: Model selection criteria: - BIC: - AIC: - Cross-validation: Maximize held-out log-likelihood - Silhouette coefficient: Clustering quality assessment

Typically choosethat maximizes BIC.


Q4: What is the relationship between GMM and K-means?

A: K-means is a special case of GMM: - All covariances (spherical clusters) - As, GMM degenerates to K-means (hard assignment)

GMM is a soft and generalized version of K-means.


Q5: Can the EM algorithm be parallelized?

A: The E-step is naturally parallel (each sample computesindependently), M-step is partially parallel (statistics computation can be parallelized, parameter updates need synchronization). Distributed implementations: Map-Reduce or parameter servers.


Q6: What to do when the covariance matrix is singular?

A: Common causes: High dimensionality, few samples in some components. Solutions: - Regularization: - Diagonal covariance: - Shared covariance: All components use the same - PCA dimensionality reduction


Q7: What are variants of the EM algorithm?

A: - Incremental EM: Online learning, batch updates - Stochastic EM: Random subset sampling at each iteration - Variational EM: E-step uses variational inference to approximate posterior - Generalized EM: M-step doesn't fully optimize (e.g., a few gradient ascent steps)


Q8: Can GMM handle high-dimensional data?

A: High-dimensional difficulties (curse of dimensionality): - Covariance parametersexplode - Sparse samples, inaccurate estimation

Solutions: - Diagonal covariance (assume feature independence) - Factor analysis (low-rank covariance) - Feature selection/dimensionality reduction


Q9: How to visualize high-dimensional GMM?

A: - PCA projection: Project to first 2-3 principal components - t-SNE: Nonlinear dimensionality reduction - Responsibility heatmap: Sample × componentmatrix - Component weights: Bar chart showing ---

Q10: What are other applications of the EM algorithm?

A: - Hidden Markov Models: Baum-Welch algorithm - Topic models: Variational EM for LDA - Missing data imputation: EM iteratively estimates missing values - Mixture of experts: Gating network + expert networks

EM is a general framework for handling latent variables.


✏️ Exercises and Solutions

Exercise 1: E-step Calculation

Problem: GMM with, , . Observation , compute. Solution:

Exercise 2: M-step Update

Problem: Samples, . Update. Solution:

Exercise 3: EM Convergence

Problem: What does EM converge to? Solution: Local optimum (or saddle point). EM guarantees likelihood increase but not global optimum.

Exercise 4: GMM vs K-means

Problem: Compare GMM and K-means. Solution: K-means: hard assignment, spherical clusters. GMM: soft assignment (probabilistic), elliptical clusters.

Exercise 5: Missing Data

Problem: Data, missing. How to use EM? Solution: E-step: estimate . M-step: update parameters using complete+filled data. Iterate.

✏️ Exercises and Solutions

Exercise 1: E-step Calculation

Problem: GMM with, , . Observation , compute. Solution:

Exercise 2: M-step Update

Problem: Samples, . Update. Solution:

Exercise 3: EM Convergence

Problem: What does EM converge to? Solution: Local optimum (or saddle point). EM guarantees likelihood increase but not global optimum.

Exercise 4: GMM vs K-means

Problem: Compare GMM and K-means. Solution: K-means: hard assignment, spherical clusters. GMM: soft assignment (probabilistic), elliptical clusters.

Exercise 5: Missing Data

Problem: Data, missing. How to use EM? Solution: E-step: estimate . M-step: update parameters using complete+filled data. Iterate.

Referencess

  1. Dempster, A. P., Laird, N. M., & Rubin, D. B. (1977). Maximum likelihood from incomplete data via the EM algorithm. Journal of the Royal Statistical Society: Series B, 39(1), 1-22.
  2. Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer. [Chapter 9: Mixture Models and EM]
  3. Murphy, K. P. (2012). Machine Learning: A Probabilistic Perspective. MIT Press. [Chapter 11: Mixture Models and EM]
  4. McLachlan, G., & Krishnan, T. (2007). The EM Algorithm and Extensions (2nd ed.). Wiley.
  5. Neal, R. M., & Hinton, G. E. (1998). A view of the EM algorithm that justifies incremental, sparse, and other variants. Learning in Graphical Models, 89, 355-368.

The EM algorithm, with its elegant mathematical structure and broad practical value, has become one of the cornerstones of modern machine learning. From parameter estimation in Gaussian mixture models to training hidden Markov models, from missing data handling to semi-supervised learning, the EM algorithm demonstrates the wisdom of "divide and conquer"— decomposing a difficult incomplete data optimization problem into simple "expectation" and "maximization" steps. Understanding the EM algorithm is not only mastering a classical algorithm but also appreciating the deep connection between probabilistic inference and optimization — a theme that runs through variational inference, Monte Carlo methods, and other modern Bayesian machine learning techniques.

  • Post title:Machine Learning Mathematical Derivations (13): EM Algorithm and GMM
  • Post author:Chen Kai
  • Create time:2021-11-05 09:45:00
  • Post link:https://www.chenk.top/Machine-Learning-Mathematical-Derivations-13-EM-Algorithm-and-GMM/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments