1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.distribution.fitting;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22
23 import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
24 import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
25 import org.apache.commons.math4.legacy.exception.ConvergenceException;
26 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
29 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
30 import org.apache.commons.math4.legacy.linear.RealMatrix;
31 import org.apache.commons.math4.legacy.core.Pair;
32 import org.junit.Assert;
33 import org.junit.Test;
34
35
36
37
38
39 public class MultivariateNormalMixtureExpectationMaximizationTest {
40
41 @Test(expected = NotStrictlyPositiveException.class)
42 public void testNonEmptyData() {
43
44 new MultivariateNormalMixtureExpectationMaximization(new double[][] {});
45 }
46
47 @Test(expected = DimensionMismatchException.class)
48 public void testNonJaggedData() {
49
50 double[][] data = new double[][] {
51 { 1, 2, 3 },
52 { 4, 5, 6, 7 },
53 };
54 new MultivariateNormalMixtureExpectationMaximization(data);
55 }
56
57 @Test(expected = NumberIsTooSmallException.class)
58 public void testMultipleColumnsRequired() {
59
60 double[][] data = new double[][] {
61 {}, {}
62 };
63 new MultivariateNormalMixtureExpectationMaximization(data);
64 }
65
66 @Test(expected = NotStrictlyPositiveException.class)
67 public void testMaxIterationsPositive() {
68
69 double[][] data = getTestSamples();
70 MultivariateNormalMixtureExpectationMaximization fitter =
71 new MultivariateNormalMixtureExpectationMaximization(data);
72
73 MixtureMultivariateNormalDistribution
74 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
75
76 fitter.fit(initialMix, 0, 1E-5);
77 }
78
79 @Test(expected = NotStrictlyPositiveException.class)
80 public void testThresholdPositive() {
81
82 double[][] data = getTestSamples();
83 MultivariateNormalMixtureExpectationMaximization fitter =
84 new MultivariateNormalMixtureExpectationMaximization(
85 data);
86
87 MixtureMultivariateNormalDistribution
88 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
89
90 fitter.fit(initialMix, 1000, 0);
91 }
92
93 @Test(expected = ConvergenceException.class)
94 public void testConvergenceException() {
95
96 double[][] data = getTestSamples();
97 MultivariateNormalMixtureExpectationMaximization fitter
98 = new MultivariateNormalMixtureExpectationMaximization(data);
99
100 MixtureMultivariateNormalDistribution
101 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
102
103
104 fitter.fit(initialMix, 5, 1E-5);
105 }
106
107 @Test(expected = DimensionMismatchException.class)
108 public void testIncompatibleInitialMixture() {
109
110 double[][] data = new double[][] {
111 { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
112 };
113 double[] weights = new double[] { 0.5, 0.5 };
114
115
116
117 MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
118
119 mvns[0] = new MultivariateNormalDistribution(new double[] {
120 -0.0021722935000328823, 3.5432892936887908 },
121 new double[][] {
122 { 4.537422569229048, 3.5266152281729304 },
123 { 3.5266152281729304, 6.175448814169779 } });
124 mvns[1] = new MultivariateNormalDistribution(new double[] {
125 5.090902706507635, 8.68540656355283 }, new double[][] {
126 { 2.886778573963039, 1.5257474543463154 },
127 { 1.5257474543463154, 3.3794567673616918 } });
128
129
130 List<Pair<Double, MultivariateNormalDistribution>> components =
131 new ArrayList<>();
132 components.add(new Pair<>(
133 weights[0], mvns[0]));
134 components.add(new Pair<>(
135 weights[1], mvns[1]));
136
137 MixtureMultivariateNormalDistribution badInitialMix
138 = new MixtureMultivariateNormalDistribution(components);
139
140 MultivariateNormalMixtureExpectationMaximization fitter
141 = new MultivariateNormalMixtureExpectationMaximization(data);
142
143 fitter.fit(badInitialMix);
144 }
145
146 @Test
147 public void testInitialMixture() {
148
149 final double[] correctWeights = new double[] { 0.5, 0.5 };
150
151 final double[][] correctMeans = new double[][] {
152 {-0.0021722935000328823, 3.5432892936887908},
153 {5.090902706507635, 8.68540656355283},
154 };
155
156 final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
157
158 correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
159 { 4.537422569229048, 3.5266152281729304 },
160 { 3.5266152281729304, 6.175448814169779 } });
161
162 correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
163 { 2.886778573963039, 1.5257474543463154 },
164 { 1.5257474543463154, 3.3794567673616918 } });
165
166 final MultivariateNormalDistribution[] correctMVNs = new
167 MultivariateNormalDistribution[2];
168
169 correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
170 correctCovMats[0].getData());
171
172 correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
173 correctCovMats[1].getData());
174
175 final MixtureMultivariateNormalDistribution initialMix
176 = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
177
178 int i = 0;
179 for (Pair<Double, MultivariateNormalDistribution> component : initialMix
180 .getComponents()) {
181 Assert.assertEquals(correctWeights[i], component.getFirst(),
182 Math.ulp(1d));
183
184 final double[] means = component.getValue().getMeans();
185 Assert.assertArrayEquals(correctMeans[i], means, 0.0);
186
187 final RealMatrix covMat = component.getValue().getCovariances();
188 Assert.assertEquals(correctCovMats[i], covMat);
189 i++;
190 }
191 }
192
193 @Test
194 public void testFit2Dimensions2Components() {
195 final double[][] data = getTestSamples();
196
197
198
199
200
201
202
203
204 final double logLikelihood = -4.292430883324220e+02 / data.length;
205
206 final double[] weights = new double[] {0.2962324189652912, 0.7037675810347089};
207
208 final double[][] means = new double[][]{
209 {-1.421239458366293, 1.692604555824222},
210 {4.213949861591596, 7.975974466776790}
211 };
212
213 final double[][][] covar = new double[][][] {
214 {{1.739441346307267, -0.586740858187563},
215 {-0.586740858187563, 1.023420964341543}},
216 {{4.243780645051973, 2.578176622652551},
217 {2.578176622652551, 3.918302056479298}}
218 };
219
220 assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3);
221 }
222
223 @Test
224 public void testFit1Dimension2Components() {
225
226 final double[][] data = Arrays.stream(getTestSamples())
227 .map(x -> new double[] {x[0]}).toArray(double[][]::new);
228
229
230
231
232
233 final double logLikelihood = -2.512197016873482e+02 / data.length;
234
235 final double[] weights = new double[] {0.240510201974078, 0.759489798025922};
236
237
238 final double[][] means = new double[][]{
239 {-1.736139126623031},
240 {3.899886984922886}
241 };
242
243 final double[][][] covar = new double[][][] {
244 {{1.371327786710623}},
245 {{5.254286022455004}}
246 };
247
248 assertFit(data, 2, logLikelihood, weights, means, covar, 0.05);
249 }
250
251 @Test
252 public void testFit1Dimension1Component() {
253
254 final double[][] data = Arrays.stream(getTestSamples())
255 .map(x -> new double[] {x[0]}).toArray(double[][]::new);
256
257
258
259
260
261 final double logLikelihood = -2.576329329354790e+02 / data.length;
262
263 final double[] weights = new double[] {1.0};
264
265
266 final double[][] means = new double[][]{
267 {2.544365206503801},
268 };
269
270 final double[][][] covar = new double[][][] {
271 {{10.122711799089901}},
272 };
273
274 assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3);
275 }
276
277 private static void assertFit(double[][] data, int numComponents,
278 double logLikelihood, double[] weights,
279 double[][] means, double[][][] covar, double relError) {
280 MultivariateNormalMixtureExpectationMaximization fitter
281 = new MultivariateNormalMixtureExpectationMaximization(data);
282
283 MixtureMultivariateNormalDistribution initialMix
284 = MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
285 fitter.fit(initialMix);
286 MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
287 List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
288
289 Assert.assertEquals(logLikelihood,
290 fitter.getLogLikelihood(),
291 Math.abs(logLikelihood) * relError);
292
293 int i = 0;
294 for (Pair<Double, MultivariateNormalDistribution> component : components) {
295 final double weight = component.getFirst();
296 final MultivariateNormalDistribution mvn = component.getSecond();
297 Assert.assertEquals(weights[i], weight, weights[i] * relError);
298 assertArrayEquals(means[i], mvn.getMeans(), relError);
299 final double[][] c = mvn.getCovariances().getData();
300 Assert.assertEquals(covar[i].length, c.length);
301 for (int j = 0; j < covar[i].length; j++) {
302 assertArrayEquals(covar[i][j], c[j], relError);
303 }
304 i++;
305 }
306 }
307
308 private static void assertArrayEquals(double[] e, double[] a, double relError) {
309 Assert.assertEquals("length", e.length, a.length);
310 for (int i = 0; i < e.length; i++) {
311 Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError);
312 }
313 }
314
315 private double[][] getTestSamples() {
316
317
318 return new double[][] { { 7.358553610469948, 11.31260831446758 },
319 { 7.175770420124739, 8.988812210204454 },
320 { 4.324151905768422, 6.837727899051482 },
321 { 2.157832219173036, 6.317444585521968 },
322 { -1.890157421896651, 1.74271202875498 },
323 { 0.8922409354455803, 1.999119343923781 },
324 { 3.396949764787055, 6.813170372579068 },
325 { -2.057498232686068, -0.002522983830852255 },
326 { 6.359932157365045, 8.343600029975851 },
327 { 3.353102234276168, 7.087541882898689 },
328 { -1.763877221595639, 0.9688890460330644 },
329 { 6.151457185125111, 9.075011757431174 },
330 { 4.281597398048899, 5.953270070976117 },
331 { 3.549576703974894, 8.616038155992861 },
332 { 6.004706732349854, 8.959423391087469 },
333 { 2.802915014676262, 6.285676742173564 },
334 { -0.6029879029880616, 1.083332958357485 },
335 { 3.631827105398369, 6.743428504049444 },
336 { 6.161125014007315, 9.60920569689001 },
337 { -1.049582894255342, 0.2020017892080281 },
338 { 3.910573022688315, 8.19609909534937 },
339 { 8.180454017634863, 7.861055769719962 },
340 { 1.488945440439716, 8.02699903761247 },
341 { 4.813750847823778, 12.34416881332515 },
342 { 0.0443208501259158, 5.901148093240691 },
343 { 4.416417235068346, 4.465243084006094 },
344 { 4.0002433603072, 6.721937850166174 },
345 { 3.190113818788205, 10.51648348411058 },
346 { 4.493600914967883, 7.938224231022314 },
347 { -3.675669533266189, 4.472845076673303 },
348 { 6.648645511703989, 12.03544085965724 },
349 { -1.330031331404445, 1.33931042964811 },
350 { -3.812111460708707, 2.50534195568356 },
351 { 5.669339356648331, 6.214488981177026 },
352 { 1.006596727153816, 1.51165463112716 },
353 { 5.039466365033024, 7.476532610478689 },
354 { 4.349091929968925, 7.446356406259756 },
355 { -1.220289665119069, 3.403926955951437 },
356 { 5.553003979122395, 6.886518211202239 },
357 { 2.274487732222856, 7.009541508533196 },
358 { 4.147567059965864, 7.34025244349202 },
359 { 4.083882618965819, 6.362852861075623 },
360 { 2.203122344647599, 7.260295257904624 },
361 { -2.147497550770442, 1.262293431529498 },
362 { 2.473700950426512, 6.558900135505638 },
363 { 8.267081298847554, 12.10214104577748 },
364 { 6.91977329776865, 9.91998488301285 },
365 { 0.1680479852730894, 6.28286034168897 },
366 { -1.268578659195158, 2.326711221485755 },
367 { 1.829966451374701, 6.254187605304518 },
368 { 5.648849025754848, 9.330002040750291 },
369 { -2.302874793257666, 3.585545172776065 },
370 { -2.629218791709046, 2.156215538500288 },
371 { 4.036618140700114, 10.2962785719958 },
372 { 0.4616386422783874, 0.6782756325806778 },
373 { -0.3447896073408363, 0.4999834691645118 },
374 { -0.475281453118318, 1.931470384180492 },
375 { 2.382509690609731, 6.071782429815853 },
376 { -3.203934441889096, 2.572079552602468 },
377 { 8.465636032165087, 13.96462998683518 },
378 { 2.36755660870416, 5.7844595007273 },
379 { 0.5935496528993371, 1.374615871358943 },
380 { -2.467481505748694, 2.097224634713005 },
381 { 4.27867444328542, 10.24772361238549 },
382 { -2.013791907543137, 2.013799426047639 },
383 { 6.424588084404173, 9.185334939684516 },
384 { -0.8448238876802175, 0.5447382022282812 },
385 { 1.342955703473923, 8.645456317633556 },
386 { 3.108712208751979, 8.512156853800064 },
387 { 4.343205178315472, 8.056869549234374 },
388 { -2.971767642212396, 3.201180146824761 },
389 { 2.583820931523672, 5.459873414473854 },
390 { 4.209139115268925, 8.171098193546225 },
391 { 0.4064909057902746, 1.454390775518743 },
392 { 3.068642411145223, 6.959485153620035 },
393 { 6.085968972900461, 7.391429799500965 },
394 { -1.342265795764202, 1.454550012997143 },
395 { 6.249773274516883, 6.290269880772023 },
396 { 4.986225847822566, 7.75266344868907 },
397 { 7.642443254378944, 10.19914817500263 },
398 { 6.438181159163673, 8.464396764810347 },
399 { 2.520859761025108, 7.68222425260111 },
400 { 2.883699944257541, 6.777960331348503 },
401 { 2.788004550956599, 6.634735386652733 },
402 { 3.331661231995638, 5.794191300046592 },
403 { 3.526172276645504, 6.710802266815884 },
404 { 3.188298528138741, 10.34495528210205 },
405 { 0.7345539486114623, 5.807604004180681 },
406 { 1.165044595880125, 7.830121829295257 },
407 { 7.146962523500671, 11.62995162065415 },
408 { 7.813872137162087, 10.62827008714735 },
409 { 3.118099164870063, 8.286003148186371 },
410 { -1.708739286262571, 1.561026755374264 },
411 { 1.786163047580084, 4.172394388214604 },
412 { 3.718506403232386, 7.807752990130349 },
413 { 6.167414046828899, 10.01104941031293 },
414 { -1.063477247689196, 1.61176085846339 },
415 { -3.396739609433642, 0.7127911050002151 },
416 { 2.438885945896797, 7.353011138689225 },
417 { -0.2073204144780931, 0.850771146627012 }, };
418 }
419 }