#include "ga.h"

#define OPTIMAL UINT_MAX/16

Element search(Population* pop, double* wheel, double key);
Element* maxFitness(Element* buffer, BITS numStr);
void initPop(Population* pop);

static double randomFitnessArray[65536];

/* initialize the random fitness function array */
void initFitArray(int seed) {
  int i;
  srand(seed);
  for (i=0; i<65536; i++) {
    randomFitnessArray[i] = rand();
  }
}

/* the fitness function for the dummy population */
double fitness(BITS str) {
  double a,b,c,d;
  //a = gaussian(str, 1/(10000*(double)UINT_MAX), (double)OPTIMAL, (double)UINT_MAX/7);
  //b = gaussian(str, 1/(1000000000*(double)UINT_MAX), 5.0*(double)OPTIMAL, (double)UINT_MAX/2);
  //c = gaussian(str, 1/(1000000000*(double)UINT_MAX), (double)OPTIMAL, (double)UINT_MAX/2);
  d = randomFitnessArray[str>>16];
  /* a+b is double peaked
   * c is single peak
   * d is random
   */
  return d;
}

/* recalculates the fitnesses for the population and stores them in the data
 * structure
 * flag determines which fitness function is used */
void applyFitness(Population* pop, BITS flag) {
  BITS i;
  for (i=0; i<pop->numStr; i++) {
    if (flag == 0) {
      pop->strings[i].fitness = fitness(pop->strings[i].str);
    } else {
      //pop->strings[i].fitness = exp(MAX_STR_EPOCHS - (gaRound(pop->strings[i].str, NUMSTRINGS)));
      //pop->strings[i].fitness = UINT_MAX - pop->strings[i].str;
      //pop->strings[i].fitness = MAX_STR_EPOCHS - (gaRound(pop->strings[i].str, NUMSTRINGS));
      pop->strings[i].fitness = 1.0/((double)gaRound(pop->strings[i].str, NUMSTRINGS));
    }
  }
}

/* calculates the average fitness of a population */
double avgFitness(Population* pop) {
  BITS i;
  double total = 0;
  for (i=0; i<pop->numStr; i++) {
    total += pop->strings[i].fitness;
  }
  return total/(pop->numStr);
}

/* calculates the average string in a population */
double avgStr(Population* pop) {
  BITS i;
  double total = 0;
  for (i=0; i<pop->numStr; i++) {
    total += pop->strings[i].str;
  }
  return total/(pop->numStr);
}

/* checks if the population has converged to within percent of the optimal */
BITS isConverged(Population* pop, double percent) {
  double avg = avgStr(pop);
  double upper = (1 + (percent/2))*OPTIMAL;
  double lower = (1 - (percent/2))*OPTIMAL;
  if ((avg <= upper) && (avg >= lower)) {
    return 1;
  } else {
    return 0;
  }
  return 0;
}

/* prints an array of BITS.  if flag==1, prints them in binary */
void printPop(Population* pop, BITS flag, FILE* output) {
  BITS i;
  for (i=0; i<pop->numStr; i++) {
    if (flag) {
      printBinary(output, pop->strings[i].str);
      fprintf(output, " %f\n", pop->strings[i].fitness);
    } else {
      fprintf(output, "%u - %u\n", i, pop->strings[i].str);
    }
  }
}

/* prints an array of double that represents the fitness of the population */ 
void printFit(Population* pop, FILE* output) {
  BITS i;
  for (i=0; i<pop->numStr; i++) {
    fprintf(output, "%u - %f\n", i, pop->strings[i].fitness);
  }
}

/* the reproduction operator */
void reproduction(Population* population) {

  double totalFitness = 0;
  BITS i;
  Element* newPop;
  double* roulette;

  /* allocate memory */
  roulette = (double*)safe_malloc(population->numStr * sizeof(double));
  newPop= (Element*)safe_malloc(population->numStr * sizeof(Element));

  for (i=0; i<population->numStr; i++) {
    totalFitness += population->strings[i].fitness; /* tally up total fitness */
    /* at each spot, save the total fitness SO FAR */
    roulette[i] = totalFitness; 
    //printf("%1.12f\n", roulette[i]);
  }
  /* find out what the new population will look like */
  for (i=0; i<population->numStr; i++) {
    newPop[i] = search(population, roulette, randomFloat(0, totalFitness));
  }
  /* copy new population over old population */
  free(population->strings);
  population->strings = newPop;
  
  free(roulette);
}

/* this swaps the ends of the strings pointed to by first and second */
void swapTails(BITS* first, BITS* second) {
  BITS position;
  BITS temp;
  BITS mask = ~0;

  /* choose a random place bit position to splice */
  position = randomInt(0, (BITS_PER_BYTE  * sizeof(BITS))-1);
//printf("pos: %u\n", position);
  /* create a mask that has leading 0's and trailing 1's */
  mask >>= position;
  /* pick out the last few positions of second */
  temp = *second & mask;
  /* overwrite the tail of second with the tail of first */
  *second = (*second & ~mask) | (*first & mask);
  /* overwrite the tail of first with temp */
  *first = (*first & ~mask) | temp;
}

/* prior saves a copy of what the population was before crossover
 * to help with implementing the election operator
 * the memory needs to be allocated before calling crossover
 */
void crossover(Population* population, double rate, Population* prior) {
  int maxElem = (population->numStr)-1;
  BITS temp, chosen, i;

  /* this does random selection WITHOUT REPLACEMENT */
  while (maxElem >= 0) {
    /* chose a string */
    chosen = randomInt(0,maxElem);
    /* get ready to swap that with the last non-picked string */
    temp = population->strings[chosen].str;
    /* do the swap */
    population->strings[chosen].str = population->strings[maxElem].str;
    population->strings[maxElem].str = temp;
    /* decrement the total number of unpicked strings */
    maxElem--;
    /* repeat until no unpicked strings remain */
  }
  /* now the strings are in a random order */
  /* now loop thru population again and swapTails */

  for (i=0; i<population->numStr; i+=2) {
    prior->strings[i].str = population->strings[i].str;
    prior->strings[i].fitness= population->strings[i].fitness;
    prior->strings[i+1].str = population->strings[i+1].str;
    prior->strings[i+1].fitness = population->strings[i+1].fitness;
    if (event(rate)) {
      swapTails(&(population->strings[i].str), &(population->strings[i+1].str));
    }
  }
}

/* this inverts the bit at position in str */
void flipBit(BITS* str, BITS position) {
  /* try XNOR str with 111101111 */
  /* a XNOR b == !(!a*b + a*!b) */
  BITS mask = 1;
  mask <<= position; /* mask is now 000010000 */
  mask = ~mask; /* mask is now 111101111 */
  *str = ~(((~(*str)) & mask) | ((~mask) & (*str)));
}

/* the mutation operator
 * rate is the mutation rate and flag is passed to applyFitness to know which
 * fitness function to use */
void mutate(Population* population, double rate, BITS flag) {
  BITS i, j, length;

  double blah, foo;
  blah = 0;
  foo = 0;

  //printf("rate is %f and rateMax is %u\n", rate, rateMax);
  length = BITS_PER_BYTE  * sizeof(BITS);
  for (i=0; i<population->numStr; i++) { /* loop thru population */
    for (j=0; j<length; j++) { /* loop thru string */
      if (event(rate)) { /* if we have a mutation */
	blah++;
        flipBit(&(population->strings[i].str), j); /* flip the bit */
      } else {
	foo++;
      }
    }
  }
  /* update the fitness */
  applyFitness(population, flag);
}

/* this will compare (pairwise) old and next and keep the best ones in next
 * so yes, the contents of next are modified and no, the contents of old aren't
 */
void election(Population* old, Population* next) {
  Element* temp;
  BITS i;
  Element* max;
  Element* max2;
  Element winnerA, winnerB;

  /* allocate a temporary array for 4 elements */
  temp = (Element*)safe_malloc(4*sizeof(Element));

  for(i=0; i<next->numStr; i+=2) { /* loop thru population */
    /* set up temp array */
    temp[0] = old->strings[i];
    temp[1] = old->strings[i+1];
    temp[2] = next->strings[i];
    temp[3] = next->strings[i+1];
    max = maxFitness(temp, 4); /* find the max fitness of the array */
    winnerA = *max; /* save it */

    *max = temp[3]; /* swap the most fittest with the last element */
    max2 = maxFitness(temp, 3); /* now find the fittest of the remaining 3 */
    winnerB = *max2; /* save the 2nd most fittest */

    next->strings[i] = winnerA; /* the 2 most fit procede to next generation */
    next->strings[i+1] = winnerB;
  }
  free(temp);
}

/* searches for the position in wheel that is less than key for the roulette
 * wheel implementation used by reproduction */
Element search(Population* pop, double* wheel, double key) {
  BITS i = 0;
  /* this should be a binary search to improve efficiency */
  while ((wheel[i]<key) && (i<pop->numStr)) {
    i++;
  }
  return pop->strings[i];
}

/* given an array of elements and the size of the array, it returns a pointer
 * to the element with the highest fitness.  used by the election operator */
Element* maxFitness(Element* buffer, BITS num) {
  BITS i;
  Element* max;
  max = buffer; /* assume the first element is max */
  double a,b; /* temporary variables */
  for (i=1; i<num; i++) { /* loop thru array */
    a = buffer[i].fitness;
    b = (*max).fitness;
    if (a > b) { /* if the current element is fitter than the previous max */
      max = &(buffer[i]); /* the current element is the new max */
    }
  }
  return max;
}

/* initializes the population to random values */
void initPop(Population* pop) {
  BITS j;

  for (j=0; j<pop->numStr; j++) {                               
    pop->strings[j].str = randomInt(0, INIT_MAX);
  }
}

/* INPUT:  a 32-bit string.  each byte encodes a mutation rate for a
 *	certain amount of time and the number of dummy strings in the
 *	population.
 * OUTPUT:  the number of epochs it takes for the other population to
 * 	converge.
 */
BITS gaRound(BITS mutRates, BITS num) {
  Population population;
  Population backupCopy;
  BITS mutRate;
  double oldAvg, newAvg;
  BITS numEpochs=0;
  BITS i;

  /* allocate space for the population and a backup copy of it */
  population.numStr = num;
  population.strings = (Element*)safe_malloc(population.numStr * sizeof(Element));
  backupCopy.numStr = num;
  backupCopy.strings = (Element*)safe_malloc(backupCopy.numStr * sizeof(Element));

  /* initialize population */
  initPop(&population);

  oldAvg = -1;
  newAvg = 0;

  //while ((oldAvg != newAvg) && (i<EPOCHS)) {
  /* while we aren't fully converged and haven't gone MAX_STR_EPOCHS */
  while ((!isConverged(&population, FINAL_RATE)) && (numEpochs < MAX_STR_EPOCHS)) {
    applyFitness(&population, 0); /* calculate fitnesses */
    //printf("Epoch: %u\n", numEpochs);
    //printf("Avg String: %f\n", avgStr(&population));
    //printf("optimal: %f\n", (double)OPTIMAL);

    /*
    momo = fopen("lala", "a");
    fprintf(momo, "%f\n", avgStr(&population)-(double)OPTIMAL);
    fclose(momo);
    */

    //kaka = fopen("pupu", "a");
    //fprintf(kaka, "%f\n", avgFitness(&population));

    //printFit(&population, stdout);

    //printf("%f\n", avgFitness(&population));
    //printf("%f\n", avgStr(&population)-(double)OPTIMAL);
    reproduction(&population); /* reproduction operator */
    //printFit(&population, stdout);
    //printf("%f\n", avgFitness(&population));

    crossover(&population, PCROSS, &backupCopy); /* crossover operator */
    if (isConverged(&population, THIRD_RATE)) { /* decide which mutation rate to use */
      mutRate = (mutRates >> 0) & 0xff;
      //printf("converged to third rate\n");
    } else {
      if (isConverged(&population, SECOND_RATE)) {
        mutRate = (mutRates >> 8) & 0xff;
        //printf("converged to second rate\n");
      } else {
        if (isConverged(&population, FIRST_RATE)) {
          mutRate = (mutRates >> 16) & 0xff;
          //printf("converged to first rate\n");
	} else {
          mutRate = mutRates >> 24;
          //printf("not converged to first rate\n");
	}
      }
    }
    mutate(&population, (double)mutRate/(double)RATE_DIVISOR, 0); /* mutation operator */
    //mutate(&population, .0033, 0);
    election(&backupCopy, &population); /* election operator */
    oldAvg = newAvg;
    newAvg = avgFitness(&population);

    numEpochs++;
    i++;
  }
  /* clean up */
  free(population.strings);
  free(backupCopy.strings);

  return numEpochs;
}

