001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;
018
019import java.util.List;
020import java.util.ArrayList;
021import java.util.Comparator;
022import java.util.function.UnaryOperator;
023import java.util.function.DoublePredicate;
024
025import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
026import org.apache.commons.math4.legacy.optim.PointValuePair;
027
028/**
029 * <a href="https://scholarship.rice.edu/handle/1911/16304">Multi-directional</a> search method.
030 */
031public class MultiDirectionalTransform
032    implements Simplex.TransformFactory {
033    /** Reflection coefficient. */
034    private static final double ALPHA = 1;
035    /** Default value for {@link #gamma}: {@value}. */
036    private static final double DEFAULT_GAMMA = 2;
037    /** Default value for {@link #sigma}: {@value}. */
038    private static final double DEFAULT_SIGMA = 0.5;
039    /** Expansion coefficient. */
040    private final double gamma;
041    /** Contraction coefficient. */
042    private final double sigma;
043
044    /**
045     * @param gamma Expansion coefficient.
046     * @param sigma Shrinkage coefficient.
047     */
048    public MultiDirectionalTransform(double gamma,
049                                     double sigma) {
050        if (gamma < 1) {
051            throw new IllegalArgumentException("gamma: " + gamma);
052        }
053        if (sigma < 0 ||
054            sigma > 1) {
055            throw new IllegalArgumentException("sigma: " + sigma);
056        }
057
058        this.gamma = gamma;
059        this.sigma = sigma;
060    }
061
062    /**
063     * Transform with default values.
064     */
065    public MultiDirectionalTransform() {
066        this(DEFAULT_GAMMA,
067             DEFAULT_SIGMA);
068    }
069
070    /** {@inheritDoc} */
071    @Override
072    public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
073                                         final Comparator<PointValuePair> comparator,
074                                         final DoublePredicate sa) {
075        return original -> {
076            final PointValuePair best = original.get(0);
077
078            // Perform a reflection step.
079            final Simplex reflectedSimplex = transform(original,
080                                                       ALPHA,
081                                                       comparator,
082                                                       evaluationFunction);
083            final PointValuePair reflectedBest = reflectedSimplex.get(0);
084
085            if (comparator.compare(reflectedBest, best) < 0) {
086                // Compute the expanded simplex.
087                final Simplex expandedSimplex = transform(original,
088                                                          gamma,
089                                                          comparator,
090                                                          evaluationFunction);
091                final PointValuePair expandedBest = expandedSimplex.get(0);
092
093                if (comparator.compare(expandedBest, reflectedBest) <= 0 ||
094                    (sa != null &&
095                     sa.test(expandedBest.getValue() - reflectedBest.getValue()))) {
096                    return expandedSimplex;
097                } else {
098                    return reflectedSimplex;
099                }
100            } else {
101                // Compute the contracted simplex.
102                return original.shrink(sigma, evaluationFunction);
103            }
104        };
105    }
106
107    /**
108     * Computes and evaluates a new simplex.
109     *
110     * @param original Original simplex.
111     * @param coeff Linear coefficient.
112     * @param comp Fitness comparator.
113     * @param evalFunc Objective function.
114     * @return the transformed simplex.
115     * @throws org.apache.commons.math4.legacy.exception.TooManyEvaluationsException
116     * if the maximal number of evaluations is exceeded.
117     */
118    private Simplex transform(Simplex original,
119                              double coeff,
120                              Comparator<PointValuePair> comp,
121                              MultivariateFunction evalFunc) {
122        // Transformed simplex is the result a linear transformation on all
123        // points except the first one.
124        final int replSize = original.getSize() - 1;
125        final List<PointValuePair> replacement = new ArrayList<>();
126        final double[] bestPoint = original.get(0).getPoint();
127        for (int i = 0; i < replSize; i++) {
128            replacement.add(Simplex.newPoint(bestPoint,
129                                             -coeff,
130                                             original.get(i + 1).getPoint(),
131                                             evalFunc));
132        }
133
134        return original.replaceLast(replacement).evaluate(evalFunc, comp);
135    }
136
137    /** {@inheritDoc} */
138    @Override
139    public String toString() {
140        return "Multidirectional [g=" + gamma +
141            " s=" + sigma + "]";
142    }
143}