FuzzyKMeansClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.ml.clustering;

  18. import java.util.ArrayList;
  19. import java.util.Collection;
  20. import java.util.Collections;
  21. import java.util.List;

  22. import org.apache.commons.math4.legacy.exception.NullArgumentException;
  23. import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
  24. import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
  25. import org.apache.commons.math4.legacy.linear.MatrixUtils;
  26. import org.apache.commons.math4.legacy.linear.RealMatrix;
  27. import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
  28. import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
  29. import org.apache.commons.rng.simple.RandomSource;
  30. import org.apache.commons.rng.UniformRandomProvider;
  31. import org.apache.commons.math4.core.jdkmath.JdkMath;
  32. import org.apache.commons.math4.legacy.core.MathArrays;

  33. /**
  34.  * Fuzzy K-Means clustering algorithm.
  35.  * <p>
  36.  * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
  37.  * major difference that a single data point is not uniquely assigned to a single cluster.
  38.  * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership
  39.  * to the cluster j.
  40.  * <p>
  41.  * The algorithm then tries to minimize the objective function:
  42.  * <div style="white-space: pre"><code>
  43.  * J = &#8721;<sub>i=1..C</sub>&#8721;<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup>
  44.  * </code></div>
  45.  * with d<sub>ik</sub> being the distance between data point i and the cluster center k.
  46.  * <p>
  47.  * The algorithm requires two parameters:
  48.  * <ul>
  49.  *   <li>k: the number of clusters
  50.  *   <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
  51.  * </ul>
  52.  * Additional, optional parameters:
  53.  * <ul>
  54.  *   <li>maxIterations: the maximum number of iterations
  55.  *   <li>epsilon: the convergence criteria, default is 1e-3
  56.  * </ul>
  57.  * <p>
  58.  * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
  59.  * of the initial cluster centers.
  60.  *
  61.  * @param <T> type of the points to cluster
  62.  * @since 3.3
  63.  */
  64. public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {

  65.     /** The default value for the convergence criteria. */
  66.     private static final double DEFAULT_EPSILON = 1e-3;

  67.     /** The number of clusters. */
  68.     private final int k;

  69.     /** The maximum number of iterations. */
  70.     private final int maxIterations;

  71.     /** The fuzziness factor. */
  72.     private final double fuzziness;

  73.     /** The convergence criteria. */
  74.     private final double epsilon;

  75.     /** Random generator for choosing initial centers. */
  76.     private final UniformRandomProvider random;

  77.     /** The membership matrix. */
  78.     private double[][] membershipMatrix;

  79.     /** The list of points used in the last call to {@link #cluster(Collection)}. */
  80.     private List<T> points;

  81.     /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
  82.     private List<CentroidCluster<T>> clusters;

  83.     /**
  84.      * Creates a new instance of a FuzzyKMeansClusterer.
  85.      * <p>
  86.      * The euclidean distance will be used as default distance measure.
  87.      *
  88.      * @param k the number of clusters to split the data into
  89.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  90.      * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
  91.      */
  92.     public FuzzyKMeansClusterer(final int k, final double fuzziness) {
  93.         this(k, fuzziness, -1, new EuclideanDistance());
  94.     }

  95.     /**
  96.      * Creates a new instance of a FuzzyKMeansClusterer.
  97.      *
  98.      * @param k the number of clusters to split the data into
  99.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  100.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  101.      *   If negative, no maximum will be used.
  102.      * @param measure the distance measure to use
  103.      * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
  104.      */
  105.     public FuzzyKMeansClusterer(final int k, final double fuzziness,
  106.                                 final int maxIterations, final DistanceMeasure measure) {
  107.         this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, RandomSource.MT_64.create());
  108.     }

  109.     /**
  110.      * Creates a new instance of a FuzzyKMeansClusterer.
  111.      *
  112.      * @param k the number of clusters to split the data into
  113.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  114.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  115.      *   If negative, no maximum will be used.
  116.      * @param measure the distance measure to use
  117.      * @param epsilon the convergence criteria (default is 1e-3)
  118.      * @param random random generator to use for choosing initial centers
  119.      * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
  120.      */
  121.     public FuzzyKMeansClusterer(final int k, final double fuzziness,
  122.                                 final int maxIterations, final DistanceMeasure measure,
  123.                                 final double epsilon, final UniformRandomProvider random) {
  124.         super(measure);

  125.         if (fuzziness <= 1.0d) {
  126.             throw new NumberIsTooSmallException(fuzziness, 1.0, false);
  127.         }
  128.         this.k = k;
  129.         this.fuzziness = fuzziness;
  130.         this.maxIterations = maxIterations;
  131.         this.epsilon = epsilon;
  132.         this.random = random;

  133.         this.membershipMatrix = null;
  134.         this.points = null;
  135.         this.clusters = null;
  136.     }

  137.     /**
  138.      * Return the number of clusters this instance will use.
  139.      * @return the number of clusters
  140.      */
  141.     public int getK() {
  142.         return k;
  143.     }

  144.     /**
  145.      * Returns the fuzziness factor used by this instance.
  146.      * @return the fuzziness factor
  147.      */
  148.     public double getFuzziness() {
  149.         return fuzziness;
  150.     }

  151.     /**
  152.      * Returns the maximum number of iterations this instance will use.
  153.      * @return the maximum number of iterations, or -1 if no maximum is set
  154.      */
  155.     public int getMaxIterations() {
  156.         return maxIterations;
  157.     }

  158.     /**
  159.      * Returns the convergence criteria used by this instance.
  160.      * @return the convergence criteria
  161.      */
  162.     public double getEpsilon() {
  163.         return epsilon;
  164.     }

  165.     /**
  166.      * Returns the random generator this instance will use.
  167.      * @return the random generator
  168.      */
  169.     public UniformRandomProvider getRandomGenerator() {
  170.         return random;
  171.     }

  172.     /**
  173.      * Returns the {@code nxk} membership matrix, where {@code n} is the number
  174.      * of data points and {@code k} the number of clusters.
  175.      * <p>
  176.      * The element U<sub>i,j</sub> represents the membership value for data point {@code i}
  177.      * to cluster {@code j}.
  178.      *
  179.      * @return the membership matrix
  180.      * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
  181.      */
  182.     public RealMatrix getMembershipMatrix() {
  183.         if (membershipMatrix == null) {
  184.             throw new MathIllegalStateException();
  185.         }
  186.         return MatrixUtils.createRealMatrix(membershipMatrix);
  187.     }

  188.     /**
  189.      * Returns an unmodifiable list of the data points used in the last
  190.      * call to {@link #cluster(Collection)}.
  191.      * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
  192.      *   not been called before.
  193.      */
  194.     public List<T> getDataPoints() {
  195.         return points;
  196.     }

  197.     /**
  198.      * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
  199.      * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
  200.      *   not been called before.
  201.      */
  202.     public List<CentroidCluster<T>> getClusters() {
  203.         return clusters;
  204.     }

  205.     /**
  206.      * Get the value of the objective function.
  207.      * @return the objective function evaluation as double value
  208.      * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
  209.      */
  210.     public double getObjectiveFunctionValue() {
  211.         if (points == null || clusters == null) {
  212.             throw new MathIllegalStateException();
  213.         }

  214.         int i = 0;
  215.         double objFunction = 0.0;
  216.         for (final T point : points) {
  217.             int j = 0;
  218.             for (final CentroidCluster<T> cluster : clusters) {
  219.                 final double dist = distance(point, cluster.getCenter());
  220.                 objFunction += (dist * dist) * JdkMath.pow(membershipMatrix[i][j], fuzziness);
  221.                 j++;
  222.             }
  223.             i++;
  224.         }
  225.         return objFunction;
  226.     }

  227.     /**
  228.      * Performs Fuzzy K-Means cluster analysis.
  229.      *
  230.      * @param dataPoints the points to cluster
  231.      * @return the list of clusters
  232.      * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException if
  233.      * the data points are null or the number of clusters is larger than the number
  234.      * of data points
  235.      */
  236.     @Override
  237.     public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) {
  238.         // sanity checks
  239.         NullArgumentException.check(dataPoints);

  240.         final int size = dataPoints.size();

  241.         // number of clusters has to be smaller or equal the number of data points
  242.         if (size < k) {
  243.             throw new NumberIsTooSmallException(size, k, false);
  244.         }

  245.         // copy the input collection to an unmodifiable list with indexed access
  246.         points = Collections.unmodifiableList(new ArrayList<>(dataPoints));
  247.         clusters = new ArrayList<>();
  248.         membershipMatrix = new double[size][k];
  249.         final double[][] oldMatrix = new double[size][k];

  250.         // if no points are provided, return an empty list of clusters
  251.         if (size == 0) {
  252.             return clusters;
  253.         }

  254.         initializeMembershipMatrix();

  255.         // there is at least one point
  256.         final int pointDimension = points.get(0).getPoint().length;
  257.         for (int i = 0; i < k; i++) {
  258.             clusters.add(new CentroidCluster<>(new DoublePoint(new double[pointDimension])));
  259.         }

  260.         int iteration = 0;
  261.         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
  262.         double difference = 0.0;

  263.         do {
  264.             saveMembershipMatrix(oldMatrix);
  265.             updateClusterCenters();
  266.             updateMembershipMatrix();
  267.             difference = calculateMaxMembershipChange(oldMatrix);
  268.         } while (difference > epsilon && ++iteration < max);

  269.         return clusters;
  270.     }

  271.     /**
  272.      * Update the cluster centers.
  273.      */
  274.     private void updateClusterCenters() {
  275.         int j = 0;
  276.         final List<CentroidCluster<T>> newClusters = new ArrayList<>(k);
  277.         for (final CentroidCluster<T> cluster : clusters) {
  278.             final Clusterable center = cluster.getCenter();
  279.             int i = 0;
  280.             double[] arr = new double[center.getPoint().length];
  281.             double sum = 0.0;
  282.             for (final T point : points) {
  283.                 final double u = JdkMath.pow(membershipMatrix[i][j], fuzziness);
  284.                 final double[] pointArr = point.getPoint();
  285.                 for (int idx = 0; idx < arr.length; idx++) {
  286.                     arr[idx] += u * pointArr[idx];
  287.                 }
  288.                 sum += u;
  289.                 i++;
  290.             }
  291.             MathArrays.scaleInPlace(1.0 / sum, arr);
  292.             newClusters.add(new CentroidCluster<>(new DoublePoint(arr)));
  293.             j++;
  294.         }
  295.         clusters.clear();
  296.         clusters = newClusters;
  297.     }

  298.     /**
  299.      * Updates the membership matrix and assigns the points to the cluster with
  300.      * the highest membership.
  301.      */
  302.     private void updateMembershipMatrix() {
  303.         for (int i = 0; i < points.size(); i++) {
  304.             final T point = points.get(i);
  305.             double maxMembership = Double.MIN_VALUE;
  306.             int newCluster = -1;
  307.             for (int j = 0; j < clusters.size(); j++) {
  308.                 double sum = 0.0;
  309.                 final double distA = JdkMath.abs(distance(point, clusters.get(j).getCenter()));

  310.                 if (distA != 0.0) {
  311.                     for (final CentroidCluster<T> c : clusters) {
  312.                         final double distB = JdkMath.abs(distance(point, c.getCenter()));
  313.                         if (distB == 0.0) {
  314.                             sum = Double.POSITIVE_INFINITY;
  315.                             break;
  316.                         }
  317.                         sum += JdkMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
  318.                     }
  319.                 }

  320.                 double membership;
  321.                 if (sum == 0.0) {
  322.                     membership = 1.0;
  323.                 } else if (sum == Double.POSITIVE_INFINITY) {
  324.                     membership = 0.0;
  325.                 } else {
  326.                     membership = 1.0 / sum;
  327.                 }
  328.                 membershipMatrix[i][j] = membership;

  329.                 if (membershipMatrix[i][j] > maxMembership) {
  330.                     maxMembership = membershipMatrix[i][j];
  331.                     newCluster = j;
  332.                 }
  333.             }
  334.             clusters.get(newCluster).addPoint(point);
  335.         }
  336.     }

  337.     /**
  338.      * Initialize the membership matrix with random values.
  339.      */
  340.     private void initializeMembershipMatrix() {
  341.         for (int i = 0; i < points.size(); i++) {
  342.             for (int j = 0; j < k; j++) {
  343.                 membershipMatrix[i][j] = random.nextDouble();
  344.             }
  345.             membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
  346.         }
  347.     }

  348.     /**
  349.      * Calculate the maximum element-by-element change of the membership matrix
  350.      * for the current iteration.
  351.      *
  352.      * @param matrix the membership matrix of the previous iteration
  353.      * @return the maximum membership matrix change
  354.      */
  355.     private double calculateMaxMembershipChange(final double[][] matrix) {
  356.         double maxMembership = 0.0;
  357.         for (int i = 0; i < points.size(); i++) {
  358.             for (int j = 0; j < clusters.size(); j++) {
  359.                 double v = JdkMath.abs(membershipMatrix[i][j] - matrix[i][j]);
  360.                 maxMembership = JdkMath.max(v, maxMembership);
  361.             }
  362.         }
  363.         return maxMembership;
  364.     }

  365.     /**
  366.      * Copy the membership matrix into the provided matrix.
  367.      *
  368.      * @param matrix the place to store the membership matrix
  369.      */
  370.     private void saveMembershipMatrix(final double[][] matrix) {
  371.         for (int i = 0; i < points.size(); i++) {
  372.             System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
  373.         }
  374.     }
  375. }