CosineSimilarity.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.text.similarity;

  18. import java.util.HashSet;
  19. import java.util.Map;
  20. import java.util.Set;

  21. /**
  22.  * Measures the Cosine similarity of two vectors of an inner product space and
  23.  * compares the angle between them.
  24.  *
  25.  * <p>
  26.  * For further explanation about the Cosine Similarity, refer to
  27.  * http://en.wikipedia.org/wiki/Cosine_similarity.
  28.  * </p>
  29.  *
  30.  * @since 1.0
  31.  */
  32. public class CosineSimilarity {

  33.     /**
  34.      * Calculates the cosine similarity for two given vectors.
  35.      *
  36.      * @param leftVector left vector
  37.      * @param rightVector right vector
  38.      * @return cosine similarity between the two vectors
  39.      */
  40.     public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector) {
  41.         if (leftVector == null || rightVector == null) {
  42.             throw new IllegalArgumentException("Vectors must not be null");
  43.         }

  44.         final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);

  45.         final double dotProduct = dot(leftVector, rightVector, intersection);
  46.         double d1 = 0.0d;
  47.         for (final Integer value : leftVector.values()) {
  48.             d1 += Math.pow(value, 2);
  49.         }
  50.         double d2 = 0.0d;
  51.         for (final Integer value : rightVector.values()) {
  52.             d2 += Math.pow(value, 2);
  53.         }
  54.         double cosineSimilarity;
  55.         if (d1 <= 0.0 || d2 <= 0.0) {
  56.             cosineSimilarity = 0.0;
  57.         } else {
  58.             cosineSimilarity = (double) (dotProduct / (double) (Math.sqrt(d1) * Math.sqrt(d2)));
  59.         }
  60.         return cosineSimilarity;
  61.     }

  62.     /**
  63.      * Returns a set with strings common to the two given maps.
  64.      *
  65.      * @param leftVector left vector map
  66.      * @param rightVector right vector map
  67.      * @return common strings
  68.      */
  69.     private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
  70.             final Map<CharSequence, Integer> rightVector) {
  71.         final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
  72.         intersection.retainAll(rightVector.keySet());
  73.         return intersection;
  74.     }

  75.     /**
  76.      * Computes the dot product of two vectors. It ignores remaining elements. It means
  77.      * that if a vector is longer than other, then a smaller part of it will be used to compute
  78.      * the dot product.
  79.      *
  80.      * @param leftVector left vector
  81.      * @param rightVector right vector
  82.      * @param intersection common elements
  83.      * @return the dot product
  84.      */
  85.     private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
  86.             final Set<CharSequence> intersection) {
  87.         long dotProduct = 0;
  88.         for (final CharSequence key : intersection) {
  89.             dotProduct += leftVector.get(key) * rightVector.get(key);
  90.         }
  91.         return dotProduct;
  92.     }

  93. }