Add decrypt

This commit is contained in:
Adam Ierymenko 2019-08-16 18:40:22 -07:00
parent 846f03504e
commit 7bdca83de3
No known key found for this signature in database
GPG key ID: 1657198823E52A61
3 changed files with 253 additions and 47 deletions

View file

@ -37,14 +37,13 @@
#define ZT_AES_AESNI 1
#endif
#define ZT_AES_KEY_SIZE 32
#define ZT_AES_BLOCK_SIZE 16
namespace ZeroTier {
/**
* AES-256 and GCM AEAD
*
* AES with 128-bit or 192-bit key sizes isn't supported here. This also only
* supports the encrypt operation since we use AES in GCM mode. For HW acceleration
* the code is inlined for maximum performance.
* AES-256 and AES-GCM AEAD
*/
class AES
{
@ -81,7 +80,18 @@ public:
_encryptSW(in,out);
}
inline void ecbEncrypt(const void *in,unsigned int inlen,void *out)
inline void decrypt(const uint8_t in[16],uint8_t out[16]) const
{
#ifdef ZT_AES_AESNI
if (likely(HW_ACCEL)) {
_decrypt_aesni(in,out);
return;
}
#endif
_decryptSW(in,out);
}
inline void ecbScramble(const void *in,unsigned int inlen,void *out)
{
if (inlen < 16)
return;
@ -101,7 +111,7 @@ public:
o += 16;
inlen -= 16;
}
if (inlen != 0) {
if (inlen) {
i -= (16 - inlen);
o -= (16 - inlen);
_encrypt_aesni(i,o);
@ -117,7 +127,7 @@ public:
o += 16;
inlen -= 16;
}
if (inlen != 0) {
if (inlen) {
i -= (16 - inlen);
o -= (16 - inlen);
_encryptSW(i,o);
@ -151,16 +161,18 @@ public:
private:
void _initSW(const uint8_t key[32]);
void _encryptSW(const uint8_t in[16],uint8_t out[16]) const;
void _decryptSW(const uint8_t in[16],uint8_t out[16]) const;
union {
#ifdef ZT_AES_AESNI
struct {
__m128i k[15];
__m128i k[28];
__m128i h,hh,hhh,hhhh;
} ni;
#endif
struct {
uint32_t k[60];
uint32_t ek[60];
uint32_t dk[60];
} sw;
} _k;
@ -211,6 +223,19 @@ private:
_k.ni.k[12] = t1 = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x20));
_k.ni.k[13] = t2 = _init256_2_aesni(t1,t2);
_k.ni.k[14] = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x40));
_k.ni.k[15] = _mm_aesimc_si128(_k.ni.k[13]);
_k.ni.k[16] = _mm_aesimc_si128(_k.ni.k[12]);
_k.ni.k[17] = _mm_aesimc_si128(_k.ni.k[11]);
_k.ni.k[18] = _mm_aesimc_si128(_k.ni.k[10]);
_k.ni.k[19] = _mm_aesimc_si128(_k.ni.k[9]);
_k.ni.k[20] = _mm_aesimc_si128(_k.ni.k[8]);
_k.ni.k[21] = _mm_aesimc_si128(_k.ni.k[7]);
_k.ni.k[22] = _mm_aesimc_si128(_k.ni.k[6]);
_k.ni.k[23] = _mm_aesimc_si128(_k.ni.k[5]);
_k.ni.k[24] = _mm_aesimc_si128(_k.ni.k[4]);
_k.ni.k[25] = _mm_aesimc_si128(_k.ni.k[3]);
_k.ni.k[26] = _mm_aesimc_si128(_k.ni.k[2]);
_k.ni.k[27] = _mm_aesimc_si128(_k.ni.k[1]);
/* Init GCM / GHASH */
__m128i h = _mm_xor_si128(_mm_setzero_si128(),_k.ni.k[0]);
@ -412,6 +437,26 @@ private:
_mm_storeu_si128((__m128i *)((uint8_t *)out + 112),_mm_aesenclast_si128(tmp7,k14));
}
}
inline void _decrypt_aesni(const void *in,void *out) const
{
__m128i tmp;
tmp = _mm_loadu_si128((const __m128i *)in);
tmp = _mm_xor_si128(tmp,_k.ni.k[14]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[15]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[16]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[17]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[18]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[19]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[20]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[21]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[22]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[23]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[24]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[25]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[26]);
tmp = _mm_aesdec_si128(tmp,_k.ni.k[27]);
_mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(tmp,_k.ni.k[0]));
}
static inline __m128i _swap128_aesni(__m128i x) { return _mm_shuffle_epi8(x,_mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)); }
static inline __m128i _mult_block_aesni(__m128i h,__m128i y)
@ -828,22 +873,6 @@ private:
__m128i *bi = (__m128i *)in;
__m128i *bo = (__m128i *)out;
__m128i k0 = _k.ni.k[0];
__m128i k1 = _k.ni.k[1];
__m128i k2 = _k.ni.k[2];
__m128i k3 = _k.ni.k[3];
__m128i k4 = _k.ni.k[4];
__m128i k5 = _k.ni.k[5];
__m128i k6 = _k.ni.k[6];
__m128i k7 = _k.ni.k[7];
__m128i k8 = _k.ni.k[8];
__m128i k9 = _k.ni.k[9];
__m128i k10 = _k.ni.k[10];
__m128i k11 = _k.ni.k[11];
__m128i k12 = _k.ni.k[12];
__m128i k13 = _k.ni.k[13];
__m128i k14 = _k.ni.k[14];
unsigned int i;
for (i=0;i<pblocks;i+=4) {
__m128i d1 = _mm_loadu_si128(bi + i + 0);
@ -852,7 +881,11 @@ private:
__m128i d4 = _mm_loadu_si128(bi + i + 3);
y = _mm_xor_si128(y,d1);
y = _mult4xor_aesni(_k.ni.hhhh,_k.ni.hhh,_k.ni.hh,_k.ni.h,y,d2,d3,d4);
__m128i t1 = _mm_xor_si128(cb,k0);
__m128i k0 = _k.ni.k[0];
__m128i k1 = _k.ni.k[1];
__m128i k2 = _k.ni.k[2];
__m128i k3 = _k.ni.k[3];
__m128i t1 = _mm_xor_si128(cb,k0);
cb = _increment_be_aesni(cb);
__m128i t2 = _mm_xor_si128(cb,k0);
cb = _increment_be_aesni(cb);
@ -872,6 +905,10 @@ private:
t2 = _mm_aesenc_si128(t2,k3);
t3 = _mm_aesenc_si128(t3,k3);
t4 = _mm_aesenc_si128(t4,k3);
__m128i k4 = _k.ni.k[4];
__m128i k5 = _k.ni.k[5];
__m128i k6 = _k.ni.k[6];
__m128i k7 = _k.ni.k[7];
t1 = _mm_aesenc_si128(t1,k4);
t2 = _mm_aesenc_si128(t2,k4);
t3 = _mm_aesenc_si128(t3,k4);
@ -888,6 +925,10 @@ private:
t2 = _mm_aesenc_si128(t2,k7);
t3 = _mm_aesenc_si128(t3,k7);
t4 = _mm_aesenc_si128(t4,k7);
__m128i k8 = _k.ni.k[8];
__m128i k9 = _k.ni.k[9];
__m128i k10 = _k.ni.k[10];
__m128i k11 = _k.ni.k[11];
t1 = _mm_aesenc_si128(t1,k8);
t2 = _mm_aesenc_si128(t2,k8);
t3 = _mm_aesenc_si128(t3,k8);
@ -904,6 +945,9 @@ private:
t2 = _mm_aesenc_si128(t2,k11);
t3 = _mm_aesenc_si128(t3,k11);
t4 = _mm_aesenc_si128(t4,k11);
__m128i k12 = _k.ni.k[12];
__m128i k13 = _k.ni.k[13];
__m128i k14 = _k.ni.k[14];
t1 = _mm_aesenc_si128(t1,k12);
t2 = _mm_aesenc_si128(t2,k12);
t3 = _mm_aesenc_si128(t3,k12);
@ -929,18 +973,33 @@ private:
for (i=pblocks;i<blocks;i++) {
__m128i d1 = _mm_loadu_si128(bi + i);
y = _ghash_aesni(_k.ni.h,y,d1);
__m128i k0 = _k.ni.k[0];
__m128i k1 = _k.ni.k[1];
__m128i k2 = _k.ni.k[2];
__m128i k3 = _k.ni.k[3];
__m128i t1 = _mm_xor_si128(cb,k0);
t1 = _mm_aesenc_si128(t1,k1);
t1 = _mm_aesenc_si128(t1,k2);
t1 = _mm_aesenc_si128(t1,k3);
__m128i k4 = _k.ni.k[4];
__m128i k5 = _k.ni.k[5];
__m128i k6 = _k.ni.k[6];
__m128i k7 = _k.ni.k[7];
t1 = _mm_aesenc_si128(t1,k4);
t1 = _mm_aesenc_si128(t1,k5);
t1 = _mm_aesenc_si128(t1,k6);
t1 = _mm_aesenc_si128(t1,k7);
__m128i k8 = _k.ni.k[8];
__m128i k9 = _k.ni.k[9];
__m128i k10 = _k.ni.k[10];
__m128i k11 = _k.ni.k[11];
t1 = _mm_aesenc_si128(t1,k8);
t1 = _mm_aesenc_si128(t1,k9);
t1 = _mm_aesenc_si128(t1,k10);
t1 = _mm_aesenc_si128(t1,k11);
__m128i k12 = _k.ni.k[12];
__m128i k13 = _k.ni.k[13];
__m128i k14 = _k.ni.k[14];
t1 = _mm_aesenc_si128(t1,k12);
t1 = _mm_aesenc_si128(t1,k13);
t1 = _mm_aesenclast_si128(t1,k14);