001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math4.legacy.optim.univariate; 019 020import java.util.Arrays; 021import java.util.Comparator; 022 023import org.apache.commons.math4.legacy.exception.MathIllegalStateException; 024import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; 025import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; 026import org.apache.commons.math4.legacy.optim.MaxEval; 027import org.apache.commons.math4.legacy.optim.OptimizationData; 028import org.apache.commons.math4.legacy.optim.nonlinear.scalar.GoalType; 029import org.apache.commons.rng.UniformRandomProvider; 030 031/** 032 * Special implementation of the {@link UnivariateOptimizer} interface 033 * adding multi-start features to an existing optimizer. 034 * <br> 035 * This class wraps an optimizer in order to use it several times in 036 * turn with different starting points (trying to avoid being trapped 037 * in a local extremum when looking for a global one). 038 * 039 * @since 3.0 040 */ 041public class MultiStartUnivariateOptimizer 042 extends UnivariateOptimizer { 043 /** Underlying classical optimizer. */ 044 private final UnivariateOptimizer optimizer; 045 /** Number of evaluations already performed for all starts. */ 046 private int totalEvaluations; 047 /** Number of starts to go. */ 048 private final int starts; 049 /** Random generator for multi-start. */ 050 private final UniformRandomProvider generator; 051 /** Found optima. */ 052 private UnivariatePointValuePair[] optima; 053 /** Optimization data. */ 054 private OptimizationData[] optimData; 055 /** 056 * Location in {@link #optimData} where the updated maximum 057 * number of evaluations will be stored. 058 */ 059 private int maxEvalIndex = -1; 060 /** 061 * Location in {@link #optimData} where the updated start value 062 * will be stored. 063 */ 064 private int searchIntervalIndex = -1; 065 066 /** 067 * Create a multi-start optimizer from a single-start optimizer. 068 * 069 * @param optimizer Single-start optimizer to wrap. 070 * @param starts Number of starts to perform. If {@code starts == 1}, 071 * the {@code optimize} methods will return the same solution as 072 * {@code optimizer} would. 073 * @param generator Random generator to use for restarts. 074 * @throws NotStrictlyPositiveException if {@code starts < 1}. 075 */ 076 public MultiStartUnivariateOptimizer(final UnivariateOptimizer optimizer, 077 final int starts, 078 final UniformRandomProvider generator) { 079 super(optimizer.getConvergenceChecker()); 080 081 if (starts < 1) { 082 throw new NotStrictlyPositiveException(starts); 083 } 084 085 this.optimizer = optimizer; 086 this.starts = starts; 087 this.generator = generator; 088 } 089 090 /** {@inheritDoc} */ 091 @Override 092 public int getEvaluations() { 093 return totalEvaluations; 094 } 095 096 /** 097 * Gets all the optima found during the last call to {@code optimize}. 098 * The optimizer stores all the optima found during a set of 099 * restarts. The {@code optimize} method returns the best point only. 100 * This method returns all the points found at the end of each starts, 101 * including the best one already returned by the {@code optimize} method. 102 * <br> 103 * The returned array as one element for each start as specified 104 * in the constructor. It is ordered with the results from the 105 * runs that did converge first, sorted from best to worst 106 * objective value (i.e in ascending order if minimizing and in 107 * descending order if maximizing), followed by {@code null} elements 108 * corresponding to the runs that did not converge. This means all 109 * elements will be {@code null} if the {@code optimize} method did throw 110 * an exception. 111 * This also means that if the first element is not {@code null}, it is 112 * the best point found across all starts. 113 * 114 * @return an array containing the optima. 115 * @throws MathIllegalStateException if {@link #optimize(OptimizationData[]) 116 * optimize} has not been called. 117 */ 118 public UnivariatePointValuePair[] getOptima() { 119 if (optima == null) { 120 throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET); 121 } 122 return optima.clone(); 123 } 124 125 /** 126 * {@inheritDoc} 127 * 128 * @throws MathIllegalStateException if {@code optData} does not contain an 129 * instance of {@link MaxEval} or {@link SearchInterval}. 130 */ 131 @Override 132 public UnivariatePointValuePair optimize(OptimizationData... optData) { 133 // Store arguments in order to pass them to the internal optimizer. 134 optimData = optData; 135 // Set up base class and perform computations. 136 return super.optimize(optData); 137 } 138 139 /** {@inheritDoc} */ 140 @Override 141 protected UnivariatePointValuePair doOptimize() { 142 // Remove all instances of "MaxEval" and "SearchInterval" from the 143 // array that will be passed to the internal optimizer. 144 // The former is to enforce smaller numbers of allowed evaluations 145 // (according to how many have been used up already), and the latter 146 // to impose a different start value for each start. 147 for (int i = 0; i < optimData.length; i++) { 148 if (optimData[i] instanceof MaxEval) { 149 optimData[i] = null; 150 maxEvalIndex = i; 151 continue; 152 } 153 if (optimData[i] instanceof SearchInterval) { 154 optimData[i] = null; 155 searchIntervalIndex = i; 156 continue; 157 } 158 } 159 if (maxEvalIndex == -1) { 160 throw new MathIllegalStateException(); 161 } 162 if (searchIntervalIndex == -1) { 163 throw new MathIllegalStateException(); 164 } 165 166 RuntimeException lastException = null; 167 optima = new UnivariatePointValuePair[starts]; 168 totalEvaluations = 0; 169 170 final int maxEval = getMaxEvaluations(); 171 final double min = getMin(); 172 final double max = getMax(); 173 final double startValue = getStartValue(); 174 175 // Multi-start loop. 176 for (int i = 0; i < starts; i++) { 177 // CHECKSTYLE: stop IllegalCatch 178 try { 179 // Decrease number of allowed evaluations. 180 optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations); 181 // New start value. 182 final double s = (i == 0) ? 183 startValue : 184 min + generator.nextDouble() * (max - min); 185 optimData[searchIntervalIndex] = new SearchInterval(min, max, s); 186 // Optimize. 187 optima[i] = optimizer.optimize(optimData); 188 } catch (RuntimeException mue) { 189 lastException = mue; 190 optima[i] = null; 191 } 192 // CHECKSTYLE: resume IllegalCatch 193 194 totalEvaluations += optimizer.getEvaluations(); 195 } 196 197 sortPairs(getGoalType()); 198 199 if (optima[0] == null) { 200 throw lastException; // Cannot be null if starts >= 1. 201 } 202 203 // Return the point with the best objective function value. 204 return optima[0]; 205 } 206 207 /** 208 * Sort the optima from best to worst, followed by {@code null} elements. 209 * 210 * @param goal Goal type. 211 */ 212 private void sortPairs(final GoalType goal) { 213 Arrays.sort(optima, new Comparator<UnivariatePointValuePair>() { 214 /** {@inheritDoc} */ 215 @Override 216 public int compare(final UnivariatePointValuePair o1, 217 final UnivariatePointValuePair o2) { 218 if (o1 == null) { 219 return (o2 == null) ? 0 : 1; 220 } else if (o2 == null) { 221 return -1; 222 } 223 final double v1 = o1.getValue(); 224 final double v2 = o2.getValue(); 225 return (goal == GoalType.MINIMIZE) ? 226 Double.compare(v1, v2) : Double.compare(v2, v1); 227 } 228 }); 229 } 230}