package uk.ac.leeds.ccg.projects.MedAction.NeuralNetwork; /** * A class for Neurons that keep a history of weights that can be adjusted in * training and their change control their processing and training. */ public class Neuron extends Object implements java.io.Serializable { // For storing the number of inputs private int inputNumber; // For storing the learning rate private double rate; // For storing a parameter of sigmoid function private double alfa; // For storing the weighted sum private double dot; // For storing the output private double fire; // For storing the error signal private double errorSignal; // For storing the weights private double[] weights; // For storing the old weights private double[] oldWeights; // For storing the difference in the weights private double[] weightsChange; // For storing the inputs private double[] inputs; /** * Creates new Neuron */ public Neuron( int inputNumber, double rate, double alfa ) { initialise( inputNumber, rate, alfa ); } /** * Initialises OutputNeuron * @param inputNumber * @param rate * @param alfa */ protected void initialise( int inputNumber, double rate, double alfa ) { setInputNumber( inputNumber ); setRate( rate ); setAlfa( alfa ); double[] weights = new double[ inputNumber + 1 ]; double[] oldWeights = new double[ inputNumber + 1 ]; double[] weightsChange = new double[ inputNumber + 1 ]; for( int i = 0; i < inputNumber + 1; i ++ ) { weights[ i ] = Math.random(); //oldWeights[ i ] = 0.0d; //weightsChange[ i ] = weight; } setWeights( weights ); setOldWeights( oldWeights ); setWeightsChange( weightsChange ); } /** * Adjusts weights by back propogation * @param lastLayerErrors * @param lastLayerWeights */ public void adjustWeightsInBackPropogationTraining( double[] lastLayerErrors, double[] lastLayerWeights ) { double[] weights = getWeights(); double[] oldWeights = getOldWeights(); double[] weightsChange = getWeightsChange(); double lastLayerErrorSum = 0.0d; for ( int i = 0; i < lastLayerErrors.length; i ++ ) { lastLayerErrorSum += lastLayerErrors[ i ] * lastLayerWeights[ i ]; } double fire = getFire(); double errorSignal = getAlfa() * lastLayerErrorSum * ( 1.0d - fire ) * fire; setErrorSignal( errorSignal ); double rate = getRate(); oldWeights[ 0 ] = weights[ 0 ]; // Why multiply by 0.09d? weights[ 0 ] += errorSignal * rate + weightsChange[ 0 ] * 0.9d; weightsChange[ 0 ] = weights[ 0 ] - oldWeights[ 0 ]; int inputNumber = getInputNumber(); double[] inputs = getInputs(); for ( int i = 0; i < inputNumber; i ++ ) { oldWeights[ i + 1 ] = weights[ i + 1 ]; weights[ i + 1 ] += errorSignal * rate * inputs[ i ] + weightsChange[ i + 1 ] * 0.9d; //weightsChange[ 0 ] = weight[ i + 1 ] - oldWeight[ i + 1 ]; weightsChange[ i + 1 ] = weights[ i + 1 ] - oldWeights[ i + 1 ]; } setWeights( weights ); setOldWeights( oldWeights ); setWeightsChange( weightsChange ); } /** * Process data * @param inputs */ public void classify( double[] inputs ) { setInputs( inputs ); double[] weights = getWeights(); double dot = weights[ 0 ]; for(int i = 0 ; i < inputNumber ; i ++ ) { dot += weights[ i + 1 ] * inputs[ i ]; } setDot( dot ); setFire( 1.0d / ( 1.0d + Math.exp( - 1.0d * getAlfa() * dot ) ) ); } public double getRate() { return this.rate; } public void setRate( double rate ) { this.rate = rate; } public int getInputNumber() { return this.inputNumber; } public void setInputNumber( int inputNumber ) { this.inputNumber = inputNumber; } public double getAlfa() { return this.alfa; } public void setAlfa( double alfa ) { this.alfa = alfa; } public double getDot() { return this.dot; } public void setDot( double dot ) { this.dot = dot; } public double getFire() { return this.fire; } public void setFire( double fire ) { this.fire = fire; } public double getErrorSignal() { return this.errorSignal; } public void setErrorSignal( double errorSignal ) { this.errorSignal = errorSignal; } public double[] getWeights() { return this.weights; } public void setWeights( double[] weights ) { this.weights = weights; } public double[] getOldWeights() { return this.oldWeights; } public void setOldWeights( double[] oldWeights ) { this.oldWeights = oldWeights; } public double[] getInputs() { return this.inputs; } public void setInputs( double[] inputs ) { this.inputs = inputs; } public double[] getWeightsChange() { return this.weightsChange; } public void setWeightsChange( double[] weightsChange ) { this.weightsChange = weightsChange; } /*public double[][] getOldWeight(){ double[][] temp=new double[inputNumber][1]; for(int i=0;i