about summary refs log tree commit diff
path: root/sysdeps/x86_64/multiarch/strstr-avx512.c
blob: 3ac53accbdde0b400dfd19a2070fbb579aff4177 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/* strstr optimized with 512-bit AVX-512 instructions
   Copyright (C) 2022-2024 Free Software Foundation, Inc.
   This file is part of the GNU C Library.

   The GNU C Library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Lesser General Public
   License as published by the Free Software Foundation; either
   version 2.1 of the License, or (at your option) any later version.

   The GNU C Library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Lesser General Public License for more details.

   You should have received a copy of the GNU Lesser General Public
   License along with the GNU C Library; if not, see
   <https://www.gnu.org/licenses/>.  */

#include <immintrin.h>
#include <inttypes.h>
#include <stdbool.h>
#include <string.h>

#define FULL_MMASK64 0xffffffffffffffff
#define ONE_64BIT 0x1ull
#define ZMM_SIZE_IN_BYTES 64
#define PAGESIZE 4096

#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__)
#define kshiftri_mask64(x, y) ((x) >> (y))
#define kand_mask64(x, y) ((x) & (y))

/*
 Returns the index of the first edge within the needle, returns 0 if no edge
 is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg'
 */
static inline size_t
find_edge_in_needle (const char *ned)
{
  size_t ind = 0;
  while (ned[ind + 1] != '\0')
    {
      if (ned[ind] != ned[ind + 1])
        return ind;
      else
        ind = ind + 1;
    }
  return 0;
}

/*
 Compare needle with haystack byte by byte at specified location
 */
static inline bool
verify_string_match (const char *hay, const size_t hay_index, const char *ned,
                     size_t ind)
{
  while (ned[ind] != '\0')
    {
      if (ned[ind] != hay[hay_index + ind])
        return false;
      ind = ind + 1;
    }
  return true;
}

/*
 Compare needle with haystack at specified location. The first 64 bytes are
 compared using a ZMM register.
 */
static inline bool
verify_string_match_avx512 (const char *hay, const size_t hay_index,
                            const char *ned, const __mmask64 ned_mask,
                            const __m512i ned_zmm)
{
  /* check first 64 bytes using zmm and then scalar */
  __m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so
  __mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm);
  if (match != 0x0) // failed the first few chars
    return false;
  else if (ned_mask == FULL_MMASK64)
    return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES);
  return true;
}

char *
__strstr_avx512 (const char *haystack, const char *ned)
{
  char first = ned[0];
  if (first == '\0')
    return (char *)haystack;
  if (ned[1] == '\0')
    return (char *)strchr (haystack, ned[0]);

  size_t edge = find_edge_in_needle (ned);

  /* ensure haystack is as long as the pos of edge in needle */
  for (int ii = 0; ii < edge; ++ii)
    {
      if (haystack[ii] == '\0')
        return NULL;
    }

  /*
   Load 64 bytes of the needle and save it to a zmm register
   Read one cache line at a time to avoid loading across a page boundary
   */
  __mmask64 ned_load_mask = _bzhi_u64 (
      FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63));
  __m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned);
  __mmask64 ned_nullmask
      = _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm);

  if (__glibc_unlikely (ned_nullmask == 0x0))
    {
      ned_zmm = _mm512_loadu_si512 (ned);
      ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm);
      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
      if (ned_nullmask != 0x0)
        ned_load_mask = ned_load_mask >> 1;
    }
  else
    {
      ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT);
      ned_load_mask = ned_load_mask >> 1;
    }
  const __m512i ned0 = _mm512_set1_epi8 (ned[edge]);
  const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]);

  /*
   Read the bytes of haystack in the current cache line
   */
  size_t hay_index = edge;
  __mmask64 loadmask = _bzhi_u64 (
      FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63));
  /* First load is a partial cache line */
  __m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index);
  /* Search for NULL and compare only till null char */
  uint64_t nullmask
      = cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0));
  uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT);
  cmpmask = cmpmask & cvtmask64_u64 (loadmask);
  /* Search for the 2 characters of needle */
  __mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
  __mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1);
  k1 = kshiftri_mask64 (k1, 1);
  /* k2 masks tell us if both chars from needle match */
  uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
  /* For every match, search for the entire needle for a full match */
  while (k2)
    {
      uint64_t bitcount = _tzcnt_u64 (k2);
      k2 = _blsr_u64 (k2);
      size_t match_pos = hay_index + bitcount - edge;
      if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
          < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
        {
          /*
           * Use vector compare as long as you are not crossing a page
           */
          if (verify_string_match_avx512 (haystack, match_pos, ned,
                                          ned_load_mask, ned_zmm))
            return (char *)haystack + match_pos;
        }
      else
        {
          if (verify_string_match (haystack, match_pos, ned, 0))
            return (char *)haystack + match_pos;
        }
    }
  /* We haven't checked for potential match at the last char yet */
  haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63));
  hay_index = 0;

  /*
   Loop over one cache line at a time to prevent reading over page
   boundary
   */
  __m512i hay1;
  while (nullmask == 0)
    {
      hay0 = _mm512_loadu_si512 (haystack + hay_index);
      hay1 = _mm512_load_si512 (haystack + hay_index
                                + 1); // Always 64 byte aligned
      nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1));
      /* Compare only till null char */
      cmpmask = nullmask ^ (nullmask - ONE_64BIT);
      k0 = _mm512_cmpeq_epi8_mask (hay0, ned0);
      k1 = _mm512_cmpeq_epi8_mask (hay1, ned1);
      /* k2 masks tell us if both chars from needle match */
      k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask;
      /* For every match, compare full strings for potential match */
      while (k2)
        {
          uint64_t bitcount = _tzcnt_u64 (k2);
          k2 = _blsr_u64 (k2);
          size_t match_pos = hay_index + bitcount - edge;
          if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1))
              < PAGESIZE - 1 - ZMM_SIZE_IN_BYTES)
            {
              /*
               * Use vector compare as long as you are not crossing a page
               */
              if (verify_string_match_avx512 (haystack, match_pos, ned,
                                              ned_load_mask, ned_zmm))
                return (char *)haystack + match_pos;
            }
          else
            {
              /* Compare byte by byte */
              if (verify_string_match (haystack, match_pos, ned, 0))
                return (char *)haystack + match_pos;
            }
        }
      hay_index += ZMM_SIZE_IN_BYTES;
    }
  return NULL;
}