// Copyright (C) 2006-2009 Kent-Andre Mardal and Simula Research Laboratory
//
// This file is part of SyFi.
//
// SyFi is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 2 of the License, or
// (at your option) any later version.
//
// SyFi is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with SyFi. If not, see <http://www.gnu.org/licenses/>.

#include "diff_tools.h"
#include "symbol_factory.h"

#include <stdexcept>

namespace SyFi
{

	/* Alternative implementation for any vector representation
	GiNaC::ex div(GiNaC::ex v){
		using SyFi::nsd;
		using SyFi::p;

		if(nsd != v.nops())
	{
	throw std::invalid_argument("In div(v): The number of elements must equal nsd.");
	}

	GiNaC::ex ret = 0;
	for(int i=0; i<nsd; i++)
	{
	ret = ret + v.op(i).diff(p[i]);
	}
	return ret;
	}
	*/

	GiNaC::ex div(GiNaC::ex v)
	{
		using SyFi::nsd;
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		GiNaC::ex ret;
		if (GiNaC::is_a<GiNaC::matrix>(v))
		{
			GiNaC::matrix m = GiNaC::ex_to<GiNaC::matrix>(v);
			if ( m.cols() == 1 && m.rows() == nsd )
			{
				if (nsd == 1)
				{
					ret = diff(m,x);
				}
				else if (nsd == 2)
				{
					ret = diff(m.op(0),x) + diff(m.op(1),y) ;
				}
				else if (nsd == 3)
				{
					ret = diff(m.op(0),x) + diff(m.op(1),y) + diff(m.op(2),z) ;
				}
				else
				{
					throw std::runtime_error("Invalid nsd");
				}

			}
			else
			{
				GiNaC::matrix retm = GiNaC::matrix(m.cols(),1);
				if ( nsd != m.rows() )
				{
					throw(std::invalid_argument("The number of rows must equal nsd."));
				}
				GiNaC::symbol xr;
				GiNaC::ex tmp;
				for (unsigned int c=0; c<m.cols(); c++)
				{
					for (unsigned int r=0; r<m.rows(); r++)
					{
						if (r+1 == 1) xr = x;
						if (r+1 == 2) xr = y;
						if (r+1 == 3) xr = z;
						retm(c,0) += diff(m(c,r), xr);
					}
				}
				ret = retm;
			}
			return ret;

		}
		else if (GiNaC::is_a<GiNaC::lst>(v))
		{
			GiNaC::lst l = GiNaC::ex_to<GiNaC::lst>(v);
			return div(l);
		}
		throw std::invalid_argument("v must be a matrix or lst.");
	}

	GiNaC::ex div(GiNaC::ex v, GiNaC::ex G)
	{
		using SyFi::nsd;
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		GiNaC::ex ret;
		if (GiNaC::is_a<GiNaC::matrix>(v) && GiNaC::is_a<GiNaC::matrix>(G))
		{
			GiNaC::matrix m = GiNaC::ex_to<GiNaC::matrix>(v);
			GiNaC::matrix GG = GiNaC::ex_to<GiNaC::matrix>(G);
			if ( m.cols() == 1 && m.rows() == nsd && GG.rows() == nsd && GG.cols() == nsd )
			{
				if ( nsd == 1 || nsd == 2 || nsd == 3)
				{
					ret = GiNaC::numeric(0);
					GiNaC::symbol xj;
					for (unsigned int i=0; i< nsd; i++)
					{
						for (unsigned int j=0; j< nsd; j++)
						{
							if (j == 0) xj = x;
							if (j == 1) xj = y;
							if (j == 2) xj = z;
							ret += m.op(i).diff(xj)*GG(i,j);
						}
					}
				}
				else
				{
					throw std::runtime_error("Invalid nsd");
				}
			}
			else
			{
				throw std::invalid_argument("This functions needs v and G on the form: v.cols()=1, v.rows()=G.rows()=G.cols()=nsd.");
			}
		}
		else if (GiNaC::is_a<GiNaC::lst>(v))
		{
			GiNaC::lst l = GiNaC::ex_to<GiNaC::lst>(v);
			return div(l,G);
		}
		else
		{
			throw std::invalid_argument("v must be a matrix or lst.");
		}
		return ret;
	}

	GiNaC::ex div(GiNaC::lst& v)
	{
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		using SyFi::nsd;
		nsd = v.nops();
		GiNaC::ex ret;
		if (nsd == 1)
		{
			ret = v.op(0).diff(x);
		}
		else if (nsd == 2)
		{
			ret = v.op(0).diff(x) + v.op(1).diff(y);
		}
		else if (nsd == 3)
		{
			ret = v.op(0).diff(x) + v.op(1).diff(y) + v.op(2).diff(z);
		}
		return ret;
	}

	GiNaC::ex div(GiNaC::lst& v, GiNaC::ex G)
	{
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		using SyFi::nsd;
		nsd = v.nops();
		GiNaC::ex ret;
		if (GiNaC::is_a<GiNaC::matrix>(G))
		{
			GiNaC::matrix GG = GiNaC::ex_to<GiNaC::matrix>(G);
			if ( nsd != GG.cols() || nsd != GG.rows())
			{
				throw(std::invalid_argument("The number of rows and cols in G must equal the size of v."));
			}
			if (nsd == 1 || nsd == 2 || nsd == 3 )
			{
				GiNaC::symbol xj;
				ret = GiNaC::numeric(0);
				for (unsigned int i=0; i< nsd; i++)
				{
					for (unsigned int j=0; j< nsd; j++)
					{
						if (i == 0) xj = x;
						if (i == 1) xj = y;
						if (i == 2) xj = z;
						ret += v.op(i).diff(xj)*GG(i,j);
					}
				}
			}
			else
			{
				throw std::runtime_error("Invalid nsd");
			}
		}
		else
		{
			throw std::invalid_argument("v must be a matrix.");
		}
		return ret;
	}

	GiNaC::ex div(GiNaC::exvector& v)
	{
		using SyFi::nsd;
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		GiNaC::ex ret;
		if (nsd == 2)
		{
			ret = v[0].diff(x) + v[1].diff(y);
		}
		else if (nsd == 3)
		{
			ret = v[0].diff(x) + v[1].diff(y) + v[2].diff(z);
		}
		return ret;
	}

	GiNaC::ex grad(GiNaC::ex f)
	{
		using SyFi::nsd;
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		if (GiNaC::is_a<GiNaC::matrix>(f))
		{
			GiNaC::matrix m = GiNaC::ex_to<GiNaC::matrix>(f);
			GiNaC::matrix ret_m(nsd,m.rows());
			for (unsigned int r=0; r< m.rows(); r++)
			{
				if (nsd == 1)
				{
					//         ret_m(0,r) = diff(m.op(r),x);
					return diff(f, x);
				}
				else if ( nsd == 2)
				{
					ret_m(0,r) = diff(m.op(r),x);
					ret_m(1,r) = diff(m.op(r),y);
				}
				else if ( nsd == 3)
				{
					ret_m(0,r) = diff(m.op(r),x);
					ret_m(1,r) = diff(m.op(r),y);
					ret_m(2,r) = diff(m.op(r),z);
				}
			}
			return ret_m;
		}
		else
		{

			if (nsd == 1)
			{
				//      return GiNaC::matrix(nsd,1,GiNaC::lst(diff(f,x)));
				return diff(f,x);
			}
			else if ( nsd == 2)
			{
				return GiNaC::matrix(nsd,1,GiNaC::lst(diff(f,x), diff(f,y)));
			}
			else if ( nsd == 3)
			{
				return GiNaC::matrix(nsd,1,GiNaC::lst(diff(f,x), diff(f,y), diff(f,z)));
			}
			else
			{
				throw(std::invalid_argument("nsd must be either 1, 2, or 3."));
				return GiNaC::matrix();
			}
		}
	}

	GiNaC::ex grad(GiNaC::ex f, GiNaC::ex G)
	{
		using SyFi::nsd;
		using SyFi::x;
		using SyFi::y;
		using SyFi::z;

		GiNaC::symbol xr;
		if ( GiNaC::is_a<GiNaC::matrix>(G))
		{
			GiNaC::matrix GG = GiNaC::ex_to<GiNaC::matrix>(G);

			if (! (GG.rows() == nsd && GG.cols() == nsd ))
			{
				throw(std::invalid_argument("The number of cols/rows in G must equal nsd."));
			}

			if (GiNaC::is_a<GiNaC::matrix>(f) )
			{
				GiNaC::matrix m = GiNaC::ex_to<GiNaC::matrix>(f);
				GiNaC::matrix ret_m(nsd,m.rows());
				for (unsigned int k=0; k< m.rows(); k++)
				{
					for (unsigned int c=0; c<nsd; c++)
					{
						for (unsigned int r=0; r<nsd; r++)
						{
							if (r == 0) xr = x;
							if (r == 1) xr = y;
							if (r == 2) xr = z;
							ret_m(c,k) += diff(f,xr)*GG(r,c);
						}
					}
				}

				return ret_m;
			}
			else
			{
				GiNaC::matrix ret_m(nsd,1);
				for (unsigned int c=0; c<nsd; c++)
				{
					for (unsigned int r=0; r<nsd; r++)
					{
						if (r == 0) xr = x;
						if (r == 1) xr = y;
						if (r == 2) xr = z;
						ret_m(c,0) += diff(f,xr)*GG(r,c);
					}
				}
				return ret_m;
			}
		}
		else
		{
			throw(std::invalid_argument("G must be a matrix."));
		}
	}

}								 // namespace SyFi
