1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.text.similarity;
18
19 import java.util.HashSet;
20 import java.util.Map;
21 import java.util.Set;
22
23 /**
24 * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
25 * <p>
26 * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
27 * </p>
28 * <p>
29 * Instances of this class are immutable and are safe for use by multiple concurrent threads.
30 * </p>
31 *
32 * @since 1.0
33 */
34 public class CosineSimilarity {
35
36 /**
37 * The singleton instance.
38 */
39 static final CosineSimilarity INSTANCE = new CosineSimilarity();
40
41 /**
42 * Construct a new instance.
43 */
44 public CosineSimilarity() {
45 // empty
46 }
47
48 /**
49 * Calculates the cosine similarity for two given vectors.
50 *
51 * @param leftVector left vector.
52 * @param rightVector right vector.
53 * @return cosine similarity between the two vectors.
54 */
55 public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
56 final Map<CharSequence, Integer> rightVector) {
57 if (leftVector == null || rightVector == null) {
58 throw new IllegalArgumentException("Vectors must not be null");
59 }
60
61 final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
62
63 final double dotProduct = dot(leftVector, rightVector, intersection);
64 double d1 = 0.0d;
65 for (final Integer value : leftVector.values()) {
66 d1 += Math.pow(value, 2);
67 }
68 double d2 = 0.0d;
69 for (final Integer value : rightVector.values()) {
70 d2 += Math.pow(value, 2);
71 }
72 final double cosineSimilarity;
73 if (d1 <= 0.0 || d2 <= 0.0) {
74 cosineSimilarity = 0.0;
75 } else {
76 cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
77 }
78 return cosineSimilarity;
79 }
80
81 /**
82 * Computes the dot product of two vectors. It ignores remaining elements. It means
83 * that if a vector is longer than other, then a smaller part of it will be used to compute
84 * the dot product.
85 *
86 * @param leftVector left vector.
87 * @param rightVector right vector.
88 * @param intersection common elements.
89 * @return The dot product.
90 */
91 private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
92 final Set<CharSequence> intersection) {
93 long dotProduct = 0;
94 for (final CharSequence key : intersection) {
95 dotProduct += leftVector.get(key) * (long) rightVector.get(key);
96 }
97 return dotProduct;
98 }
99
100 /**
101 * Returns a set with keys common to the two given maps.
102 *
103 * @param leftVector left vector map.
104 * @param rightVector right vector map.
105 * @return common strings.
106 */
107 private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
108 final Map<CharSequence, Integer> rightVector) {
109 final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
110 intersection.retainAll(rightVector.keySet());
111 return intersection;
112 }
113
114 }