UnitBallSampler.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.rng.sampling.shape;

  18. import org.apache.commons.rng.UniformRandomProvider;
  19. import org.apache.commons.rng.sampling.SharedStateObjectSampler;
  20. import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
  21. import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
  22. import org.apache.commons.rng.sampling.distribution.ZigguratSampler;

  23. /**
  24.  * Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html">
  25.  * uniformly distributed within the unit n-ball</a>.
  26.  *
  27.  * <p>Sampling uses:</p>
  28.  *
  29.  * <ul>
  30.  *   <li>{@link UniformRandomProvider#nextLong()}
  31.  *   <li>{@link UniformRandomProvider#nextDouble()} (only for dimensions above 2)
  32.  * </ul>
  33.  *
  34.  * @since 1.4
  35.  */
  36. public abstract class UnitBallSampler implements SharedStateObjectSampler<double[]> {
  37.     /** The dimension for 1D sampling. */
  38.     private static final int ONE_D = 1;
  39.     /** The dimension for 2D sampling. */
  40.     private static final int TWO_D = 2;
  41.     /** The dimension for 3D sampling. */
  42.     private static final int THREE_D = 3;
  43.     /**
  44.      * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
  45.      * Taken from o.a.c.rng.core.utils.NumberFactory.
  46.      *
  47.      * <p>This is equivalent to {@code 1.0 / (1L << 53)}.
  48.      */
  49.     private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;

  50.     /**
  51.      * Sample uniformly from a 1D unit line.
  52.      */
  53.     private static final class UnitBallSampler1D extends UnitBallSampler {
  54.         /** The source of randomness. */
  55.         private final UniformRandomProvider rng;

  56.         /**
  57.          * @param rng Source of randomness.
  58.          */
  59.         UnitBallSampler1D(UniformRandomProvider rng) {
  60.             this.rng = rng;
  61.         }

  62.         @Override
  63.         public double[] sample() {
  64.             return new double[] {makeSignedDouble(rng.nextLong())};
  65.         }

  66.         @Override
  67.         public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
  68.             return new UnitBallSampler1D(rng);
  69.         }
  70.     }

  71.     /**
  72.      * Sample uniformly from a 2D unit disk.
  73.      */
  74.     private static final class UnitBallSampler2D extends UnitBallSampler {
  75.         /** The source of randomness. */
  76.         private final UniformRandomProvider rng;

  77.         /**
  78.          * @param rng Source of randomness.
  79.          */
  80.         UnitBallSampler2D(UniformRandomProvider rng) {
  81.             this.rng = rng;
  82.         }

  83.         @Override
  84.         public double[] sample() {
  85.             // Generate via rejection method of a circle inside a square of edge length 2.
  86.             // This should compute approximately 2^2 / pi = 1.27 square positions per sample.
  87.             double x;
  88.             double y;
  89.             do {
  90.                 x = makeSignedDouble(rng.nextLong());
  91.                 y = makeSignedDouble(rng.nextLong());
  92.             } while (x * x + y * y > 1.0);
  93.             return new double[] {x, y};
  94.         }

  95.         @Override
  96.         public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
  97.             return new UnitBallSampler2D(rng);
  98.         }
  99.     }

  100.     /**
  101.      * Sample uniformly from a 3D unit ball. This is an non-array based specialisation of
  102.      * {@link UnitBallSamplerND} for performance.
  103.      */
  104.     private static final class UnitBallSampler3D extends UnitBallSampler {
  105.         /** The standard normal distribution. */
  106.         private final NormalizedGaussianSampler normal;
  107.         /** The exponential distribution (mean=1). */
  108.         private final ContinuousSampler exp;

  109.         /**
  110.          * @param rng Source of randomness.
  111.          */
  112.         UnitBallSampler3D(UniformRandomProvider rng) {
  113.             normal = ZigguratSampler.NormalizedGaussian.of(rng);
  114.             // Require an Exponential(mean=2).
  115.             // Here we use mean = 1 and scale the output later.
  116.             exp = ZigguratSampler.Exponential.of(rng);
  117.         }

  118.         @Override
  119.         public double[] sample() {
  120.             final double x = normal.sample();
  121.             final double y = normal.sample();
  122.             final double z = normal.sample();
  123.             // Include the exponential sample. It has mean 1 so multiply by 2.
  124.             final double sum = exp.sample() * 2 + x * x + y * y + z * z;
  125.             // Note: Handle the possibility of a zero sum and invalid inverse
  126.             if (sum == 0) {
  127.                 return sample();
  128.             }
  129.             final double f = 1.0 / Math.sqrt(sum);
  130.             return new double[] {x * f, y * f, z * f};
  131.         }

  132.         @Override
  133.         public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
  134.             return new UnitBallSampler3D(rng);
  135.         }
  136.     }

  137.     /**
  138.      * Sample using ball point picking.
  139.      * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a>
  140.      */
  141.     private static final class UnitBallSamplerND extends UnitBallSampler {
  142.         /** The dimension. */
  143.         private final int dimension;
  144.         /** The standard normal distribution. */
  145.         private final NormalizedGaussianSampler normal;
  146.         /** The exponential distribution (mean=1). */
  147.         private final ContinuousSampler exp;

  148.         /**
  149.          * @param rng Source of randomness.
  150.          * @param dimension Space dimension.
  151.          */
  152.         UnitBallSamplerND(UniformRandomProvider rng, int dimension) {
  153.             this.dimension  = dimension;
  154.             normal = ZigguratSampler.NormalizedGaussian.of(rng);
  155.             // Require an Exponential(mean=2).
  156.             // Here we use mean = 1 and scale the output later.
  157.             exp = ZigguratSampler.Exponential.of(rng);
  158.         }

  159.         @Override
  160.         public double[] sample() {
  161.             final double[] sample = new double[dimension];
  162.             // Include the exponential sample. It has mean 1 so multiply by 2.
  163.             double sum = exp.sample() * 2;
  164.             for (int i = 0; i < dimension; i++) {
  165.                 final double x = normal.sample();
  166.                 sum += x * x;
  167.                 sample[i] = x;
  168.             }
  169.             // Note: Handle the possibility of a zero sum and invalid inverse
  170.             if (sum == 0) {
  171.                 return sample();
  172.             }
  173.             final double f = 1.0 / Math.sqrt(sum);
  174.             for (int i = 0; i < dimension; i++) {
  175.                 sample[i] *= f;
  176.             }
  177.             return sample;
  178.         }

  179.         @Override
  180.         public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
  181.             return new UnitBallSamplerND(rng, dimension);
  182.         }
  183.     }

  184.     /**
  185.      * Create an instance.
  186.      */
  187.     public UnitBallSampler() {}

  188.     /**
  189.      * @return a random Cartesian coordinate within the unit n-ball.
  190.      */
  191.     @Override
  192.     public abstract double[] sample();

  193.     /** {@inheritDoc} */
  194.     // Redeclare the signature to return a UnitBallSampler not a SharedStateObjectSampler<double[]>
  195.     @Override
  196.     public abstract UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng);

  197.     /**
  198.      * Create a unit n-ball sampler for the given dimension.
  199.      * Sampled points are uniformly distributed within the unit n-ball.
  200.      *
  201.      * <p>Sampling is supported in dimensions of 1 or above.
  202.      *
  203.      * @param rng Source of randomness.
  204.      * @param dimension Space dimension.
  205.      * @return the sampler
  206.      * @throws IllegalArgumentException If {@code dimension <= 0}
  207.      */
  208.     public static UnitBallSampler of(UniformRandomProvider rng,
  209.                                      int dimension) {
  210.         if (dimension <= 0) {
  211.             throw new IllegalArgumentException("Dimension must be strictly positive");
  212.         } else if (dimension == ONE_D) {
  213.             return new UnitBallSampler1D(rng);
  214.         } else if (dimension == TWO_D) {
  215.             return new UnitBallSampler2D(rng);
  216.         } else if (dimension == THREE_D) {
  217.             return new UnitBallSampler3D(rng);
  218.         }
  219.         return new UnitBallSamplerND(rng, dimension);
  220.     }

  221.     /**
  222.      * Creates a signed double in the range {@code [-1, 1)}. The magnitude is sampled evenly
  223.      * from the 2<sup>54</sup> dyadic rationals in the range.
  224.      *
  225.      * <p>Note: This method will not return samples for both -0.0 and 0.0.
  226.      *
  227.      * @param bits the bits
  228.      * @return the double
  229.      */
  230.     private static double makeSignedDouble(long bits) {
  231.         // As per o.a.c.rng.core.utils.NumberFactory.makeDouble(long) but using a signed
  232.         // shift of 10 in place of an unsigned shift of 11.
  233.         // Use the upper 54 bits on the assumption they are more random.
  234.         // The sign bit is maintained by the signed shift.
  235.         // The next 53 bits generates a magnitude in the range [0, 2^53) or [-2^53, 0).
  236.         return (bits >> 10) * DOUBLE_MULTIPLIER;
  237.     }
  238. }