Implement transmit checksum offload.
[akaros.git] / kern / src / net / udp.c
1 // INFERNO
2 #define DEBUG
3 #include <vfs.h>
4 #include <kfs.h>
5 #include <slab.h>
6 #include <kmalloc.h>
7 #include <kref.h>
8 #include <string.h>
9 #include <stdio.h>
10 #include <assert.h>
11 #include <error.h>
12 #include <cpio.h>
13 #include <pmap.h>
14 #include <smp.h>
15 #include <ip.h>
16
17 #include <vfs.h>
18 #include <kfs.h>
19 #include <slab.h>
20 #include <kmalloc.h>
21 #include <kref.h>
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <error.h>
26 #include <cpio.h>
27 #include <pmap.h>
28 #include <smp.h>
29 #include <ip.h>
30
31 #define DPRINT if(0)print
32
33 enum {
34         UDP_UDPHDR_SZ = 8,
35
36         UDP4_PHDR_OFF = 8,
37         UDP4_PHDR_SZ = 12,
38         UDP4_IPHDR_SZ = 20,
39         UDP6_IPHDR_SZ = 40,
40         UDP6_PHDR_SZ = 40,
41         UDP6_PHDR_OFF = 0,
42
43         IP_UDPPROTO = 17,
44         UDP_USEAD7 = 52,
45         UDP_USEAD6 = 36,
46
47         Udprxms = 200,
48         Udptickms = 100,
49         Udpmaxxmit = 10,
50 };
51
52 typedef struct Udp4hdr Udp4hdr;
53 struct Udp4hdr {
54         /* ip header */
55         uint8_t vihl;                           /* Version and header length */
56         uint8_t tos;                            /* Type of service */
57         uint8_t length[2];                      /* packet length */
58         uint8_t id[2];                          /* Identification */
59         uint8_t frag[2];                        /* Fragment information */
60         uint8_t Unused;
61         uint8_t udpproto;                       /* Protocol */
62         uint8_t udpplen[2];                     /* Header plus data length */
63         uint8_t udpsrc[IPv4addrlen];    /* Ip source */
64         uint8_t udpdst[IPv4addrlen];    /* Ip destination */
65
66         /* udp header */
67         uint8_t udpsport[2];            /* Source port */
68         uint8_t udpdport[2];            /* Destination port */
69         uint8_t udplen[2];                      /* data length */
70         uint8_t udpcksum[2];            /* Checksum */
71 };
72
73 typedef struct Udp6hdr Udp6hdr;
74 struct Udp6hdr {
75         uint8_t viclfl[4];
76         uint8_t len[2];
77         uint8_t nextheader;
78         uint8_t hoplimit;
79         uint8_t udpsrc[IPaddrlen];
80         uint8_t udpdst[IPaddrlen];
81
82         /* udp header */
83         uint8_t udpsport[2];            /* Source port */
84         uint8_t udpdport[2];            /* Destination port */
85         uint8_t udplen[2];                      /* data length */
86         uint8_t udpcksum[2];            /* Checksum */
87 };
88
89 /* MIB II counters */
90 typedef struct Udpstats Udpstats;
91 struct Udpstats {
92         uint32_t udpInDatagrams;
93         uint32_t udpNoPorts;
94         uint32_t udpInErrors;
95         uint32_t udpOutDatagrams;
96 };
97
98 typedef struct Udppriv Udppriv;
99 struct Udppriv {
100         struct Ipht ht;
101
102         /* MIB counters */
103         Udpstats ustats;
104
105         /* non-MIB stats */
106         uint32_t csumerr;                       /* checksum errors */
107         uint32_t lenerr;                        /* short packet */
108 };
109
110 void (*etherprofiler) (char *name, int qlen);
111 void udpkick(void *x, struct block *bp);
112
113 /*
114  *  protocol specific part of Conv
115  */
116 typedef struct Udpcb Udpcb;
117 struct Udpcb {
118         qlock_t qlock;
119         uint8_t headers;
120 };
121
122 static char *udpconnect(struct conv *c, char **argv, int argc)
123 {
124         char *e;
125         Udppriv *upriv;
126
127         upriv = c->p->priv;
128         e = Fsstdconnect(c, argv, argc);
129         Fsconnected(c, e);
130         if (e != NULL)
131                 return e;
132
133         iphtadd(&upriv->ht, c);
134         return NULL;
135 }
136
137 static int udpstate(struct conv *c, char *state, int n)
138 {
139         return snprintf(state, n, "%s qin %d qout %d",
140                                         c->inuse ? "Open" : "Closed",
141                                         c->rq ? qlen(c->rq) : 0, c->wq ? qlen(c->wq) : 0);
142 }
143
144 static char *udpannounce(struct conv *c, char **argv, int argc)
145 {
146         char *e;
147         Udppriv *upriv;
148
149         upriv = c->p->priv;
150         e = Fsstdannounce(c, argv, argc);
151         if (e != NULL)
152                 return e;
153         Fsconnected(c, NULL);
154         iphtadd(&upriv->ht, c);
155
156         return NULL;
157 }
158
159 static void udpcreate(struct conv *c)
160 {
161         c->rq = qopen(128 * 1024, Qmsg, 0, 0);
162         c->wq = qbypass(udpkick, c);
163 }
164
165 static void udpclose(struct conv *c)
166 {
167         Udpcb *ucb;
168         Udppriv *upriv;
169
170         upriv = c->p->priv;
171         iphtrem(&upriv->ht, c);
172
173         c->state = 0;
174         qclose(c->rq);
175         qclose(c->wq);
176         qclose(c->eq);
177         ipmove(c->laddr, IPnoaddr);
178         ipmove(c->raddr, IPnoaddr);
179         c->lport = 0;
180         c->rport = 0;
181
182         ucb = (Udpcb *) c->ptcl;
183         ucb->headers = 0;
184
185         qunlock(&c->qlock);
186 }
187
188 void udpkick(void *x, struct block *bp)
189 {
190         struct conv *c = x;
191         Udp4hdr *uh4;
192         Udp6hdr *uh6;
193         uint16_t rport;
194         uint8_t laddr[IPaddrlen], raddr[IPaddrlen];
195         Udpcb *ucb;
196         int dlen, ptcllen;
197         Udppriv *upriv;
198         struct Fs *f;
199         int version;
200         struct conv *rc;
201
202         upriv = c->p->priv;
203         assert(upriv);
204         f = c->p->f;
205
206         netlog(c->p->f, Logudp, "udp: kick\n");
207         if (bp == NULL)
208                 return;
209
210         ucb = (Udpcb *) c->ptcl;
211         switch (ucb->headers) {
212                 case 7:
213                         /* get user specified addresses */
214                         bp = pullupblock(bp, UDP_USEAD7);
215                         if (bp == NULL)
216                                 return;
217                         ipmove(raddr, bp->rp);
218                         bp->rp += IPaddrlen;
219                         ipmove(laddr, bp->rp);
220                         bp->rp += IPaddrlen;
221                         /* pick interface closest to dest */
222                         if (ipforme(f, laddr) != Runi)
223                                 findlocalip(f, laddr, raddr);
224                         bp->rp += IPaddrlen;    /* Ignore ifc address */
225                         rport = nhgets(bp->rp);
226                         bp->rp += 2 + 2;        /* Ignore local port */
227                         break;
228                 case 6:
229                         /* get user specified addresses */
230                         bp = pullupblock(bp, UDP_USEAD6);
231                         if (bp == NULL)
232                                 return;
233                         ipmove(raddr, bp->rp);
234                         bp->rp += IPaddrlen;
235                         ipmove(laddr, bp->rp);
236                         bp->rp += IPaddrlen;
237                         /* pick interface closest to dest */
238                         if (ipforme(f, laddr) != Runi)
239                                 findlocalip(f, laddr, raddr);
240                         rport = nhgets(bp->rp);
241                         bp->rp += 2 + 2;        /* Ignore local port */
242                         break;
243                 default:
244                         rport = 0;
245                         break;
246         }
247
248         if (ucb->headers) {
249                 if (memcmp(laddr, v4prefix, IPv4off) == 0 ||
250                         ipcmp(laddr, IPnoaddr) == 0)
251                         version = V4;
252                 else
253                         version = V6;
254         } else {
255                 if ((memcmp(c->raddr, v4prefix, IPv4off) == 0 &&
256                          memcmp(c->laddr, v4prefix, IPv4off) == 0)
257                         || ipcmp(c->raddr, IPnoaddr) == 0)
258                         version = V4;
259                 else
260                         version = V6;
261         }
262
263         dlen = blocklen(bp);
264
265         /* fill in pseudo header and compute checksum */
266         switch (version) {
267                 case V4:
268                         bp = padblock(bp, UDP4_IPHDR_SZ + UDP_UDPHDR_SZ);
269                         if (bp == NULL)
270                                 return;
271
272                         uh4 = (Udp4hdr *) (bp->rp);
273                         ptcllen = dlen + UDP_UDPHDR_SZ;
274                         uh4->Unused = 0;
275                         uh4->udpproto = IP_UDPPROTO;
276                         uh4->frag[0] = 0;
277                         uh4->frag[1] = 0;
278                         hnputs(uh4->udpplen, ptcllen);
279                         if (ucb->headers) {
280                                 v6tov4(uh4->udpdst, raddr);
281                                 hnputs(uh4->udpdport, rport);
282                                 v6tov4(uh4->udpsrc, laddr);
283                                 rc = NULL;
284                         } else {
285                                 v6tov4(uh4->udpdst, c->raddr);
286                                 hnputs(uh4->udpdport, c->rport);
287                                 if (ipcmp(c->laddr, IPnoaddr) == 0)
288                                         findlocalip(f, c->laddr, c->raddr);
289                                 v6tov4(uh4->udpsrc, c->laddr);
290                                 rc = c;
291                         }
292                         hnputs(uh4->udpsport, c->lport);
293                         hnputs(uh4->udplen, ptcllen);
294                         uh4->udpcksum[0] = 0;
295                         uh4->udpcksum[1] = 0;
296                         hnputs(uh4->udpcksum,
297                                    ~ptclcsum(bp, UDP4_PHDR_OFF, UDP4_PHDR_SZ));
298                         bp->checksum_start = UDP4_IPHDR_SZ;
299                         bp->checksum_offset = uh4->udpcksum - uh4->udpsport;
300                         bp->flag |= Budpck;
301                         uh4->vihl = IP_VER4;
302                         ipoput4(f, bp, 0, c->ttl, c->tos, rc);
303                         break;
304
305                 case V6:
306                         bp = padblock(bp, UDP6_IPHDR_SZ + UDP_UDPHDR_SZ);
307                         if (bp == NULL)
308                                 return;
309
310                         // using the v6 ip header to create pseudo header 
311                         // first then reset it to the normal ip header
312                         uh6 = (Udp6hdr *) (bp->rp);
313                         memset(uh6, 0, 8);
314                         ptcllen = dlen + UDP_UDPHDR_SZ;
315                         hnputl(uh6->viclfl, ptcllen);
316                         uh6->hoplimit = IP_UDPPROTO;
317                         if (ucb->headers) {
318                                 ipmove(uh6->udpdst, raddr);
319                                 hnputs(uh6->udpdport, rport);
320                                 ipmove(uh6->udpsrc, laddr);
321                                 rc = NULL;
322                         } else {
323                                 ipmove(uh6->udpdst, c->raddr);
324                                 hnputs(uh6->udpdport, c->rport);
325                                 if (ipcmp(c->laddr, IPnoaddr) == 0)
326                                         findlocalip(f, c->laddr, c->raddr);
327                                 ipmove(uh6->udpsrc, c->laddr);
328                                 rc = c;
329                         }
330                         hnputs(uh6->udpsport, c->lport);
331                         hnputs(uh6->udplen, ptcllen);
332                         uh6->udpcksum[0] = 0;
333                         uh6->udpcksum[1] = 0;
334                         hnputs(uh6->udpcksum,
335                                    ptclcsum(bp, UDP6_PHDR_OFF,
336                                                         dlen + UDP_UDPHDR_SZ + UDP6_PHDR_SZ));
337                         memset(uh6, 0, 8);
338                         uh6->viclfl[0] = IP_VER6;
339                         hnputs(uh6->len, ptcllen);
340                         uh6->nextheader = IP_UDPPROTO;
341                         ipoput6(f, bp, 0, c->ttl, c->tos, rc);
342                         break;
343
344                 default:
345                         panic("udpkick: version %d", version);
346         }
347         upriv->ustats.udpOutDatagrams++;
348 }
349
350 void udpiput(struct Proto *udp, struct Ipifc *ifc, struct block *bp)
351 {
352         int len;
353         Udp4hdr *uh4;
354         Udp6hdr *uh6;
355         struct conv *c;
356         Udpcb *ucb;
357         uint8_t raddr[IPaddrlen], laddr[IPaddrlen];
358         uint16_t rport, lport;
359         Udppriv *upriv;
360         struct Fs *f;
361         int version;
362         int ottl, oviclfl, olen;
363         uint8_t *p;
364
365         upriv = udp->priv;
366         f = udp->f;
367         upriv->ustats.udpInDatagrams++;
368
369         uh4 = (Udp4hdr *) (bp->rp);
370         version = ((uh4->vihl & 0xF0) == IP_VER6) ? V6 : V4;
371
372         /*
373          * Put back pseudo header for checksum 
374          * (remember old values for icmpnoconv())
375          */
376         switch (version) {
377                 case V4:
378                         ottl = uh4->Unused;
379                         uh4->Unused = 0;
380                         len = nhgets(uh4->udplen);
381                         olen = nhgets(uh4->udpplen);
382                         hnputs(uh4->udpplen, len);
383
384                         v4tov6(raddr, uh4->udpsrc);
385                         v4tov6(laddr, uh4->udpdst);
386                         lport = nhgets(uh4->udpdport);
387                         rport = nhgets(uh4->udpsport);
388
389                         if (!(bp->flag & Budpck) &&
390                             (uh4->udpcksum[0] || uh4->udpcksum[1]) &&
391                             ptclcsum(bp, UDP4_PHDR_OFF, len + UDP4_PHDR_SZ)) {
392                                 upriv->ustats.udpInErrors++;
393                                 netlog(f, Logudp, "udp: checksum error %I\n",
394                                        raddr);
395                                 printd("udp: checksum error %I\n", raddr);
396                                 freeblist(bp);
397                                 return;
398                         }
399                         uh4->Unused = ottl;
400                         hnputs(uh4->udpplen, olen);
401                         break;
402                 case V6:
403                         uh6 = (Udp6hdr *) (bp->rp);
404                         len = nhgets(uh6->udplen);
405                         oviclfl = nhgetl(uh6->viclfl);
406                         olen = nhgets(uh6->len);
407                         ottl = uh6->hoplimit;
408                         ipmove(raddr, uh6->udpsrc);
409                         ipmove(laddr, uh6->udpdst);
410                         lport = nhgets(uh6->udpdport);
411                         rport = nhgets(uh6->udpsport);
412                         memset(uh6, 0, 8);
413                         hnputl(uh6->viclfl, len);
414                         uh6->hoplimit = IP_UDPPROTO;
415                         if (ptclcsum(bp, UDP6_PHDR_OFF, len + UDP6_PHDR_SZ)) {
416                                 upriv->ustats.udpInErrors++;
417                                 netlog(f, Logudp, "udp: checksum error %I\n", raddr);
418                                 printd("udp: checksum error %I\n", raddr);
419                                 freeblist(bp);
420                                 return;
421                         }
422                         hnputl(uh6->viclfl, oviclfl);
423                         hnputs(uh6->len, olen);
424                         uh6->nextheader = IP_UDPPROTO;
425                         uh6->hoplimit = ottl;
426                         break;
427                 default:
428                         panic("udpiput: version %d", version);
429                         return; /* to avoid a warning */
430         }
431
432         qlock(&udp->qlock);
433
434         c = iphtlook(&upriv->ht, raddr, rport, laddr, lport);
435         if (c == NULL) {
436                 /* no converstation found */
437                 upriv->ustats.udpNoPorts++;
438                 qunlock(&udp->qlock);
439                 netlog(f, Logudp, "udp: no conv %I!%d -> %I!%d\n", raddr, rport,
440                            laddr, lport);
441
442                 switch (version) {
443                         case V4:
444                                 icmpnoconv(f, bp);
445                                 break;
446                         case V6:
447                                 icmphostunr(f, ifc, bp, icmp6_port_unreach, 0);
448                                 break;
449                         default:
450                                 panic("udpiput2: version %d", version);
451                 }
452
453                 freeblist(bp);
454                 return;
455         }
456         ucb = (Udpcb *) c->ptcl;
457
458         if (c->state == Announced) {
459                 if (ucb->headers == 0) {
460                         /* create a new conversation */
461                         if (ipforme(f, laddr) != Runi) {
462                                 switch (version) {
463                                         case V4:
464                                                 v4tov6(laddr, ifc->lifc->local);
465                                                 break;
466                                         case V6:
467                                                 ipmove(laddr, ifc->lifc->local);
468                                                 break;
469                                         default:
470                                                 panic("udpiput3: version %d", version);
471                                 }
472                         }
473                         c = Fsnewcall(c, raddr, rport, laddr, lport, version);
474                         if (c == NULL) {
475                                 qunlock(&udp->qlock);
476                                 freeblist(bp);
477                                 return;
478                         }
479                         iphtadd(&upriv->ht, c);
480                         ucb = (Udpcb *) c->ptcl;
481                 }
482         }
483
484         qlock(&c->qlock);
485         qunlock(&udp->qlock);
486
487         /*
488          * Trim the packet down to data size
489          */
490         len -= UDP_UDPHDR_SZ;
491         switch (version) {
492                 case V4:
493                         bp = trimblock(bp, UDP4_IPHDR_SZ + UDP_UDPHDR_SZ, len);
494                         break;
495                 case V6:
496                         bp = trimblock(bp, UDP6_IPHDR_SZ + UDP_UDPHDR_SZ, len);
497                         break;
498                 default:
499                         bp = NULL;
500                         panic("udpiput4: version %d", version);
501         }
502         if (bp == NULL) {
503                 qunlock(&c->qlock);
504                 netlog(f, Logudp, "udp: len err %I.%d -> %I.%d\n", raddr, rport,
505                            laddr, lport);
506                 upriv->lenerr++;
507                 return;
508         }
509
510         netlog(f, Logudpmsg, "udp: %I.%d -> %I.%d l %d\n", raddr, rport,
511                    laddr, lport, len);
512
513         switch (ucb->headers) {
514                 case 7:
515                         /* pass the src address */
516                         bp = padblock(bp, UDP_USEAD7);
517                         p = bp->rp;
518                         ipmove(p, raddr);
519                         p += IPaddrlen;
520                         ipmove(p, laddr);
521                         p += IPaddrlen;
522                         ipmove(p, ifc->lifc->local);
523                         p += IPaddrlen;
524                         hnputs(p, rport);
525                         p += 2;
526                         hnputs(p, lport);
527                         break;
528                 case 6:
529                         /* pass the src address */
530                         bp = padblock(bp, UDP_USEAD6);
531                         p = bp->rp;
532                         ipmove(p, raddr);
533                         p += IPaddrlen;
534                         ipmove(p, ipforme(f, laddr) == Runi ? laddr : ifc->lifc->local);
535                         p += IPaddrlen;
536                         hnputs(p, rport);
537                         p += 2;
538                         hnputs(p, lport);
539                         break;
540         }
541
542         if (bp->next)
543                 bp = concatblock(bp);
544
545         if (qfull(c->rq)) {
546                 qunlock(&c->qlock);
547                 netlog(f, Logudp, "udp: qfull %I.%d -> %I.%d\n", raddr, rport,
548                            laddr, lport);
549                 freeblist(bp);
550                 return;
551         }
552
553         qpass(c->rq, bp);
554         qunlock(&c->qlock);
555
556 }
557
558 char *udpctl(struct conv *c, char **f, int n)
559 {
560         Udpcb *ucb;
561
562         ucb = (Udpcb *) c->ptcl;
563         if (n == 1) {
564                 if (strcmp(f[0], "oldheaders") == 0) {
565                         ucb->headers = 6;
566                         return NULL;
567                 } else if (strcmp(f[0], "headers") == 0) {
568                         ucb->headers = 7;
569                         return NULL;
570                 }
571         }
572         return "unknown control request";
573 }
574
575 void udpadvise(struct Proto *udp, struct block *bp, char *msg)
576 {
577         Udp4hdr *h4;
578         Udp6hdr *h6;
579         uint8_t source[IPaddrlen], dest[IPaddrlen];
580         uint16_t psource, pdest;
581         struct conv *s, **p;
582         int version;
583
584         h4 = (Udp4hdr *) (bp->rp);
585         version = ((h4->vihl & 0xF0) == IP_VER6) ? V6 : V4;
586
587         switch (version) {
588                 case V4:
589                         v4tov6(dest, h4->udpdst);
590                         v4tov6(source, h4->udpsrc);
591                         psource = nhgets(h4->udpsport);
592                         pdest = nhgets(h4->udpdport);
593                         break;
594                 case V6:
595                         h6 = (Udp6hdr *) (bp->rp);
596                         ipmove(dest, h6->udpdst);
597                         ipmove(source, h6->udpsrc);
598                         psource = nhgets(h6->udpsport);
599                         pdest = nhgets(h6->udpdport);
600                         break;
601                 default:
602                         panic("udpadvise: version %d", version);
603                         return; /* to avoid a warning */
604         }
605
606         /* Look for a connection */
607         qlock(&udp->qlock);
608         for (p = udp->conv; *p; p++) {
609                 s = *p;
610                 if (s->rport == pdest)
611                         if (s->lport == psource)
612                                 if (ipcmp(s->raddr, dest) == 0)
613                                         if (ipcmp(s->laddr, source) == 0) {
614                                                 if (s->ignoreadvice)
615                                                         break;
616                                                 qlock(&s->qlock);
617                                                 qunlock(&udp->qlock);
618                                                 qhangup(s->rq, msg);
619                                                 qhangup(s->wq, msg);
620                                                 qunlock(&s->qlock);
621                                                 freeblist(bp);
622                                                 return;
623                                         }
624         }
625         qunlock(&udp->qlock);
626         freeblist(bp);
627 }
628
629 int udpstats(struct Proto *udp, char *buf, int len)
630 {
631         Udppriv *upriv;
632         char *p, *e;
633
634         upriv = udp->priv;
635         p = buf;
636         e = p + len;
637         p = seprintf(p, e, "InDatagrams: %u\n", upriv->ustats.udpInDatagrams);
638         p = seprintf(p, e, "NoPorts: %u\n", upriv->ustats.udpNoPorts);
639         p = seprintf(p, e, "InErrors: %u\n", upriv->ustats.udpInErrors);
640         p = seprintf(p, e, "OutDatagrams: %u\n", upriv->ustats.udpOutDatagrams);
641         return p - buf;
642 }
643
644 void udpnewconv(struct Proto *udp, struct conv *conv)
645 {
646         /* Fsprotoclone alloc'd our priv struct and attached it to conv already.
647          * Now we need to init it */
648         struct Udpcb *ucb = (struct Udpcb *)conv->ptcl;
649         qlock_init(&ucb->qlock);
650 }
651
652 void udpinit(struct Fs *fs)
653 {
654         struct Proto *udp;
655
656         udp = kzmalloc(sizeof(struct Proto), 0);
657         udp->priv = kzmalloc(sizeof(Udppriv), 0);
658         udp->name = "udp";
659         udp->connect = udpconnect;
660         udp->announce = udpannounce;
661         udp->ctl = udpctl;
662         udp->state = udpstate;
663         udp->create = udpcreate;
664         udp->close = udpclose;
665         udp->rcv = udpiput;
666         udp->advise = udpadvise;
667         udp->stats = udpstats;
668         udp->ipproto = IP_UDPPROTO;
669         udp->nc = Nchans;
670         udp->newconv = udpnewconv;
671         udp->ptclsize = sizeof(Udpcb);
672
673         Fsproto(fs, udp);
674 }