001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.descriptive;
018
019import java.io.Serializable;
020import java.util.Arrays;
021
022import org.apache.commons.math3.exception.util.LocalizedFormats;
023import org.apache.commons.math3.exception.DimensionMismatchException;
024import org.apache.commons.math3.exception.MathIllegalStateException;
025import org.apache.commons.math3.linear.RealMatrix;
026import org.apache.commons.math3.stat.descriptive.moment.GeometricMean;
027import org.apache.commons.math3.stat.descriptive.moment.Mean;
028import org.apache.commons.math3.stat.descriptive.moment.VectorialCovariance;
029import org.apache.commons.math3.stat.descriptive.rank.Max;
030import org.apache.commons.math3.stat.descriptive.rank.Min;
031import org.apache.commons.math3.stat.descriptive.summary.Sum;
032import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs;
033import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares;
034import org.apache.commons.math3.util.MathUtils;
035import org.apache.commons.math3.util.MathArrays;
036import org.apache.commons.math3.util.Precision;
037import org.apache.commons.math3.util.FastMath;
038
039/**
040 * <p>Computes summary statistics for a stream of n-tuples added using the
041 * {@link #addValue(double[]) addValue} method. The data values are not stored
042 * in memory, so this class can be used to compute statistics for very large
043 * n-tuple streams.</p>
044 *
045 * <p>The {@link StorelessUnivariateStatistic} instances used to maintain
046 * summary state and compute statistics are configurable via setters.
047 * For example, the default implementation for the mean can be overridden by
048 * calling {@link #setMeanImpl(StorelessUnivariateStatistic[])}. Actual
049 * parameters to these methods must implement the
050 * {@link StorelessUnivariateStatistic} interface and configuration must be
051 * completed before <code>addValue</code> is called. No configuration is
052 * necessary to use the default, commons-math provided implementations.</p>
053 *
054 * <p>To compute statistics for a stream of n-tuples, construct a
055 * MultivariateStatistics instance with dimension n and then use
056 * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code>
057 * methods where Xxx is a statistic return an array of <code>double</code>
058 * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element is the
059 * value of the given statistic for data range consisting of the i<sup>th</sup> element of
060 * each of the input n-tuples.  For example, if <code>addValue</code> is called
061 * with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
062 * <code>getSum</code> will return a three-element array with values
063 * {0+3+6, 1+4+7, 2+5+8}</p>
064 *
065 * <p>Note: This class is not thread-safe. Use
066 * {@link SynchronizedMultivariateSummaryStatistics} if concurrent access from multiple
067 * threads is required.</p>
068 *
069 * @since 1.2
070 */
071public class MultivariateSummaryStatistics
072    implements StatisticalMultivariateSummary, Serializable {
073
074    /** Serialization UID */
075    private static final long serialVersionUID = 2271900808994826718L;
076
077    /** Dimension of the data. */
078    private int k;
079
080    /** Count of values that have been added */
081    private long n = 0;
082
083    /** Sum statistic implementation - can be reset by setter. */
084    private StorelessUnivariateStatistic[] sumImpl;
085
086    /** Sum of squares statistic implementation - can be reset by setter. */
087    private StorelessUnivariateStatistic[] sumSqImpl;
088
089    /** Minimum statistic implementation - can be reset by setter. */
090    private StorelessUnivariateStatistic[] minImpl;
091
092    /** Maximum statistic implementation - can be reset by setter. */
093    private StorelessUnivariateStatistic[] maxImpl;
094
095    /** Sum of log statistic implementation - can be reset by setter. */
096    private StorelessUnivariateStatistic[] sumLogImpl;
097
098    /** Geometric mean statistic implementation - can be reset by setter. */
099    private StorelessUnivariateStatistic[] geoMeanImpl;
100
101    /** Mean statistic implementation - can be reset by setter. */
102    private StorelessUnivariateStatistic[] meanImpl;
103
104    /** Covariance statistic implementation - cannot be reset. */
105    private VectorialCovariance covarianceImpl;
106
107    /**
108     * Construct a MultivariateSummaryStatistics instance
109     * @param k dimension of the data
110     * @param isCovarianceBiasCorrected if true, the unbiased sample
111     * covariance is computed, otherwise the biased population covariance
112     * is computed
113     */
114    public MultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
115        this.k = k;
116
117        sumImpl     = new StorelessUnivariateStatistic[k];
118        sumSqImpl   = new StorelessUnivariateStatistic[k];
119        minImpl     = new StorelessUnivariateStatistic[k];
120        maxImpl     = new StorelessUnivariateStatistic[k];
121        sumLogImpl  = new StorelessUnivariateStatistic[k];
122        geoMeanImpl = new StorelessUnivariateStatistic[k];
123        meanImpl    = new StorelessUnivariateStatistic[k];
124
125        for (int i = 0; i < k; ++i) {
126            sumImpl[i]     = new Sum();
127            sumSqImpl[i]   = new SumOfSquares();
128            minImpl[i]     = new Min();
129            maxImpl[i]     = new Max();
130            sumLogImpl[i]  = new SumOfLogs();
131            geoMeanImpl[i] = new GeometricMean();
132            meanImpl[i]    = new Mean();
133        }
134
135        covarianceImpl =
136            new VectorialCovariance(k, isCovarianceBiasCorrected);
137
138    }
139
140    /**
141     * Add an n-tuple to the data
142     *
143     * @param value  the n-tuple to add
144     * @throws DimensionMismatchException if the length of the array
145     * does not match the one used at construction
146     */
147    public void addValue(double[] value) throws DimensionMismatchException {
148        checkDimension(value.length);
149        for (int i = 0; i < k; ++i) {
150            double v = value[i];
151            sumImpl[i].increment(v);
152            sumSqImpl[i].increment(v);
153            minImpl[i].increment(v);
154            maxImpl[i].increment(v);
155            sumLogImpl[i].increment(v);
156            geoMeanImpl[i].increment(v);
157            meanImpl[i].increment(v);
158        }
159        covarianceImpl.increment(value);
160        n++;
161    }
162
163    /**
164     * Returns the dimension of the data
165     * @return The dimension of the data
166     */
167    public int getDimension() {
168        return k;
169    }
170
171    /**
172     * Returns the number of available values
173     * @return The number of available values
174     */
175    public long getN() {
176        return n;
177    }
178
179    /**
180     * Returns an array of the results of a statistic.
181     * @param stats univariate statistic array
182     * @return results array
183     */
184    private double[] getResults(StorelessUnivariateStatistic[] stats) {
185        double[] results = new double[stats.length];
186        for (int i = 0; i < results.length; ++i) {
187            results[i] = stats[i].getResult();
188        }
189        return results;
190    }
191
192    /**
193     * Returns an array whose i<sup>th</sup> entry is the sum of the
194     * i<sup>th</sup> entries of the arrays that have been added using
195     * {@link #addValue(double[])}
196     *
197     * @return the array of component sums
198     */
199    public double[] getSum() {
200        return getResults(sumImpl);
201    }
202
203    /**
204     * Returns an array whose i<sup>th</sup> entry is the sum of squares of the
205     * i<sup>th</sup> entries of the arrays that have been added using
206     * {@link #addValue(double[])}
207     *
208     * @return the array of component sums of squares
209     */
210    public double[] getSumSq() {
211        return getResults(sumSqImpl);
212    }
213
214    /**
215     * Returns an array whose i<sup>th</sup> entry is the sum of logs of the
216     * i<sup>th</sup> entries of the arrays that have been added using
217     * {@link #addValue(double[])}
218     *
219     * @return the array of component log sums
220     */
221    public double[] getSumLog() {
222        return getResults(sumLogImpl);
223    }
224
225    /**
226     * Returns an array whose i<sup>th</sup> entry is the mean of the
227     * i<sup>th</sup> entries of the arrays that have been added using
228     * {@link #addValue(double[])}
229     *
230     * @return the array of component means
231     */
232    public double[] getMean() {
233        return getResults(meanImpl);
234    }
235
236    /**
237     * Returns an array whose i<sup>th</sup> entry is the standard deviation of the
238     * i<sup>th</sup> entries of the arrays that have been added using
239     * {@link #addValue(double[])}
240     *
241     * @return the array of component standard deviations
242     */
243    public double[] getStandardDeviation() {
244        double[] stdDev = new double[k];
245        if (getN() < 1) {
246            Arrays.fill(stdDev, Double.NaN);
247        } else if (getN() < 2) {
248            Arrays.fill(stdDev, 0.0);
249        } else {
250            RealMatrix matrix = covarianceImpl.getResult();
251            for (int i = 0; i < k; ++i) {
252                stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
253            }
254        }
255        return stdDev;
256    }
257
258    /**
259     * Returns the covariance matrix of the values that have been added.
260     *
261     * @return the covariance matrix
262     */
263    public RealMatrix getCovariance() {
264        return covarianceImpl.getResult();
265    }
266
267    /**
268     * Returns an array whose i<sup>th</sup> entry is the maximum of the
269     * i<sup>th</sup> entries of the arrays that have been added using
270     * {@link #addValue(double[])}
271     *
272     * @return the array of component maxima
273     */
274    public double[] getMax() {
275        return getResults(maxImpl);
276    }
277
278    /**
279     * Returns an array whose i<sup>th</sup> entry is the minimum of the
280     * i<sup>th</sup> entries of the arrays that have been added using
281     * {@link #addValue(double[])}
282     *
283     * @return the array of component minima
284     */
285    public double[] getMin() {
286        return getResults(minImpl);
287    }
288
289    /**
290     * Returns an array whose i<sup>th</sup> entry is the geometric mean of the
291     * i<sup>th</sup> entries of the arrays that have been added using
292     * {@link #addValue(double[])}
293     *
294     * @return the array of component geometric means
295     */
296    public double[] getGeometricMean() {
297        return getResults(geoMeanImpl);
298    }
299
300    /**
301     * Generates a text report displaying
302     * summary statistics from values that
303     * have been added.
304     * @return String with line feeds displaying statistics
305     */
306    @Override
307    public String toString() {
308        final String separator = ", ";
309        final String suffix = System.getProperty("line.separator");
310        StringBuilder outBuffer = new StringBuilder();
311        outBuffer.append("MultivariateSummaryStatistics:" + suffix);
312        outBuffer.append("n: " + getN() + suffix);
313        append(outBuffer, getMin(), "min: ", separator, suffix);
314        append(outBuffer, getMax(), "max: ", separator, suffix);
315        append(outBuffer, getMean(), "mean: ", separator, suffix);
316        append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
317        append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
318        append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
319        append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
320        outBuffer.append("covariance: " + getCovariance().toString() + suffix);
321        return outBuffer.toString();
322    }
323
324    /**
325     * Append a text representation of an array to a buffer.
326     * @param buffer buffer to fill
327     * @param data data array
328     * @param prefix text prefix
329     * @param separator elements separator
330     * @param suffix text suffix
331     */
332    private void append(StringBuilder buffer, double[] data,
333                        String prefix, String separator, String suffix) {
334        buffer.append(prefix);
335        for (int i = 0; i < data.length; ++i) {
336            if (i > 0) {
337                buffer.append(separator);
338            }
339            buffer.append(data[i]);
340        }
341        buffer.append(suffix);
342    }
343
344    /**
345     * Resets all statistics and storage
346     */
347    public void clear() {
348        this.n = 0;
349        for (int i = 0; i < k; ++i) {
350            minImpl[i].clear();
351            maxImpl[i].clear();
352            sumImpl[i].clear();
353            sumLogImpl[i].clear();
354            sumSqImpl[i].clear();
355            geoMeanImpl[i].clear();
356            meanImpl[i].clear();
357        }
358        covarianceImpl.clear();
359    }
360
361    /**
362     * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code>
363     * instance and all statistics have the same values as this.
364     * @param object the object to test equality against.
365     * @return true if object equals this
366     */
367    @Override
368    public boolean equals(Object object) {
369        if (object == this ) {
370            return true;
371        }
372        if (object instanceof MultivariateSummaryStatistics == false) {
373            return false;
374        }
375        MultivariateSummaryStatistics stat = (MultivariateSummaryStatistics) object;
376        return MathArrays.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) &&
377               MathArrays.equalsIncludingNaN(stat.getMax(),           getMax())           &&
378               MathArrays.equalsIncludingNaN(stat.getMean(),          getMean())          &&
379               MathArrays.equalsIncludingNaN(stat.getMin(),           getMin())           &&
380               Precision.equalsIncludingNaN(stat.getN(),             getN())             &&
381               MathArrays.equalsIncludingNaN(stat.getSum(),           getSum())           &&
382               MathArrays.equalsIncludingNaN(stat.getSumSq(),         getSumSq())         &&
383               MathArrays.equalsIncludingNaN(stat.getSumLog(),        getSumLog())        &&
384               stat.getCovariance().equals( getCovariance());
385    }
386
387    /**
388     * Returns hash code based on values of statistics
389     *
390     * @return hash code
391     */
392    @Override
393    public int hashCode() {
394        int result = 31 + MathUtils.hash(getGeometricMean());
395        result = result * 31 + MathUtils.hash(getGeometricMean());
396        result = result * 31 + MathUtils.hash(getMax());
397        result = result * 31 + MathUtils.hash(getMean());
398        result = result * 31 + MathUtils.hash(getMin());
399        result = result * 31 + MathUtils.hash(getN());
400        result = result * 31 + MathUtils.hash(getSum());
401        result = result * 31 + MathUtils.hash(getSumSq());
402        result = result * 31 + MathUtils.hash(getSumLog());
403        result = result * 31 + getCovariance().hashCode();
404        return result;
405    }
406
407    // Getters and setters for statistics implementations
408    /**
409     * Sets statistics implementations.
410     * @param newImpl new implementations for statistics
411     * @param oldImpl old implementations for statistics
412     * @throws DimensionMismatchException if the array dimension
413     * does not match the one used at construction
414     * @throws MathIllegalStateException if data has already been added
415     * (i.e. if n > 0)
416     */
417    private void setImpl(StorelessUnivariateStatistic[] newImpl,
418                         StorelessUnivariateStatistic[] oldImpl) throws MathIllegalStateException,
419                         DimensionMismatchException {
420        checkEmpty();
421        checkDimension(newImpl.length);
422        System.arraycopy(newImpl, 0, oldImpl, 0, newImpl.length);
423    }
424
425    /**
426     * Returns the currently configured Sum implementation
427     *
428     * @return the StorelessUnivariateStatistic implementing the sum
429     */
430    public StorelessUnivariateStatistic[] getSumImpl() {
431        return sumImpl.clone();
432    }
433
434    /**
435     * <p>Sets the implementation for the Sum.</p>
436     * <p>This method must be activated before any data has been added - i.e.,
437     * before {@link #addValue(double[]) addValue} has been used to add data;
438     * otherwise an IllegalStateException will be thrown.</p>
439     *
440     * @param sumImpl the StorelessUnivariateStatistic instance to use
441     * for computing the Sum
442     * @throws DimensionMismatchException if the array dimension
443     * does not match the one used at construction
444     * @throws MathIllegalStateException if data has already been added
445     *  (i.e if n > 0)
446     */
447    public void setSumImpl(StorelessUnivariateStatistic[] sumImpl)
448    throws MathIllegalStateException, DimensionMismatchException {
449        setImpl(sumImpl, this.sumImpl);
450    }
451
452    /**
453     * Returns the currently configured sum of squares implementation
454     *
455     * @return the StorelessUnivariateStatistic implementing the sum of squares
456     */
457    public StorelessUnivariateStatistic[] getSumsqImpl() {
458        return sumSqImpl.clone();
459    }
460
461    /**
462     * <p>Sets the implementation for the sum of squares.</p>
463     * <p>This method must be activated before any data has been added - i.e.,
464     * before {@link #addValue(double[]) addValue} has been used to add data;
465     * otherwise an IllegalStateException will be thrown.</p>
466     *
467     * @param sumsqImpl the StorelessUnivariateStatistic instance to use
468     * for computing the sum of squares
469     * @throws DimensionMismatchException if the array dimension
470     * does not match the one used at construction
471     * @throws MathIllegalStateException if data has already been added
472     *  (i.e if n > 0)
473     */
474    public void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl)
475    throws MathIllegalStateException, DimensionMismatchException {
476        setImpl(sumsqImpl, this.sumSqImpl);
477    }
478
479    /**
480     * Returns the currently configured minimum implementation
481     *
482     * @return the StorelessUnivariateStatistic implementing the minimum
483     */
484    public StorelessUnivariateStatistic[] getMinImpl() {
485        return minImpl.clone();
486    }
487
488    /**
489     * <p>Sets the implementation for the minimum.</p>
490     * <p>This method must be activated before any data has been added - i.e.,
491     * before {@link #addValue(double[]) addValue} has been used to add data;
492     * otherwise an IllegalStateException will be thrown.</p>
493     *
494     * @param minImpl the StorelessUnivariateStatistic instance to use
495     * for computing the minimum
496     * @throws DimensionMismatchException if the array dimension
497     * does not match the one used at construction
498     * @throws MathIllegalStateException if data has already been added
499     *  (i.e if n > 0)
500     */
501    public void setMinImpl(StorelessUnivariateStatistic[] minImpl)
502    throws MathIllegalStateException, DimensionMismatchException {
503        setImpl(minImpl, this.minImpl);
504    }
505
506    /**
507     * Returns the currently configured maximum implementation
508     *
509     * @return the StorelessUnivariateStatistic implementing the maximum
510     */
511    public StorelessUnivariateStatistic[] getMaxImpl() {
512        return maxImpl.clone();
513    }
514
515    /**
516     * <p>Sets the implementation for the maximum.</p>
517     * <p>This method must be activated before any data has been added - i.e.,
518     * before {@link #addValue(double[]) addValue} has been used to add data;
519     * otherwise an IllegalStateException will be thrown.</p>
520     *
521     * @param maxImpl the StorelessUnivariateStatistic instance to use
522     * for computing the maximum
523     * @throws DimensionMismatchException if the array dimension
524     * does not match the one used at construction
525     * @throws MathIllegalStateException if data has already been added
526     *  (i.e if n > 0)
527     */
528    public void setMaxImpl(StorelessUnivariateStatistic[] maxImpl)
529    throws MathIllegalStateException, DimensionMismatchException{
530        setImpl(maxImpl, this.maxImpl);
531    }
532
533    /**
534     * Returns the currently configured sum of logs implementation
535     *
536     * @return the StorelessUnivariateStatistic implementing the log sum
537     */
538    public StorelessUnivariateStatistic[] getSumLogImpl() {
539        return sumLogImpl.clone();
540    }
541
542    /**
543     * <p>Sets the implementation for the sum of logs.</p>
544     * <p>This method must be activated before any data has been added - i.e.,
545     * before {@link #addValue(double[]) addValue} has been used to add data;
546     * otherwise an IllegalStateException will be thrown.</p>
547     *
548     * @param sumLogImpl the StorelessUnivariateStatistic instance to use
549     * for computing the log sum
550     * @throws DimensionMismatchException if the array dimension
551     * does not match the one used at construction
552     * @throws MathIllegalStateException if data has already been added
553     *  (i.e if n > 0)
554     */
555    public void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl)
556    throws MathIllegalStateException, DimensionMismatchException{
557        setImpl(sumLogImpl, this.sumLogImpl);
558    }
559
560    /**
561     * Returns the currently configured geometric mean implementation
562     *
563     * @return the StorelessUnivariateStatistic implementing the geometric mean
564     */
565    public StorelessUnivariateStatistic[] getGeoMeanImpl() {
566        return geoMeanImpl.clone();
567    }
568
569    /**
570     * <p>Sets the implementation for the geometric mean.</p>
571     * <p>This method must be activated before any data has been added - i.e.,
572     * before {@link #addValue(double[]) addValue} has been used to add data;
573     * otherwise an IllegalStateException will be thrown.</p>
574     *
575     * @param geoMeanImpl the StorelessUnivariateStatistic instance to use
576     * for computing the geometric mean
577     * @throws DimensionMismatchException if the array dimension
578     * does not match the one used at construction
579     * @throws MathIllegalStateException if data has already been added
580     *  (i.e if n > 0)
581     */
582    public void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl)
583    throws MathIllegalStateException, DimensionMismatchException {
584        setImpl(geoMeanImpl, this.geoMeanImpl);
585    }
586
587    /**
588     * Returns the currently configured mean implementation
589     *
590     * @return the StorelessUnivariateStatistic implementing the mean
591     */
592    public StorelessUnivariateStatistic[] getMeanImpl() {
593        return meanImpl.clone();
594    }
595
596    /**
597     * <p>Sets the implementation for the mean.</p>
598     * <p>This method must be activated before any data has been added - i.e.,
599     * before {@link #addValue(double[]) addValue} has been used to add data;
600     * otherwise an IllegalStateException will be thrown.</p>
601     *
602     * @param meanImpl the StorelessUnivariateStatistic instance to use
603     * for computing the mean
604     * @throws DimensionMismatchException if the array dimension
605     * does not match the one used at construction
606     * @throws MathIllegalStateException if data has already been added
607     *  (i.e if n > 0)
608     */
609    public void setMeanImpl(StorelessUnivariateStatistic[] meanImpl)
610    throws MathIllegalStateException, DimensionMismatchException{
611        setImpl(meanImpl, this.meanImpl);
612    }
613
614    /**
615     * Throws MathIllegalStateException if the statistic is not empty.
616     * @throws MathIllegalStateException if n > 0.
617     */
618    private void checkEmpty() throws MathIllegalStateException {
619        if (n > 0) {
620            throw new MathIllegalStateException(
621                    LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n);
622        }
623    }
624
625    /**
626     * Throws DimensionMismatchException if dimension != k.
627     * @param dimension dimension to check
628     * @throws DimensionMismatchException if dimension != k
629     */
630    private void checkDimension(int dimension) throws DimensionMismatchException {
631        if (dimension != k) {
632            throw new DimensionMismatchException(dimension, k);
633        }
634    }
635}