-rwxr-xr-x 6705 nttcompiler-20220411/scripts/doublecheck
#!/usr/bin/env python3
import os
import sys
from functools import reduce
from random import randrange
import subprocess
N = int(sys.argv[1])
ntt = sys.argv[2]
L,H = None,None
if len(sys.argv) > 4:
L = int(sys.argv[3])
H = int(sys.argv[4])
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
pmnumber = Word(nums+"-")
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + Literal('constant').suppress()
+ lparen + number + comma + number + rparen
).setParseAction(group('constant'))
for binary in ['__sub__','__rshift__','__lshift__','LShR','mulhi16','mulhrs16']:
assignment |= (
name + equal + Literal(binary).suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group(binary))
assignment |= (
name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignment |= (
name + equal + Literal('SignExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('SignExt'))
assignment |= (
name + equal + Literal('ZeroExt').suppress()
+ lparen + name + comma + number + rparen
).setParseAction(group('ZeroExt'))
for manyary in ['Concat','__or__','__and__','__add__','__mul__']:
assignment |= (
name + equal + Literal(manyary).suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group(manyary))
assignment |= (
name + equal + name).setParseAction(group('copy')
)
rangecheck = (
Literal('assertsignedminmax').suppress()
+ lparen + name + comma + pmnumber + comma + pmnumber + rparen
).setParseAction(group('assertsignedminmax'))
assignments = ZeroOrMore(assignment | rangecheck) + StringEnd()
inputcopy = ''
program = ''
for line in sys.stdin:
inputcopy += line
if line.startswith('#'): continue
if line.startswith('rem'): continue
program += line+'\n'
program = assignments.parseString(program)
program = list(program)
def op_constant(x,y):
return x,y
def op___add__(*args):
assert len(args) > 0
b = args[0][0]
assert all(x[0] == b for x in args)
return b,reduce((lambda s,t:(s+t)%(2**b)),(x[1] for x in args))
def op___mul__(*args):
assert len(args) > 0
b = args[0][0]
assert all(x[0] == b for x in args)
return b,reduce((lambda s,t:(s*t)%(2**b)),(x[1] for x in args))
def op___and__(*args):
assert len(args) > 0
b = args[0][0]
assert all(x[0] == b for x in args)
return b,reduce((lambda s,t:s&t),(x[1] for x in args))
def op___or__(*args):
assert len(args) > 0
b = args[0][0]
assert all(x[0] == b for x in args)
return b,reduce((lambda s,t:s|t),(x[1] for x in args))
def op___sub__(x,y):
assert x[0] == y[0]
return x[0],(x[1] - y[1]) % (2**x[0])
def op_assertsignedminmax(x,L,H):
b = x[0]
flip = 2**(b-1)
xsigned = (x[1] ^ flip) - flip
assert xsigned >= L
assert xsigned <= H
def op_mulhi16(x,y):
b = 16
assert x[0] == b
assert y[0] == b
flip = 2**(b-1)
xsigned = (x[1] ^ flip) - flip
ysigned = (y[1] ^ flip) - flip
zsigned = (xsigned * ysigned) >> 16
assert (zsigned + flip) ^ flip == zsigned % 2**b
return b,zsigned % 2**b
def op_mulhrs16(x,y):
b = 16
assert x[0] == b
assert y[0] == b
flip = 2**(b-1)
xsigned = (x[1] ^ flip) - flip
ysigned = (y[1] ^ flip) - flip
zsigned = (((xsigned * ysigned) >> 14) + 1) >> 1
# XXX: fails if xsigned and ysigned are both -32768
# assert (zsigned + flip) ^ flip == zsigned % 2**b
return b,zsigned % 2**b
def op___lshift__(x,y):
assert x[0] == y[0]
return x[0],(x[1] << y[1]) % (2**x[0])
def op_LShR(x,y): # unsigned right shift
b = x[0]
assert b == y[0]
xunsigned = x[1]
yunsigned = y[1]
zunsigned = xunsigned >> yunsigned
return b,zunsigned
def op___rshift__(x,y): # signed right shift
b = x[0]
assert b == y[0]
flip = 2**(b-1)
xsigned = (x[1] ^ flip) - flip
ysigned = (y[1] ^ flip) - flip
assert 0 <= ysigned
assert ysigned < b
zsigned = xsigned >> ysigned
return b,(zsigned + flip) ^ flip
def op_Concat(*args):
pos,value = 0,0
for arg in reversed(args):
pos,value = pos + arg[0],value + (arg[1] << pos)
return pos,value
def op_Extract(x,top,bot):
assert x[0] > top
assert top >= bot
assert bot >= 0
return top + 1 - bot,((x[1] & ((2 << top) - 1)) >> bot)
def op_SignExt(x,bits):
b,val = x
if val & (2**(b-1)):
return b + bits,val + 2**(b+bits) - 2**b
return b + bits,val
def op_ZeroExt(x,bits):
b,val = x
return b + bits,val
def op_copy(x):
return x
input = {} # e.g., input[37] for in_static_f_37_99_16
output = {} # e.g., output[37] for out_static_f_37
value = {} # e.g., value['in_static_f_37_99_16']
def evaluate(x):
if x not in value:
assert x.startswith('in')
z = x.split('_')
z = [int(zj) for zj in z if zj.isnumeric()]
assert len(z) >= 2 # first is key, last is numbits
key = z[0]
assert key not in input
b = int(z[-1])
if H is not None:
v = randrange(L,H+1)
else:
v = randrange(2**b)
v %= 2**b
input[key] = b,v
value[x] = b,v
return value[x]
for p in program:
if p[0] == 'assertsignedminmax':
args = [evaluate(pj) for pj in p[1:2]] + [int(pj) for pj in p[2:]]
op_assertsignedminmax(*args)
continue
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'constant':
args = [int(pj) for pj in p[2:]]
elif p[0] in ('Extract','SignExt','ZeroExt'):
args = [evaluate(pj) for pj in p[2:3]] + [int(pj) for pj in p[3:]]
else:
args = [evaluate(pj) for pj in p[2:]]
op = getattr(sys.modules[__name__],'op_'+p[0])
value[p[1]] = op(*args)
if p[1].startswith('out'):
z = p[1].split('_')
z = [int(zj) for zj in z if zj.isnumeric()]
assert len(z) >= 1
key = z[0]
output[key] = value[p[1]]
assert sorted(input) == list(range(N))
assert sorted(output) == list(range(N))
inputstr = ''
for j in range(N):
assert input[j][0] == 16
v = input[j][1]
if v >= 2**15: v -= 2**16
inputstr += '%d\n'%v
proc = subprocess.Popen(ntt,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
executestr,err = proc.communicate(input=inputstr)
outputstr = ''
for j in range(N):
assert output[j][0] == 16
v = output[j][1]
if v >= 2**15: v -= 2**16
outputstr += '%d\n'%v
# and now the real test...
assert executestr == outputstr
# ok, allowed to pass along the input
sys.stdout.write(inputcopy)