diff options
Diffstat (limited to 'src/ann_mlp.c')
-rwxr-xr-x | src/ann_mlp.c | 124 |
1 files changed, 62 insertions, 62 deletions
diff --git a/src/ann_mlp.c b/src/ann_mlp.c index e83aea3..e01b203 100755 --- a/src/ann_mlp.c +++ b/src/ann_mlp.c @@ -30,8 +30,8 @@ typedef struct _ann_mlp { int mode; // 0 = training, 1 = running
t_symbol *filename; // name of the file where this ann is saved
t_symbol *filenametrain; // name of the file with training data
- float desired_error; - unsigned int max_iterations; + float desired_error;
+ unsigned int max_iterations;
unsigned int iterations_between_reports;
t_outlet *l_out, *f_out;
} t_ann_mlp;
@@ -49,48 +49,48 @@ void help(t_ann_mlp *x) }
void createFann(t_ann_mlp *x, t_symbol *sl, int argc, t_atom *argv)
-{ - unsigned int num_input = 2; - unsigned int num_output = 1; - unsigned int num_layers = 3; - unsigned int num_neurons_hidden = 3; - float connection_rate = 1; - float learning_rate = (float)0.7; - - if (argc>0) - num_input = atom_getint(argv++); - - if (argc>1) - num_output = atom_getint(argv++); - - if (argc>2) - num_layers = atom_getint(argv++); - - if (argc>3) - num_neurons_hidden = atom_getint(argv++); - - if (argc>4) - connection_rate = atom_getfloat(argv++); - - if (argc>5) - learning_rate = atom_getfloat(argv++); - - if (num_input>=MAXINPUT) - { - error("too many inputs, maximum allowed is MAXINPUT"); - return; - } - - if (num_output>=MAXOUTPUT) - { - error("too many outputs, maximum allowed is MAXOUTPUT"); - return; - } - - x->ann = fann_create(connection_rate, learning_rate, num_layers, - num_input, num_neurons_hidden, num_output); - - fann_set_activation_function_hidden(x->ann, FANN_SIGMOID_SYMMETRIC); +{
+ unsigned int num_input = 2;
+ unsigned int num_output = 1;
+ unsigned int num_layers = 3;
+ unsigned int num_neurons_hidden = 3;
+ float connection_rate = 1;
+ float learning_rate = (float)0.7;
+
+ if (argc>0)
+ num_input = atom_getint(argv++);
+
+ if (argc>1)
+ num_output = atom_getint(argv++);
+
+ if (argc>2)
+ num_layers = atom_getint(argv++);
+
+ if (argc>3)
+ num_neurons_hidden = atom_getint(argv++);
+
+ if (argc>4)
+ connection_rate = atom_getfloat(argv++);
+
+ if (argc>5)
+ learning_rate = atom_getfloat(argv++);
+
+ if (num_input>=MAXINPUT)
+ {
+ error("too many inputs, maximum allowed is MAXINPUT");
+ return;
+ }
+
+ if (num_output>=MAXOUTPUT)
+ {
+ error("too many outputs, maximum allowed is MAXOUTPUT");
+ return;
+ }
+
+ x->ann = fann_create(connection_rate, learning_rate, num_layers,
+ num_input, num_neurons_hidden, num_output);
+
+ fann_set_activation_function_hidden(x->ann, FANN_SIGMOID_SYMMETRIC);
fann_set_activation_function_output(x->ann, FANN_SIGMOID_SYMMETRIC);
if (x->ann == 0)
@@ -150,25 +150,25 @@ void set_mode(t_ann_mlp *x, t_symbol *sl, int argc, t_atom *argv) void train_on_file(t_ann_mlp *x, t_symbol *sl, int argc, t_atom *argv)
-{ +{
if (x->ann == 0)
{
error("ann not initialized");
return;
- } - - if (argc<1) - { - error("you must specify the filename with training data"); - return; - } else - { - x->filenametrain = atom_gensym(argv); - }
- - //post("nn: starting training on file %s, please be patient and wait for my next message (it could take severeal minutes to complete training)", x->filenametrain->s_name); - - fann_train_on_file(x->ann, x->filenametrain->s_name, x->max_iterations, + }
+
+ if (argc<1)
+ {
+ error("you must specify the filename with training data");
+ return;
+ } else
+ {
+ x->filenametrain = atom_gensym(argv);
+ }
+
+ //post("nn: starting training on file %s, please be patient and wait for my next message (it could take severeal minutes to complete training)", x->filenametrain->s_name);
+
+ fann_train_on_file(x->ann, x->filenametrain->s_name, x->max_iterations,
x->iterations_between_reports, x->desired_error);
post("nn: finished training on file %s", x->filenametrain->s_name);
@@ -496,8 +496,8 @@ void *nn_new(t_symbol *s, int argc, t_atom *argv) x->l_out = outlet_new(&x->x_obj, &s_list);
x->f_out = outlet_new(&x->x_obj, &s_float);
- x->desired_error = (float)0.001; - x->max_iterations = 500000; + x->desired_error = (float)0.001;
+ x->max_iterations = 500000;
x->iterations_between_reports = 1000;
x->mode=RUN;
@@ -562,4 +562,4 @@ void ann_mlp_setup(void) { class_sethelpsymbol(ann_mlp_class, gensym("help-ann_mlp"));
-}
\ No newline at end of file +}
|