-rw-r--r-- 5836 attackntrw-20220829/attackntrw.c raw
#include <stdio.h>
#include <inttypes.h>
#include <string.h>
#include <assert.h>
#include "crypto_hash_sha3256.h"
#include "crypto_kem_ntruhrss701.h"
#define PUBLICKEYBYTES crypto_kem_ntruhrss701_PUBLICKEYBYTES
#define SECRETKEYBYTES crypto_kem_ntruhrss701_SECRETKEYBYTES
#define CIPHERTEXTBYTES crypto_kem_ntruhrss701_CIPHERTEXTBYTES
#define BYTES crypto_kem_ntruhrss701_BYTES
#define keypair crypto_kem_ntruhrss701_keypair
#define enc crypto_kem_ntruhrss701_enc
#define dec crypto_kem_ntruhrss701_dec
// ----- legitimate alice and bob
#define TARGETS 10
unsigned char alice_pk[PUBLICKEYBYTES];
unsigned char alice_sk[SECRETKEYBYTES];
unsigned char alice_k[BYTES];
unsigned char bob_ct[TARGETS][CIPHERTEXTBYTES];
unsigned char bob_k[TARGETS][BYTES];
void alice_prep(void)
{
keypair(alice_pk,alice_sk);
}
void bob(void)
{
for (long long t = 0;t < TARGETS;++t)
enc(bob_ct[t],bob_k[t],alice_pk);
}
void alice(void)
{
for (long long t = 0;t < TARGETS;++t) {
dec(alice_k,bob_ct[t],alice_sk);
assert(!memcmp(alice_k,bob_k[t],BYTES));
}
}
void alice_oracle(unsigned char *k,const unsigned char *ct)
{
for (long long t = 0;t < TARGETS;++t)
assert(memcmp(ct,bob_ct[t],CIPHERTEXTBYTES));
dec(k,ct,alice_sk);
}
void alice_fault(void)
{
alice_sk[SECRETKEYBYTES-1] ^= 2;
}
// ----- for eve: useful subroutines from ntruhrss701
#define NTRU_N 701
#define NTRU_PACK_DEG (NTRU_N-1)
#define NTRU_PACK_TRINARY_BYTES ((NTRU_PACK_DEG+4)/5)
#define NTRU_OWCPA_MSGBYTES (2*NTRU_PACK_TRINARY_BYTES)
#define PAD32(X) ((((X) + 31)/32)*32)
#include <immintrin.h>
typedef union{ /* align to 32 byte boundary for vmovdqa */
uint16_t coeffs[PAD32(NTRU_N)];
__m256i coeffs_x16[PAD32(NTRU_N)/16];
} poly;
#define poly_lift crypto_kem_ntruhrss701_avx2_constbranchindex_poly_lift
#define poly_S3_tobytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_S3_tobytes
#define poly_trinary_Zq_to_Z3 crypto_kem_ntruhrss701_avx2_constbranchindex_poly_trinary_Zq_to_Z3
#define poly_Rq_inv crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Rq_inv
#define poly_Rq_sum_zero_frombytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Rq_sum_zero_frombytes
#define poly_Sq_mul crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_mul
#define poly_Sq_frombytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_frombytes
#define poly_Sq_tobytes crypto_kem_ntruhrss701_avx2_constbranchindex_poly_Sq_tobytes
void poly_lift(poly *r, const poly *a);
void poly_S3_tobytes(unsigned char msg[NTRU_PACK_TRINARY_BYTES], const poly *a);
void poly_trinary_Zq_to_Z3(poly *r);
void poly_Rq_inv(poly *r, const poly *a);
void poly_Rq_sum_zero_frombytes(poly *r, const unsigned char *a);
void poly_Sq_mul(poly *r, const poly *a, const poly *b);
void poly_Sq_frombytes(poly *r, const unsigned char *a);
void poly_Sq_tobytes(unsigned char *r, const poly *a);
// ----- attack
unsigned char eve_ct[CIPHERTEXTBYTES];
unsigned char eve_k[BYTES];
#define EVE_MODS (2*NTRU_N)
unsigned char eve_k_stored[TARGETS][EVE_MODS][BYTES];
long long eve_match[EVE_MODS];
long long eve_m1x1[NTRU_N];
long long eve_reconstruction[NTRU_N];
unsigned char eve_rm[NTRU_OWCPA_MSGBYTES];
unsigned char eve_final_k[BYTES];
poly eve_pk_poly;
poly eve_pk_inv;
poly eve_m;
poly eve_b;
poly eve_r;
poly eve_liftm;
poly eve_ct_starting_poly;
poly eve_ct_poly;
void attack_onetarget(long long t)
{
poly_Rq_sum_zero_frombytes(&eve_ct_starting_poly,bob_ct[t]);
for (long long start = 0;start < NTRU_N;++start) {
if (eve_match[2*start]) eve_m1x1[start] = 1;
else if (eve_match[2*start+1]) eve_m1x1[start] = -1;
else eve_m1x1[start] = 0;
}
for (long long start = 0;start < NTRU_N;++start)
eve_m.coeffs[start] = (3-eve_m1x1[(start+NTRU_N-1)%NTRU_N]+eve_m1x1[start])%3;
for (long long start = 0;start < NTRU_N;++start)
eve_m.coeffs[start] = (3+eve_m.coeffs[start]-eve_m.coeffs[NTRU_N-1])%3;
// now follow (portions of) ref owcpa_dec to reconstruct r
poly_S3_tobytes(eve_rm+NTRU_PACK_TRINARY_BYTES,&eve_m);
poly_lift(&eve_liftm,&eve_m);
for (long long i = 0;i < NTRU_N;++i)
eve_b.coeffs[i] = eve_ct_starting_poly.coeffs[i] - eve_liftm.coeffs[i];
poly_Sq_mul(&eve_r,&eve_b,&eve_pk_inv);
poly_trinary_Zq_to_Z3(&eve_r);
poly_S3_tobytes(eve_rm,&eve_r);
// and hash as in ref crypto_kem_dec
crypto_hash_sha3256(eve_final_k,eve_rm,NTRU_OWCPA_MSGBYTES);
for (long long i = 0;i < BYTES;++i) assert(bob_k[t][i] == eve_final_k[i]);
printf("successfully broke plaintext %lld\n",t);
}
void attack(void)
{
poly_Rq_sum_zero_frombytes(&eve_pk_poly,alice_pk);
poly_Rq_inv(&eve_pk_inv,&eve_pk_poly);
for (long long epoch = 0;epoch < 2;++epoch) {
for (long long t = 0;t < TARGETS;++t) {
poly_Rq_sum_zero_frombytes(&eve_ct_starting_poly,bob_ct[t]);
for (long long mod = 0;mod < EVE_MODS;++mod) {
long long pos = mod/2;
long long pos1 = (pos+1)%NTRU_N;
long long offset = (mod%2) ? 2 : -2;
for (long long i = 0;i < NTRU_N;++i)
eve_ct_poly.coeffs[i] = eve_ct_starting_poly.coeffs[i];
eve_ct_poly.coeffs[pos] = eve_ct_starting_poly.coeffs[pos]+offset;
eve_ct_poly.coeffs[pos1] = eve_ct_starting_poly.coeffs[pos1]-offset;
poly_Sq_tobytes(eve_ct,&eve_ct_poly);
alice_oracle(eve_k,eve_ct);
if (epoch == 0)
memcpy(eve_k_stored[t][mod],eve_k,BYTES);
else
eve_match[mod] = !memcmp(eve_k_stored[t][mod],eve_k,BYTES);
}
if (epoch == 0)
printf("collected data for ciphertext %lld\n",t);
else
attack_onetarget(t);
fflush(stdout);
}
if (epoch == 0) {
// one single-bit fault at the end of epoch 0
alice_fault();
printf("fault!\n");
fflush(stdout);
}
}
}
int main()
{
alice_prep();
bob();
alice();
attack();
return 0;
}