Neuron.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. package org.apache.commons.math4.neuralnet;

  18. import java.util.concurrent.atomic.AtomicReference;
  19. import java.util.concurrent.atomic.AtomicLong;

  20. import org.apache.commons.numbers.core.Precision;
  21. import org.apache.commons.math4.neuralnet.internal.NeuralNetException;

  22. /**
  23.  * Describes a neuron element of a neural network.
  24.  *
  25.  * This class aims to be thread-safe.
  26.  *
  27.  * @since 3.3
  28.  */
  29. public class Neuron {
  30.     /** Identifier. */
  31.     private final long identifier;
  32.     /** Length of the feature set. */
  33.     private final int size;
  34.     /** Neuron data. */
  35.     private final AtomicReference<double[]> features;
  36.     /** Number of attempts to update a neuron. */
  37.     private final AtomicLong numberOfAttemptedUpdates = new AtomicLong(0);
  38.     /** Number of successful updates  of a neuron. */
  39.     private final AtomicLong numberOfSuccessfulUpdates = new AtomicLong(0);

  40.     /**
  41.      * Creates a neuron.
  42.      * The size of the feature set is fixed to the length of the given
  43.      * argument.
  44.      * <br>
  45.      * Constructor is package-private: Neurons must be
  46.      * {@link Network#createNeuron(double[]) created} by the network
  47.      * instance to which they will belong.
  48.      *
  49.      * @param identifier Identifier (assigned by the {@link Network}).
  50.      * @param features Initial values of the feature set.
  51.      */
  52.     Neuron(long identifier,
  53.            double[] features) {
  54.         this.identifier = identifier;
  55.         this.size = features.length;
  56.         this.features = new AtomicReference<>(features.clone());
  57.     }

  58.     /**
  59.      * Performs a deep copy of this instance.
  60.      * Upon return, the copied and original instances will be independent:
  61.      * Updating one will not affect the other.
  62.      *
  63.      * @return a new instance with the same state as this instance.
  64.      * @since 3.6
  65.      */
  66.     public synchronized Neuron copy() {
  67.         final Neuron copy = new Neuron(getIdentifier(),
  68.                                        getFeatures());
  69.         copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
  70.         copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());

  71.         return copy;
  72.     }

  73.     /**
  74.      * Gets the neuron's identifier.
  75.      *
  76.      * @return the identifier.
  77.      */
  78.     public long getIdentifier() {
  79.         return identifier;
  80.     }

  81.     /**
  82.      * Gets the length of the feature set.
  83.      *
  84.      * @return the number of features.
  85.      */
  86.     public int getSize() {
  87.         return size;
  88.     }

  89.     /**
  90.      * Gets the neuron's features.
  91.      *
  92.      * @return a copy of the neuron's features.
  93.      */
  94.     public double[] getFeatures() {
  95.         return features.get().clone();
  96.     }

  97.     /**
  98.      * Tries to atomically update the neuron's features.
  99.      * Update will be performed only if the expected values match the
  100.      * current values.<br>
  101.      * In effect, when concurrent threads call this method, the state
  102.      * could be modified by one, so that it does not correspond to the
  103.      * the state assumed by another.
  104.      * Typically, a caller {@link #getFeatures() retrieves the current state},
  105.      * and uses it to compute the new state.
  106.      * During this computation, another thread might have done the same
  107.      * thing, and updated the state: If the current thread were to proceed
  108.      * with its own update, it would overwrite the new state (which might
  109.      * already have been used by yet other threads).
  110.      * To prevent this, the method does not perform the update when a
  111.      * concurrent modification has been detected, and returns {@code false}.
  112.      * When this happens, the caller should fetch the new current state,
  113.      * redo its computation, and call this method again.
  114.      *
  115.      * @param expect Current values of the features, as assumed by the caller.
  116.      * Update will never succeed if the contents of this array does not match
  117.      * the values returned by {@link #getFeatures()}.
  118.      * @param update Features's new values.
  119.      * @return {@code true} if the update was successful, {@code false}
  120.      * otherwise.
  121.      * @throws IllegalArgumentException if the length of {@code update} is
  122.      * not the same as specified in the {@link #Neuron(long,double[])
  123.      * constructor}.
  124.      */
  125.     public boolean compareAndSetFeatures(double[] expect,
  126.                                          double[] update) {
  127.         if (update.length != size) {
  128.             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
  129.                                          update.length, size);
  130.         }

  131.         // Get the internal reference. Note that this must not be a copy;
  132.         // otherwise the "compareAndSet" below will always fail.
  133.         final double[] current = features.get();
  134.         if (!containSameValues(current, expect)) {
  135.             // Some other thread already modified the state.
  136.             return false;
  137.         }

  138.         // Increment attempt counter.
  139.         numberOfAttemptedUpdates.incrementAndGet();

  140.         if (features.compareAndSet(current, update.clone())) {
  141.             // The current thread could atomically update the state (attempt succeeded).
  142.             numberOfSuccessfulUpdates.incrementAndGet();
  143.             return true;
  144.         } else {
  145.             // Some other thread came first (attempt failed).
  146.             return false;
  147.         }
  148.     }

  149.     /**
  150.      * Retrieves the number of calls to the
  151.      * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures}
  152.      * method.
  153.      * Note that if the caller wants to use this method in combination with
  154.      * {@link #getNumberOfSuccessfulUpdates()}, additional synchronization
  155.      * may be required to ensure consistency.
  156.      *
  157.      * @return the number of update attempts.
  158.      * @since 3.6
  159.      */
  160.     public long getNumberOfAttemptedUpdates() {
  161.         return numberOfAttemptedUpdates.get();
  162.     }

  163.     /**
  164.      * Retrieves the number of successful calls to the
  165.      * {@link #compareAndSetFeatures(double[],double[]) compareAndSetFeatures}
  166.      * method.
  167.      * Note that if the caller wants to use this method in combination with
  168.      * {@link #getNumberOfAttemptedUpdates()}, additional synchronization
  169.      * may be required to ensure consistency.
  170.      *
  171.      * @return the number of successful updates.
  172.      * @since 3.6
  173.      */
  174.     public long getNumberOfSuccessfulUpdates() {
  175.         return numberOfSuccessfulUpdates.get();
  176.     }

  177.     /**
  178.      * Checks whether the contents of both arrays is the same.
  179.      *
  180.      * @param current Current values.
  181.      * @param expect Expected values.
  182.      * @throws IllegalArgumentException if the length of {@code expect}
  183.      * is not the same as specified in the {@link #Neuron(long,double[])
  184.      * constructor}.
  185.      * @return {@code true} if the arrays contain the same values.
  186.      */
  187.     private boolean containSameValues(double[] current,
  188.                                       double[] expect) {
  189.         if (expect.length != size) {
  190.             throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH,
  191.                                          expect.length, size);
  192.         }

  193.         for (int i = 0; i < size; i++) {
  194.             if (!Precision.equals(current[i], expect[i])) {
  195.                 return false;
  196.             }
  197.         }
  198.         return true;
  199.     }
  200. }