001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.Collection;
020import java.util.HashMap;
021import java.util.Map;
022import java.util.Map.Entry;
023import java.util.Set;
024import java.util.function.Function;
025
026/**
027 * Measures the intersection of two sets created from a pair of character sequences.
028 *
029 * <p>It is assumed that the type {@code T} correctly conforms to the requirements for storage
030 * within a {@link Set} or {@link HashMap}. Ideally the type is immutable and implements
031 * {@link Object#equals(Object)} and {@link Object#hashCode()}.</p>
032 *
033 * @param <T> the type of the elements extracted from the character sequence
034 * @since 1.7
035 * @see Set
036 * @see HashMap
037 */
038public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> {
039
040    /**
041     * Mutable counter class for storing the count of elements.
042     */
043    private static final class BagCount {
044
045        /** Private, mutable but must be used as immutable. */
046        private static final BagCount ZERO = new BagCount();
047
048        /** The count. */
049        private int count;
050
051        private BagCount() {
052            this.count = 0;
053        }
054    }
055
056    // The following is adapted from commons-collections for a Bag.
057    // A Bag is a collection that can store the count of the number
058    // of copies of each element.
059
060    /**
061     * A minimal implementation of a Bag that can store elements and a count.
062     *
063     * <p>
064     * For the intended purpose the Bag does not have to be a {@link Collection}. It does not
065     * even have to know its own size.
066     * </p>
067     */
068    private final class TinyBag {
069
070        /** The backing map. */
071        private final Map<T, BagCount> map;
072
073        /**
074         * Create a new tiny bag.
075         *
076         * @param initialCapacity the initial capacity
077         */
078        private TinyBag(final int initialCapacity) {
079            map = new HashMap<>(initialCapacity);
080        }
081
082        /**
083         * Adds a new element to the bag, incrementing its count in the underlying map.
084         *
085         * @param object the object to add
086         */
087        private void add(final T object) {
088            map.computeIfAbsent(object, k -> new BagCount()).count++;
089        }
090
091        /**
092         * Returns a Set view of the mappings contained in this bag.
093         *
094         * @return The Set view
095         */
096        private Set<Entry<T, BagCount>> entrySet() {
097            return map.entrySet();
098        }
099
100        /**
101         * Returns the number of occurrence of the given element in this bag by
102         * looking up its count in the underlying map.
103         *
104         * @param object the object to search for
105         * @return The number of occurrences of the object, zero if not found
106         */
107        private int getCount(final Object object) {
108            return map.getOrDefault(object, BagCount.ZERO).count;
109        }
110
111        /**
112         * Gets the number of unique elements in the bag.
113         *
114         * @return The unique element size
115         */
116        private int uniqueElementSize() {
117            return map.size();
118        }
119    }
120
121    /**
122     * Computes the intersection between two sets. This is the count of all the elements
123     * that are within both sets.
124     *
125     * @param <T> the type of the elements in the set
126     * @param setA the set A
127     * @param setB the set B
128     * @return The intersection
129     */
130    private static <T> int getIntersection(final Set<T> setA, final Set<T> setB) {
131        int intersection = 0;
132        for (final T element : setA) {
133            if (setB.contains(element)) {
134                intersection++;
135            }
136        }
137        return intersection;
138    }
139
140    /** The converter used to create the elements from the characters. */
141    private final Function<CharSequence, Collection<T>> converter;
142
143    /**
144     * Create a new intersection similarity using the provided converter.
145     *
146     * <p>
147     * If the converter returns a {@link Set} then the intersection result will
148     * not include duplicates. Any other {@link Collection} is used to produce a result
149     * that will include duplicates in the intersect and union.
150     * </p>
151     *
152     * @param converter the converter used to create the elements from the characters
153     * @throws IllegalArgumentException if the converter is null
154     */
155    public IntersectionSimilarity(final Function<CharSequence, Collection<T>> converter) {
156        if (converter == null) {
157            throw new IllegalArgumentException("Converter must not be null");
158        }
159        this.converter = converter;
160    }
161
162    /**
163     * Calculates the intersection of two character sequences passed as input.
164     *
165     * @param left first character sequence
166     * @param right second character sequence
167     * @return The intersection result
168     * @throws IllegalArgumentException if either input sequence is {@code null}
169     */
170    @Override
171    public IntersectionResult apply(final CharSequence left, final CharSequence right) {
172        if (left == null || right == null) {
173            throw new IllegalArgumentException("Input cannot be null");
174        }
175
176        // Create the elements from the sequences
177        final Collection<T> objectsA = converter.apply(left);
178        final Collection<T> objectsB = converter.apply(right);
179        final int sizeA = objectsA.size();
180        final int sizeB = objectsB.size();
181
182        // Short-cut if either collection is empty
183        if (Math.min(sizeA, sizeB) == 0) {
184            // No intersection
185            return new IntersectionResult(sizeA, sizeB, 0);
186        }
187
188        // Intersection = count the number of shared elements
189        final int intersection;
190        if (objectsA instanceof Set && objectsB instanceof Set) {
191            // If a Set then the elements will only have a count of 1.
192            // Iterate over the smaller set.
193            intersection = sizeA < sizeB
194                    ? getIntersection((Set<T>) objectsA, (Set<T>) objectsB)
195                    : getIntersection((Set<T>) objectsB, (Set<T>) objectsA);
196        } else  {
197            // Create a bag for each collection
198            final TinyBag bagA = toBag(objectsA);
199            final TinyBag bagB = toBag(objectsB);
200            // Iterate over the smaller number of unique elements
201            intersection = bagA.uniqueElementSize() < bagB.uniqueElementSize()
202                    ? getIntersection(bagA, bagB)
203                    : getIntersection(bagB, bagA);
204        }
205
206        return new IntersectionResult(sizeA, sizeB, intersection);
207    }
208
209    /**
210     * Computes the intersection between two bags. This is the sum of the minimum
211     * count of each element that is within both sets.
212     *
213     * @param bagA the bag A
214     * @param bagB the bag B
215     * @return The intersection
216     */
217    private int getIntersection(final TinyBag bagA, final TinyBag bagB) {
218        int intersection = 0;
219        for (final Entry<T, BagCount> entry : bagA.entrySet()) {
220            final T element = entry.getKey();
221            final int count = entry.getValue().count;
222            // The intersection of this entry in both bags is the minimum count
223            intersection += Math.min(count, bagB.getCount(element));
224        }
225        return intersection;
226    }
227
228    /**
229     * Converts the collection to a bag. The bag will contain the count of each element
230     * in the collection.
231     *
232     * @param objects the objects
233     * @return The bag
234     */
235    private TinyBag toBag(final Collection<T> objects) {
236        final TinyBag bag = new TinyBag(objects.size());
237        objects.forEach(bag::add);
238        return bag;
239    }
240}