printk: check for user pointers in format string parameters
authorBarret Rhoden <brho@cs.berkeley.edu>
Thu, 2 May 2019 01:40:27 +0000 (21:40 -0400)
committerBarret Rhoden <brho@cs.berkeley.edu>
Thu, 2 May 2019 01:41:48 +0000 (21:41 -0400)
These may be a potential source of vulernability.

Signed-off-by: Barret Rhoden <brho@cs.berkeley.edu>
kern/src/printfmt.c

index 75937ec..35ad17a 100644 (file)
@@ -52,6 +52,7 @@ void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
        struct Gas *g;
        int i;
        uint32_t *lp;
+       void *p;
 
        while (1) {
                while ((ch = *(unsigned char *) fmt) != '%') {
@@ -125,7 +126,9 @@ void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
 
                // chan
                case 'C':
-                       printchan(putch, putdat, va_arg(ap, void*));
+                       p = va_arg(ap, void*);
+                       warn_on_user_ptr(p);
+                       printchan(putch, putdat, p);
                        break;
 
                // character
@@ -134,18 +137,24 @@ void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
                        break;
 
                case 'E': // ENET MAC
-                       if ((mac = va_arg(ap, uint8_t *)) == NULL){
+                       mac = va_arg(ap, uint8_t *);
+                       warn_on_user_ptr(mac);
+                       if (!mac) {
                                char *s = "00:00:00:00:00:00";
-                               while(*s)
+
+                               while (*s)
                                        putch(*s++, putdat);
                        }
                        printemac(putch, putdat, mac);
                        break;
                case 'i':
                        /* what to do if they screw up? */
-                       if ((lp = va_arg(ap, uint32_t *)) != NULL){
+                       lp = va_arg(ap, uint32_t *);
+                       warn_on_user_ptr(lp);
+                       if (lp) {
                                uint32_t hostfmt;
-                               for(i = 0; i < 4; i++){
+
+                               for (i = 0; i < 4; i++) {
                                        hnputl(&hostfmt, lp[i]);
                                        printfmt(putch, putdat, "%08lx",
                                                 hostfmt);
@@ -154,23 +163,31 @@ void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
                        break;
                case 'I':
                        /* what to do if they screw up? */
-                       if ((ip = va_arg(ap, uint8_t *)) != NULL)
+                       ip = va_arg(ap, uint8_t *);
+                       warn_on_user_ptr(ip);
+                       if (ip)
                                printip(putch, putdat, ip);
                        break;
                case 'M':
                        /* what to do if they screw up? */
-                       if ((mask = va_arg(ap, uint8_t *)) != NULL)
+                       mask = va_arg(ap, uint8_t *);
+                       warn_on_user_ptr(mask);
+                       if (mask)
                                printipmask(putch, putdat, mask);
                        break;
                case 'V':
                        /* what to do if they screw up? */
-                       if ((ip = va_arg(ap, uint8_t *)) != NULL)
+                       ip = va_arg(ap, uint8_t *);
+                       warn_on_user_ptr(ip);
+                       if (ip)
                                printipv4(putch, putdat, ip);
                        break;
 
                // string
                case 's':
-                       if ((s = va_arg(ap, char *)) == NULL)
+                       s = va_arg(ap, char *);
+                       warn_on_user_ptr(s);
+                       if (!s)
                                s = "(null)";
                        if (width > 0 && padc != '-')
                                for (width -= strnlen(s, precision);
@@ -239,7 +256,9 @@ void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt,
 
                // qid
                case 'Q':
-                       printqid(putch, putdat, va_arg(ap, void*));
+                       p = va_arg(ap, void*);
+                       warn_on_user_ptr(p);
+                       printqid(putch, putdat, p);
                        break;
                number:
                        printnum(putch, putdat, num, base, width, padc);