View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.text.similarity;
18  
19  import java.util.Collection;
20  import java.util.HashMap;
21  import java.util.Map;
22  import java.util.Map.Entry;
23  import java.util.Set;
24  import java.util.function.Function;
25  
26  /**
27   * Measures the intersection of two sets created from a pair of character sequences.
28   *
29   * <p>It is assumed that the type {@code T} correctly conforms to the requirements for storage
30   * within a {@link Set} or {@link HashMap}. Ideally the type is immutable and implements
31   * {@link Object#equals(Object)} and {@link Object#hashCode()}.</p>
32   *
33   * @param <T> the type of the elements extracted from the character sequence
34   * @since 1.7
35   * @see Set
36   * @see HashMap
37   */
38  public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> {
39  
40      /**
41       * Mutable counter class for storing the count of elements.
42       */
43      private static final class BagCount {
44  
45          /** Private, mutable but must be used as immutable. */
46          private static final BagCount ZERO = new BagCount();
47  
48          /** The count. */
49          private int count;
50  
51          private BagCount() {
52              this.count = 0;
53          }
54      }
55  
56      // The following is adapted from commons-collections for a Bag.
57      // A Bag is a collection that can store the count of the number
58      // of copies of each element.
59  
60      /**
61       * A minimal implementation of a Bag that can store elements and a count.
62       *
63       * <p>
64       * For the intended purpose the Bag does not have to be a {@link Collection}. It does not
65       * even have to know its own size.
66       * </p>
67       */
68      private final class TinyBag {
69  
70          /** The backing map. */
71          private final Map<T, BagCount> map;
72  
73          /**
74           * Create a new tiny bag.
75           *
76           * @param initialCapacity the initial capacity
77           */
78          private TinyBag(final int initialCapacity) {
79              map = new HashMap<>(initialCapacity);
80          }
81  
82          /**
83           * Adds a new element to the bag, incrementing its count in the underlying map.
84           *
85           * @param object the object to add
86           */
87          private void add(final T object) {
88              map.computeIfAbsent(object, k -> new BagCount()).count++;
89          }
90  
91          /**
92           * Returns a Set view of the mappings contained in this bag.
93           *
94           * @return The Set view
95           */
96          private Set<Entry<T, BagCount>> entrySet() {
97              return map.entrySet();
98          }
99  
100         /**
101          * Returns the number of occurrence of the given element in this bag by
102          * looking up its count in the underlying map.
103          *
104          * @param object the object to search for
105          * @return The number of occurrences of the object, zero if not found
106          */
107         private int getCount(final Object object) {
108             return map.getOrDefault(object, BagCount.ZERO).count;
109         }
110 
111         /**
112          * Gets the number of unique elements in the bag.
113          *
114          * @return The unique element size
115          */
116         private int uniqueElementSize() {
117             return map.size();
118         }
119     }
120 
121     /**
122      * Computes the intersection between two sets. This is the count of all the elements
123      * that are within both sets.
124      *
125      * @param <T> the type of the elements in the set
126      * @param setA the set A
127      * @param setB the set B
128      * @return The intersection
129      */
130     private static <T> int getIntersection(final Set<T> setA, final Set<T> setB) {
131         int intersection = 0;
132         for (final T element : setA) {
133             if (setB.contains(element)) {
134                 intersection++;
135             }
136         }
137         return intersection;
138     }
139 
140     /** The converter used to create the elements from the characters. */
141     private final Function<CharSequence, Collection<T>> converter;
142 
143     /**
144      * Create a new intersection similarity using the provided converter.
145      *
146      * <p>
147      * If the converter returns a {@link Set} then the intersection result will
148      * not include duplicates. Any other {@link Collection} is used to produce a result
149      * that will include duplicates in the intersect and union.
150      * </p>
151      *
152      * @param converter the converter used to create the elements from the characters
153      * @throws IllegalArgumentException if the converter is null
154      */
155     public IntersectionSimilarity(final Function<CharSequence, Collection<T>> converter) {
156         if (converter == null) {
157             throw new IllegalArgumentException("Converter must not be null");
158         }
159         this.converter = converter;
160     }
161 
162     /**
163      * Calculates the intersection of two character sequences passed as input.
164      *
165      * @param left first character sequence
166      * @param right second character sequence
167      * @return The intersection result
168      * @throws IllegalArgumentException if either input sequence is {@code null}
169      */
170     @Override
171     public IntersectionResult apply(final CharSequence left, final CharSequence right) {
172         if (left == null || right == null) {
173             throw new IllegalArgumentException("Input cannot be null");
174         }
175 
176         // Create the elements from the sequences
177         final Collection<T> objectsA = converter.apply(left);
178         final Collection<T> objectsB = converter.apply(right);
179         final int sizeA = objectsA.size();
180         final int sizeB = objectsB.size();
181 
182         // Short-cut if either collection is empty
183         if (Math.min(sizeA, sizeB) == 0) {
184             // No intersection
185             return new IntersectionResult(sizeA, sizeB, 0);
186         }
187 
188         // Intersection = count the number of shared elements
189         final int intersection;
190         if (objectsA instanceof Set && objectsB instanceof Set) {
191             // If a Set then the elements will only have a count of 1.
192             // Iterate over the smaller set.
193             intersection = sizeA < sizeB
194                     ? getIntersection((Set<T>) objectsA, (Set<T>) objectsB)
195                     : getIntersection((Set<T>) objectsB, (Set<T>) objectsA);
196         } else  {
197             // Create a bag for each collection
198             final TinyBag bagA = toBag(objectsA);
199             final TinyBag bagB = toBag(objectsB);
200             // Iterate over the smaller number of unique elements
201             intersection = bagA.uniqueElementSize() < bagB.uniqueElementSize()
202                     ? getIntersection(bagA, bagB)
203                     : getIntersection(bagB, bagA);
204         }
205 
206         return new IntersectionResult(sizeA, sizeB, intersection);
207     }
208 
209     /**
210      * Computes the intersection between two bags. This is the sum of the minimum
211      * count of each element that is within both sets.
212      *
213      * @param bagA the bag A
214      * @param bagB the bag B
215      * @return The intersection
216      */
217     private int getIntersection(final TinyBag bagA, final TinyBag bagB) {
218         int intersection = 0;
219         for (final Entry<T, BagCount> entry : bagA.entrySet()) {
220             final T element = entry.getKey();
221             final int count = entry.getValue().count;
222             // The intersection of this entry in both bags is the minimum count
223             intersection += Math.min(count, bagB.getCount(element));
224         }
225         return intersection;
226     }
227 
228     /**
229      * Converts the collection to a bag. The bag will contain the count of each element
230      * in the collection.
231      *
232      * @param objects the objects
233      * @return The bag
234      */
235     private TinyBag toBag(final Collection<T> objects) {
236         final TinyBag bag = new TinyBag(objects.size());
237         objects.forEach(bag::add);
238         return bag;
239     }
240 }