# proof_synthesis_env/tactics.py
"""
A collection of tactics that transform one ProofState into another.
Each tactic returns a new ProofState on success or None on failure.
"""

import copy
from env.state import ProofState
from env.language import Implies, And, Or, PropFalse


def tactic_intros(state: ProofState, hyp_names: list[str]) -> ProofState | None:
    """
    For a goal P -> Q, adds P to context and makes Q the new goal.
    Can be run multiple times.
    """
    new_state = copy.deepcopy(state)
    for new_hyp_name in hyp_names:
        if isinstance(goal := new_state.next_goal, Implies):
            del new_state.goals[0]  # Remove the current goal
            new_state.context[new_hyp_name] = goal.antecedent
            new_state.goals.insert(0, goal.consequent)
        else:
            # Could not unroll the goal as many times as desired
            return None
    return new_state


def tactic_exact(state: ProofState, hyp_name: str) -> ProofState | None:
    """Solves a goal if it exactly matches a hypothesis."""
    if (goal := state.next_goal) is not None and goal == state.context.get(hyp_name):
        new_state = copy.deepcopy(state)
        del new_state.goals[0]
        return new_state
    else:
        return None


def tactic_split(state: ProofState) -> ProofState | None:
    """If the goal is P /\\ Q, splits it into two sub-goals: P and Q."""
    if isinstance(goal := state.next_goal, And):
        new_state = copy.deepcopy(state)
        del new_state.goals[0]  # Remove the current goal
        new_state.goals.insert(0, goal.right)
        new_state.goals.insert(0, goal.left)
        return new_state
    else:
        return None


def tactic_destruct(
    state: ProofState, hyp_to_break: str, new_hyp_names: list
) -> ProofState | None:
    """If hypothesis is P /\\ Q, replaces it with new hypotheses for P and Q."""
    if len(new_hyp_names) == 2 and isinstance(
        hyp := state.context.get(hyp_to_break), And
    ):
        new_state = copy.deepcopy(state)
        del new_state.context[hyp_to_break]
        new_state.context[new_hyp_names[0]] = hyp.left
        new_state.context[new_hyp_names[1]] = hyp.right
        return new_state
    else:
        return None


def tactic_apply(state: ProofState, hyp_name: str) -> ProofState | None:
    """
    If goal is Q and hypothesis is P -> Q, change goal to P.
    If hypothesis is P -> Q -> R, and goal is R, changes goal to P and Q.
    If hypothesis is False, solves any goal: if pigs could fly...
    """
    hyp = state.context.get(hyp_name)
    goal = state.next_goal
    assert goal is not None, "absurd: applying tactic apply with no goals"

    # Ex falso quodlibet
    if isinstance(hyp, PropFalse):
        new_state = copy.deepcopy(state)
        del new_state.goals[0]
        return new_state

    if isinstance(hyp, Implies):
        new_state = copy.deepcopy(state)

        # Unwind the implications from the hypothesis
        temp_hyp = hyp
        antecedents = []
        while isinstance(temp_hyp, Implies):
            antecedents.append(temp_hyp.antecedent)
            temp_hyp = temp_hyp.consequent

        # If the final consequent matches the goal, we can proceed
        if temp_hyp == new_state.next_goal:
            del new_state.goals[0]
            # The antecedents are collected in order p, q, ...
            # The goals should be added in the same order to be solved sequentially.
            # Prepending the list of antecedents achieves this.
            new_state.goals = antecedents + new_state.goals
            return new_state

    return None


def tactic_left(state: ProofState) -> ProofState | None:
    """If the goal is P \\/ Q, changes the goal to P."""
    if isinstance(goal := state.next_goal, Or):
        new_state = copy.deepcopy(state)
        new_state.goals[0] = goal.left  # replace goal with left disjunct
        return new_state
    else:
        return None


def tactic_right(state: ProofState) -> ProofState | None:
    """If the goal is P \\/ Q, changes the goal to Q."""
    if isinstance(goal := state.next_goal, Or):
        new_state = copy.deepcopy(state)
        new_state.goals[0] = goal.right  # replace goal with right disjunct
        return new_state
    else:
        return None


def tactic_cases(state: ProofState, hyp_to_break: str) -> ProofState | None:
    """
    If hypothesis is P \\/ Q and goal is R, makes two new goals (P -> R) and (Q -> R).
    This allows proving the goal for each case of the disjunction.
    """
    if isinstance(hyp := state.context.get(hyp_to_break), Or):
        r = state.next_goal
        assert r is not None, "absurd: applying tactic or_cases with no goals"
        # Create the new goal: (P -> R) /\ (Q -> R)
        p, q = hyp.left, hyp.right
        new_state = copy.deepcopy(state)
        del new_state.goals[0]
        new_state.goals.insert(0, Implies(q, r))
        new_state.goals.insert(0, Implies(p, r))
        del new_state.context[hyp_to_break]
        return new_state
    else:
        return None
