#include "Reward.h"

#include <sstream>

namespace prismparser {
using namespace std;

void Reward::addStateReward(const Expr &guard, const Expr &expr) {
	stateRewards.push_back(make_pair(guard, expr));
}

void Reward::addTransReward(const string &action, const Expr &guard, const Expr &expr) {
	transRewards[action].push_back(make_pair(guard, expr));
}

const Expr &Reward::getStateRewardExpr() const {
	Expr zero = Expr::ratExpr(0l, 1l);
	if (!stateRewExpr.isNull()) {
		return stateRewExpr;
	} else {
		vector<Expr> summands;
		for (unsigned entryNr = 0; entryNr < stateRewards.size(); entryNr++) {
			const Expr &guard(stateRewards[entryNr].first);
			const Expr &expr(stateRewards[entryNr].second);
			summands.push_back(Expr::iteExpr(guard, expr, zero));
		}
		stateRewExpr = Expr::sumExpr(summands).simplify();
		return stateRewExpr;
	}
}

const Expr &Reward::getTransRewardExpr(const string &action) const {
	if (!transRewExpr[action].isNull()) {
		return transRewExpr[action];
	} else {
		if (transRewards.end() == transRewards.find(action)) {
			transRewExpr[action] = Expr::ratExpr(0l, 1l);
			return transRewExpr[action];
		} else {
			const vector<pair<Expr, Expr> > &rewards(transRewards.find(action)->second);
			vector<Expr> summands;
			Expr zero = Expr::ratExpr(0l, 1l);
			for (unsigned entryNr = 0; entryNr < rewards.size(); entryNr++) {
				const Expr &guard(rewards[entryNr].first);
				const Expr &expr(rewards[entryNr].second);
				summands.push_back(Expr::iteExpr(guard, expr, zero));
			}
			transRewExpr[action] = Expr::sumExpr(summands).simplify();
			return transRewExpr[action];
		}
	}
}

ostream &operator<<(ostream &stream, const Reward &e) {
	stream << "rewards\n";
	for (unsigned entryNr = 0; entryNr < e.stateRewards.size(); entryNr++) {
		const Expr &guard(e.stateRewards[entryNr].first);
		const Expr &expr(e.stateRewards[entryNr].second);
		stream << "  " << guard << " : " << expr << ";\n";
	}
	for (Reward::TransRewards::const_iterator it = e.transRewards.begin(); it != e.transRewards.end(); it++) {
		const string &action(it->first);
		for (unsigned entryNr = 0; entryNr < it->second.size(); entryNr++) {
			const Expr &guard(it->second[entryNr].first);
			const Expr &expr(it->second[entryNr].second);
			stream << "  " << "[" << action << "]" << guard << " : " << expr << ";\n";
		}
	}
	stream << "endrewards\n";
	return stream;
}

const string &Reward::toString() const {
	stringstream sstream;

	if ("" != str) {
		return str;
	} else {
		sstream << *this;
		str = sstream.str();
		return str;
	}
}
}
