"""
AST representing very simple language involving variables, simple arithmetic operations, and a set of simple commands
(if, while, assignment and sequential composition)
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Union
from dataclasses import dataclass

import cvc5
from cvc5 import Kind


def fresh_var(name):
    counter = getattr(fresh_var, "_counter", 0) + 1
    setattr(fresh_var, "_counter", counter)
    return name + "_" + str(counter)


class Expr(ABC):
    @abstractmethod
    def to_cvc5(self, slv, variables, invariants):
        pass

    @abstractmethod
    def variables(self):
        pass

    @abstractmethod
    def substitute(self, subst):
        pass

    @abstractmethod
    def invariant_arities(self) -> Dict[str, int]:
        pass


@dataclass
class VariableExpr(Expr):
    name: str

    def to_cvc5(self, slv, variables, invariants):
        # Question 2b
        # YOUR CODE HERE
        raise NotImplementedError

    def variables(self):
        return {self.name}

    def substitute(self, subst):
        return subst.get(self.name, self)

    def invariant_arities(self) -> Dict[str, int]:
        return {}


@dataclass
class ConstantExpr(Expr):
    value: Union[int, bool]

    def to_cvc5(self, slv, variables, invariants):
        # Question 2b
        # YOUR CODE HERE
        raise NotImplementedError

    def variables(self):
        return set()

    def substitute(self, subst):
        return self

    def invariant_arities(self) -> Dict[str, int]:
        return {}


@dataclass
class Op:
    name: str

    def cvc5_kind(self):
        return getattr(Kind, self.name.upper())


@dataclass
class UnOpExpr(Expr):
    op: Op
    operand: Expr

    def to_cvc5(self, slv, variables, invariants):
        # Question 2b
        # YOUR CODE HERE
        raise NotImplementedError

    def variables(self):
        return self.operand.variables()

    def substitute(self, subst):
        return UnOpExpr(self.op, self.operand.substitute(subst))

    def invariant_arities(self) -> Dict[str, int]:
        return self.operand.invariant_arities()


@dataclass
class BinOpExpr(Expr):
    op: Op
    left: Expr
    right: Expr

    def to_cvc5(self, slv, variables, invariants):
        # Question 2b
        # YOUR CODE HERE
        raise NotImplementedError

    def variables(self):
        return self.left.variables() | self.right.variables()

    def substitute(self, subst):
        return BinOpExpr(
            self.op, self.left.substitute(subst), self.right.substitute(subst)
        )

    def invariant_arities(self) -> Dict[str, int]:
        return {
            **self.left.invariant_arities(),
            **self.right.invariant_arities(),
        }


class Command(ABC):
    @abstractmethod
    def verification_condition(self, postcondition: Expr) -> Expr:
        pass

    @abstractmethod
    def variables_modified_in(self):
        pass

    @abstractmethod
    def variables(self):
        pass


@dataclass
class SkipCommand(Command):
    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return set()

    def variables(self):
        return set()


@dataclass
class AssignCommand(Command):
    variable: str
    expression: Expr

    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return {self.variable}

    def variables(self):
        return self.expression.variables() | {self.variable}


@dataclass
class IfCommand(Command):
    condition: Expr
    true_command: Command
    false_command: Command

    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return (
            self.true_command.variables_modified_in()
            | self.false_command.variables_modified_in()
        )

    def variables(self):
        return (
            self.true_command.variables()
            | self.false_command.variables()
            | self.condition.variables()
        )




@dataclass
class WhileCommand(Command):
    condition: Expr
    body: Command

    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return self.body.variables_modified_in()

    def variables(self):
        return self.body.variables() | self.condition.variables()


@dataclass
class SeqCommand(Command):
    first_command: Command
    second_command: Command

    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return (
            self.first_command.variables_modified_in()
            | self.second_command.variables_modified_in()
        )

    def variables(self):
        return self.first_command.variables() | self.second_command.variables()


@dataclass
class AssertCommand(Command):
    condition: Expr

    def verification_condition(self, postcondition: Expr) -> Expr:
        # Question 2a
        # YOUR CODE HERE
        raise NotImplementedError

    def variables_modified_in(self):
        return set()

    def variables(self):
        return self.condition.variables()


def block_command(*commands):
    assert len(commands) > 0
    if len(commands) == 1:
        return commands[0]
    else:
        return SeqCommand(commands[0], block_command(*commands[1:]))


"""
x := 0;
y := 0;
t := 0;
while (x < 10) {
    if (5 < t) {
        x := x + 1;
        y := y + 1;
    }
    t := t + 1;
}
"""

example = block_command(
    AssignCommand("x", ConstantExpr(0)),
    AssignCommand("y", ConstantExpr(0)),
    AssignCommand("t", ConstantExpr(0)),
    WhileCommand(
        BinOpExpr(Op("lt"), VariableExpr("x"), ConstantExpr(10)),
        block_command(
            IfCommand(
                BinOpExpr(Op("lt"), ConstantExpr(5), VariableExpr("t")),
                block_command(
                    AssignCommand(
                        "x", BinOpExpr(Op("add"), VariableExpr("x"), ConstantExpr(1))
                    ),
                    AssignCommand(
                        "y", BinOpExpr(Op("add"), VariableExpr("y"), ConstantExpr(1))
                    ),
                ),
                SkipCommand(),
            ),
            AssignCommand(
                "t", BinOpExpr(Op("add"), VariableExpr("t"), ConstantExpr(1))
            ),
        ),
    ),
)

precondition = ConstantExpr(True)
postcondition = BinOpExpr(Op("equal"), VariableExpr("x"), VariableExpr("y"))


def boolean_grammar(slv, variables):
    # You should copy and adapt the example from the grammar
    integer = slv.getIntegerSort()
    boolean = slv.getBooleanSort()

    # declare input variables for the functions-to-synthesize
    slv_vars = [slv.mkVar(integer, v) for v in variables]

    # declare the grammar non-terminals
    start = slv.mkVar(integer, "StartInt")
    start_bool = slv.mkVar(boolean, "Start")

    # define the rules
    zero = slv.mkInteger(0)
    one = slv.mkInteger(1)

    # Kinds are
    # Kinds are listed here https://cvc5.github.io/docs/cvc5-1.0.2/api/python/base/kind.html
    plus = slv.mkTerm(Kind.ADD, start, start)
    minus = slv.mkTerm(Kind.SUB, start, start)
    times = slv.mkTerm(Kind.MULT, start, start)
    # ite = slv.mkTerm(Kind.ITE, start_bool, start, start)

    Or = slv.mkTerm(Kind.OR, start_bool, start_bool)
    And = slv.mkTerm(Kind.AND, start_bool, start_bool)
    Not = slv.mkTerm(Kind.NOT, start_bool)
    leq = slv.mkTerm(Kind.LEQ, start, start)
    lt = slv.mkTerm(Kind.LT, start, start)
    eq = slv.mkTerm(Kind.EQUAL, start, start)

    # create the grammar object
    g = slv.mkGrammar(slv_vars, [start_bool, start])

    # bind each non-terminal to its rules
    g.addRules(start, [zero, one, *slv_vars, plus, minus, times])
    g.addRules(start_bool, [Or, And, Not, eq, leq, lt])

    return g, slv_vars


def solve(assumption, vc):
    slv = cvc5.Solver()

    # required options
    slv.setOption("sygus", "true")
    slv.setOption("incremental", "false")

    # set the logic
    # slv.setLogic("LIA")

    # Question 2b
    # YOUR CODE HERE
    raise NotImplementedError


def example_1():
    """
    This is the example from 2a/2b
    """
    return dict(
        example=example,
        precondition=precondition,
        postcondition=postcondition,
    )


def example_2():
    # Question 2c
    # YOUR CODE HERE
    raise NotImplementedError
    return dict(
        example=code_2,
        precondition=precondition_2,
        postcondition=postcondition_2,
    )


def example_3():
    # Question 2c
    # YOUR CODE HERE
    raise NotImplementedError
    return dict(
        example=code_3,
        precondition=precondition_3,
        postcondition=postcondition_3,
    )


def example_4():
    # Question 2c
    # YOUR CODE HERE
    raise NotImplementedError
    return dict(
        example=code,
        precondition=precondition,
        postcondition=postcondition,
    )


def example_5():
    # Question 2c
    # YOUR CODE HERE
    raise NotImplementedError
    return dict(
        example=code,
        precondition=precondition,
        postcondition=postcondition,
    )


def example_6():
    # Question 2c
    # YOUR CODE HERE
    raise NotImplementedError
    return dict(
        example=code,
        precondition=precondition,
        postcondition=postcondition,
    )


examples = {
    "example_1": example_1,
    "example_2": example_2,
    "example_3": example_3,
    "example_4": example_4,
    "example_5": example_5,
    "example_6": example_6,
}


def main():
    from sys import argv

    if len(argv) == 1:
        print("Usage: python3 q2.py <2a|2b|2c>")
        exit(1)

    if argv[1] == "2a":
        print("Finding verification condition for example")
        print(example.verification_condition(postcondition))

    elif argv[1] == "2b":
        print("Finding solution for example")
        print(solve(precondition, example.verification_condition(postcondition)))

    elif argv[1] == "2c":
        print("Running examples")
        for name, eg in examples.items():
            print("*" * 80)
            print(f"Example {name}")
            eg = eg()
            print(eg["example"])

            print("Finding verification condition")
            vc = eg["example"].verification_condition(eg["postcondition"])
            print(vc)

            print("Finding solution")

            print(solve(eg["precondition"], vc))

    else:
        raise ValueError(f"Unknown argument {argv[1]}")


if __name__ == "__main__":
    main()
