aboutsummaryrefslogtreecommitdiff
path: root/pix_linNN/LinNeuralNet.cpp
blob: 87318a092d33b43b123d1cfa0ca4cedf6b60c4ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/////////////////////////////////////////////////////////////////////////////
//
// class LinNeuralNet
//
//   source file
//
//   Copyright (c) 2004 Georg Holzmann <grh@gmx.at>
//
//   For information on usage and redistribution, and for a DISCLAIMER OF ALL
//   WARRANTIES, see the file, "GEM.LICENSE.TERMS" in this distribution.
//
/////////////////////////////////////////////////////////////////////////////

#include "LinNeuralNet.h"

//--------------------------------------------------
/* Constructor
 */
LinNeuralNet::LinNeuralNet(int netsize) : learn_rate_(0), range_(1), IW_(NULL), b1_(0)
{
  // set random seed:
  srand( (unsigned)time(NULL) );

  netsize_ = (netsize<1) ? 1 : netsize;
}

//--------------------------------------------------
/* Destructor
 */
LinNeuralNet::~LinNeuralNet()
{
  if(IW_)
    delete[] IW_;
}

//--------------------------------------------------
/* creates a new IW-matrix (size: netsize_) and 
 * b1-vector
 * ATTENTION: if they exist they'll be deleted
 */
bool LinNeuralNet::createNeurons()
{
  // delete if they exist
  if(IW_)
    delete[] IW_;

  IW_ = new float[netsize_];
  if(!IW_)
    return false;

  return true;
}

//--------------------------------------------------
/* inits the weight matrix and the bias vector of
 * the network with random values between [min|max]
 */
bool LinNeuralNet::initNetworkRand(const int &min, const int &max)
{
  if(!IW_)
    return false;

  // make randomvalue between 0 and 1
  // then map it to the bounds
  b1_ = ((float)rand()/(float)RAND_MAX)*(max-min) + min;

  for(int i=0; i<netsize_; i++)
    {
      IW_[i] = ((float)rand()/(float)RAND_MAX)*(max-min) + min;
    }

  return true;
}

//--------------------------------------------------
/* inits the net with a given weight matrix and bias
 * (makes a deep copy)
 * ATTENTION: the dimension of IW-pointer must be the same
 *            as the netsize !!!
 * returns false if there's a failure
 */
bool LinNeuralNet::initNetwork(const float *IW, float b1)
{
  if(!IW_)
    return false;

  b1_ = b1;

  for(int i=0; i<netsize_; i++)
      IW_[i] = IW[i];

  return true;
}

//--------------------------------------------------
/* calculates the output with the current IW, b1 values
 * ATTENTION: the array input_data must be in the same
 *            size as netsize_
 */
float LinNeuralNet::calculateNet(float *input_data)
{
  if(!IW_)
    return 0;

  float output = 0;

  // multiply the inputs with the weight matrix IW
  // and add the bias vector b1
  for(int i=0; i<netsize_; i++)
      output += input_data[i] * IW_[i];
  
  // map input values to the range
  output /= range_;
  
  return (output+b1_);
}

//--------------------------------------------------
/* this method trains the network:
 * input_data is, as above, the input data, output_data is the 
 * output of the current net with input_data (output_data is not
 * calculated in that method !), target_output is the desired
 * output data
 * (this is the LMS-algorithm to train linear neural networks)
 * ATTENTION: the array input_data must be in the same
 *            size as netsize_
 */
bool LinNeuralNet::trainNet(float *input_data, const float &output_data, 
			    const float &target_output)
{
  if(!IW_)
    return false;

  // this is the LMS-algorithm to train linear
  // neural networks
  
  // calculate the error signal:
  float error = (target_output - output_data);

  // now change the weights the bias
  for(int i=0; i<netsize_; i++)
    IW_[i] += 2 * learn_rate_ * error * (input_data[i]/range_);

  b1_ += 2 * learn_rate_ * error; 

  return true;
}