IntersectionSimilarity.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.text.similarity;

  18. import java.util.Collection;
  19. import java.util.HashMap;
  20. import java.util.Map;
  21. import java.util.Map.Entry;
  22. import java.util.Set;
  23. import java.util.function.Function;

  24. /**
  25.  * Measures the intersection of two sets created from a pair of character sequences.
  26.  *
  27.  * <p>It is assumed that the type {@code T} correctly conforms to the requirements for storage
  28.  * within a {@link Set} or {@link HashMap}. Ideally the type is immutable and implements
  29.  * {@link Object#equals(Object)} and {@link Object#hashCode()}.</p>
  30.  *
  31.  * @param <T> the type of the elements extracted from the character sequence
  32.  * @since 1.7
  33.  * @see Set
  34.  * @see HashMap
  35.  */
  36. public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> {

  37.     /**
  38.      * Mutable counter class for storing the count of elements.
  39.      */
  40.     private static final class BagCount {

  41.         /** Private, mutable but must be used as immutable. */
  42.         private static final BagCount ZERO = new BagCount();

  43.         /** The count. */
  44.         private int count;

  45.         private BagCount() {
  46.             this.count = 0;
  47.         }
  48.     }

  49.     // The following is adapted from commons-collections for a Bag.
  50.     // A Bag is a collection that can store the count of the number
  51.     // of copies of each element.

  52.     /**
  53.      * A minimal implementation of a Bag that can store elements and a count.
  54.      *
  55.      * <p>
  56.      * For the intended purpose the Bag does not have to be a {@link Collection}. It does not
  57.      * even have to know its own size.
  58.      * </p>
  59.      */
  60.     private final class TinyBag {

  61.         /** The backing map. */
  62.         private final Map<T, BagCount> map;

  63.         /**
  64.          * Create a new tiny bag.
  65.          *
  66.          * @param initialCapacity the initial capacity
  67.          */
  68.         private TinyBag(final int initialCapacity) {
  69.             map = new HashMap<>(initialCapacity);
  70.         }

  71.         /**
  72.          * Adds a new element to the bag, incrementing its count in the underlying map.
  73.          *
  74.          * @param object the object to add
  75.          */
  76.         private void add(final T object) {
  77.             map.computeIfAbsent(object, k -> new BagCount()).count++;
  78.         }

  79.         /**
  80.          * Returns a Set view of the mappings contained in this bag.
  81.          *
  82.          * @return The Set view
  83.          */
  84.         private Set<Entry<T, BagCount>> entrySet() {
  85.             return map.entrySet();
  86.         }

  87.         /**
  88.          * Returns the number of occurrence of the given element in this bag by
  89.          * looking up its count in the underlying map.
  90.          *
  91.          * @param object the object to search for
  92.          * @return The number of occurrences of the object, zero if not found
  93.          */
  94.         private int getCount(final Object object) {
  95.             return map.getOrDefault(object, BagCount.ZERO).count;
  96.         }

  97.         /**
  98.          * Gets the number of unique elements in the bag.
  99.          *
  100.          * @return The unique element size
  101.          */
  102.         private int uniqueElementSize() {
  103.             return map.size();
  104.         }
  105.     }

  106.     /**
  107.      * Computes the intersection between two sets. This is the count of all the elements
  108.      * that are within both sets.
  109.      *
  110.      * @param <T> the type of the elements in the set
  111.      * @param setA the set A
  112.      * @param setB the set B
  113.      * @return The intersection
  114.      */
  115.     private static <T> int getIntersection(final Set<T> setA, final Set<T> setB) {
  116.         int intersection = 0;
  117.         for (final T element : setA) {
  118.             if (setB.contains(element)) {
  119.                 intersection++;
  120.             }
  121.         }
  122.         return intersection;
  123.     }

  124.     /** The converter used to create the elements from the characters. */
  125.     private final Function<CharSequence, Collection<T>> converter;

  126.     /**
  127.      * Create a new intersection similarity using the provided converter.
  128.      *
  129.      * <p>
  130.      * If the converter returns a {@link Set} then the intersection result will
  131.      * not include duplicates. Any other {@link Collection} is used to produce a result
  132.      * that will include duplicates in the intersect and union.
  133.      * </p>
  134.      *
  135.      * @param converter the converter used to create the elements from the characters
  136.      * @throws IllegalArgumentException if the converter is null
  137.      */
  138.     public IntersectionSimilarity(final Function<CharSequence, Collection<T>> converter) {
  139.         if (converter == null) {
  140.             throw new IllegalArgumentException("Converter must not be null");
  141.         }
  142.         this.converter = converter;
  143.     }

  144.     /**
  145.      * Calculates the intersection of two character sequences passed as input.
  146.      *
  147.      * @param left first character sequence
  148.      * @param right second character sequence
  149.      * @return The intersection result
  150.      * @throws IllegalArgumentException if either input sequence is {@code null}
  151.      */
  152.     @Override
  153.     public IntersectionResult apply(final CharSequence left, final CharSequence right) {
  154.         if (left == null || right == null) {
  155.             throw new IllegalArgumentException("Input cannot be null");
  156.         }

  157.         // Create the elements from the sequences
  158.         final Collection<T> objectsA = converter.apply(left);
  159.         final Collection<T> objectsB = converter.apply(right);
  160.         final int sizeA = objectsA.size();
  161.         final int sizeB = objectsB.size();

  162.         // Short-cut if either collection is empty
  163.         if (Math.min(sizeA, sizeB) == 0) {
  164.             // No intersection
  165.             return new IntersectionResult(sizeA, sizeB, 0);
  166.         }

  167.         // Intersection = count the number of shared elements
  168.         final int intersection;
  169.         if (objectsA instanceof Set && objectsB instanceof Set) {
  170.             // If a Set then the elements will only have a count of 1.
  171.             // Iterate over the smaller set.
  172.             intersection = sizeA < sizeB
  173.                     ? getIntersection((Set<T>) objectsA, (Set<T>) objectsB)
  174.                     : getIntersection((Set<T>) objectsB, (Set<T>) objectsA);
  175.         } else  {
  176.             // Create a bag for each collection
  177.             final TinyBag bagA = toBag(objectsA);
  178.             final TinyBag bagB = toBag(objectsB);
  179.             // Iterate over the smaller number of unique elements
  180.             intersection = bagA.uniqueElementSize() < bagB.uniqueElementSize()
  181.                     ? getIntersection(bagA, bagB)
  182.                     : getIntersection(bagB, bagA);
  183.         }

  184.         return new IntersectionResult(sizeA, sizeB, intersection);
  185.     }

  186.     /**
  187.      * Computes the intersection between two bags. This is the sum of the minimum
  188.      * count of each element that is within both sets.
  189.      *
  190.      * @param bagA the bag A
  191.      * @param bagB the bag B
  192.      * @return The intersection
  193.      */
  194.     private int getIntersection(final TinyBag bagA, final TinyBag bagB) {
  195.         int intersection = 0;
  196.         for (final Entry<T, BagCount> entry : bagA.entrySet()) {
  197.             final T element = entry.getKey();
  198.             final int count = entry.getValue().count;
  199.             // The intersection of this entry in both bags is the minimum count
  200.             intersection += Math.min(count, bagB.getCount(element));
  201.         }
  202.         return intersection;
  203.     }

  204.     /**
  205.      * Converts the collection to a bag. The bag will contain the count of each element
  206.      * in the collection.
  207.      *
  208.      * @param objects the objects
  209.      * @return The bag
  210.      */
  211.     private TinyBag toBag(final Collection<T> objects) {
  212.         final TinyBag bag = new TinyBag(objects.size());
  213.         objects.forEach(bag::add);
  214.         return bag;
  215.     }
  216. }