/*
 * This file is part of a parser for an extension of the PRISM language.
 *
 * This is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * The parser is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with the program this parser part of.
 * If not, see <http://www.gnu.org/licenses/>.
 *
 * Copyright 2007-2010 Bjoern Wachter (Bjoern.Wachter@comlab.ox.ac.uk)
 * Copyright 2009-2012 Ernst Moritz Hahn (emh@cs.uni-saarland.de)
 */

#include <limits>
#include <string>
#include <stdio.h>
#include <iostream>
#include "AST.h"
#include "Util.h"
#include "Property.h"
#include "PropertyImpl.h"
#include "Model.h"
#include "ModelImpl.h"
#include "PRISMParser.h"
extern void PRISMparse();
extern FILE *PRISMin;
#include "Expr.h"
#include "System.h"

std::string file_name;

namespace prismparser {
extern unsigned rewStructNr;

using namespace std;

class PRISMParserImpl {
private:
	static Expr translateBaseExpr(boost::shared_ptr<prismparser_ast::Expr>);
	static void translateModel(prismparser_ast::Model &, Model &);
	static ProbBound *boundFromAST(const prismparser_ast::Expr &, bool, bool);
	static Filter *filterFromAST(const prismparser_ast::Expr &);
	static Property* translateProperty(boost::shared_ptr<prismparser_ast::Expr>);
	static Expr translateExpr(boost::shared_ptr<prismparser_ast::Expr>);
	static Alternative* translateAlternative(boost::shared_ptr<prismparser_ast::Alternative>);
	static Command* translateCommand(boost::shared_ptr<prismparser_ast::Command>);
	static Module* translateModule(boost::shared_ptr<prismparser_ast::Module>);
	static void translateVariables(const prismparser_ast::Variables &, Model &);

	friend class PRISMParser;
};

extern prismparser_ast::RewNameToNr rewNameToNr;
extern int line_number;

prismparser_ast::Model prismparser::PRISMParser::astModel;
prismparser_ast::Substitution constants;
prismparser_ast::ExternalConstants externalConstants;
prismparser::System *prismsystem;

PRISMParser::PRISMParser() {
}

PRISMParser::~PRISMParser() {
}

void PRISMParser::run(const string& file, prismparser::Model &model) {
	//    rewStructNr = 0;
	prismsystem = NULL;
	file_name = file;
	line_number = 1;
	if (!(PRISMin = fopen(file.c_str(), "r"))) {
		throw prismparser_error("File " + file + " not found\n");
	}

	PRISMparse();

	PRISMParserImpl::translateModel(PRISMParser::astModel, model);
	PRISMParser::astModel.clear(); // avoid double insertions into ::model
	fclose(PRISMin);
}

void PRISMParser::addConstDef(const string &constant, const string &value) {
	externalConstants.insert(make_pair(constant, value));
}

Expr PRISMParserImpl::translateBaseExpr(boost::shared_ptr<prismparser_ast::Expr> ae) {
	return Expr(ae);
}

ProbBound *PRISMParserImpl::boundFromAST(const prismparser_ast::Expr& boundExpr, bool minimize, bool minOrMax) {
	BoundType k;
	Expr bound;
	if (2 == boundExpr.arity()) {
		bound = translateBaseExpr(boundExpr.children[1]).simplify();
	}

	switch (boundExpr.getKind()) {
	case prismparser_ast::Gt:
		k = GrBound; // >  bound ... greater
		break;
	case prismparser_ast::Ge:
		k = GeBound; // >= bound ... greater or equal
		break;
	case prismparser_ast::Lt:
		k = LtBound; // <  bound ... strictly less
		break;
	case prismparser_ast::Le:
		k = LeBound; // <= bound ... less or equal
		break;
	case prismparser_ast::Eq:
		k = EqBound; // =  bound ... equal
		break;
	default:
		k = DkBound; // = ? ... value to be computed
		break;
	}

	ProbBound *b = new ProbBound();
	b->impl->type = k;
	b->impl->bound = new Expr(bound);
	b->impl->min = minimize;
	b->impl->minOrMax = minOrMax;

	return b;
}

Filter *PRISMParserImpl::filterFromAST(const prismparser_ast::Expr &filterExpr) {
	assert(prismparser_ast::Filter == filterExpr.getKind());
	assert(2 == filterExpr.arity());
	Expr filter(translateBaseExpr(filterExpr.children[0]));

	if (prismparser_ast::NullFilter == filterExpr.children[1]->getKind()) {
		return NULL;
	}
	Filter *f = new Filter;
	f->impl->expr = new Expr(filter);
	switch (filterExpr.children[1]->getKind()) {
	case prismparser_ast::MinFilter:
		f->impl->type = MinFilter;
		break;
	case prismparser_ast::MaxFilter:
		f->impl->type = MaxFilter;
		break;
	case prismparser_ast::CountFilter:
		f->impl->type = CountFilter;
		break;
	case prismparser_ast::SumFilter:
		f->impl->type = SumFilter;
		break;
	case prismparser_ast::AvgFilter:
		f->impl->type = AvgFilter;
		break;
	case prismparser_ast::FirstFilter:
		f->impl->type = FirstFilter;
		break;
	case prismparser_ast::RangeFilter:
		f->impl->type = RangeFilter;
		break;
	case prismparser_ast::ForallFilter:
		f->impl->type = ForallFilter;
		break;
	case prismparser_ast::ExistsFilter:
		f->impl->type = ExistsFilter;
		break;
	case prismparser_ast::StateFilter:
		f->impl->type = StateFilter;
		break;
	case prismparser_ast::ArgminFilter:
		f->impl->type = ArgminFilter;
		break;
	case prismparser_ast::ArgmaxFilter:
		f->impl->type = ArgmaxFilter;
		break;
	case prismparser_ast::PrintFilter:
		f->impl->type = PrintFilter;
		break;
	default:
		assert(false);
	}

	return f;
}

static ExprHashMap<Expr> replaceInit;

Property* PRISMParserImpl::translateProperty(boost::shared_ptr<prismparser_ast::Expr> ae) {
	Property* result = NULL;
	assert(ae.get());
	prismparser_ast::Expr e(*ae.get());

	switch (e.getKind()) {
	case prismparser_ast::Next: {
		Property* p1(translateProperty(e.children[0]));
		TimeBound *tb = new TimeBound();

		result = new Property();
		result->impl->type = UntilProp;
		result->impl->timeBound = tb;
		result->impl->children.push_back(p1);
		Expr t1 = Expr(e.children[1]).simplify();
		if (!t1.isRational()) {
			throw prismparser_error("time bound of Next property does not evaluate to constant");
		}
		Expr t2 = Expr(e.children[2]).simplify();
		if (!t2.isRational()) {
			throw prismparser_error("time bound of Next property does not evaluate to constant");
		}
		tb->impl->t1 = t1.getRatAsDouble();
		tb->impl->t2 = t2.getRatAsDouble();
		break;
	}
	case prismparser_ast::Until: {
		Property* p1(translateProperty(e.children[0]));
		Property* p2(translateProperty(e.children[1]));
		TimeBound *tb = new TimeBound();
		result = new Property();
		result->impl->type = UntilProp;
		result->impl->timeBound = tb;
		result->impl->children.push_back(p1);
		result->impl->children.push_back(p2);
		Expr t1 = Expr(e.children[2]).simplify();
		if (!t1.isRational()) {
			throw prismparser_error("time bound of Until property does not evaluate to constant");
		}
		Expr t2 = Expr(e.children[3]).simplify();
		if (!t2.isRational()) {
			throw prismparser_error("time bound of Until property does not evaluate to constant");
		}
		tb->impl->t1 = t1.getRatAsDouble();
		tb->impl->t2 = t2.getRatAsDouble();
		break;
	}
	case prismparser_ast::P:
	case prismparser_ast::Steady:
	case prismparser_ast::ReachabilityReward:
	case prismparser_ast::CumulativeReward:
	case prismparser_ast::InstantaneousReward:
	case prismparser_ast::SteadyStateReward: {
		const int direction((e.children[0].get())->getInt());
		const prismparser_ast::Expr& bound_expr(*e.children[1].get());
		const boost::shared_ptr<prismparser_ast::Expr> &inner_expr(e.children[2]);
		const prismparser_ast::Expr& filter(*e.children[4].get());
		unsigned rewNr = 0;
		if ((e.getKind() != prismparser_ast::P) && (e.getKind() != prismparser_ast::Steady)) {
			Expr rewStruct = Expr(e.children[3]).simplify();
			if (rewStruct.isInteger()) {
				rewNr = rewStruct.getNumerator();
				rewNr--;
			} else if (rewStruct.isVar()) {
				const string rewName(rewStruct.getName());
				if (0 == rewNameToNr.count(rewName)) {
					throw prismparser_error("Reward structure \"" + rewName + "\" not specified.");
				}
				rewNr = rewNameToNr[rewName];
			}
		}

		bool min(false);
		bool minOrMax(false);
		if (0 == direction) {
			if ((prismparser_ast::Gt == bound_expr.getKind()) || (prismparser_ast::Ge == bound_expr.getKind())) {
				min = true;
				minOrMax = false;
			} else if ((prismparser_ast::Lt == bound_expr.getKind()) || (prismparser_ast::Le == bound_expr.getKind())) {
				min = false;
				minOrMax = false;
			} else {
				min = false;
				minOrMax = false;
			}
		} else if (-1 == direction) {
			min = true;
			minOrMax = true;
		} else if (1 == direction) {
			min = false;
			minOrMax = true;
		}

		result = new Property();
		result->impl->rewStructNr = rewNr;
		result->impl->probBound = boundFromAST(bound_expr, min, minOrMax);
		result->impl->filter = filterFromAST(filter);
		if (prismparser_ast::P == e.getKind()) {
			result->impl->type = QuantProp;
			result->impl->children.push_back(translateProperty(inner_expr));
		} else if (prismparser_ast::Steady == e.getKind()) {
			result->impl->type = SteadySProp;
			result->impl->children.push_back(translateProperty(inner_expr));
		} else if (prismparser_ast::ReachabilityReward == e.getKind()) {
			result->impl->type = ReachRewProp;
			result->impl->children.push_back(translateProperty(inner_expr));
		} else if (prismparser_ast::CumulativeReward == e.getKind()) {
			result->impl->type = CumulRewProp;
			result->impl->doubleTime = inner_expr->getDoubleVal();
		} else if (prismparser_ast::SteadyStateReward == e.getKind()) {
			result->impl->type = SteadySRewProp;
		} else if (prismparser_ast::InstantaneousReward == e.getKind()) {
			result->impl->type = InstRewProp;
			result->impl->doubleTime = inner_expr->getDoubleVal();
		}
		break;
	}
	case prismparser_ast::Not: {
		Property *inner = translateProperty(e.children[0]);
		if (ExprProp == inner->getType()) {
			result = new Property();
			result->impl->type = ExprProp;
			result->impl->expr = new Expr(Expr::notExpr(inner->getExpr()));
			delete inner;
		} else {
			result = new Property();
			result->impl->type = NegProp;
			result->impl->children.push_back(inner);
		}
		break;
	}
	case prismparser_ast::And: {
		Property *innerA = translateProperty(e.children[0]);
		Property *innerB = translateProperty(e.children[1]);
		if ((ExprProp == innerA->getType()) && (ExprProp == innerB->getType())) {
			result = new Property();
			result->impl->type = ExprProp;
			result->impl->expr = new Expr(Expr::andExpr(innerA->getExpr(), innerB->getExpr()));
			delete innerA;
			delete innerB;
		} else {
			result = new Property();
			result->impl->type = AndProp;
			result->impl->children.push_back(innerA);
			result->impl->children.push_back(innerB);
		}
		break;
	}
	case prismparser_ast::Or: {
		Property *innerA = translateProperty(e.children[0]);
		Property *innerB = translateProperty(e.children[1]);
		if ((ExprProp == innerA->getType()) && (ExprProp == innerB->getType())) {
			result = new Property();
			result->impl->type = ExprProp;
			result->impl->expr = new Expr(Expr::orExpr(innerA->getExpr(), innerB->getExpr()));
			delete innerA;
			delete innerB;
		} else {
			result = new Property();
			result->impl->type = OrProp;
			result->impl->children.push_back(innerA);
			result->impl->children.push_back(innerB);
		}
		break;
	}
	case prismparser_ast::Impl: {
		Property *innerA = translateProperty(e.children[0]);
		Property *innerB = translateProperty(e.children[1]);
		if ((ExprProp == innerA->getType()) && (ExprProp == innerB->getType())) {
			result = new Property();
			result->impl->type = ExprProp;
			result->impl->expr = new Expr(Expr::implExpr(innerA->getExpr(), innerB->getExpr()));
			delete innerA;
			delete innerB;
		} else {
			result = new Property();
			result->impl->type = ImplProp;
			result->impl->children.push_back(innerA);
			result->impl->children.push_back(innerB);
		}
		break;
	}
	default: {
		Expr nested_expr(translateBaseExpr(ae));
		nested_expr = nested_expr.substExpr(replaceInit);
		result = new Property();
		result->impl->type = ExprProp;
		result->impl->expr = new Expr(nested_expr);
		break;
	}
	}
	return result;
}

Expr PRISMParserImpl::translateExpr(boost::shared_ptr<prismparser_ast::Expr> ae) {
	Property *prop = translateProperty(ae);
	assert(prop->getType() == ExprProp);
	Expr result(prop->getExpr());
	delete prop;

	return result;
}

Alternative* PRISMParserImpl::translateAlternative(boost::shared_ptr<prismparser_ast::Alternative> aa) {
	const prismparser_ast::Alternative& alternative(*aa.get());
	const prismparser_ast::Update& update(alternative.update);
	Alternative* result(new Alternative());

	for (prismparser_ast::Assignment::const_iterator i = update.assignment.begin(); i != update.assignment.end(); i++) {
		Expr lhs(translateExpr(i->first));
		Expr rhs(translateExpr(i->second));
		result->impl->Assign(lhs, rhs);
	}
	Expr weight(translateExpr(alternative.weight));
	result->impl->setWeight(weight);

	return result;
}

Command* PRISMParserImpl::translateCommand(boost::shared_ptr<prismparser_ast::Command> ac) {
	string label;
	boost::shared_ptr<prismparser_ast::Expr> guard;
	prismparser_ast::Alternatives alternatives;

	const prismparser_ast::Command& command(*ac.get());
	Command* result(new Command());

	for (prismparser_ast::Alternatives::const_iterator i(command.alternatives.begin()); i != command.alternatives.end(); ++i) {
		result->impl->addAlternative(translateAlternative(*i));
	}

	result->impl->setGuard(translateExpr(command.guard));
	result->impl->setAction(command.label);
	return result;
}

Module* PRISMParserImpl::translateModule(boost::shared_ptr<prismparser_ast::Module> am) {
	const prismparser_ast::Module& module(*am.get());
	Module* result(new Module());
	result->impl->name = module.name;
	for (prismparser_ast::Commands::const_iterator i(module.commands.begin()); i != module.commands.end(); ++i) {
		result->impl->addCommand(translateCommand(*i));
	}
	return result;
}

void PRISMParserImpl::translateVariables(const prismparser_ast::Variables& vars, Model& model) {
	for (prismparser_ast::Variables::const_iterator i(vars.begin()); i != vars.end(); ++i) {
		const prismparser_ast::Variable& var(*i->second.get());

		Expr var_expr;

		switch (var.type->kind) {
		case prismparser_ast::Type::Boolean: {
			var_expr = Expr::varExpr(i->first);
			var_expr.setVarType(BoolVar);
			Expr::setVarBounds(var_expr, 0, 1, 1, 1);
			model.impl->addVariable(var_expr);
			model.impl->setDefaultInitialValue(var_expr, var.init.get() ? translateExpr(var.init) : Expr::falseExpr());
		}
			break;
		case prismparser_ast::Type::Integer: {
			var_expr = Expr::varExpr(i->first);
			var_expr.setVarType(IntVar);
			Expr::setVarBounds(var_expr, -1, 0, 1, 0);
			model.impl->addVariable(var_expr);
			model.impl->setDefaultInitialValue(var_expr, var.init.get() ? translateExpr(var.init) : Expr::ratExpr(0ll, 1ll));
		}
			break;
		case prismparser_ast::Type::Double: {
			var_expr = Expr::varExpr(i->first);
			var_expr.setVarType(RealVar);
			Expr::setVarBounds(var_expr, -1, 0, 1, 0);
			model.impl->addVariable(var_expr);
			model.impl->setDefaultInitialValue(var_expr, var.init.get() ? translateExpr(var.init) : Expr::ratExpr(0ll, 1ll));
		}
			break;
		case prismparser_ast::Type::Range: {
			Expr upper, lower;
			lower = translateExpr(var.type->range_data.lower).simplify();
			upper = translateExpr(var.type->range_data.upper).simplify();
			var_expr = (Expr::varExpr(i->first));
			assert(lower.isRational());
			assert(upper.isRational());
			var_expr.setVarType(RangeVar);
			Expr::setVarBounds(var_expr, lower.getNumerator(), lower.getDenominator(), upper.getNumerator(),
					upper.getDenominator());
			model.impl->addVariable(var_expr);
			model.impl->setDefaultInitialValue(var_expr, var.init.get() ? translateExpr(var.init) : lower);
		}
			break;
		}

		if (var.is_parameter) {
			model.impl->setAsParameter(var_expr);
		}
	}
}

void PRISMParserImpl::translateModel(prismparser_ast::Model &am, Model &model) {
	switch (am.model_type) {
	case prismparser_ast::DTMC:
		model.impl->setModelType(DTMC);
		break;
	case prismparser_ast::MDP:
		model.impl->setModelType(MDP);
		break;
	case prismparser_ast::CTMC:
		model.impl->setModelType(CTMC);
		break;
	case prismparser_ast::Unspecified:
		model.impl->setModelType(MDP);
		break;
	}

	/* 1) Variable table
	 *
	 * build the variable table by traversing the model
	 * collecting variables from each module */

	/* global variables */
	translateVariables(am.globals, model);

	/* local module variables */
	for (prismparser_ast::Modules::const_iterator i(am.modules.begin()); i != am.modules.end(); i++) {
		translateVariables(i->second->locals, model);
	}

	/* 2) translate modules and add them to the model */
	for (prismparser_ast::Modules::const_iterator i(am.modules.begin()); i != am.modules.end(); i++) {
		model.impl->addModule(translateModule(i->second));
	}

	/* 3) translate the rest */

	// boost::shared_ptr < Expr > initial
	if (am.initial.get()) {
		Expr e(translateExpr(am.initial));
		model.impl->setInitial(e);
	}
	if (model.getInitial().isNull()) {
		model.impl->computeDefaultInitialValue();
	}

	// Exprs invariants
	for (prismparser_ast::Exprs::const_iterator i = am.invariants.begin(); i != am.invariants.end(); ++i) {
		Expr e(translateExpr(*i));
		model.impl->addInvariant(e);
	}

	// Actions actions;
	for (prismparser_ast::Actions::const_iterator i = am.actions.begin(); i != am.actions.end(); ++i) {
		model.impl->addAction(*i);
	}

	// Exprs predicates;
	for (prismparser_ast::Exprs::const_iterator i = am.predicates.begin(); i != am.predicates.end(); i++) {
		Expr e(translateExpr(*i));
		model.impl->addPredicate(e);
	}

	// Exprs invariants
	replaceInit.clear();
	replaceInit.insert(make_pair(Expr::varExpr("init"), model.getInitial()));
	for (prismparser_ast::Exprs::const_iterator i = am.properties.begin(); i != am.properties.end(); i++) {
		Property* p(translateProperty(*i));
		model.impl->addProperty(p);
	}

	// StateRewards state_rewards;
	for (prismparser_ast::StateRewards::const_iterator i = am.state_rewards.begin(); i != am.state_rewards.end(); ++i) {
		const unsigned structure(i->get<0>());
		Expr guard(translateExpr(i->get<1>()));
		Expr reward(translateExpr(i->get<2>()));
		model.impl->rewards[structure].addStateReward(guard.simplify(), reward.simplify());
	}

	// TransitionRewards transition_rewards;
	for (prismparser_ast::TransitionRewards::const_iterator i = am.transition_rewards.begin(); i != am.transition_rewards.end();
			i++) {
		const unsigned structure(i->get<0>());
		Action action(i->get<1>());
		Expr guard(translateExpr(i->get<2>()));
		Expr reward(translateExpr(i->get<3>()));
		model.impl->rewards[structure].addTransReward(action, guard.simplify(), reward.simplify());
	}

	if (NULL != prismsystem) {
		model.impl->system = prismsystem;
	}

}
}

