ST793 Project: A Blogpost on “Doubly Enhanced EM Algorithm for Model-Based Tensor Clustering” by Mai et al

Ayumi Mutoh, Jisu Oh, Shih-Ni Prim

December 7, 2023

1 Introduction

In recent decades, while tensor data have gained popularity in modern science, their high-dimensional structures often pose challenges for statistical analysis, specifically in model-based clustering. Model-based clustering is a statistical approach to data clustering, where observed data is considered to have been created from a finite combination of component models, such as the Gaussian mixture model (GMM). Since the formalization of the expected-maximization (EM) algorithm by Dempster et al. (1977), the EM algorithm has been widely employed in the majority of model-based clustering applications. While the GMMs can be readily extended to higher-order tensors using the standard EM algorithm, their performance can be further enhanced by integrating the Doubly Enhanced EM algorithm (DEEM), as proposed by Qing Mai and Deng (2022). Mai et al. consider a tensor normal mixture model (TNMM) that incorporates tensor correlation structure and variable selection for clustering and parameter estimation. They developed the DEEM algorithm which enables DEEM to excel in high-dimensional tensor data analysis. Similar to the EM algorithm, DEEM carries out an enhanced E-step and an enhanced M-step.

In this blogpost, we first introduce the DEEM methods with intermediate steps for the theoretical explanation. The objective is to break down the steps, making the derivation more accessible for our readers to follow. Subsequently, we will conduct a simulation study to evaluate the performance of DEEM.

2 Theoretical Derivation

2.1 EM Algorithm

Before delving into DEEM, we would like to review the EM algorithm and its functioning in clustering.

The EM algorithm is an iterative approach that cycles between two steps for maximum likelihood estimation in the presence of latent variables. The observed data Y is incomplete and data Z is missing. The first step is to write down the joint likelihood, Lc(𝜃|Y,Z), of the “complete” data (Y,Z). The “E” step of the EM algorithm is to compute the conditional expectation of log-likelihood, logLc(𝜃|Y,Z), given Y, assuming the true parameter value is 𝜃(ν)

Q(𝜃,𝜃(ν),Y ) = E 𝜃(ν)(logLc(𝜃|Y,Z)|Y ).

In the “M” step, we maximize Q(𝜃,𝜃(ν),Y ) with respect to 𝜃 with 𝜃(ν) fixed. We repeat the E step and M step until convergence.

The EM algorithm is well-known for its use in unsupervised learning problems such as clustering with a mixture model. The process from Yang et al. (2012) goes as follows:

  1. Identify the number of clusters.
  2. Define each cluster by generating a Gaussian model.
  3. For every observation, calculate the probability that it belongs to each cluster (Ex. observation 12 has 40% probability of belonging to Cluster A and 60% probability of belonging to Cluster B.)
  4. Use the above probabilities to recalculate the Gaussian models.
  5. Repeat until observations “converge” on their assignments.

Let’s consider a simple example. Suppose we have data Xi as shown in Figure 1, which comes from two distinct classes. We use this data to build a Gaussian model for each class. Since we don’t know which class each observation belongs to, there is no straightforward way to construct two Gaussian models to partition the data. Therefore, we begin with a random guess of our Gaussian model parameters: μ1,σ12,μ2,σ22.

We have ‘missing’ data points Xi that we believe belong to either of the two distributions. After initializing two random Gaussian models, we compute the likelihood of each observation, Xi, being expressed in both of the Gaussian models. The next is the E-step, where we compute the probability that each Xi can belong to any of two distributions. Now we have each point’s probability of belonging to either distribution.

In the M-step, we update the parameters, μ1,σ12,μ2,σ22, of the model to their most likely values. For the new μ1, we take a weighted average of all the points, weighted by the probability that they belong to the first distribution. Denoting pi is the probability that Xi belongs to the first distribution.

μ1 = p1X1 + p2X2 + + pnXn p1 + p2 + + pn

The new σ12 can be updated similarly.

σ12 = p1(X1 μ1)2 + p2(X2 μ1)2 + + pn(Xn μ1)2 p1 + p2 + + pn

We repeat this process for μ2 and σ22 and update our distributions. We iterate through the E-step and M-step until convergence, obtaining two clusters as shown in Figure 2.

PIC

Figure 1: Mixture of two Gaussian Distributions

PIC

Figure 2: Clusters Found by EM algorithm

2.2 Tensor

While the term “tensor” might sound unfamiliar to some, tensors are simply multi-way arrays. Data is often structured as matrices, and they are in fact second-order tensors. When we use the term “tensor,” we usually mean tensors of third- and higher-order. The “order” means the dimension of a tensor, and it is sometimes called the “mode.” You can think of a third-order tensor as a cube. As shown in Figure 3, a tensor can be manipulated similarly as a matrix. In a matrix, we can talk about rows and columns. In a third-order tensor, we can talk about fibers when you fix two modes and keep all values of one mode. The index is then the mode that has all the values. Slices are when you fix one mode and keep all values for the rest of the modes. The index is then the mode that is fixed.

PIC

Figure 3: Dimensions and Terminology of a Tensor, taken from Kolda and Bader (2009)

Before we continue onto the DEEM algorithm, some concepts and notations are necessary to understand derivations in the following section. Note that the following notations are taken from Kolda and Bader (2009). First of all, we should go over the concept of matricization. If we want to matricize a third-order tensor, we can think of cutting a cube into slices and putting the slices side-by-side to make them into a matrix. We now borrow an example (Example 2.1) given in Kolda and Bader (2009) to demonstrate how to matricize a third-order tensor X 3×4×2:

X1 = [ 1 4 7 10 2 5 8 11 ],X2 = [ 13 16 19 22 14 17 20 23 ].

The three mode-n unfoldings/matricizations are then

X1 = [ 1 4 7 10 13 16 19 22 2 5 8 11 14 17 20 23 ],
X2 = [ 1 2 3 13 14 15 4 5 6 16 17 18 7 8 9 19 20 21 ],
X3 = [ 1 2 3 4 5 9 10 11 12 ].

The example above shows that mode-1 unfolding is to put X1 and X2 side-by-side and mode-2 unfolding is to stack on top of each other.

This concept is intuitive but much more awkward when we want to define it formally. In Kolda and Bader (2009), the mode-n matricization of a tensor X I1×I2××IM is denoted by X(n), which is a matrix of the dimension (In, pInIp). (The first dimension comes from the dimension of mode n, and the second dimension comes from the product of all the other dimensions.) The tensor element (i1,,iM) is mapped to the matrix element (in,j) in the following manner:

j = 1 + k=1,knM(i k 1)JkwithJk = m=1,mnk1I m.

Next, the notation is defined as:

⟦G;A(1),A(2),,A(M) := G × 1A(1) × 2A(2) × MA(M),

where A are matrices and X and G are tensors. The symbol ×n means n-mode (matrix) product of a tensor G I1×I2××IM and the matrices A(n) J×In. G ×nA(n) is then a tensor of the dimension I1 × × In1 × J × In+1 × × IM. We can write the elements of G ×nA(n) as:

(G ×nA(n)) i1in1jin+1iN = in=1In gi1i2iN ajin(n).

Also, for any two tensors A and B in I1×I2××IM, we define their inner product by

A,B = JI1×I2××INaJbJ.

With that, we are ready to learn about the DEEM algorithm. If you are interested in knowing more about tensors, Kolda and Bader (2009) has lots of great details. So be sure to check it out!

2.3 Doubly Enhanced EM Algorithm

In this subsection, we introduce the doubly enhanced EM (DEEM) algorithm and discuss its theoretical properties. Algorithm 1 from Qing Mai and Deng (2022) is provided in the Appendix as our Figure 6.

Let Z denote the random tensor in p1,××pM such that every element in Z is distributed as iid N(0,1). Then we say that a random tensor X has a tensor normal distribution, denoted by X TN(μ;Σ1,,ΣM), if X = μ + ⟦Z;Σ112,,ΣM12, where μ p1,××pM is the total mean and each Σ i pi ×pi means the covariance matrix within ith class. In other words, a tensor normal regression is multivariate normal distribution generalized to a high dimension. An Mth order random tensor X has a total of i=1Mpi means and M covariate matrices each of the dimension pi × pi for i = 1,,M. We can find that the density of X has the form

p(X|μ;Σ1,,ΣM) = 1 (2π)p2|Σ1|qi2|ΣM|qM2 exp (1 2 ⟦X μ;Σ11, ,Σ M1,X μ),
(1)

where p = p1p2pM and qi = ppi.

We will consider independent tensor-variate observations in p1,××pM drawn from K clusters with the same within-class covariance matrices; suppose that μi’s are the mean tensor of the kth cluster. Let πk be the probability of an observation to be taken from the kth cluster.

Then the sample {Xi}i=1n from a mixture of the tensor normal distributions can be written as the following:

Xi k=1Kπ kTN(μk;Σ1,,ΣM),i = 1,2,,n,

or equivalently,

P(Y i = k) = πkXi|Y i = k TN(μk;Σ1,,ΣM),i = 1,2,,n.
(2)

Hence, Y i indicates the number of the cluster from which Xi was taken, and if Y i = k is given, Xi has the tensor normal distribution with the mean μk of the cluster k and the within-class covariance matrices Σ1,,ΣM. (Recall we assume that the clusters have the same within-class matrices.)

Suppose that {Xi}i=1n is a sample from the model (2). Let 𝜃 = {πi,μi,Σj : 1 i K,1 j M} denote the set of all parameters in the model. If we can observe Y i, then the complete log-likelihood can be obtained as follows:

c(𝜃;X,Y ) = log i=1nπ Y ip(Xi|μY i;Σ1,ΣM) = i=1n[log π Y i + log p(Xi|μY i,Σ1, ,ΣM)].

But, in general, we cannot observe Y i; hence, from an initial value 𝜃~(0), we create a sequence 𝜃~(t) through the E-step to obtain the Q function

Q(𝜃;𝜃~(t)) = E Y |X,𝜃~(t)[c(𝜃;X,Y )] = i=1n k=1Kξ~ ik(t)[log π k + log p(Xi|μk,Σ1, ΣM)]

where

ξ~ik(t) = P(Y i = k|X,𝜃~(t)) = π~k(t)p(Xi|μ~k(t),Σ~1(t),,Σ~M(t)) j=1Kπ~j(t)p(Xi|μ~j(t),Σ~1(t),,Σ~M(t))
(3)

and the M-step to update the parameter

𝜃~(t+1) = argmax 𝜃Q(𝜃;𝜃~(t)).

Then the EM sequence 𝜃~(t) converges to the MLE, but there are some issues in our situation: Getting the updates for πk and μk is quite easy and straightforward, but it is challenging to obtain the updates for the covariance matrices Σi. When we compute ξ^ik(t) in (3), all the elements in Xi are used, and the standard EM algorithm does not involve a process for variable selection. Thus, due to an excessive number of parameters in the model, it may lead to the accumulation of errors, which potentially results in inaccurate estimates.

To overcome these problems, we introduce the enhanced E-step, where we replace ξ~ik(t) with an estimator ξ^(t) that can be calculated relatively faster under the sparsity assumption. We want to find the objective function QDEEM that has a better property than the standard Q function above. First, it can be seen that

ξi1 = P(Y i = 1|Xi,𝜃) = π1 π1 + k=2Kπk exp [Xi (μk + μ1)2,Bk]
(4)

and

ξik = P(Y i = k|Xi,𝜃) = πk exp [Xi (μk + μ1)2,Bk] π1 + j=2Kπj exp [Xi (μj + μ1)2,Bj]
(5)

for k 2, where

Bk = μk μ1;Σ11,,Σ M1 p1,××pM .

Note that we are considering the clustering problem, that is, we are interested in recovering Y i’s and finding a method that minimizes the clustering error. It can be shown that the covariance matrices Σi are nuisance parameters in the optimal clustering rule, i.e., the values of Σi are not used in the optimal clustering rule if we already know Bk’s. To cover more general cases, we do not impose conditions on Σi. Instead, we assume the sparsity condition on B; for Bk = [bkJ]J, where J = (j1,,jM) denotes an index of the tensor, we impose the condition b2J = = bKJ = 0 for almost every J. In other words, if D = {J : bk J 0 for some k = 2, , K }, then we assume that the number of elements in D is significantly smaller than p = p1p2pM. This assumption comes from the belief that, in the high-dimensional setting, most of the variables are not significant in estimation. The above expressions for ξik show that this assumption reduces the computational cost and improves the estimation efficiency.

If we accept the fact that (B2,,BK) minimizes the quantity

k=2K (B k,Bk,Σ1,,ΣM 2Bk,μk μ1),

then it is reasonable to obtain the sequence of estimates B^k(t+1) by solving the optimization problem

argmin B2, ,BK k=2K (B k,Bk,Σ^1(t),,Σ^ M(t) 2B k,μ^k(t) μ^ 1(t))+λ(t+1) J ( k=2K(b kJ)2) ,

where we added the lasso penalty term to satisfy the sparsity assumption to some extent. Using these B^k(t+1), we can obtain the sequence ξ^ik(t+1) by replacing the parameters in (4) and (5) with their estimates. Then, the objective QDEEM is defined using ξ^(t) as follows:

QDEEM(𝜃;𝜃^(t)) = i=1n k=1Kξ^ ik(t)[log π k + log p(Xi|μk,Σ1, ,ΣM)].

In light of the sparsity assumption, ξ^ik(t) can be computed based on the values of relatively smaller variables.

In M-step, the parameters can be updated inductively from the proposed QDEEM function: The estimates for πk and μk can be obtained by the formula

π^k(t+1) = 1 n i=1nξ^ ik(t+1)andμ^ k(t+1) = i=1nξ^ik(t+1)Xi i=1nξ^ik(t+1) ,k = 1,2,,K.

Then given ξik(t+1), we calculate the intermediate covariance matrices

Σ j(t+1) = 1 nqj i=1n k=1Kξ^ ik(t+1)(X i μ^kT (t + 1)) (j)(Xi μ^kT (t + 1)) (j)T ,

and the conditional variance of Xi11

(σ^111)(t+1) = 1 n i=1n k=1Kξ^ ik(t+1)(X i11 (μ^ k11)(t+1))2.

The target covariance estimator is given by scaling the intermediate covariances with (σ^111)(t+1) and (σ 111)(t+1):

Σ^j(t+1) = { 1 (σ j11)(t+1)Σ j(t+1)if j 2, (σ^111)(t+1) (σ 111)(t+1)Σ 1(t+1)if j = 1.

The general process is summarized in Algorithm 1 in the Appendix as our Figure 6.

Now we are interested in how this sequence of parameters 𝜃^(t) behave. In fact, under some initialization condition, it can be seen that there are some constant C and 0 < κ < 12 such that for large t, with a probability 1 O( i=1Mpi1),

B^(t) B CsΣ i=1 M log pi n ,

where s = o(n i log pi) and d0 is a measure of the difference between the initial value 𝜃^(0) and the true parameter 𝜃. This result implies that if the number of iterations t is large, the DEEM estimator B^(t) converges to the true parameter B. Here, the condition s = o(n i log pi) means the sparsity assumption.

We consider the error rate of DEEM and the optimal clustering rule:

R(DEEM) = min ΠP(Π(Y ^iDEEMY i))andR(Opt) = P(Y ^ioptY i),

where Π denotes the permutation operator and Y ^iDEEM = argmax kξ^ik. Here, the optimal clustering rule is optimal in the sense that it minimizes the clustering error. From the result above, we can show that if t is large, then with probability 1 O( i=1Mpi1)

R(DEEM) R(Opt) Cs i=1M log pi n .

Consequently, this result shows that the error rate of DEEM converges to the error rate of the optimal clustering rule if t is large.

3 Simulation Study

3.1 Data Generation

For our simulation studies, we follow the framework used in Qing Mai and Deng (2022). For each setting, K denotes the number of mixture groups, and noise is generated as a Mth-order tensor:

Xi k=1Kπ kTN(π k;Σ 1,,Σ M),i = 1,,n
(6)

For K 1 mixture groups, the Xi is given as a given Bk plus the noise above. For 1 mixture group, the values are simply the noise. Qing Mai and Deng designate two types of Σk:

Ω = {AR(ρ) :ωij = ρ|ij| CS(ρ) :ωij = ρ + (1 ρ)1(i = j).

The covariance matrices are not sparse if they are set using the two formats above. For each setting, we generate 100 independent datasets, the same number of replicates as used by the authors, and present the mean error rate and standard deviation.

3.2 Settings

The settings are provided in Table 1. Note that, for Bk, the indices not included in the subscript is 0. In other words, Bk is a sparse tensor. Both the DEEM and EM algorithms require K to be given. λ is a tuning parameter for regularization; however, due to the long computation time, we experiment with several different values of λ and set it at 0.05 instead of tuning it for each setting. We encourage interested readers to try out two functions–tune_K and tune_lambda–in the R package TensorClustering.

We choose four settings from the seven settings, because these settings are increasingly more computationally expensive, and we believe that they demonstrate the advantage of the DEEM algorithm compared to the classical EM algorithm in terms of accuracy, as shown in Table 2. We also run three extra settings, which we call M8, M9, M10 to avoid confusion with the seven settings in the paper. The results for these extra settings are shown in Tables 4 and 5.

Setting

Parameters

M1

K = 2,p = 10×10×4,Σ1 = CS(0.3),Σ2 = AR(0.8),Σ3 = CS(0.3),B2,[1:6,1,1] = 0.5

M3

K = 3,p = 10×10×4,Σ1 = CS(0.3),Σ2 = AR(0.8),Σ3 = CS(0.5),B2,[1:6,1,1] = 0.5,B3,[1:6,1,1] = 0.5

M4

K = 4,p = 10×10×4,Σ1 = I10,Σ2 = AR(0.8),Σ3 = I4,B2,[1:6,1,1] = 0.8,B3,[1:6,1,1] = 0.8

M5

K = 6,p = 10×10×4,Σ1 = AR(0.9),Σ2 = CS(0.6),Σ3 = AR(0.9),B2,[1:6,1,1] = 0.6,B3,[1:6,1,1] = 1.2,B4,[1:6,1,1] = 1.8,B5,[1:6,1,1] = 2.4,B6,[1:6,1,1] = 3

M8

K = 2,p = 10×10×10,Σ1 = AR(0.5),Σ2 = CS(0.5),Σ3 = AR(0.5),B2,[1:6,1,1] = 1.5

M9

K = 2,p = 10×10×10,Σ1 = I10,Σ2 = I10,Σ3 = I10,B2,[1:6,1,1] = 1.5

M10

K = 2,p = 10×10×4×4×4,Σ1 = AR(0.5),Σ2 = CS(0.5),Σ3 = AR(0.5),Σ4 = I4,Σ5 = I4,B2,[1:6,1,1,1,1] = 5

Table 1: Simulation settings

3.3 Metrics

Note that it is as straightforward to calculate the mean error rate for a clustering problem as it is for a classification problem. Both methods return labels for the groups; however, the group labels do not matter. For example, if there are five observations and their true group labels are (1,1,2,2,2) and the methods return (2,2,1,1,1), the error rate should be 0. In the paper, the authors explain that the mean clustering error rate is calculated by:

min Π 1 n i=1n1(Ŷ iΠ(Y i)) over all possible permutations Π : {1,,}{1,,K}

We create a function to permute the true labels, compare the estimated labels and the true labels, and return the lowest error rate. To compare the speed of the two methods, we also record the computation time. Tables 3 and 5 provides the mean computation time and standard error (in parentheses) for each setting.

3.4 External R Packages and Functions

For the DEEM algorithm, we use the function DEEM; for the standard EM algorithm, we use the function TGMM. Both functions are from the R package TensorClustering. We use the Trnorm function from the R package Tlasso to generate tensor noise with designated covariance matrices. We use the permutations function in the gtools package to permute true labels. In short, be sure to install the three R packages: TensorClustering, Tlasso, and gtools if you would like to reproduce our simulation.

3.5 Simulation result

The error rates and computation time are shown in Tables 2 and 3. It is clear that DEEM has lower mean error rates in all four settings. The error rates are in general higher than those given in the article, possibly because the hyperparameters are not tuned for each setting in our case.

The computation time tells a different story, however. As seen in Table 3, DEEM is not always the winner in terms of time. As the setting becomes more complicated and estimating the clusters becomes more challenging, it takes longer for DEEM to converge. For setting M5, it is possible that DEEM reached the maximum iterations for some runs.

Setting DEEM EM
M1 0.41 (0.05) 0.45 (0.03)
M3 0.46 (0.09) 0.56 (0.05)
M4 0.35 (0.03) 0.57 (0.06)
M5 0.31 (0.11) 0.43 (0.06)
Table 2: Error Rates from 100 Replicates

Setting DEEM EM
M1 0.72 (0.45) 0.93 (0.39)
M3 13.95 (7.78) 7.74 (3.99)
M4 15.8 (0.81) 21.9 (9.68)
M5 332.96 (124.38) 14.66 (5.88)
Table 3: Computation Time (seconds) from 100 Replicates

Next, we transform the values in the tables into figures. As shown in Figure 4, DEEM always has lower mean error rates. However, as the model becomes complicated, DEEM’s error rates become more varied, even though the mean rate is still lower. In Figure 5, the story seems more complicated. (Note that we cannot use the same y-axis for all four plots, because the computation time for DEEM for M5 is so long, which would make some of the boxes very small and not informative.) For the two settings M1 and M4, DEEM has a lower computation time. For M3, the computation time for DEEM is much more varied, and EM has an overall shorter computation time. For M5, DEEM has a very long computation time; in fact, the 100 replicates took almost 10 hours. It is unclear if the reduced error rate is worth the computational cost.

PIC

Figure 4: Boxplots of Mean Error Rates from 100 Replicates

PIC

Figure 5: Boxplots of Mean Computation Time (in seconds) from 100 Replicates

In terms of the extra settings, we noticed that the error for DEEM is much lower than EM for M8, for which we set two clusters with the covariance matrices to have a moderate correlation. For M9, we let the setting be an easy case, since the covariance matrices are all identity matrices, and DEEM still has a much lower mean error rate than EM. For M10, we use a fifth-order tensor and make the estimating task an easy one. As seen in Table 4, DEEM has a much lower mean error rate than EM, although its computation time is again quite long.

One recurrent problem from the simulation is that DEEM has a much longer running time than EM. We leave this question about DEEM’s computation time to interested readers.

Setting DEEM EM
M8 0.28 (0.11) 0.45 (0.03)
M9 0.18 (0.05) 0.34 (0.10)
M10 0.0003 (0.002) 0.31 (0.13)
Table 4: Error Rates from 100 Replicates

Setting DEEM EM
M8 23.7 (4,3) 0.17 (0.07)
M9 7.99 (1.88) 0.32 (0.09)
M10 72.8 (211.41) 1.51 (0.08)
Table 5: Computation Time (seconds) from 100 Replicates

4 Summary

In this blogpost, we review a new method proposed by Qing Mai and Deng (2022), which is essentially an upgraded version of the classical EM algorithm. This new method, DEEM, tends to have lower error rates on tensor data. However, despite the paper’s claim that the enhanced M step in the DEEM algorithm facilitates fast covariance estimation, we have encountered situations where the running time could be prohibitive. While DEEM proves to be efficient and effective in handling tensor data, there remains potential for further enhancement.

References

   Dempster, A., Laird, N. and Rubin, D. (1977) Maximum likelihood from incomplete data via the em algorithm. Journal of the Royal Statistical Society, 39, 1–22.

   Kolda, T. G. and Bader, B. W. (2009) Tensor decompositions and applications. SIAM Review, 51, 455–500. URLhttps://doi.org/10.1137/07070111X.

   Qing Mai, Xin Zhang, Y. P. and Deng, K. (2022) A doubly enhanced em algorithm for model-based tensor clustering. Journal of the American Statistical Association, 117, 2120–2134.

   Yang, M.-S., Lai, C.-Y. and Lin, C.-Y. (2012) A robust em clustering algorithm for gaussian mixture models. Pattern Recognition, 45, 3950–3961.

Appendix

PIC

Figure 6: DEEM algorithm