1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.analysis.differentiation;
18
19 import java.util.Collections;
20 import java.util.HashMap;
21 import java.util.Map;
22
23 import org.apache.commons.numbers.core.Sum;
24 import org.apache.commons.numbers.core.Precision;
25 import org.apache.commons.math4.legacy.core.Field;
26 import org.apache.commons.math4.legacy.core.FieldElement;
27 import org.apache.commons.math4.legacy.core.RealFieldElement;
28 import org.apache.commons.math4.core.jdkmath.JdkMath;
29
30
31
32
33
34
35
36
37
38
39
40
41
42 public final class SparseGradient implements RealFieldElement<SparseGradient> {
43
44 private double value;
45
46
47 private final Map<Integer, Double> derivatives;
48
49
50
51
52
53
54
55 private SparseGradient(final double value, final Map<Integer, Double> derivatives) {
56 this.value = value;
57 this.derivatives = new HashMap<>();
58 if (derivatives != null) {
59 this.derivatives.putAll(derivatives);
60 }
61 }
62
63
64
65
66
67
68
69
70 private SparseGradient(final double value, final double scale,
71 final Map<Integer, Double> derivatives) {
72 this.value = value;
73 this.derivatives = new HashMap<>();
74 if (derivatives != null) {
75 for (final Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
76 this.derivatives.put(entry.getKey(), scale * entry.getValue());
77 }
78 }
79 }
80
81
82
83
84
85 public static SparseGradient createConstant(final double value) {
86 return new SparseGradient(value, Collections.<Integer, Double>emptyMap());
87 }
88
89
90
91
92
93
94 public static SparseGradient createVariable(final int idx, final double value) {
95 return new SparseGradient(value, Collections.singletonMap(idx, 1.0));
96 }
97
98
99
100
101
102 public int numVars() {
103 return derivatives.size();
104 }
105
106
107
108
109
110
111
112 public double getDerivative(final int index) {
113 final Double out = derivatives.get(index);
114 return (out == null) ? 0.0 : out;
115 }
116
117
118
119
120
121 public double getValue() {
122 return value;
123 }
124
125
126 @Override
127 public double getReal() {
128 return value;
129 }
130
131
132 @Override
133 public SparseGradient add(final SparseGradient a) {
134 final SparseGradient out = new SparseGradient(value + a.value, derivatives);
135 for (Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
136 final int id = entry.getKey();
137 final Double old = out.derivatives.get(id);
138 if (old == null) {
139 out.derivatives.put(id, entry.getValue());
140 } else {
141 out.derivatives.put(id, old + entry.getValue());
142 }
143 }
144
145 return out;
146 }
147
148
149
150
151
152
153
154
155
156
157
158
159
160 public void addInPlace(final SparseGradient a) {
161 value += a.value;
162 for (final Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
163 final int id = entry.getKey();
164 final Double old = derivatives.get(id);
165 if (old == null) {
166 derivatives.put(id, entry.getValue());
167 } else {
168 derivatives.put(id, old + entry.getValue());
169 }
170 }
171 }
172
173
174 @Override
175 public SparseGradient add(final double c) {
176 return new SparseGradient(value + c, derivatives);
177 }
178
179
180 @Override
181 public SparseGradient subtract(final SparseGradient a) {
182 final SparseGradient out = new SparseGradient(value - a.value, derivatives);
183 for (Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
184 final int id = entry.getKey();
185 final Double old = out.derivatives.get(id);
186 if (old == null) {
187 out.derivatives.put(id, -entry.getValue());
188 } else {
189 out.derivatives.put(id, old - entry.getValue());
190 }
191 }
192 return out;
193 }
194
195
196 @Override
197 public SparseGradient subtract(double c) {
198 return new SparseGradient(value - c, derivatives);
199 }
200
201
202 @Override
203 public SparseGradient multiply(final SparseGradient a) {
204 final SparseGradient out =
205 new SparseGradient(value * a.value, Collections.<Integer, Double>emptyMap());
206
207
208 for (Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
209 out.derivatives.put(entry.getKey(), a.value * entry.getValue());
210 }
211 for (Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
212 final int id = entry.getKey();
213 final Double old = out.derivatives.get(id);
214 if (old == null) {
215 out.derivatives.put(id, value * entry.getValue());
216 } else {
217 out.derivatives.put(id, old + value * entry.getValue());
218 }
219 }
220 return out;
221 }
222
223
224
225
226
227
228
229
230
231
232
233
234
235 public void multiplyInPlace(final SparseGradient a) {
236
237 for (Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
238 derivatives.put(entry.getKey(), a.value * entry.getValue());
239 }
240 for (Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
241 final int id = entry.getKey();
242 final Double old = derivatives.get(id);
243 if (old == null) {
244 derivatives.put(id, value * entry.getValue());
245 } else {
246 derivatives.put(id, old + value * entry.getValue());
247 }
248 }
249 value *= a.value;
250 }
251
252
253 @Override
254 public SparseGradient multiply(final double c) {
255 return new SparseGradient(value * c, c, derivatives);
256 }
257
258
259 @Override
260 public SparseGradient multiply(final int n) {
261 return new SparseGradient(value * n, n, derivatives);
262 }
263
264
265 @Override
266 public SparseGradient divide(final SparseGradient a) {
267 final SparseGradient out = new SparseGradient(value / a.value, Collections.<Integer, Double>emptyMap());
268
269
270 for (Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
271 out.derivatives.put(entry.getKey(), entry.getValue() / a.value);
272 }
273 for (Map.Entry<Integer, Double> entry : a.derivatives.entrySet()) {
274 final int id = entry.getKey();
275 final Double old = out.derivatives.get(id);
276 if (old == null) {
277 out.derivatives.put(id, -out.value / a.value * entry.getValue());
278 } else {
279 out.derivatives.put(id, old - out.value / a.value * entry.getValue());
280 }
281 }
282 return out;
283 }
284
285
286 @Override
287 public SparseGradient divide(final double c) {
288 return new SparseGradient(value / c, 1.0 / c, derivatives);
289 }
290
291
292 @Override
293 public SparseGradient negate() {
294 return new SparseGradient(-value, -1.0, derivatives);
295 }
296
297
298 @Override
299 public Field<SparseGradient> getField() {
300 return new Field<SparseGradient>() {
301
302
303 @Override
304 public SparseGradient getZero() {
305 return createConstant(0);
306 }
307
308
309 @Override
310 public SparseGradient getOne() {
311 return createConstant(1);
312 }
313
314
315 @Override
316 public Class<? extends FieldElement<SparseGradient>> getRuntimeClass() {
317 return SparseGradient.class;
318 }
319 };
320 }
321
322
323 @Override
324 public SparseGradient remainder(final double a) {
325 return new SparseGradient(JdkMath.IEEEremainder(value, a), derivatives);
326 }
327
328
329 @Override
330 public SparseGradient remainder(final SparseGradient a) {
331
332
333 final double rem = JdkMath.IEEEremainder(value, a.value);
334 final double k = JdkMath.rint((value - rem) / a.value);
335
336 return subtract(a.multiply(k));
337 }
338
339
340 @Override
341 public SparseGradient abs() {
342 if (Double.doubleToLongBits(value) < 0) {
343
344 return negate();
345 } else {
346 return this;
347 }
348 }
349
350
351 @Override
352 public SparseGradient ceil() {
353 return createConstant(JdkMath.ceil(value));
354 }
355
356
357 @Override
358 public SparseGradient floor() {
359 return createConstant(JdkMath.floor(value));
360 }
361
362
363 @Override
364 public SparseGradient rint() {
365 return createConstant(JdkMath.rint(value));
366 }
367
368
369 @Override
370 public long round() {
371 return JdkMath.round(value);
372 }
373
374
375 @Override
376 public SparseGradient signum() {
377 return createConstant(JdkMath.signum(value));
378 }
379
380
381 @Override
382 public SparseGradient copySign(final SparseGradient sign) {
383 final long m = Double.doubleToLongBits(value);
384 final long s = Double.doubleToLongBits(sign.value);
385 if ((m >= 0 && s >= 0) || (m < 0 && s < 0)) {
386 return this;
387 }
388 return negate();
389 }
390
391
392 @Override
393 public SparseGradient copySign(final double sign) {
394 final long m = Double.doubleToLongBits(value);
395 final long s = Double.doubleToLongBits(sign);
396 if ((m >= 0 && s >= 0) || (m < 0 && s < 0)) {
397 return this;
398 }
399 return negate();
400 }
401
402
403 @Override
404 public SparseGradient scalb(final int n) {
405 final SparseGradient out = new SparseGradient(JdkMath.scalb(value, n), Collections.<Integer, Double>emptyMap());
406 for (Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
407 out.derivatives.put(entry.getKey(), JdkMath.scalb(entry.getValue(), n));
408 }
409 return out;
410 }
411
412
413 @Override
414 public SparseGradient hypot(final SparseGradient y) {
415 if (Double.isInfinite(value) || Double.isInfinite(y.value)) {
416 return createConstant(Double.POSITIVE_INFINITY);
417 } else if (Double.isNaN(value) || Double.isNaN(y.value)) {
418 return createConstant(Double.NaN);
419 } else {
420
421 final int expX = JdkMath.getExponent(value);
422 final int expY = JdkMath.getExponent(y.value);
423 if (expX > expY + 27) {
424
425 return abs();
426 } else if (expY > expX + 27) {
427
428 return y.abs();
429 } else {
430
431
432 final int middleExp = (expX + expY) / 2;
433
434
435 final SparseGradient scaledX = scalb(-middleExp);
436 final SparseGradient scaledY = y.scalb(-middleExp);
437
438
439 final SparseGradient scaledH =
440 scaledX.multiply(scaledX).add(scaledY.multiply(scaledY)).sqrt();
441
442
443 return scaledH.scalb(middleExp);
444 }
445 }
446 }
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462 public static SparseGradient hypot(final SparseGradient x, final SparseGradient y) {
463 return x.hypot(y);
464 }
465
466
467 @Override
468 public SparseGradient reciprocal() {
469 return new SparseGradient(1.0 / value, -1.0 / (value * value), derivatives);
470 }
471
472
473 @Override
474 public SparseGradient sqrt() {
475 final double sqrt = JdkMath.sqrt(value);
476 return new SparseGradient(sqrt, 0.5 / sqrt, derivatives);
477 }
478
479
480 @Override
481 public SparseGradient cbrt() {
482 final double cbrt = JdkMath.cbrt(value);
483 return new SparseGradient(cbrt, 1.0 / (3 * cbrt * cbrt), derivatives);
484 }
485
486
487 @Override
488 public SparseGradient rootN(final int n) {
489 if (n == 2) {
490 return sqrt();
491 } else if (n == 3) {
492 return cbrt();
493 } else {
494 final double root = JdkMath.pow(value, 1.0 / n);
495 return new SparseGradient(root, 1.0 / (n * JdkMath.pow(root, n - 1)), derivatives);
496 }
497 }
498
499
500 @Override
501 public SparseGradient pow(final double p) {
502 return new SparseGradient(JdkMath.pow(value, p), p * JdkMath.pow(value, p - 1), derivatives);
503 }
504
505
506 @Override
507 public SparseGradient pow(final int n) {
508 if (n == 0) {
509 return getField().getOne();
510 } else {
511 final double valueNm1 = JdkMath.pow(value, n - 1);
512 return new SparseGradient(value * valueNm1, n * valueNm1, derivatives);
513 }
514 }
515
516
517 @Override
518 public SparseGradient pow(final SparseGradient e) {
519 return log().multiply(e).exp();
520 }
521
522
523
524
525
526
527 public static SparseGradient pow(final double a, final SparseGradient x) {
528 if (a == 0) {
529 if (x.value == 0) {
530 return x.compose(1.0, Double.NEGATIVE_INFINITY);
531 } else if (x.value < 0) {
532 return x.compose(Double.NaN, Double.NaN);
533 } else {
534 return x.getField().getZero();
535 }
536 } else {
537 final double ax = JdkMath.pow(a, x.value);
538 return new SparseGradient(ax, ax * JdkMath.log(a), x.derivatives);
539 }
540 }
541
542
543 @Override
544 public SparseGradient exp() {
545 final double e = JdkMath.exp(value);
546 return new SparseGradient(e, e, derivatives);
547 }
548
549
550 @Override
551 public SparseGradient expm1() {
552 return new SparseGradient(JdkMath.expm1(value), JdkMath.exp(value), derivatives);
553 }
554
555
556 @Override
557 public SparseGradient log() {
558 return new SparseGradient(JdkMath.log(value), 1.0 / value, derivatives);
559 }
560
561
562
563
564 @Override
565 public SparseGradient log10() {
566 return new SparseGradient(JdkMath.log10(value), 1.0 / (JdkMath.log(10.0) * value), derivatives);
567 }
568
569
570 @Override
571 public SparseGradient log1p() {
572 return new SparseGradient(JdkMath.log1p(value), 1.0 / (1.0 + value), derivatives);
573 }
574
575
576 @Override
577 public SparseGradient cos() {
578 return new SparseGradient(JdkMath.cos(value), -JdkMath.sin(value), derivatives);
579 }
580
581
582 @Override
583 public SparseGradient sin() {
584 return new SparseGradient(JdkMath.sin(value), JdkMath.cos(value), derivatives);
585 }
586
587
588 @Override
589 public SparseGradient tan() {
590 final double t = JdkMath.tan(value);
591 return new SparseGradient(t, 1 + t * t, derivatives);
592 }
593
594
595 @Override
596 public SparseGradient acos() {
597 return new SparseGradient(JdkMath.acos(value), -1.0 / JdkMath.sqrt(1 - value * value), derivatives);
598 }
599
600
601 @Override
602 public SparseGradient asin() {
603 return new SparseGradient(JdkMath.asin(value), 1.0 / JdkMath.sqrt(1 - value * value), derivatives);
604 }
605
606
607 @Override
608 public SparseGradient atan() {
609 return new SparseGradient(JdkMath.atan(value), 1.0 / (1 + value * value), derivatives);
610 }
611
612
613 @Override
614 public SparseGradient atan2(final SparseGradient x) {
615
616
617 final SparseGradient r = multiply(this).add(x.multiply(x)).sqrt();
618
619 final SparseGradient a;
620 if (x.value >= 0) {
621
622
623 a = divide(r.add(x)).atan().multiply(2);
624 } else {
625
626
627 final SparseGradient tmp = divide(r.subtract(x)).atan().multiply(-2);
628 a = tmp.add(tmp.value <= 0 ? -JdkMath.PI : JdkMath.PI);
629 }
630
631
632 a.value = JdkMath.atan2(value, x.value);
633
634 return a;
635 }
636
637
638
639
640
641
642 public static SparseGradient atan2(final SparseGradient y, final SparseGradient x) {
643 return y.atan2(x);
644 }
645
646
647 @Override
648 public SparseGradient cosh() {
649 return new SparseGradient(JdkMath.cosh(value), JdkMath.sinh(value), derivatives);
650 }
651
652
653 @Override
654 public SparseGradient sinh() {
655 return new SparseGradient(JdkMath.sinh(value), JdkMath.cosh(value), derivatives);
656 }
657
658
659 @Override
660 public SparseGradient tanh() {
661 final double t = JdkMath.tanh(value);
662 return new SparseGradient(t, 1 - t * t, derivatives);
663 }
664
665
666 @Override
667 public SparseGradient acosh() {
668 return new SparseGradient(JdkMath.acosh(value), 1.0 / JdkMath.sqrt(value * value - 1.0), derivatives);
669 }
670
671
672 @Override
673 public SparseGradient asinh() {
674 return new SparseGradient(JdkMath.asinh(value), 1.0 / JdkMath.sqrt(value * value + 1.0), derivatives);
675 }
676
677
678 @Override
679 public SparseGradient atanh() {
680 return new SparseGradient(JdkMath.atanh(value), 1.0 / (1.0 - value * value), derivatives);
681 }
682
683
684
685
686 public SparseGradient toDegrees() {
687 return new SparseGradient(JdkMath.toDegrees(value), JdkMath.toDegrees(1.0), derivatives);
688 }
689
690
691
692
693 public SparseGradient toRadians() {
694 return new SparseGradient(JdkMath.toRadians(value), JdkMath.toRadians(1.0), derivatives);
695 }
696
697
698
699
700
701 public double taylor(final double ... delta) {
702 double y = value;
703 for (int i = 0; i < delta.length; ++i) {
704 y += delta[i] * getDerivative(i);
705 }
706 return y;
707 }
708
709
710
711
712
713
714
715 public SparseGradient compose(final double f0, final double f1) {
716 return new SparseGradient(f0, f1, derivatives);
717 }
718
719
720 @Override
721 public SparseGradient linearCombination(final SparseGradient[] a,
722 final SparseGradient[] b) {
723
724 SparseGradient out = a[0].getField().getZero();
725 for (int i = 0; i < a.length; ++i) {
726 out = out.add(a[i].multiply(b[i]));
727 }
728
729
730 final double[] aDouble = new double[a.length];
731 for (int i = 0; i < a.length; ++i) {
732 aDouble[i] = a[i].getValue();
733 }
734 final double[] bDouble = new double[b.length];
735 for (int i = 0; i < b.length; ++i) {
736 bDouble[i] = b[i].getValue();
737 }
738 out.value = Sum.ofProducts(aDouble, bDouble).getAsDouble();
739
740 return out;
741 }
742
743
744 @Override
745 public SparseGradient linearCombination(final double[] a, final SparseGradient[] b) {
746
747
748 SparseGradient out = b[0].getField().getZero();
749 for (int i = 0; i < a.length; ++i) {
750 out = out.add(b[i].multiply(a[i]));
751 }
752
753
754 final double[] bDouble = new double[b.length];
755 for (int i = 0; i < b.length; ++i) {
756 bDouble[i] = b[i].getValue();
757 }
758 out.value = Sum.ofProducts(a, bDouble).getAsDouble();
759
760 return out;
761 }
762
763
764 @Override
765 public SparseGradient linearCombination(final SparseGradient a1, final SparseGradient b1,
766 final SparseGradient a2, final SparseGradient b2) {
767
768
769 SparseGradient out = a1.multiply(b1).add(a2.multiply(b2));
770
771
772 out.value = Sum.create()
773 .addProduct(a1.value, b1.value)
774 .addProduct(a2.value, b2.value).getAsDouble();
775
776 return out;
777 }
778
779
780 @Override
781 public SparseGradient linearCombination(final double a1, final SparseGradient b1,
782 final double a2, final SparseGradient b2) {
783
784
785 SparseGradient out = b1.multiply(a1).add(b2.multiply(a2));
786
787
788 out.value = Sum.create()
789 .addProduct(a1, b1.value)
790 .addProduct(a2, b2.value).getAsDouble();
791
792 return out;
793 }
794
795
796 @Override
797 public SparseGradient linearCombination(final SparseGradient a1, final SparseGradient b1,
798 final SparseGradient a2, final SparseGradient b2,
799 final SparseGradient a3, final SparseGradient b3) {
800
801
802 SparseGradient out = a1.multiply(b1).add(a2.multiply(b2)).add(a3.multiply(b3));
803
804
805 out.value = Sum.create()
806 .addProduct(a1.value, b1.value)
807 .addProduct(a2.value, b2.value)
808 .addProduct(a3.value, b3.value).getAsDouble();
809
810 return out;
811 }
812
813
814 @Override
815 public SparseGradient linearCombination(final double a1, final SparseGradient b1,
816 final double a2, final SparseGradient b2,
817 final double a3, final SparseGradient b3) {
818
819
820 SparseGradient out = b1.multiply(a1).add(b2.multiply(a2)).add(b3.multiply(a3));
821
822
823 out.value = Sum.create()
824 .addProduct(a1, b1.value)
825 .addProduct(a2, b2.value)
826 .addProduct(a3, b3.value).getAsDouble();
827
828 return out;
829 }
830
831
832 @Override
833 public SparseGradient linearCombination(final SparseGradient a1, final SparseGradient b1,
834 final SparseGradient a2, final SparseGradient b2,
835 final SparseGradient a3, final SparseGradient b3,
836 final SparseGradient a4, final SparseGradient b4) {
837
838
839 SparseGradient out = a1.multiply(b1).add(a2.multiply(b2)).add(a3.multiply(b3)).add(a4.multiply(b4));
840
841
842 out.value = Sum.create()
843 .addProduct(a1.value, b1.value)
844 .addProduct(a2.value, b2.value)
845 .addProduct(a3.value, b3.value)
846 .addProduct(a4.value, b4.value).getAsDouble();
847
848 return out;
849 }
850
851
852 @Override
853 public SparseGradient linearCombination(final double a1, final SparseGradient b1,
854 final double a2, final SparseGradient b2,
855 final double a3, final SparseGradient b3,
856 final double a4, final SparseGradient b4) {
857
858
859 SparseGradient out = b1.multiply(a1).add(b2.multiply(a2)).add(b3.multiply(a3)).add(b4.multiply(a4));
860
861
862 out.value = Sum.create()
863 .addProduct(a1, b1.value)
864 .addProduct(a2, b2.value)
865 .addProduct(a3, b3.value)
866 .addProduct(a4, b4.value).getAsDouble();
867
868 return out;
869 }
870
871
872
873
874
875
876
877
878
879
880 @Override
881 public boolean equals(Object other) {
882
883 if (this == other) {
884 return true;
885 }
886
887 if (other instanceof SparseGradient) {
888 final SparseGradient rhs = (SparseGradient)other;
889 if (!Precision.equals(value, rhs.value, 1)) {
890 return false;
891 }
892 if (derivatives.size() != rhs.derivatives.size()) {
893 return false;
894 }
895 for (final Map.Entry<Integer, Double> entry : derivatives.entrySet()) {
896 if (!rhs.derivatives.containsKey(entry.getKey())) {
897 return false;
898 }
899 if (!Precision.equals(entry.getValue(), rhs.derivatives.get(entry.getKey()), 1)) {
900 return false;
901 }
902 }
903 return true;
904 }
905
906 return false;
907 }
908
909
910
911
912
913
914 @Override
915 public int hashCode() {
916 return 743 + 809 * Double.hashCode(value) + 167 * derivatives.hashCode();
917 }
918 }