-rw-r--r-- 3348 nttcompiler-20220411/512/works.c
#include <stdlib.h>
#include <assert.h>
#include "ntt_512.h"
typedef int16_t int16;
#define ALIGN __attribute((aligned(32)))
ALIGN int16 base[512];
ALIGN int16 M[512*512];
ALIGN int16 f[512*512];
ALIGN int16 g[512*512];
#define REPS 30
#define OFFSETS 128
ALIGN int16 h[REPS*512+OFFSETS];
long long qlist[2] = {7681,10753};
void (*nttlist[2])(int16*,long long) = {ntt_512_7681,ntt_512_10753};
void (*invnttlist[2])(int16*,long long) = {ntt_512_7681_inv,ntt_512_10753_inv};
int main()
{
for (long long qpos = 0;qpos < 2;++qpos) {
long long q = qlist[qpos];
void (*ntt)(int16*,long long) = nttlist[qpos];
void (*invntt)(int16*,long long) = invnttlist[qpos];
// test that basis gives powers
char seenbase[q];
for (long long i = 0;i < q;++i) seenbase[i] = 0;
for (long long e = 0;e < 512;++e) {
for (long long j = 0;j < 512;++j)
M[e*512+j] = 0;
M[e*512+e] = 1;
}
ntt(M,512);
for (long long j = 0;j < 512;++j) {
long long z = M[1*512+j];
z %= q; if (z < 0) z += q;
assert(z >= 0);
assert(z < q);
seenbase[z] = 1;
long long ze = 1;
for (long long e = 0;e < 512;++e) {
assert((ze-M[e*512+j])%q == 0);
ze *= z;
ze %= q; if (ze < 0) ze += q;
}
}
// test that powers are of 512th roots of 1
char root[10][q];
for (long long r = 0;r < 10;++r) {
if (r == 0)
for (long long i = 0;i < q;++i)
root[r][i] = i == 1;
else
for (long long i = 0;i < q;++i) {
long long ii = i*i % q;
assert(ii >= 0);
assert(ii < q);
assert(r-1 >= 0);
root[r][i] = root[r-1][ii];
}
}
for (long long i = 0;i < q;++i)
assert(root[9][i] == seenbase[i]);
// test that some random examples pass bounds checks and linearity checks
// XXX: rethink how input bounds should be selected here
for (long long j = 0;j < 512*512;++j) {
M[j] %= q;
M[j] += q;
M[j] %= q;
if (random()&1) M[j] -= q;
if (M[j] > 8000) M[j] -= q;
if (M[j] < -8000) M[j] += q;
}
for (long long loop = 0;loop < 30;++loop) {
for (long long j = 0;j < 512*512;++j)
g[j] = f[j] = (random()%q)-(q/2);
ntt(g,512);
if (loop == 0)
for (long long e = 0;e < 512;++e)
for (long long j = 0;j < 512;++j) {
long long s = 0;
for (long long i = 0;i < 512;++i)
s += f[e*512+i]*(long long) M[i*512+j];
assert((s-g[e*512+j])%q == 0);
}
}
// test that inverse gives identity
invntt(M,512);
for (long long e = 0;e < 512;++e) {
for (long long j = 0;j < 512;++j)
if (j != e)
assert(M[e*512+j]%q == 0);
assert(((M[e*512+e]%q)+q)%q == 512);
}
// test for consistency across reps and alignment
for (long long reps = 1;reps < REPS;++reps) {
if (reps > 512) continue;
for (long long offset = 0;offset < OFFSETS;++offset) {
for (long long j = 0;j < 512*reps;++j)
g[j] = h[offset+j] = (random()%q)-(q/2);
for (long long j = 0;j < 512*reps;j += 512)
ntt(g+j,1);
ntt(h+offset,reps);
for (long long j = 0;j < 512*reps;++j)
assert(g[j] == h[offset+j]);
}
}
}
return 0;
}