import re
import RAM

class While:
    def __init__(self, n, instructions = [], inputs = []):
        self.wordsize = n
        self.instructions = instructions
        self.nvar = 1+self.__check(instructions)
        self.inputs = inputs

    def __check(self,instructions):
        n = 0
        for inst in instructions:
            name, args = inst[0], inst[1:]
            if name in ['inc', 'dec', 'zero']:
                if len(args) != 1: raise ValueError(f'{name} requires 1 arguments ({len(args)} given)')
                n = max(n, *args)
            elif name == 'copy':
                if len(args) != 2: raise ValueError(f'{name} requires 2 arguments ({len(args)} given)')
                n = max(n,*args)
            elif name in ['copy','add','sub']:
                if len(args) != 3: raise ValueError(f'{name} requires 3 arguments ({len(args)} given)')
                n = max(n,*args)
            elif name == 'while':
                if len(args) != 2: raise ValueError(f'{name} requires 2 arguments ({len(args)} given)')
                n = max(n,args[0],self.__check(args[1]))
            elif name == 'for':
                if len(args) != 3: raise ValueError(f'{name} requires 3 arguments ({len(args)} given)')
                n = max(n,args[0],args[1],self.__check(args[2]))
            elif name == 'if':
                if len(args) != 4: raise ValueError(f'{name} requires 4 arguments ({len(args)} given)')
                n = max(n,args[0],args[1],self.__check(args[2]), self.__check(args[3]))
            else: 
                raise ValueError(f'Unknown instruction {name}')
        return n

    def __repr__(self):
        return f"WHILE program with wordsize {self.wordsize}, {self.nvar} variables"

    def __str(self,instructions,ind=0):
        L = []
        for inst in instructions:
            name = inst[0]
            if name == 'inc':
                L.append(' '*ind + f'x{inst[1]} ← x{inst[1]} + 1')
            elif name == 'dec':
                L.append(' '*ind + f'x{inst[1]} ← x{inst[1]} - 1')
            elif name == 'zero':
                L.append(' '*ind + f'x{inst[1]} ← 0')
            elif name == 'copy':
                L.append(' '*ind + f'x{inst[1]} ← x{inst[2]}')
            elif name == 'add':
                L.append(' '*ind + f'x{inst[1]} ← x{inst[2]} + x{inst[3]}')
            elif name == 'sub':
                L.append(' '*ind + f'x{inst[1]} ← x{inst[2]} - x{inst[3]}')
            elif name == 'while':
                L.append(' '*ind + f'while x{inst[1]} ≠ 0:')
                L.extend(self.__str(inst[2],ind+2))
            elif name == 'for':
                L.append(' '*ind + f'for x{inst[1]} = 0 to x{inst[2]}-1:')
                L.extend(self.__str(inst[3],ind+2))
            elif name == 'if':
                L.append(' '*ind + f'if x{inst[1]} > x{inst[2]}:')
                L.extend(self.__str(inst[3],ind+2))
                if len(inst[4]) > 0: 
                    L.append(' '*ind + 'else:')
                    L.extend(self.__str(inst[4],ind+2))

        return L

    def __str__(self):
        s = f'# wordsize: {self.wordsize}\n# inputs = {self.inputs}\n'
        return s + '\n'.join(self.__str(self.instructions))
    
    def __getitem__(self, i):
        assert i < self.nvar, f"The program uses only {self.nvar} variables"
        if 1 <= i <= len(self.inputs): return self.inputs[i-1]
        return None

    def __setitem__(self, i, v):
        assert 1 <= i <= len(self.inputs), f"Only inputs x1, ..., x{len(self.inputs)} can be set"
        assert v < 1 << self.wordsize, f"Inputs must be < 2**{self.wordsize} = {1<<self.wordsize}"
        self.inputs[i-1] = v

    @classmethod
    def read_code(cls,s):
        ws = re.compile(r'# *word[- ]*size *: *(\d+) *')
        inp = re.compile(r'# *inputs *: *(\[\d+ *(, *\d+ *)*\]) *')
        zero = re.compile(r'x(\d+) *← *0 *')
        inc = re.compile(r'x(\d+) *← *x\1 *\+ *1 *')
        dec = re.compile(r'x(\d+) *← *x\1 *- *1 *')
        copy = re.compile(r'x(\d+) *← *x(\d+) *')
        add = re.compile(r'x(\d+) *← *x(\d+) *\+ *x(\d+) *')
        sub = re.compile(r'x(\d+) *← *x(\d+) *- *x(\d+) *')
        whilenz = re.compile(r'while +x(\d+) *≠ *0 *')
        forloop = re.compile(r'for +x(\d+) *= *0 +to +x(\d+) *- *1 *')
        ifsup = re.compile(r'if +x(\d+) *> *x(\d+) *')
        ifelse = re.compile(r'else *')

        if not isinstance(s,str):
            s = s.read()
        
        L = [l for l in s.split('\n') if len(l) > 0]
        I = []
        inputs = []
        wordsize = 0
        for line in L:
            if m := ws.fullmatch(line):
                wordsize = int(m[1])
                continue
            if m := inp.fullmatch(line):
                inputs = eval(m[1])
                continue

            ind = len(line) - len(line.lstrip())
            line = line.split('#')[0]

            parts = line.split(':')
            if len(parts) == 1:
                instr = [l.strip() for l in parts[0].split(';') if len(l.strip()) > 0]
                I.extend([(ind,ins) for ins in instr])
            elif len(parts) == 2:
                I.append((ind, parts[0].strip()))
                instr = [l.strip() for l in parts[1].split(';') if len(l.strip()) > 0]
                I.extend([(ind+2,ins) for ins in instr])
            else:
                raise ValueError(f'Too many colons in line "{line}"')

        def __scope(l = 0, ind = 0):
            instructions = []
            while l < len(I):
                indent, ins = I[l]
                if ind is None:
                    ind = indent
                if indent < ind:
                    return l-1, instructions

                if m := zero.fullmatch(ins): instructions.append(('zero',int(m[1])))
                elif m := inc.fullmatch(ins): instructions.append(('inc',int(m[1])))
                elif m := dec.fullmatch(ins): instructions.append(('dec',int(m[1])))
                elif m := copy.fullmatch(ins): instructions.append(('copy',int(m[1]),int(m[2])))
                elif m := add.fullmatch(ins): instructions.append(('add',int(m[1]),int(m[2]),int(m[3])))
                elif m := sub.fullmatch(ins): instructions.append(('sub',int(m[1]),int(m[2]),int(m[3])))

                elif m := whilenz.fullmatch(ins):
                    l, W = __scope(l+1,None)
                    instructions.append(('while',int(m[1]),W))

                elif m := forloop.fullmatch(ins):
                    l, W = __scope(l+1,None)
                    instructions.append(('for',int(m[1]),int(m[2]),W))

                elif m := ifsup.fullmatch(ins):
                    l, W = __scope(l+1,None)
                    l += 1
                    if l < len(I):
                        if mm := ifelse.fullmatch(I[l][1]):
                            l, WW = __scope(l+1,None)
                            instructions.append(('if',int(m[1]),int(m[2]),W,WW))
                        else:
                            instructions.append(('if',int(m[1]),int(m[2]),W,[]))
                    else:
                        instructions.append(('if',int(m[1]),int(m[2]),W,[]))

                else:
                    raise ValueError(f'Instruction "{ins}" not parsed at line {l}')
                
                l += 1
            return l-1, instructions

        l, instructions = __scope()
        return cls(wordsize, instructions, inputs)

    def to_RAM(self):
        l = 0
        z = self.nvar

        def _translate(instructions, l = 0):
            L = []
            maxj = 0
            for inst in instructions:
                l += 1
                name = inst[0]
                if name in ['inc','dec','copy']:
                    L.append(inst)
                elif name == 'zero':
                    L.append(('copy', inst[1], z))
                elif name == 'while':
                    W = _translate(inst[2],l)
                    k = len(W)
                    L.append(('jump',inst[1],l+k+2))
                    L.extend(W)
                    L.append(('jump',z,l))
                    l += k+1
                    maxj = l+k+2
                else:
                    raise ValueError(f'Only minimal WHILE language supported (no instruction "{name}")')
            return L

        instructions = _translate(self.instructions) + [('stop',)]
        return RAM.RAM(self.wordsize, instructions, [0] + self.inputs)

    def __to_Python(self, instructions,m,ind = 0, verbose=False):
        L = []
        for inst in instructions:
            name = inst[0]
            if name == 'inc':
                L.append(' '*ind + f'X[{inst[1]}] = (X[{inst[1]}] + 1) % {m}')
            elif name == 'dec':
                L.append(' '*ind + f'X[{inst[1]}] = (X[{inst[1]}] - 1) % {m}')
            elif name == 'zero':
                L.append(' '*ind + f'X[{inst[1]}] = 0')
            elif name == 'copy':
                L.append(' '*ind + f'X[{inst[1]}] = X[{inst[2]}]')
            elif name == 'add':
                L.append(' '*ind + f'X[{inst[1]}] = (X[{inst[2]}] + X[{inst[3]}]) % {m}')
            elif name == 'sub':
                L.append(' '*ind + f'X[{inst[1]}] = (X[{inst[2]}] - X[{inst[3]}]) % {m}')
            elif name == 'while':
                L.append(' '*ind + f'while X[{inst[1]}] != 0:')
                L.extend(self.__to_Python(inst[2],m,ind+2, verbose))
            elif name == 'for':
                L.append(' '*ind + f'for X[{inst[1]}] in range(X[{inst[2]}]):')
                L.extend(self.__to_Python(inst[3],m,ind+2, verbose))
            elif name == 'if':
                L.append(' '*ind + f'if X[{inst[1]}] > X[{inst[2]}]:')
                L.extend(self.__to_Python(inst[3],m,ind+2,verbose))
                if len(inst[4]) > 0:
                    L.append(' '*ind + 'else:')
                    L.extend(self.__to_Python(inst[4],m,ind+2,verbose))
            if verbose: 
                inst_str = f'{name}({",".join(str(a) if isinstance(a,int) else "…" for a in inst[1:])})'
                L.append(' '*ind + f'print("{inst_str}:", X)')
        return L

    def to_Python(self, inputs=None, verbose=False):
        m = 1<<self.wordsize
        if inputs is None: 
            inputs = self.inputs
        else: 
            assert len(inputs) == len(self.inputs), f"Incorrect number of inputs ({len(inputs)} instead of {len(self.inputs)}"
        s = f'X = [0,{",".join(str(x) for x in inputs)}{",0"*(self.nvar-len(inputs)-1)}]\n'
        return s + '\n'.join(self.__to_Python(self.instructions,m,0,verbose))

    def run(self, inputs=None, output=[0], verbose=False):
        s = self.to_Python(inputs, verbose)
        if output == 'all':
            s += '\nOutput = X'
        else:
            s += '\nOutput = [X[v] for v in output]'
        d = {'output':output}
        exec(s, d)
        return d['Output']

    def __simplify(self, instructions, t):
        newinst = []
        for inst in instructions:
            name = inst[0]
            if name in ['inc','dec','zero','copy']:
                newinst.append(inst)
            elif name == 'while':
                newinst.append(('while',inst[1],self.__simplify(inst[2],t)))
            elif name == 'add':
                i,j,k = inst[1:]
                if i != j: newinst.append(('copy',i,j))
                newinst.append(('copy',t,k))
                newinst.append(('while',t,[('inc',i),('dec',t)]))
            elif name == 'sub':
                i,j,k = inst[1:]
                if i != j: newinst.append(('copy',i,j))
                newinst.append(('copy',t,k))
                newinst.append(('while',t,[('dec',i),('dec',t)]))

            elif name == 'for':
                i,j = inst[1],inst[2]
                body = self.__simplify(inst[3],t+1)
                newinst.extend([('zero',i),('copy',t,j)])
                newinst.append(('while', t, body+[('inc',i),('dec',t)]))

            elif name == 'if' :
                i, j = inst[1], inst[2]
                if len(inst[4]) > 0:
                    bodytrue = self.__simplify(inst[3],t+5)
                    bodyfalse = self.__simplify(inst[4],t+5)
                    newinst.extend([('zero',t),('inc',t),('copy',t+1,i),('copy',t+2,j)])
                    while1 = ('while', t+3, [('dec',t+1),('dec',t+2),('zero',t+3),('zero',t+4)])
                    while2 = ('while', t+4, bodytrue + [('zero',t+1),('zero',t),('zero',t+4)])
                    newinst.append(('while', t+1, [('zero',t+4),('inc',t+4),('copy',t+3,t+2),while1,while2]))
                    newinst.append(('while', t, bodyfalse + [('zero',t)]))
                else:
                    bodytrue = self.__simplify(inst[3],t+4)
                    newinst.extend([('copy',t,i),('copy',t+1,j)])
                    while1 = ('while', t+2, [('dec',t),('dec',t+1),('zero',t+2),('zero',t+3)])
                    while2 = ('while', t+3, bodytrue + [('zero',t),('zero',t+3)])
                    newinst.append(('while', t, [('zero',t+3),('inc',t+3),('copy',t+2,t+1),while1,while2]))

        return newinst 

    def simplify(self):
        newinst = self.__simplify(self.instructions, self.nvar)
        return While(self.wordsize, newinst, self.inputs) 



