AdventOfCode/2019/intcode.py

187 lines
4.6 KiB
Python
Raw Normal View History

2019-12-19 12:56:55 +01:00
"""Intcode interpreter."""
2019-12-07 16:06:06 +01:00
from enum import IntEnum
class Op(IntEnum):
ADD = 1
MUL = 2
IN = 3
OUT = 4
JNZ = 5
JEZ = 6
LT = 7
EQ = 8
2019-12-09 18:27:55 +01:00
SRB = 9
2019-12-07 16:06:06 +01:00
HALT = 99
class PMode(IntEnum):
POS = 0
IMM = 1
2019-12-09 18:27:55 +01:00
REL = 2
2019-12-07 16:06:06 +01:00
class Intcode(object):
2019-12-11 18:45:28 +01:00
def __init__(self, program, debug=False):
2019-12-19 17:32:24 +01:00
self._program = program.copy()
2019-12-09 18:27:55 +01:00
self._memory = program.copy()
2019-12-07 16:06:06 +01:00
self.ip = 0
2019-12-09 18:27:55 +01:00
self.rel_base = 0
2019-12-07 16:06:06 +01:00
self.ctx_op = None
self.ctx_modes = None
self.halt = False
2019-12-09 23:03:48 +01:00
self.debug = debug
2019-12-19 12:56:55 +01:00
self._inputs_list = None
2019-12-19 17:32:24 +01:00
def reset(self):
self._memory = self._program.copy()
self.ip = 0
self.rel_base = 0
self.halt = False
@staticmethod
def parse_file(filename):
with open(filename, "rt") as input_file:
return Intcode.parse_text(input_file.read().rstrip())
2019-12-07 16:06:06 +01:00
@staticmethod
2019-12-19 17:32:24 +01:00
def parse_text(text):
2019-12-07 16:06:06 +01:00
return [int(i) for i in text.rstrip().split(",")]
2019-12-19 12:56:55 +01:00
2019-12-09 23:03:48 +01:00
def log(self, message, *args):
if self.debug:
print("debug:", message, *args)
2019-12-19 12:56:55 +01:00
2019-12-09 18:27:55 +01:00
def run(self, inputs=None):
2019-12-19 12:56:55 +01:00
self._inputs_list = inputs or []
2019-12-07 16:06:06 +01:00
handlers = self.get_handlers()
while not self.halt:
self.read_code()
2019-12-09 23:03:48 +01:00
handlers.get(self.ctx_op)()
2019-12-07 16:06:06 +01:00
def get_handlers(self):
return {
Op.ADD: self.op_add,
Op.MUL: self.op_mul,
Op.IN: self.op_in,
Op.OUT: self.op_out,
Op.JNZ: self.op_jnz,
Op.JEZ: self.op_jez,
Op.LT: self.op_lt,
Op.EQ: self.op_eq,
2019-12-09 18:27:55 +01:00
Op.SRB: self.op_srb,
2019-12-07 16:06:06 +01:00
Op.HALT: self.op_halt,
}
2019-12-09 18:27:55 +01:00
def mem_get(self, pos):
self._check_memory_limits(pos)
return self._memory[pos]
2019-12-19 12:56:55 +01:00
2019-12-09 18:27:55 +01:00
def mem_set(self, pos, value):
self._check_memory_limits(pos)
self._memory[pos] = value
def _check_memory_limits(self, index):
if index >= len(self._memory):
2019-12-09 23:03:48 +01:00
new_size = index + 1
self._memory += [0] * (new_size - len(self._memory))
2019-12-09 18:27:55 +01:00
2019-12-07 16:06:06 +01:00
def read_code(self):
2019-12-09 23:03:48 +01:00
raw_code = self.mem_get(self.ip)
code = raw_code % 100
self.ctx_op = Op(code)
code = raw_code // 100
2019-12-07 16:06:06 +01:00
self.ctx_modes = []
for _ in range(0, 3):
self.ctx_modes.append(PMode(code % 10))
code //= 10
2019-12-09 23:03:48 +01:00
self.log("read_code", raw_code, self.ctx_op, self.ctx_modes)
2019-12-07 16:06:06 +01:00
2019-12-09 23:03:48 +01:00
def param(self, index, pointer=False):
2019-12-07 16:06:06 +01:00
mode = self.ctx_modes[index - 1]
if mode == PMode.POS:
2019-12-09 18:27:55 +01:00
address = self.mem_get(self.ip + index)
2019-12-09 23:03:48 +01:00
return address if pointer else self.mem_get(address)
2019-12-07 16:06:06 +01:00
elif mode == PMode.IMM:
2019-12-09 18:27:55 +01:00
return self.mem_get(self.ip + index)
elif mode == PMode.REL:
2019-12-09 23:03:48 +01:00
address = self.rel_base + self.mem_get(self.ip + index)
return address if pointer else self.mem_get(address)
2019-12-19 12:56:55 +01:00
2019-12-09 23:03:48 +01:00
def write_at_param(self, param_offset, value):
self.mem_set(self.param(param_offset, pointer=True), value)
2019-12-07 16:06:06 +01:00
def op_add(self):
2019-12-09 23:03:48 +01:00
self.write_at_param(3, self.param(1) + self.param(2))
2019-12-07 16:06:06 +01:00
self.ip += 4
def op_mul(self):
2019-12-09 23:03:48 +01:00
self.write_at_param(3, self.param(1) * self.param(2))
2019-12-07 16:06:06 +01:00
self.ip += 4
def op_in(self):
2019-12-11 18:45:28 +01:00
self.mem_set(self.param(1, pointer=True), self.input_data())
2019-12-07 16:06:06 +01:00
self.ip += 2
def op_out(self):
2019-12-11 18:45:28 +01:00
self.output_data(self.param(1))
2019-12-07 16:06:06 +01:00
self.ip += 2
def op_jnz(self):
if self.param(1) != 0:
self.ip = self.param(2)
else:
self.ip += 3
def op_jez(self):
if self.param(1) == 0:
self.ip = self.param(2)
else:
self.ip += 3
def op_lt(self):
2019-12-09 23:03:48 +01:00
self.write_at_param(3, int(self.param(1) < self.param(2)))
2019-12-07 16:06:06 +01:00
self.ip += 4
def op_eq(self):
2019-12-09 23:03:48 +01:00
self.write_at_param(3, int(self.param(1) == self.param(2)))
2019-12-07 16:06:06 +01:00
self.ip += 4
2019-12-09 18:27:55 +01:00
def op_srb(self):
self.rel_base += self.param(1)
self.ip += 2
2019-12-07 16:06:06 +01:00
def op_halt(self):
self.halt = True
2019-12-11 18:45:28 +01:00
def input_data(self):
2019-12-19 12:56:55 +01:00
return self._inputs_list.pop(0)
2019-12-11 18:45:28 +01:00
def output_data(self, data):
print(">", data)
2019-12-21 22:04:47 +01:00
class AIC(Intcode):
"""ASCII-enabled Intcode interpreter."""
LN = ord("\n")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_text = ""
def input_data(self):
data, self.input_text = self.input_text[0], self.input_text[1:]
return ord(data)
def output_data(self, data):
if data > 256:
self.handle_int_output(data)
else:
print(chr(data), end="")
def handle_int_output(self, i):
print(f"Int output: {i}.")