001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.numbers.arrays;
018
019/**
020 * Computes linear combinations accurately.
021 * This method computes the sum of the products
022 * <code>a<sub>i</sub> b<sub>i</sub></code> to high accuracy.
023 * It does so by using specific multiplication and addition algorithms to
024 * preserve accuracy and reduce cancellation effects.
025 *
026 * It is based on the 2005 paper
027 * <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.2.1547">
028 * Accurate Sum and Dot Product</a> by Takeshi Ogita, Siegfried M. Rump,
029 * and Shin'ichi Oishi published in <em>SIAM J. Sci. Comput</em>.
030 */
031public final class LinearCombination {
032    /*
033     * Caveat:
034     *
035     * The code below is split in many additions/subtractions that may
036     * appear redundant. However, they should NOT be simplified, as they
037     * do use IEEE754 floating point arithmetic rounding properties.
038     * The variables naming conventions are that xyzHigh contains the most significant
039     * bits of xyz and xyzLow contains its least significant bits. So theoretically
040     * xyz is the sum xyzHigh + xyzLow, but in many cases below, this sum cannot
041     * be represented in only one double precision number so we preserve two numbers
042     * to hold it as long as we can, combining the high and low order bits together
043     * only at the end, after cancellation may have occurred on high order bits
044     */
045
046    /** Private constructor. */
047    private LinearCombination() {
048        // intentionally empty.
049    }
050
051    /**
052     * @param a Factors.
053     * @param b Factors.
054     * @return \( \sum_i a_i b_i \).
055     * @throws IllegalArgumentException if the sizes of the arrays are different.
056     */
057    public static double value(double[] a,
058                               double[] b) {
059        if (a.length != b.length) {
060            throw new IllegalArgumentException("Dimension mismatch: " + a.length + " != " + b.length);
061        }
062
063        final int len = a.length;
064
065        if (len == 1) {
066            // Revert to scalar multiplication.
067            return a[0] * b[0];
068        }
069
070        final double[] prodHigh = new double[len];
071        double prodLowSum = 0;
072
073        for (int i = 0; i < len; i++) {
074            final double ai    = a[i];
075            final double aHigh = highPart(ai);
076            final double aLow  = ai - aHigh;
077
078            final double bi    = b[i];
079            final double bHigh = highPart(bi);
080            final double bLow  = bi - bHigh;
081            prodHigh[i] = ai * bi;
082            final double prodLow = prodLow(aLow, bLow, prodHigh[i], aHigh, bHigh);
083            prodLowSum += prodLow;
084        }
085
086
087        final double prodHighCur = prodHigh[0];
088        double prodHighNext = prodHigh[1];
089        double sHighPrev = prodHighCur + prodHighNext;
090        double sPrime = sHighPrev - prodHighNext;
091        double sLowSum = (prodHighNext - (sHighPrev - sPrime)) + (prodHighCur - sPrime);
092
093        final int lenMinusOne = len - 1;
094        for (int i = 1; i < lenMinusOne; i++) {
095            prodHighNext = prodHigh[i + 1];
096            final double sHighCur = sHighPrev + prodHighNext;
097            sPrime = sHighCur - prodHighNext;
098            sLowSum += (prodHighNext - (sHighCur - sPrime)) + (sHighPrev - sPrime);
099            sHighPrev = sHighCur;
100        }
101
102        double result = sHighPrev + (prodLowSum + sLowSum);
103
104        if (Double.isNaN(result)) {
105            // either we have split infinite numbers or some coefficients were NaNs,
106            // just rely on the naive implementation and let IEEE754 handle this
107            result = 0;
108            for (int i = 0; i < len; ++i) {
109                result += a[i] * b[i];
110            }
111        }
112
113        return result;
114    }
115
116    /**
117     * @param a1 First factor of the first term.
118     * @param b1 Second factor of the first term.
119     * @param a2 First factor of the second term.
120     * @param b2 Second factor of the second term.
121     * @return \( a_1 b_1 + a_2 b_2 \)
122     *
123     * @see #value(double, double, double, double, double, double)
124     * @see #value(double, double, double, double, double, double, double, double)
125     * @see #value(double[], double[])
126     */
127    public static double value(double a1, double b1,
128                               double a2, double b2) {
129        // split a1 and b1 as one 26 bits number and one 27 bits number
130        final double a1High     = highPart(a1);
131        final double a1Low      = a1 - a1High;
132        final double b1High     = highPart(b1);
133        final double b1Low      = b1 - b1High;
134
135        // accurate multiplication a1 * b1
136        final double prod1High  = a1 * b1;
137        final double prod1Low   = prodLow(a1Low, b1Low, prod1High, a1High, b1High);
138
139        // split a2 and b2 as one 26 bits number and one 27 bits number
140        final double a2High     = highPart(a2);
141        final double a2Low      = a2 - a2High;
142        final double b2High     = highPart(b2);
143        final double b2Low      = b2 - b2High;
144
145        // accurate multiplication a2 * b2
146        final double prod2High  = a2 * b2;
147        final double prod2Low   = prodLow(a2Low, b2Low, prod2High, a2High, b2High);
148
149        // accurate addition a1 * b1 + a2 * b2
150        final double s12High    = prod1High + prod2High;
151        final double s12Prime   = s12High - prod2High;
152        final double s12Low     = (prod2High - (s12High - s12Prime)) + (prod1High - s12Prime);
153
154        // final rounding, s12 may have suffered many cancellations, we try
155        // to recover some bits from the extra words we have saved up to now
156        double result = s12High + (prod1Low + prod2Low + s12Low);
157
158        if (Double.isNaN(result)) {
159            // either we have split infinite numbers or some coefficients were NaNs,
160            // just rely on the naive implementation and let IEEE754 handle this
161            result = a1 * b1 + a2 * b2;
162        }
163
164        return result;
165    }
166
167    /**
168     * @param a1 First factor of the first term.
169     * @param b1 Second factor of the first term.
170     * @param a2 First factor of the second term.
171     * @param b2 Second factor of the second term.
172     * @param a3 First factor of the third term.
173     * @param b3 Second factor of the third term.
174     * @return \( a_1 b_1 + a_2 b_2 + a_3 b_3 \)
175     *
176     * @see #value(double, double, double, double)
177     * @see #value(double, double, double, double, double, double, double, double)
178     * @see #value(double[], double[])
179     */
180    public static double value(double a1, double b1,
181                               double a2, double b2,
182                               double a3, double b3) {
183        // split a1 and b1 as one 26 bits number and one 27 bits number
184        final double a1High     = highPart(a1);
185        final double a1Low      = a1 - a1High;
186        final double b1High     = highPart(b1);
187        final double b1Low      = b1 - b1High;
188
189        // accurate multiplication a1 * b1
190        final double prod1High  = a1 * b1;
191        final double prod1Low   = prodLow(a1Low, b1Low, prod1High, a1High, b1High);
192
193        // split a2 and b2 as one 26 bits number and one 27 bits number
194        final double a2High     = highPart(a2);
195        final double a2Low      = a2 - a2High;
196        final double b2High     = highPart(b2);
197        final double b2Low      = b2 - b2High;
198
199        // accurate multiplication a2 * b2
200        final double prod2High  = a2 * b2;
201        final double prod2Low   = prodLow(a2Low, b2Low, prod2High, a2High, b2High);
202
203        // split a3 and b3 as one 26 bits number and one 27 bits number
204        final double a3High     = highPart(a3);
205        final double a3Low      = a3 - a3High;
206        final double b3High     = highPart(b3);
207        final double b3Low      = b3 - b3High;
208
209        // accurate multiplication a3 * b3
210        final double prod3High  = a3 * b3;
211        final double prod3Low   = prodLow(a3Low, b3Low, prod3High, a3High, b3High);
212
213        // accurate addition a1 * b1 + a2 * b2
214        final double s12High    = prod1High + prod2High;
215        final double s12Prime   = s12High - prod2High;
216        final double s12Low     = (prod2High - (s12High - s12Prime)) + (prod1High - s12Prime);
217
218        // accurate addition a1 * b1 + a2 * b2 + a3 * b3
219        final double s123High   = s12High + prod3High;
220        final double s123Prime  = s123High - prod3High;
221        final double s123Low    = (prod3High - (s123High - s123Prime)) + (s12High - s123Prime);
222
223        // final rounding, s123 may have suffered many cancellations, we try
224        // to recover some bits from the extra words we have saved up to now
225        double result = s123High + (prod1Low + prod2Low + prod3Low + s12Low + s123Low);
226
227        if (Double.isNaN(result)) {
228            // either we have split infinite numbers or some coefficients were NaNs,
229            // just rely on the naive implementation and let IEEE754 handle this
230            result = a1 * b1 + a2 * b2 + a3 * b3;
231        }
232
233        return result;
234    }
235
236    /**
237     * @param a1 First factor of the first term.
238     * @param b1 Second factor of the first term.
239     * @param a2 First factor of the second term.
240     * @param b2 Second factor of the second term.
241     * @param a3 First factor of the third term.
242     * @param b3 Second factor of the third term.
243     * @param a4 First factor of the fourth term.
244     * @param b4 Second factor of the fourth term.
245     * @return \( a_1 b_1 + a_2 b_2 + a_3 b_3 + a_4 b_4 \)
246     *
247     * @see #value(double, double, double, double)
248     * @see #value(double, double, double, double, double, double)
249     * @see #value(double[], double[])
250     */
251    public static double value(double a1, double b1,
252                               double a2, double b2,
253                               double a3, double b3,
254                               double a4, double b4) {
255        // split a1 and b1 as one 26 bits number and one 27 bits number
256        final double a1High     = highPart(a1);
257        final double a1Low      = a1 - a1High;
258        final double b1High     = highPart(b1);
259        final double b1Low      = b1 - b1High;
260
261        // accurate multiplication a1 * b1
262        final double prod1High  = a1 * b1;
263        final double prod1Low   = prodLow(a1Low, b1Low, prod1High, a1High, b1High);
264
265        // split a2 and b2 as one 26 bits number and one 27 bits number
266        final double a2High     = highPart(a2);
267        final double a2Low      = a2 - a2High;
268        final double b2High     = highPart(b2);
269        final double b2Low      = b2 - b2High;
270
271        // accurate multiplication a2 * b2
272        final double prod2High  = a2 * b2;
273        final double prod2Low   = prodLow(a2Low, b2Low, prod2High, a2High, b2High);
274
275        // split a3 and b3 as one 26 bits number and one 27 bits number
276        final double a3High     = highPart(a3);
277        final double a3Low      = a3 - a3High;
278        final double b3High     = highPart(b3);
279        final double b3Low      = b3 - b3High;
280
281        // accurate multiplication a3 * b3
282        final double prod3High  = a3 * b3;
283        final double prod3Low   = prodLow(a3Low, b3Low, prod3High, a3High, b3High);
284
285        // split a4 and b4 as one 26 bits number and one 27 bits number
286        final double a4High     = highPart(a4);
287        final double a4Low      = a4 - a4High;
288        final double b4High     = highPart(b4);
289        final double b4Low      = b4 - b4High;
290
291        // accurate multiplication a4 * b4
292        final double prod4High  = a4 * b4;
293        final double prod4Low   = prodLow(a4Low, b4Low, prod4High, a4High, b4High);
294
295        // accurate addition a1 * b1 + a2 * b2
296        final double s12High    = prod1High + prod2High;
297        final double s12Prime   = s12High - prod2High;
298        final double s12Low     = (prod2High - (s12High - s12Prime)) + (prod1High - s12Prime);
299
300        // accurate addition a1 * b1 + a2 * b2 + a3 * b3
301        final double s123High   = s12High + prod3High;
302        final double s123Prime  = s123High - prod3High;
303        final double s123Low    = (prod3High - (s123High - s123Prime)) + (s12High - s123Prime);
304
305        // accurate addition a1 * b1 + a2 * b2 + a3 * b3 + a4 * b4
306        final double s1234High  = s123High + prod4High;
307        final double s1234Prime = s1234High - prod4High;
308        final double s1234Low   = (prod4High - (s1234High - s1234Prime)) + (s123High - s1234Prime);
309
310        // final rounding, s1234 may have suffered many cancellations, we try
311        // to recover some bits from the extra words we have saved up to now
312        double result = s1234High + (prod1Low + prod2Low + prod3Low + prod4Low + s12Low + s123Low + s1234Low);
313
314        if (Double.isNaN(result)) {
315            // either we have split infinite numbers or some coefficients were NaNs,
316            // just rely on the naive implementation and let IEEE754 handle this
317            result = a1 * b1 + a2 * b2 + a3 * b3 + a4 * b4;
318        }
319
320        return result;
321    }
322
323    /**
324     * @param value Value.
325     * @return the high part of the value.
326     */
327    private static double highPart(double value) {
328        return Double.longBitsToDouble(Double.doubleToRawLongBits(value) & ((-1L) << 27));
329    }
330
331    /**
332     * @param aLow Low part of first factor.
333     * @param bLow Low part of second factor.
334     * @param prodHigh Product of the factors.
335     * @param aHigh High part of first factor.
336     * @param bHigh High part of second factor.
337     * @return <code>aLow * bLow - (((prodHigh - aHigh * bHigh) - aLow * bHigh) - aHigh * bLow)</code>
338     */
339    private static double prodLow(double aLow,
340                                  double bLow,
341                                  double prodHigh,
342                                  double aHigh,
343                                  double bHigh) {
344        return aLow * bLow - (((prodHigh - aHigh * bHigh) - aLow * bHigh) - aHigh * bLow);
345    }
346}