1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.distribution;
18
19 import java.util.ArrayList;
20 import java.util.List;
21
22 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
23 import org.apache.commons.math4.legacy.exception.MathArithmeticException;
24 import org.apache.commons.math4.legacy.exception.NotPositiveException;
25 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.apache.commons.math4.legacy.core.Pair;
28
29
30
31
32
33
34
35
36
37 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
38 extends AbstractMultivariateRealDistribution {
39
40 private final double[] weight;
41
42 private final List<T> distribution;
43
44
45
46
47
48
49
50
51
52
53 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
54 super(components.get(0).getSecond().getDimension());
55
56 final int numComp = components.size();
57 final int dim = getDimension();
58 double weightSum = 0;
59 for (int i = 0; i < numComp; i++) {
60 final Pair<Double, T> comp = components.get(i);
61 if (comp.getSecond().getDimension() != dim) {
62 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
63 }
64 if (comp.getFirst() < 0) {
65 throw new NotPositiveException(comp.getFirst());
66 }
67 weightSum += comp.getFirst();
68 }
69
70
71 if (Double.isInfinite(weightSum)) {
72 throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
73 }
74
75
76 distribution = new ArrayList<>();
77 weight = new double[numComp];
78 for (int i = 0; i < numComp; i++) {
79 final Pair<Double, T> comp = components.get(i);
80 weight[i] = comp.getFirst() / weightSum;
81 distribution.add(comp.getSecond());
82 }
83 }
84
85
86 @Override
87 public double density(final double[] values) {
88 double p = 0;
89 for (int i = 0; i < weight.length; i++) {
90 p += weight[i] * distribution.get(i).density(values);
91 }
92 return p;
93 }
94
95
96
97
98
99
100 public List<Pair<Double, T>> getComponents() {
101 final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
102
103 for (int i = 0; i < weight.length; i++) {
104 list.add(new Pair<>(weight[i], distribution.get(i)));
105 }
106
107 return list;
108 }
109
110
111 @Override
112 public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) {
113 return new MixtureSampler(rng);
114 }
115
116
117
118
119 private final class MixtureSampler implements MultivariateRealDistribution.Sampler {
120
121 private final UniformRandomProvider rng;
122
123 private final MultivariateRealDistribution.Sampler[] samplers;
124
125
126
127
128 MixtureSampler(UniformRandomProvider generator) {
129 rng = generator;
130
131 samplers = new MultivariateRealDistribution.Sampler[weight.length];
132 for (int i = 0; i < weight.length; i++) {
133 samplers[i] = distribution.get(i).createSampler(rng);
134 }
135 }
136
137
138 @Override
139 public double[] sample() {
140
141 double[] vals = null;
142
143
144 final double randomValue = rng.nextDouble();
145 double sum = 0;
146
147 for (int i = 0; i < weight.length; i++) {
148 sum += weight[i];
149 if (randomValue <= sum) {
150
151 vals = samplers[i].sample();
152 break;
153 }
154 }
155
156 if (vals == null) {
157
158
159
160 vals = samplers[weight.length - 1].sample();
161 }
162
163 return vals;
164 }
165 }
166 }