"""Module that converts a problem data and a plan to a trajectory."""
import logging
from typing import List, Tuple

import grounding
from pddl.pddl import Problem, Domain, Action, Predicate
from pyperplan import Parser
from task import Operator

from sam_learner.core.grounded_action_locator import parse_plan_action_string, locate_lifted_action, \
	ground_lifted_action
from sam_learner.sam_models.parameter_binding import ParameterBinding
from sam_learner.sam_models.state import State
from sam_learner.sam_models.trajectory_component import TrajectoryComponent
from sam_learner.sam_models.types import Trajectory


class TrajectoryGenerator:
	"""Class that generates trajectory out of a problem file and it's corresponding plan."""

	logger: logging.Logger
	domain: Domain
	problem: Problem

	def __init__(self, domain_path: str, problem_path: str):
		parser = Parser(domain_path, problem_path)
		self.logger = logging.getLogger(__name__)
		self.domain = parser.parse_domain(read_from_file=True)
		self.problem = parser.parse_problem(dom=self.domain, read_from_file=True)

	def _parse_plan_actions(self, plan_path: str) -> Tuple[List[str], List[Action]]:
		"""Parse the actions that exist in the plan file and extract the action objects with the
		object binding.

		:param plan_path: the path to the file containing the plan that solves the input problem.
		:return: both the grounded operator names and the actions to execute.
		"""
		self.logger.debug(f"Parsing the plan file in the path - {plan_path}")
		with open(plan_path, "rt") as plan_file:
			lines = plan_file.readlines()
			lines = [line.strip("\n") for line in lines]
			if "cost" in lines[-1]:
				lines.pop()  # remove the cost line from the plan.

		actions = []
		self.logger.debug("The plan contains the following grounded actions:")
		for index, line in enumerate(lines):
			self.logger.debug(line)
			action_name, action_params = parse_plan_action_string(line)
			if len(action_params) == 0:
				lines[index] = f"({action_name})"

			actions.append(locate_lifted_action(self.domain, action_name))

		return lines, actions

	def _create_state_parameter_matching(self, grounded_state: frozenset) -> State:
		"""Matches between the grounded and the lifted predicates.

		:param grounded_state: the grounded set of facts that represent the state.
		:return: a match between the grounded and the lifted predicates.
		"""
		lifted_state_data: List[Tuple[Predicate, List[ParameterBinding]]] = \
			State.generate_problem_state(grounded_state, self.domain, self.problem)
		grounded_predicates = []
		for predicate, bindings in lifted_state_data:
			signature = [binding.bind_parameter() for binding in bindings]
			grounded_predicates.append(Predicate(predicate.name, signature))

		return State(grounded_predicates, self.domain)

	def _create_single_trajectory_component(
			self, index: int, action: Action, op_name: str,
			operators: dict, previous_state: State) -> TrajectoryComponent:
		"""Create a single trajectory component by applying the action on the previous state.

		:param index: the index of the step that is being parsed currently.
		:param action: the lifted action that is to be executed on the state.
		:param op_name: the grounded operator's name.
		:param operators: the grounded operators that can apply an action on a state.
		:param previous_state: the previous state to be changed.
		:return: the trajectory component representing the current stage.
		"""
		grounded_action: Operator = operators[op_name]
		grounded_next_state_statements = grounded_action.apply(previous_state.ground_facts())
		next_state = self._create_state_parameter_matching(grounded_next_state_statements)
		_, action_objects = parse_plan_action_string(grounded_action.name)
		_, bindings = ground_lifted_action(action, action_objects)
		component = TrajectoryComponent(index=index,
										previous_state=previous_state,
										action=action,
										action_parameters_binding=bindings,
										next_state=next_state)
		self.logger.debug(f"Finished creating the component:\n{component}")
		return component

	def generate_trajectory(self, plan_path: str, should_return_partial_trajectory: bool = False) -> Trajectory:
		"""Generates a trajectory out of a problem file and a plan file.

		:param plan_path the path to the plan generated by a solver.
		:param should_return_partial_trajectory: whether or not to return a partial trajectory in case of failure.
		:return: the trajectory that represents the plan file.

		Note:
			the plan output should be in the form of: (load-truck obj23 tru2 pos2) ...
		"""
		self.logger.info(f"Generating a trajectory from the file: {self.problem}")
		trajectory = []
		grounded_planning_task = grounding.ground(problem=self.problem)
		operators = {operator.name: operator for operator in grounded_planning_task.operators}
		op_names, plan_actions_sequence = self._parse_plan_actions(plan_path)
		previous_state = State(self.problem.initial_state, self.domain)

		self.logger.info("Starting to iterate over the actions sequence.")
		for index, (op_name, action) in enumerate(zip(op_names, plan_actions_sequence)):
			try:
				component = self._create_single_trajectory_component(
					index, action, op_name, operators, previous_state)
				trajectory.append(component)
				previous_state = component.next_state

			except AssertionError:
				error_message = f"The operation {op_name} is not applicable! The failed action - {action.name}"
				self.logger.warning(error_message)
				if should_return_partial_trajectory:
					self.logger.debug(f"Returning partial trajectory since the flag was turned on.")
					return trajectory

				raise AssertionError(error_message)

		return trajectory

	def validate_trajectory(self, trajectory: Trajectory) -> bool:
		"""Validate that the last state in the trajectory is indeed the goal state.

		:param trajectory: the trajectory to validate.
		:return: whether or not the trajectory ends with the goal state.
		"""
		last_component: TrajectoryComponent = trajectory[-1]
		last_state = last_component.next_state
		grounded_goals = grounding.ground(problem=self.problem).goals
		return grounded_goals <= last_state.ground_facts()
