#include "rheolef.h"
#include "sphere.h"
#include "helmholtz_band_assembly.h"
using namespace std;
using namespace rheolef;
// w(x) = cos(theta)
Float w (const point& x) {
  Float rho, theta, phi;
  get_spherical_coordinates (x, rho, theta, phi);
  return cos(theta);
}
// grad_s w(x) = -sin(theta).t = (y*y,-x*y) en 2d sur Gamma=C(0,1)
//              = ? en 3d
point grad_s_w (const point& x) {
  point t (x[1], -x[0]);
  return x[1]*t;
}
int main (int argc, char**argv) {
  geo Lambda (argv[1]);
  Float tol = (argc > 2) ? atof(argv[2]) : 1e-10;
  size_t d = Lambda.dimension();
  space Vh (Lambda, "P1");
  field phi_h_lambda = interpolate(Vh, phi);
  geo band = banded_level_set (phi_h_lambda);

  // compute grad_s(w)
  space Bh  (band, "P1");
  Bh.block("isolated");
  space Bvh (band, "P1", "vector");
  field phi_h = interpolate(Bh, phi);
  form bs (Bh, Bvh, "grad_s", phi_h);
  form ms (Bh, Bh,  "mass_s", phi_h);
  // ms is scalar, not vectorial: go component by component
  field wh = interpolate(Bh, w);
  field lvh = bs*wh;
  field l0h = lvh[0];
  field l1h = lvh[1];
  csr<Float> M = band_assembly<Float> (ms, phi_h);
  ssk<Float> fact_M = ldlt(M);

  // comp 0
  vec<Float> L0(M.nrow(), 0.0);
  for (size_t i = 0; i < l0h.u.size(); i++) L0.at(i) = l0h.u.at(i);
  for (size_t i = l0h.u.size(); i < L0.size(); i++) L0.at(i) = 0;
  vec<Float> U0 (L0.size());
  U0 = fact_M.solve(L0);
  field g0h (Bh);
  for (size_t i = 0; i < g0h.u.size(); i++) g0h.u.at(i) = U0.at(i);
  // comp 1
  vec<Float> L1(M.nrow(), 0.0);
  for (size_t i = 0; i < l1h.u.size(); i++) L1.at(i) = l1h.u.at(i);
  for (size_t i = l1h.u.size(); i < L1.size(); i++) L1.at(i) = 0;
  vec<Float> U1 (L1.size());
  U1 = fact_M.solve(L1);
  field g1h (Bh);
  for (size_t i = 0; i < g1h.u.size(); i++) g1h.u.at(i) = U1.at(i);

  // all comps
  field gh (Bvh, 0.0);
  gh[0] = g0h;
  gh[1] = g1h;

  // error
  field pi_h_g = interpolate (Bvh, grad_s_w);
  field eh = gh - pi_h_g;
  Float err = sqrt(ms (eh[0],eh[0]) + ms (eh[1],eh[1]));
  cerr << "err = " << err << endl;
  cerr << "tol = " << tol << endl;
  return (err <= tol) ? 0 : 1;
}
