big speedups for memcpy/memset
[akaros.git] / kern / src / string.c
1 // Basic string routines.  Not hardware optimized, but not shabby.
2
3 #ifdef __SHARC__
4 #pragma nosharc
5 #endif
6
7 #include <stdio.h>
8 #include <string.h>
9 #include <ros/memlayout.h>
10 #include <assert.h>
11
12 int
13 strlen(const char *s)
14 {
15         int n;
16
17         for (n = 0; *s != '\0'; s++)
18                 n++;
19         return n;
20 }
21
22 int
23 strnlen(const char *s, size_t size)
24 {
25         int n;
26
27         for (n = 0; size > 0 && *s != '\0'; s++, size--)
28                 n++;
29         return n;
30 }
31
32 /* zra: These aren't being used, and they are dangerous, so I'm rm'ing them
33 char *
34 strcpy(char *dst, const char *src)
35 {
36         char *ret;
37
38         ret = dst;
39         while ((*dst++ = *src++) != '\0')
40                 ;
41         return ret;
42 }
43
44 char *
45 strcat(char *dst, const char *src)
46 {
47         strcpy(dst+strlen(dst),src);
48         return dst;
49 }
50 */
51
52 char *
53 strncpy(char *dst, const char *src, size_t size) {
54         size_t i;
55         char *ret;
56
57         ret = dst;
58         for (i = 0; i < size; i++) {
59                 // TODO: ivy bitches about this
60                 *dst++ = *src;
61                 // If strlen(src) < size, null-pad 'dst' out to 'size' chars
62                 if (*src != '\0')
63                         src++;
64         }
65         return ret;
66 }
67
68 size_t
69 strlcpy(char *dst, const char *src, size_t size)
70 {
71         char *dst_in;
72
73         dst_in = dst;
74         if (size > 0) {
75                 while (--size > 0 && *src != '\0')
76                         *dst++ = *src++;
77                 *dst = '\0';
78         }
79         return dst - dst_in;
80 }
81
82 int
83 strcmp(const char *p, const char *q)
84 {
85         while (*p && *p == *q)
86                 p++, q++;
87         return (int) ((unsigned char) *p - (unsigned char) *q);
88 }
89
90 int
91 strncmp(const char *p, const char *q, size_t n)
92 {
93         while (n > 0 && *p && *p == *q)
94                 n--, p++, q++;
95         if (n == 0)
96                 return 0;
97         else
98                 return (int) ((unsigned char) *p - (unsigned char) *q);
99 }
100
101 // Return a pointer to the first occurrence of 'c' in 's',
102 // or a null pointer if the string has no 'c'.
103 char *
104 strchr(const char *s, char c)
105 {
106         for (; *s; s++)
107                 if (*s == c)
108                         return (char *) s;
109         return 0;
110 }
111
112 void *
113 memchr(void* mem, int chr, int len)
114 {
115         char* s = (char*)mem;
116         for(int i = 0; i < len; i++)
117                 if(s[i] == (char)chr)
118                         return s+i;
119         return NULL;
120 }
121
122 // Return a pointer to the first occurrence of 'c' in 's',
123 // or a pointer to the string-ending null character if the string has no 'c'.
124 char *
125 strfind(const char *s, char c)
126 {
127         for (; *s; s++)
128                 if (*s == c)
129                         break;
130         return (char *) s;
131 }
132
133 // memset aligned words.
134 static inline void *
135 memsetw(long* _v, long c, size_t n)
136 {
137         long *start, *end, *v;
138
139         start = _v;
140         end = _v + n/sizeof(long);
141         v = _v;
142         c = (char)c;
143         c = c | c<<8;
144         c = c | c<<16;
145         #if NUM_ADDR_BITS == 64
146         c = c | c<<32;
147         #elif NUM_ADDR_BITS != 32
148         # error
149         #endif
150
151         while(v < end - (8-1))
152         {
153                 v[3] = v[2] = v[1] = v[0] = c;
154                 v += 4;
155                 v[3] = v[2] = v[1] = v[0] = c;
156                 v += 4;
157         }
158
159         while(v < end)
160           *v++ = c;
161
162         return start;
163 }
164
165 // copy aligned words.
166 // unroll 9 ways to get multiple misses in flight
167 #define memcpyw(type, _dst, _src, n) \
168   do { \
169         type* restrict src = (type*)(_src); \
170         type* restrict dst = (type*)(_dst); \
171         type* srcend = src + (n)/sizeof(type); \
172         type* dstend = dst + (n)/sizeof(type); \
173         while (dst < dstend - (9-1)) { \
174                 dst[0] = src[0]; \
175                 dst[1] = src[1]; \
176                 dst[2] = src[2]; \
177                 dst[3] = src[3]; \
178                 dst[4] = src[4]; \
179                 dst[5] = src[5]; \
180                 dst[6] = src[6]; \
181                 dst[7] = src[7]; \
182                 dst[8] = src[8]; \
183                 src += 9; \
184                 dst += 9; \
185         } \
186         while(dst < dstend) \
187           *dst++ = *src++; \
188   } while(0)
189
190 void *
191 memset(void *COUNT(_n) v, int c, size_t _n)
192 {
193         char *BND(v,v+_n) p;
194         size_t n0;
195         size_t n = _n;
196
197         if (n == 0) return NULL; // zra: complain here?
198
199         p = v;
200
201     while (n > 0 && ((uintptr_t)p & (sizeof(long)-1)))
202         {
203                 *p++ = c;
204                 n--;
205         }
206
207         if (n >= sizeof(long))
208         {
209                 n0 = n / sizeof(long) * sizeof(long);
210                 memsetw((long*)p, c, n0);
211                 n -= n0;
212                 p += n0;
213         }
214
215         while (n > 0)
216         {
217                 *p++ = c;
218                 n--;
219         }
220
221         return v;
222 }
223
224 void *
225 memcpy(void* dst, const void* src, size_t _n)
226 {
227         const char* s;
228         char* d;
229         size_t n0 = 0;
230         size_t n = _n;
231         int align = sizeof(long)-1;
232
233         s = src;
234         d = dst;
235
236         if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(long)-1)) == 0)
237         {
238                 n0 = n / sizeof(long) * sizeof(long);
239                 memcpyw(long, d, s, n0);
240         }
241         else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(int)-1)) == 0)
242         {
243                 n0 = n / sizeof(int) * sizeof(int);
244                 memcpyw(int, d, s, n0);
245         }
246         else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(short)-1)) == 0)
247         {
248                 n0 = n / sizeof(short) * sizeof(short);
249                 memcpyw(short, d, s, n0);
250         }
251
252         n -= n0;
253         s += n0;
254         d += n0;
255
256         while (n-- > 0)
257                 *d++ = *s++;
258
259         return dst;
260 }
261
262 void *
263 memmove(void *COUNT(_n) dst, const void *COUNT(_n) src, size_t _n)
264 {
265         const char *BND(src,src+_n) s;
266         char *BND(dst,dst+_n) d;
267         size_t n = _n;
268         
269         s = src;
270         d = dst;
271         if (s < d && s + n > d) {
272                 s += n;
273                 d += n;
274                 while (n-- > 0)
275                         *--d = *--s;
276         } else
277                 while (n-- > 0)
278                         *d++ = *s++;
279
280         return dst;
281 }
282
283 int
284 memcmp(const void *COUNT(n) v1, const void *COUNT(n) v2, size_t n)
285 {
286         const uint8_t *BND(v1,v1+n) s1 = (const uint8_t *) v1;
287         const uint8_t *BND(v2,v2+n) s2 = (const uint8_t *) v2;
288
289         while (n-- > 0) {
290                 if (*s1 != *s2)
291                         return (int) *s1 - (int) *s2;
292                 s1++, s2++;
293         }
294
295         return 0;
296 }
297
298 void *
299 memfind(const void *COUNT(n) _s, int c, size_t n)
300 {
301         const void *SNT ends = (const char *) _s + n;
302         const void *BND(_s,_s + n) s = _s;
303         for (; s < ends; s++)
304                 if (*(const unsigned char *) s == (unsigned char) c)
305                         break;
306         return (void *BND(_s,_s+n)) s;
307 }
308
309 long
310 strtol(const char *s, char **endptr, int base)
311 {
312         int neg = 0;
313         long val = 0;
314
315         // gobble initial whitespace
316         while (*s == ' ' || *s == '\t')
317                 s++;
318
319         // plus/minus sign
320         if (*s == '+')
321                 s++;
322         else if (*s == '-')
323                 s++, neg = 1;
324
325         // hex or octal base prefix
326         if ((base == 0 || base == 16) && (s[0] == '0' && s[1] == 'x'))
327                 s += 2, base = 16;
328         else if (base == 0 && s[0] == '0')
329                 s++, base = 8;
330         else if (base == 0)
331                 base = 10;
332
333         // digits
334         while (1) {
335                 int dig;
336
337                 if (*s >= '0' && *s <= '9')
338                         dig = *s - '0';
339                 else if (*s >= 'a' && *s <= 'z')
340                         dig = *s - 'a' + 10;
341                 else if (*s >= 'A' && *s <= 'Z')
342                         dig = *s - 'A' + 10;
343                 else
344                         break;
345                 if (dig >= base)
346                         break;
347                 s++, val = (val * base) + dig;
348                 // we don't properly detect overflow!
349         }
350
351         if (endptr)
352                 *endptr = (char *) s;
353         return (neg ? -val : val);
354 }
355
356 int
357 atoi(const char* s)
358 {
359         // no overflow detection
360         return (int)strtol(s,NULL,10);
361 }