CosineSimilarity.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.text.similarity;

  18. import java.util.HashSet;
  19. import java.util.Map;
  20. import java.util.Set;

  21. /**
  22.  * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
  23.  * <p>
  24.  * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
  25.  * </p>
  26.  * <p>
  27.  * Instances of this class are immutable and are safe for use by multiple concurrent threads.
  28.  * </p>
  29.  *
  30.  * @since 1.0
  31.  */
  32. public class CosineSimilarity {

  33.     /**
  34.      * Singleton instance.
  35.      */
  36.     static final CosineSimilarity INSTANCE = new CosineSimilarity();

  37.     /**
  38.      * Construct a new instance.
  39.      */
  40.     public CosineSimilarity() {
  41.         // empty
  42.     }

  43.     /**
  44.      * Calculates the cosine similarity for two given vectors.
  45.      *
  46.      * @param leftVector left vector
  47.      * @param rightVector right vector
  48.      * @return cosine similarity between the two vectors
  49.      */
  50.     public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
  51.                                    final Map<CharSequence, Integer> rightVector) {
  52.         if (leftVector == null || rightVector == null) {
  53.             throw new IllegalArgumentException("Vectors must not be null");
  54.         }

  55.         final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);

  56.         final double dotProduct = dot(leftVector, rightVector, intersection);
  57.         double d1 = 0.0d;
  58.         for (final Integer value : leftVector.values()) {
  59.             d1 += Math.pow(value, 2);
  60.         }
  61.         double d2 = 0.0d;
  62.         for (final Integer value : rightVector.values()) {
  63.             d2 += Math.pow(value, 2);
  64.         }
  65.         final double cosineSimilarity;
  66.         if (d1 <= 0.0 || d2 <= 0.0) {
  67.             cosineSimilarity = 0.0;
  68.         } else {
  69.             cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
  70.         }
  71.         return cosineSimilarity;
  72.     }

  73.     /**
  74.      * Computes the dot product of two vectors. It ignores remaining elements. It means
  75.      * that if a vector is longer than other, then a smaller part of it will be used to compute
  76.      * the dot product.
  77.      *
  78.      * @param leftVector left vector
  79.      * @param rightVector right vector
  80.      * @param intersection common elements
  81.      * @return The dot product
  82.      */
  83.     private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
  84.             final Set<CharSequence> intersection) {
  85.         long dotProduct = 0;
  86.         for (final CharSequence key : intersection) {
  87.             dotProduct += leftVector.get(key) * (long) rightVector.get(key);
  88.         }
  89.         return dotProduct;
  90.     }

  91.     /**
  92.      * Returns a set with keys common to the two given maps.
  93.      *
  94.      * @param leftVector left vector map
  95.      * @param rightVector right vector map
  96.      * @return common strings
  97.      */
  98.     private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
  99.             final Map<CharSequence, Integer> rightVector) {
  100.         final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
  101.         intersection.retainAll(rightVector.keySet());
  102.         return intersection;
  103.     }

  104. }