printk: check for user pointers in format string parameters
[akaros.git] / kern / src / string.c
1 // Basic string routines.  Not hardware optimized, but not shabby.
2
3 #include <stdio.h>
4 #include <string.h>
5 #include <ros/memlayout.h>
6 #include <assert.h>
7
8 int strlen(const char *s)
9 {
10         int n;
11
12         for (n = 0; *s != '\0'; s++)
13                 n++;
14         return n;
15 }
16
17 int strnlen(const char *s, size_t size)
18 {
19         int n;
20
21         for (n = 0; size > 0 && *s != '\0'; s++, size--)
22                 n++;
23         return n;
24 }
25
26 /* zra: These aren't being used, and they are dangerous, so I'm rm'ing them
27 char *
28 strcpy(char *dst, const char *src)
29 {
30         char *ret;
31
32         ret = dst;
33         while ((*dst++ = *src++) != '\0')
34                 ;
35         return ret;
36 }
37
38 char *
39 strcat(char *dst, const char *src)
40 {
41         strcpy(dst+strlen(dst),src);
42         return dst;
43 }
44 */
45
46 char *strncpy(char *dst, const char *src, size_t size) {
47         size_t i;
48         char *ret;
49
50         ret = dst;
51         for (i = 0; i < size; i++) {
52                 *dst++ = *src;
53                 // If strlen(src) < size, null-pad 'dst' out to 'size' chars
54                 if (*src != '\0')
55                         src++;
56         }
57         return ret;
58 }
59
60 size_t strlcpy(char *dst, const char *src, size_t size)
61 {
62         if (size > 0) {
63                 while (--size > 0 && *src != '\0')
64                         *dst++ = *src++;
65                 *dst = '\0';
66         }
67
68         return strlen(src);
69 }
70
71 size_t strlcat(char *dst, const char *src, size_t size)
72 {
73         size_t rem;     /* Buffer space remaining after null in dst. */
74
75         /* We must find the terminating NUL byte in dst, but abort the
76          * search if we go past 'size' bytes.  At the end of this loop,
77          * 'dst' will point to either the NUL byte in the original
78          * destination or to one byte beyond the end of the buffer.
79          *
80          * 'rem' will be the amount of 'size' remaining beyond the NUL byte;
81          * potentially zero. This implies that 'size - rem' is equal to the
82          * distance from the beginning of the destination buffer to 'dst'.
83          *
84          * The return value of strlcat is the sum of the length of the
85          * original destination buffer (size - rem) plus the size of the
86          * src string (the return value of strlcpy). */
87         rem = size;
88         while ((rem > 0) && (*dst != '\0')) {
89                 rem--;
90                 dst++;
91         }
92
93         return (size - rem) + strlcpy(dst, src, rem);
94 }
95
96 int strcmp(const char *p, const char *q)
97 {
98         while (*p && *p == *q)
99                 p++, q++;
100         return (int) ((unsigned char) *p - (unsigned char) *q);
101 }
102
103 int strncmp(const char *p, const char *q, size_t n)
104 {
105         while (n > 0 && *p && *p == *q)
106                 n--, p++, q++;
107         if (n == 0)
108                 return 0;
109         else
110                 return (int) ((unsigned char) *p - (unsigned char) *q);
111 }
112
113 // Return a pointer to the first occurrence of 'c' in 's',
114 // or a null pointer if the string has no 'c'.
115 char *strchr(const char *s, char c)
116 {
117         for (; *s; s++)
118                 if (*s == c)
119                         return (char *) s;
120         return 0;
121 }
122
123 // Return a pointer to the last occurrence of 'c' in 's',
124 // or a null pointer if the string has no 'c'.
125 char *strrchr(const char *s, char c)
126 {
127         char *lastc = NULL;
128         for (; *s; s++)
129                 if (*s == c){
130                         lastc = (char*)s;
131                 }
132         return lastc;
133 }
134
135 void *memchr(const void *mem, int chr, int len)
136 {
137         char *s = (char*) mem;
138
139         for (int i = 0; i < len; i++)
140                 if (s[i] == (char) chr)
141                         return s + i;
142         return NULL;
143 }
144
145 // Return a pointer to the first occurrence of 'c' in 's',
146 // or a pointer to the string-ending null character if the string has no 'c'.
147 char *strfind(const char *s, char c)
148 {
149         for (; *s; s++)
150                 if (*s == c)
151                         break;
152         return (char *) s;
153 }
154
155 // memset aligned words.
156 static inline void *memsetw(long* _v, long c, size_t n)
157 {
158         long *start, *end, *v;
159
160         start = _v;
161         end = _v + n/sizeof(long);
162         v = _v;
163         c = c & 0xff;
164         c = c | c<<8;
165         c = c | c<<16;
166         #if NUM_ADDR_BITS == 64
167         c = c | c<<32;
168         #elif NUM_ADDR_BITS != 32
169         # error
170         #endif
171
172         while(v < end - (8-1))
173         {
174                 v[3] = v[2] = v[1] = v[0] = c;
175                 v += 4;
176                 v[3] = v[2] = v[1] = v[0] = c;
177                 v += 4;
178         }
179
180         while(v < end)
181           *v++ = c;
182
183         return start;
184 }
185
186 // copy aligned words.
187 // unroll 9 ways to get multiple misses in flight
188 #define memcpyw(type, _dst, _src, n) \
189   do { \
190         type* restrict src = (type*)(_src); \
191         type* restrict dst = (type*)(_dst); \
192         type* srcend = src + (n)/sizeof(type); \
193         type* dstend = dst + (n)/sizeof(type); \
194         while (dst < dstend - (9-1)) { \
195                 dst[0] = src[0]; \
196                 dst[1] = src[1]; \
197                 dst[2] = src[2]; \
198                 dst[3] = src[3]; \
199                 dst[4] = src[4]; \
200                 dst[5] = src[5]; \
201                 dst[6] = src[6]; \
202                 dst[7] = src[7]; \
203                 dst[8] = src[8]; \
204                 src += 9; \
205                 dst += 9; \
206         } \
207         while(dst < dstend) \
208           *dst++ = *src++; \
209   } while(0)
210
211 void *memset(void *v, int c, size_t _n)
212 {
213         char *p;
214         size_t n0;
215         size_t n = _n;
216
217         if (n == 0) return NULL; // zra: complain here?
218
219         p = v;
220
221         while (n > 0 && ((uintptr_t)p & (sizeof(long)-1))) {
222                 *p++ = c;
223                 n--;
224         }
225
226         if (n >= sizeof(long)) {
227                 n0 = n / sizeof(long) * sizeof(long);
228                 memsetw((long*)p, c, n0);
229                 n -= n0;
230                 p += n0;
231         }
232
233         while (n > 0) {
234                 *p++ = c;
235                 n--;
236         }
237
238         return v;
239 }
240
241 void *memcpy(void* dst, const void* src, size_t _n)
242 {
243         const char* s;
244         char* d;
245         size_t n0 = 0;
246         size_t n = _n;
247         int align = sizeof(long)-1;
248
249         s = src;
250         d = dst;
251
252         if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(long)-1)) == 0) {
253                 n0 = n / sizeof(long) * sizeof(long);
254                 memcpyw(long, d, s, n0);
255         } else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(int)-1)) == 0) {
256                 n0 = n / sizeof(int) * sizeof(int);
257                 memcpyw(int, d, s, n0);
258         } else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(short)-1)) == 0) {
259                 n0 = n / sizeof(short) * sizeof(short);
260                 memcpyw(short, d, s, n0);
261         }
262
263         n -= n0;
264         s += n0;
265         d += n0;
266
267         while (n-- > 0)
268                 *d++ = *s++;
269
270         return dst;
271 }
272
273 void *memmove(void *dst, const void *src, size_t _n)
274 {
275 #ifdef CONFIG_X86
276         bcopy(src, dst, _n);
277         return dst;
278 #else
279         const char *s;
280         char *d;
281         size_t n = _n;
282
283         s = src;
284         d = dst;
285         if (s < d && s + n > d) {
286                 s += n;
287                 d += n;
288                 while (n-- > 0)
289                         *--d = *--s;
290         } else
291                 while (n-- > 0)
292                         *d++ = *s++;
293
294         return dst;
295 #endif
296 }
297
298 int memcmp(const void *v1, const void *v2, size_t n)
299 {
300         const uint8_t *s1 = (const uint8_t *) v1;
301         const uint8_t *s2 = (const uint8_t *) v2;
302
303         while (n-- > 0) {
304                 if (*s1 != *s2)
305                         return (int) *s1 - (int) *s2;
306                 s1++, s2++;
307         }
308
309         return 0;
310 }
311
312 void *memfind(const void *_s, int c, size_t n)
313 {
314         const void *ends = (const char *) _s + n;
315         const void *s = _s;
316         for (; s < ends; s++)
317                 if (*(const unsigned char *) s == (unsigned char) c)
318                         break;
319         return (void *)s;
320 }
321
322 long strtol(const char *s, char **endptr, int base)
323 {
324         int neg = 0;
325         long val = 0;
326
327         // gobble initial whitespace
328         while (*s == ' ' || *s == '\t')
329                 s++;
330
331         // plus/minus sign
332         if (*s == '+')
333                 s++;
334         else if (*s == '-')
335                 s++, neg = 1;
336
337         // hex or octal base prefix
338         if ((base == 0 || base == 16) && (s[0] == '0' && s[1] == 'x'))
339                 s += 2, base = 16;
340         else if (base == 0 && s[0] == '0')
341                 s++, base = 8;
342         else if (base == 0)
343                 base = 10;
344
345         // digits
346         while (1) {
347                 int dig;
348
349                 if (*s >= '0' && *s <= '9')
350                         dig = *s - '0';
351                 else if (*s >= 'a' && *s <= 'z')
352                         dig = *s - 'a' + 10;
353                 else if (*s >= 'A' && *s <= 'Z')
354                         dig = *s - 'A' + 10;
355                 else
356                         break;
357                 if (dig >= base)
358                         break;
359                 s++, val = (val * base) + dig;
360                 // we don't properly detect overflow!
361         }
362
363         if (endptr)
364                 *endptr = (char *) s;
365         return (neg ? -val : val);
366 }
367
368 unsigned long strtoul(const char *s, char **endptr, int base)
369 {
370         int neg = 0;
371         unsigned long val = 0;
372
373         // gobble initial whitespace
374         while (*s == ' ' || *s == '\t')
375                 s++;
376
377         // plus/minus sign
378         if (*s == '+')
379                 s++;
380         else if (*s == '-')
381                 s++, neg = 1;
382
383         // hex or octal base prefix
384         if ((base == 0 || base == 16) && (s[0] == '0' && s[1] == 'x'))
385                 s += 2, base = 16;
386         else if (base == 0 && s[0] == '0')
387                 s++, base = 8;
388         else if (base == 0)
389                 base = 10;
390
391         // digits
392         while (1) {
393                 int dig;
394
395                 if (*s >= '0' && *s <= '9')
396                         dig = *s - '0';
397                 else if (*s >= 'a' && *s <= 'z')
398                         dig = *s - 'a' + 10;
399                 else if (*s >= 'A' && *s <= 'Z')
400                         dig = *s - 'A' + 10;
401                 else
402                         break;
403                 if (dig >= base)
404                         break;
405                 s++, val = (val * base) + dig;
406                 // we don't properly detect overflow!
407         }
408
409         if (endptr)
410                 *endptr = (char *) s;
411         return (neg ? -val : val);
412 }
413
414 int atoi(const char *s)
415 {
416         if (!s)
417                 return 0;
418         if (s[0] == '0' && s[1] == 'x')
419                 warn("atoi() used on a hex string!");
420         // no overflow detection
421         return (int)strtol(s,NULL,10);
422 }
423
424 int sigchecksum(void *address, int length)
425 {
426         uint8_t *p, sum;
427
428         sum = 0;
429         for (p = address; length-- > 0; p++)
430                 sum += *p;
431
432         return sum;
433 }
434
435 void *sigscan(uint8_t *address, int length, char *signature)
436 {
437         uint8_t *e, *p;
438         int siglength;
439
440         e = address + length;
441         siglength = strlen(signature);
442         for (p = address; p + siglength < e; p += 16) {
443                 if (memcmp(p, signature, siglength))
444                         continue;
445                 return p;
446         }
447
448         return NULL;
449 }