// - - - - - - - - - - - - - - - - - - -
// File: visual_ann.cpp                 |
// Purpose: implements class VisualANN  |
// Author: Taivo Lints, Estonia         |
// Date: May, 2003                      |
// Copyright: see copyright.txt         |
// - - - - - - - - - - - - - - - - - - -

#include "visual_ann.h"
#include "ann/neuralnetwork.h"
#include "ann/connection.h"
#include "mypalette.h"
#include "allegro.h"
#include "stdio.h"  // For sprintf()
#include <vector>
using namespace std;

// - - - - - Class VisualANN - - - - -

// ******************************
// * Construction & Destruction *
// ******************************

// - VisualANN constructor - -
// VisualANN constructor. Creates a network from your configuration
// file (using NeuralNetwork constructor).
VisualANN::VisualANN(char* config_file) : NeuralNetwork(config_file) {

  // Setting colors on palette.
  color_of_pin = 100;
  color_of_txt = 101;
  color_of_node_circle = 102;
  start_of_blue = 120;
  start_of_white = 190;

  change_color(color_of_pin, 0, 14, 48);
  change_color(color_of_node_circle, 0, 45, 0);
  change_color(color_of_txt, 2, 35, 10);
  make_blueish_palette_64(start_of_blue);
  make_grayscale_palette_64(start_of_white);

  // Some network parameters are duplicated for quicker / easier access.
  num_of_nodes = get_number_of_nodes();
  second_layer_start = get_number_of_inputs();
  output_layer_start = num_of_nodes - get_number_of_outputs();
  max_weight = get_max_weight();

  // Drawing related parameters.
  neuron_radius = 15;
  edge_x = 50;

  // If network is empty, then nothing to do.
  if(get_number_of_layers() == 0)
    return;
  
  // How far should neurons be from lower and upper edges of bitmap?
  int edge_y = 2 * neuron_radius;

  // How far should they be from each other in x direction?
  double gap_x;
  if(get_number_of_layers() > 1) {
    gap_x = static_cast<double>(SCREEN_W - 2 * edge_x) /
                                 (get_number_of_layers() - 1);

    // If neuron circles are horizontally too close, then reduces neuron_radius.
    if(gap_x < 3 * neuron_radius)
    neuron_radius = static_cast<int>(gap_x / 3 + 4);
  }

  // Variable for storing the appropriate vertical gap between neurons.
  double gap_y = 0;

  // This will remember the smallest vertical cap between neurons and is
  // used to keep vertical gap between neurons reasonable (by reducing
  // neuron radius).
  int min_gap_y = SCREEN_H;

  // For storing the position of a neuron.
  Position Pos;
  
  // Goes through all layers.
  for(int i = 0; i < get_number_of_layers(); i++) {

    // How many nodes in this layer?
    int nodes_in_lr = v_int_layers[i];

    // For gap_y there must be at least two nodes.
    if(nodes_in_lr > 1)
      gap_y = static_cast<double>(SCREEN_H - 2 * edge_y) / (nodes_in_lr - 1);

    // Updates min_gap_y.
    if( (gap_y < min_gap_y) and (gap_y != 0) )
      min_gap_y = static_cast<int>(gap_y);

    // x is same for all neurons in one layer.
    Pos.x = static_cast<int>(edge_x + i * gap_x);

    // Stores positions for all neurons in this layer.
    for(int j = 0; j < nodes_in_lr; j++) {
      Pos.y = static_cast<int>(edge_y + j * gap_y);
      vPositions.push_back(Pos);
    }
  }

  // If neuron circles are vertically too close, then reduces neuron_radius.
  if(min_gap_y < 3 * neuron_radius)
    neuron_radius = min_gap_y / 3 + 4;

}

// *************
// * Functions *
// *************

// - Function: draw - -
// Draws network on given bitmap.
void VisualANN::draw(BITMAP* pBitmap) {

  // A buffer for storing some text (output numbers).
  char buf[80] = "";

  for(int i = 0; i < num_of_nodes; i++) {

    // Position of this neuron.
    int x = vPositions[i].x,
        y = vPositions[i].y;

    // Output of this neuron.
    double output = vpNeurons[i]->output;

    // Calculates color of node.
    int color_of_node;
    if(output < 0)
      color_of_node = static_cast<int>(start_of_blue - output * 63);
    else
      color_of_node = static_cast<int>(start_of_white + output * 63);

    // If input layer, then will draw the input pins,
    // and circle color is also different.
    if(i < second_layer_start) {
      rectfill(pBitmap, 0, y - 1, x, y + 1, color_of_pin);

      // If we have ONLY input layer, then must also draw ouput pins.
      if(output_layer_start == 0)
        rectfill(pBitmap, x, y - 1, SCREEN_W - 1, y + 1, color_of_pin);
      
      circlefill(pBitmap, x, y, neuron_radius - 1, color_of_node);
      circle(pBitmap, x, y, neuron_radius, color_of_pin);
    }
    else {
      if(i >= output_layer_start) {  // Output layer needs output pins.
        rectfill(pBitmap, x, y - 1, SCREEN_W - 1, y + 1, color_of_pin);
      }

      // All non-input layer nodes are drawn here.
      circlefill(pBitmap, x, y, neuron_radius - 1, color_of_node);
      circle(pBitmap, x, y, neuron_radius, color_of_node_circle);
    }

    // Prints output value under node.
    sprintf(buf, "%f", output);
    textout(pBitmap, font, buf, x - 15, y + neuron_radius, color_of_txt);

    // Draws connections (except for input layer).
    if(i >= second_layer_start) {

      for(vector<Connection*>::iterator iter = vpNeurons[i]->
          vpConnections.begin() ; iter != vpNeurons[i]->vpConnections.end();
          iter++) {

        //Calculates color of connection.
        int col_con;
        double weight = (*iter)->weight;
        if(weight < 0)
          col_con= static_cast<int>(start_of_blue - weight / max_weight * 63);
        else
          col_con = static_cast<int>(start_of_white + weight / max_weight * 63);

        // Source coordinates.
        int sx = vPositions[(*iter)->source].x;
        int sy = vPositions[(*iter)->source].y;
        
        // Draws the connection.
        line(pBitmap, sx, sy, x, y, col_con);
      }
    }

  }
}