001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.commons.math3.ml.neuralnet.twod.util; 019 020import org.apache.commons.math3.ml.neuralnet.MapUtils; 021import org.apache.commons.math3.ml.neuralnet.Neuron; 022import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; 023import org.apache.commons.math3.ml.distance.DistanceMeasure; 024 025/** 026 * Computes the quantization error histogram. 027 * Each bin will contain the average of the distances between samples 028 * mapped to the corresponding unit and the weight vector of that unit. 029 * @since 3.6 030 */ 031public class QuantizationError implements MapDataVisualization { 032 /** Distance. */ 033 private final DistanceMeasure distance; 034 035 /** 036 * @param distance Distance. 037 */ 038 public QuantizationError(DistanceMeasure distance) { 039 this.distance = distance; 040 } 041 042 /** {@inheritDoc} */ 043 public double[][] computeImage(NeuronSquareMesh2D map, 044 Iterable<double[]> data) { 045 final int nR = map.getNumberOfRows(); 046 final int nC = map.getNumberOfColumns(); 047 048 final LocationFinder finder = new LocationFinder(map); 049 050 // Hit bins. 051 final int[][] hit = new int[nR][nC]; 052 // Error bins. 053 final double[][] error = new double[nR][nC]; 054 055 for (double[] sample : data) { 056 final Neuron best = MapUtils.findBest(sample, map, distance); 057 058 final LocationFinder.Location loc = finder.getLocation(best); 059 final int row = loc.getRow(); 060 final int col = loc.getColumn(); 061 hit[row][col] += 1; 062 error[row][col] += distance.compute(sample, best.getFeatures()); 063 } 064 065 for (int r = 0; r < nR; r++) { 066 for (int c = 0; c < nC; c++) { 067 final int count = hit[r][c]; 068 if (count != 0) { 069 error[r][c] /= count; 070 } 071 } 072 } 073 074 return error; 075 } 076}