001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      https://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.HashSet;
020import java.util.Map;
021import java.util.Set;
022
023/**
024 * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
025 * <p>
026 * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
027 * </p>
028 * <p>
029 * Instances of this class are immutable and are safe for use by multiple concurrent threads.
030 * </p>
031 *
032 * @since 1.0
033 */
034public class CosineSimilarity {
035
036    /**
037     * The singleton instance.
038     */
039    static final CosineSimilarity INSTANCE = new CosineSimilarity();
040
041    /**
042     * Construct a new instance.
043     */
044    public CosineSimilarity() {
045        // empty
046    }
047
048    /**
049     * Calculates the cosine similarity for two given vectors.
050     *
051     * @param leftVector left vector.
052     * @param rightVector right vector.
053     * @return cosine similarity between the two vectors.
054     */
055    public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector) {
056        if (leftVector == null || rightVector == null) {
057            throw new IllegalArgumentException("Vectors must not be null");
058        }
059        final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
060        final double dotProduct = dot(leftVector, rightVector, intersection);
061        double d1 = 0.0d;
062        for (final Integer value : leftVector.values()) {
063            d1 += Math.pow(value, 2);
064        }
065        double d2 = 0.0d;
066        for (final Integer value : rightVector.values()) {
067            d2 += Math.pow(value, 2);
068        }
069        final double cosineSimilarity;
070        if (d1 <= 0.0 || d2 <= 0.0) {
071            cosineSimilarity = 0.0;
072        } else {
073            cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
074        }
075        return cosineSimilarity;
076    }
077
078    /**
079     * Computes the dot product of two vectors. It ignores remaining elements. It means
080     * that if a vector is longer than other, then a smaller part of it will be used to compute
081     * the dot product.
082     *
083     * @param leftVector left vector.
084     * @param rightVector right vector.
085     * @param intersection common elements.
086     * @return The dot product.
087     */
088    private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector, final Set<CharSequence> intersection) {
089        long dotProduct = 0;
090        for (final CharSequence key : intersection) {
091            dotProduct += leftVector.get(key) * (long) rightVector.get(key);
092        }
093        return dotProduct;
094    }
095
096    /**
097     * Returns a set with keys common to the two given maps.
098     *
099     * @param leftVector left vector map.
100     * @param rightVector right vector map.
101     * @return common strings.
102     */
103    private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector) {
104        final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
105        intersection.retainAll(rightVector.keySet());
106        return intersection;
107    }
108}