1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.List;
21 import java.util.Objects;
22 import java.util.ArrayList;
23
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
26 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
27 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
28 import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
29 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
30 import org.apache.commons.rng.sampling.distribution.LongSampler;
31 import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
32 import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
33 import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
34 import org.apache.commons.rng.sampling.distribution.SharedStateLongSampler;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public final class CompositeSamplers {
71
72
73
74
75
76 public interface DiscreteProbabilitySamplerFactory {
77
78
79
80
81
82
83
84 DiscreteSampler create(UniformRandomProvider rng,
85 double[] probabilities);
86 }
87
88
89
90
91
92
93
94
95 public enum DiscreteProbabilitySampler implements DiscreteProbabilitySamplerFactory {
96
97 GUIDE_TABLE {
98 @Override
99 public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) {
100 return GuideTableDiscreteSampler.of(rng, probabilities);
101 }
102 },
103
104 ALIAS_METHOD {
105 @Override
106 public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) {
107 return AliasMethodDiscreteSampler.of(rng, probabilities);
108 }
109 },
110
111
112
113
114
115 LOOKUP_TABLE {
116 @Override
117 public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) {
118 return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
119 }
120 }
121 }
122
123
124
125
126
127
128 private static class SharedStateDiscreteProbabilitySampler implements SharedStateDiscreteSampler {
129
130 private final DiscreteSampler sampler;
131
132 private final DiscreteProbabilitySamplerFactory factory;
133
134 private final double[] probabilities;
135
136
137
138
139
140
141
142 SharedStateDiscreteProbabilitySampler(DiscreteSampler sampler,
143 DiscreteProbabilitySamplerFactory factory,
144 double[] probabilities) {
145 this.sampler = Objects.requireNonNull(sampler, "discrete sampler");
146
147 this.factory = factory;
148 this.probabilities = probabilities;
149 }
150
151 @Override
152 public int sample() {
153
154 return sampler.sample();
155 }
156
157 @Override
158 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
159
160 return new SharedStateDiscreteProbabilitySampler(factory.create(rng, probabilities.clone()),
161 factory, probabilities);
162 }
163 }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187 public interface Builder<S> {
188
189
190
191
192
193
194 int size();
195
196
197
198
199
200
201
202
203
204
205 Builder<S> add(S sampler, double weight);
206
207
208
209
210
211
212
213
214
215
216
217 Builder<S> setFactory(DiscreteProbabilitySamplerFactory factory);
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232 S build(UniformRandomProvider rng);
233 }
234
235
236
237
238
239
240
241
242
243
244
245
246 private static class SamplerBuilder<S> implements Builder<S> {
247
248 private final Specialisation specialisation;
249
250 private final List<WeightedSampler<S>> weightedSamplers;
251
252 private DiscreteProbabilitySamplerFactory factory;
253
254 private final SamplerFactory<S> compositeFactory;
255
256
257
258
259
260
261 enum Specialisation {
262
263 SHARED_STATE_SAMPLER,
264
265 NONE
266 }
267
268
269
270
271
272
273
274
275
276 interface SamplerFactory<S> {
277
278
279
280
281
282
283
284
285
286
287
288
289 S createSampler(DiscreteSampler discreteSampler,
290 List<S> samplers);
291 }
292
293
294
295
296
297
298 private static class WeightedSampler<S> {
299
300 private final double weight;
301
302 private final S sampler;
303
304
305
306
307
308
309
310 WeightedSampler(double weight, S sampler) {
311 this.weight = requirePositiveFinite(weight, "weight");
312 this.sampler = Objects.requireNonNull(sampler, "sampler");
313 }
314
315
316
317
318
319
320 double getWeight() {
321 return weight;
322 }
323
324
325
326
327
328
329 S getSampler() {
330 return sampler;
331 }
332
333
334
335
336
337
338
339
340
341
342
343 private static double requirePositiveFinite(double value, String message) {
344
345 if (!(value >= 0 && value < Double.POSITIVE_INFINITY)) {
346 throw new IllegalArgumentException(message + " is not positive finite: " + value);
347 }
348 return value;
349 }
350 }
351
352
353
354
355
356 SamplerBuilder(Specialisation specialisation,
357 SamplerFactory<S> compositeFactory) {
358 this.specialisation = specialisation;
359 this.compositeFactory = compositeFactory;
360 weightedSamplers = new ArrayList<>();
361 factory = DiscreteProbabilitySampler.GUIDE_TABLE;
362 }
363
364 @Override
365 public int size() {
366 return weightedSamplers.size();
367 }
368
369 @Override
370 public Builder<S> add(S sampler, double weight) {
371
372 if (weight != 0) {
373 weightedSamplers.add(new WeightedSampler<>(weight, sampler));
374 }
375 return this;
376 }
377
378
379
380
381
382
383
384 @Override
385 public Builder<S> setFactory(DiscreteProbabilitySamplerFactory samplerFactory) {
386 this.factory = Objects.requireNonNull(samplerFactory, "factory");
387 return this;
388 }
389
390
391
392
393
394
395
396
397
398 @Override
399 public S build(UniformRandomProvider rng) {
400 final List<WeightedSampler<S>> list = this.weightedSamplers;
401 final int n = list.size();
402 if (n == 0) {
403 throw new IllegalStateException("No samplers to build the composite");
404 }
405 if (n == 1) {
406
407 final S sampler = list.get(0).sampler;
408 reset();
409 return sampler;
410 }
411
412
413 final double[] weights = new double[n];
414 final ArrayList<S> samplers = new ArrayList<>(n);
415 for (int i = 0; i < n; i++) {
416 final WeightedSampler<S> weightedItem = list.get(i);
417 weights[i] = weightedItem.getWeight();
418 samplers.add(weightedItem.getSampler());
419 }
420
421 reset();
422
423 final DiscreteSampler discreteSampler = createDiscreteSampler(rng, weights);
424
425 return compositeFactory.createSampler(discreteSampler, samplers);
426 }
427
428
429
430
431 private void reset() {
432 weightedSamplers.clear();
433 }
434
435
436
437
438
439
440
441
442
443
444
445 private DiscreteSampler createDiscreteSampler(UniformRandomProvider rng,
446 double[] weights) {
447
448 final int n = weights.length;
449 if (uniform(weights)) {
450
451
452 return DiscreteUniformSampler.of(rng, 0, n - 1);
453 }
454
455
456 final double sum = sum(weights);
457 if (sum < Double.POSITIVE_INFINITY) {
458
459
460 for (int i = 0; i < n; i++) {
461 weights[i] /= sum;
462 }
463 } else {
464
465
466 final double mean = mean(weights);
467 for (int i = 0; i < n; i++) {
468
469 weights[i] = weights[i] / mean / n;
470 }
471 }
472
473
474
475
476
477 if (specialisation == Specialisation.SHARED_STATE_SAMPLER &&
478 !(factory instanceof DiscreteProbabilitySampler)) {
479
480
481 final DiscreteSampler sampler = factory.create(rng, weights.clone());
482 return sampler instanceof SharedStateDiscreteSampler ?
483 sampler :
484 new SharedStateDiscreteProbabilitySampler(sampler, factory, weights);
485 }
486
487 return factory.create(rng, weights);
488 }
489
490
491
492
493
494
495
496
497
498
499 private static boolean uniform(double[] values) {
500 final double value = values[0];
501 for (int i = 1; i < values.length; i++) {
502 if (value != values[i]) {
503 return false;
504 }
505 }
506 return true;
507 }
508
509
510
511
512
513
514
515 private static double sum(double[] values) {
516 double sum = 0;
517 for (final double value : values) {
518 sum += value;
519 }
520 return sum;
521 }
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544 private static double mean(double[] values) {
545 double mean = values[0];
546 int i = 1;
547 while (i < values.length) {
548
549 final double dev = values[i] - mean;
550 i++;
551 mean += dev / i;
552 }
553 return mean;
554 }
555 }
556
557
558
559
560
561
562
563
564
565 private static class CompositeSampler<S> {
566
567 protected final DiscreteSampler discreteSampler;
568
569 protected final List<S> samplers;
570
571
572
573
574
575 CompositeSampler(DiscreteSampler discreteSampler,
576 List<S> samplers) {
577 this.discreteSampler = discreteSampler;
578 this.samplers = samplers;
579 }
580
581
582
583
584
585
586 S nextSampler() {
587 return samplers.get(discreteSampler.sample());
588 }
589 }
590
591
592
593
594
595
596 private static class ObjectSamplerFactory<T> implements
597 SamplerBuilder.SamplerFactory<ObjectSampler<T>> {
598
599 @SuppressWarnings("rawtypes")
600 private static final ObjectSamplerFactory INSTANCE = new ObjectSamplerFactory();
601
602
603
604
605
606
607
608 @SuppressWarnings("unchecked")
609 static <T> ObjectSamplerFactory<T> instance() {
610 return (ObjectSamplerFactory<T>) INSTANCE;
611 }
612
613 @Override
614 public ObjectSampler<T> createSampler(DiscreteSampler discreteSampler,
615 List<ObjectSampler<T>> samplers) {
616 return new CompositeObjectSampler<>(discreteSampler, samplers);
617 }
618
619
620
621
622
623
624 private static class CompositeObjectSampler<T>
625 extends CompositeSampler<ObjectSampler<T>>
626 implements ObjectSampler<T> {
627
628
629
630
631 CompositeObjectSampler(DiscreteSampler discreteSampler,
632 List<ObjectSampler<T>> samplers) {
633 super(discreteSampler, samplers);
634 }
635
636 @Override
637 public T sample() {
638 return nextSampler().sample();
639 }
640 }
641 }
642
643
644
645
646
647
648 private static class SharedStateObjectSamplerFactory<T> implements
649 SamplerBuilder.SamplerFactory<SharedStateObjectSampler<T>> {
650
651 @SuppressWarnings("rawtypes")
652 private static final SharedStateObjectSamplerFactory INSTANCE = new SharedStateObjectSamplerFactory();
653
654
655
656
657
658
659
660 @SuppressWarnings("unchecked")
661 static <T> SharedStateObjectSamplerFactory<T> instance() {
662 return (SharedStateObjectSamplerFactory<T>) INSTANCE;
663 }
664
665 @Override
666 public SharedStateObjectSampler<T> createSampler(DiscreteSampler discreteSampler,
667 List<SharedStateObjectSampler<T>> samplers) {
668
669 return new CompositeSharedStateObjectSampler<>(
670 (SharedStateDiscreteSampler) discreteSampler, samplers);
671 }
672
673
674
675
676
677
678
679
680
681 private static class CompositeSharedStateObjectSampler<T>
682 extends CompositeSampler<SharedStateObjectSampler<T>>
683 implements SharedStateObjectSampler<T> {
684
685
686
687
688 CompositeSharedStateObjectSampler(SharedStateDiscreteSampler discreteSampler,
689 List<SharedStateObjectSampler<T>> samplers) {
690 super(discreteSampler, samplers);
691 }
692
693 @Override
694 public T sample() {
695 return nextSampler().sample();
696 }
697
698 @Override
699 public CompositeSharedStateObjectSampler<T> withUniformRandomProvider(UniformRandomProvider rng) {
700
701 return new CompositeSharedStateObjectSampler<>(
702 ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng),
703 copy(samplers, rng));
704 }
705 }
706 }
707
708
709
710
711 private static class DiscreteSamplerFactory implements
712 SamplerBuilder.SamplerFactory<DiscreteSampler> {
713
714 static final DiscreteSamplerFactory INSTANCE = new DiscreteSamplerFactory();
715
716 @Override
717 public DiscreteSampler createSampler(DiscreteSampler discreteSampler,
718 List<DiscreteSampler> samplers) {
719 return new CompositeDiscreteSampler(discreteSampler, samplers);
720 }
721
722
723
724
725 private static class CompositeDiscreteSampler
726 extends CompositeSampler<DiscreteSampler>
727 implements DiscreteSampler {
728
729
730
731
732 CompositeDiscreteSampler(DiscreteSampler discreteSampler,
733 List<DiscreteSampler> samplers) {
734 super(discreteSampler, samplers);
735 }
736
737 @Override
738 public int sample() {
739 return nextSampler().sample();
740 }
741 }
742 }
743
744
745
746
747 private static class SharedStateDiscreteSamplerFactory implements
748 SamplerBuilder.SamplerFactory<SharedStateDiscreteSampler> {
749
750 static final SharedStateDiscreteSamplerFactory INSTANCE = new SharedStateDiscreteSamplerFactory();
751
752 @Override
753 public SharedStateDiscreteSampler createSampler(DiscreteSampler discreteSampler,
754 List<SharedStateDiscreteSampler> samplers) {
755
756 return new CompositeSharedStateDiscreteSampler(
757 (SharedStateDiscreteSampler) discreteSampler, samplers);
758 }
759
760
761
762
763 private static class CompositeSharedStateDiscreteSampler
764 extends CompositeSampler<SharedStateDiscreteSampler>
765 implements SharedStateDiscreteSampler {
766
767
768
769
770 CompositeSharedStateDiscreteSampler(SharedStateDiscreteSampler discreteSampler,
771 List<SharedStateDiscreteSampler> samplers) {
772 super(discreteSampler, samplers);
773 }
774
775 @Override
776 public int sample() {
777 return nextSampler().sample();
778 }
779
780 @Override
781 public CompositeSharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
782
783 return new CompositeSharedStateDiscreteSampler(
784 ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng),
785 copy(samplers, rng));
786 }
787 }
788 }
789
790
791
792
793 private static class ContinuousSamplerFactory implements
794 SamplerBuilder.SamplerFactory<ContinuousSampler> {
795
796 static final ContinuousSamplerFactory INSTANCE = new ContinuousSamplerFactory();
797
798 @Override
799 public ContinuousSampler createSampler(DiscreteSampler discreteSampler,
800 List<ContinuousSampler> samplers) {
801 return new CompositeContinuousSampler(discreteSampler, samplers);
802 }
803
804
805
806
807 private static class CompositeContinuousSampler
808 extends CompositeSampler<ContinuousSampler>
809 implements ContinuousSampler {
810
811
812
813
814 CompositeContinuousSampler(DiscreteSampler discreteSampler,
815 List<ContinuousSampler> samplers) {
816 super(discreteSampler, samplers);
817 }
818
819 @Override
820 public double sample() {
821 return nextSampler().sample();
822 }
823 }
824 }
825
826
827
828
829 private static class SharedStateContinuousSamplerFactory implements
830 SamplerBuilder.SamplerFactory<SharedStateContinuousSampler> {
831
832 static final SharedStateContinuousSamplerFactory INSTANCE = new SharedStateContinuousSamplerFactory();
833
834 @Override
835 public SharedStateContinuousSampler createSampler(DiscreteSampler discreteSampler,
836 List<SharedStateContinuousSampler> samplers) {
837
838 return new CompositeSharedStateContinuousSampler(
839 (SharedStateDiscreteSampler) discreteSampler, samplers);
840 }
841
842
843
844
845 private static class CompositeSharedStateContinuousSampler
846 extends CompositeSampler<SharedStateContinuousSampler>
847 implements SharedStateContinuousSampler {
848
849
850
851
852 CompositeSharedStateContinuousSampler(SharedStateDiscreteSampler discreteSampler,
853 List<SharedStateContinuousSampler> samplers) {
854 super(discreteSampler, samplers);
855 }
856
857 @Override
858 public double sample() {
859 return nextSampler().sample();
860 }
861
862 @Override
863 public CompositeSharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
864
865 return new CompositeSharedStateContinuousSampler(
866 ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng),
867 copy(samplers, rng));
868 }
869 }
870 }
871
872
873
874
875 private static class LongSamplerFactory implements
876 SamplerBuilder.SamplerFactory<LongSampler> {
877
878 static final LongSamplerFactory INSTANCE = new LongSamplerFactory();
879
880 @Override
881 public LongSampler createSampler(DiscreteSampler discreteSampler,
882 List<LongSampler> samplers) {
883 return new CompositeLongSampler(discreteSampler, samplers);
884 }
885
886
887
888
889 private static class CompositeLongSampler
890 extends CompositeSampler<LongSampler>
891 implements LongSampler {
892
893
894
895
896 CompositeLongSampler(DiscreteSampler discreteSampler,
897 List<LongSampler> samplers) {
898 super(discreteSampler, samplers);
899 }
900
901 @Override
902 public long sample() {
903 return nextSampler().sample();
904 }
905 }
906 }
907
908
909
910
911 private static class SharedStateLongSamplerFactory implements
912 SamplerBuilder.SamplerFactory<SharedStateLongSampler> {
913
914 static final SharedStateLongSamplerFactory INSTANCE = new SharedStateLongSamplerFactory();
915
916 @Override
917 public SharedStateLongSampler createSampler(DiscreteSampler discreteSampler,
918 List<SharedStateLongSampler> samplers) {
919
920 return new CompositeSharedStateLongSampler(
921 (SharedStateDiscreteSampler) discreteSampler, samplers);
922 }
923
924
925
926
927 private static class CompositeSharedStateLongSampler
928 extends CompositeSampler<SharedStateLongSampler>
929 implements SharedStateLongSampler {
930
931
932
933
934 CompositeSharedStateLongSampler(SharedStateDiscreteSampler discreteSampler,
935 List<SharedStateLongSampler> samplers) {
936 super(discreteSampler, samplers);
937 }
938
939 @Override
940 public long sample() {
941 return nextSampler().sample();
942 }
943
944 @Override
945 public CompositeSharedStateLongSampler withUniformRandomProvider(UniformRandomProvider rng) {
946
947 return new CompositeSharedStateLongSampler(
948 ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng),
949 copy(samplers, rng));
950 }
951 }
952 }
953
954
955 private CompositeSamplers() {}
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971 public static <T> Builder<ObjectSampler<T>> newObjectSamplerBuilder() {
972 final SamplerBuilder.SamplerFactory<ObjectSampler<T>> factory = ObjectSamplerFactory.instance();
973 return new SamplerBuilder<>(
974 SamplerBuilder.Specialisation.NONE, factory);
975 }
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991 public static <T> Builder<SharedStateObjectSampler<T>> newSharedStateObjectSamplerBuilder() {
992 final SamplerBuilder.SamplerFactory<SharedStateObjectSampler<T>> factory =
993 SharedStateObjectSamplerFactory.instance();
994 return new SamplerBuilder<>(
995 SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, factory);
996 }
997
998
999
1000
1001
1002
1003 public static Builder<DiscreteSampler> newDiscreteSamplerBuilder() {
1004 return new SamplerBuilder<>(
1005 SamplerBuilder.Specialisation.NONE, DiscreteSamplerFactory.INSTANCE);
1006 }
1007
1008
1009
1010
1011
1012
1013 public static Builder<SharedStateDiscreteSampler> newSharedStateDiscreteSamplerBuilder() {
1014 return new SamplerBuilder<>(
1015 SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateDiscreteSamplerFactory.INSTANCE);
1016 }
1017
1018
1019
1020
1021
1022
1023 public static Builder<ContinuousSampler> newContinuousSamplerBuilder() {
1024 return new SamplerBuilder<>(
1025 SamplerBuilder.Specialisation.NONE, ContinuousSamplerFactory.INSTANCE);
1026 }
1027
1028
1029
1030
1031
1032
1033 public static Builder<SharedStateContinuousSampler> newSharedStateContinuousSamplerBuilder() {
1034 return new SamplerBuilder<>(
1035 SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateContinuousSamplerFactory.INSTANCE);
1036 }
1037
1038
1039
1040
1041
1042
1043 public static Builder<LongSampler> newLongSamplerBuilder() {
1044 return new SamplerBuilder<>(
1045 SamplerBuilder.Specialisation.NONE, LongSamplerFactory.INSTANCE);
1046 }
1047
1048
1049
1050
1051
1052
1053 public static Builder<SharedStateLongSampler> newSharedStateLongSamplerBuilder() {
1054 return new SamplerBuilder<>(
1055 SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateLongSamplerFactory.INSTANCE);
1056 }
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067 private static <T extends SharedStateSampler<T>> List<T> copy(List<T> samplers,
1068 UniformRandomProvider rng) {
1069 final ArrayList<T> newSamplers = new ArrayList<>(samplers.size());
1070 for (final T s : samplers) {
1071 newSamplers.add(s.withUniformRandomProvider(rng));
1072 }
1073 return newSamplers;
1074 }
1075 }