Speed up memmove on x86
[akaros.git] / kern / src / string.c
1 // Basic string routines.  Not hardware optimized, but not shabby.
2
3 #ifdef __SHARC__
4 #pragma nosharc
5 #endif
6
7 #include <stdio.h>
8 #include <string.h>
9 #include <ros/memlayout.h>
10 #include <assert.h>
11
12 int
13 strlen(const char *s)
14 {
15         int n;
16
17         for (n = 0; *s != '\0'; s++)
18                 n++;
19         return n;
20 }
21
22 int
23 strnlen(const char *s, size_t size)
24 {
25         int n;
26
27         for (n = 0; size > 0 && *s != '\0'; s++, size--)
28                 n++;
29         return n;
30 }
31
32 /* zra: These aren't being used, and they are dangerous, so I'm rm'ing them
33 char *
34 strcpy(char *dst, const char *src)
35 {
36         char *ret;
37
38         ret = dst;
39         while ((*dst++ = *src++) != '\0')
40                 ;
41         return ret;
42 }
43
44 char *
45 strcat(char *dst, const char *src)
46 {
47         strcpy(dst+strlen(dst),src);
48         return dst;
49 }
50 */
51
52 char *
53 strncpy(char *dst, const char *src, size_t size) {
54         size_t i;
55         char *ret;
56
57         ret = dst;
58         for (i = 0; i < size; i++) {
59                 // TODO: ivy bitches about this
60                 *dst++ = *src;
61                 // If strlen(src) < size, null-pad 'dst' out to 'size' chars
62                 if (*src != '\0')
63                         src++;
64         }
65         return ret;
66 }
67
68 size_t
69 strlcpy(char *dst, const char *src, size_t size)
70 {
71         char *dst_in;
72
73         dst_in = dst;
74         if (size > 0) {
75                 while (--size > 0 && *src != '\0')
76                         *dst++ = *src++;
77                 *dst = '\0';
78         }
79         return dst - dst_in;
80 }
81
82 int
83 strcmp(const char *p, const char *q)
84 {
85         while (*p && *p == *q)
86                 p++, q++;
87         return (int) ((unsigned char) *p - (unsigned char) *q);
88 }
89
90 int
91 strncmp(const char *p, const char *q, size_t n)
92 {
93         while (n > 0 && *p && *p == *q)
94                 n--, p++, q++;
95         if (n == 0)
96                 return 0;
97         else
98                 return (int) ((unsigned char) *p - (unsigned char) *q);
99 }
100
101 // Return a pointer to the first occurrence of 'c' in 's',
102 // or a null pointer if the string has no 'c'.
103 char *
104 strchr(const char *s, char c)
105 {
106         for (; *s; s++)
107                 if (*s == c)
108                         return (char *) s;
109         return 0;
110 }
111
112 // Return a pointer to the last occurrence of 'c' in 's',
113 // or a null pointer if the string has no 'c'.
114 char *
115 strrchr(const char *s, char c)
116 {
117         char *lastc = NULL;
118         for (; *s; s++)
119                 if (*s == c){
120                         lastc = (char*)s;
121                 }
122         return lastc;
123 }
124
125 void *
126 memchr(void* mem, int chr, int len)
127 {
128         char* s = (char*)mem;
129         for(int i = 0; i < len; i++)
130                 if(s[i] == (char)chr)
131                         return s+i;
132         return NULL;
133 }
134
135 // Return a pointer to the first occurrence of 'c' in 's',
136 // or a pointer to the string-ending null character if the string has no 'c'.
137 char *
138 strfind(const char *s, char c)
139 {
140         for (; *s; s++)
141                 if (*s == c)
142                         break;
143         return (char *) s;
144 }
145
146 // memset aligned words.
147 static inline void *
148 memsetw(long* _v, long c, size_t n)
149 {
150         long *start, *end, *v;
151
152         start = _v;
153         end = _v + n/sizeof(long);
154         v = _v;
155         c = c & 0xff;
156         c = c | c<<8;
157         c = c | c<<16;
158         #if NUM_ADDR_BITS == 64
159         c = c | c<<32;
160         #elif NUM_ADDR_BITS != 32
161         # error
162         #endif
163
164         while(v < end - (8-1))
165         {
166                 v[3] = v[2] = v[1] = v[0] = c;
167                 v += 4;
168                 v[3] = v[2] = v[1] = v[0] = c;
169                 v += 4;
170         }
171
172         while(v < end)
173           *v++ = c;
174
175         return start;
176 }
177
178 // copy aligned words.
179 // unroll 9 ways to get multiple misses in flight
180 #define memcpyw(type, _dst, _src, n) \
181   do { \
182         type* restrict src = (type*)(_src); \
183         type* restrict dst = (type*)(_dst); \
184         type* srcend = src + (n)/sizeof(type); \
185         type* dstend = dst + (n)/sizeof(type); \
186         while (dst < dstend - (9-1)) { \
187                 dst[0] = src[0]; \
188                 dst[1] = src[1]; \
189                 dst[2] = src[2]; \
190                 dst[3] = src[3]; \
191                 dst[4] = src[4]; \
192                 dst[5] = src[5]; \
193                 dst[6] = src[6]; \
194                 dst[7] = src[7]; \
195                 dst[8] = src[8]; \
196                 src += 9; \
197                 dst += 9; \
198         } \
199         while(dst < dstend) \
200           *dst++ = *src++; \
201   } while(0)
202
203 void *
204 memset(void *COUNT(_n) v, int c, size_t _n)
205 {
206         char *BND(v,v+_n) p;
207         size_t n0;
208         size_t n = _n;
209
210         if (n == 0) return NULL; // zra: complain here?
211
212         p = v;
213
214     while (n > 0 && ((uintptr_t)p & (sizeof(long)-1)))
215         {
216                 *p++ = c;
217                 n--;
218         }
219
220         if (n >= sizeof(long))
221         {
222                 n0 = n / sizeof(long) * sizeof(long);
223                 memsetw((long*)p, c, n0);
224                 n -= n0;
225                 p += n0;
226         }
227
228         while (n > 0)
229         {
230                 *p++ = c;
231                 n--;
232         }
233
234         return v;
235 }
236
237 void *
238 memcpy(void* dst, const void* src, size_t _n)
239 {
240         const char* s;
241         char* d;
242         size_t n0 = 0;
243         size_t n = _n;
244         int align = sizeof(long)-1;
245
246         s = src;
247         d = dst;
248
249         if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(long)-1)) == 0)
250         {
251                 n0 = n / sizeof(long) * sizeof(long);
252                 memcpyw(long, d, s, n0);
253         }
254         else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(int)-1)) == 0)
255         {
256                 n0 = n / sizeof(int) * sizeof(int);
257                 memcpyw(int, d, s, n0);
258         }
259         else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(short)-1)) == 0)
260         {
261                 n0 = n / sizeof(short) * sizeof(short);
262                 memcpyw(short, d, s, n0);
263         }
264
265         n -= n0;
266         s += n0;
267         d += n0;
268
269         while (n-- > 0)
270                 *d++ = *s++;
271
272         return dst;
273 }
274
275 #ifdef CONFIG_X86
276 void bcopy(const void *src, void *dst, size_t len);
277 #endif
278
279 void *
280 memmove(void *COUNT(_n) dst, const void *COUNT(_n) src, size_t _n)
281 {
282 #ifdef CONFIG_X86
283         bcopy(src, dst, _n);
284         return dst;
285 #else
286         const char *BND(src,src+_n) s;
287         char *BND(dst,dst+_n) d;
288         size_t n = _n;
289         
290         s = src;
291         d = dst;
292         if (s < d && s + n > d) {
293                 s += n;
294                 d += n;
295                 while (n-- > 0)
296                         *--d = *--s;
297         } else
298                 while (n-- > 0)
299                         *d++ = *s++;
300
301         return dst;
302 #endif
303 }
304
305 int
306 memcmp(const void *COUNT(n) v1, const void *COUNT(n) v2, size_t n)
307 {
308         const uint8_t *BND(v1,v1+n) s1 = (const uint8_t *) v1;
309         const uint8_t *BND(v2,v2+n) s2 = (const uint8_t *) v2;
310
311         while (n-- > 0) {
312                 if (*s1 != *s2)
313                         return (int) *s1 - (int) *s2;
314                 s1++, s2++;
315         }
316
317         return 0;
318 }
319
320 void *
321 memfind(const void *COUNT(n) _s, int c, size_t n)
322 {
323         const void *SNT ends = (const char *) _s + n;
324         const void *BND(_s,_s + n) s = _s;
325         for (; s < ends; s++)
326                 if (*(const unsigned char *) s == (unsigned char) c)
327                         break;
328         return (void *BND(_s,_s+n)) s;
329 }
330
331 long
332 strtol(const char *s, char **endptr, int base)
333 {
334         int neg = 0;
335         long val = 0;
336
337         // gobble initial whitespace
338         while (*s == ' ' || *s == '\t')
339                 s++;
340
341         // plus/minus sign
342         if (*s == '+')
343                 s++;
344         else if (*s == '-')
345                 s++, neg = 1;
346
347         // hex or octal base prefix
348         if ((base == 0 || base == 16) && (s[0] == '0' && s[1] == 'x'))
349                 s += 2, base = 16;
350         else if (base == 0 && s[0] == '0')
351                 s++, base = 8;
352         else if (base == 0)
353                 base = 10;
354
355         // digits
356         while (1) {
357                 int dig;
358
359                 if (*s >= '0' && *s <= '9')
360                         dig = *s - '0';
361                 else if (*s >= 'a' && *s <= 'z')
362                         dig = *s - 'a' + 10;
363                 else if (*s >= 'A' && *s <= 'Z')
364                         dig = *s - 'A' + 10;
365                 else
366                         break;
367                 if (dig >= base)
368                         break;
369                 s++, val = (val * base) + dig;
370                 // we don't properly detect overflow!
371         }
372
373         if (endptr)
374                 *endptr = (char *) s;
375         return (neg ? -val : val);
376 }
377
378 unsigned long
379 strtoul(const char *s, char **endptr, int base)
380 {
381         int neg = 0;
382         unsigned long val = 0;
383
384         // gobble initial whitespace
385         while (*s == ' ' || *s == '\t')
386                 s++;
387
388         // plus/minus sign
389         if (*s == '+')
390                 s++;
391         else if (*s == '-')
392                 s++, neg = 1;
393
394         // hex or octal base prefix
395         if ((base == 0 || base == 16) && (s[0] == '0' && s[1] == 'x'))
396                 s += 2, base = 16;
397         else if (base == 0 && s[0] == '0')
398                 s++, base = 8;
399         else if (base == 0)
400                 base = 10;
401
402         // digits
403         while (1) {
404                 int dig;
405
406                 if (*s >= '0' && *s <= '9')
407                         dig = *s - '0';
408                 else if (*s >= 'a' && *s <= 'z')
409                         dig = *s - 'a' + 10;
410                 else if (*s >= 'A' && *s <= 'Z')
411                         dig = *s - 'A' + 10;
412                 else
413                         break;
414                 if (dig >= base)
415                         break;
416                 s++, val = (val * base) + dig;
417                 // we don't properly detect overflow!
418         }
419
420         if (endptr)
421                 *endptr = (char *) s;
422         return (neg ? -val : val);
423 }
424
425 int atoi(const char *s)
426 {
427         if (!s)
428                 return 0;
429         if (s[0] == '0' && s[1] == 'x')
430                 warn("atoi() used on a hex string!");
431         // no overflow detection
432         return (int)strtol(s,NULL,10);
433 }
434
435 int sigchecksum(void *address, int length)
436 {
437         uint8_t *p, sum;
438
439         sum = 0;
440         for (p = address; length-- > 0; p++)
441                 sum += *p;
442
443         return sum;
444 }
445
446 void *sigscan(uint8_t *address, int length, char *signature)
447 {
448         uint8_t *e, *p;
449         int siglength;
450
451         e = address + length;
452         siglength = strlen(signature);
453         for (p = address; p + siglength < e; p += 16) {
454                 if (memcmp(p, signature, siglength))
455                         continue;
456                 return p;
457         }
458
459         return NULL;
460 }