Fixes ICMP block reuse (checksumming)
[akaros.git] / kern / src / net / icmp.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 #include <vfs.h>
17 #include <kfs.h>
18 #include <slab.h>
19 #include <kmalloc.h>
20 #include <kref.h>
21 #include <string.h>
22 #include <stdio.h>
23 #include <assert.h>
24 #include <error.h>
25 #include <cpio.h>
26 #include <pmap.h>
27 #include <smp.h>
28 #include <ip.h>
29
30 typedef struct Icmp {
31         uint8_t vihl;                           /* Version and header length */
32         uint8_t tos;                            /* Type of service */
33         uint8_t length[2];                      /* packet length */
34         uint8_t id[2];                          /* Identification */
35         uint8_t frag[2];                        /* Fragment information */
36         uint8_t ttl;                            /* Time to live */
37         uint8_t proto;                          /* Protocol */
38         uint8_t ipcksum[2];                     /* Header checksum */
39         uint8_t src[4];                         /* Ip source */
40         uint8_t dst[4];                         /* Ip destination */
41         uint8_t type;
42         uint8_t code;
43         uint8_t cksum[2];
44         uint8_t icmpid[2];
45         uint8_t seq[2];
46         uint8_t data[1];
47 } Icmp;
48
49 enum {                                                  /* Packet Types */
50         EchoReply = 0,
51         Unreachable = 3,
52         SrcQuench = 4,
53         Redirect = 5,
54         EchoRequest = 8,
55         TimeExceed = 11,
56         InParmProblem = 12,
57         Timestamp = 13,
58         TimestampReply = 14,
59         InfoRequest = 15,
60         InfoReply = 16,
61         AddrMaskRequest = 17,
62         AddrMaskReply = 18,
63
64         Maxtype = 18,
65 };
66
67 enum {
68         MinAdvise = 24,                         /* minimum needed for us to advise another protocol */
69 };
70
71 char *icmpnames[Maxtype + 1] = {
72         [EchoReply] "EchoReply",
73         [Unreachable] "Unreachable",
74         [SrcQuench] "SrcQuench",
75         [Redirect] "Redirect",
76         [EchoRequest] "EchoRequest",
77         [TimeExceed] "TimeExceed",
78         [InParmProblem] "InParmProblem",
79         [Timestamp] "Timestamp",
80         [TimestampReply] "TimestampReply",
81         [InfoRequest] "InfoRequest",
82         [InfoReply] "InfoReply",
83         [AddrMaskRequest] "AddrMaskRequest",
84         [AddrMaskReply] "AddrMaskReply  ",
85 };
86
87 enum {
88         IP_ICMPPROTO = 1,
89         ICMP_IPSIZE = 20,
90         ICMP_HDRSIZE = 8,
91 };
92
93 enum {
94         InMsgs,
95         InErrors,
96         OutMsgs,
97         CsumErrs,
98         LenErrs,
99         HlenErrs,
100
101         Nstats,
102 };
103
104 static char *statnames[Nstats] = {
105         [InMsgs] "InMsgs",
106         [InErrors] "InErrors",
107         [OutMsgs] "OutMsgs",
108         [CsumErrs] "CsumErrs",
109         [LenErrs] "LenErrs",
110         [HlenErrs] "HlenErrs",
111 };
112
113 typedef struct Icmppriv Icmppriv;
114 struct Icmppriv {
115         uint32_t stats[Nstats];
116
117         /* message counts */
118         uint32_t in[Maxtype + 1];
119         uint32_t out[Maxtype + 1];
120 };
121
122 static void icmpkick(void *x, struct block *);
123
124 static void icmpcreate(struct conv *c)
125 {
126         c->rq = qopen(64 * 1024, Qmsg, 0, c);
127         c->wq = qbypass(icmpkick, c);
128 }
129
130 extern char *icmpconnect(struct conv *c, char **argv, int argc)
131 {
132         char *e;
133
134         e = Fsstdconnect(c, argv, argc);
135         if (e != NULL)
136                 return e;
137         Fsconnected(c, e);
138
139         return NULL;
140 }
141
142 extern int icmpstate(struct conv *c, char *state, int n)
143 {
144         return snprintf(state, n, "%s qin %d qout %d\n",
145                                         "Datagram",
146                                         c->rq ? qlen(c->rq) : 0, c->wq ? qlen(c->wq) : 0);
147 }
148
149 extern char *icmpannounce(struct conv *c, char **argv, int argc)
150 {
151         char *e;
152
153         e = Fsstdannounce(c, argv, argc);
154         if (e != NULL)
155                 return e;
156         Fsconnected(c, NULL);
157
158         return NULL;
159 }
160
161 extern void icmpclose(struct conv *c)
162 {
163         qclose(c->rq);
164         qclose(c->wq);
165         ipmove(c->laddr, IPnoaddr);
166         ipmove(c->raddr, IPnoaddr);
167         c->lport = 0;
168 }
169
170 static void icmpkick(void *x, struct block *bp)
171 {
172         struct conv *c = x;
173         Icmp *p;
174         Icmppriv *ipriv;
175
176         if (bp == NULL)
177                 return;
178
179         bp = pullupblock(bp, ICMP_IPSIZE + ICMP_HDRSIZE);
180         if (bp == 0)
181                 return;
182         p = (Icmp *) (bp->rp);
183         p->vihl = IP_VER4;
184         ipriv = c->p->priv;
185         if (p->type <= Maxtype)
186                 ipriv->out[p->type]++;
187
188         v6tov4(p->dst, c->raddr);
189         v6tov4(p->src, c->laddr);
190         p->proto = IP_ICMPPROTO;
191         hnputs(p->icmpid, c->lport);
192         memset(p->cksum, 0, sizeof(p->cksum));
193         hnputs(p->cksum, ptclcsum(bp, ICMP_IPSIZE, blocklen(bp) - ICMP_IPSIZE));
194         ipriv->stats[OutMsgs]++;
195         ipoput4(c->p->f, bp, 0, c->ttl, c->tos, NULL);
196 }
197
198 extern void icmpttlexceeded(struct Fs *f, uint8_t * ia, struct block *bp)
199 {
200         struct block *nbp;
201         Icmp *p, *np;
202
203         p = (Icmp *) bp->rp;
204
205         netlog(f, Logicmp, "sending icmpttlexceeded -> %V\n", p->src);
206         nbp = allocb(ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8);
207         nbp->wp += ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8;
208         np = (Icmp *) nbp->rp;
209         np->vihl = IP_VER4;
210         memmove(np->dst, p->src, sizeof(np->dst));
211         v6tov4(np->src, ia);
212         memmove(np->data, bp->rp, ICMP_IPSIZE + 8);
213         np->type = TimeExceed;
214         np->code = 0;
215         np->proto = IP_ICMPPROTO;
216         hnputs(np->icmpid, 0);
217         hnputs(np->seq, 0);
218         memset(np->cksum, 0, sizeof(np->cksum));
219         hnputs(np->cksum, ptclcsum(nbp, ICMP_IPSIZE, blocklen(nbp) - ICMP_IPSIZE));
220         ipoput4(f, nbp, 0, MAXTTL, DFLTTOS, NULL);
221
222 }
223
224 static void icmpunreachable(struct Fs *f, struct block *bp, int code, int seq)
225 {
226         struct block *nbp;
227         Icmp *p, *np;
228         int i;
229         uint8_t addr[IPaddrlen];
230
231         p = (Icmp *) bp->rp;
232
233         /* only do this for unicast sources and destinations */
234         v4tov6(addr, p->dst);
235         i = ipforme(f, addr);
236         if ((i & Runi) == 0)
237                 return;
238         v4tov6(addr, p->src);
239         i = ipforme(f, addr);
240         if (i != 0 && (i & Runi) == 0)
241                 return;
242
243         netlog(f, Logicmp, "sending icmpnoconv -> %V\n", p->src);
244         nbp = allocb(ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8);
245         nbp->wp += ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8;
246         np = (Icmp *) nbp->rp;
247         np->vihl = IP_VER4;
248         memmove(np->dst, p->src, sizeof(np->dst));
249         memmove(np->src, p->dst, sizeof(np->src));
250         memmove(np->data, bp->rp, ICMP_IPSIZE + 8);
251         np->type = Unreachable;
252         np->code = code;
253         np->proto = IP_ICMPPROTO;
254         hnputs(np->icmpid, 0);
255         hnputs(np->seq, seq);
256         memset(np->cksum, 0, sizeof(np->cksum));
257         hnputs(np->cksum, ptclcsum(nbp, ICMP_IPSIZE, blocklen(nbp) - ICMP_IPSIZE));
258         ipoput4(f, nbp, 0, MAXTTL, DFLTTOS, NULL);
259 }
260
261 extern void icmpnoconv(struct Fs *f, struct block *bp)
262 {
263         icmpunreachable(f, bp, 3, 0);
264 }
265
266 extern void icmpcantfrag(struct Fs *f, struct block *bp, int mtu)
267 {
268         icmpunreachable(f, bp, 4, mtu);
269 }
270
271 static void goticmpkt(struct Proto *icmp, struct block *bp)
272 {
273         struct conv **c, *s;
274         Icmp *p;
275         uint8_t dst[IPaddrlen];
276         uint16_t recid;
277
278         p = (Icmp *) bp->rp;
279         v4tov6(dst, p->src);
280         recid = nhgets(p->icmpid);
281
282         for (c = icmp->conv; *c; c++) {
283                 s = *c;
284                 if (s->lport == recid)
285                         if (ipcmp(s->raddr, dst) == 0) {
286                                 bp = concatblock(bp);
287                                 if (bp != NULL)
288                                         qpass(s->rq, bp);
289                                 return;
290                         }
291         }
292         freeblist(bp);
293 }
294
295 static struct block *mkechoreply(struct block *bp)
296 {
297         Icmp *q;
298         uint8_t ip[4];
299
300         /* we're repurposing bp to send it back out.  we need to remove any inbound
301          * checksum flags (which were saying the HW did the checksum) */
302         bp->flag &= ~BCKSUM_FLAGS;
303         q = (Icmp *) bp->rp;
304         q->vihl = IP_VER4;
305         memmove(ip, q->src, sizeof(q->dst));
306         memmove(q->src, q->dst, sizeof(q->src));
307         memmove(q->dst, ip, sizeof(q->dst));
308         q->type = EchoReply;
309         memset(q->cksum, 0, sizeof(q->cksum));
310         hnputs(q->cksum, ptclcsum(bp, ICMP_IPSIZE, blocklen(bp) - ICMP_IPSIZE));
311
312         return bp;
313 }
314
315 static char *unreachcode[] = {
316         [0] "net unreachable",
317         [1] "host unreachable",
318         [2] "protocol unreachable",
319         [3] "port unreachable",
320         [4] "fragmentation needed and DF set",
321         [5] "source route failed",
322 };
323
324 static void icmpiput(struct Proto *icmp, struct Ipifc *unused, struct block *bp)
325 {
326         int n, iplen;
327         Icmp *p;
328         struct block *r;
329         struct Proto *pr;
330         char *msg;
331         char m2[128];
332         Icmppriv *ipriv;
333
334         ipriv = icmp->priv;
335
336         ipriv->stats[InMsgs]++;
337
338         p = (Icmp *) bp->rp;
339         netlog(icmp->f, Logicmp, "icmpiput %d %d\n", p->type, p->code);
340         n = blocklen(bp);
341         if (n < ICMP_IPSIZE + ICMP_HDRSIZE) {
342                 ipriv->stats[InErrors]++;
343                 ipriv->stats[HlenErrs]++;
344                 netlog(icmp->f, Logicmp, "icmp hlen %d\n", n);
345                 goto raise;
346         }
347         iplen = nhgets(p->length);
348         if (iplen > n || (iplen % 1)) {
349                 ipriv->stats[LenErrs]++;
350                 ipriv->stats[InErrors]++;
351                 netlog(icmp->f, Logicmp, "icmp length %d\n", iplen);
352                 goto raise;
353         }
354         if (ptclcsum(bp, ICMP_IPSIZE, iplen - ICMP_IPSIZE)) {
355                 ipriv->stats[InErrors]++;
356                 ipriv->stats[CsumErrs]++;
357                 netlog(icmp->f, Logicmp, "icmp checksum error\n");
358                 goto raise;
359         }
360         if (p->type <= Maxtype)
361                 ipriv->in[p->type]++;
362
363         switch (p->type) {
364                 case EchoRequest:
365                         if (iplen < n)
366                                 bp = trimblock(bp, 0, iplen);
367                         r = mkechoreply(bp);
368                         ipriv->out[EchoReply]++;
369                         ipoput4(icmp->f, r, 0, MAXTTL, DFLTTOS, NULL);
370                         break;
371                 case Unreachable:
372                         if (p->code > 5)
373                                 msg = unreachcode[1];
374                         else
375                                 msg = unreachcode[p->code];
376
377                         bp->rp += ICMP_IPSIZE + ICMP_HDRSIZE;
378                         if (blocklen(bp) < MinAdvise) {
379                                 ipriv->stats[LenErrs]++;
380                                 goto raise;
381                         }
382                         p = (Icmp *) bp->rp;
383                         pr = Fsrcvpcolx(icmp->f, p->proto);
384                         if (pr != NULL && pr->advise != NULL) {
385                                 (*pr->advise) (pr, bp, msg);
386                                 return;
387                         }
388
389                         bp->rp -= ICMP_IPSIZE + ICMP_HDRSIZE;
390                         goticmpkt(icmp, bp);
391                         break;
392                 case TimeExceed:
393                         if (p->code == 0) {
394                                 snprintf(m2, sizeof(m2), "ttl exceeded at %V", p->src);
395
396                                 bp->rp += ICMP_IPSIZE + ICMP_HDRSIZE;
397                                 if (blocklen(bp) < MinAdvise) {
398                                         ipriv->stats[LenErrs]++;
399                                         goto raise;
400                                 }
401                                 p = (Icmp *) bp->rp;
402                                 pr = Fsrcvpcolx(icmp->f, p->proto);
403                                 if (pr != NULL && pr->advise != NULL) {
404                                         (*pr->advise) (pr, bp, m2);
405                                         return;
406                                 }
407                                 bp->rp -= ICMP_IPSIZE + ICMP_HDRSIZE;
408                         }
409
410                         goticmpkt(icmp, bp);
411                         break;
412                 default:
413                         goticmpkt(icmp, bp);
414                         break;
415         }
416         return;
417
418 raise:
419         freeblist(bp);
420 }
421
422 void icmpadvise(struct Proto *icmp, struct block *bp, char *msg)
423 {
424         struct conv **c, *s;
425         Icmp *p;
426         uint8_t dst[IPaddrlen];
427         uint16_t recid;
428
429         p = (Icmp *) bp->rp;
430         v4tov6(dst, p->dst);
431         recid = nhgets(p->icmpid);
432
433         for (c = icmp->conv; *c; c++) {
434                 s = *c;
435                 if (s->lport == recid)
436                         if (ipcmp(s->raddr, dst) == 0) {
437                                 qhangup(s->rq, msg);
438                                 qhangup(s->wq, msg);
439                                 break;
440                         }
441         }
442         freeblist(bp);
443 }
444
445 int icmpstats(struct Proto *icmp, char *buf, int len)
446 {
447         Icmppriv *priv;
448         char *p, *e;
449         int i;
450
451         priv = icmp->priv;
452         p = buf;
453         e = p + len;
454         for (i = 0; i < Nstats; i++)
455                 p = seprintf(p, e, "%s: %u\n", statnames[i], priv->stats[i]);
456         for (i = 0; i <= Maxtype; i++) {
457                 if (icmpnames[i])
458                         p = seprintf(p, e, "%s: %u %u\n", icmpnames[i], priv->in[i],
459                                                  priv->out[i]);
460                 else
461                         p = seprintf(p, e, "%d: %u %u\n", i, priv->in[i], priv->out[i]);
462         }
463         return p - buf;
464 }
465
466 void icmpinit(struct Fs *fs)
467 {
468         struct Proto *icmp;
469
470         icmp = kzmalloc(sizeof(struct Proto), 0);
471         icmp->priv = kzmalloc(sizeof(Icmppriv), 0);
472         icmp->name = "icmp";
473         icmp->connect = icmpconnect;
474         icmp->announce = icmpannounce;
475         icmp->state = icmpstate;
476         icmp->create = icmpcreate;
477         icmp->close = icmpclose;
478         icmp->rcv = icmpiput;
479         icmp->stats = icmpstats;
480         icmp->ctl = NULL;
481         icmp->advise = icmpadvise;
482         icmp->gc = NULL;
483         icmp->ipproto = IP_ICMPPROTO;
484         icmp->nc = 128;
485         icmp->ptclsize = 0;
486
487         Fsproto(fs, icmp);
488 }