#!/usr/bin/env python3

import os
import sys
import io
import resource
import subprocess
import angr
import claripy

binary = 'syminsn-sparc32-x'

symflags = True

maxsplit = 300
# max number of universes within an angr run

hostlittle = True
targetlittle = False
flags = 'cf','vf','zf','nf'
regs32 = [f'g{i}' for i in range(1,8)]
regs32 += [f'o{i}' for i in range(8)]
regs32 += [f'l{i}' for i in range(8)]
regs32 += [f'i{i}' for i in range(8)]
regs = [(r,32) for r in regs32]

expectedoutputs = len(flags)+len(regs)

invars = []

for r,rbits in regs:
  varname = 'in_'+r
  variable = claripy.BVS(varname,rbits,explicit_name=True)
  if hostlittle:
    variable = claripy.Reverse(variable)
  invars += [variable]

for r in flags:
  if symflags:
    varname = 'in_'+r
    variable = claripy.BVS(varname,1,explicit_name=True)
  else:
    variable = claripy.BVV(0,1)
  invars += [variable.zero_extend(7)]

insn = 0x82c08003 # flags,g1 = g2+g3+carry
tweak = 0
if len(sys.argv) >= 2:
  insn = int(sys.argv[1],base=16)
  if len(sys.argv) >= 3:
    tweak = int(sys.argv[2])

print(f'exploring insn {hex(insn)} tweak {tweak}')
sys.stdout.flush()

insn = claripy.BVV(insn,32)
if tweak > 0:
  imm = claripy.BVS('in_n',tweak,explicit_name=True)
  insn |= imm.zero_extend(32-tweak)

if targetlittle:
  insn = claripy.Reverse(insn)
invars += [insn]

with open(f'{binary}.c','w') as f:
  f.write(r'''#include <inttypes.h>
#include <string.h>
#include <stdio.h>
#include <stdarg.h>
#include <errno.h>
#include <unicorn/unicorn.h>
#include <time.h>

long long flags[] = {
  20, // C
  21, // V
  22, // Z
  23 // N
} ;

long long regs32[] = {
  UC_SPARC_REG_G1,
  UC_SPARC_REG_G2,
  UC_SPARC_REG_G3,
  UC_SPARC_REG_G4,
  UC_SPARC_REG_G5,
  UC_SPARC_REG_G6,
  UC_SPARC_REG_G7,
  UC_SPARC_REG_O0,
  UC_SPARC_REG_O1,
  UC_SPARC_REG_O2,
  UC_SPARC_REG_O3,
  UC_SPARC_REG_O4,
  UC_SPARC_REG_O5,
  UC_SPARC_REG_O6,
  UC_SPARC_REG_O7,
  UC_SPARC_REG_L0,
  UC_SPARC_REG_L1,
  UC_SPARC_REG_L2,
  UC_SPARC_REG_L3,
  UC_SPARC_REG_L4,
  UC_SPARC_REG_L5,
  UC_SPARC_REG_L6,
  UC_SPARC_REG_L7,
  UC_SPARC_REG_I0,
  UC_SPARC_REG_I1,
  UC_SPARC_REG_I2,
  UC_SPARC_REG_I3,
  UC_SPARC_REG_I4,
  UC_SPARC_REG_I5,
  UC_SPARC_REG_I6,
  UC_SPARC_REG_I7,
} ;

int vasprintf(char **x,const char *fmt,va_list ap)
{
  int len;
  va_list apcopy;
  va_copy(apcopy,ap);
  len = vsnprintf(0,0,fmt,apcopy);
  va_end(apcopy);
  *x = 0;
  if (len < 0) return -1;
  len += 1;
  *x = malloc(len);
  if (!*x) return -1;
  return vsnprintf(*x,len,fmt,ap);
}

int clock_gettime(clockid_t clockid, struct timespec *tp)
{
  tp->tv_sec = 0;
  tp->tv_nsec = 0;
  return 0;
}

int gettimeofday(struct timeval *tv,void *tz)
{
  tv->tv_sec = 0;
  tv->tv_usec = 0;
  return 0;
}

int main()
{
  uc_engine *uc;
  uc_err err;
  uint64_t insnpos = 0x10000;
  unsigned char r8;
  uint32_t r32;
  long long i;

  err = uc_open(UC_ARCH_SPARC,UC_MODE_SPARC32|UC_MODE_BIG_ENDIAN,&uc);
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_open failed: %s\n",uc_strerror(err));
    exit(111);
  }

  err = uc_mem_map(uc,insnpos,insnpos+4096,UC_PROT_ALL);
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err));
    exit(111);
  }

  for (long long i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) {
    fread(&r32,4,1,stdin);
    uc_reg_write(uc,regs32[i],&r32);
  }

  r32 = 0;
  for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) {
    fread(&r8,1,1,stdin);
    r8 &= 1;
    r32 |= ((uint32_t) r8) << flags[i];
  }
  uc_reg_write(uc,UC_SPARC_REG_PSR,&r32);

  fread(&r32,4,1,stdin);
  uc_mem_write(uc,insnpos,&r32,4);

  err = uc_emu_start(uc,insnpos,insnpos+4,0,1);
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_emu_start failed: %s\n",uc_strerror(err));
    exit(111);
  }

  for (i = 0;i < sizeof(regs32)/sizeof(regs32[0]);++i) {
    uc_reg_read(uc,regs32[i],&r32);
    fwrite(&r32,4,1,stdout);
  }

  uc_reg_read(uc,UC_SPARC_REG_PSR,&r32);
  for (long long i = 0;i < sizeof(flags)/sizeof(flags[0]);++i) {
    r8 = 1 & (r32 >> flags[i]);
    fwrite(&r8,1,1,stdout);
  }

  return 0;
}
''')

subprocess.run(f'gcc -Os -o {binary} {binary}.c setjmp.s -lunicorn'.split(),check=True)

def clarikey(e):
  try:
    # e.g. angr 9.2.102
    return e.cache_key
  except:
    # e.g. angr 9.2.144
    return e.hash()

def values(terms,replacements):
  # input: replacements mapping clarikey to integers
  # output: dictionary V mapping clarikey to pairs (b,i) where i is a b-bit value
  # output includes all terms
  # _or_ output is None if terms use variables outside replacements

  V = {}
  W = set() # warnings

  def evaluate(t):
    if clarikey(t) in V:
      return True
    if t.op == 'BoolV':
      V[clarikey(t)] = 1,t.args[0]
      return True
    if t.op == 'BVV':
      V[clarikey(t)] = t.size(),t.args[0]
      return True
    if t.op == 'BVS':
      if clarikey(t) not in replacements: return False
      V[clarikey(t)] = t.size(),replacements[clarikey(t)].args[0]
      return True
    if t.op == 'Extract':
      assert len(t.args) == 3
      top = t.args[0]
      bot = t.args[1]
      if not evaluate(t.args[2]): return False
      x0 = V[clarikey(t.args[2])]
      assert x0[0] > top
      assert top >= bot
      assert bot >= 0
      V[clarikey(t)] = top+1-bot,((x0[1] & ((2<<top)-1)) >> bot)
      return True
    if t.op in ('SignExt','ZeroExt'):
      assert len(t.args) == 2
      if not evaluate(t.args[1]): return False
      x0bits,x0 = V[clarikey(t.args[1])]
      extend = t.args[0]
      assert extend >= 0
      if t.op == 'SignExt':
        if x0 >= (1<<(x0bits-1)):
          x0 -= 1<<x0bits
          x0 += 1<<(x0bits+extend)
      V[clarikey(t)] = x0bits+extend,x0
      return True

    for a in t.args:
      if not evaluate(a): return False
    x = [V[clarikey(a)] for a in t.args]

    if t.op == 'Concat':
      y = 0
      ybits = 0
      for xbitsi,xi in x:
        y <<= xbitsi
        y += xi
        ybits += xbitsi
      V[clarikey(t)] = ybits,y
      return True
    if t.op == 'Reverse':
      assert len(x) == 1
      xbits0,x0 = x[0]
      ybits = xbits0
      assert ybits%8 == 0
      y = 0
      for i in range(ybits//8):
        y = (y<<8)+((x0>>(8*i))&255)
      V[clarikey(t)] = ybits,y
      return True
    if t.op in ('__eq__','__ne__'):
      assert len(x) == 2
      assert x[0][0] == x[1][0]
      if t.op == '__eq__': V[clarikey(t)] = 1,(x[0][1]==x[1][1])
      elif t.op == '__ne__': V[clarikey(t)] = 1,(x[0][1]!=x[1][1])
      else:
        print('values: internal error %s, falling back to Z3' % t.op)
        return False
      return True
    if t.op in ('__add__','__mul__','SDiv','__floordiv__','SMod','__mod__','__sub__','__lshift__','LShR','__rshift__','__and__','__or__','__xor__'):
      bits = x[0][0]
      assert all(xi[0] == bits for xi in x)
      if t.op == '__and__': reduction = (lambda s,t:s&t)
      elif t.op == '__or__': reduction = (lambda s,t:s|t)
      elif t.op == '__xor__': reduction = (lambda s,t:s^t)
      elif t.op == '__add__': reduction = (lambda s,t:(s+t)%(2**bits))
      elif t.op == '__sub__': reduction = (lambda s,t:(s-t)%(2**bits))
      elif t.op == '__lshift__': reduction = (lambda s,t:(s<<t)%(2**bits))
      elif t.op == 'LShR': reduction = (lambda s,t:(s>>t)%(2**bits))
      elif t.op == '__rshift__':
        def reduction(s,t):
          flip = 2**(bits-1)
          ssigned = (s ^ flip) - flip
          tsigned = (t ^ flip) - flip
          # XXX: what are semantics for tsigned outside [0:bits]? also check __lshift__, LShR
          usigned = ssigned >> tsigned
          return usigned%(2**bits)
      elif t.op == '__mul__':
        reduction = (lambda s,t:(s*t)%(2**bits))
        W.add('mul')
      elif t.op == '__floordiv__':
        def reduction(s,t):
          if t == 0: return 0
          return (s//t)%(2**bits)
        W.add('div')
      elif t.op == '__mod__':
        def reduction(s,t):
          if t == 0: return s
          return (s%t)%(2**bits)
        W.add('div')
      elif t.op == 'SDiv':
        def reduction(s,t):
          if t == 0: return 0
          flip = 2**(bits-1)
          ssigned = (s ^ flip) - flip
          tsigned = (t ^ flip) - flip
          # sdiv definition in Z3:
          # - The \c floor of [t1/t2] if \c t2 is different from zero, and [t1*t2 >= 0].
          # - The \c ceiling of [t1/t2] if \c t2 is different from zero, and [t1*t2 < 0].
          if ssigned*tsigned >= 0:
            usigned = ssigned // tsigned
          else:
            usigned = -((-ssigned) // tsigned)
          return usigned%(2**bits)
        W.add('div')
      elif t.op == 'SMod':
        def reduction(s,t):
          if t == 0: return s
          flip = 2**(bits-1)
          ssigned = (s ^ flip) - flip
          tsigned = (t ^ flip) - flip
          # srem definition in Z3:
          # It is defined as t1 - (t1 /s t2) * t2, where /s represents signed division.
          # The most significant bit (sign) of the result is equal to the most significant bit of \c t1.
          if ssigned*tsigned >= 0:
            usigned = ssigned - tsigned*(ssigned // tsigned)
          else:
            usigned = ssigned + tsigned*((-ssigned) // tsigned)
          return usigned%(2**bits)
        W.add('div')
      else:
        print('values: internal error %s, falling back to Z3' % t.op)
        return False
      V[clarikey(t)] = bits,functools.reduce(reduction,(xi[1] for xi in x))
      return True
    if t.op == '__invert__':
      assert len(x) == 1
      bits = x[0][0]
      V[clarikey(t)] = bits,(1<<bits)-1-x[0][1]
      return True
    if t.op == 'Not':
      assert len(x) == 1
      assert all(xi[0] == 1 for xi in x)
      V[clarikey(t)] = 1,1-x[0][1]
      return True
    if t.op in ('And','Or'):
      assert all(xi[0] == 1 for xi in x)
      if t.op == 'And': reduction = (lambda s,t:s*t)
      elif t.op == 'Or': reduction = (lambda s,t:s+t-s*t)
      else:
        print('values: internal error %s, falling back to Z3' % t.op)
        return False
      V[clarikey(t)] = 1,functools.reduce(reduction,(xi[1] for xi in x))
      return True
    if t.op == 'If':
      assert len(x) == 3
      assert x[0][0] == 1
      if x[0][1]:
        V[clarikey(t)] = x[1]
      else:
        V[clarikey(t)] = x[2]
      return True
    if t.op in ('__le__','ULE','__lt__','ULT','__ge__','UGE','__gt__','UGT','SLE','SLT','SGE','SGT'):
      assert len(x) == 2
      bits = x[0][0]
      assert bits == x[1][0]
      flip = 2**(bits-1)
      x0,x1 = x[0][1],x[1][1]
      if t.op == '__le__': V[clarikey(t)] = (1,x0<=x1)
      elif t.op == 'ULE': V[clarikey(t)] = (1,x0<=x1)
      elif t.op == '__lt__': V[clarikey(t)] = (1,x0<x1)
      elif t.op == 'ULT': V[clarikey(t)] = (1,x0<x1)
      elif t.op == '__ge__': V[clarikey(t)] = (1,x0>=x1)
      elif t.op == 'UGE': V[clarikey(t)] = (1,x0>=x1)
      elif t.op == '__gt__': V[clarikey(t)] = (1,x0>x1)
      elif t.op == 'UGT': V[clarikey(t)] = (1,x0>x1)
      elif t.op == 'SLE': V[clarikey(t)] = (1,(x0^flip)<=(x1^flip))
      elif t.op == 'SLT': V[clarikey(t)] = (1,(x0^flip)<(x1^flip))
      elif t.op == 'SGE': V[clarikey(t)] = (1,(x0^flip)>=(x1^flip))
      elif t.op == 'SGT': V[clarikey(t)] = (1,(x0^flip)>(x1^flip))
      else:
        print('values: internal error %s, falling back to Z3' % t.op)
        return False
      return True

    # XXX: add support for more
    print('values: unsupported operation %s, falling back to Z3' % t.op)
    return False

    # XXX: also add more validation for all of the above

  for t in terms:
    if not evaluate(t): return None,W
  return V,W

def unroll_print(names,unrolled,f):
  walked = {}

  def walk(t):
    if clarikey(t) in walked: return walked[clarikey(t)]
    if t.op == 'BoolV':
      walknext = len(walked)+1
      f.write('v%d = bool(%d)\n' % (walknext,t.args[0]))
    elif t.op == 'BVV':
      walknext = len(walked)+1
      f.write('v%d = constant(%d,%d)\n' % (walknext,t.size(),t.args[0]))
    elif t.op == 'BVS':
      walknext = len(walked)+1
      f.write('v%d = %s\n' % (walknext,t.args[0]))
    elif t.op == 'Extract':
      assert len(t.args) == 3
      input = 'v%d' % walk(t.args[2])
      walknext = len(walked)+1
      f.write('v%d = Extract(%s,%d,%d)\n' % (walknext,input,t.args[0],t.args[1]))
    elif t.op in ['SignExt','ZeroExt']:
      assert len(t.args) == 2
      input = 'v%d' % walk(t.args[1])
      walknext = len(walked)+1
      f.write('v%d = %s(%s,%d)\n' % (walknext,t.op,input,t.args[0]))
    else:
      inputs = ['v%d' % walk(a) for a in t.args]
      walknext = len(walked)+1
      f.write('v%d = %s(%s)\n' % (walknext,t.op,','.join(inputs)))
    walked[clarikey(t)] = walknext
    return walknext

  for x in unrolled:
    walk(x)

  unrolledpos = 0
  for varname in names:
    f.write('%s = v%s\n' % (varname,walk(unrolled[unrolledpos])))
    unrolledpos += 1

sys.setrecursionlimit(1000000)

add_options = {
  angr.options.SYMBOLIC_WRITE_ADDRESSES,
  angr.options.CONSERVATIVE_READ_STRATEGY,
  angr.options.CONSERVATIVE_WRITE_STRATEGY,
  angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY,
  angr.options.ZERO_FILL_UNCONSTRAINED_REGISTERS,
  angr.options.CONSTRAINT_TRACKING_IN_SOLVER,
  angr.options.SIMPLIFY_EXPRS,
  angr.options.SIMPLIFY_MEMORY_READS,
  angr.options.SIMPLIFY_MEMORY_WRITES,
  angr.options.SIMPLIFY_REGISTER_READS,
  angr.options.SIMPLIFY_REGISTER_WRITES,
  angr.options.SIMPLIFY_CONSTRAINTS,
  angr.options.SIMPLIFY_RETS,
}
add_options |= angr.options.unicorn-{angr.options.UNICORN_SYM_REGS_SUPPORT}

stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*invars),has_end=True)

avoidsimprocedures = (
  'clock_gettime',
  '_setjmp',
  '__setjmp',
  '___setjmp',
  'longjmp',
  '_longjmp',
  '__longjmp',
  '___longjmp',
  'sigsetjmp',
  '_sigsetjmp',
  '__sigsetjmp',
  '___sigsetjmp',
  'siglongjmp',
  '_siglongjmp',
  '__siglongjmp',
  '___siglongjmp',
)

lib = 'libunicorn.so'
for libdir in os.getenv('LD_LIBRARY_PATH','').split(':'):
  for fn in f'{libdir}/{lib}',f'{libdir}/{lib}.1':
    if os.path.exists(fn):
      lib = fn
      break
proj = angr.Project(binary,exclude_sim_procedures_list=avoidsimprocedures,auto_load_libs=False,force_load_libs=[lib])

class posix_memalign(angr.SimProcedure):
  def run(self,sim_ptr,sim_alignment,sim_size):
    result = self.state.heap._malloc(sim_size)
    self.state.memory.store(sim_ptr,result,size=8)
    return claripy.BVV(0,32)
proj.hook_symbol('posix_memalign',posix_memalign())

class sysconf(angr.SimProcedure):
  def run(self,num):
    return claripy.BVV(os.sysconf(num.concrete_value),64)
proj.hook_symbol('sysconf',sysconf())

class getpagesize(angr.SimProcedure):
  def run(self):
    return claripy.BVV(resource.getpagesize(),32)
proj.hook_symbol('getpagesize',getpagesize())

class strerror(angr.SimProcedure):
  def run(self,num):
    e = os.strerror(num.concrete_value)+'\0'
    malloc = angr.SIM_PROCEDURES["libc"]["malloc"]
    where = self.inline_call(malloc,len(e)).ret_expr
    self.state.memory.store(where,claripy.BVV(e),size=len(e))
    return where
proj.hook_symbol('strerror',strerror())

proj.hook_symbol('mmap64',angr.procedures.posix.mmap.mmap())

state = proj.factory.full_init_state(add_options=add_options,args=[binary],stdin=stdin)
print('options',state.options._options)

simgr = proj.factory.simgr(state)

printingsteps = False
fast = True
step = 0
while True:
  if printingsteps:
    print(f'step {step} fast {fast} deadended {len(simgr.deadended)} active {len(simgr.active)}')
    for epos,e in enumerate(simgr.active):
      print(f'active {epos}:')
      try:
        for ins in proj.factory.block(e.addr).capstone.insns:
          print(ins)
      except:
        print('exception')
    sys.stdout.flush()

  step += 1
  if len(simgr.errored) > 0:
    raise Exception(simgr.errored)
  if len(simgr.deadended)+len(simgr.active) > maxsplit:
    constraintbuf = io.StringIO()
    constraintnames = [f'exited{i}' for i in range(len(simgr.deadended))]
    constraintnames += [f'active{i}' for i in range(len(simgr.active))]
    unroll_print(constraintnames,[claripy.And(*e.solver.constraints) for e in simgr.deadended+simgr.active],constraintbuf)
    raise Exception(f'limiting split to {maxsplit}\n{constraintbuf.getvalue()}')
  if len(simgr.active) == 0:
    break
  # XXX: should find a documented interface to do this
  if any(type(e.posix.fd[0]._read_pos) != int for e in simgr.active):
    state.options -= angr.options.unicorn
    fast = False
  simgr.step()

exits = simgr.deadended
assert len(exits) > 0

ok = True
comment = f'have {len(exits)} exits\n'
for epos,e in enumerate(exits):
  receivedpackets = len(e.posix.stdout.content)
  if receivedpackets != expectedoutputs:
    ok = False
  comment += f'exit {epos} stdout packets {receivedpackets} expecting {expectedoutputs} stderr packets {len(e.posix.stderr.content)}\n'
  for packetpos,packet in enumerate(e.posix.stderr.content):
    if not (packet[0].concrete and packet[1].concrete):
      comment += f'stderr symbolic packet {packetpos}\n'
      continue
    x = packet[0].concrete_value
    n = packet[1].concrete_value
    todo = b''
    for i in range(n):
      todo = bytes(bytearray([x&255]))+todo
      x >>= 8
    for line in todo.splitlines():
      comment += f'stderr packet {packetpos} line: {line}\n'
if not ok:
  constraintbuf = io.StringIO()
  constraintnames = [f'exited{i}' for i in range(len(exits))]
  unroll_print(constraintnames,[claripy.And(claripy.true,*e.solver.constraints) for e in exits],constraintbuf)
  raise Exception(comment+constraintbuf.getvalue())

if len(exits) > 1:
  mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits])
else:
  mergedexit = exits[0]

names = []
results = []
packetpos = 0

for r,rbits in regs:
  names += ['out_'+r]
  packet = mergedexit.posix.stdout.content[packetpos]
  packetpos += 1
  assert packet[1].concrete_value == rbits//8
  xi = packet[0]
  if hostlittle:
    xi = claripy.Reverse(xi)
  results += [xi]

for r in flags:
  names += ['out_'+r]
  packet = mergedexit.posix.stdout.content[packetpos]
  packetpos += 1
  assert packet[1].concrete_value == 1
  xi = packet[0]
  results += [claripy.Extract(0,0,xi)]

unroll_print(names,results,sys.stdout)
