001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.sampling.shape;
019
020import org.apache.commons.rng.UniformRandomProvider;
021import org.apache.commons.rng.sampling.SharedStateObjectSampler;
022import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
023import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
024import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
025
026/**
027 * Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html">
028 * uniformly distributed within the unit n-ball</a>.
029 *
030 * <p>Sampling uses:</p>
031 *
032 * <ul>
033 *   <li>{@link UniformRandomProvider#nextLong()}
034 *   <li>{@link UniformRandomProvider#nextDouble()} (only for dimensions above 2)
035 * </ul>
036 *
037 * @since 1.4
038 */
039public abstract class UnitBallSampler implements SharedStateObjectSampler<double[]> {
040    /** The dimension for 1D sampling. */
041    private static final int ONE_D = 1;
042    /** The dimension for 2D sampling. */
043    private static final int TWO_D = 2;
044    /** The dimension for 3D sampling. */
045    private static final int THREE_D = 3;
046    /**
047     * The multiplier to convert the least significant 53-bits of a {@code long} to a {@code double}.
048     * Taken from o.a.c.rng.core.utils.NumberFactory.
049     *
050     * <p>This is equivalent to 1.0 / (1L << 53).
051     */
052    private static final double DOUBLE_MULTIPLIER = 0x1.0p-53d;
053
054    /**
055     * Sample uniformly from a 1D unit line.
056     */
057    private static class UnitBallSampler1D extends UnitBallSampler {
058        /** The source of randomness. */
059        private final UniformRandomProvider rng;
060
061        /**
062         * @param rng Source of randomness.
063         */
064        UnitBallSampler1D(UniformRandomProvider rng) {
065            this.rng = rng;
066        }
067
068        @Override
069        public double[] sample() {
070            return new double[] {makeSignedDouble(rng.nextLong())};
071        }
072
073        @Override
074        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
075            return new UnitBallSampler1D(rng);
076        }
077    }
078
079    /**
080     * Sample uniformly from a 2D unit disk.
081     */
082    private static class UnitBallSampler2D extends UnitBallSampler {
083        /** The source of randomness. */
084        private final UniformRandomProvider rng;
085
086        /**
087         * @param rng Source of randomness.
088         */
089        UnitBallSampler2D(UniformRandomProvider rng) {
090            this.rng = rng;
091        }
092
093        @Override
094        public double[] sample() {
095            // Generate via rejection method of a circle inside a square of edge length 2.
096            // This should compute approximately 2^2 / pi = 1.27 square positions per sample.
097            double x;
098            double y;
099            do {
100                x = makeSignedDouble(rng.nextLong());
101                y = makeSignedDouble(rng.nextLong());
102            } while (x * x + y * y > 1.0);
103            return new double[] {x, y};
104        }
105
106        @Override
107        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
108            return new UnitBallSampler2D(rng);
109        }
110    }
111
112    /**
113     * Sample uniformly from a 3D unit ball. This is an non-array based specialisation of
114     * {@link UnitBallSamplerND} for performance.
115     */
116    private static class UnitBallSampler3D extends UnitBallSampler {
117        /** The standard normal distribution. */
118        private final NormalizedGaussianSampler normal;
119        /** The exponential distribution (mean=1). */
120        private final ContinuousSampler exp;
121
122        /**
123         * @param rng Source of randomness.
124         */
125        UnitBallSampler3D(UniformRandomProvider rng) {
126            normal = ZigguratSampler.NormalizedGaussian.of(rng);
127            // Require an Exponential(mean=2).
128            // Here we use mean = 1 and scale the output later.
129            exp = ZigguratSampler.Exponential.of(rng);
130        }
131
132        @Override
133        public double[] sample() {
134            final double x = normal.sample();
135            final double y = normal.sample();
136            final double z = normal.sample();
137            // Include the exponential sample. It has mean 1 so multiply by 2.
138            final double sum = exp.sample() * 2 + x * x + y * y + z * z;
139            // Note: Handle the possibility of a zero sum and invalid inverse
140            if (sum == 0) {
141                return sample();
142            }
143            final double f = 1.0 / Math.sqrt(sum);
144            return new double[] {x * f, y * f, z * f};
145        }
146
147        @Override
148        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
149            return new UnitBallSampler3D(rng);
150        }
151    }
152
153    /**
154     * Sample using ball point picking.
155     * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a>
156     */
157    private static class UnitBallSamplerND extends UnitBallSampler {
158        /** The dimension. */
159        private final int dimension;
160        /** The standard normal distribution. */
161        private final NormalizedGaussianSampler normal;
162        /** The exponential distribution (mean=1). */
163        private final ContinuousSampler exp;
164
165        /**
166         * @param rng Source of randomness.
167         * @param dimension Space dimension.
168         */
169        UnitBallSamplerND(UniformRandomProvider rng, int dimension) {
170            this.dimension  = dimension;
171            normal = ZigguratSampler.NormalizedGaussian.of(rng);
172            // Require an Exponential(mean=2).
173            // Here we use mean = 1 and scale the output later.
174            exp = ZigguratSampler.Exponential.of(rng);
175        }
176
177        @Override
178        public double[] sample() {
179            final double[] sample = new double[dimension];
180            // Include the exponential sample. It has mean 1 so multiply by 2.
181            double sum = exp.sample() * 2;
182            for (int i = 0; i < dimension; i++) {
183                final double x = normal.sample();
184                sum += x * x;
185                sample[i] = x;
186            }
187            // Note: Handle the possibility of a zero sum and invalid inverse
188            if (sum == 0) {
189                return sample();
190            }
191            final double f = 1.0 / Math.sqrt(sum);
192            for (int i = 0; i < dimension; i++) {
193                sample[i] *= f;
194            }
195            return sample;
196        }
197
198        @Override
199        public UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng) {
200            return new UnitBallSamplerND(rng, dimension);
201        }
202    }
203
204    /**
205     * @return a random Cartesian coordinate within the unit n-ball.
206     */
207    @Override
208    public abstract double[] sample();
209
210    /** {@inheritDoc} */
211    // Redeclare the signature to return a UnitBallSampler not a SharedStateObjectSampler<double[]>
212    @Override
213    public abstract UnitBallSampler withUniformRandomProvider(UniformRandomProvider rng);
214
215    /**
216     * Create a unit n-ball sampler for the given dimension.
217     * Sampled points are uniformly distributed within the unit n-ball.
218     *
219     * <p>Sampling is supported in dimensions of 1 or above.
220     *
221     * @param rng Source of randomness.
222     * @param dimension Space dimension.
223     * @return the sampler
224     * @throws IllegalArgumentException If {@code dimension <= 0}
225     */
226    public static UnitBallSampler of(UniformRandomProvider rng,
227                                     int dimension) {
228        if (dimension <= 0) {
229            throw new IllegalArgumentException("Dimension must be strictly positive");
230        } else if (dimension == ONE_D) {
231            return new UnitBallSampler1D(rng);
232        } else if (dimension == TWO_D) {
233            return new UnitBallSampler2D(rng);
234        } else if (dimension == THREE_D) {
235            return new UnitBallSampler3D(rng);
236        }
237        return new UnitBallSamplerND(rng, dimension);
238    }
239
240    /**
241     * Creates a signed double in the range {@code [-1, 1)}. The magnitude is sampled evenly
242     * from the 2<sup>54</sup> dyadic rationals in the range.
243     *
244     * <p>Note: This method will not return samples for both -0.0 and 0.0.
245     *
246     * @param bits the bits
247     * @return the double
248     */
249    private static double makeSignedDouble(long bits) {
250        // As per o.a.c.rng.core.utils.NumberFactory.makeDouble(long) but using a signed
251        // shift of 10 in place of an unsigned shift of 11.
252        // Use the upper 54 bits on the assumption they are more random.
253        // The sign bit is maintained by the signed shift.
254        // The next 53 bits generates a magnitude in the range [0, 2^53) or [-2^53, 0).
255        return (bits >> 10) * DOUBLE_MULTIPLIER;
256    }
257}