aboutsummaryrefslogtreecommitdiff
path: root/pix_recNN/RecurrentNeuron.h
blob: ee870686aacf41adfd2ffe66a22a951715f3aa79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/////////////////////////////////////////////////////////////////////////////
//
// class RecurrentNeuron
//
//   this is an implementation of one neuron of a Recurrent Neural Network
//   this neuron can have n input values, m values in it's memory and
//   one output value
//   (see NeuralNet documentations for more information)
//
//   header file
//
//   Copyright (c) 2005 Georg Holzmann <grh@gmx.at>
//
//   This program is free software; you can redistribute it and/or
//   modify it under the terms of the GNU General Public License
//   as published by the Free Software Foundation; either version 2
//   of the License, or (at your option) any later version.
//
/////////////////////////////////////////////////////////////////////////////


#ifndef _INCLUDE_RECURRENT_NEURON_NET__
#define _INCLUDE_RECURRENT_NEURON_NET__

#include <stdlib.h>
#include <stdexcept>
#include "Neuron.h"

namespace TheBrain
{

//------------------------------------------------------
/* class of one neuron
 */
class RecurrentNeuron : public Neuron
{
 protected:

  /* this determines how much output values the net
   * can remeber
   * these values are fed back as new input
   */
  int memory_;

  /* the weight matrix for the recurrent 
   * values (size: memory_)
   */
  float *LW_;


 public:

  /* Constructor
   */
  RecurrentNeuron(int inputs, int memory);

  /* Destructor
   */
  virtual ~RecurrentNeuron();


  //-----------------------------------------------------
  /* some more get/set methods
   */

  virtual int getMemory() const
  {  return memory_; }

  virtual float *getLW() const
  {  return LW_; }
  virtual float getLW(int index) const
  {  return LW_[index]; }

  virtual void setLW(const float *LW)
  {  for(int i=0; i<inputs_; i++) LW_[i] = LW[i]; }
  virtual void setLW(int index, float value)
  {  LW_[index] = value; }


  //-----------------------------------------------------

  /* creates a new IW-matrix (size: inputs_) and 
   * b1-vector
   * ATTENTION: if they exist they'll be deleted
   */
  virtual void create()
    throw(NNExcept);

  /* inits the weight matrix and the bias vector of
   * the network with random values between [min|max]
   */
  virtual void initRand(const int &min, const int &max)
    throw(NNExcept);

  /* inits the net with given weight matrix and bias
   * (makes a deep copy)
   * ATTENTION: the dimension of IW-pointer must be the same
   *            as the inputs (also for LW) !!!
   */
  virtual void init(const float *IW, const float *LW, float b1)
    throw(NNExcept);

  /* calculates the output with the current IW, b1 values
   * ATTENTION: the array input_data must be in the same
   *            size as inputs_
   */
  virtual float calculate(float *input_data);

  /* this method trains the network:
   * input_data is, as above, the input data, output_data is the 
   * output of the current net with input_data (output_data is not
   * calculated in that method !), target_output is the desired
   * output data
   * (this is the LMS-algorithm to train linear neural networks)
   * ATTENTION: the array input_data must be in the same
   *            size as inputs_
   * returns the calculated output
   */
/*   virtual float trainLMS(const float *input_data,  */
/* 			 const float &target_output); */


  //-----------------------------------------------------
 private:

  /* the storage for the memory data
   */
  float *mem_data_;

  /* this index is used to make something
   * like a simple list or ringbuffer
   */
  int index_;

  /* Copy Construction is not allowed
   */
  RecurrentNeuron(const RecurrentNeuron &src) : Neuron(1)
    { }

  /* assignement operator is not allowed
   */
    const RecurrentNeuron& operator= (const RecurrentNeuron& src)
    { return *this; }
};


} // end of namespace

#endif //_INCLUDE_RECURRENT_NEURON_NET__