001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.HashSet;
020import java.util.Map;
021import java.util.Set;
022
023/**
024 * Measures the Cosine similarity of two vectors of an inner product space and
025 * compares the angle between them.
026 *
027 * <p>
028 * For further explanation about the Cosine Similarity, refer to
029 * http://en.wikipedia.org/wiki/Cosine_similarity.
030 * </p>
031 *
032 * @since 1.0
033 */
034public class CosineSimilarity {
035
036    /**
037     * Calculates the cosine similarity for two given vectors.
038     *
039     * @param leftVector left vector
040     * @param rightVector right vector
041     * @return cosine similarity between the two vectors
042     */
043    public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector) {
044        if (leftVector == null || rightVector == null) {
045            throw new IllegalArgumentException("Vectors must not be null");
046        }
047
048        final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
049
050        final double dotProduct = dot(leftVector, rightVector, intersection);
051        double d1 = 0.0d;
052        for (final Integer value : leftVector.values()) {
053            d1 += Math.pow(value, 2);
054        }
055        double d2 = 0.0d;
056        for (final Integer value : rightVector.values()) {
057            d2 += Math.pow(value, 2);
058        }
059        double cosineSimilarity;
060        if (d1 <= 0.0 || d2 <= 0.0) {
061            cosineSimilarity = 0.0;
062        } else {
063            cosineSimilarity = (double) (dotProduct / (double) (Math.sqrt(d1) * Math.sqrt(d2)));
064        }
065        return cosineSimilarity;
066    }
067
068    /**
069     * Returns a set with strings common to the two given maps.
070     *
071     * @param leftVector left vector map
072     * @param rightVector right vector map
073     * @return common strings
074     */
075    private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
076            final Map<CharSequence, Integer> rightVector) {
077        final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
078        intersection.retainAll(rightVector.keySet());
079        return intersection;
080    }
081
082    /**
083     * Computes the dot product of two vectors. It ignores remaining elements. It means
084     * that if a vector is longer than other, then a smaller part of it will be used to compute
085     * the dot product.
086     *
087     * @param leftVector left vector
088     * @param rightVector right vector
089     * @param intersection common elements
090     * @return the dot product
091     */
092    private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
093            final Set<CharSequence> intersection) {
094        long dotProduct = 0;
095        for (final CharSequence key : intersection) {
096            dotProduct += leftVector.get(key) * rightVector.get(key);
097        }
098        return dotProduct;
099    }
100
101}