View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution.fitting;
18  
19  import java.util.ArrayList;
20  import java.util.Arrays;
21  import java.util.List;
22  
23  import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
24  import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
25  import org.apache.commons.math4.legacy.exception.ConvergenceException;
26  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
29  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
30  import org.apache.commons.math4.legacy.linear.RealMatrix;
31  import org.apache.commons.math4.legacy.core.Pair;
32  import org.junit.Assert;
33  import org.junit.Test;
34  
35  /**
36   * Test that demonstrates the use of
37   * {@link MultivariateNormalMixtureExpectationMaximization}.
38   */
39  public class MultivariateNormalMixtureExpectationMaximizationTest {
40  
41      @Test(expected = NotStrictlyPositiveException.class)
42      public void testNonEmptyData() {
43          // Should not accept empty data
44          new MultivariateNormalMixtureExpectationMaximization(new double[][] {});
45      }
46  
47      @Test(expected = DimensionMismatchException.class)
48      public void testNonJaggedData() {
49          // Reject data with nonconstant numbers of columns
50          double[][] data = new double[][] {
51                  { 1, 2, 3 },
52                  { 4, 5, 6, 7 },
53          };
54          new MultivariateNormalMixtureExpectationMaximization(data);
55      }
56  
57      @Test(expected = NumberIsTooSmallException.class)
58      public void testMultipleColumnsRequired() {
59          // Data should have at least 1 column
60          double[][] data = new double[][] {
61                  {}, {}
62          };
63          new MultivariateNormalMixtureExpectationMaximization(data);
64      }
65  
66      @Test(expected = NotStrictlyPositiveException.class)
67      public void testMaxIterationsPositive() {
68          // Maximum iterations for fit must be positive integer
69          double[][] data = getTestSamples();
70          MultivariateNormalMixtureExpectationMaximization fitter =
71                  new MultivariateNormalMixtureExpectationMaximization(data);
72  
73          MixtureMultivariateNormalDistribution
74              initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
75  
76          fitter.fit(initialMix, 0, 1E-5);
77      }
78  
79      @Test(expected = NotStrictlyPositiveException.class)
80      public void testThresholdPositive() {
81          // Maximum iterations for fit must be positive
82          double[][] data = getTestSamples();
83          MultivariateNormalMixtureExpectationMaximization fitter =
84                  new MultivariateNormalMixtureExpectationMaximization(
85                      data);
86  
87          MixtureMultivariateNormalDistribution
88              initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
89  
90          fitter.fit(initialMix, 1000, 0);
91      }
92  
93      @Test(expected = ConvergenceException.class)
94      public void testConvergenceException() {
95          // ConvergenceException thrown if fit terminates before threshold met
96          double[][] data = getTestSamples();
97          MultivariateNormalMixtureExpectationMaximization fitter
98              = new MultivariateNormalMixtureExpectationMaximization(data);
99  
100         MixtureMultivariateNormalDistribution
101             initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
102 
103         // 5 iterations not enough to meet convergence threshold
104         fitter.fit(initialMix, 5, 1E-5);
105     }
106 
107     @Test(expected = DimensionMismatchException.class)
108     public void testIncompatibleInitialMixture() {
109         // Data has 3 columns
110         double[][] data = new double[][] {
111                 { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
112         };
113         double[] weights = new double[] { 0.5, 0.5 };
114 
115         // These distributions are compatible with 2-column data, not 3-column
116         // data
117         MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
118 
119         mvns[0] = new MultivariateNormalDistribution(new double[] {
120                         -0.0021722935000328823, 3.5432892936887908 },
121                         new double[][] {
122                                 { 4.537422569229048, 3.5266152281729304 },
123                                 { 3.5266152281729304, 6.175448814169779 } });
124         mvns[1] = new MultivariateNormalDistribution(new double[] {
125                         5.090902706507635, 8.68540656355283 }, new double[][] {
126                         { 2.886778573963039, 1.5257474543463154 },
127                         { 1.5257474543463154, 3.3794567673616918 } });
128 
129         // Create components and mixture
130         List<Pair<Double, MultivariateNormalDistribution>> components =
131                 new ArrayList<>();
132         components.add(new Pair<>(
133                 weights[0], mvns[0]));
134         components.add(new Pair<>(
135                 weights[1], mvns[1]));
136 
137         MixtureMultivariateNormalDistribution badInitialMix
138             = new MixtureMultivariateNormalDistribution(components);
139 
140         MultivariateNormalMixtureExpectationMaximization fitter
141             = new MultivariateNormalMixtureExpectationMaximization(data);
142 
143         fitter.fit(badInitialMix);
144     }
145 
146     @Test
147     public void testInitialMixture() {
148         // Testing initial mixture estimated from data
149         final double[] correctWeights = new double[] { 0.5, 0.5 };
150 
151         final double[][] correctMeans = new double[][] {
152             {-0.0021722935000328823, 3.5432892936887908},
153             {5.090902706507635, 8.68540656355283},
154         };
155 
156         final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
157 
158         correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
159                 { 4.537422569229048, 3.5266152281729304 },
160                 { 3.5266152281729304, 6.175448814169779 } });
161 
162         correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
163                 { 2.886778573963039, 1.5257474543463154 },
164                 { 1.5257474543463154, 3.3794567673616918 } });
165 
166         final MultivariateNormalDistribution[] correctMVNs = new
167                 MultivariateNormalDistribution[2];
168 
169         correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
170                 correctCovMats[0].getData());
171 
172         correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
173                 correctCovMats[1].getData());
174 
175         final MixtureMultivariateNormalDistribution initialMix
176             = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
177 
178         int i = 0;
179         for (Pair<Double, MultivariateNormalDistribution> component : initialMix
180                 .getComponents()) {
181             Assert.assertEquals(correctWeights[i], component.getFirst(),
182                     Math.ulp(1d));
183 
184             final double[] means = component.getValue().getMeans();
185             Assert.assertArrayEquals(correctMeans[i], means, 0.0);
186 
187             final RealMatrix covMat = component.getValue().getCovariances();
188             Assert.assertEquals(correctCovMats[i], covMat);
189             i++;
190         }
191     }
192 
193     @Test
194     public void testFit2Dimensions2Components() {
195         final double[][] data = getTestSamples();
196 
197         // Fit using the test samples using Matlab R2023b (Update 6):
198         // GMModel = fitgmdist(X,2);
199 
200         // Expected results use the component order generated by the CM code for convenience
201         // i.e. ComponentProportion from matlab is reversed: [0.703722, 0.296278]
202 
203         // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
204         final double logLikelihood = -4.292430883324220e+02 / data.length;
205         // ComponentProportion
206         final double[] weights = new double[] {0.2962324189652912, 0.7037675810347089};
207         // mu
208         final double[][] means = new double[][]{
209             {-1.421239458366293, 1.692604555824222},
210             {4.213949861591596, 7.975974466776790}
211         };
212         // Sigma
213         final double[][][] covar = new double[][][] {
214             {{1.739441346307267, -0.586740858187563},
215              {-0.586740858187563, 1.023420964341543}},
216             {{4.243780645051973, 2.578176622652551},
217              {2.578176622652551, 3.918302056479298}}
218         };
219 
220         assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3);
221     }
222 
223     @Test
224     public void testFit1Dimension2Components() {
225         // Use only the first column of the test data
226         final double[][] data = Arrays.stream(getTestSamples())
227             .map(x -> new double[] {x[0]}).toArray(double[][]::new);
228 
229         // Fit the first column of test samples using Matlab R2023b (Update 6):
230         // GMModel = fitgmdist(X,2);
231 
232         // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
233         final double logLikelihood = -2.512197016873482e+02 / data.length;
234         // ComponentProportion
235         final double[] weights = new double[] {0.240510201974078, 0.759489798025922};
236         // Since data has 1 dimension the means and covariances are single values
237         // mu
238         final double[][] means = new double[][]{
239             {-1.736139126623031},
240             {3.899886984922886}
241         };
242         // Sigma
243         final double[][][] covar = new double[][][] {
244             {{1.371327786710623}},
245             {{5.254286022455004}}
246         };
247 
248         assertFit(data, 2, logLikelihood, weights, means, covar, 0.05);
249     }
250 
251     @Test
252     public void testFit1Dimension1Component() {
253         // Use only the first column of the test data
254         final double[][] data = Arrays.stream(getTestSamples())
255             .map(x -> new double[] {x[0]}).toArray(double[][]::new);
256 
257         // Fit the first column of test samples using Matlab R2023b (Update 6):
258         // GMModel = fitgmdist(X,1);
259 
260         // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations)
261         final double logLikelihood = -2.576329329354790e+02 / data.length;
262         // ComponentProportion
263         final double[] weights = new double[] {1.0};
264         // Since data has 1 dimension the means and covariances are single values
265         // mu
266         final double[][] means = new double[][]{
267             {2.544365206503801},
268         };
269         // Sigma
270         final double[][][] covar = new double[][][] {
271             {{10.122711799089901}},
272         };
273 
274         assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3);
275     }
276 
277     private static void assertFit(double[][] data, int numComponents,
278             double logLikelihood, double[] weights,
279             double[][] means, double[][][] covar, double relError) {
280         MultivariateNormalMixtureExpectationMaximization fitter
281             = new MultivariateNormalMixtureExpectationMaximization(data);
282 
283         MixtureMultivariateNormalDistribution initialMix
284             = MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
285         fitter.fit(initialMix);
286         MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
287         List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
288 
289         Assert.assertEquals(logLikelihood,
290             fitter.getLogLikelihood(),
291             Math.abs(logLikelihood) * relError);
292 
293         int i = 0;
294         for (Pair<Double, MultivariateNormalDistribution> component : components) {
295             final double weight = component.getFirst();
296             final MultivariateNormalDistribution mvn = component.getSecond();
297             Assert.assertEquals(weights[i], weight, weights[i] * relError);
298             assertArrayEquals(means[i], mvn.getMeans(), relError);
299             final double[][] c = mvn.getCovariances().getData();
300             Assert.assertEquals(covar[i].length, c.length);
301             for (int j = 0; j < covar[i].length; j++) {
302                 assertArrayEquals(covar[i][j], c[j], relError);
303             }
304             i++;
305         }
306     }
307 
308     private static void assertArrayEquals(double[] e, double[] a, double relError) {
309         Assert.assertEquals("length", e.length, a.length);
310         for (int i = 0; i < e.length; i++) {
311             Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError);
312         }
313     }
314 
315     private double[][] getTestSamples() {
316         // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
317         // [4, 8.2]
318         return new double[][] { { 7.358553610469948, 11.31260831446758 },
319                 { 7.175770420124739, 8.988812210204454 },
320                 { 4.324151905768422, 6.837727899051482 },
321                 { 2.157832219173036, 6.317444585521968 },
322                 { -1.890157421896651, 1.74271202875498 },
323                 { 0.8922409354455803, 1.999119343923781 },
324                 { 3.396949764787055, 6.813170372579068 },
325                 { -2.057498232686068, -0.002522983830852255 },
326                 { 6.359932157365045, 8.343600029975851 },
327                 { 3.353102234276168, 7.087541882898689 },
328                 { -1.763877221595639, 0.9688890460330644 },
329                 { 6.151457185125111, 9.075011757431174 },
330                 { 4.281597398048899, 5.953270070976117 },
331                 { 3.549576703974894, 8.616038155992861 },
332                 { 6.004706732349854, 8.959423391087469 },
333                 { 2.802915014676262, 6.285676742173564 },
334                 { -0.6029879029880616, 1.083332958357485 },
335                 { 3.631827105398369, 6.743428504049444 },
336                 { 6.161125014007315, 9.60920569689001 },
337                 { -1.049582894255342, 0.2020017892080281 },
338                 { 3.910573022688315, 8.19609909534937 },
339                 { 8.180454017634863, 7.861055769719962 },
340                 { 1.488945440439716, 8.02699903761247 },
341                 { 4.813750847823778, 12.34416881332515 },
342                 { 0.0443208501259158, 5.901148093240691 },
343                 { 4.416417235068346, 4.465243084006094 },
344                 { 4.0002433603072, 6.721937850166174 },
345                 { 3.190113818788205, 10.51648348411058 },
346                 { 4.493600914967883, 7.938224231022314 },
347                 { -3.675669533266189, 4.472845076673303 },
348                 { 6.648645511703989, 12.03544085965724 },
349                 { -1.330031331404445, 1.33931042964811 },
350                 { -3.812111460708707, 2.50534195568356 },
351                 { 5.669339356648331, 6.214488981177026 },
352                 { 1.006596727153816, 1.51165463112716 },
353                 { 5.039466365033024, 7.476532610478689 },
354                 { 4.349091929968925, 7.446356406259756 },
355                 { -1.220289665119069, 3.403926955951437 },
356                 { 5.553003979122395, 6.886518211202239 },
357                 { 2.274487732222856, 7.009541508533196 },
358                 { 4.147567059965864, 7.34025244349202 },
359                 { 4.083882618965819, 6.362852861075623 },
360                 { 2.203122344647599, 7.260295257904624 },
361                 { -2.147497550770442, 1.262293431529498 },
362                 { 2.473700950426512, 6.558900135505638 },
363                 { 8.267081298847554, 12.10214104577748 },
364                 { 6.91977329776865, 9.91998488301285 },
365                 { 0.1680479852730894, 6.28286034168897 },
366                 { -1.268578659195158, 2.326711221485755 },
367                 { 1.829966451374701, 6.254187605304518 },
368                 { 5.648849025754848, 9.330002040750291 },
369                 { -2.302874793257666, 3.585545172776065 },
370                 { -2.629218791709046, 2.156215538500288 },
371                 { 4.036618140700114, 10.2962785719958 },
372                 { 0.4616386422783874, 0.6782756325806778 },
373                 { -0.3447896073408363, 0.4999834691645118 },
374                 { -0.475281453118318, 1.931470384180492 },
375                 { 2.382509690609731, 6.071782429815853 },
376                 { -3.203934441889096, 2.572079552602468 },
377                 { 8.465636032165087, 13.96462998683518 },
378                 { 2.36755660870416, 5.7844595007273 },
379                 { 0.5935496528993371, 1.374615871358943 },
380                 { -2.467481505748694, 2.097224634713005 },
381                 { 4.27867444328542, 10.24772361238549 },
382                 { -2.013791907543137, 2.013799426047639 },
383                 { 6.424588084404173, 9.185334939684516 },
384                 { -0.8448238876802175, 0.5447382022282812 },
385                 { 1.342955703473923, 8.645456317633556 },
386                 { 3.108712208751979, 8.512156853800064 },
387                 { 4.343205178315472, 8.056869549234374 },
388                 { -2.971767642212396, 3.201180146824761 },
389                 { 2.583820931523672, 5.459873414473854 },
390                 { 4.209139115268925, 8.171098193546225 },
391                 { 0.4064909057902746, 1.454390775518743 },
392                 { 3.068642411145223, 6.959485153620035 },
393                 { 6.085968972900461, 7.391429799500965 },
394                 { -1.342265795764202, 1.454550012997143 },
395                 { 6.249773274516883, 6.290269880772023 },
396                 { 4.986225847822566, 7.75266344868907 },
397                 { 7.642443254378944, 10.19914817500263 },
398                 { 6.438181159163673, 8.464396764810347 },
399                 { 2.520859761025108, 7.68222425260111 },
400                 { 2.883699944257541, 6.777960331348503 },
401                 { 2.788004550956599, 6.634735386652733 },
402                 { 3.331661231995638, 5.794191300046592 },
403                 { 3.526172276645504, 6.710802266815884 },
404                 { 3.188298528138741, 10.34495528210205 },
405                 { 0.7345539486114623, 5.807604004180681 },
406                 { 1.165044595880125, 7.830121829295257 },
407                 { 7.146962523500671, 11.62995162065415 },
408                 { 7.813872137162087, 10.62827008714735 },
409                 { 3.118099164870063, 8.286003148186371 },
410                 { -1.708739286262571, 1.561026755374264 },
411                 { 1.786163047580084, 4.172394388214604 },
412                 { 3.718506403232386, 7.807752990130349 },
413                 { 6.167414046828899, 10.01104941031293 },
414                 { -1.063477247689196, 1.61176085846339 },
415                 { -3.396739609433642, 0.7127911050002151 },
416                 { 2.438885945896797, 7.353011138689225 },
417                 { -0.2073204144780931, 0.850771146627012 }, };
418     }
419 }