-rw-r--r-- 2723 nttcompiler-20220411/command/ops-512-speed.c
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include "ntt_ops_512.h"
#include "ntt_ops.h"
typedef int16_t int16;
#define ALIGN __attribute((aligned(32)))
ALIGN int16 base[512];
ALIGN int16 M[512*512];
ALIGN int16 f[512*512];
ALIGN int16 g[512*512];
long long qlist[2] = {7681,10753};
void (*nttlist[2])(int16*,long long) = {ntt_ops_512_7681,ntt_ops_512_10753};
void (*invnttlist[2])(int16*,long long) = {ntt_ops_512_7681_inv,ntt_ops_512_10753_inv};
int main()
{
printf("ntt_ops_512_implementation %s\n",ntt_ops_512_implementation);
printf("ntt_ops_512_version %s\n",ntt_ops_512_version);
printf("ntt_ops_512_compiler %s\n",ntt_ops_512_compiler);
fflush(stdout);
for (long long qpos = 0;qpos < 2;++qpos) {
long long q = qlist[qpos];
void (*ntt)(int16*,long long) = nttlist[qpos];
void (*invntt)(int16*,long long) = invnttlist[qpos];
printf("q %lld\n",q);
fflush(stdout);
for (long long j = 0;j < 512*512;++j)
M[j] = 0;
ntt_ops_mul = 0;
ntt_ops_add = 0;
ntt_ops_mul_x16 = 0;
ntt_ops_add_x16 = 0;
ntt_ops_mulmod = 0;
ntt_ops_reduce = 0;
ntt(M,512);
assert(ntt_ops_mulmod%512 == 0);
assert(ntt_ops_reduce%512 == 0);
assert(ntt_ops_mul%512 == 0);
assert(ntt_ops_add%512 == 0);
assert(ntt_ops_mul_x16%512 == 0);
assert(ntt_ops_add_x16%512 == 0);
printf("ntt mulmod %lld\n",ntt_ops_mulmod/512);
printf("ntt reduce %lld\n",ntt_ops_reduce/512);
printf("ntt mul %lld (underlying 16-bit mul instructions)\n",ntt_ops_mul/512);
printf("ntt add %lld (underlying 16-bit add/sub instructions)\n",ntt_ops_add/512);
printf("ntt mul_x16 %lld -> %lld\n",ntt_ops_mul_x16/512,16*ntt_ops_mul_x16/512);
printf("ntt add_x16 %lld -> %lld\n",ntt_ops_add_x16/512,16*ntt_ops_add_x16/512);
fflush(stdout);
ntt_ops_mul = 0;
ntt_ops_add = 0;
ntt_ops_mul_x16 = 0;
ntt_ops_add_x16 = 0;
ntt_ops_mulmod = 0;
ntt_ops_reduce = 0;
invntt(M,512);
assert(ntt_ops_mulmod%512 == 0);
assert(ntt_ops_reduce%512 == 0);
assert(ntt_ops_mul%512 == 0);
assert(ntt_ops_add%512 == 0);
assert(ntt_ops_mul_x16%512 == 0);
assert(ntt_ops_add_x16%512 == 0);
printf("invntt mulmod %lld\n",ntt_ops_mulmod/512);
printf("invntt reduce %lld\n",ntt_ops_reduce/512);
printf("invntt mul %lld (underlying 16-bit mul instructions)\n",ntt_ops_mul/512);
printf("invntt add %lld (underlying 16-bit add/sub instructions)\n",ntt_ops_add/512);
printf("invntt mul_x16 %lld -> %lld\n",ntt_ops_mul_x16/512,16*ntt_ops_mul_x16/512);
printf("invntt add_x16 %lld -> %lld\n",ntt_ops_add_x16/512,16*ntt_ops_add_x16/512);
fflush(stdout);
}
return 0;
}