#include <cassert>
#include <cmath>
#include <fstream>
#include <iostream>
#include "Img.h"
using namespace std;


class WeightedIndex
{
public:
    // If this doesn't get properly initialized, an assertion will throw
    // somewhere to tell me it failed to get set up. (i.e, when used to index).
    WeightedIndex() { i = -1; j = -1; weight = INFINITY; }
    WeightedIndex(int x,int y, double w) { i = x; j = y; weight = w; }

    int i;
    int j;
    double weight;
};

class NearestNeighbors
{
public:
    NearestNeighbors() {
        k = -1;
        neighbors = 0;
    }
    
    void init(int num_neighbors) {
        assert( num_neighbors >= 0 );
        k = num_neighbors;
        neighbors = new WeightedIndex[k];
    }

    ~NearestNeighbors() {
        assert(neighbors != 0);
        delete[] neighbors;
    }

    WeightedIndex getNeighborInd(int ki) {
        assert(ki >= 0 && ki < k);
        return neighbors[ki];
    }

    void addNeighbor(int i, int j, double weight);

    int get_k() {
        return k;
    }

private:
    int k;
    WeightedIndex *neighbors;
};



class KNN
{
public:
    KNN(){ 
        k = -1; w = -1; h = -1; nradius = -1;
        nns = 0;
    }

    void init(Img& img, int kn, int nbdrad) {
        assert(!nns);
        assert(img.is_data());
        assert(kn > 0);
        assert(nbdrad > 0);

        k = kn;
        nradius = nbdrad;

        w = img.get_w();
        h = img.get_h();

        nns = new NearestNeighbors[w*h];
        for(int i = 0; i < w*h; ++i)
            nns[i].init(k);

        compute_nns(img);
    }

    void init(const char* filename)
    {
        ifstream in(filename);
        in>>k;
        in>>w>>h;
        in>>nradius;

        nns = new NearestNeighbors[w*h];
        
        for(int i = 0; i < w*h; ++i)
            nns[i].init(k);

        for(int f = 0; f < w*h; ++f)
        {
            for(int d = 0; d < k; ++d)
            {
                int i, j;
                double weight;
                in>>i>>j>>weight;
                nns[f].addNeighbor(i,j,weight);
            }
        }
    }


    void compute_nns(Img& img);

    int get_k() {
        return k;
    }
    int get_nradius() {
        return nradius;
    }
    int get_w() {
        return w;
    }
    int get_h() {
        return h;
    }

    void print_nns(int i, int j)
    {
        assert(i >= 0 && i < w);
        assert(j >= 0 && j < h);
        
        for(int d = 0; d < k; ++d)
        {
            WeightedIndex t = nns[ind(i,j)].getNeighborInd(d);
            cout<<t.i<<" "<<t.j<<" "<<t.weight<<" ";
        }
    }

    void print_all_nns()
    {
        // Output to custom file format.
        cout<<k<<endl;
        cout<<w<<" "<<h<<endl;
        cout<<nradius<<endl;
        for(int j = 0; j < h; ++j)
        {
            for(int i = 0; i < w; ++i)
            {
                print_nns(i,j);
                cout<<endl;
            }
        }
    }

    NearestNeighbors* getNNs(int i, int j)
    {
        assert(i >= 0 && i < w);
        assert(j >= 0 && j < h);
        return &nns[ind(i,j)];
    }


private:
    NearestNeighbors* nns;
    int w;
    int h;

    int k;
    int nradius; // Neighborhood radius. Square nbd.

    int ind(int i, int j) {
        return w*j + i;
    }
};

