001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.rng.core;
019
020import org.apache.commons.rng.RestorableUniformRandomProvider;
021import org.apache.commons.rng.RandomProviderState;
022
023/**
024 * Base class with default implementation for common methods.
025 */
026public abstract class BaseProvider
027    implements RestorableUniformRandomProvider {
028    /** Error message when an integer is not positive. */
029    private static final String NOT_POSITIVE = "Must be strictly positive: ";
030    /** 2^32. */
031    private static final long POW_32 = 1L << 32;
032
033    /** {@inheritDoc} */
034    @Override
035    public int nextInt(int n) {
036        if (n <= 0) {
037            throw new IllegalArgumentException(NOT_POSITIVE + n);
038        }
039
040        // Lemire (2019): Fast Random Integer Generation in an Interval
041        // https://arxiv.org/abs/1805.10941
042        long m = (nextInt() & 0xffffffffL) * n;
043        long l = m & 0xffffffffL;
044        if (l < n) {
045            // 2^32 % n
046            final long t = POW_32 % n;
047            while (l < t) {
048                m = (nextInt() & 0xffffffffL) * n;
049                l = m & 0xffffffffL;
050            }
051        }
052        return (int) (m >>> 32);
053    }
054
055    /** {@inheritDoc} */
056    @Override
057    public long nextLong(long n) {
058        if (n <= 0) {
059            throw new IllegalArgumentException(NOT_POSITIVE + n);
060        }
061
062        long bits;
063        long val;
064        do {
065            bits = nextLong() >>> 1;
066            val  = bits % n;
067        } while (bits - val + (n - 1) < 0);
068
069        return val;
070    }
071
072    /** {@inheritDoc} */
073    @Override
074    public RandomProviderState saveState() {
075        return new RandomProviderDefaultState(getStateInternal());
076    }
077
078    /** {@inheritDoc} */
079    @Override
080    public void restoreState(RandomProviderState state) {
081        if (state instanceof RandomProviderDefaultState) {
082            setStateInternal(((RandomProviderDefaultState) state).getState());
083        } else {
084            throw new IllegalArgumentException("Foreign instance");
085        }
086    }
087
088    /** {@inheritDoc} */
089    @Override
090    public String toString() {
091        return getClass().getName();
092    }
093
094    /**
095     * Combine parent and subclass states.
096     * This method must be called by all subclasses in order to ensure
097     * that state can be restored in case some of it is stored higher
098     * up in the class hierarchy.
099     *
100     * I.e. the body of the overridden {@link #getStateInternal()},
101     * will end with a statement like the following:
102     * <pre>
103     *  <code>
104     *    return composeStateInternal(state,
105     *                                super.getStateInternal());
106     *  </code>
107     * </pre>
108     * where {@code state} is the state needed and defined by the class
109     * where the method is overridden.
110     *
111     * @param state State of the calling class.
112     * @param parentState State of the calling class' parent.
113     * @return the combined state.
114     * Bytes that belong to the local state will be stored at the
115     * beginning of the resulting array.
116     */
117    protected byte[] composeStateInternal(byte[] state,
118                                          byte[] parentState) {
119        final int len = parentState.length + state.length;
120        final byte[] c = new byte[len];
121        System.arraycopy(state, 0, c, 0, state.length);
122        System.arraycopy(parentState, 0, c, state.length, parentState.length);
123        return c;
124    }
125
126    /**
127     * Splits the given {@code state} into a part to be consumed by the caller
128     * in order to restore its local state, while the reminder is passed to
129     * the parent class.
130     *
131     * I.e. the body of the overridden {@link #setStateInternal(byte[])},
132     * will contain statements like the following:
133     * <pre>
134     *  <code>
135     *    final byte[][] s = splitState(state, localStateLength);
136     *    // Use "s[0]" to recover the local state.
137     *    super.setStateInternal(s[1]);
138     *  </code>
139     * </pre>
140     * where {@code state} is the combined state of the calling class and of
141     * all its parents.
142     *
143     * @param state State.
144     * The local state must be stored at the beginning of the array.
145     * @param localStateLength Number of elements that will be consumed by the
146     * locally defined state.
147     * @return the local state (in slot 0) and the parent state (in slot 1).
148     * @throws IllegalStateException if {@code state.length < localStateLength}.
149     */
150    protected byte[][] splitStateInternal(byte[] state,
151                                          int localStateLength) {
152        checkStateSize(state, localStateLength);
153
154        final byte[] local = new byte[localStateLength];
155        System.arraycopy(state, 0, local, 0, localStateLength);
156        final int parentLength = state.length - localStateLength;
157        final byte[] parent = new byte[parentLength];
158        System.arraycopy(state, localStateLength, parent, 0, parentLength);
159
160        return new byte[][] {local, parent};
161    }
162
163    /**
164     * Creates a snapshot of the RNG state.
165     *
166     * @return the internal state.
167     */
168    protected byte[] getStateInternal() {
169        // This class has no state (and is the top-level class that
170        // declares this method).
171        return new byte[0];
172    }
173
174    /**
175     * Resets the RNG to the given {@code state}.
176     *
177     * @param state State (previously obtained by a call to
178     * {@link #getStateInternal()}).
179     * @throws IllegalStateException if the size of the given array is
180     * not consistent with the state defined by this class.
181     *
182     * @see #checkStateSize(byte[],int)
183     */
184    protected void setStateInternal(byte[] state) {
185        if (state.length != 0) {
186            // This class has no state.
187            throw new IllegalStateException("State not fully recovered by subclasses");
188        }
189    }
190
191    /**
192     * Simple filling procedure.
193     * It will
194     * <ol>
195     *  <li>
196     *   fill the beginning of {@code state} by copying
197     *   {@code min(seed.length, state.length)} elements from
198     *   {@code seed},
199     *  </li>
200     *  <li>
201     *   set all remaining elements of {@code state} with non-zero
202     *   values (even if {@code seed.length < state.length}).
203     *  </li>
204     * </ol>
205     *
206     * @param state State. Must be allocated.
207     * @param seed Seed. Cannot be null.
208     */
209    protected void fillState(int[] state,
210                             int[] seed) {
211        final int stateSize = state.length;
212        final int seedSize = seed.length;
213        System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
214
215        if (seedSize < stateSize) {
216            for (int i = seedSize; i < stateSize; i++) {
217                state[i] = (int) (scrambleWell(state[i - seed.length], i) & 0xffffffffL);
218            }
219        }
220    }
221
222    /**
223     * Simple filling procedure.
224     * It will
225     * <ol>
226     *  <li>
227     *   fill the beginning of {@code state} by copying
228     *   {@code min(seed.length, state.length)} elements from
229     *   {@code seed},
230     *  </li>
231     *  <li>
232     *   set all remaining elements of {@code state} with non-zero
233     *   values (even if {@code seed.length < state.length}).
234     *  </li>
235     * </ol>
236     *
237     * @param state State. Must be allocated.
238     * @param seed Seed. Cannot be null.
239     */
240    protected void fillState(long[] state,
241                             long[] seed) {
242        final int stateSize = state.length;
243        final int seedSize = seed.length;
244        System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize));
245
246        if (seedSize < stateSize) {
247            for (int i = seedSize; i < stateSize; i++) {
248                state[i] = scrambleWell(state[i - seed.length], i);
249            }
250        }
251    }
252
253    /**
254     * Checks that the {@code state} has the {@code expected} size.
255     *
256     * @param state State.
257     * @param expected Expected length of {@code state} array.
258     * @throws IllegalStateException if {@code state.length < expected}.
259     * @deprecated Method is used internally and should be made private in
260     * some future release.
261     */
262    @Deprecated
263    protected void checkStateSize(byte[] state,
264                                  int expected) {
265        if (state.length < expected) {
266            throw new IllegalStateException("State size must be larger than " +
267                                            expected + " but was " + state.length);
268        }
269    }
270
271    /**
272     * Checks whether {@code index} is in the range {@code [min, max]}.
273     *
274     * @param min Lower bound.
275     * @param max Upper bound.
276     * @param index Value that must lie within the {@code [min, max]} interval.
277     * @throws IndexOutOfBoundsException if {@code index} is not within the
278     * {@code [min, max]} interval.
279     */
280    protected void checkIndex(int min,
281                              int max,
282                              int index) {
283        if (index < min ||
284            index > max) {
285            throw new IndexOutOfBoundsException(index + " is out of interval [" +
286                                                min + ", " +
287                                                max + "]");
288        }
289    }
290
291    /**
292     * Transformation used to scramble the initial state of
293     * a generator.
294     *
295     * @param n Seed element.
296     * @param mult Multiplier.
297     * @param shift Shift.
298     * @param add Offset.
299     * @return the transformed seed element.
300     */
301    private static long scramble(long n,
302                                 long mult,
303                                 int shift,
304                                 int add) {
305        // Code inspired from "AbstractWell" class.
306        return mult * (n ^ (n >> shift)) + add;
307    }
308
309    /**
310     * Transformation used to scramble the initial state of
311     * a generator.
312     *
313     * @param n Seed element.
314     * @param add Offset.
315     * @return the transformed seed element.
316     * @see #scramble(long,long,int,int)
317     */
318    private static long scrambleWell(long n,
319                                     int add) {
320        // Code inspired from "AbstractWell" class.
321        return scramble(n, 1812433253L, 30, add);
322    }
323}