001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math3.stat.descriptive;
018
019import java.io.Serializable;
020import java.util.Arrays;
021
022import org.apache.commons.math3.exception.util.LocalizedFormats;
023import org.apache.commons.math3.exception.DimensionMismatchException;
024import org.apache.commons.math3.exception.MathIllegalStateException;
025import org.apache.commons.math3.linear.RealMatrix;
026import org.apache.commons.math3.stat.descriptive.moment.GeometricMean;
027import org.apache.commons.math3.stat.descriptive.moment.Mean;
028import org.apache.commons.math3.stat.descriptive.moment.VectorialCovariance;
029import org.apache.commons.math3.stat.descriptive.rank.Max;
030import org.apache.commons.math3.stat.descriptive.rank.Min;
031import org.apache.commons.math3.stat.descriptive.summary.Sum;
032import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs;
033import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares;
034import org.apache.commons.math3.util.MathUtils;
035import org.apache.commons.math3.util.MathArrays;
036import org.apache.commons.math3.util.Precision;
037import org.apache.commons.math3.util.FastMath;
038
039/**
040 * <p>Computes summary statistics for a stream of n-tuples added using the
041 * {@link #addValue(double[]) addValue} method. The data values are not stored
042 * in memory, so this class can be used to compute statistics for very large
043 * n-tuple streams.</p>
044 *
045 * <p>The {@link StorelessUnivariateStatistic} instances used to maintain
046 * summary state and compute statistics are configurable via setters.
047 * For example, the default implementation for the mean can be overridden by
048 * calling {@link #setMeanImpl(StorelessUnivariateStatistic[])}. Actual
049 * parameters to these methods must implement the
050 * {@link StorelessUnivariateStatistic} interface and configuration must be
051 * completed before <code>addValue</code> is called. No configuration is
052 * necessary to use the default, commons-math provided implementations.</p>
053 *
054 * <p>To compute statistics for a stream of n-tuples, construct a
055 * MultivariateStatistics instance with dimension n and then use
056 * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code>
057 * methods where Xxx is a statistic return an array of <code>double</code>
058 * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element is the
059 * value of the given statistic for data range consisting of the i<sup>th</sup> element of
060 * each of the input n-tuples.  For example, if <code>addValue</code> is called
061 * with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
062 * <code>getSum</code> will return a three-element array with values
063 * {0+3+6, 1+4+7, 2+5+8}</p>
064 *
065 * <p>Note: This class is not thread-safe. Use
066 * {@link SynchronizedMultivariateSummaryStatistics} if concurrent access from multiple
067 * threads is required.</p>
068 *
069 * @since 1.2
070 * @version $Id: MultivariateSummaryStatistics.java 1416643 2012-12-03 19:37:14Z tn $
071 */
072public class MultivariateSummaryStatistics
073    implements StatisticalMultivariateSummary, Serializable {
074
075    /** Serialization UID */
076    private static final long serialVersionUID = 2271900808994826718L;
077
078    /** Dimension of the data. */
079    private int k;
080
081    /** Count of values that have been added */
082    private long n = 0;
083
084    /** Sum statistic implementation - can be reset by setter. */
085    private StorelessUnivariateStatistic[] sumImpl;
086
087    /** Sum of squares statistic implementation - can be reset by setter. */
088    private StorelessUnivariateStatistic[] sumSqImpl;
089
090    /** Minimum statistic implementation - can be reset by setter. */
091    private StorelessUnivariateStatistic[] minImpl;
092
093    /** Maximum statistic implementation - can be reset by setter. */
094    private StorelessUnivariateStatistic[] maxImpl;
095
096    /** Sum of log statistic implementation - can be reset by setter. */
097    private StorelessUnivariateStatistic[] sumLogImpl;
098
099    /** Geometric mean statistic implementation - can be reset by setter. */
100    private StorelessUnivariateStatistic[] geoMeanImpl;
101
102    /** Mean statistic implementation - can be reset by setter. */
103    private StorelessUnivariateStatistic[] meanImpl;
104
105    /** Covariance statistic implementation - cannot be reset. */
106    private VectorialCovariance covarianceImpl;
107
108    /**
109     * Construct a MultivariateSummaryStatistics instance
110     * @param k dimension of the data
111     * @param isCovarianceBiasCorrected if true, the unbiased sample
112     * covariance is computed, otherwise the biased population covariance
113     * is computed
114     */
115    public MultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
116        this.k = k;
117
118        sumImpl     = new StorelessUnivariateStatistic[k];
119        sumSqImpl   = new StorelessUnivariateStatistic[k];
120        minImpl     = new StorelessUnivariateStatistic[k];
121        maxImpl     = new StorelessUnivariateStatistic[k];
122        sumLogImpl  = new StorelessUnivariateStatistic[k];
123        geoMeanImpl = new StorelessUnivariateStatistic[k];
124        meanImpl    = new StorelessUnivariateStatistic[k];
125
126        for (int i = 0; i < k; ++i) {
127            sumImpl[i]     = new Sum();
128            sumSqImpl[i]   = new SumOfSquares();
129            minImpl[i]     = new Min();
130            maxImpl[i]     = new Max();
131            sumLogImpl[i]  = new SumOfLogs();
132            geoMeanImpl[i] = new GeometricMean();
133            meanImpl[i]    = new Mean();
134        }
135
136        covarianceImpl =
137            new VectorialCovariance(k, isCovarianceBiasCorrected);
138
139    }
140
141    /**
142     * Add an n-tuple to the data
143     *
144     * @param value  the n-tuple to add
145     * @throws DimensionMismatchException if the length of the array
146     * does not match the one used at construction
147     */
148    public void addValue(double[] value) throws DimensionMismatchException {
149        checkDimension(value.length);
150        for (int i = 0; i < k; ++i) {
151            double v = value[i];
152            sumImpl[i].increment(v);
153            sumSqImpl[i].increment(v);
154            minImpl[i].increment(v);
155            maxImpl[i].increment(v);
156            sumLogImpl[i].increment(v);
157            geoMeanImpl[i].increment(v);
158            meanImpl[i].increment(v);
159        }
160        covarianceImpl.increment(value);
161        n++;
162    }
163
164    /**
165     * Returns the dimension of the data
166     * @return The dimension of the data
167     */
168    public int getDimension() {
169        return k;
170    }
171
172    /**
173     * Returns the number of available values
174     * @return The number of available values
175     */
176    public long getN() {
177        return n;
178    }
179
180    /**
181     * Returns an array of the results of a statistic.
182     * @param stats univariate statistic array
183     * @return results array
184     */
185    private double[] getResults(StorelessUnivariateStatistic[] stats) {
186        double[] results = new double[stats.length];
187        for (int i = 0; i < results.length; ++i) {
188            results[i] = stats[i].getResult();
189        }
190        return results;
191    }
192
193    /**
194     * Returns an array whose i<sup>th</sup> entry is the sum of the
195     * i<sup>th</sup> entries of the arrays that have been added using
196     * {@link #addValue(double[])}
197     *
198     * @return the array of component sums
199     */
200    public double[] getSum() {
201        return getResults(sumImpl);
202    }
203
204    /**
205     * Returns an array whose i<sup>th</sup> entry is the sum of squares of the
206     * i<sup>th</sup> entries of the arrays that have been added using
207     * {@link #addValue(double[])}
208     *
209     * @return the array of component sums of squares
210     */
211    public double[] getSumSq() {
212        return getResults(sumSqImpl);
213    }
214
215    /**
216     * Returns an array whose i<sup>th</sup> entry is the sum of logs of the
217     * i<sup>th</sup> entries of the arrays that have been added using
218     * {@link #addValue(double[])}
219     *
220     * @return the array of component log sums
221     */
222    public double[] getSumLog() {
223        return getResults(sumLogImpl);
224    }
225
226    /**
227     * Returns an array whose i<sup>th</sup> entry is the mean of the
228     * i<sup>th</sup> entries of the arrays that have been added using
229     * {@link #addValue(double[])}
230     *
231     * @return the array of component means
232     */
233    public double[] getMean() {
234        return getResults(meanImpl);
235    }
236
237    /**
238     * Returns an array whose i<sup>th</sup> entry is the standard deviation of the
239     * i<sup>th</sup> entries of the arrays that have been added using
240     * {@link #addValue(double[])}
241     *
242     * @return the array of component standard deviations
243     */
244    public double[] getStandardDeviation() {
245        double[] stdDev = new double[k];
246        if (getN() < 1) {
247            Arrays.fill(stdDev, Double.NaN);
248        } else if (getN() < 2) {
249            Arrays.fill(stdDev, 0.0);
250        } else {
251            RealMatrix matrix = covarianceImpl.getResult();
252            for (int i = 0; i < k; ++i) {
253                stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
254            }
255        }
256        return stdDev;
257    }
258
259    /**
260     * Returns the covariance matrix of the values that have been added.
261     *
262     * @return the covariance matrix
263     */
264    public RealMatrix getCovariance() {
265        return covarianceImpl.getResult();
266    }
267
268    /**
269     * Returns an array whose i<sup>th</sup> entry is the maximum of the
270     * i<sup>th</sup> entries of the arrays that have been added using
271     * {@link #addValue(double[])}
272     *
273     * @return the array of component maxima
274     */
275    public double[] getMax() {
276        return getResults(maxImpl);
277    }
278
279    /**
280     * Returns an array whose i<sup>th</sup> entry is the minimum of the
281     * i<sup>th</sup> entries of the arrays that have been added using
282     * {@link #addValue(double[])}
283     *
284     * @return the array of component minima
285     */
286    public double[] getMin() {
287        return getResults(minImpl);
288    }
289
290    /**
291     * Returns an array whose i<sup>th</sup> entry is the geometric mean of the
292     * i<sup>th</sup> entries of the arrays that have been added using
293     * {@link #addValue(double[])}
294     *
295     * @return the array of component geometric means
296     */
297    public double[] getGeometricMean() {
298        return getResults(geoMeanImpl);
299    }
300
301    /**
302     * Generates a text report displaying
303     * summary statistics from values that
304     * have been added.
305     * @return String with line feeds displaying statistics
306     */
307    @Override
308    public String toString() {
309        final String separator = ", ";
310        final String suffix = System.getProperty("line.separator");
311        StringBuilder outBuffer = new StringBuilder();
312        outBuffer.append("MultivariateSummaryStatistics:" + suffix);
313        outBuffer.append("n: " + getN() + suffix);
314        append(outBuffer, getMin(), "min: ", separator, suffix);
315        append(outBuffer, getMax(), "max: ", separator, suffix);
316        append(outBuffer, getMean(), "mean: ", separator, suffix);
317        append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
318        append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
319        append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
320        append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
321        outBuffer.append("covariance: " + getCovariance().toString() + suffix);
322        return outBuffer.toString();
323    }
324
325    /**
326     * Append a text representation of an array to a buffer.
327     * @param buffer buffer to fill
328     * @param data data array
329     * @param prefix text prefix
330     * @param separator elements separator
331     * @param suffix text suffix
332     */
333    private void append(StringBuilder buffer, double[] data,
334                        String prefix, String separator, String suffix) {
335        buffer.append(prefix);
336        for (int i = 0; i < data.length; ++i) {
337            if (i > 0) {
338                buffer.append(separator);
339            }
340            buffer.append(data[i]);
341        }
342        buffer.append(suffix);
343    }
344
345    /**
346     * Resets all statistics and storage
347     */
348    public void clear() {
349        this.n = 0;
350        for (int i = 0; i < k; ++i) {
351            minImpl[i].clear();
352            maxImpl[i].clear();
353            sumImpl[i].clear();
354            sumLogImpl[i].clear();
355            sumSqImpl[i].clear();
356            geoMeanImpl[i].clear();
357            meanImpl[i].clear();
358        }
359        covarianceImpl.clear();
360    }
361
362    /**
363     * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code>
364     * instance and all statistics have the same values as this.
365     * @param object the object to test equality against.
366     * @return true if object equals this
367     */
368    @Override
369    public boolean equals(Object object) {
370        if (object == this ) {
371            return true;
372        }
373        if (object instanceof MultivariateSummaryStatistics == false) {
374            return false;
375        }
376        MultivariateSummaryStatistics stat = (MultivariateSummaryStatistics) object;
377        return MathArrays.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) &&
378               MathArrays.equalsIncludingNaN(stat.getMax(),           getMax())           &&
379               MathArrays.equalsIncludingNaN(stat.getMean(),          getMean())          &&
380               MathArrays.equalsIncludingNaN(stat.getMin(),           getMin())           &&
381               Precision.equalsIncludingNaN(stat.getN(),             getN())             &&
382               MathArrays.equalsIncludingNaN(stat.getSum(),           getSum())           &&
383               MathArrays.equalsIncludingNaN(stat.getSumSq(),         getSumSq())         &&
384               MathArrays.equalsIncludingNaN(stat.getSumLog(),        getSumLog())        &&
385               stat.getCovariance().equals( getCovariance());
386    }
387
388    /**
389     * Returns hash code based on values of statistics
390     *
391     * @return hash code
392     */
393    @Override
394    public int hashCode() {
395        int result = 31 + MathUtils.hash(getGeometricMean());
396        result = result * 31 + MathUtils.hash(getGeometricMean());
397        result = result * 31 + MathUtils.hash(getMax());
398        result = result * 31 + MathUtils.hash(getMean());
399        result = result * 31 + MathUtils.hash(getMin());
400        result = result * 31 + MathUtils.hash(getN());
401        result = result * 31 + MathUtils.hash(getSum());
402        result = result * 31 + MathUtils.hash(getSumSq());
403        result = result * 31 + MathUtils.hash(getSumLog());
404        result = result * 31 + getCovariance().hashCode();
405        return result;
406    }
407
408    // Getters and setters for statistics implementations
409    /**
410     * Sets statistics implementations.
411     * @param newImpl new implementations for statistics
412     * @param oldImpl old implementations for statistics
413     * @throws DimensionMismatchException if the array dimension
414     * does not match the one used at construction
415     * @throws MathIllegalStateException if data has already been added
416     * (i.e. if n > 0)
417     */
418    private void setImpl(StorelessUnivariateStatistic[] newImpl,
419                         StorelessUnivariateStatistic[] oldImpl) throws MathIllegalStateException,
420                         DimensionMismatchException {
421        checkEmpty();
422        checkDimension(newImpl.length);
423        System.arraycopy(newImpl, 0, oldImpl, 0, newImpl.length);
424    }
425
426    /**
427     * Returns the currently configured Sum implementation
428     *
429     * @return the StorelessUnivariateStatistic implementing the sum
430     */
431    public StorelessUnivariateStatistic[] getSumImpl() {
432        return sumImpl.clone();
433    }
434
435    /**
436     * <p>Sets the implementation for the Sum.</p>
437     * <p>This method must be activated before any data has been added - i.e.,
438     * before {@link #addValue(double[]) addValue} has been used to add data;
439     * otherwise an IllegalStateException will be thrown.</p>
440     *
441     * @param sumImpl the StorelessUnivariateStatistic instance to use
442     * for computing the Sum
443     * @throws DimensionMismatchException if the array dimension
444     * does not match the one used at construction
445     * @throws MathIllegalStateException if data has already been added
446     *  (i.e if n > 0)
447     */
448    public void setSumImpl(StorelessUnivariateStatistic[] sumImpl)
449    throws MathIllegalStateException, DimensionMismatchException {
450        setImpl(sumImpl, this.sumImpl);
451    }
452
453    /**
454     * Returns the currently configured sum of squares implementation
455     *
456     * @return the StorelessUnivariateStatistic implementing the sum of squares
457     */
458    public StorelessUnivariateStatistic[] getSumsqImpl() {
459        return sumSqImpl.clone();
460    }
461
462    /**
463     * <p>Sets the implementation for the sum of squares.</p>
464     * <p>This method must be activated before any data has been added - i.e.,
465     * before {@link #addValue(double[]) addValue} has been used to add data;
466     * otherwise an IllegalStateException will be thrown.</p>
467     *
468     * @param sumsqImpl the StorelessUnivariateStatistic instance to use
469     * for computing the sum of squares
470     * @throws DimensionMismatchException if the array dimension
471     * does not match the one used at construction
472     * @throws MathIllegalStateException if data has already been added
473     *  (i.e if n > 0)
474     */
475    public void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl)
476    throws MathIllegalStateException, DimensionMismatchException {
477        setImpl(sumsqImpl, this.sumSqImpl);
478    }
479
480    /**
481     * Returns the currently configured minimum implementation
482     *
483     * @return the StorelessUnivariateStatistic implementing the minimum
484     */
485    public StorelessUnivariateStatistic[] getMinImpl() {
486        return minImpl.clone();
487    }
488
489    /**
490     * <p>Sets the implementation for the minimum.</p>
491     * <p>This method must be activated before any data has been added - i.e.,
492     * before {@link #addValue(double[]) addValue} has been used to add data;
493     * otherwise an IllegalStateException will be thrown.</p>
494     *
495     * @param minImpl the StorelessUnivariateStatistic instance to use
496     * for computing the minimum
497     * @throws DimensionMismatchException if the array dimension
498     * does not match the one used at construction
499     * @throws MathIllegalStateException if data has already been added
500     *  (i.e if n > 0)
501     */
502    public void setMinImpl(StorelessUnivariateStatistic[] minImpl)
503    throws MathIllegalStateException, DimensionMismatchException {
504        setImpl(minImpl, this.minImpl);
505    }
506
507    /**
508     * Returns the currently configured maximum implementation
509     *
510     * @return the StorelessUnivariateStatistic implementing the maximum
511     */
512    public StorelessUnivariateStatistic[] getMaxImpl() {
513        return maxImpl.clone();
514    }
515
516    /**
517     * <p>Sets the implementation for the maximum.</p>
518     * <p>This method must be activated before any data has been added - i.e.,
519     * before {@link #addValue(double[]) addValue} has been used to add data;
520     * otherwise an IllegalStateException will be thrown.</p>
521     *
522     * @param maxImpl the StorelessUnivariateStatistic instance to use
523     * for computing the maximum
524     * @throws DimensionMismatchException if the array dimension
525     * does not match the one used at construction
526     * @throws MathIllegalStateException if data has already been added
527     *  (i.e if n > 0)
528     */
529    public void setMaxImpl(StorelessUnivariateStatistic[] maxImpl)
530    throws MathIllegalStateException, DimensionMismatchException{
531        setImpl(maxImpl, this.maxImpl);
532    }
533
534    /**
535     * Returns the currently configured sum of logs implementation
536     *
537     * @return the StorelessUnivariateStatistic implementing the log sum
538     */
539    public StorelessUnivariateStatistic[] getSumLogImpl() {
540        return sumLogImpl.clone();
541    }
542
543    /**
544     * <p>Sets the implementation for the sum of logs.</p>
545     * <p>This method must be activated before any data has been added - i.e.,
546     * before {@link #addValue(double[]) addValue} has been used to add data;
547     * otherwise an IllegalStateException will be thrown.</p>
548     *
549     * @param sumLogImpl the StorelessUnivariateStatistic instance to use
550     * for computing the log sum
551     * @throws DimensionMismatchException if the array dimension
552     * does not match the one used at construction
553     * @throws MathIllegalStateException if data has already been added
554     *  (i.e if n > 0)
555     */
556    public void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl)
557    throws MathIllegalStateException, DimensionMismatchException{
558        setImpl(sumLogImpl, this.sumLogImpl);
559    }
560
561    /**
562     * Returns the currently configured geometric mean implementation
563     *
564     * @return the StorelessUnivariateStatistic implementing the geometric mean
565     */
566    public StorelessUnivariateStatistic[] getGeoMeanImpl() {
567        return geoMeanImpl.clone();
568    }
569
570    /**
571     * <p>Sets the implementation for the geometric mean.</p>
572     * <p>This method must be activated before any data has been added - i.e.,
573     * before {@link #addValue(double[]) addValue} has been used to add data;
574     * otherwise an IllegalStateException will be thrown.</p>
575     *
576     * @param geoMeanImpl the StorelessUnivariateStatistic instance to use
577     * for computing the geometric mean
578     * @throws DimensionMismatchException if the array dimension
579     * does not match the one used at construction
580     * @throws MathIllegalStateException if data has already been added
581     *  (i.e if n > 0)
582     */
583    public void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl)
584    throws MathIllegalStateException, DimensionMismatchException {
585        setImpl(geoMeanImpl, this.geoMeanImpl);
586    }
587
588    /**
589     * Returns the currently configured mean implementation
590     *
591     * @return the StorelessUnivariateStatistic implementing the mean
592     */
593    public StorelessUnivariateStatistic[] getMeanImpl() {
594        return meanImpl.clone();
595    }
596
597    /**
598     * <p>Sets the implementation for the mean.</p>
599     * <p>This method must be activated before any data has been added - i.e.,
600     * before {@link #addValue(double[]) addValue} has been used to add data;
601     * otherwise an IllegalStateException will be thrown.</p>
602     *
603     * @param meanImpl the StorelessUnivariateStatistic instance to use
604     * for computing the mean
605     * @throws DimensionMismatchException if the array dimension
606     * does not match the one used at construction
607     * @throws MathIllegalStateException if data has already been added
608     *  (i.e if n > 0)
609     */
610    public void setMeanImpl(StorelessUnivariateStatistic[] meanImpl)
611    throws MathIllegalStateException, DimensionMismatchException{
612        setImpl(meanImpl, this.meanImpl);
613    }
614
615    /**
616     * Throws MathIllegalStateException if the statistic is not empty.
617     * @throws MathIllegalStateException if n > 0.
618     */
619    private void checkEmpty() throws MathIllegalStateException {
620        if (n > 0) {
621            throw new MathIllegalStateException(
622                    LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n);
623        }
624    }
625
626    /**
627     * Throws DimensionMismatchException if dimension != k.
628     * @param dimension dimension to check
629     * @throws DimensionMismatchException if dimension != k
630     */
631    private void checkDimension(int dimension) throws DimensionMismatchException {
632        if (dimension != k) {
633            throw new DimensionMismatchException(dimension, k);
634        }
635    }
636}