-rwxr-xr-x 5118 nttcompiler-20220411/scripts/range2linear
#!/usr/bin/env python3
import sys
import math
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
q = int(sys.argv[1])
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
equal = Literal('=').suppress()
caret = Literal('^').suppress() | Literal('**').suppress()
newline = Literal('_NEWLINE').suppress()
timesdiv = Literal('*') | Literal('/')
pm = Literal('+') | Literal('-')
number = (Word(nums)|Literal('-')+Word(nums)).setParseAction(lambda x:[['number'] + list(x)])
name = Word(alphas,alphas+nums+"_").setParseAction(lambda x:[['name'] + list(x)])
expr = Forward()
atom = lparen + expr + rparen | name | number
power = (atom + Optional(caret + number)).setParseAction(group('power'))
term = (power + ZeroOrMore(timesdiv + power)).setParseAction(group('term'))
sum = (term + ZeroOrMore(pm + term)).setParseAction(group('sum'))
expr << sum.setParseAction(group('expr'))
assignment = (name + equal + expr + newline).setParseAction(group('assignment'))
file_input = ZeroOrMore(newline | assignment) + StringEnd()
program = ''
while True:
line = sys.stdin.readline()
if not line: break
if not line.startswith('rem'): continue
program += line + ' _NEWLINE '
program = file_input.parseString(program)
program = list(program)
variables = {}
# e.g., variables['rem10'] = [(3,'in1'),(4,'in2')]
# meaning that rem10 is 3*in1 + 4*in2 modulo q
# must minimize: [(3,'in5'),(4,'in5')] not allowed
# exception: [(7,)] means 7 modulo q
# XXX: slightly cleaner approach would be [(7,'')]
# XXX: currently no support for, e.g., 7+3*in1
def exprtimes(e,f):
if e == [] and f == []: return []
if len(e) == 1 and len(e[0]) == 1:
e,f = f,e
assert len(f) == 1
assert len(f[0]) == 1
x = f[0][0]
if len(e) == 1 and len(e[0]) == 1:
return [((x*e[0][0])%q,)]
return [((x*c[0])%q,c[1]) for c in e]
def exprpow(e,f):
assert len(e) == 1
assert len(e[0]) == 1
assert len(f) == 1
assert len(f[0]) == 1
return [(pow(e[0][0],f[0][0],q),)]
def exprdiv(e,f):
if len(e) == 1 and len(e[0]) == 1:
e,f = f,e
assert len(f) == 1
assert len(f[0]) == 1
x = f[0][0]
x = pow(x,q-2,q) # XXX: only for primes
if len(e) == 1 and len(e[0]) == 1:
return [((x*e[0][0])%q,)]
return [((x*c[0])%q,c[1]) for c in e]
def exprplus(e,f):
if len(e) == 1 and len(e[0]) == 1:
raise Exception('numerical %s + %s' % (e,f))
if len(f) == 1 and len(f[0]) == 1:
raise Exception('numerical %s + %s' % (e,f))
coeffs = {}
for c in e + f:
if c[1] in coeffs:
coeffs[c[1]] = (coeffs[c[1]] + c[0]) % q
else:
coeffs[c[1]] = c[0]
return [(coeffs[v],v) for v in coeffs]
def exprminus(e,f):
if len(e) == 1 and len(e[0]) == 1:
raise Exception('numerical exprminus')
if len(f) == 1 and len(f[0]) == 1:
raise Exception('numerical exprminus')
coeffs = {}
for c in e:
if c[1] in coeffs:
coeffs[c[1]] = (coeffs[c[1]] + c[0]) % q
else:
coeffs[c[1]] = c[0]
for c in f:
if c[1] in coeffs:
coeffs[c[1]] = (coeffs[c[1]] - c[0]) % q
else:
coeffs[c[1]] = (-c[0]) % q
return [(coeffs[v],v) for v in coeffs]
def evaluate(e):
if e[0] == 'number':
if e[1] == '-':
assert len(e) == 3
if int(e[2]) % q == 0: return []
return [((-int(e[2]))%q,)]
assert len(e) == 2
if int(e[1]) % q == 0: return []
return [(int(e[1])%q,)]
if e[0] == 'name':
assert len(e) == 2
if not e[1] in variables:
shortname = e[1]
# if '_' in shortname: shortname = shortname[:e[1].index('_')]
variables[e[1]] = [(1,shortname)]
return variables[e[1]]
if e[0] == 'power':
assert len(e) == 3
return exprpow(evaluate(e[1]),evaluate(e[2]))
if e[0] == 'term':
assert len(e) == 4
if e[2] == '*':
return exprtimes(evaluate(e[1]),evaluate(e[3]))
if e[2] == '/':
return exprdiv(evaluate(e[1]),evaluate(e[3]))
if e[0] == 'expr':
result = evaluate(e[1])
for j in range(2,len(e),2):
if e[j] == '+':
result = exprplus(result,evaluate(e[j+1]))
elif e[j] == '-':
result = exprminus(result,evaluate(e[j+1]))
else:
raise Exception('unknown expression %s' % e)
return result
raise Exception('unknown expression %s' % e)
outputs = []
for p in program:
if p[0] == 'assignment':
assert p[1][0] == 'name'
variables[p[1][1]] = evaluate(p[2])
# print('# %s = %s = %s' % (p[1][1],p[2],variables[p[1][1]]))
if p[1][1].startswith('remout'):
outputs += [p[1][1]]
else:
raise Exception('unknown statement %s' % p)
def varnumber(c):
s = ['0'] + [ch for ch in c[1] if ch in '0123456789']
return int(''.join(s))
for o in outputs:
result = '%s = ' % o
zero = True
variables[o].sort(key=varnumber)
for c in variables[o]:
if c[0] != 0:
if not zero:
result += ' + '
inname = '_'.join(c[1].split('_')[:3])
result += '%d*%s' % (c[0],inname)
zero = False
if zero:
result += '0'
print(result)