View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math3.stat.descriptive;
18  
19  import java.io.Serializable;
20  import java.util.Arrays;
21  
22  import org.apache.commons.math3.exception.util.LocalizedFormats;
23  import org.apache.commons.math3.exception.DimensionMismatchException;
24  import org.apache.commons.math3.exception.MathIllegalStateException;
25  import org.apache.commons.math3.linear.RealMatrix;
26  import org.apache.commons.math3.stat.descriptive.moment.GeometricMean;
27  import org.apache.commons.math3.stat.descriptive.moment.Mean;
28  import org.apache.commons.math3.stat.descriptive.moment.VectorialCovariance;
29  import org.apache.commons.math3.stat.descriptive.rank.Max;
30  import org.apache.commons.math3.stat.descriptive.rank.Min;
31  import org.apache.commons.math3.stat.descriptive.summary.Sum;
32  import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs;
33  import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares;
34  import org.apache.commons.math3.util.MathUtils;
35  import org.apache.commons.math3.util.MathArrays;
36  import org.apache.commons.math3.util.Precision;
37  import org.apache.commons.math3.util.FastMath;
38  
39  /**
40   * <p>Computes summary statistics for a stream of n-tuples added using the
41   * {@link #addValue(double[]) addValue} method. The data values are not stored
42   * in memory, so this class can be used to compute statistics for very large
43   * n-tuple streams.</p>
44   *
45   * <p>The {@link StorelessUnivariateStatistic} instances used to maintain
46   * summary state and compute statistics are configurable via setters.
47   * For example, the default implementation for the mean can be overridden by
48   * calling {@link #setMeanImpl(StorelessUnivariateStatistic[])}. Actual
49   * parameters to these methods must implement the
50   * {@link StorelessUnivariateStatistic} interface and configuration must be
51   * completed before <code>addValue</code> is called. No configuration is
52   * necessary to use the default, commons-math provided implementations.</p>
53   *
54   * <p>To compute statistics for a stream of n-tuples, construct a
55   * MultivariateStatistics instance with dimension n and then use
56   * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code>
57   * methods where Xxx is a statistic return an array of <code>double</code>
58   * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element is the
59   * value of the given statistic for data range consisting of the i<sup>th</sup> element of
60   * each of the input n-tuples.  For example, if <code>addValue</code> is called
61   * with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
62   * <code>getSum</code> will return a three-element array with values
63   * {0+3+6, 1+4+7, 2+5+8}</p>
64   *
65   * <p>Note: This class is not thread-safe. Use
66   * {@link SynchronizedMultivariateSummaryStatistics} if concurrent access from multiple
67   * threads is required.</p>
68   *
69   * @since 1.2
70   * @version $Id: MultivariateSummaryStatistics.java 1416643 2012-12-03 19:37:14Z tn $
71   */
72  public class MultivariateSummaryStatistics
73      implements StatisticalMultivariateSummary, Serializable {
74  
75      /** Serialization UID */
76      private static final long serialVersionUID = 2271900808994826718L;
77  
78      /** Dimension of the data. */
79      private int k;
80  
81      /** Count of values that have been added */
82      private long n = 0;
83  
84      /** Sum statistic implementation - can be reset by setter. */
85      private StorelessUnivariateStatistic[] sumImpl;
86  
87      /** Sum of squares statistic implementation - can be reset by setter. */
88      private StorelessUnivariateStatistic[] sumSqImpl;
89  
90      /** Minimum statistic implementation - can be reset by setter. */
91      private StorelessUnivariateStatistic[] minImpl;
92  
93      /** Maximum statistic implementation - can be reset by setter. */
94      private StorelessUnivariateStatistic[] maxImpl;
95  
96      /** Sum of log statistic implementation - can be reset by setter. */
97      private StorelessUnivariateStatistic[] sumLogImpl;
98  
99      /** Geometric mean statistic implementation - can be reset by setter. */
100     private StorelessUnivariateStatistic[] geoMeanImpl;
101 
102     /** Mean statistic implementation - can be reset by setter. */
103     private StorelessUnivariateStatistic[] meanImpl;
104 
105     /** Covariance statistic implementation - cannot be reset. */
106     private VectorialCovariance covarianceImpl;
107 
108     /**
109      * Construct a MultivariateSummaryStatistics instance
110      * @param k dimension of the data
111      * @param isCovarianceBiasCorrected if true, the unbiased sample
112      * covariance is computed, otherwise the biased population covariance
113      * is computed
114      */
115     public MultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
116         this.k = k;
117 
118         sumImpl     = new StorelessUnivariateStatistic[k];
119         sumSqImpl   = new StorelessUnivariateStatistic[k];
120         minImpl     = new StorelessUnivariateStatistic[k];
121         maxImpl     = new StorelessUnivariateStatistic[k];
122         sumLogImpl  = new StorelessUnivariateStatistic[k];
123         geoMeanImpl = new StorelessUnivariateStatistic[k];
124         meanImpl    = new StorelessUnivariateStatistic[k];
125 
126         for (int i = 0; i < k; ++i) {
127             sumImpl[i]     = new Sum();
128             sumSqImpl[i]   = new SumOfSquares();
129             minImpl[i]     = new Min();
130             maxImpl[i]     = new Max();
131             sumLogImpl[i]  = new SumOfLogs();
132             geoMeanImpl[i] = new GeometricMean();
133             meanImpl[i]    = new Mean();
134         }
135 
136         covarianceImpl =
137             new VectorialCovariance(k, isCovarianceBiasCorrected);
138 
139     }
140 
141     /**
142      * Add an n-tuple to the data
143      *
144      * @param value  the n-tuple to add
145      * @throws DimensionMismatchException if the length of the array
146      * does not match the one used at construction
147      */
148     public void addValue(double[] value) throws DimensionMismatchException {
149         checkDimension(value.length);
150         for (int i = 0; i < k; ++i) {
151             double v = value[i];
152             sumImpl[i].increment(v);
153             sumSqImpl[i].increment(v);
154             minImpl[i].increment(v);
155             maxImpl[i].increment(v);
156             sumLogImpl[i].increment(v);
157             geoMeanImpl[i].increment(v);
158             meanImpl[i].increment(v);
159         }
160         covarianceImpl.increment(value);
161         n++;
162     }
163 
164     /**
165      * Returns the dimension of the data
166      * @return The dimension of the data
167      */
168     public int getDimension() {
169         return k;
170     }
171 
172     /**
173      * Returns the number of available values
174      * @return The number of available values
175      */
176     public long getN() {
177         return n;
178     }
179 
180     /**
181      * Returns an array of the results of a statistic.
182      * @param stats univariate statistic array
183      * @return results array
184      */
185     private double[] getResults(StorelessUnivariateStatistic[] stats) {
186         double[] results = new double[stats.length];
187         for (int i = 0; i < results.length; ++i) {
188             results[i] = stats[i].getResult();
189         }
190         return results;
191     }
192 
193     /**
194      * Returns an array whose i<sup>th</sup> entry is the sum of the
195      * i<sup>th</sup> entries of the arrays that have been added using
196      * {@link #addValue(double[])}
197      *
198      * @return the array of component sums
199      */
200     public double[] getSum() {
201         return getResults(sumImpl);
202     }
203 
204     /**
205      * Returns an array whose i<sup>th</sup> entry is the sum of squares of the
206      * i<sup>th</sup> entries of the arrays that have been added using
207      * {@link #addValue(double[])}
208      *
209      * @return the array of component sums of squares
210      */
211     public double[] getSumSq() {
212         return getResults(sumSqImpl);
213     }
214 
215     /**
216      * Returns an array whose i<sup>th</sup> entry is the sum of logs of the
217      * i<sup>th</sup> entries of the arrays that have been added using
218      * {@link #addValue(double[])}
219      *
220      * @return the array of component log sums
221      */
222     public double[] getSumLog() {
223         return getResults(sumLogImpl);
224     }
225 
226     /**
227      * Returns an array whose i<sup>th</sup> entry is the mean of the
228      * i<sup>th</sup> entries of the arrays that have been added using
229      * {@link #addValue(double[])}
230      *
231      * @return the array of component means
232      */
233     public double[] getMean() {
234         return getResults(meanImpl);
235     }
236 
237     /**
238      * Returns an array whose i<sup>th</sup> entry is the standard deviation of the
239      * i<sup>th</sup> entries of the arrays that have been added using
240      * {@link #addValue(double[])}
241      *
242      * @return the array of component standard deviations
243      */
244     public double[] getStandardDeviation() {
245         double[] stdDev = new double[k];
246         if (getN() < 1) {
247             Arrays.fill(stdDev, Double.NaN);
248         } else if (getN() < 2) {
249             Arrays.fill(stdDev, 0.0);
250         } else {
251             RealMatrix matrix = covarianceImpl.getResult();
252             for (int i = 0; i < k; ++i) {
253                 stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
254             }
255         }
256         return stdDev;
257     }
258 
259     /**
260      * Returns the covariance matrix of the values that have been added.
261      *
262      * @return the covariance matrix
263      */
264     public RealMatrix getCovariance() {
265         return covarianceImpl.getResult();
266     }
267 
268     /**
269      * Returns an array whose i<sup>th</sup> entry is the maximum of the
270      * i<sup>th</sup> entries of the arrays that have been added using
271      * {@link #addValue(double[])}
272      *
273      * @return the array of component maxima
274      */
275     public double[] getMax() {
276         return getResults(maxImpl);
277     }
278 
279     /**
280      * Returns an array whose i<sup>th</sup> entry is the minimum of the
281      * i<sup>th</sup> entries of the arrays that have been added using
282      * {@link #addValue(double[])}
283      *
284      * @return the array of component minima
285      */
286     public double[] getMin() {
287         return getResults(minImpl);
288     }
289 
290     /**
291      * Returns an array whose i<sup>th</sup> entry is the geometric mean of the
292      * i<sup>th</sup> entries of the arrays that have been added using
293      * {@link #addValue(double[])}
294      *
295      * @return the array of component geometric means
296      */
297     public double[] getGeometricMean() {
298         return getResults(geoMeanImpl);
299     }
300 
301     /**
302      * Generates a text report displaying
303      * summary statistics from values that
304      * have been added.
305      * @return String with line feeds displaying statistics
306      */
307     @Override
308     public String toString() {
309         final String separator = ", ";
310         final String suffix = System.getProperty("line.separator");
311         StringBuilder outBuffer = new StringBuilder();
312         outBuffer.append("MultivariateSummaryStatistics:" + suffix);
313         outBuffer.append("n: " + getN() + suffix);
314         append(outBuffer, getMin(), "min: ", separator, suffix);
315         append(outBuffer, getMax(), "max: ", separator, suffix);
316         append(outBuffer, getMean(), "mean: ", separator, suffix);
317         append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
318         append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
319         append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
320         append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
321         outBuffer.append("covariance: " + getCovariance().toString() + suffix);
322         return outBuffer.toString();
323     }
324 
325     /**
326      * Append a text representation of an array to a buffer.
327      * @param buffer buffer to fill
328      * @param data data array
329      * @param prefix text prefix
330      * @param separator elements separator
331      * @param suffix text suffix
332      */
333     private void append(StringBuilder buffer, double[] data,
334                         String prefix, String separator, String suffix) {
335         buffer.append(prefix);
336         for (int i = 0; i < data.length; ++i) {
337             if (i > 0) {
338                 buffer.append(separator);
339             }
340             buffer.append(data[i]);
341         }
342         buffer.append(suffix);
343     }
344 
345     /**
346      * Resets all statistics and storage
347      */
348     public void clear() {
349         this.n = 0;
350         for (int i = 0; i < k; ++i) {
351             minImpl[i].clear();
352             maxImpl[i].clear();
353             sumImpl[i].clear();
354             sumLogImpl[i].clear();
355             sumSqImpl[i].clear();
356             geoMeanImpl[i].clear();
357             meanImpl[i].clear();
358         }
359         covarianceImpl.clear();
360     }
361 
362     /**
363      * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code>
364      * instance and all statistics have the same values as this.
365      * @param object the object to test equality against.
366      * @return true if object equals this
367      */
368     @Override
369     public boolean equals(Object object) {
370         if (object == this ) {
371             return true;
372         }
373         if (object instanceof MultivariateSummaryStatistics == false) {
374             return false;
375         }
376         MultivariateSummaryStatistics stat = (MultivariateSummaryStatistics) object;
377         return MathArrays.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) &&
378                MathArrays.equalsIncludingNaN(stat.getMax(),           getMax())           &&
379                MathArrays.equalsIncludingNaN(stat.getMean(),          getMean())          &&
380                MathArrays.equalsIncludingNaN(stat.getMin(),           getMin())           &&
381                Precision.equalsIncludingNaN(stat.getN(),             getN())             &&
382                MathArrays.equalsIncludingNaN(stat.getSum(),           getSum())           &&
383                MathArrays.equalsIncludingNaN(stat.getSumSq(),         getSumSq())         &&
384                MathArrays.equalsIncludingNaN(stat.getSumLog(),        getSumLog())        &&
385                stat.getCovariance().equals( getCovariance());
386     }
387 
388     /**
389      * Returns hash code based on values of statistics
390      *
391      * @return hash code
392      */
393     @Override
394     public int hashCode() {
395         int result = 31 + MathUtils.hash(getGeometricMean());
396         result = result * 31 + MathUtils.hash(getGeometricMean());
397         result = result * 31 + MathUtils.hash(getMax());
398         result = result * 31 + MathUtils.hash(getMean());
399         result = result * 31 + MathUtils.hash(getMin());
400         result = result * 31 + MathUtils.hash(getN());
401         result = result * 31 + MathUtils.hash(getSum());
402         result = result * 31 + MathUtils.hash(getSumSq());
403         result = result * 31 + MathUtils.hash(getSumLog());
404         result = result * 31 + getCovariance().hashCode();
405         return result;
406     }
407 
408     // Getters and setters for statistics implementations
409     /**
410      * Sets statistics implementations.
411      * @param newImpl new implementations for statistics
412      * @param oldImpl old implementations for statistics
413      * @throws DimensionMismatchException if the array dimension
414      * does not match the one used at construction
415      * @throws MathIllegalStateException if data has already been added
416      * (i.e. if n > 0)
417      */
418     private void setImpl(StorelessUnivariateStatistic[] newImpl,
419                          StorelessUnivariateStatistic[] oldImpl) throws MathIllegalStateException,
420                          DimensionMismatchException {
421         checkEmpty();
422         checkDimension(newImpl.length);
423         System.arraycopy(newImpl, 0, oldImpl, 0, newImpl.length);
424     }
425 
426     /**
427      * Returns the currently configured Sum implementation
428      *
429      * @return the StorelessUnivariateStatistic implementing the sum
430      */
431     public StorelessUnivariateStatistic[] getSumImpl() {
432         return sumImpl.clone();
433     }
434 
435     /**
436      * <p>Sets the implementation for the Sum.</p>
437      * <p>This method must be activated before any data has been added - i.e.,
438      * before {@link #addValue(double[]) addValue} has been used to add data;
439      * otherwise an IllegalStateException will be thrown.</p>
440      *
441      * @param sumImpl the StorelessUnivariateStatistic instance to use
442      * for computing the Sum
443      * @throws DimensionMismatchException if the array dimension
444      * does not match the one used at construction
445      * @throws MathIllegalStateException if data has already been added
446      *  (i.e if n > 0)
447      */
448     public void setSumImpl(StorelessUnivariateStatistic[] sumImpl)
449     throws MathIllegalStateException, DimensionMismatchException {
450         setImpl(sumImpl, this.sumImpl);
451     }
452 
453     /**
454      * Returns the currently configured sum of squares implementation
455      *
456      * @return the StorelessUnivariateStatistic implementing the sum of squares
457      */
458     public StorelessUnivariateStatistic[] getSumsqImpl() {
459         return sumSqImpl.clone();
460     }
461 
462     /**
463      * <p>Sets the implementation for the sum of squares.</p>
464      * <p>This method must be activated before any data has been added - i.e.,
465      * before {@link #addValue(double[]) addValue} has been used to add data;
466      * otherwise an IllegalStateException will be thrown.</p>
467      *
468      * @param sumsqImpl the StorelessUnivariateStatistic instance to use
469      * for computing the sum of squares
470      * @throws DimensionMismatchException if the array dimension
471      * does not match the one used at construction
472      * @throws MathIllegalStateException if data has already been added
473      *  (i.e if n > 0)
474      */
475     public void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl)
476     throws MathIllegalStateException, DimensionMismatchException {
477         setImpl(sumsqImpl, this.sumSqImpl);
478     }
479 
480     /**
481      * Returns the currently configured minimum implementation
482      *
483      * @return the StorelessUnivariateStatistic implementing the minimum
484      */
485     public StorelessUnivariateStatistic[] getMinImpl() {
486         return minImpl.clone();
487     }
488 
489     /**
490      * <p>Sets the implementation for the minimum.</p>
491      * <p>This method must be activated before any data has been added - i.e.,
492      * before {@link #addValue(double[]) addValue} has been used to add data;
493      * otherwise an IllegalStateException will be thrown.</p>
494      *
495      * @param minImpl the StorelessUnivariateStatistic instance to use
496      * for computing the minimum
497      * @throws DimensionMismatchException if the array dimension
498      * does not match the one used at construction
499      * @throws MathIllegalStateException if data has already been added
500      *  (i.e if n > 0)
501      */
502     public void setMinImpl(StorelessUnivariateStatistic[] minImpl)
503     throws MathIllegalStateException, DimensionMismatchException {
504         setImpl(minImpl, this.minImpl);
505     }
506 
507     /**
508      * Returns the currently configured maximum implementation
509      *
510      * @return the StorelessUnivariateStatistic implementing the maximum
511      */
512     public StorelessUnivariateStatistic[] getMaxImpl() {
513         return maxImpl.clone();
514     }
515 
516     /**
517      * <p>Sets the implementation for the maximum.</p>
518      * <p>This method must be activated before any data has been added - i.e.,
519      * before {@link #addValue(double[]) addValue} has been used to add data;
520      * otherwise an IllegalStateException will be thrown.</p>
521      *
522      * @param maxImpl the StorelessUnivariateStatistic instance to use
523      * for computing the maximum
524      * @throws DimensionMismatchException if the array dimension
525      * does not match the one used at construction
526      * @throws MathIllegalStateException if data has already been added
527      *  (i.e if n > 0)
528      */
529     public void setMaxImpl(StorelessUnivariateStatistic[] maxImpl)
530     throws MathIllegalStateException, DimensionMismatchException{
531         setImpl(maxImpl, this.maxImpl);
532     }
533 
534     /**
535      * Returns the currently configured sum of logs implementation
536      *
537      * @return the StorelessUnivariateStatistic implementing the log sum
538      */
539     public StorelessUnivariateStatistic[] getSumLogImpl() {
540         return sumLogImpl.clone();
541     }
542 
543     /**
544      * <p>Sets the implementation for the sum of logs.</p>
545      * <p>This method must be activated before any data has been added - i.e.,
546      * before {@link #addValue(double[]) addValue} has been used to add data;
547      * otherwise an IllegalStateException will be thrown.</p>
548      *
549      * @param sumLogImpl the StorelessUnivariateStatistic instance to use
550      * for computing the log sum
551      * @throws DimensionMismatchException if the array dimension
552      * does not match the one used at construction
553      * @throws MathIllegalStateException if data has already been added
554      *  (i.e if n > 0)
555      */
556     public void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl)
557     throws MathIllegalStateException, DimensionMismatchException{
558         setImpl(sumLogImpl, this.sumLogImpl);
559     }
560 
561     /**
562      * Returns the currently configured geometric mean implementation
563      *
564      * @return the StorelessUnivariateStatistic implementing the geometric mean
565      */
566     public StorelessUnivariateStatistic[] getGeoMeanImpl() {
567         return geoMeanImpl.clone();
568     }
569 
570     /**
571      * <p>Sets the implementation for the geometric mean.</p>
572      * <p>This method must be activated before any data has been added - i.e.,
573      * before {@link #addValue(double[]) addValue} has been used to add data;
574      * otherwise an IllegalStateException will be thrown.</p>
575      *
576      * @param geoMeanImpl the StorelessUnivariateStatistic instance to use
577      * for computing the geometric mean
578      * @throws DimensionMismatchException if the array dimension
579      * does not match the one used at construction
580      * @throws MathIllegalStateException if data has already been added
581      *  (i.e if n > 0)
582      */
583     public void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl)
584     throws MathIllegalStateException, DimensionMismatchException {
585         setImpl(geoMeanImpl, this.geoMeanImpl);
586     }
587 
588     /**
589      * Returns the currently configured mean implementation
590      *
591      * @return the StorelessUnivariateStatistic implementing the mean
592      */
593     public StorelessUnivariateStatistic[] getMeanImpl() {
594         return meanImpl.clone();
595     }
596 
597     /**
598      * <p>Sets the implementation for the mean.</p>
599      * <p>This method must be activated before any data has been added - i.e.,
600      * before {@link #addValue(double[]) addValue} has been used to add data;
601      * otherwise an IllegalStateException will be thrown.</p>
602      *
603      * @param meanImpl the StorelessUnivariateStatistic instance to use
604      * for computing the mean
605      * @throws DimensionMismatchException if the array dimension
606      * does not match the one used at construction
607      * @throws MathIllegalStateException if data has already been added
608      *  (i.e if n > 0)
609      */
610     public void setMeanImpl(StorelessUnivariateStatistic[] meanImpl)
611     throws MathIllegalStateException, DimensionMismatchException{
612         setImpl(meanImpl, this.meanImpl);
613     }
614 
615     /**
616      * Throws MathIllegalStateException if the statistic is not empty.
617      * @throws MathIllegalStateException if n > 0.
618      */
619     private void checkEmpty() throws MathIllegalStateException {
620         if (n > 0) {
621             throw new MathIllegalStateException(
622                     LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n);
623         }
624     }
625 
626     /**
627      * Throws DimensionMismatchException if dimension != k.
628      * @param dimension dimension to check
629      * @throws DimensionMismatchException if dimension != k
630      */
631     private void checkDimension(int dimension) throws DimensionMismatchException {
632         if (dimension != k) {
633             throw new DimensionMismatchException(dimension, k);
634         }
635     }
636 }