from yaml import safe_load
from re import sub

class TuringMachine:
    def __init__(self, table={}, init=None, tape='', blank=' ', alphabet={}, ntape=None):
        self.table = {}
        self.alphabet = [blank] + [c for c in alphabet if c != blank]
        self.ntape = ntape
        self.init = init
        self.tape = tape
        for state in table:
            self.table[state] = {}
            for read in table[state]:
                self[state,read] = table[state][read]
        n = self.__check_ntape()
        if n == 0: raise ValueError("inconsistent length of read/write")
        if ntape is None: self.ntape = n
        elif ntape != n: ValueError(f'inconsistent value of ntape ({ntape} instead of {n}?)')
        else: self.ntape = n


    def __check_ntape(self):
        n = None
        for state in self.table:
            for read in self.table[state]:
                if n is None: n = len(read)
                else:
                    if n != len(read):
                        return 0
        return n
                
    def __repr__(self):
        nstate = len(self.table)
        return f"Turing machine with {nstate} states and {len(self.alphabet)} symbols"
    
    @property
    def states(self):
        return list(self.table)
    
    @property
    def blank(self):
        return self.alphabet[0]

    def __getitem__(self, pair):
        state, read = pair
        if state in self.table:
            if read in self.table[state]:
                return self.table[state][read]
        return None

    def __setitem__(self, pair, triple):
        state, read = pair
        if state not in self.table:
            self.table[state] = {}
        self.table[state][read] = triple
        if triple[2] not in self.table:
            self.table[triple[2]] = {}
        for c in read + triple[0]:
            if c not in self.alphabet: 
                self.alphabet.append(c)
        n = self.__check_ntape()
        if self.ntape is None: self.ntape = n
        if n != self.ntape or n == 0:
            raise ValueError("inconsistent length of read/write")
        if self.init is None: self.init = state

    def __contains__(self, state):
        return state in self.table

    def __eq__(self, other):
        for state in self.table:
            for read in self.table[state]:
                written, move, newst = self[state,read]
                if other[state,read] != (written,move,newst): 
                    return False
        for state in other.table:
            for read in other.table[state]:
                written, move, newst = other[state,read]
                if self[state,read] != (written,move,newst): 
                    return False
        return True

    @classmethod
    def from_yaml(cls, s):
        if not isinstance(s, str):
            s = s.read()
        s = sub('\\[([^]]*)\\]:', '"[\\1]":', s)
        doc = safe_load(s)
        if 'source code' in doc:
            s = doc['source code']
            doc = safe_load(s)

        tape = str(doc['input'])
        blank = str(doc['blank'])
        init = str(doc['start state'])
        D = {}
        for state in doc['table']:
            D[str(state)] = {}
            trans = doc['table'][state]
            if trans is None: continue
            for read in trans:
                value = trans[read]
                written = None
                newst = str(state)
                if isinstance(value,dict):
                    if 'write' in value:
                        written = str(value.pop('write'))
                    if len(D) > 0:
                        move = next(iter(value))
                        if value[move] is not None:
                            newst = str(value[move])
                    #if 'R' in value:
                    #    move = 'R'
                    #    if value['R'] is not None: newst = str(value['R'])
                    #elif 'L' in value:
                    #    move = 'L'
                    #    if value['L'] is not None: newst = str(value['L'])
                else:
                    move = value

                if isinstance(read,str) and read[0] == '[' and read[-1] == ']':
                    read = read[1:-1].split(',')
                    for i in range(len(read)):
                        read[i] = read[i].strip()
                        if read[i][0] == read[i][-1] == "'": read[i] = read[i][1:-1]
                else:
                    read = [str(read)]
                for r in read:
                    if written is None:
                        D[str(state)][r] = (r, move, newst)
                    else:
                        D[str(state)][r] = (written, move, newst)

        return cls(D, init, tape, blank)

    @classmethod
    def universal(cls, filename="universelle.yaml"):
        return cls.from_yaml(open(filename))

    def __tmio(self):
        L = []
        L.append(f"input: '{self.tape}'")
        L.append(f"blank: '{self.blank}'")
        L.append(f"start state: '{self.init}'")
        L.append(f"table:")
        for state in self.table:
            L.append(f"  '{state}':")
            transitions = {}
            for read in self.table[state]:
                written, move, newst = self[state,read]
                if written == read: written = None
                if newst == state: newst = None
                if written is None:
                    if newst is None:
                        if move not in transitions: transitions[move] = []
                        transitions[move].append(read)
                    else:
                        s = f"{{{move}: '{newst}'}}"
                        if s not in transitions: transitions[s] = []
                        transitions[s].append(read)
                elif newst is None:
                    s = f"{{write: '{written}', {move}}}"
                    if s not in transitions: transitions[s] = []
                    transitions[s].append(read)
                else:
                    s = f"{{write: '{written}', {move}: '{newst}'}}"
                    if s not in transitions: transitions[s] = []
                    transitions[s].append(read)
            
            for s in transitions:
                L.append(f"    {transitions[s]}: {s}")
        return '\n'.join(L)

    def __tmscom(self):
        L = ['name: Turing machine']
        L.append(f'init: {self.init}')
        L.append(f'accept: {"oui" if "oui" in self.states else self.init}')
        for state in self.table:
            for read in self.table[state]:
                written, move, newst = self[state,read]
                L.append(f'{state},{",".join(l.replace(self.blank,"_") for l in read)}')
                L.append(f'{newst},{",".join(l.replace(self.blank,"_") for l in written)},{",".join(m for m in move.replace("L","<").replace("R",">"))}')
                L.append('')
            
        return '\n'.join(L)

    def str(self, simulator='turingmachine.io'):
        if simulator == 'turingmachine.io':
            return self.__tmio()
        elif simulator == 'turingmachinesimulator.com':
            return self.__tmscom()

        raise ValueError(f'Unknown simulator {simulator}')

    def __str__(self, bla=0):
        return self.str()

    def run(self, wall = '#', verb = False, count=False):
        n = self.ntape
        if n > 1:
            tapes = [(t,0) for t in self.tape.split(wall)]
        else:
            tapes = [(self.tape,0)]
        assert len(tapes) == n, f'Not enough tapes given ({len(tapes)} instead of {n})'
        for i in range(n):
            if len(tapes[i][0]) == 0:
                tapes[i] = (self.blank,tapes[i][1])
        q = self.init
        if count: c = 0
        while True:
            if q not in self.states:
                break
            read = ''.join(t[p] for t,p in tapes)
            
            if self[q, read] is None:
                break

            written, move, newst = self[q, read]
            for i in range(n):
                t, p = tapes[i]
                t = t[:p] + written[i] + t[p+1:]
                if move[i] == 'R':
                    p += 1
                    if p == len(t): 
                        t += self.blank
                if move[i] == 'L':
                    if p > 0: p -= 1
                    else: t = self.blank + t
                tapes[i] = (t,p)

            q = newst
            if verb:
                print(f"{q}:\t{wall.join(t[:p]+'|'+t[p:] for t,p in tapes)}")
            if count: c += 1
        
        tape = wall.join((t[:p]+'|'+t[p:]).strip() for t,p in tapes)
        if count: return q, tape, c
        return q, tape

    def binary_alphabet(self, code=None, clean=True, verb=False):
        """
        code: traduction lettre → code binaire (dictionnaire)
        clean: supprimer les états inutiles ?
        verb: affichage
        """
        def make_trans(state,lu,ecrit,depl,nv):
            blanc, B = self.blank, '_'
            T[state.replace(blanc,B),lu] = (ecrit,depl,nv.replace(blanc,B))
            if verb: print(f"  {state.replace(blanc,B)},{lu} → {ecrit},{depl},{nv.replace(blanc,B)}")

        assert self.ntape == 1, f'Not implemented for machines with {self.ntape} tapes'
    
        if code is None:
            assert len(self.alphabet) > 3, "The machine is already binary"
            nonblank_symbols = len(self.alphabet)-1
            l = max(1,(nonblank_symbols-1).bit_length())
            code = {a:bin(i)[2:] for i,a in enumerate(self.alphabet[1:])}
            code = {a:'0' * (l-len(code[a])) + code[a] for a in code}
            if verb: print(f"Computed code: {code}")
        else:
            # Vérification : chaque code a même longueur l
            for c in code: 
                l = len(code[c])
                break
            assert all(len(code[c]) == l for c in code), "codes de longueurs différentes"
    
        # Code de blanc = l symboles blancs (par simplicité)
        if self.blank not in code: code[self.blank] = self.blank*l
        # ICI : remplacer par init de T
        T = TuringMachine(alphabet=[self.blank, '0','1'], init=self.init, tape = ''.join(code[a] for a in self.tape))
        
        ### Création des transitions
        for state in self.table:
            for read in self.table[state]:
                # transition q, lu → ecrit, M, nv
                written, move, newst = self[state,read]
                if verb: print(f"Transition: {state},{read} → {written},{move},{state}")
    
                # Codes des symboles lu / écrit
                code_lu = code[read]
                code_ecrit = code[written]
                if verb: print(f"Codes: {read}={code_lu}, {written}={code_ecrit}")
    
                # Lecture 1er bit b : transition q → q:b
                b = code_lu[0]
                make_trans(state,b,b,'R',state+':'+b)
    
                # Lecture bits b suivants : transitions q:m → q:mb
                for i in range(1,l-1):
                    b = code_lu[i]
                    make_trans(state+':'+code_lu[:i],b,b,'R',state+':'+code_lu[:i+1])
    
                # Lecture dernier bit b, écriture 1er bit w: transition q:m → q':m'M
                b = code_lu[l-1]
                w = code_ecrit[l-1]
                make_trans(state+':'+code_lu[:-1],b,w,'L',newst+':'+code_ecrit[:-1]+move)
    
                # Écriture bits b suivants : transition q':mbM → q':mM
                for i in range(1,l):
                    b = code_ecrit[i-1]
                    for s in T.alphabet:
                        make_trans(newst+':'+code_ecrit[:i]+move,s,b,'L',newst+':'+code_ecrit[:i-1]+move)
    
                # Écriture dernier bit b: transition q':bM → q':MMM
                b = code_ecrit[0]
                for s in T.alphabet:
                    make_trans(newst+':'+b+move,s,b,move,newst+':'+move*(l-1))
    
                # Déplacement de la tête : transitions q':MMM → q':MM
                for i in range(1,l-1):
                    for s in T.alphabet:
                        make_trans(newst+':'+move*(i+1),s,s,move,newst+':'+move*i)
    
    
                # Dernier déplacement : transition q':M → q'
                for s in T.alphabet:
                    make_trans(newst+':'+move,s,s,move,newst)
    
        ## Nettoyage : suppression des états non utilisés
        #if nettoie: T = nettoyage(T)
    
        return T #, init, ''.join(code[c] for c in ruban)

    def left_bounded(self, wall='#'):
        assert self.ntape == 1, f'Not implemented for machines with {self.ntape} tapes'
        assert wall not in self.alphabet, f'The wall {wall} is already in the alphabet'
        # copie de self
        T = TuringMachine(table = self.table, alphabet = self.alphabet + [wall], init='s:'+self.init, tape = wall+self.tape)
        
        # Pour chaque état, ajout d'une boucle de décalage si lecture de '#'
        for state in self.table:
            T[state,wall] = (wall,'R',state+':d')
            for read in self.alphabet:
                r = read.replace(' ','_')
                T[state+':d',read] = (' ','R',state+':d'+r)
                for write in self.alphabet:
                    w = write.replace(' ','_')
                    T[state+':d'+r,write] = (read,'R',state+':d'+w)
                T[state+':d'+r,' '] = (read, 'L',state+':r')
                T[state+':r',read] = (read, 'L', state+':r')
            T[state+':r',wall] = (wall,'R',state)
    
        T['s:'+self.init,wall]  = (wall,'R',self.init)
    
        return T 
                
    def universal_encoding(self, verb = False):
        assert self.ntape == 1, f'Not implemented for machines with {self.ntape} tapes'
        assert len(self.alphabet) == 3, 'can only encode machines with binary alphabet'

        universal_alphabet = [' ','0','1']
        
        N2E = { 1: self.init }
        E2N = { self.init: 1 }
        c = 2
        for state in self.table:
            if state != self.init:
                N2E[c] = state
                E2N[state] = c
                c += 1
    
        if verb: print(f"State encoding: {E2N}")

        tape = '+'
        for i in range(1,c):
            for s in self.alphabet:
                if s in self.table[N2E[i]]:
                    written, move, newst = self[N2E[i],s]
                    tape += '1' * E2N[newst] + move + written
                else:
                    tape += '0'
            tape += '|'
    
        if len(self.tape) > 0:
            first_symb = self.tape[0]
            first_symb = first_symb.replace(self.blank,'b') 
            first_symb = first_symb.replace(self.alphabet[1],'o') 
            first_symb = first_symb.replace(self.alphabet[2],'i') 
        else: first_symb = 'b'
        tape += first_symb
        tape += self.tape[1:]
        tape = tape.replace(self.blank,'_')
        tape = tape.replace(self.alphabet[1],'0')
        tape = tape.replace(self.alphabet[2],'1')
    
        return tape

    def two_to_one_tape(self, wall='#', marked=None):
        assert wall not in self.alphabet, f'The wall {wall} is already in the alphabet'
        assert self.ntape == 2, f'The machine needs to have two tapes, not {self.ntape}'

        alphabet = self.alphabet
        if marked is None: marked = {}

        if self.blank not in marked:
            assert '_' not in alphabet, "automatic attribution of marked symbols failed"
            marked[self.blank] = '_'

        o = 97
        for l in alphabet[1:]:
            if l not in marked:
                while chr(o) in self.alphabet: o += 1
                marked[l] = chr(o)
                o += 1
        if o > 97: print(f'Automatic marked symbols: {marked}')

        tape1,tape2 = self.tape.split(wall)
        
        if len(tape1) == 0: tape1 = self.blank
        if len(tape2) == 0: tape2 = self.blank

        tape = marked[tape1[0]] + tape1[1:] + wall + marked[tape2[0]] + tape2[1:]
        T = TuringMachine(alphabet=alphabet+[wall]+[marked[c] for c in alphabet[1:]], init=self.init, tape = tape)

        for state in self.table:
            for read in self.table[state]:
                written,move,newst = self[state,read]
                r1 = read[0]
                r2 = read[1]
                wr1 = written[0]
                wr2 = written[1]
                m1 = move[0]
                m2 = move[1]
                
                name = state+':'+r1+r2

                T[state,marked[r1]] = (marked[r1],'R',state+':'+r1)
                for l in alphabet[1:]: T[state,l] = (l,'R',state)

                T[state+':'+r1,marked[r2]] = (marked[r2],'L',name+'1')
                for l in alphabet[1:]+[wall]: T[state+':'+r1,l] = (l,'R',state+':'+r1)

                T[name+'1',T.blank] = (T.blank,'R',name+'2')
                for l in T.alphabet[1:]: T[name+'1',l] = (l,'L',name+'1')

                T[name+'2',marked[r1]] = (wr1,m1,name+'3')
                for l in alphabet[1:]: T[name+'2',l] = (l,'R',name+'2')

                for l in alphabet[1:]: T[name+'3',l] = (marked[l],'R',name+'4')
                T[name+'3',wall] = (wall, 'L', name+'l')
                for l in alphabet[1:]: 
                    T[name+'l',l] = (marked[T.blank],'L',name+'l:'+l)
                    for ll in alphabet[1:]: T[name+'l:'+l, ll] = (l, 'L', name+'l:'+ll)
                    T[name+'l:'+l,self.blank] = (l, 'R', name+'lend')
                for l in T.alphabet:
                    if l != wall:
                        T[name+'lend',l] = (l,'R',name+'lend')
                    else:
                        T[name+'lend',wall] = (wall, 'R', name+'4')

                T[name+'4',marked[r2]] = (wr2,m2,name+'5')
                for l in alphabet[1:]+[wall]: T[name+'4',l] = (l,'R',name+'4')

                for l in alphabet[1:]: T[name+'5',l] = (marked[l],'L',name+'6')
                T[name+'5',wall] = (wall, 'R', name+'r')
                for l in alphabet[1:]: 
                    T[name+'r',l] = (marked[T.blank],'R',name+'r:'+l)
                    for ll in alphabet[1:]: T[name+'r:'+l, ll] = (l, 'R', name+'r:'+ll)
                    T[name+'r:'+l,self.blank] = (l, 'L', name+'rend')
                for l in T.alphabet:
                    if l != wall:
                        T[name+'rend',l] = (l,'L',name+'rend')
                    else:
                        T[name+'rend',wall] = (wall, 'L', name+'6')

                T[name+'6',T.blank] = (T.blank,'R',newst)
                for l in T.alphabet[1:]: T[name+'6',l] = (l,'L',name+'6')

            return T

