-rwxr-xr-x 19249 nttcompiler-20220411/scripts/unroll2opt
#!/usr/bin/env python3
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + Literal('constant').suppress()
+ lparen + number + comma + number + rparen
).setParseAction(group('constant'))
for binary in ['__sub__','__rshift__','__lshift__','LShR']:
assignment |= (
name + equal + Literal(binary).suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group(binary))
assignment |= (
name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignment |= (
name + equal + Literal('SignExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('SignExt'))
assignment |= (
name + equal + Literal('ZeroExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('ZeroExt'))
for manyary in ['Concat','__or__','__and__','__add__','__mul__']:
assignment |= (
name + equal + Literal(manyary).suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group(manyary))
assignment |= (
name + equal + name).setParseAction(group('copy')
)
assignments = ZeroOrMore(assignment) + StringEnd()
program = sys.stdin.read()
program = assignments.parseString(program)
program = list(program)
nextvalue = 0
# indexed by variable name:
value = {}
# indexed by value:
operation = {}
parents = {}
bits = {}
def input(v):
global nextvalue
y = v.split('_')
nextvalue += 1
value[v] = nextvalue
operation[nextvalue] = ['input',v]
parents[nextvalue] = []
bits[nextvalue] = int(y[-1])
for p in program:
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'copy':
if p[2] not in value:
input(p[2])
value[p[1]] = value[p[2]]
continue
nextvalue += 1
operation[nextvalue] = [p[0]]
if p[0] == 'constant':
parents[nextvalue] = []
operation[nextvalue] += [int(p[2]),int(p[3])]
bits[nextvalue] = int(p[2])
elif p[0] in ['__sub__','__rshift__','__lshift__','LShR']:
# binary size-preserving operation
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] in ['__or__','__and__','__add__','__mul__']:
b = bits[value[p[2]]]
assert all(b == bits[value[v]] for v in p[2:])
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = b
elif p[0] == 'Concat':
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = sum(bits[v] for v in parents[nextvalue])
elif p[0] == 'Extract':
top = int(p[3])
bot = int(p[4])
assert top >= bot
assert bits[value[p[2]]] > top
assert bot >= 0
operation[nextvalue] += [top,bot]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = top + 1 - bot
elif p[0] in ('SignExt','ZeroExt'):
morebits = int(p[3])
operation[nextvalue] += [morebits]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = bits[value[p[2]]] + morebits
else:
raise Exception('unknown internal operation %s' % p[0])
value[p[1]] = nextvalue
optloop = 0
progress = True
while progress:
progress = False
print('# opt loop %d' % optloop)
optloop += 1
constants = {}
for v in operation:
if operation[v][0] != 'constant': continue
key = (operation[v][1],operation[v][2])
constants[key] = v
# for any __mul__(mulhrs16(mulhi16(x,A),B),C)
# want to have x minus that
differences = set()
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences.add((parents[v][0],parents[v][1]))
hihrslo = []
for v in operation:
if operation[v][0] != '__mul__': continue
c,C = parents[v]
if operation[C][0] != 'constant':
C,c = parents[v]
if operation[C][0] != 'constant': continue
if operation[c][0] != 'mulhrs16': continue
b,B = parents[c]
if operation[B][0] != 'constant':
B,b = parents[c]
if operation[B][0] != 'constant': continue
if operation[b][0] != 'mulhi16': continue
a,A = parents[b]
if operation[A][0] != 'constant':
A,a = parents[b]
if operation[A][0] != 'constant': continue
if (a,v) not in differences:
hihrslo += [(a,v)]
for a,v in hihrslo:
nextvalue += 1
operation[nextvalue] = ['__sub__']
parents[nextvalue] = [a,v]
bits[nextvalue] = bits[v]
# print('# providing v%d-v%d' % (a,v))
# for any __mul__(mulhrs16(x,A),B)
# want to have x minus that
# ... but limit A to constants <=4 to save time
differences = set()
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences.add((parents[v][0],parents[v][1]))
hihrslo = []
for v in operation:
if operation[v][0] != '__mul__': continue
b,B = parents[v]
if operation[B][0] != 'constant':
B,b = parents[v]
if operation[B][0] != 'constant': continue
if operation[B][2] > 4: continue
if operation[b][0] != 'mulhi16': continue
a,A = parents[b]
if operation[A][0] != 'constant':
A,a = parents[b]
if operation[A][0] != 'constant': continue
if (a,v) not in differences:
hihrslo += [(a,v)]
for a,v in hihrslo:
nextvalue += 1
operation[nextvalue] = ['__sub__']
parents[nextvalue] = [a,v]
bits[nextvalue] = bits[v]
# print('# providing v%d-v%d' % (a,v))
# given a-b, try ...+a+(c-b) -> ...+c+(a-b)
# (but do not mark as progress)
differences = {}
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences[parents[v][0],parents[v][1]] = v
differenceflips = []
for v in operation:
if operation[v][0] != '__add__': continue
todo = None
for j in range(len(parents[v])):
for i in range(j):
A,D = parents[v][i],parents[v][j]
for a,d in (A,D),(D,A):
if operation[d][0] != '__sub__': continue
c,b = parents[d]
if (a,b) not in differences: continue
todo = (v,i,j,c,differences[a,b])
if todo == None: continue
differenceflips += [todo]
for v,i,j,c,ab in differenceflips:
parents[v][i] = c
parents[v][j] = ab
# given a-b, do ...+b+(x-a) -> (...+x)-(a-b)
differences = {}
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences[parents[v][0],parents[v][1]] = v
differenceflips = []
for v in operation:
if operation[v][0] != '__add__': continue
todo = None
for j in range(len(parents[v])):
for i in range(j):
B,D = parents[v][i],parents[v][j]
for b,d in (B,D),(D,B):
if operation[d][0] != '__sub__': continue
x,a = parents[d]
if (a,b) not in differences: continue
todo = (v,i,j,x,differences[a,b])
if todo == None: continue
differenceflips += [todo]
for v,i,j,x,ab in differenceflips:
if len(parents[v]) == 2:
plusx = x
else:
nextvalue += 1
operation[nextvalue] = ['__add__']
parents[nextvalue] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j]
parents[nextvalue] += [x]
bits[nextvalue] = bits[v]
plusx = nextvalue
operation[v] = ['__sub__']
parents[v] = [plusx,ab]
progress = True
# given a-b, do (c+b)-a -> c-(a-b)
differences = {}
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences[parents[v][0],parents[v][1]] = v
cablist = []
for v in operation:
if operation[v][0] != '__sub__': continue
s,a = parents[v]
if operation[s][0] != '__add__': continue
if len(parents[s]) != 2: continue
C,B = parents[s]
todo = None
for b,c in (B,C),(C,B):
if (a,b) in differences:
todo = (v,c,differences[a,b])
if todo == None: continue
cablist += [todo]
for v,c,ab in cablist:
parents[v] = (c,ab)
progress = True
if 0:
# a-b is high priority for simplifying a+b+...
differences = set()
for v in operation:
if operation[v][0] != '__sub__': continue
if len(parents[v]) != 2: continue
differences.add((parents[v][0],parents[v][1]))
differencesums = {}
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) != 2: continue
if (parents[v][0],parents[v][1]) not in differences: continue
differencesums[parents[v][0],parents[v][1]] = v
differencesums[parents[v][1],parents[v][0]] = v
simplifysums = []
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) <= 2: continue
todo = None
for j in range(len(parents[v])):
for i in range(j):
if (parents[v][i],parents[v][j]) in differences:
todo = (v,i,j)
if todo == None: continue
simplifysums += [todo]
for v,i,j in simplifysums:
if (parents[v][i],parents[v][j]) not in differencesums:
nextvalue += 1
operation[nextvalue] = ['__add__']
parents[nextvalue] = [parents[v][i],parents[v][j]]
bits[nextvalue] = bits[v]
differencesums[parents[v][i],parents[v][j]] = nextvalue
differencesums[parents[v][j],parents[v][i]] = nextvalue
u = differencesums[parents[v][i],parents[v][j]]
parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j]
parents[v] += [u]
progress = True
if progress: continue
# use a+b to simplify a+b+c+...
pairsums = {}
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) != 2: continue
pairsums[parents[v][0],parents[v][1]] = v
pairsums[parents[v][1],parents[v][0]] = v
simplifysums = []
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) <= 2: continue
s = None
for j in range(len(parents[v])):
for i in range(j):
if (parents[v][i],parents[v][j]) in pairsums:
s = (v,i,j,pairsums[parents[v][i],parents[v][j]])
if s == None: continue
simplifysums += [s]
for v,i,j,u in simplifysums:
parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j]
parents[v] += [u]
progress = True
# (a+b+c)-d, (a+b+d)-c -> (a+b)+(c-d),(a+b)-(c-d)
abcdpatterns = {}
abcdrewrite = []
for v in operation:
if operation[v][0] != '__sub__': continue
s,d = parents[v]
if operation[s][0] != '__add__': continue
if len(parents[s]) != 3: continue
A,B,C = parents[s]
todo = None
for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A):
if (a,b) not in abcdpatterns:
abcdpatterns[a,b] = []
for u,uc,ud in abcdpatterns[a,b]:
if (uc,ud) == (d,c):
todo = (v,u,a,b,c,d)
if todo == None:
for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A):
abcdpatterns[a,b] += [(v,c,d)]
continue
abcdrewrite += [todo]
abcdrewritten = set()
for v,u,a,b,c,d in abcdrewrite:
if v in abcdrewritten: continue
if u in abcdrewritten: continue
nextvalue += 1
operation[nextvalue] = ['__add__']
parents[nextvalue] = [a,b]
bits[nextvalue] = bits[v]
ab = nextvalue
nextvalue += 1
operation[nextvalue] = ['__sub__']
parents[nextvalue] = [c,d]
bits[nextvalue] = bits[v]
cd = nextvalue
operation[v] = ['__add__']
parents[v] = [ab,cd]
abcdrewritten.add(v)
operation[u] = ['__sub__']
parents[u] = [ab,cd]
abcdrewritten.add(u)
progress = True
# a+b+c, a+b+d -> (a+b)+c, (a+b)+d
abcpatterns = {}
abcrewrite = []
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) != 3: continue
A,B,C = parents[v]
todo = None
for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A):
if (a,b) not in abcpatterns:
abcpatterns[a,b] = []
for u,d in abcpatterns[a,b]:
todo = (v,u,a,b,c,d)
if todo == None:
for a,b,c in (A,B,C),(A,C,B),(B,A,C),(B,C,A),(C,A,B),(C,B,A):
abcpatterns[a,b] += [(v,c)]
continue
abcrewrite += [todo]
abcrewritten = set()
for v,u,a,b,c,d in abcrewrite:
if v in abcrewritten: continue
if u in abcrewritten: continue
nextvalue += 1
operation[nextvalue] = ['__add__']
parents[nextvalue] = [a,b]
bits[nextvalue] = bits[v]
ab = nextvalue
operation[v] = ['__add__']
parents[v] = [ab,c]
abcrewritten.add(v)
operation[u] = ['__add__']
parents[u] = [ab,d]
abcrewritten.add(u)
progress = True
if 1:
# (a-b)+mulhi16(...) -> a+(mulhi16(...)-b)
subaddmul = []
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) != 2: continue
c,m = parents[v]
if operation[m][0] != 'mulhi16':
m,c = parents[v]
if operation[m][0] != 'mulhi16':
continue
if operation[c][0] != '__sub__': continue
a,b = parents[c]
subaddmul += [(v,a,b,m)]
for v,a,b,m in subaddmul:
nextvalue += 1
operation[nextvalue] = ['__sub__']
parents[nextvalue] = [m,b]
bits[nextvalue] = bits[v]
operation[v] = ['__add__']
parents[v] = [a,nextvalue]
progress = True
# extract x+((x>>(b-1))&u) from sums
reduce = []
for v in operation:
if operation[v][0] != '__add__': continue
if len(parents[v]) < 3: continue
b = bits[v]
vreduce = False
for i in range(len(parents[v])):
if vreduce: continue
y = parents[v][i]
if operation[y][0] != '__and__': continue
t,u = parents[y]
if operation[t][0] != '__rshift__':
u,t = parents[y]
if operation[t][0] != '__rshift__': continue
x,d = parents[t]
if operation[d] != ['constant',b,b-1]: continue
for j in range(len(parents[v])):
if vreduce: continue
if parents[v][j] != x: continue
reduce += [(v,x,y)]
vreduce = True
for v,x,y in reduce:
i = parents[v].index(y)
j = parents[v].index(x)
assert i != j
nextvalue += 1
operation[nextvalue] = ['__add__']
parents[nextvalue] = [x,y]
bits[nextvalue] = bits[v]
parents[v] = [parents[v][k] for k in range(len(parents[v])) if k != i and k != j]
parents[v] += [nextvalue]
# (SignExt(x,16)*constant(32,y))[31:16] -> mulhi16(x,constant(16,y))
# if x has 16 bits and -2^15 <= y < 2^15
mulhi16 = []
for v in operation:
if operation[v] != ['Extract',31,16]: continue
hi = parents[v][0]
if operation[hi][0] != '__mul__': continue
s,c = parents[hi]
if operation[s] != ['SignExt',16]:
c,s = parents[hi]
if operation[s] != ['SignExt',16]: continue
if operation[c][0] != 'constant': continue
if operation[c][1] != 32: continue
y = operation[c][2]
if y > 2**31: y -= 2**32
if y < -2**15: continue
if y >= 2**15: continue
y %= 2**16
x = parents[s][0]
if bits[x] != 16: continue
mulhi16 += [(v,x,y)]
for v,x,y in mulhi16:
if (16,y) in constants:
c = constants[16,y]
else:
nextvalue += 1
operation[nextvalue] = ['constant',16,y]
parents[nextvalue] = []
bits[nextvalue] = 16
c = nextvalue
operation[v] = ['mulhi16']
parents[v] = [x,c]
progress = True
# Extract(LShR(LShR(__rshift__(Concat(x,x),16)*y,14)+1,1),15,0)
# -> mulhrs16(x,y)
# if x has 16 bits and -2^15 <= y < 2^15
mulhrs16 = []
for v in operation:
if operation[v] != ['Extract',15,0]: continue
a = parents[v][0]
if operation[a][0] != 'LShR': continue
b,c = parents[a]
if operation[c] != ['constant',32,1]: continue
if operation[b][0] != '__add__': continue
d,e = parents[b]
if operation[d][0] != 'LShR':
e,d = parents[b]
if operation[d][0] != 'LShR':
continue
if operation[e] != ['constant',32,1]: continue
f,g = parents[d]
if operation[g] != ['constant',32,14]: continue
if operation[f][0] != '__mul__': continue
h,i = parents[f]
if operation[h][0] != '__rshift__':
i,h = parents[f]
if operation[h][0] != '__rshift__':
continue
if operation[i][:2] != ['constant',32]: continue
y = operation[i][2]
if y > 2**31: y -= 2**32
if y < -2**15: continue
if y >= 2**15: continue
y %= 2**16
j,k = parents[h]
if operation[k] != ['constant',32,16]: continue
if operation[j][0] != 'Concat': continue
if len(parents[j]) != 2: continue
x = parents[j][0]
if x != parents[j][1]: continue
if bits[x] != 16: continue
mulhrs16 += [(v,x,y)]
for v,x,y in mulhrs16:
if (16,y) in constants:
c = constants[16,y]
else:
nextvalue += 1
operation[nextvalue] = ['constant',16,y]
parents[nextvalue] = []
bits[nextvalue] = 16
c = nextvalue
operation[v] = ['mulhrs16']
parents[v] = [x,c]
progress = True
# ----- clean up unused nodes, merge
children = dict()
for z in operation: children[z] = set()
for v in value:
if v.startswith('out'): children[value[v]].add(-1)
for z in operation:
for x in parents[z]:
children[x].add(z)
deleting = set(v for v in operation if len(children[v]) == 0)
merging = deleting.copy()
merge = []
for x in operation:
c = list(children[x])
for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]:
if y == -1: continue
if z == -1: continue
if operation[y] != operation[z]: continue
parentsmatch = False
if parents[y] == parents[z]: parentsmatch = True
if operation[y][0] in ['signedmin','signedmax']:
if set(parents[y]) == set(parents[z]):
parentsmatch = True
if not parentsmatch: continue
assert bits[y] == bits[z]
if y in merging: continue
if z in merging: continue
merge += [(y,z)]
merging.add(y)
merging.add(z)
for y,z in merge:
# eliminate z in favor of y
for t in children[z]:
if t == -1:
for v in value:
if v.startswith('out'):
if value[v] == z:
value[v] = y
else:
for j in range(len(parents[t])):
if parents[t][j] == z:
parents[t][j] = y
deleting.add(z)
for v in deleting:
del operation[v]
del parents[v]
del bits[v]
done = set()
def do(v):
if v in done: return
done.add(v)
for x in parents[v]: do(x)
if operation[v][0] == 'input':
print('v%d = %s' % (v,operation[v][1]))
else:
p = ['v%s' % x for x in parents[v]]
p += ['%s' % x for x in operation[v][1:]]
print('v%d = %s(%s)' % (v,operation[v][0],','.join(p)))
for v in value:
if not v.startswith('out'): continue
do(value[v])
print('%s = v%d' % (v,value[v]))