All qlocks are initialized
[akaros.git] / kern / src / net / udp.c
1 // INFERNO
2 #define DEBUG
3 #include <vfs.h>
4 #include <kfs.h>
5 #include <slab.h>
6 #include <kmalloc.h>
7 #include <kref.h>
8 #include <string.h>
9 #include <stdio.h>
10 #include <assert.h>
11 #include <error.h>
12 #include <cpio.h>
13 #include <pmap.h>
14 #include <smp.h>
15 #include <ip.h>
16
17 #include <vfs.h>
18 #include <kfs.h>
19 #include <slab.h>
20 #include <kmalloc.h>
21 #include <kref.h>
22 #include <string.h>
23 #include <stdio.h>
24 #include <assert.h>
25 #include <error.h>
26 #include <cpio.h>
27 #include <pmap.h>
28 #include <smp.h>
29 #include <ip.h>
30
31
32 #define DPRINT if(0)print
33
34 enum
35 {
36         UDP_UDPHDR_SZ   = 8,
37
38         UDP4_PHDR_OFF = 8,
39         UDP4_PHDR_SZ = 12,
40         UDP4_IPHDR_SZ = 20,
41         UDP6_IPHDR_SZ = 40,
42         UDP6_PHDR_SZ = 40,
43         UDP6_PHDR_OFF = 0,
44
45         IP_UDPPROTO     = 17,
46         UDP_USEAD7      = 52,
47         UDP_USEAD6      = 36,
48
49         Udprxms         = 200,
50         Udptickms       = 100,
51         Udpmaxxmit      = 10,
52 };
53
54 typedef struct Udp4hdr Udp4hdr;
55 struct Udp4hdr
56 {
57         /* ip header */
58         uint8_t vihl;           /* Version and header length */
59         uint8_t tos;            /* Type of service */
60         uint8_t length[2];      /* packet length */
61         uint8_t id[2];          /* Identification */
62         uint8_t frag[2];        /* Fragment information */
63         uint8_t Unused; 
64         uint8_t udpproto;       /* Protocol */
65         uint8_t udpplen[2];     /* Header plus data length */
66         uint8_t udpsrc[IPv4addrlen];    /* Ip source */
67         uint8_t udpdst[IPv4addrlen];    /* Ip destination */
68
69         /* udp header */
70         uint8_t udpsport[2];    /* Source port */
71         uint8_t udpdport[2];    /* Destination port */
72         uint8_t udplen[2];      /* data length */
73         uint8_t udpcksum[2];    /* Checksum */
74 };
75
76 typedef struct Udp6hdr Udp6hdr;
77 struct Udp6hdr {
78         uint8_t viclfl[4];
79         uint8_t len[2];
80         uint8_t nextheader;
81         uint8_t hoplimit;
82         uint8_t udpsrc[IPaddrlen];
83         uint8_t udpdst[IPaddrlen];
84
85         /* udp header */
86         uint8_t udpsport[2];    /* Source port */
87         uint8_t udpdport[2];    /* Destination port */
88         uint8_t udplen[2];      /* data length */
89         uint8_t udpcksum[2];    /* Checksum */
90 };
91
92 /* MIB II counters */
93 typedef struct Udpstats Udpstats;
94 struct Udpstats
95 {
96         uint32_t        udpInDatagrams;
97         uint32_t        udpNoPorts;
98         uint32_t        udpInErrors;
99         uint32_t        udpOutDatagrams;
100 };
101
102 typedef struct Udppriv Udppriv;
103 struct Udppriv
104 {
105         struct Ipht             ht;
106
107         /* MIB counters */
108         Udpstats        ustats;
109
110         /* non-MIB stats */
111         uint32_t                csumerr;                /* checksum errors */
112         uint32_t                lenerr;                 /* short packet */
113 };
114
115 void (*etherprofiler)(char *name, int qlen);
116 void udpkick(void *x, struct block *bp);
117
118 /*
119  *  protocol specific part of Conv
120  */
121 typedef struct Udpcb Udpcb;
122 struct Udpcb
123 {
124         qlock_t qlock;
125         uint8_t headers;
126 };
127
128 static char*
129 udpconnect(struct conv *c, char **argv, int argc)
130 {
131         char *e;
132         Udppriv *upriv;
133
134         upriv = c->p->priv;
135         e = Fsstdconnect(c, argv, argc);
136         Fsconnected(c, e);
137         if(e != NULL)
138                 return e;
139
140         iphtadd(&upriv->ht, c);
141         return NULL;
142 }
143
144
145 static int
146 udpstate(struct conv *c, char *state, int n)
147 {
148         return snprintf(state, n, "%s qin %d qout %d",
149                 c->inuse ? "Open" : "Closed",
150                 c->rq ? qlen(c->rq) : 0,
151                 c->wq ? qlen(c->wq) : 0
152         );
153 }
154
155 static char*
156 udpannounce(struct conv *c, char** argv, int argc)
157 {
158         char *e;
159         Udppriv *upriv;
160
161         upriv = c->p->priv;
162         e = Fsstdannounce(c, argv, argc);
163         if(e != NULL)
164                 return e;
165         Fsconnected(c, NULL);
166         iphtadd(&upriv->ht, c);
167
168         return NULL;
169 }
170
171 static void
172 udpcreate(struct conv *c)
173 {
174         c->rq = qopen(64*1024, Qmsg, 0, 0);
175         c->wq = qbypass(udpkick, c);
176 }
177
178 static void
179 udpclose(struct conv *c)
180 {
181         Udpcb *ucb;
182         Udppriv *upriv;
183
184         upriv = c->p->priv;
185         iphtrem(&upriv->ht, c);
186
187         c->state = 0;
188         qclose(c->rq);
189         qclose(c->wq);
190         qclose(c->eq);
191         ipmove(c->laddr, IPnoaddr);
192         ipmove(c->raddr, IPnoaddr);
193         c->lport = 0;
194         c->rport = 0;
195
196         ucb = (Udpcb*)c->ptcl;
197         ucb->headers = 0;
198
199         qunlock(&c->qlock);
200 }
201
202 void
203 udpkick(void *x, struct block *bp)
204 {
205         struct conv *c = x;
206         Udp4hdr *uh4;
207         Udp6hdr *uh6;
208         uint16_t rport;
209         uint8_t laddr[IPaddrlen], raddr[IPaddrlen];
210         Udpcb *ucb;
211         int dlen, ptcllen;
212         Udppriv *upriv;
213         struct Fs *f;
214         int version;
215         struct conv *rc;
216
217         upriv = c->p->priv;
218         f = c->p->f;
219
220         netlog(c->p->f, Logudp, "udp: kick\n");
221         if(bp == NULL)
222                 return;
223
224         ucb = (Udpcb*)c->ptcl;
225         switch(ucb->headers) {
226         case 7:
227                 /* get user specified addresses */
228                 bp = pullupblock(bp, UDP_USEAD7);
229                 if(bp == NULL)
230                         return;
231                 ipmove(raddr, bp->rp);
232                 bp->rp += IPaddrlen;
233                 ipmove(laddr, bp->rp);
234                 bp->rp += IPaddrlen;
235                 /* pick interface closest to dest */
236                 if(ipforme(f, laddr) != Runi)
237                         findlocalip(f, laddr, raddr);
238                 bp->rp += IPaddrlen;            /* Ignore ifc address */
239                 rport = nhgets(bp->rp);
240                 bp->rp += 2+2;                  /* Ignore local port */
241                 break;
242         case 6:
243                 /* get user specified addresses */
244                 bp = pullupblock(bp, UDP_USEAD6);
245                 if(bp == NULL)
246                         return;
247                 ipmove(raddr, bp->rp);
248                 bp->rp += IPaddrlen;
249                 ipmove(laddr, bp->rp);
250                 bp->rp += IPaddrlen;
251                 /* pick interface closest to dest */
252                 if(ipforme(f, laddr) != Runi)
253                         findlocalip(f, laddr, raddr);
254                 rport = nhgets(bp->rp);
255                 bp->rp += 2+2;                  /* Ignore local port */
256                 break;
257         default:
258                 rport = 0;
259                 break;
260         }
261
262         if(ucb->headers) {
263                 if(memcmp(laddr, v4prefix, IPv4off) == 0 ||
264                     ipcmp(laddr, IPnoaddr) == 0)
265                         version = V4;
266                 else
267                         version = V6;
268         } else {
269                 if( (memcmp(c->raddr, v4prefix, IPv4off) == 0 &&
270                         memcmp(c->laddr, v4prefix, IPv4off) == 0)
271                         || ipcmp(c->raddr, IPnoaddr) == 0)
272                         version = V4;
273                 else
274                         version = V6;
275         }
276
277         dlen = blocklen(bp);
278
279         /* fill in pseudo header and compute checksum */
280         switch(version){
281         case V4:
282                 bp = padblock(bp, UDP4_IPHDR_SZ+UDP_UDPHDR_SZ);
283                 if(bp == NULL)
284                         return;
285
286                 uh4 = (Udp4hdr *)(bp->rp);
287                 ptcllen = dlen + UDP_UDPHDR_SZ;
288                 uh4->Unused = 0;
289                 uh4->udpproto = IP_UDPPROTO;
290                 uh4->frag[0] = 0;
291                 uh4->frag[1] = 0;
292                 hnputs(uh4->udpplen, ptcllen);
293                 if(ucb->headers) {
294                         v6tov4(uh4->udpdst, raddr);
295                         hnputs(uh4->udpdport, rport);
296                         v6tov4(uh4->udpsrc, laddr);
297                         rc = NULL;
298                 } else {
299                         v6tov4(uh4->udpdst, c->raddr);
300                         hnputs(uh4->udpdport, c->rport);
301                         if(ipcmp(c->laddr, IPnoaddr) == 0)
302                                 findlocalip(f, c->laddr, c->raddr);
303                         v6tov4(uh4->udpsrc, c->laddr);
304                         rc = c;
305                 }
306                 hnputs(uh4->udpsport, c->lport);
307                 hnputs(uh4->udplen, ptcllen);
308                 uh4->udpcksum[0] = 0;
309                 uh4->udpcksum[1] = 0;
310                 hnputs(uh4->udpcksum, 
311                        ptclcsum(bp, UDP4_PHDR_OFF, dlen+UDP_UDPHDR_SZ+UDP4_PHDR_SZ));
312                 uh4->vihl = IP_VER4;
313                 ipoput4(f, bp, 0, c->ttl, c->tos, rc);
314                 break;
315
316         case V6:
317                 bp = padblock(bp, UDP6_IPHDR_SZ+UDP_UDPHDR_SZ);
318                 if(bp == NULL)
319                         return;
320
321                 // using the v6 ip header to create pseudo header 
322                 // first then reset it to the normal ip header
323                 uh6 = (Udp6hdr *)(bp->rp);
324                 memset(uh6, 0, 8);
325                 ptcllen = dlen + UDP_UDPHDR_SZ;
326                 hnputl(uh6->viclfl, ptcllen);
327                 uh6->hoplimit = IP_UDPPROTO;
328                 if(ucb->headers) {
329                         ipmove(uh6->udpdst, raddr);
330                         hnputs(uh6->udpdport, rport);
331                         ipmove(uh6->udpsrc, laddr);
332                         rc = NULL;
333                 } else {
334                         ipmove(uh6->udpdst, c->raddr);
335                         hnputs(uh6->udpdport, c->rport);
336                         if(ipcmp(c->laddr, IPnoaddr) == 0)
337                                 findlocalip(f, c->laddr, c->raddr);
338                         ipmove(uh6->udpsrc, c->laddr);
339                         rc = c;
340                 }
341                 hnputs(uh6->udpsport, c->lport);
342                 hnputs(uh6->udplen, ptcllen);
343                 uh6->udpcksum[0] = 0;
344                 uh6->udpcksum[1] = 0;
345                 hnputs(uh6->udpcksum, 
346                        ptclcsum(bp, UDP6_PHDR_OFF, dlen+UDP_UDPHDR_SZ+UDP6_PHDR_SZ));
347                 memset(uh6, 0, 8);
348                 uh6->viclfl[0] = IP_VER6;
349                 hnputs(uh6->len, ptcllen);
350                 uh6->nextheader = IP_UDPPROTO;
351                 ipoput6(f, bp, 0, c->ttl, c->tos, rc);
352                 break;
353
354         default:
355                 panic("udpkick: version %d", version);
356         }
357         upriv->ustats.udpOutDatagrams++;
358 }
359
360 void
361 udpiput(struct Proto *udp, struct Ipifc *ifc, struct block *bp)
362 {
363         int len;
364         Udp4hdr *uh4;
365         Udp6hdr *uh6;
366         struct conv *c;
367         Udpcb *ucb;
368         uint8_t raddr[IPaddrlen], laddr[IPaddrlen];
369         uint16_t rport, lport;
370         Udppriv *upriv;
371         struct Fs *f;
372         int version;
373         int ottl, oviclfl, olen;
374         uint8_t *p;
375
376         upriv = udp->priv;
377         f = udp->f;
378         upriv->ustats.udpInDatagrams++;
379
380         uh4 = (Udp4hdr*)(bp->rp);
381         version = ((uh4->vihl&0xF0)==IP_VER6) ? V6 : V4;
382
383         /*
384          * Put back pseudo header for checksum 
385          * (remember old values for icmpnoconv())
386          */
387         switch(version) {
388         case V4:
389                 ottl = uh4->Unused;
390                 uh4->Unused = 0;
391                 len = nhgets(uh4->udplen);
392                 olen = nhgets(uh4->udpplen);
393                 hnputs(uh4->udpplen, len);
394
395                 v4tov6(raddr, uh4->udpsrc);
396                 v4tov6(laddr, uh4->udpdst);
397                 lport = nhgets(uh4->udpdport);
398                 rport = nhgets(uh4->udpsport);
399
400                 if(nhgets(uh4->udpcksum)) {
401                         if(ptclcsum(bp, UDP4_PHDR_OFF, len+UDP4_PHDR_SZ)) {
402                                 upriv->ustats.udpInErrors++;
403                                 netlog(f, Logudp, "udp: checksum error %I\n", raddr);
404                                 printd("udp: checksum error %I\n", raddr);
405                                 freeblist(bp);
406                                 return;
407                         }
408                 }
409                 uh4->Unused = ottl;
410                 hnputs(uh4->udpplen, olen);
411                 break;
412         case V6:
413                 uh6 = (Udp6hdr*)(bp->rp);
414                 len = nhgets(uh6->udplen);
415                 oviclfl = nhgetl(uh6->viclfl);
416                 olen = nhgets(uh6->len);
417                 ottl = uh6->hoplimit;
418                 ipmove(raddr, uh6->udpsrc);
419                 ipmove(laddr, uh6->udpdst);
420                 lport = nhgets(uh6->udpdport);
421                 rport = nhgets(uh6->udpsport);
422                 memset(uh6, 0, 8);
423                 hnputl(uh6->viclfl, len);
424                 uh6->hoplimit = IP_UDPPROTO;
425                 if(ptclcsum(bp, UDP6_PHDR_OFF, len+UDP6_PHDR_SZ)) {
426                         upriv->ustats.udpInErrors++;
427                         netlog(f, Logudp, "udp: checksum error %I\n", raddr);
428                         printd("udp: checksum error %I\n", raddr);
429                         freeblist(bp);
430                         return;
431                 }
432                 hnputl(uh6->viclfl, oviclfl);
433                 hnputs(uh6->len, olen);
434                 uh6->nextheader = IP_UDPPROTO;
435                 uh6->hoplimit = ottl;
436                 break;
437         default:
438                 panic("udpiput: version %d", version);
439                 return; /* to avoid a warning */
440         }
441
442         qlock(&udp->qlock);
443
444         c = iphtlook(&upriv->ht, raddr, rport, laddr, lport);
445         if(c == NULL){
446                 /* no converstation found */
447                 upriv->ustats.udpNoPorts++;
448                 qunlock(&udp->qlock);
449                 netlog(f, Logudp, "udp: no conv %I!%d -> %I!%d\n", raddr, rport,
450                        laddr, lport);
451
452                 switch(version){
453                 case V4:
454                         icmpnoconv(f, bp);
455                         break;
456                 case V6:
457                         icmphostunr(f, ifc, bp, icmp6_port_unreach, 0);
458                         break;
459                 default:
460                         panic("udpiput2: version %d", version);
461                 }
462
463                 freeblist(bp);
464                 return;
465         }
466         ucb = (Udpcb*)c->ptcl;
467
468         if(c->state == Announced){
469                 if(ucb->headers == 0){
470                         /* create a new conversation */
471                         if(ipforme(f, laddr) != Runi) {
472                                 switch(version){
473                                 case V4:
474                                         v4tov6(laddr, ifc->lifc->local);
475                                         break;
476                                 case V6:
477                                         ipmove(laddr, ifc->lifc->local);
478                                         break;
479                                 default:
480                                         panic("udpiput3: version %d", version);
481                                 }
482                         }
483                         c = Fsnewcall(c, raddr, rport, laddr, lport, version);
484                         if(c == NULL){
485                                 qunlock(&udp->qlock);
486                                 freeblist(bp);
487                                 return;
488                         }
489                         iphtadd(&upriv->ht, c);
490                         ucb = (Udpcb*)c->ptcl;
491                 }
492         }
493
494         qlock(&c->qlock);
495         qunlock(&udp->qlock);
496
497         /*
498          * Trim the packet down to data size
499          */
500         len -= UDP_UDPHDR_SZ;
501         switch(version){
502         case V4:
503                 bp = trimblock(bp, UDP4_IPHDR_SZ+UDP_UDPHDR_SZ, len);
504                 break;
505         case V6:
506                 bp = trimblock(bp, UDP6_IPHDR_SZ+UDP_UDPHDR_SZ, len);
507                 break;
508         default:
509                 bp = NULL;
510                 panic("udpiput4: version %d", version);
511         }
512         if(bp == NULL){
513                 qunlock(&c->qlock);
514                 netlog(f, Logudp, "udp: len err %I.%d -> %I.%d\n", raddr, rport,
515                        laddr, lport);
516                 upriv->lenerr++;
517                 return;
518         }
519
520         netlog(f, Logudpmsg, "udp: %I.%d -> %I.%d l %d\n", raddr, rport,
521                laddr, lport, len);
522
523         switch(ucb->headers){
524         case 7:
525                 /* pass the src address */
526                 bp = padblock(bp, UDP_USEAD7);
527                 p = bp->rp;
528                 ipmove(p, raddr); p += IPaddrlen;
529                 ipmove(p, laddr); p += IPaddrlen;
530                 ipmove(p, ifc->lifc->local); p += IPaddrlen;
531                 hnputs(p, rport); p += 2;
532                 hnputs(p, lport);
533                 break;
534         case 6:
535                 /* pass the src address */
536                 bp = padblock(bp, UDP_USEAD6);
537                 p = bp->rp;
538                 ipmove(p, raddr); p += IPaddrlen;
539                 ipmove(p, ipforme(f, laddr)==Runi ? laddr : ifc->lifc->local); p += IPaddrlen;
540                 hnputs(p, rport); p += 2;
541                 hnputs(p, lport);
542                 break;
543         }
544
545         if(bp->next)
546                 bp = concatblock(bp);
547
548         if(qfull(c->rq)){
549                 qunlock(&c->qlock);
550                 netlog(f, Logudp, "udp: qfull %I.%d -> %I.%d\n", raddr, rport,
551                        laddr, lport);
552                 freeblist(bp);
553                 return;
554         }
555
556         qpass(c->rq, bp);
557         qunlock(&c->qlock);
558
559 }
560
561 char*
562 udpctl(struct conv *c, char **f, int n)
563 {
564         Udpcb *ucb;
565
566         ucb = (Udpcb*)c->ptcl;
567         if(n == 1){
568                 if(strcmp(f[0], "oldheaders") == 0){
569                         ucb->headers = 6;
570                         return NULL;
571                 } else if(strcmp(f[0], "headers") == 0){
572                         ucb->headers = 7;
573                         return NULL;
574                 }
575         }
576         return "unknown control request";
577 }
578
579 void
580 udpadvise(struct Proto *udp, struct block *bp, char *msg)
581 {
582         Udp4hdr *h4;
583         Udp6hdr *h6;
584         uint8_t source[IPaddrlen], dest[IPaddrlen];
585         uint16_t psource, pdest;
586         struct conv *s, **p;
587         int version;
588
589         h4 = (Udp4hdr*)(bp->rp);
590         version = ((h4->vihl&0xF0)==IP_VER6) ? V6 : V4;
591
592         switch(version) {
593         case V4:
594                 v4tov6(dest, h4->udpdst);
595                 v4tov6(source, h4->udpsrc);
596                 psource = nhgets(h4->udpsport);
597                 pdest = nhgets(h4->udpdport);
598                 break;
599         case V6:
600                 h6 = (Udp6hdr*)(bp->rp);
601                 ipmove(dest, h6->udpdst);
602                 ipmove(source, h6->udpsrc);
603                 psource = nhgets(h6->udpsport);
604                 pdest = nhgets(h6->udpdport);
605                 break;
606         default:
607                 panic("udpadvise: version %d", version);
608                 return;  /* to avoid a warning */
609         }
610
611         /* Look for a connection */
612         qlock(&udp->qlock);
613         for(p = udp->conv; *p; p++) {
614                 s = *p;
615                 if(s->rport == pdest)
616                 if(s->lport == psource)
617                 if(ipcmp(s->raddr, dest) == 0)
618                 if(ipcmp(s->laddr, source) == 0){
619                         if(s->ignoreadvice)
620                                 break;
621                         qlock(&s->qlock);
622                         qunlock(&udp->qlock);
623                         qhangup(s->rq, msg);
624                         qhangup(s->wq, msg);
625                         qunlock(&s->qlock);
626                         freeblist(bp);
627                         return;
628                 }
629         }
630         qunlock(&udp->qlock);
631         freeblist(bp);
632 }
633
634 int
635 udpstats(struct Proto *udp, char *buf, int len)
636 {
637         Udppriv *upriv;
638
639         upriv = udp->priv;
640         return snprintf(buf, len, "InDatagrams: %lud\nNoPorts: %lud\nInErrors: %lud\nOutDatagrams: %lud\n",
641                 upriv->ustats.udpInDatagrams,
642                 upriv->ustats.udpNoPorts,
643                 upriv->ustats.udpInErrors,
644                 upriv->ustats.udpOutDatagrams);
645 }
646
647 void udpnewconv(struct Proto *udp, struct conv *conv)
648 {
649         /* Fsprotoclone alloc'd our priv struct and attached it to conv already.
650          * Now we need to init it */
651         struct Udpcb *ucb = (struct Udpcb*)conv->ptcl;
652         qlock_init(&ucb->qlock);
653 }
654
655 void
656 udpinit(struct Fs *fs)
657 {
658         struct Proto *udp;
659
660         udp = kzmalloc(sizeof(struct Proto), 0);
661         udp->priv = kzmalloc(sizeof(Udppriv), 0);
662         udp->name = "udp";
663         udp->connect = udpconnect;
664         udp->announce = udpannounce;
665         udp->ctl = udpctl;
666         udp->state = udpstate;
667         udp->create = udpcreate;
668         udp->close = udpclose;
669         udp->rcv = udpiput;
670         udp->advise = udpadvise;
671         udp->stats = udpstats;
672         udp->ipproto = IP_UDPPROTO;
673         udp->nc = Nchans;
674         udp->newconv = udpnewconv;
675         udp->ptclsize = sizeof(Udpcb);
676
677         Fsproto(fs, udp);
678 }