HedarFukushimaTransform.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.apache.commons.math4.legacy.optim.nonlinear.scalar.noderiv;

  18. import java.util.Comparator;
  19. import java.util.List;
  20. import java.util.ArrayList;
  21. import java.util.Collections;
  22. import java.util.function.UnaryOperator;
  23. import java.util.function.DoublePredicate;
  24. import org.apache.commons.rng.UniformRandomProvider;
  25. import org.apache.commons.rng.simple.RandomSource;
  26. import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
  27. import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
  28. import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
  29. import org.apache.commons.math4.legacy.optim.PointValuePair;

  30. /**
  31.  * DSSA algorithm.
  32.  *
  33.  * Described in
  34.  * <blockquote>
  35.  *  <em>Abdel-Rahman Hedar and Masao Fukushima (2002)</em>,
  36.  *  <b>
  37.  *   Hybrid simulated annealing and direct search method
  38.  *   for nonlinear unconstrained global optimization
  39.  *  </b>,
  40.  *  Optimization Methods and Software, 17:5, 891-912,
  41.  *  DOI: 10.1080/1055678021000030084
  42.  * </blockquote>
  43.  *
  44.  * <p>
  45.  * A note about the {@link #HedarFukushimaTransform(double) "shrink" factor}:
  46.  * Per DSSA's description, the simplex must keep its size during the simulated
  47.  * annealing (SA) phase to avoid premature convergence.  This assumes that the
  48.  * best candidates from the SA phase will each subsequently serve as starting
  49.  * point for another optimization to hone in on the local optimum.
  50.  * Values lower than 1 and no subsequent "best list" search correspond to the
  51.  * "SSA" algorithm in the above paper.
  52.  */
  53. public class HedarFukushimaTransform
  54.     implements Simplex.TransformFactory {
  55.     /** Shrinkage coefficient. */
  56.     private final double sigma;
  57.     /** Sampler for reflection coefficient. */
  58.     private final ContinuousSampler alphaSampler;
  59.     /** No shrink indicator. */
  60.     private final boolean noShrink;

  61.     /**
  62.      * @param sigma Shrink factor.
  63.      * @param rng Random generator.
  64.      * @throws IllegalArgumentException if {@code sigma <= 0} or
  65.      * {@code sigma > 1}.
  66.      */
  67.     public HedarFukushimaTransform(double sigma,
  68.                                    UniformRandomProvider rng) {
  69.         if (sigma <= 0 ||
  70.             sigma > 1) {
  71.             throw new IllegalArgumentException("Shrink factor out of range: " +
  72.                                                sigma);
  73.         }

  74.         this.sigma = sigma;
  75.         alphaSampler = ContinuousUniformSampler.of(rng, 0.9, 1.1);
  76.         noShrink = sigma == 1d;
  77.     }

  78.     /**
  79.      * @param sigma Shrink factor.
  80.      * @throws IllegalArgumentException if {@code sigma <= 0} or
  81.      * {@code sigma > 1}.
  82.      */
  83.     public HedarFukushimaTransform(double sigma) {
  84.         this(sigma, RandomSource.KISS.create());
  85.     }

  86.     /**
  87.      * Disable shrinking of the simplex (as mandated by DSSA).
  88.      */
  89.     public HedarFukushimaTransform() {
  90.         this(1d);
  91.     }

  92.     /** {@inheritDoc} */
  93.     @Override
  94.     public UnaryOperator<Simplex> create(final MultivariateFunction evaluationFunction,
  95.                                          final Comparator<PointValuePair> comparator,
  96.                                          final DoublePredicate saAcceptance) {
  97.         if (saAcceptance == null) {
  98.             throw new IllegalArgumentException("Missing SA acceptance test");
  99.         }

  100.         return original -> transform(original,
  101.                                      saAcceptance,
  102.                                      evaluationFunction,
  103.                                      comparator);
  104.     }

  105.     /**
  106.      * Simulated annealing step (at fixed temperature).
  107.      *
  108.      * @param original Current simplex.  Points must be sorted from best to worst.
  109.      * @param sa Simulated annealing acceptance test.
  110.      * @param eval Evaluation function.
  111.      * @param comp Objective function comparator.
  112.      * @return a new instance.
  113.      */
  114.     private Simplex transform(Simplex original,
  115.                               DoublePredicate sa,
  116.                               MultivariateFunction eval,
  117.                               Comparator<PointValuePair> comp) {
  118.         final int size = original.getSize();
  119.         // Current best point.
  120.         final PointValuePair best = original.get(0);
  121.         final double bestValue = best.getValue();

  122.         for (int k = 1; k < size; k++) {
  123.             // Perform reflections of the "k" worst points.
  124.             final List<PointValuePair> reflected = reflectPoints(original, k, eval);
  125.             Collections.sort(reflected, comp);

  126.             // Check whether the best of the reflected points is better than the
  127.             // current overall best.
  128.             final PointValuePair candidate = reflected.get(0);
  129.             final boolean candidateIsBetter = comp.compare(candidate, best) < 0;
  130.             final boolean candidateIsAccepted = candidateIsBetter ||
  131.                 sa.test(candidate.getValue() - bestValue);

  132.             if (candidateIsAccepted) {
  133.                 // Replace worst points with the reflected points.
  134.                 return original.replaceLast(reflected);
  135.             }
  136.         }

  137.         // No direction provided a better point.
  138.         return noShrink ?
  139.             original :
  140.             original.shrink(sigma, eval);
  141.     }

  142.     /**
  143.      * @param simplex Current simplex (whose points must be sorted, from best
  144.      * to worst).
  145.      * @param nPoints Number of points to reflect.
  146.      * The {@code nPoints} worst points will be reflected through the centroid
  147.      * of the {@code n + 1 - nPoints} best points.
  148.      * @param eval Evaluation function.
  149.      * @return the (unsorted) list of reflected points.
  150.      * @throws IllegalArgumentException if {@code nPoints < 1} or
  151.      * {@code nPoints > n}.
  152.      */
  153.     private List<PointValuePair> reflectPoints(Simplex simplex,
  154.                                                int nPoints,
  155.                                                MultivariateFunction eval) {
  156.         final int size = simplex.getSize();
  157.         if (nPoints < 1 ||
  158.             nPoints >= size) {
  159.             throw new IllegalArgumentException("Out of range: " + nPoints);
  160.         }

  161.         final int nCentroid = size - nPoints;
  162.         final List<PointValuePair> centroidList = simplex.asList(0, nCentroid);
  163.         final List<PointValuePair> reflectList = simplex.asList(nCentroid, size);

  164.         final double[] centroid = Simplex.centroid(centroidList);

  165.         final List<PointValuePair> reflected = new ArrayList<>(nPoints);
  166.         for (int i = 0; i < reflectList.size(); i++) {
  167.             reflected.add(newReflectedPoint(reflectList.get(i),
  168.                                             centroid,
  169.                                             eval));
  170.         }

  171.         return reflected;
  172.     }

  173.     /**
  174.      * @param point Current point.
  175.      * @param centroid Coordinates through which reflection must be performed.
  176.      * @param eval Evaluation function.
  177.      * @return a new point with Cartesian coordinates set to the reflection
  178.      * of {@code point} through {@code centroid}.
  179.      */
  180.     private PointValuePair newReflectedPoint(PointValuePair point,
  181.                                              double[] centroid,
  182.                                              MultivariateFunction eval) {
  183.         final double alpha = alphaSampler.sample();
  184.         return Simplex.newPoint(centroid,
  185.                                 -alpha,
  186.                                 point.getPoint(),
  187.                                 eval);
  188.     }

  189.     /** {@inheritDoc} */
  190.     @Override
  191.     public String toString() {
  192.         return "Hedar-Fukushima [s=" + sigma + "]";
  193.     }
  194. }