001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.Collection;
020import java.util.HashMap;
021import java.util.Map;
022import java.util.Map.Entry;
023import java.util.Set;
024import java.util.function.Function;
025
026/**
027 * Measures the intersection of two sets created from a pair of character sequences.
028 *
029 * <p>It is assumed that the type {@code T} correctly conforms to the requirements for storage
030 * within a {@link Set} or {@link HashMap}. Ideally the type is immutable and implements
031 * {@link Object#equals(Object)} and {@link Object#hashCode()}.</p>
032 *
033 * @param <T> the type of the elements extracted from the character sequence
034 * @since 1.7
035 * @see Set
036 * @see HashMap
037 */
038public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> {
039
040    /**
041     * Mutable counter class for storing the count of elements.
042     */
043    private static final class BagCount {
044
045        /** Private, mutable but must be used as immutable. */
046        private static final BagCount ZERO = new BagCount();
047
048        /** The count. */
049        int count;
050
051        private BagCount() {
052            this.count = 0;
053        }
054    }
055
056    // The following is adapted from commons-collections for a Bag.
057    // A Bag is a collection that can store the count of the number
058    // of copies of each element.
059
060    /**
061     * A minimal implementation of a Bag that can store elements and a count.
062     *
063     * <p>For the intended purpose the Bag does not have to be a {@link Collection}. It does not
064     * even have to know its own size.
065     */
066    private class TinyBag {
067        /** The backing map. */
068        private final Map<T, BagCount> map;
069
070        /**
071         * Create a new tiny bag.
072         *
073         * @param initialCapacity the initial capacity
074         */
075        TinyBag(final int initialCapacity) {
076            map = new HashMap<>(initialCapacity);
077        }
078
079        /**
080         * Adds a new element to the bag, incrementing its count in the underlying map.
081         *
082         * @param object the object to add
083         */
084        void add(final T object) {
085            map.computeIfAbsent(object, k -> new BagCount()).count++;
086        }
087
088        /**
089         * Returns a Set view of the mappings contained in this bag.
090         *
091         * @return The Set view
092         */
093        Set<Entry<T, BagCount>> entrySet() {
094            return map.entrySet();
095        }
096
097        /**
098         * Returns the number of occurrence of the given element in this bag by
099         * looking up its count in the underlying map.
100         *
101         * @param object the object to search for
102         * @return The number of occurrences of the object, zero if not found
103         */
104        int getCount(final Object object) {
105            return map.getOrDefault(object, BagCount.ZERO).count;
106        }
107
108        /**
109         * Gets the number of unique elements in the bag.
110         *
111         * @return The unique element size
112         */
113        int uniqueElementSize() {
114            return map.size();
115        }
116    }
117
118    /**
119     * Computes the intersection between two sets. This is the count of all the elements
120     * that are within both sets.
121     *
122     * @param <T> the type of the elements in the set
123     * @param setA the set A
124     * @param setB the set B
125     * @return The intersection
126     */
127    private static <T> int getIntersection(final Set<T> setA, final Set<T> setB) {
128        int intersection = 0;
129        for (final T element : setA) {
130            if (setB.contains(element)) {
131                intersection++;
132            }
133        }
134        return intersection;
135    }
136
137    /** The converter used to create the elements from the characters. */
138    private final Function<CharSequence, Collection<T>> converter;
139
140    /**
141     * Create a new intersection similarity using the provided converter.
142     *
143     * <p>
144     * If the converter returns a {@link Set} then the intersection result will
145     * not include duplicates. Any other {@link Collection} is used to produce a result
146     * that will include duplicates in the intersect and union.
147     * </p>
148     *
149     * @param converter the converter used to create the elements from the characters
150     * @throws IllegalArgumentException if the converter is null
151     */
152    public IntersectionSimilarity(final Function<CharSequence, Collection<T>> converter) {
153        if (converter == null) {
154            throw new IllegalArgumentException("Converter must not be null");
155        }
156        this.converter = converter;
157    }
158
159    /**
160     * Calculates the intersection of two character sequences passed as input.
161     *
162     * @param left first character sequence
163     * @param right second character sequence
164     * @return The intersection result
165     * @throws IllegalArgumentException if either input sequence is {@code null}
166     */
167    @Override
168    public IntersectionResult apply(final CharSequence left, final CharSequence right) {
169        if (left == null || right == null) {
170            throw new IllegalArgumentException("Input cannot be null");
171        }
172
173        // Create the elements from the sequences
174        final Collection<T> objectsA = converter.apply(left);
175        final Collection<T> objectsB = converter.apply(right);
176        final int sizeA = objectsA.size();
177        final int sizeB = objectsB.size();
178
179        // Short-cut if either collection is empty
180        if (Math.min(sizeA, sizeB) == 0) {
181            // No intersection
182            return new IntersectionResult(sizeA, sizeB, 0);
183        }
184
185        // Intersection = count the number of shared elements
186        final int intersection;
187        if (objectsA instanceof Set && objectsB instanceof Set) {
188            // If a Set then the elements will only have a count of 1.
189            // Iterate over the smaller set.
190            intersection = sizeA < sizeB
191                    ? getIntersection((Set<T>) objectsA, (Set<T>) objectsB)
192                    : getIntersection((Set<T>) objectsB, (Set<T>) objectsA);
193        } else  {
194            // Create a bag for each collection
195            final TinyBag bagA = toBag(objectsA);
196            final TinyBag bagB = toBag(objectsB);
197            // Iterate over the smaller number of unique elements
198            intersection = bagA.uniqueElementSize() < bagB.uniqueElementSize()
199                    ? getIntersection(bagA, bagB)
200                    : getIntersection(bagB, bagA);
201        }
202
203        return new IntersectionResult(sizeA, sizeB, intersection);
204    }
205
206    /**
207     * Computes the intersection between two bags. This is the sum of the minimum
208     * count of each element that is within both sets.
209     *
210     * @param bagA the bag A
211     * @param bagB the bag B
212     * @return The intersection
213     */
214    private int getIntersection(final TinyBag bagA, final TinyBag bagB) {
215        int intersection = 0;
216        for (final Entry<T, BagCount> entry : bagA.entrySet()) {
217            final T element = entry.getKey();
218            final int count = entry.getValue().count;
219            // The intersection of this entry in both bags is the minimum count
220            intersection += Math.min(count, bagB.getCount(element));
221        }
222        return intersection;
223    }
224
225    /**
226     * Converts the collection to a bag. The bag will contain the count of each element
227     * in the collection.
228     *
229     * @param objects the objects
230     * @return The bag
231     */
232    private TinyBag toBag(final Collection<T> objects) {
233        final TinyBag bag = new TinyBag(objects.size());
234        objects.forEach(bag::add);
235        return bag;
236    }
237}