/**
 * Parametric SCC based Model Checking
 *  
 * This is a stand-alone tool which performs model checking
 * for parametric discrete-time Markov Chains (PDTMCs).
 * 
 * Copyright (c) 2013 RWTH Aachen University.
 * Authors: Florian Corzilius, Nils Jansen, Matthias Volk
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see
 * http://www.gnu.org/licenses/gpl.html.
 * 
 * 
 * Main Contact:
 * 
 * Nils Jansen
 * Theory of Hybrid Systems
 * RWTH Aachen
 * 52056 Aachen
 * Germany
 * nils.jansen@cs.rwth-aachen.de
 */

#include "../defines.h"
#include "Cancellator2.h"
#include "Parameters.h"
#include "Factorization.h"
#include <sstream>
#include <list>

#ifdef USE_LOGLIB
#include <log4cplus/logger.h>
#endif

using namespace log4cplus;

namespace parametric {

Logger cancellogger = Logger::getInstance("Cancellator");

/**
 * Constructor
 */
Cancellator::Cancellator() {
	cancellogger.setLogLevel(DEFAULT_LOG_LEVEL);
	cancelTime = 0;
	additionTime = 0;
	commonCoefficientTime = 0;
}

/**
 * Constructor
 * @param cancellator cancellator
 * @return
 */
Cancellator::Cancellator(const Cancellator& cancellator) {
	cancellogger.setLogLevel(DEFAULT_LOG_LEVEL);
	cancelTime = cancellator.cancelTime;
	additionTime = cancellator.additionTime;
	commonCoefficientTime = cancellator.commonCoefficientTime;
}

/**
 * Destructor
 * @return
 */
Cancellator::~Cancellator() {
}

/**
 * Used for singleton
 * @return instance
 */
Cancellator& Cancellator::getInstance() {
	static Cancellator instance;
	return instance;
}

/**
 * Cancellation of rational function
 * @param rat rational function, which is canceled afterwards
 * @param completely if true, factorization is refined
 */
void Cancellator::cancel(Rational& rat, bool completely) {
	if (*(rat.numerator) == 0) {
		rat.setDenominator(Parameters::getInstance().getOne());
		return;
	}

	const clock_t t_start = clock();

#ifdef DEBUGGING
	Rational original = rat;
#endif

	//Try to cancel
	//Denominator is first, because its factorization may be smaller
	GCDResult gcdResult = computeGCD(rat.denominator, rat.numerator, completely);
	if (*(gcdResult.gcd) != 1) {
		rat.setDenominator(gcdResult.rest1);
		rat.setNumerator(gcdResult.rest2);
	}

#ifdef DEBUGGING
	//Check result
	ex gcdRes = gcdResult.gcd->getExpression();
	assert(original.denominator->getExpression().expand() == (rat.denominator->getExpression() * gcdRes).expand());
	assert(original.numerator->getExpression().expand() == (rat.numerator->getExpression() * gcdRes).expand());
#endif

	//Stop time needed for cancellation
	const clock_t t_end = clock();
	cancelTime += t_end - t_start;
}

/**
 * GCD computation
 * @param pol1 first polynomial
 * @param pol2 second polynomial
 * @param completely if true, factorization will be refined
 * @param gcdresult with gcd, rest1 and rest2
 */
GCDResult Cancellator::computeGCD(constPoly pol1, constPoly pol2, bool completely) {
	//TODO change if Ginac should be used
	const bool USE_FACTORIZATION = true;

	if (USE_FACTORIZATION) {
		return gcdFactorization(pol1, pol2, completely);
	} else {
		return gcdGinac(pol1, pol2);
	}
}

/**
 * GCD computation with use of GiNaC's gcd function.
 * @param pol1 first polynomial
 * @param pol2 second polynomial
 * @param gcdresult with gcd, rest1 and rest2
 */
GCDResult Cancellator::gcdGinac(constPoly pol1, constPoly pol2) {
	ex exprGCD;
	ex expr1, expr2;
	exprGCD = gcd(pol1->getExpression(), pol2->getExpression(), &expr1, &expr2, false);

	//Avoid -1 occuring somewhere
	if (exprGCD == -1 || expr1 == -1 || expr2 == -1) {
		exprGCD = -exprGCD;
		expr1 = -expr1;
		expr2 = -expr2;
	}

	GCDResult result(Parameters::getInstance().createPolynomial(exprGCD.expand()),
			Parameters::getInstance().createPolynomial(expr1.expand()),
			Parameters::getInstance().createPolynomial(expr2.expand()));

	//Add factors
	if (*result.rest1 != 1) {
		pol1->getFactorization()->setFactors(result.gcd, result.rest1);
	}
	if (*result.rest2 != 1) {
		pol2->getFactorization()->setFactors(result.gcd, result.rest2);
	}

	//Checking result
	assert(pol1->getExpression().expand() == (expr1 * exprGCD).expand());
	assert(pol2->getExpression().expand() == (expr2 * exprGCD).expand());

	return result;
}

/**
 * Computes GCD of two polynomials and rest of both polynomials.
 * All factors of the first and second polynomial are stored in their own FactorMap.
 * By iterating over the factors of the first polynomial, we search for this polynomial
 * in the second factor map. Finally we gain a partition of the first polynomial as
 * factors of the gcd and factors of the rest. The same yields for the second polynomial.
 * During the computation the factorization of both original polynomials is refined.
 * @param pol1 first polynomial
 * @param pol2 second polynomial
 * @param if true, gcd will be computed completely by refining the current factorization
 * @return gcdresult with gcd, rest1 and rest2
 */
GCDResult Cancellator::gcdFactorization(constPoly pol1, constPoly pol2, bool completely) {
	//Avoid trivial cases
	if (*pol1 == 1 || *pol2 == 1 || *pol1 == 0 || *pol2 == 0) {
		//Notice: case for 0 is not really correct, but suffices for our computations
		return GCDResult(Parameters::getInstance().getOne(), pol1, pol2);
	} else if (pol1 == pol2) {
		return GCDResult(pol1, Parameters::getInstance().getOne(), Parameters::getInstance().getOne());
	}

	FactorList gcd;
	FactorMap factors1 = pol1->getFactorization()->getFactors();
	FactorMap factors2 = pol2->getFactorization()->getFactors();

	//Construct two (temporal) sets containing factorization of rests at the end
	FactorList factorList1;
	FactorList factorList2;
	for (factors_it iter1 = factors1.begin(); iter1 != factors1.end(); /*Increment done later*/) {
		bool found;
		//Search for common factors
		factors_it iter2 = factors2.find(iter1->first);
		if (iter2 != factors2.end()) {
			found = true;
		} else {
			//Possible sign
			constPoly minusIter1 = Parameters::getInstance().getMinusOne()->times(iter1->first);
			iter2 = factors2.find(minusIter1);
			found = (iter2 != factors2.end());
		}
		if (found) {
			//Part of gcd found
			constPoly factor = iter1->first;
			unsigned int count1 = iter1->second;
			unsigned int count2 = iter2->second;

			constPoly minusIter2 = Parameters::getInstance().getMinusOne()->times(iter2->first);
			//Possible sign for rest2
			if (iter1->first == minusIter2) {
				if (count1 >= count2) {
					factorList2.push_back(Factor(Parameters::getInstance().getMinusOne(), count2));
				} else {
					factorList2.push_back(Factor(Parameters::getInstance().getMinusOne(), count1));
				}
			}

			if (count1 == count2) {
				gcd.push_back(Factor(factor, count1));
				//Remove gcd from both factorizations
				iter1 = factors1.erase(iter1);
				factors2.erase(iter2);
			} else if (count1 < count2) {
				gcd.push_back(Factor(factor, count1));
				//Remove gcd from first factorization
				iter1 = factors1.erase(iter1);
				//Update second factor
				iter2->second = count2 - count1;
			} else {
				//count2 < count1
				gcd.push_back(Factor(factor, count2));
				//Add updated factor to list
				factorList1.push_back(Factor(iter1->first, count1 - count2));
				iter1++;
				//Remove gcd from second factorization
				factors2.erase(iter2);
			}
		} else {
			//Add factor to list
			factorList1.push_back(Factor(iter1->first, iter1->second));
			iter1++;
		}
	}

	for (factors_it iter = factors2.begin(); iter != factors2.end(); iter++) {
		factorList2.push_back(Factor(iter->first, iter->second));
	}

	if (completely) {
		//Consider all factors
		for (factorList_it iter1 = factorList1.begin(); iter1 != factorList1.end(); iter1++) {
			for (factorList_it iter2 = factorList2.begin(); iter2 != factorList2.end(); iter2++) {
				unsigned int count1 = iter1->exponent;
				unsigned int count2 = iter2->exponent;

				constPoly minusIter2 = Parameters::getInstance().getMinusOne()->times(iter2->pol);
				if (iter1->pol == iter2->pol || iter1->pol == minusIter2) {
					//Part of gcd found
					constPoly factor = iter1->pol;

					//Possible sign for rest2
					if (iter1->pol == minusIter2) {
						if (count1 >= count2) {
							factorList2.push_back(Factor(Parameters::getInstance().getMinusOne(), count2));
						} else {
							factorList2.push_back(Factor(Parameters::getInstance().getMinusOne(), count1));
						}
					}

					if (count1 == count2) {
						gcd.push_back(Factor(factor, count1));
						//Remove gcd from both factorizations
						iter1 = factorList1.erase(iter1);
						iter1--;
						factorList2.erase(iter2);
						iter2--;
						//Current factor1 was removed
						break;
					} else if (count1 < count2) {
						gcd.push_back(Factor(factor, count1));
						//Remove gcd from first factorization
						iter1 = factorList1.erase(iter1);
						iter1--;
						//Update second factor
						iter2->exponent = count2 - count1;
						//Current factor1 was removed
						break;
					} else {
						//count2 < count1
						gcd.push_back(Factor(factor, count2));
						//Update first factor
						iter1->exponent = count1 - count2;
						//Remove gcd from second factorization
						factorList2.erase(iter2);
						iter2--;
					}
				} else {
					//Possible refinement
					constPoly factor1 = iter1->pol;
					constPoly factor2 = iter2->pol;

					if (!factor1->isIrreducible() || !factor2->isIrreducible()) {
						//Maybe gcd was already computed with result 1
						//TODO check if working as planned
						if (!gcdPairSet.isInGCDPairSet(factor1, factor2)) {
							//Compute GCD for better factorization with use of Ginac
							GCDResult gcdResult = gcdGinac(factor1, factor2);
							constPoly polGCD = gcdResult.gcd;

							if (*polGCD != 1) {
								//Extend gcd
								unsigned int exponent = count1 < count2 ? count1 : count2;
								gcd.push_back(Factor(polGCD, exponent));

								if (TriBool::toBoolUndefTrue(*polGCD != *factor1)) {
									//Update factorization fac=rest*gcd
									factor1->getFactorization()->setFactors(gcdResult.rest1, polGCD);
									//Insert rest at end
									factorList1.push_back(Factor(gcdResult.rest1, count1));
								}

								if (TriBool::toBoolUndefTrue(*polGCD != *factor2)) {
									//Update factorization fac=rest*gcd
									factor2->getFactorization()->setFactors(gcdResult.rest2, polGCD);
									//Insert rest at end
									factorList2.push_back(Factor(gcdResult.rest2, count2));
								}

								if (count1 > count2) {
									//Insert part of gcd
									factorList1.push_back(Factor(polGCD, count1 - count2));
								} else if (count2 > count1) {
									//Insert part of gcd
									factorList2.push_back(Factor(polGCD, count2 - count1));
								}

								//Remove old factors from factorizations
								iter1 = factorList1.erase(iter1);
								iter1--;
								iter2 = factorList2.erase(iter2);
								iter2--;
								//current factor1 was removed
								break;
							} else {
								//Result of gcd is 1
								gcdPairSet.insertGCDPair(factor1, factor2);
							}
						}
					}
				}
			}
		}
	}

	if (gcd.empty()) {
		//No gcd found
		return GCDResult(Parameters::getInstance().getOne(), pol1, pol2);
	}

	//Compute rest from leftovers of factorList
	constPoly gcdPol = Parameters::getInstance().createPolynomial(gcd);
	constPoly rest1Pol = Parameters::getInstance().createPolynomial(factorList1);
	constPoly rest2Pol = Parameters::getInstance().createPolynomial(factorList2);

#ifdef DEBUGGING
	//Checking result for correctness
	assert((gcdPol->getExpression() * rest1Pol->getExpression()).expand() == pol1->getExpression().expand());
	assert((gcdPol->getExpression() * rest2Pol->getExpression()).expand() == pol2->getExpression().expand());
	ex result = GiNaC::gcd(pol1->getExpression(), pol2->getExpression());
	//Compare GiNac's and our result
	ex resDiv;
	assert(divide(result.expand(), gcdPol->getExpression().expand(), resDiv));
	if (completely) {
		assert(resDiv == 1 || resDiv == -1);
	}
#endif

	return GCDResult(gcdPol, rest1Pol, rest2Pol);
}

/**
 * Addition of two rational functions w.r.t. LCM as num1/denom1 + num2/denom2. By use of factorization we search for the least common multiple and expand each rational function.
 * @param rat1 first rational function
 * @param rat2 second rational function
 * @return result as rat1+rat2
 */
Rational Cancellator::additionFactorization(const Rational& rat1, const Rational& rat2) {
	const clock_t t_start = clock();

	Rational result;

	//Find gcd of denominators
	GCDResult gcdDenomResult = computeGCD(rat1.denominator, rat2.denominator, false);

	//lcm is one polynom * other polynom without gcd (=rest)
	result.setDenominator(gcdDenomResult.rest1->times(rat2.denominator));

	//Find gcd of numerators
	//TODO only factorization and not completely with GCD?
	GCDResult gcdNumResult = computeGCD(rat1.numerator, rat2.numerator, false);

	//Factor for second numerator is rest of first denominator
	constPoly secondNum = gcdNumResult.rest2->times(gcdDenomResult.rest1);

	//Factor for first numerator is rest of second denominator
	constPoly firstNum = gcdNumResult.rest1->times(gcdDenomResult.rest2);

	//Result for numerator
	result.setNumerator(gcdNumResult.gcd->times(firstNum->add(secondNum)));

	const clock_t t_end = clock();
	additionTime += t_end - t_start;

	//Cancel completely
	cancel(result, true);

#ifdef DEBUGGING
	//Checking result for correctness
	ex originalResult = (rat1.getExpression() + rat2.getExpression()).numer_denom();
	assert(originalResult.nops() == 2);

	ex ownResult = result.getExpression().numer_denom();
	assert(ownResult.nops() == 2);

	assert(originalResult.expand() == ownResult.expand());
#endif

	return result;
}

/**
 * Get greatest common coefficient in polynomial
 * @param pol polynomial
 * @return greatest common (numeric) coefficient
 */
numeric Cancellator::getCommonCoefficient(constPoly pol) {
	const clock_t t_start = clock();

	numeric result = getCommonCoefficient(pol->getExpression());

	const clock_t t_end = clock();
	commonCoefficientTime += t_end - t_start;

	return result;
}

/**
 * Get greatest common coefficient in (polynomial) expression
 * @param expr expression
 * @return greatest common (numeric) coefficient
 */
numeric Cancellator::getCommonCoefficient(const ex& expr) {
	if (is_exactly_a<numeric>(expr)) {
		return ex_to<numeric>(expr);
	}
	if (is_exactly_a<symbol>(expr) || is_exactly_a<power>(expr)) {
		return 1;
	}
	if (is_exactly_a<mul>(expr)) {
		numeric n = 1;
		for (const_iterator iter = expr.begin(); iter != expr.end(); ++iter) {
			if (is_exactly_a<numeric>(*iter)) {
				n *= ex_to<numeric>(*iter);
			}
		}
		return n;
	}
	if (is_exactly_a<GiNaC::add>(expr)) {
		const_iterator iter = expr.begin();
		numeric m = getCommonCoefficient(*iter);
		for (iter++; iter != expr.end(); iter++) {
			m = gcd(m, getCommonCoefficient(*iter));
		}
		return m;
	}

	return 1;
}

//#############
//# GCDRESULT #
//#############

/**
 * Constructor
 * @param gcd gcd result
 * @param rest1 remaining part of first polynomial
 * @param rest2 remaining part of second polynomial
 */
GCDResult::GCDResult(constPoly gcd, constPoly rest1, constPoly rest2) :
		gcd(gcd), rest1(rest1), rest2(rest2) {
#ifdef USE_POLY_MANAGEMENT
	gcd->incUsageCounter();
	rest1->incUsageCounter();
	rest2->incUsageCounter();
#endif
}

/**
 * Copy constructor
 * @param result gcdresult
 */
GCDResult::GCDResult(const GCDResult& result) :
		gcd(result.gcd), rest1(result.rest1), rest2(result.rest2) {
#ifdef USE_POLY_MANAGEMENT
	gcd->incUsageCounter();
	rest1->incUsageCounter();
	rest2->incUsageCounter();
#endif
}

/**
 * Destructor
 */
GCDResult::~GCDResult() {
#ifdef USE_POLY_MANAGEMENT
	gcd->decUsageCounter();
	rest1->decUsageCounter();
	rest2->decUsageCounter();
#endif
}

/**
 * Output
 * @param stream stream
 * @param result result of gcd computation
 * @return string
 */
std::ostream &operator<<(std::ostream &stream, const GCDResult& result) {
	return stream << result.gcd;
}
}
