001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.HashSet;
020import java.util.Map;
021import java.util.Set;
022
023/**
024 * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
025 * <p>
026 * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
027 * </p>
028 * <p>
029 * Instances of this class are immutable and are safe for use by multiple concurrent threads.
030 * </p>
031 *
032 * @since 1.0
033 */
034public class CosineSimilarity {
035
036    /**
037     * Singleton instance.
038     */
039    static final CosineSimilarity INSTANCE = new CosineSimilarity();
040
041    /**
042     * Calculates the cosine similarity for two given vectors.
043     *
044     * @param leftVector left vector
045     * @param rightVector right vector
046     * @return cosine similarity between the two vectors
047     */
048    public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
049                                   final Map<CharSequence, Integer> rightVector) {
050        if (leftVector == null || rightVector == null) {
051            throw new IllegalArgumentException("Vectors must not be null");
052        }
053
054        final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
055
056        final double dotProduct = dot(leftVector, rightVector, intersection);
057        double d1 = 0.0d;
058        for (final Integer value : leftVector.values()) {
059            d1 += Math.pow(value, 2);
060        }
061        double d2 = 0.0d;
062        for (final Integer value : rightVector.values()) {
063            d2 += Math.pow(value, 2);
064        }
065        final double cosineSimilarity;
066        if (d1 <= 0.0 || d2 <= 0.0) {
067            cosineSimilarity = 0.0;
068        } else {
069            cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
070        }
071        return cosineSimilarity;
072    }
073
074    /**
075     * Computes the dot product of two vectors. It ignores remaining elements. It means
076     * that if a vector is longer than other, then a smaller part of it will be used to compute
077     * the dot product.
078     *
079     * @param leftVector left vector
080     * @param rightVector right vector
081     * @param intersection common elements
082     * @return The dot product
083     */
084    private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
085            final Set<CharSequence> intersection) {
086        long dotProduct = 0;
087        for (final CharSequence key : intersection) {
088            dotProduct += leftVector.get(key) * (long) rightVector.get(key);
089        }
090        return dotProduct;
091    }
092
093    /**
094     * Returns a set with keys common to the two given maps.
095     *
096     * @param leftVector left vector map
097     * @param rightVector right vector map
098     * @return common strings
099     */
100    private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
101            final Map<CharSequence, Integer> rightVector) {
102        final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
103        intersection.retainAll(rightVector.keySet());
104        return intersection;
105    }
106
107}