aboutsummaryrefslogtreecommitdiff
path: root/knn/knn.c
diff options
context:
space:
mode:
Diffstat (limited to 'knn/knn.c')
-rw-r--r--knn/knn.c475
1 files changed, 475 insertions, 0 deletions
diff --git a/knn/knn.c b/knn/knn.c
new file mode 100644
index 0000000..592f982
--- /dev/null
+++ b/knn/knn.c
@@ -0,0 +1,475 @@
+/*
+ * Copyright (C) 2000 Ichiro Fujinaga and Karl MacMillan
+ *
+ * Minor updates and maintenance (2008) Jamie Bullock <jamie@postlude.co.uk>
+ *
+ *
+ * This program is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU General Public License
+ * as published by the Free Software Foundation; either version 2
+ * of the License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+ */
+
+/* This is a simple KNN object for PureData */
+
+#include <stdio.h>
+#include <math.h>
+#include <stdlib.h>
+#include <errno.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include "m_pd.h"
+
+
+#undef N_FEATURES
+#define N_FEATURES 40 /* FIXME - this really needs to be in the file somehow */
+#define MAX_N_CLASSES 100 /* number of instruments */
+#define TRUE 1
+#define FALSE 0
+
+#define ABS(x) (((x) > 0) ? (x) : -(x))
+
+typedef struct _nn {
+ int index; /* where in the feature db */
+ float dist; /* distance from the unknown */
+} NN;
+
+typedef struct _feature {
+ unsigned long id;
+ unsigned long symbol_id;
+ float *feature;
+} FEATURE;
+
+static t_class *knn_class; /* class pointer for PD */
+
+typedef struct _knn {
+ t_object x_obj; /* object pointer for PD */
+ t_outlet *out; /* this outputs the instrument class
+ * (a float) */
+ t_outlet *loo; /* this is used for debugging - it outputs a
+ * feature vec */
+ FEATURE *feature_db; /* the database of features */
+ FEATURE unknown;
+ int db_size; /* size of the features db */
+ float mean[N_FEATURES]; /* for normalization */
+ float stdev[N_FEATURES]; /* for normalization */
+ int k; /* number of k to use */
+ NN *knn; /* array of nearest neighbors of length k */
+ float wt[N_FEATURES]; /* the feature weight vector */
+ int normalize; /* flag for normalization */
+ int learn_mode; /* flag for learn (i.e. add) mode */
+ int learn_symbol_id; /* the symbol id to add if in learn mode */
+} t_knn;
+
+static void *knn_new(t_floatarg f);
+static int find_nn(t_knn * x, FEATURE * unknown, FEATURE * database,
+ int size, float *weights, int k);
+static NN *k_nn_new(int k);
+static int majority(t_knn * x);
+static float k_nn(t_knn * x, int candidate, float dist);
+static float calc_dist(FEATURE * unknown, FEATURE * known, float *weights,
+ float min_dist);
+static void get_timbre_features(t_knn * x, char *fname);
+static void normalize_database(t_knn * x, FEATURE * Features, int size);
+static void normalize(t_knn * x, FEATURE * f);
+static void add_feature(t_knn * x, float *features, int id);
+
+static void *knn_new(t_floatarg f)
+{
+ int i;
+ t_knn *x = (t_knn *) pd_new(knn_class);
+ x->out = outlet_new(&x->x_obj, &s_float);
+ x->loo = outlet_new(&x->x_obj, gensym("list"));
+
+ /* create the NN for the correct size of k */
+ if (f == 0)
+ x->k = 1; /* default */
+ else
+ x->k = (int) f;
+ x->knn = k_nn_new(x->k);
+
+ /* create space for the unknown feature */
+ x->unknown.feature = (float *) calloc(sizeof(float), N_FEATURES);
+
+ x->db_size = 0;
+ x->feature_db = NULL;
+ x->normalize = FALSE;
+ x->learn_mode = FALSE;
+
+ for (i = 0; i < N_FEATURES; i++) {
+ x->wt[i] = 1;
+ x->mean[i] = 0;
+ x->stdev[i] = 1;
+ }
+ return (x);
+}
+
+static void knn_free(t_knn * x)
+{
+ free(x->knn);
+ free(x->feature_db);
+}
+
+static void knn_list(t_knn * x, t_symbol * s, int argc, t_atom * argv)
+{
+ int i;
+ int answer;
+ int id;
+ int add;
+ t_atom *a;
+ float f[N_FEATURES]; // Is this still used? 00/12/04 IF
+
+ /* read in the features from the PD list */
+ for (i = 0, a = argv; i < argc; i++, a++) {
+ if (a->a_type == A_FLOAT)
+ f[i] = x->unknown.feature[i] = a->a_w.w_float;
+ else {
+ f[i] = x->unknown.feature[i] = 0;
+ post("KNN: error - wrong type for list elem %d %d", i,
+ a->a_w.w_float);
+ if (isinf(x->unknown.feature[i]))
+ f[i] = x->unknown.feature[i] = 0;
+ }
+ }
+
+ if (x->learn_mode == TRUE) {
+ add_feature(x, x->unknown.feature, x->learn_symbol_id);
+ } else if (x->feature_db != NULL) {
+ post("Looking for answer");
+ if (x->normalize == TRUE)
+ normalize(x, &x->unknown);
+ answer =
+ find_nn(x, &x->unknown, x->feature_db, x->db_size, x->wt, 1);
+post("ANSWER: %d", answer);
+ outlet_float(x->out, answer);
+ } else
+ post("No database loaded!");
+}
+
+static void add_feature(t_knn * x, float *features, int id)
+{
+ int i;
+
+ if (x->feature_db == NULL)
+ x->feature_db = (FEATURE *) calloc(sizeof(FEATURE), 1);
+ else
+ x->feature_db = (FEATURE *) realloc(x->feature_db, sizeof(FEATURE) *
+ (x->db_size + 1));
+ assert(x->feature_db != NULL);
+
+ x->feature_db[x->db_size].feature =
+ (float *) calloc(sizeof(float), N_FEATURES);
+ x->feature_db[x->db_size].symbol_id = id;
+ for (i = 0; i < N_FEATURES; i++)
+ x->feature_db[x->db_size].feature[i] = features[i];
+ x->db_size++;
+ if (x->learn_mode == TRUE)
+ post("Added entry for instrument %d - db now has %d entries.",
+ id, x->db_size);
+}
+
+static void knn_loo(t_knn * x, float f)
+{
+ int i;
+ t_atom at[N_FEATURES];
+
+ if ((int) f >= 0 && (int) f <= x->db_size - 1) {
+ for (i = 0; i < N_FEATURES; i++) {
+ SETFLOAT(&at[i], x->feature_db[(int) f].feature[i]);
+ }
+ outlet_list(x->loo, 0, N_FEATURES, at);
+ outlet_float(x->out, x->feature_db[(int) f].symbol_id);
+
+ post("KNN: output instrument %d", x->feature_db[(int) f].symbol_id);
+ } else
+ post("Index out of range");
+}
+
+static void knn_set_normal(t_knn * x, float f)
+{
+ if ((int) f == FALSE)
+ x->normalize = FALSE;
+ else if ((int) f == TRUE)
+ x->normalize = TRUE;
+ post("set normalize %d", x->normalize);
+}
+
+static void knn_set_learn(t_knn * x, float mode, float symbol)
+{
+ if ((int) mode == TRUE) {
+ x->learn_mode = TRUE;
+ x->learn_symbol_id = (int) symbol;
+ post("Now in learn mode for instrument %d", (int) symbol);
+ } else {
+ x->learn_mode = FALSE;
+ post("Learn mode deactivated");
+ }
+}
+
+static void knn_free_db(t_knn * x)
+{
+ int i;
+ for (i = 0; i < x->db_size; i++) {
+ free(x->feature_db[i].feature);
+ }
+ free(x->feature_db);
+ x->db_size = 0;
+ x->feature_db = NULL;
+ post("Database cleared");
+}
+
+static void knn_save(t_knn * x, t_symbol * s)
+{
+ FILE *fp;
+ int i, j;
+
+ fp = fopen(s->s_name, "w");
+ if (fp == NULL) {
+ post("Error saving file %s", strerror(errno));
+ return;
+ }
+
+ for (i = 0; i < x->db_size; i++) {
+ fprintf(fp, "%i", (int) x->feature_db[i].symbol_id);
+
+ for (j = 0; j < N_FEATURES; j++) {
+ fprintf(fp, " %g", x->feature_db[i].feature[j]);
+ }
+ fprintf(fp, "\n");
+ }
+ fclose(fp);
+ post("File saved to %s", s->s_name);
+
+}
+
+static void knn_read(t_knn * x, t_symbol * s)
+{
+ if (x->feature_db != NULL)
+ knn_free_db(x);
+ get_timbre_features(x, s->s_name);
+}
+
+static void knn_read_wt(t_knn * x, t_symbol * s)
+{
+ FILE *fp;
+ int ret, count = 0;
+
+ fp = fopen(s->s_name, "r");
+ if (fp == NULL) {
+ post("Error reading weights %s", strerror(errno));
+ return;
+ }
+ printf("Reading weights from %s . . . ", s->s_name);
+ while (1) {
+ ret = fscanf(fp, "%g", &x->wt[count]);
+ if (ret != 1)
+ break;
+ count++;
+ if (count > N_FEATURES) {
+ post("too many weights!");
+ return;
+ }
+ }
+ fclose(fp); /* JB-070505 - close the file handle */
+ printf("done\n");
+}
+
+void knn_setup(void)
+{
+ knn_class = class_new(gensym("knn"), (t_newmethod) knn_new,
+ (t_method) knn_free, sizeof(t_knn), 0, A_DEFFLOAT,
+ 0);
+ class_addmethod(knn_class, (t_method) knn_loo, gensym("loo"), A_FLOAT,
+ 0);
+ class_addmethod(knn_class, (t_method) knn_set_normal, gensym("normal"),
+ A_FLOAT, 0);
+ class_addmethod(knn_class, (t_method) knn_set_learn, gensym("learn"),
+ A_FLOAT, A_DEFFLOAT, 0);
+ class_addmethod(knn_class, (t_method) knn_save, gensym("save"),
+ A_SYMBOL, 0);
+ class_addmethod(knn_class, (t_method) knn_read, gensym("read"),
+ A_SYMBOL, 0);
+ class_addmethod(knn_class, (t_method) knn_free_db, gensym("clear"), 0);
+ class_addmethod(knn_class, (t_method) knn_read_wt,
+ gensym("readweights"), A_SYMBOL, 0);
+ class_addlist(knn_class, knn_list);
+}
+
+static int find_nn(t_knn * x, FEATURE * unknown, FEATURE * database,
+ int size, float *weights, int k)
+{
+ int i, min = 1;
+ float dist = 0, min_dist = 9e100; // A BIG number
+
+ for (i = 0; i < x->k; i++) {
+ x->knn[i].index = 0;
+ x->knn[i].dist = 9e10;
+ }
+
+ for (i = 0; i < x->db_size; i++) {
+
+ dist = calc_dist(unknown, &database[i], weights, min_dist);
+ if (dist < min_dist) {
+ min_dist = dist;
+ min = i;
+ min_dist = k_nn(x, min, min_dist);
+ }
+ }
+
+ min = majority(x);
+
+ return min;
+}
+
+static NN *k_nn_new(int k)
+{
+ int i;
+ NN *nn;
+
+ nn = (NN *) malloc(sizeof(NN) * k);
+ assert(nn != NULL);
+ for (i = 0; i < k; i++) {
+ nn[i].index = -1;
+ nn[i].dist = FLT_MAX;
+ }
+ return (nn);
+}
+
+
+static int majority(t_knn * x)
+{
+/* returns class number. NB: no tie breaker */
+ int i, max_i = 0, max = 0;
+ int classes[MAX_N_CLASSES];
+
+ for (i = 0; i < MAX_N_CLASSES; i++)
+ classes[i] = 0;
+
+ for (i = 0; i < x->k; i++)
+ classes[x->feature_db[x->knn[i].index].symbol_id] += 1;
+
+ for (i = 0; i < MAX_N_CLASSES; i++) {
+ if (classes[i] > max) {
+ max = classes[i];
+ max_i = i;
+ }
+ }
+ return (max_i);
+}
+
+static float k_nn(t_knn * x, int candidate, float dist)
+{
+ int i, insert;
+
+ for (i = 0; i < x->k; i++) {
+ if (dist < x->knn[i].dist)
+ break;
+ }
+ if (i < x->k) {
+ insert = i;
+
+ for (i = x->k - 1; i > insert; i--)
+ x->knn[i] = x->knn[i - 1];
+
+ x->knn[insert].index = candidate;
+ x->knn[insert].dist = dist;
+ }
+ return (x->knn[x->k - 1].dist);
+}
+
+static float calc_dist(FEATURE * unknown, FEATURE * known, float *weights,
+ float min_dist)
+{
+ float dist = 0;
+ int i;
+
+ for (i = 0; i < N_FEATURES && dist < min_dist; i++) {
+ /* Euclidean w/o the sqrt */
+ dist += (weights[i]) *
+ (unknown->feature[i] - known->feature[i]) *
+ (unknown->feature[i] - known->feature[i]);
+
+ }
+ return ((float) dist);
+}
+
+static void normalize_database(t_knn * x, FEATURE * Features, int size)
+{
+
+ int i, k;
+ float sum, sum2;
+
+ for (k = 0; k < N_FEATURES; k++) {
+ sum = 0;
+ sum2 = 0;
+ for (i = 0; i < size; i++) {
+ sum += Features[i].feature[k];
+ sum2 += Features[i].feature[k] * Features[i].feature[k];
+ }
+ x->mean[k] = sum / size;
+ x->stdev[k] = sqrt((size * sum2 - sum * sum) / (size * (size - 1)));
+
+ if (x->stdev[k] < 0.00001)
+ x->stdev[k] = 0.00001;
+ if (x->normalize == TRUE) {
+ for (i = 0; i < size; i++)
+ Features[i].feature[k] =
+ (Features[i].feature[k] - x->mean[k]) / x->stdev[k];
+ }
+ }
+}
+
+static void normalize(t_knn * x, FEATURE * f)
+{
+ int i;
+
+ for (i = 0; i < N_FEATURES; i++)
+ f->feature[i] = (f->feature[i] - x->mean[i]) / x->stdev[i];
+}
+
+
+void get_timbre_features(t_knn * x, char *fname)
+{
+ FILE *fp;
+ int i, j, ret, instr = 0;
+ float val;
+ static char last_name[20], name[20];
+ float features[N_FEATURES];
+ int id = 0;
+ int lastpitch = 0;
+
+ printf("Reading %s . . . ", fname);
+
+ fp = fopen(fname, "r");
+ if (fp == NULL) {
+ post("\nError opening file - %s", strerror(errno));
+ return;
+ }
+ while (1) {
+ ret = fscanf(fp, "%d", &instr);
+
+ if (ret != 1)
+ break;
+
+ for (i = 0; i < N_FEATURES; i++) {
+ fscanf(fp, "%g ", &features[i]);
+ if (isinf(features[i]))
+ features[i] = 0.0;
+ }
+
+ add_feature(x, features, instr);
+ }
+ printf("done.\nDatabase contains %d entries.\n", x->db_size);
+ fclose(fp);
+ normalize_database(x, x->feature_db, x->db_size);
+}