// - - - - - - - - - - - - - - - - - - - - -
// File: neuron.cpp                         |
// Purpose: implements class Neuron         |
// Author: Taivo Lints, Estonia             |
// Date: May, 2003                          |
// Copyright: see copyright.txt             |
// - - - - - - - - - - - - - - - - - - - - -

#include "neuron.h"
#include "connection.h"
#include "math.h"
#include <vector>
using namespace std;

// - - - - - Class Neuron - - - - -


// ******************************
// * Construction & Destruction *
// ******************************

// - Neuron constructor - -
// Sets Neuron parameters to default values. Needs a pointer to the
// vpNeurons vector where this neuron will belong (because it needs
// to communicate with other neurons).
Neuron::Neuron(vector<Neuron*>* init_pvpNeurons, bool inp_flag_comp_neuron) {

  pvpNeurons = init_pvpNeurons;

  treshold = 0;
  slope_parameter = 0.5;

  output = 0;

  error = 0;
  learning_rate = 5;
  max_weight = 10;

  flag_comp_neuron = inp_flag_comp_neuron;

}

// - Neuron copy-constructor - -
Neuron::Neuron(const Neuron& rN) {

  pvpNeurons = rN.pvpNeurons;
  treshold = rN.treshold;
  slope_parameter = rN.slope_parameter;
  learning_rate = rN.learning_rate;
  max_weight = rN.max_weight;
  output = rN.output;
  error = rN.error;
  flag_comp_neuron = rN.flag_comp_neuron;

  Neuron& rN2 = const_cast<Neuron&>(rN);
              // A suspicious thing to do, but compiler will throw warnings
              // about invalid conversions in stl_container when I don't
              // cast away the const-ness. At the moment I don't have time
              // to trace the source of this problem. At least it SEEMS to
              // be a harmless problem...

  // Must NOT point to the connections of the OTHER neuron,
  // must create its own connections.
  for(vector<Connection*>::iterator i = rN2.vpConnections.begin();
      i != rN2.vpConnections.end(); i++) {
    vpConnections.push_back(new Connection(0, 0));
    *(*(--vpConnections.end())) = *(*i);
  }

}

// - Operator= overloading - -
Neuron& Neuron::operator=(const Neuron& rN) {

  if(&rN != this) {   // Check for self-assignment.

    pvpNeurons = rN.pvpNeurons;
    treshold = rN.treshold;
    slope_parameter = rN.slope_parameter;
    learning_rate = rN.learning_rate;
    max_weight = rN.max_weight;
    output = rN.output;
    error = rN.error;
    flag_comp_neuron = rN.flag_comp_neuron;

    Neuron& rN2 = const_cast<Neuron&>(rN);
                // A suspicious thing to do, but compiler will throw warnings
                // about invalid conversions in stl_container when I don't
                // cast away the const-ness. At the moment I don't have time
                // to trace the source of this problem. At least it SEEMS to
                // be a harmless problem...

    // Must NOT point to the connections of the OTHER neuron,
    // must create its own connections.
    for(vector<Connection*>::iterator i = rN2.vpConnections.begin();
        i != rN2.vpConnections.end(); i++) {
      vpConnections.push_back(new Connection(0, 0));
      *(*(--vpConnections.end())) = *(*i);
    }
  }

  return *this;

}

// - Neuron destructor - -
Neuron::~Neuron() {

  // Deletes all Connections (to prevent memory leak).
  for(vector<Connection*>::iterator i = vpConnections.begin();
      i != vpConnections.end(); i++)
    delete *i;

}


// *************
// * Functions *
// *************

// - Function: add_connection() - -
// Adds a connection to vpConnections vector. Does NOT verify
// it's correctness!
void Neuron::add_connection(int source, double weight) {

  // Input nodes can't have connections.
  if(not flag_comp_neuron)
    return;

  vpConnections.push_back(new Connection(source, weight));

}

// - Function: update() - -
// Reads inputs through connections and calculates the output.
void Neuron::update() {

  // Input nodes don't update themselves.
  if(not flag_comp_neuron)
    return;
  
  double sum = 0;  // The sum of all inputs (source neuron outputs
                   // multiplied by corresponding connection weights).

  // Calculating the sum.
  for(vector<Connection*>::iterator i = vpConnections.begin();
      i != vpConnections.end(); i++)
    sum += (*pvpNeurons)[(*i)->source]->output * (*i)->weight;

  sum -= treshold;  // Applying treshold value.

  // Calculating new output (i.e. applying the activation function).
  output = 1 / (1 + exp(-slope_parameter * sum));

}

// - Function: learn() - -
// Updates error value and modifies weights of the connections.
// Also backpropagates error.
void Neuron::learn() {

  // Input nodes don't learn.
  if(not flag_comp_neuron) {
    error = 0;
    return;
  }

  // Calculating error signal (multiplying the error value with
  // the derivative of activation function).
  double error_signal = error * slope_parameter * output * (1 - output);

  // Backpropagating error signal and THEN updating weights.
  for(vector<Connection*>::iterator i = vpConnections.begin();
      i != vpConnections.end(); i++) {
    (*pvpNeurons)[(*i)->source]->error += error_signal * (*i)->weight;
    (*i)->weight += learning_rate * error_signal *
                      (*pvpNeurons)[(*i)->source]->output;

    if((*i)->weight > max_weight)   // Weights should NOT be allowed to go
      (*i)->weight = max_weight;    // to infinity.

    if((*i)->weight < -max_weight)
      (*i)->weight = -max_weight;
  }

  error = 0;  // Must be done, otherwise error will accumulate (look at
              // the "+=" operation a few rows up from here).

}