/**
 * Copyright (C) 2007-2014 Lawrence Murray
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option)
 * any later version.
 * 
 * @author Lawrence Murray <lawrence@indii.org>
 */
#ifndef INDII_CLUSTER_KMEANSCLUSTERER_HPP
#define INDII_CLUSTER_KMEANSCLUSTERER_HPP

#include "boost/random.hpp"

namespace indii {
/**
 * K-means clustering.
 *
 * @tparam D Distance metric type.
 */
template<class D>
class KMeansClusterer {
public:
  /**
   * Distance type.
   */
  typedef D distance_type;

  /**
   * Constructor
   *
   * @param seed Random number seed.
   */
  KMeansClusterer(const int seed);

  /**
   * Cluster data. Continues until convergence or a maximum number of
   * iterations is reached.
   *
   * @tparam V1 Vector type.
   * @tparam V2 Vector type.
   * @tparam V3 Vector type.
   * @tparam V4 Vector type.
   * @tparam V5 Vector type.
   *
   * @param data Data set.
   * @param[out] ks Component assignments.
   * @param[out] ds Distance of each data point from component mean.
   * @param[out] mus Component means. The size of this determines the number
   * of components.
   * @param[out] cs Count of number of data points assigned to each component.
   * @param maxIters Maximum number of iterations to perform, zero for
   * no limit (use with caution!).
   */
  template<class V1, class V2, class V3, class V4, class V5>
  void cluster(const V1& data, V2& ks, V3& ds, V4& mus, V5& cs,
      const int maxIters = 0);

  /**
   * Initialise assignments.
   *
   * @tparam V1 Vector type.
   * @tparam V2 Vector type.
   * @tparam V3 Vector type.
   *
   * @param data Data set.
   * @param ks Component assignments.
   * @param cs Count of number of data points assigned to each component.
   */
  template<class V1, class V2, class V3>
  void init(const V1& data, V2& ks, V3& cs);

  /**
   * Assign data point to components.
   *
   * @tparam T1 Value type.
   * @tparam V1 Vector type.
   * @tparam V2 Vector type.
   *
   * @param x Data point.
   * @param[in,out] ds Distance of each point from component mean.
   * @param mus Component means.
   *
   * @return Index of component to which @p data is assigned.
   */
  template<class T1, class V1, class V2>
  int assign(const T1& x, V1& ds, V2& mus);

  /**
   * Assign data points to components.
   *
   * @tparam V1 Vector type.
   * @tparam V2 Vector type.
   * @tparam V3 Vector type.
   * @tparam V4 Vector type.
   * @tparam V5 Vector type.
   *
   * @param data Data set.
   * @param[in,out] ks Component assignments.
   * @param[in,out] ds Distance of each point from component mean.
   * @param mus Component means.
   * @param cs Count of number of data points assigned to each component.
   *
   * @return Number of assignments that have changed from input @p ks.
   */
  template<class V1, class V2, class V3, class V4, class V5>
  int assign(const V1& data, V2& ks, V3& ds, const V4& mus, V5& cs);

  /**
   * Update means.
   *
   * @tparam V1 Vector type.
   * @tparam V2 Vector type.
   * @tparam V3 Vector type.
   * @tparam V4 Vector type.
   *
   * @param data Data set.
   * @param ks Component assignments.
   * @param[out] mus Component means.
   * @param cs Count of number of data points assigned to each component.
   */
  template<class V1, class V2, class V3, class V4>
  void update(const V1& data, const V2& ks, V3& mus, const V4& cs);

private:
  /*
   * Random number generator. Mersenne Twister would be better than Linear
   * Congruential for random number generation, but appears to give compile
   * errors in Visual C++ 9
   */
  boost::minstd_rand rng;
};
}

template<class D>
indii::KMeansClusterer<D>::KMeansClusterer(const int seed) {
  rng.seed(seed);
}

template<class D>
template<class V1, class V2, class V3, class V4, class V5>
void indii::KMeansClusterer<D>::cluster(const V1& data, V2& ks, V3& ds, V4& mus,
    V5& cs, const int maxIters) {
  /* pre-condition */
  assert(data.size() == ks.size());
  assert(data.size() == ds.size());
  assert(mus.size() == cs.size());

  int changed, iters = 0;
  init(data, ks, cs);
  do {
    update(data, ks, mus, cs);
    changed = assign(data, ks, ds, mus, cs);
    ++iters;
  } while (changed > 0 && (iters < maxIters || maxIters == 0));
}

template<class D>
template<class V1, class V2, class V3>
void indii::KMeansClusterer<D>::init(const V1& data, V2& ks, V3& cs) {
  /* pre-condition */
  assert(data.size() == ks.size());

  boost::uniform_int<> dist(0, cs.size() - 1);
  boost::variate_generator<boost::minstd_rand&, boost::uniform_int<> > rc(rng,
      dist);
  int i, k;

  for (k = 0; k < (int)cs.size(); ++k) {
    cs[k] = 0;
  }
  for (i = 0; i < (int)data.size(); ++i) {
    k = rc();
    ks[i] = k;
    ++cs[k];
  }
}

template<class D>
template<class V1, class V2, class V3, class V4>
void indii::KMeansClusterer<D>::update(const V1& data, const V2& ks, V3& mus,
    const V4& cs) {
  /* pre-condition */
  assert(data.size() == ks.size());
  assert(mus.size() == cs.size());

  std::vector<bool> first(mus.size(), true);
  int i, k;

  for (i = 0; i < (int)data.size(); ++i) {
    k = ks[i];
    if (first[k]) {
      mus[k] = data[i];
      first[k] = false;
    } else {
      mus[k] += data[i];
    }
  }
  for (k = 0; k < (int)mus.size(); ++k) {
    mus[k] /= cs[k];
    D::prepare(mus[k]);
  }
}

template<class D>
template<class T1, class V1, class V2>
int indii::KMeansClusterer<D>::assign(const T1& x, V1& ds, V2& mus) {
  /* pre-conditions */
  assert(ds.size() == mus.size());

  int k, kmin;
  float d, dmin;

  k = 0;
  kmin = 0;
  d = D::distance(x, mus[0]);
  dmin = d;
  ds[k] = d;
  for (k = 1; k < (int)mus.size(); ++k) {
    d = D::distance(x, mus[k]);
    if (d < dmin) {
      kmin = k;
      dmin = d;
    }
    ds[k] = d;
  }
  return kmin;
}

template<class D>
template<class V1, class V2, class V3, class V4, class V5>
int indii::KMeansClusterer<D>::assign(const V1& data, V2& ks, V3& ds,
    const V4& mus, V5& cs) {
  /* pre-conditions */
  assert(data.size() == ks.size());
  assert(data.size() == ds.size());
  assert(mus.size() == cs.size());

  typedef typename V3::value_type T3;

  int i, nchanges = 0;
  int k, kmin;
  T3 d, dmin;

  for (i = 0; i < (int)data.size(); ++i) {
    kmin = 0;
    dmin = D::distance(data[i], mus[0]);
    for (k = 1; k < (int)mus.size(); ++k) {
      d = D::distance(data[i], mus[k]);
      if (d < dmin) {
        kmin = k;
        dmin = d;
      }
    }
    k = ks[i];
    if (k != kmin) {
      ks[i] = kmin;
      --cs[k];
      ++cs[kmin];
      ++nchanges;
    }
    ds[i] = dmin;
  }
  return nchanges;
}

#endif
