backpropagationalgo.h
1 /********************************************************************************
2  * Neural Network Framework. *
3  * Copyright (C) 2005-2011 Gianluca Massera <emmegian@yahoo.it> *
4  * *
5  * This program is free software; you can redistribute it and/or modify *
6  * it under the terms of the GNU General Public License as published by *
7  * the Free Software Foundation; either version 2 of the License, or *
8  * (at your option) any later version. *
9  * *
10  * This program is distributed in the hope that it will be useful, *
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of *
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
13  * GNU General Public License for more details. *
14  * *
15  * You should have received a copy of the GNU General Public License *
16  * along with this program; if not, write to the Free Software *
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA *
18  ********************************************************************************/
19 
20 #ifndef BACKPROPAGATIONALGO_H
21 #define BACKPROPAGATIONALGO_H
22 
23 #include "nnfwconfig.h"
24 #include "learningalgorithm.h"
25 #include "biasedcluster.h"
26 #include "matrixlinker.h"
27 #include <QMap>
28 #include <QVector>
29 
30 namespace farsa {
31 
35 class FARSA_NNFW_API BackPropagationAlgo : public LearningAlgorithm {
36 public:
43  BackPropagationAlgo( NeuralNet *n_n, UpdatableList update_order, double l_r = 0.1 );
46 
49 
54  void setUpdateOrder( const UpdatableList& update_order );
55 
57  UpdatableList updateOrder() const {
58  return update_order;
59  };
63  void setTeachingInput( Cluster* output, const DoubleVector& ti );
64 
65  virtual void learn();
66 
68  virtual void learn( const Pattern& );
69 
71  virtual double calculateMSE( const Pattern& );
72 
74  void setRate( double newrate ) {
75  learn_rate = newrate;
76  };
77 
79  double rate() const {
80  return learn_rate;
81  };
82 
84  void setMomentum( double newmom ) {
85  momentumv = newmom;
86  };
87 
89  double momentum() const {
90  return momentumv;
91  };
92 
94  void enableMomentum();
95 
97  void disableMomentum() {
98  useMomentum = false;
99  };
100 
122  DoubleVector getError( Cluster* );
166  virtual void configure(ConfigurationParameters& params, QString prefix);
174  virtual void save(ConfigurationParameters& params, QString prefix);
176  static void describe( QString type );
177 protected:
179  virtual void neuralNetChanged();
180 private:
182  double learn_rate;
184  double momentumv;
186  double useMomentum;
188  UpdatableList update_order;
189 
191  class FARSA_NNFW_API cluster_deltas {
192  public:
193  BiasedCluster* cluster;
194  bool isOutput;
195  DoubleVector deltas_outputs;
196  DoubleVector deltas_inputs;
197  DoubleVector last_deltas_inputs;
198  QList<MatrixLinker*> incoming_linkers_vec;
199  QVector<DoubleVector> incoming_last_outputs;
200  };
202  QMap<Cluster*, int> mapIndex;
204  QVector<cluster_deltas> cluster_deltas_vec;
205  // --- propagate delta through the net
206  void propagDeltas();
207  // --- add a Cluster into the structures above
208  void addCluster( Cluster*, bool );
209  // --- add a Linker into the structures above
210  void addLinker( Linker* );
211 
212 };
213 
214 }
215 
216 #endif
217 
This file contains the common type defitions used on the whole framework.
void disableMomentum()
Disable momentum.
double momentum() const
return the momentum
Back-Propagation Algorithm implementation.
In a BiasedCluster each neuron has an input, an output and a bias value.
Definition: biasedcluster.h:41
UpdatableList updateOrder() const
Return the order on which the error is backpropaget through the NeuralNet.
double rate() const
return the learning rate
Define the common interface among Clusters.
Definition: cluster.h:73
LearningAlgorithm object.
void setMomentum(double newmom)
Set the momentum value.
The Neural Network Class.
Definition: neuralnet.h:221
Pattern object.
void setRate(double newrate)
Set the learning rate.
This file contains the declaration of BiasedCluster class.