001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.commons.math3.optimization;
019
020import java.util.Arrays;
021import java.util.Comparator;
022
023import org.apache.commons.math3.analysis.MultivariateVectorFunction;
024import org.apache.commons.math3.exception.ConvergenceException;
025import org.apache.commons.math3.exception.MathIllegalStateException;
026import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027import org.apache.commons.math3.exception.NullArgumentException;
028import org.apache.commons.math3.exception.util.LocalizedFormats;
029import org.apache.commons.math3.random.RandomVectorGenerator;
030
031/**
032 * Base class for all implementations of a multi-start optimizer.
033 *
034 * This interface is mainly intended to enforce the internal coherence of
035 * Commons-Math. Users of the API are advised to base their code on
036 * {@link DifferentiableMultivariateVectorMultiStartOptimizer}.
037 *
038 * @param <FUNC> Type of the objective function to be optimized.
039 *
040 * @deprecated As of 3.1 (to be removed in 4.0).
041 * @since 3.0
042 */
043@Deprecated
044public class BaseMultivariateVectorMultiStartOptimizer<FUNC extends MultivariateVectorFunction>
045    implements BaseMultivariateVectorOptimizer<FUNC> {
046    /** Underlying classical optimizer. */
047    private final BaseMultivariateVectorOptimizer<FUNC> optimizer;
048    /** Maximal number of evaluations allowed. */
049    private int maxEvaluations;
050    /** Number of evaluations already performed for all starts. */
051    private int totalEvaluations;
052    /** Number of starts to go. */
053    private int starts;
054    /** Random generator for multi-start. */
055    private RandomVectorGenerator generator;
056    /** Found optima. */
057    private PointVectorValuePair[] optima;
058
059    /**
060     * Create a multi-start optimizer from a single-start optimizer.
061     *
062     * @param optimizer Single-start optimizer to wrap.
063     * @param starts Number of starts to perform. If {@code starts == 1},
064     * the {@link #optimize(int,MultivariateVectorFunction,double[],double[],double[])
065     * optimize} will return the same solution as {@code optimizer} would.
066     * @param generator Random vector generator to use for restarts.
067     * @throws NullArgumentException if {@code optimizer} or {@code generator}
068     * is {@code null}.
069     * @throws NotStrictlyPositiveException if {@code starts < 1}.
070     */
071    protected BaseMultivariateVectorMultiStartOptimizer(final BaseMultivariateVectorOptimizer<FUNC> optimizer,
072                                                           final int starts,
073                                                           final RandomVectorGenerator generator) {
074        if (optimizer == null ||
075            generator == null) {
076            throw new NullArgumentException();
077        }
078        if (starts < 1) {
079            throw new NotStrictlyPositiveException(starts);
080        }
081
082        this.optimizer = optimizer;
083        this.starts = starts;
084        this.generator = generator;
085    }
086
087    /**
088     * Get all the optima found during the last call to {@link
089     * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize}.
090     * The optimizer stores all the optima found during a set of
091     * restarts. The {@link #optimize(int,MultivariateVectorFunction,double[],double[],double[])
092     * optimize} method returns the best point only. This method
093     * returns all the points found at the end of each starts, including
094     * the best one already returned by the {@link
095     * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method.
096     * <br/>
097     * The returned array as one element for each start as specified
098     * in the constructor. It is ordered with the results from the
099     * runs that did converge first, sorted from best to worst
100     * objective value (i.e. in ascending order if minimizing and in
101     * descending order if maximizing), followed by and null elements
102     * corresponding to the runs that did not converge. This means all
103     * elements will be null if the {@link
104     * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} method did
105     * throw a {@link ConvergenceException}). This also means that if
106     * the first element is not {@code null}, it is the best point found
107     * across all starts.
108     *
109     * @return array containing the optima
110     * @throws MathIllegalStateException if {@link
111     * #optimize(int,MultivariateVectorFunction,double[],double[],double[]) optimize} has not been
112     * called.
113     */
114    public PointVectorValuePair[] getOptima() {
115        if (optima == null) {
116            throw new MathIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET);
117        }
118        return optima.clone();
119    }
120
121    /** {@inheritDoc} */
122    public int getMaxEvaluations() {
123        return maxEvaluations;
124    }
125
126    /** {@inheritDoc} */
127    public int getEvaluations() {
128        return totalEvaluations;
129    }
130
131    /** {@inheritDoc} */
132    public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() {
133        return optimizer.getConvergenceChecker();
134    }
135
136    /**
137     * {@inheritDoc}
138     */
139    public PointVectorValuePair optimize(int maxEval, final FUNC f,
140                                            double[] target, double[] weights,
141                                            double[] startPoint) {
142        maxEvaluations = maxEval;
143        RuntimeException lastException = null;
144        optima = new PointVectorValuePair[starts];
145        totalEvaluations = 0;
146
147        // Multi-start loop.
148        for (int i = 0; i < starts; ++i) {
149
150            // CHECKSTYLE: stop IllegalCatch
151            try {
152                optima[i] = optimizer.optimize(maxEval - totalEvaluations, f, target, weights,
153                                               i == 0 ? startPoint : generator.nextVector());
154            } catch (ConvergenceException oe) {
155                optima[i] = null;
156            } catch (RuntimeException mue) {
157                lastException = mue;
158                optima[i] = null;
159            }
160            // CHECKSTYLE: resume IllegalCatch
161
162            totalEvaluations += optimizer.getEvaluations();
163        }
164
165        sortPairs(target, weights);
166
167        if (optima[0] == null) {
168            throw lastException; // cannot be null if starts >=1
169        }
170
171        // Return the found point given the best objective function value.
172        return optima[0];
173    }
174
175    /**
176     * Sort the optima from best to worst, followed by {@code null} elements.
177     *
178     * @param target Target value for the objective functions at optimum.
179     * @param weights Weights for the least-squares cost computation.
180     */
181    private void sortPairs(final double[] target,
182                           final double[] weights) {
183        Arrays.sort(optima, new Comparator<PointVectorValuePair>() {
184                /** {@inheritDoc} */
185                public int compare(final PointVectorValuePair o1,
186                                   final PointVectorValuePair o2) {
187                    if (o1 == null) {
188                        return (o2 == null) ? 0 : 1;
189                    } else if (o2 == null) {
190                        return -1;
191                    }
192                    return Double.compare(weightedResidual(o1), weightedResidual(o2));
193                }
194                private double weightedResidual(final PointVectorValuePair pv) {
195                    final double[] value = pv.getValueRef();
196                    double sum = 0;
197                    for (int i = 0; i < value.length; ++i) {
198                        final double ri = value[i] - target[i];
199                        sum += weights[i] * ri * ri;
200                    }
201                    return sum;
202                }
203            });
204    }
205}