from functools import lru_cache
import json
import subprocess
import tempfile
import openai
from datasets import load_dataset
import tqdm

# You can find the port on piazza.
# YOUR CODE HERE
port = 'FILL HERE'
openai.api_base = f"http://sketch5.csail.mit.edu:{port}/v1"


@lru_cache(maxsize=1)
def get_model():
    # List models API
    models = openai.Model.list()
    [model] = models["data"]
    model = model["id"]
    return model


@lru_cache(maxsize=1)
def get_data():
    return [x for x in load_dataset("mbpp", split="train") if not x["test_setup_code"]]


def datum(idx):
    dat = get_data()[idx]
    tests = dat["test_list"]
    return dict(
        text=dat["text"],
        code=dat["code"],
        example_test=tests[0],
        test_list=tests[1:],
    )


def get_completions(prompt, *, num_completions, max_tokens, temperature=1):
    # Completion API
    completion = openai.Completion.create(
        model=get_model(),
        prompt=prompt,
        echo=False,
        n=num_completions,
        max_tokens=max_tokens,
        stream=False,
        temperature=temperature,
    )
    return [x["text"] for x in completion["choices"]]


standard_instructions = """Put your fixed program within code delimiters, for example: [PYTHON]
# YOUR CODE HERE
[/PYTHON]"""


def zero_shot_prompt(
    instructions, text, example_test, *, include_answer=False, code=None
):
    """
    Zero-shot prompt for the Python program completion task, where the
    model is given a question and must generate a correct Python program
    that matches the specification and passes all tests. You are allowed to use
    the example_test in the prompt, but not the test_list.

    include_answer, code: ignore unless useful for part 2b
    """
    p = instructions + "\n"
    # Question 2a
    # YOUR CODE HERE
    return p


def one_shot_prompt(instructions, text, example_test):
    """
    In one-shot prompting, the model is given an example of a question and
    answer, and must generate a correct Python program that matches the
    specification and passes all tests.
    """
    reference_sample = datum(0)
    assert text != reference_sample["text"]
    # Question 2b
    # YOUR CODE HERE


def clean_program(completion):
    end_token = "[/PYTHON]"
    # Question 2a
    # YOUR CODE HERE


def test_program(test_list, completion):
    # Question 2a
    # YOUR CODE HERE
    raise NotImplementedError


def accuracy_on_set(test_list, completions):
    if not completions:
        return 0
    return sum(test_program(test_list, completion) for completion in completions) / len(
        completions
    )


def does_program_run_without_error(code):
    with tempfile.NamedTemporaryFile("w") as f:
        f.write(code)
        f.flush()
        try:
            subprocess.check_output(
                ["python3", f.name], stderr=subprocess.STDOUT, timeout=1
            )
            return True
        except subprocess.TimeoutExpired:
            return False
        except subprocess.CalledProcessError as e:
            return False


def just_prompt(prompt_type, instructions, text, example_test):
    return get_completions(
        prompt_type(instructions, text, example_test),
        num_completions=100,
        max_tokens=100,
        temperature=0.8,
    )


def check_syntactic(prompt_type, instructions, text, example_test):
    programs = just_prompt(prompt_type, instructions, text, example_test)
    # Question 2c
    # YOUR CODE HERE


def check_single_test(prompt_type, instructions, text, example_test):
    programs = just_prompt(prompt_type, instructions, text, example_test)
    # Question 2d
    # YOUR CODE HERE


def compute_accuracy(algorithm, prompt_type, instructions):
    print("Computing accuracy...")
    accs = []
    for i in tqdm.trange(2, 12):
        completions = algorithm(
            prompt_type, instructions, datum(i)["text"], datum(i)["example_test"]
        )
        acc = accuracy_on_set(datum(i)["test_list"], completions)
        accs.append(acc)
        print(f"Accuracy on datum {i}: {acc:.0%}")
    print(f"Average accuracy: {sum(accs) / len(accs):.1%}")


def renaming_robustness_example():
    # Question 2e
    # YOUR CODE HERE

    return prompt_1, prompt_2, example_test_case, test_cases


def renaming_robustness():
    prompt_1, prompt_2, example_test_case, test_cases = renaming_robustness_example()

    compl_1 = just_prompt(
        zero_shot_prompt, standard_instructions, prompt_1, example_test_case
    )
    compl_2 = just_prompt(
        zero_shot_prompt, standard_instructions, prompt_2, example_test_case
    )
    acc_1 = accuracy_on_set(test_cases, compl_1)
    acc_2 = accuracy_on_set(test_cases, compl_2)

    print(f"Accuracy on prompt 1: {acc_1:.0%}")
    print(f"Accuracy on prompt 2: {acc_2:.0%}")


def main():
    from sys import argv

    if len(argv) == 2 and argv[1] == "a":
        compute_accuracy(just_prompt, zero_shot_prompt, standard_instructions)
        return
    if len(argv) == 2 and argv[1] == "b":
        compute_accuracy(just_prompt, one_shot_prompt, standard_instructions)
        return
    if len(argv) == 2 and argv[1] == "c":
        compute_accuracy(check_syntactic, zero_shot_prompt, standard_instructions)
        return
    if len(argv) == 2 and argv[1] == "d":
        compute_accuracy(check_single_test, zero_shot_prompt, standard_instructions)
        return
    if len(argv) == 2 and argv[1] == "e":
        renaming_robustness()
        return
    print("Usage: python3 q2.py [a|b|c|d|e]")


if __name__ == "__main__":
    main()
