Loading...
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | /* SPDX-License-Identifier: GPL-2.0 */ /* * NH - ε-almost-universal hash function, x86_64 AVX2 accelerated * * Copyright 2018 Google LLC * * Author: Eric Biggers <ebiggers@google.com> */ #include <linux/linkage.h> #include <linux/cfi_types.h> #define PASS0_SUMS %ymm0 #define PASS1_SUMS %ymm1 #define PASS2_SUMS %ymm2 #define PASS3_SUMS %ymm3 #define K0 %ymm4 #define K0_XMM %xmm4 #define K1 %ymm5 #define K1_XMM %xmm5 #define K2 %ymm6 #define K2_XMM %xmm6 #define K3 %ymm7 #define K3_XMM %xmm7 #define T0 %ymm8 #define T1 %ymm9 #define T2 %ymm10 #define T2_XMM %xmm10 #define T3 %ymm11 #define T3_XMM %xmm11 #define T4 %ymm12 #define T5 %ymm13 #define T6 %ymm14 #define T7 %ymm15 #define KEY %rdi #define MESSAGE %rsi #define MESSAGE_LEN %rdx #define HASH %rcx .macro _nh_2xstride k0, k1, k2, k3 // Add message words to key words vpaddd \k0, T3, T0 vpaddd \k1, T3, T1 vpaddd \k2, T3, T2 vpaddd \k3, T3, T3 // Multiply 32x32 => 64 and accumulate vpshufd $0x10, T0, T4 vpshufd $0x32, T0, T0 vpshufd $0x10, T1, T5 vpshufd $0x32, T1, T1 vpshufd $0x10, T2, T6 vpshufd $0x32, T2, T2 vpshufd $0x10, T3, T7 vpshufd $0x32, T3, T3 vpmuludq T4, T0, T0 vpmuludq T5, T1, T1 vpmuludq T6, T2, T2 vpmuludq T7, T3, T3 vpaddq T0, PASS0_SUMS, PASS0_SUMS vpaddq T1, PASS1_SUMS, PASS1_SUMS vpaddq T2, PASS2_SUMS, PASS2_SUMS vpaddq T3, PASS3_SUMS, PASS3_SUMS .endm /* * void nh_avx2(const u32 *key, const u8 *message, size_t message_len, * __le64 hash[NH_NUM_PASSES]) * * It's guaranteed that message_len % 16 == 0. */ SYM_TYPED_FUNC_START(nh_avx2) vmovdqu 0x00(KEY), K0 vmovdqu 0x10(KEY), K1 add $0x20, KEY vpxor PASS0_SUMS, PASS0_SUMS, PASS0_SUMS vpxor PASS1_SUMS, PASS1_SUMS, PASS1_SUMS vpxor PASS2_SUMS, PASS2_SUMS, PASS2_SUMS vpxor PASS3_SUMS, PASS3_SUMS, PASS3_SUMS sub $0x40, MESSAGE_LEN jl .Lloop4_done .Lloop4: vmovdqu (MESSAGE), T3 vmovdqu 0x00(KEY), K2 vmovdqu 0x10(KEY), K3 _nh_2xstride K0, K1, K2, K3 vmovdqu 0x20(MESSAGE), T3 vmovdqu 0x20(KEY), K0 vmovdqu 0x30(KEY), K1 _nh_2xstride K2, K3, K0, K1 add $0x40, MESSAGE add $0x40, KEY sub $0x40, MESSAGE_LEN jge .Lloop4 .Lloop4_done: and $0x3f, MESSAGE_LEN jz .Ldone cmp $0x20, MESSAGE_LEN jl .Llast // 2 or 3 strides remain; do 2 more. vmovdqu (MESSAGE), T3 vmovdqu 0x00(KEY), K2 vmovdqu 0x10(KEY), K3 _nh_2xstride K0, K1, K2, K3 add $0x20, MESSAGE add $0x20, KEY sub $0x20, MESSAGE_LEN jz .Ldone vmovdqa K2, K0 vmovdqa K3, K1 .Llast: // Last stride. Zero the high 128 bits of the message and keys so they // don't affect the result when processing them like 2 strides. vmovdqu (MESSAGE), T3_XMM vmovdqa K0_XMM, K0_XMM vmovdqa K1_XMM, K1_XMM vmovdqu 0x00(KEY), K2_XMM vmovdqu 0x10(KEY), K3_XMM _nh_2xstride K0, K1, K2, K3 .Ldone: // Sum the accumulators for each pass, then store the sums to 'hash' // PASS0_SUMS is (0A 0B 0C 0D) // PASS1_SUMS is (1A 1B 1C 1D) // PASS2_SUMS is (2A 2B 2C 2D) // PASS3_SUMS is (3A 3B 3C 3D) // We need the horizontal sums: // (0A + 0B + 0C + 0D, // 1A + 1B + 1C + 1D, // 2A + 2B + 2C + 2D, // 3A + 3B + 3C + 3D) // vpunpcklqdq PASS1_SUMS, PASS0_SUMS, T0 // T0 = (0A 1A 0C 1C) vpunpckhqdq PASS1_SUMS, PASS0_SUMS, T1 // T1 = (0B 1B 0D 1D) vpunpcklqdq PASS3_SUMS, PASS2_SUMS, T2 // T2 = (2A 3A 2C 3C) vpunpckhqdq PASS3_SUMS, PASS2_SUMS, T3 // T3 = (2B 3B 2D 3D) vinserti128 $0x1, T2_XMM, T0, T4 // T4 = (0A 1A 2A 3A) vinserti128 $0x1, T3_XMM, T1, T5 // T5 = (0B 1B 2B 3B) vperm2i128 $0x31, T2, T0, T0 // T0 = (0C 1C 2C 3C) vperm2i128 $0x31, T3, T1, T1 // T1 = (0D 1D 2D 3D) vpaddq T5, T4, T4 vpaddq T1, T0, T0 vpaddq T4, T0, T0 vmovdqu T0, (HASH) RET SYM_FUNC_END(nh_avx2) |