View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.text.similarity;
18  
19  import java.util.HashSet;
20  import java.util.Map;
21  import java.util.Set;
22  
23  /**
24   * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
25   * <p>
26   * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
27   * </p>
28   * <p>
29   * Instances of this class are immutable and are safe for use by multiple concurrent threads.
30   * </p>
31   *
32   * @since 1.0
33   */
34  public class CosineSimilarity {
35  
36      /**
37       * The singleton instance.
38       */
39      static final CosineSimilarity INSTANCE = new CosineSimilarity();
40  
41      /**
42       * Construct a new instance.
43       */
44      public CosineSimilarity() {
45          // empty
46      }
47  
48      /**
49       * Calculates the cosine similarity for two given vectors.
50       *
51       * @param leftVector left vector.
52       * @param rightVector right vector.
53       * @return cosine similarity between the two vectors.
54       */
55      public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
56                                     final Map<CharSequence, Integer> rightVector) {
57          if (leftVector == null || rightVector == null) {
58              throw new IllegalArgumentException("Vectors must not be null");
59          }
60  
61          final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
62  
63          final double dotProduct = dot(leftVector, rightVector, intersection);
64          double d1 = 0.0d;
65          for (final Integer value : leftVector.values()) {
66              d1 += Math.pow(value, 2);
67          }
68          double d2 = 0.0d;
69          for (final Integer value : rightVector.values()) {
70              d2 += Math.pow(value, 2);
71          }
72          final double cosineSimilarity;
73          if (d1 <= 0.0 || d2 <= 0.0) {
74              cosineSimilarity = 0.0;
75          } else {
76              cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
77          }
78          return cosineSimilarity;
79      }
80  
81      /**
82       * Computes the dot product of two vectors. It ignores remaining elements. It means
83       * that if a vector is longer than other, then a smaller part of it will be used to compute
84       * the dot product.
85       *
86       * @param leftVector left vector.
87       * @param rightVector right vector.
88       * @param intersection common elements.
89       * @return The dot product.
90       */
91      private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
92              final Set<CharSequence> intersection) {
93          long dotProduct = 0;
94          for (final CharSequence key : intersection) {
95              dotProduct += leftVector.get(key) * (long) rightVector.get(key);
96          }
97          return dotProduct;
98      }
99  
100     /**
101      * Returns a set with keys common to the two given maps.
102      *
103      * @param leftVector left vector map.
104      * @param rightVector right vector map.
105      * @return common strings.
106      */
107     private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
108             final Map<CharSequence, Integer> rightVector) {
109         final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
110         intersection.retainAll(rightVector.keySet());
111         return intersection;
112     }
113 
114 }