001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.rng.core; 019 020import org.apache.commons.rng.RestorableUniformRandomProvider; 021import org.apache.commons.rng.RandomProviderState; 022 023/** 024 * Base class with default implementation for common methods. 025 */ 026public abstract class BaseProvider 027 implements RestorableUniformRandomProvider { 028 /** Error message when an integer is not positive. */ 029 private static final String NOT_POSITIVE = "Must be strictly positive: "; 030 /** 2^32. */ 031 private static final long POW_32 = 1L << 32; 032 033 /** {@inheritDoc} */ 034 @Override 035 public int nextInt(int n) { 036 if (n <= 0) { 037 throw new IllegalArgumentException(NOT_POSITIVE + n); 038 } 039 040 // Lemire (2019): Fast Random Integer Generation in an Interval 041 // https://arxiv.org/abs/1805.10941 042 long m = (nextInt() & 0xffffffffL) * n; 043 long l = m & 0xffffffffL; 044 if (l < n) { 045 // 2^32 % n 046 final long t = POW_32 % n; 047 while (l < t) { 048 m = (nextInt() & 0xffffffffL) * n; 049 l = m & 0xffffffffL; 050 } 051 } 052 return (int) (m >>> 32); 053 } 054 055 /** {@inheritDoc} */ 056 @Override 057 public long nextLong(long n) { 058 if (n <= 0) { 059 throw new IllegalArgumentException(NOT_POSITIVE + n); 060 } 061 062 long bits; 063 long val; 064 do { 065 bits = nextLong() >>> 1; 066 val = bits % n; 067 } while (bits - val + (n - 1) < 0); 068 069 return val; 070 } 071 072 /** {@inheritDoc} */ 073 @Override 074 public RandomProviderState saveState() { 075 return new RandomProviderDefaultState(getStateInternal()); 076 } 077 078 /** {@inheritDoc} */ 079 @Override 080 public void restoreState(RandomProviderState state) { 081 if (state instanceof RandomProviderDefaultState) { 082 setStateInternal(((RandomProviderDefaultState) state).getState()); 083 } else { 084 throw new IllegalArgumentException("Foreign instance"); 085 } 086 } 087 088 /** {@inheritDoc} */ 089 @Override 090 public String toString() { 091 return getClass().getName(); 092 } 093 094 /** 095 * Combine parent and subclass states. 096 * This method must be called by all subclasses in order to ensure 097 * that state can be restored in case some of it is stored higher 098 * up in the class hierarchy. 099 * 100 * I.e. the body of the overridden {@link #getStateInternal()}, 101 * will end with a statement like the following: 102 * <pre> 103 * <code> 104 * return composeStateInternal(state, 105 * super.getStateInternal()); 106 * </code> 107 * </pre> 108 * where {@code state} is the state needed and defined by the class 109 * where the method is overridden. 110 * 111 * @param state State of the calling class. 112 * @param parentState State of the calling class' parent. 113 * @return the combined state. 114 * Bytes that belong to the local state will be stored at the 115 * beginning of the resulting array. 116 */ 117 protected byte[] composeStateInternal(byte[] state, 118 byte[] parentState) { 119 final int len = parentState.length + state.length; 120 final byte[] c = new byte[len]; 121 System.arraycopy(state, 0, c, 0, state.length); 122 System.arraycopy(parentState, 0, c, state.length, parentState.length); 123 return c; 124 } 125 126 /** 127 * Splits the given {@code state} into a part to be consumed by the caller 128 * in order to restore its local state, while the reminder is passed to 129 * the parent class. 130 * 131 * I.e. the body of the overridden {@link #setStateInternal(byte[])}, 132 * will contain statements like the following: 133 * <pre> 134 * <code> 135 * final byte[][] s = splitState(state, localStateLength); 136 * // Use "s[0]" to recover the local state. 137 * super.setStateInternal(s[1]); 138 * </code> 139 * </pre> 140 * where {@code state} is the combined state of the calling class and of 141 * all its parents. 142 * 143 * @param state State. 144 * The local state must be stored at the beginning of the array. 145 * @param localStateLength Number of elements that will be consumed by the 146 * locally defined state. 147 * @return the local state (in slot 0) and the parent state (in slot 1). 148 * @throws IllegalStateException if {@code state.length < localStateLength}. 149 */ 150 protected byte[][] splitStateInternal(byte[] state, 151 int localStateLength) { 152 checkStateSize(state, localStateLength); 153 154 final byte[] local = new byte[localStateLength]; 155 System.arraycopy(state, 0, local, 0, localStateLength); 156 final int parentLength = state.length - localStateLength; 157 final byte[] parent = new byte[parentLength]; 158 System.arraycopy(state, localStateLength, parent, 0, parentLength); 159 160 return new byte[][] {local, parent}; 161 } 162 163 /** 164 * Creates a snapshot of the RNG state. 165 * 166 * @return the internal state. 167 */ 168 protected byte[] getStateInternal() { 169 // This class has no state (and is the top-level class that 170 // declares this method). 171 return new byte[0]; 172 } 173 174 /** 175 * Resets the RNG to the given {@code state}. 176 * 177 * @param state State (previously obtained by a call to 178 * {@link #getStateInternal()}). 179 * @throws IllegalStateException if the size of the given array is 180 * not consistent with the state defined by this class. 181 * 182 * @see #checkStateSize(byte[],int) 183 */ 184 protected void setStateInternal(byte[] state) { 185 if (state.length != 0) { 186 // This class has no state. 187 throw new IllegalStateException("State not fully recovered by subclasses"); 188 } 189 } 190 191 /** 192 * Simple filling procedure. 193 * It will 194 * <ol> 195 * <li> 196 * fill the beginning of {@code state} by copying 197 * {@code min(seed.length, state.length)} elements from 198 * {@code seed}, 199 * </li> 200 * <li> 201 * set all remaining elements of {@code state} with non-zero 202 * values (even if {@code seed.length < state.length}). 203 * </li> 204 * </ol> 205 * 206 * @param state State. Must be allocated. 207 * @param seed Seed. Cannot be null. 208 */ 209 protected void fillState(int[] state, 210 int[] seed) { 211 final int stateSize = state.length; 212 final int seedSize = seed.length; 213 System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize)); 214 215 if (seedSize < stateSize) { 216 for (int i = seedSize; i < stateSize; i++) { 217 state[i] = (int) (scrambleWell(state[i - seed.length], i) & 0xffffffffL); 218 } 219 } 220 } 221 222 /** 223 * Simple filling procedure. 224 * It will 225 * <ol> 226 * <li> 227 * fill the beginning of {@code state} by copying 228 * {@code min(seed.length, state.length)} elements from 229 * {@code seed}, 230 * </li> 231 * <li> 232 * set all remaining elements of {@code state} with non-zero 233 * values (even if {@code seed.length < state.length}). 234 * </li> 235 * </ol> 236 * 237 * @param state State. Must be allocated. 238 * @param seed Seed. Cannot be null. 239 */ 240 protected void fillState(long[] state, 241 long[] seed) { 242 final int stateSize = state.length; 243 final int seedSize = seed.length; 244 System.arraycopy(seed, 0, state, 0, Math.min(seedSize, stateSize)); 245 246 if (seedSize < stateSize) { 247 for (int i = seedSize; i < stateSize; i++) { 248 state[i] = scrambleWell(state[i - seed.length], i); 249 } 250 } 251 } 252 253 /** 254 * Checks that the {@code state} has the {@code expected} size. 255 * 256 * @param state State. 257 * @param expected Expected length of {@code state} array. 258 * @throws IllegalStateException if {@code state.length < expected}. 259 * @deprecated Method is used internally and should be made private in 260 * some future release. 261 */ 262 @Deprecated 263 protected void checkStateSize(byte[] state, 264 int expected) { 265 if (state.length < expected) { 266 throw new IllegalStateException("State size must be larger than " + 267 expected + " but was " + state.length); 268 } 269 } 270 271 /** 272 * Checks whether {@code index} is in the range {@code [min, max]}. 273 * 274 * @param min Lower bound. 275 * @param max Upper bound. 276 * @param index Value that must lie within the {@code [min, max]} interval. 277 * @throws IndexOutOfBoundsException if {@code index} is not within the 278 * {@code [min, max]} interval. 279 */ 280 protected void checkIndex(int min, 281 int max, 282 int index) { 283 if (index < min || 284 index > max) { 285 throw new IndexOutOfBoundsException(index + " is out of interval [" + 286 min + ", " + 287 max + "]"); 288 } 289 } 290 291 /** 292 * Transformation used to scramble the initial state of 293 * a generator. 294 * 295 * @param n Seed element. 296 * @param mult Multiplier. 297 * @param shift Shift. 298 * @param add Offset. 299 * @return the transformed seed element. 300 */ 301 private static long scramble(long n, 302 long mult, 303 int shift, 304 int add) { 305 // Code inspired from "AbstractWell" class. 306 return mult * (n ^ (n >> shift)) + add; 307 } 308 309 /** 310 * Transformation used to scramble the initial state of 311 * a generator. 312 * 313 * @param n Seed element. 314 * @param add Offset. 315 * @return the transformed seed element. 316 * @see #scramble(long,long,int,int) 317 */ 318 private static long scrambleWell(long n, 319 int add) { 320 // Code inspired from "AbstractWell" class. 321 return scramble(n, 1812433253L, 30, add); 322 } 323}