Sanitize vcoreid from untrusted sources
[akaros.git] / kern / src / printfmt.c
1 // Stripped-down primitive printf-style formatting routines,
2 // used in common by printf, sprintf, fprintf, etc.
3 // This code is also used by both the kernel and user programs.
4 #include <ros/common.h>
5 #include <error.h>
6 #include <stdio.h>
7 #include <string.h>
8 #include <stdarg.h>
9 #include <kthread.h>
10 #include <syscall.h>
11 #include <ns.h>
12
13 /* Print a number (base <= 16) in reverse order,
14  * using specified putch function and associated pointer putdat. */
15 void printnum(void (*putch)(int, void**), void **putdat,
16               unsigned long long num, unsigned base, int width, int padc)
17 {
18         unsigned long long temp = num;
19         int nr_digits = 1;
20
21         /* Determine how many leading zeros we need.
22          * For every digit/nibble beyond base, we do one less width padding */
23         while ((temp /= base)) {
24                 nr_digits++;
25                 width--;
26         }
27         /* And another one less, since we'll always print the last digit */
28         while (--width > 0)
29                 putch(padc, putdat);
30         for (int i = nr_digits; i > 0; i--) {
31                 temp = num;
32                 /* To get digit i, we only div (i-1) times */
33                 for (int j = 0; j < i - 1; j++) {
34                         temp /= base;
35                 }
36                 putch("0123456789abcdef"[temp % base], putdat);
37         }
38 }
39
40 void printfmt(void (*putch)(int, void**), void **putdat, const char *fmt, ...);
41
42 void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
43                va_list ap)
44 {
45         register const char *p;
46         const char *last_fmt;
47         register int ch, err;
48         unsigned long long num;
49         int base, lflag, width, precision, altflag;
50         char padc;
51         uint8_t *mac, *ip, *mask;
52         struct Gas *g;
53         int i;
54         uint32_t *lp;
55
56         while (1) {
57                 while ((ch = *(unsigned char *) fmt) != '%') {
58                         if (ch == '\0')
59                                 return;
60                         fmt++;
61                         putch(ch, putdat);
62                 }
63                 fmt++;
64
65                 // Process a %-escape sequence
66                 last_fmt = fmt;
67                 padc = ' ';
68                 width = -1;
69                 precision = -1;
70                 lflag = 0;
71                 altflag = 0;
72         reswitch:
73                 switch (ch = *(unsigned char *) fmt++) {
74
75                 // flag to pad on the right
76                 case '-':
77                         padc = '-';
78                         goto reswitch;
79
80                 // flag to pad with 0's instead of spaces
81                 case '0':
82                         padc = '0';
83                         goto reswitch;
84
85                 // width field
86                 case '1':
87                 case '2':
88                 case '3':
89                 case '4':
90                 case '5':
91                 case '6':
92                 case '7':
93                 case '8':
94                 case '9':
95                         for (precision = 0; ; ++fmt) {
96                                 precision = precision * 10 + ch - '0';
97                                 ch = *fmt;
98                                 if (ch < '0' || ch > '9')
99                                         break;
100                         }
101                         goto process_precision;
102
103                 case '*':
104                         precision = va_arg(ap, int);
105                         goto process_precision;
106
107                 case '.':
108                         if (width < 0)
109                                 width = 0;
110                         goto reswitch;
111
112                 case '#':
113                         altflag = 1;
114                         goto reswitch;
115
116                 process_precision:
117                         if (width < 0)
118                                 width = precision, precision = -1;
119                         goto reswitch;
120
121                 // long flag (doubled for long long)
122                 case 'l':
123                         lflag++;
124                         goto reswitch;
125
126                 // chan
127                 case 'C':
128                         printchan(putch, putdat, va_arg(ap, void*));
129                         break;
130
131                 // character
132                 case 'c':
133                         putch(va_arg(ap, int), putdat);
134                         break;
135
136                 case 'E': // ENET MAC
137                         if ((mac = va_arg(ap, uint8_t *)) == NULL){
138                                 char *s = "00:00:00:00:00:00";
139                                 while(*s)
140                                         putch(*s++, putdat);
141                         }
142                         printemac(putch, putdat, mac);
143                         break;
144                 case 'i':
145                         /* what to do if they screw up? */
146                         if ((lp = va_arg(ap, uint32_t *)) != NULL){
147                                 uint32_t hostfmt;
148                                 for(i = 0; i < 4; i++){
149                                         hnputl(&hostfmt, lp[i]);
150                                         printfmt(putch, putdat, "%08lx",
151                                                  hostfmt);
152                                 }
153                         }
154                         break;
155                 case 'I':
156                         /* what to do if they screw up? */
157                         if ((ip = va_arg(ap, uint8_t *)) != NULL)
158                                 printip(putch, putdat, ip);
159                         break;
160                 case 'M':
161                         /* what to do if they screw up? */
162                         if ((mask = va_arg(ap, uint8_t *)) != NULL)
163                                 printipmask(putch, putdat, mask);
164                         break;
165                 case 'V':
166                         /* what to do if they screw up? */
167                         if ((ip = va_arg(ap, uint8_t *)) != NULL)
168                                 printipv4(putch, putdat, ip);
169                         break;
170
171                 // string
172                 case 's':
173                         if ((p = va_arg(ap, char *)) == NULL)
174                                 p = "(null)";
175                         if (width > 0 && padc != '-')
176                                 for (width -= strnlen(p, precision);
177                                      width > 0;
178                                      width--)
179                                         putch(padc, putdat);
180                         for (;
181                              (ch = *p) != '\0' && (precision < 0
182                                                    || --precision >= 0);
183                              width--) {
184                                 if (altflag && (ch < ' ' || ch > '~'))
185                                         putch('?', putdat);
186                                 else
187                                         putch(ch, putdat);
188                                 // zra: make sure *p isn't '\0' before inc'ing
189                                 p++;
190                         }
191                         for (; width > 0; width--)
192                                 putch(' ', putdat);
193                         break;
194
195                 case 'd': /* (signed) decimal */
196                         if (lflag >= 2)
197                                 num = va_arg(ap, long long);
198                         else if (lflag)
199                                 num = va_arg(ap, long);
200                         else
201                                 num = va_arg(ap, int);
202                         if ((long long) num < 0) {
203                                 putch('-', putdat);
204                                 num = -(long long) num;
205                         }
206                         base = 10;
207                         goto number;
208
209                 case 'u': /* unsigned decimal */
210                 case 'o': /* (unsigned) octal */
211                 case 'x': /* (unsigned) hexadecimal */
212                         if (lflag >= 2)
213                                 num = va_arg(ap, unsigned long long);
214                         else if (lflag)
215                                 num = va_arg(ap, unsigned long);
216                         else
217                                 num = va_arg(ap, unsigned int);
218                         if (ch == 'u')
219                                 base = 10;
220                         else if (ch == 'o')
221                                 base = 8;
222                         else    /* x */
223                                 base = 16;
224                         goto number;
225
226                 // pointer
227                 case 'p':
228                         putch('0', putdat);
229                         putch('x', putdat);
230                         /* automatically zero-pad pointers, out to the length of
231                          * a ptr */
232                         padc = '0';
233                         /* 8 bits per byte / 4 bits per char */
234                         width = sizeof(void*) * 2;
235                         num = (unsigned long long)
236                                 (uintptr_t) va_arg(ap, void *);
237                         base = 16;
238                         goto number;
239
240                 // qid
241                 case 'Q':
242                         printqid(putch, putdat, va_arg(ap, void*));
243                         break;
244                 number:
245                         printnum(putch, putdat, num, base, width, padc);
246                         break;
247
248                 // escaped '%' character
249                 case '%':
250                         putch(ch, putdat);
251                         break;
252
253                 // unrecognized escape sequence - just print it literally
254                 default:
255                         putch('%', putdat);
256                         fmt = last_fmt;
257                         //for (fmt--; fmt[-1] != '%'; fmt--)
258                                 /* do nothing */;
259                         break;
260                 }
261         }
262 }
263
264 void printfmt(void (*putch)(int, void**), void **putdat, const char *fmt, ...)
265 {
266         va_list ap;
267
268         va_start(ap, fmt);
269         vprintfmt(putch, putdat, fmt, ap);
270         va_end(ap);
271 }
272
273 typedef struct sprintbuf {
274         char *buf;
275         char *ebuf;
276         int cnt;
277 } sprintbuf_t;
278
279 static void sprintputch(int ch, sprintbuf_t **b)
280 {
281         if ((*b)->buf < (*b)->ebuf)
282                 *((*b)->buf++) = ch;
283         (*b)->cnt++;
284 }
285
286 int vsnprintf(char *buf, size_t n, const char *fmt, va_list ap)
287 {
288         sprintbuf_t b;// = {buf, buf+n-1, 0};
289         sprintbuf_t *bp = &b;
290
291         /* this isn't quite the snprintf 'spec', but errors aren't helpful */
292         assert(buf);
293         /* We might get large, 'negative' values for code that repeatedly calls
294          * snprintf(), e.g.:
295          *              len += snprintf(buf + len, bufsz - len, "foo");
296          *              len += snprintf(buf + len, bufsz - len, "bar");
297          * If len > bufsz, that will appear as a large value.  This is not quite
298          * the glibc semantics (we aren't returning the size we would have
299          * printed), but it short circuits the rest of the function and avoids
300          * potential errors in the putch() functions. */
301         if (!n || (n > INT32_MAX))
302                 return 0;
303
304         b.buf = NULL; // zra : help out the Deputy optimizer a bit
305         b.ebuf = buf+n-1;
306         b.cnt = 0;
307         b.buf = buf;
308
309         vprintfmt((void*)sprintputch, (void*)&bp, fmt, ap);
310
311         // null terminate the buffer
312         *b.buf = '\0';
313
314         return b.cnt;
315 }
316
317 int snprintf(char *buf, size_t n, const char *fmt, ...)
318 {
319         va_list ap;
320         int rc;
321
322         va_start(ap, fmt);
323         rc = vsnprintf(buf, n, fmt, ap);
324         va_end(ap);
325
326         return rc;
327 }
328
329 /* Convenience function: do a print, return the pointer to the null at the end.
330  *
331  * Unlike snprintf(), when we overflow, this doesn't return the 'end' where we
332  * would have written to.  Instead, we'll return 'end - 1', which is the last
333  * byte, and enforce the null-termination.  */
334 char *seprintf(char *buf, char *end, const char *fmt, ...)
335 {
336         va_list ap;
337         int rc;
338         size_t n = end - buf;
339
340         va_start(ap, fmt);
341         rc = vsnprintf(buf, n, fmt, ap);
342         va_end(ap);
343
344         /* Some error - leave them where they were. */
345         if (rc < 0)
346                 return buf;
347         /* Overflow - put them at the end */
348         if (rc >= n) {
349                 *(end - 1) = '\0';
350                 return end - 1;
351         }
352         assert(buf[rc] == '\0');
353         return buf + rc;
354 }