Open mode checks, don't use == for OREAD
[akaros.git] / kern / src / net / icmp.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 #include <vfs.h>
17 #include <kfs.h>
18 #include <slab.h>
19 #include <kmalloc.h>
20 #include <kref.h>
21 #include <string.h>
22 #include <stdio.h>
23 #include <assert.h>
24 #include <error.h>
25 #include <cpio.h>
26 #include <pmap.h>
27 #include <smp.h>
28 #include <ip.h>
29
30 typedef struct Icmp {
31         uint8_t vihl;           /* Version and header length */
32         uint8_t tos;            /* Type of service */
33         uint8_t length[2];      /* packet length */
34         uint8_t id[2];          /* Identification */
35         uint8_t frag[2];        /* Fragment information */
36         uint8_t ttl;            /* Time to live */
37         uint8_t proto;          /* Protocol */
38         uint8_t ipcksum[2];     /* Header checksum */
39         uint8_t src[4];         /* Ip source */
40         uint8_t dst[4];         /* Ip destination */
41         uint8_t type;
42         uint8_t code;
43         uint8_t cksum[2];
44         uint8_t icmpid[2];
45         uint8_t seq[2];
46         uint8_t data[1];
47 } Icmp;
48
49 enum {                  /* Packet Types */
50         EchoReply       = 0,
51         Unreachable     = 3,
52         SrcQuench       = 4,
53         Redirect        = 5,
54         EchoRequest     = 8,
55         TimeExceed      = 11,
56         InParmProblem   = 12,
57         Timestamp       = 13,
58         TimestampReply  = 14,
59         InfoRequest     = 15,
60         InfoReply       = 16,
61         AddrMaskRequest = 17,
62         AddrMaskReply   = 18,
63
64         Maxtype         = 18,
65 };
66
67 enum
68 {
69         MinAdvise       = 24,   /* minimum needed for us to advise another protocol */ 
70 };
71
72 char *icmpnames[Maxtype+1] =
73 {
74 [EchoReply]             "EchoReply",
75 [Unreachable]           "Unreachable",
76 [SrcQuench]             "SrcQuench",
77 [Redirect]              "Redirect",
78 [EchoRequest]           "EchoRequest",
79 [TimeExceed]            "TimeExceed",
80 [InParmProblem]         "InParmProblem",
81 [Timestamp]             "Timestamp",
82 [TimestampReply]        "TimestampReply",
83 [InfoRequest]           "InfoRequest",
84 [InfoReply]             "InfoReply",
85 [AddrMaskRequest]       "AddrMaskRequest",
86 [AddrMaskReply  ]       "AddrMaskReply  ",
87 };
88
89 enum {
90         IP_ICMPPROTO    = 1,
91         ICMP_IPSIZE     = 20,
92         ICMP_HDRSIZE    = 8,
93 };
94
95 enum
96 {
97         InMsgs,
98         InErrors,
99         OutMsgs,
100         CsumErrs,
101         LenErrs,
102         HlenErrs,
103
104         Nstats,
105 };
106
107 static char *statnames[Nstats] =
108 {
109 [InMsgs]        "InMsgs",
110 [InErrors]      "InErrors",
111 [OutMsgs]       "OutMsgs",
112 [CsumErrs]      "CsumErrs",
113 [LenErrs]       "LenErrs",
114 [HlenErrs]      "HlenErrs",
115 };
116
117 typedef struct Icmppriv Icmppriv;
118 struct Icmppriv
119 {
120         uint32_t        stats[Nstats];
121
122         /* message counts */
123         uint32_t        in[Maxtype+1];
124         uint32_t        out[Maxtype+1];
125 };
126
127 static void icmpkick(void *x, struct block*);
128
129 static void
130 icmpcreate(struct conv *c)
131 {
132         c->rq = qopen(64*1024, Qmsg, 0, c);
133         c->wq = qbypass(icmpkick, c);
134 }
135
136 extern char*
137 icmpconnect(struct conv *c, char **argv, int argc)
138 {
139         char *e;
140
141         e = Fsstdconnect(c, argv, argc);
142         if(e != NULL)
143                 return e;
144         Fsconnected(c, e);
145
146         return NULL;
147 }
148
149 extern int
150 icmpstate(struct conv *c, char *state, int n)
151 {
152         return snprintf(state, n, "%s qin %d qout %d",
153                 "Datagram",
154                 c->rq ? qlen(c->rq) : 0,
155                 c->wq ? qlen(c->wq) : 0
156         );
157 }
158
159 extern char*
160 icmpannounce(struct conv *c, char **argv, int argc)
161 {
162         char *e;
163
164         e = Fsstdannounce(c, argv, argc);
165         if(e != NULL)
166                 return e;
167         Fsconnected(c, NULL);
168
169         return NULL;
170 }
171
172 extern void
173 icmpclose(struct conv *c)
174 {
175         qclose(c->rq);
176         qclose(c->wq);
177         ipmove(c->laddr, IPnoaddr);
178         ipmove(c->raddr, IPnoaddr);
179         c->lport = 0;
180 }
181
182 static void
183 icmpkick(void *x, struct block *bp)
184 {
185         struct conv *c = x;
186         Icmp *p;
187         Icmppriv *ipriv;
188
189         if(bp == NULL)
190                 return;
191
192         if(blocklen(bp) < ICMP_IPSIZE + ICMP_HDRSIZE){
193                 freeblist(bp);
194                 return;
195         }
196         p = (Icmp *)(bp->rp);
197         p->vihl = IP_VER4;
198         ipriv = c->p->priv;
199         if(p->type <= Maxtype)  
200                 ipriv->out[p->type]++;
201         
202         v6tov4(p->dst, c->raddr);
203         v6tov4(p->src, c->laddr);
204         p->proto = IP_ICMPPROTO;
205         hnputs(p->icmpid, c->lport);
206         memset(p->cksum, 0, sizeof(p->cksum));
207         hnputs(p->cksum, ptclcsum(bp, ICMP_IPSIZE, blocklen(bp) - ICMP_IPSIZE));
208         ipriv->stats[OutMsgs]++;
209         ipoput4(c->p->f, bp, 0, c->ttl, c->tos, NULL);
210 }
211
212 extern void
213 icmpttlexceeded(struct Fs *f, uint8_t *ia, struct block *bp)
214 {
215         struct block    *nbp;
216         Icmp    *p, *np;
217
218         p = (Icmp *)bp->rp;
219
220         netlog(f, Logicmp, "sending icmpttlexceeded -> %V\n", p->src);
221         nbp = allocb(ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8);
222         nbp->wp += ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8;
223         np = (Icmp *)nbp->rp;
224         np->vihl = IP_VER4;
225         memmove(np->dst, p->src, sizeof(np->dst));
226         v6tov4(np->src, ia);
227         memmove(np->data, bp->rp, ICMP_IPSIZE + 8);
228         np->type = TimeExceed;
229         np->code = 0;
230         np->proto = IP_ICMPPROTO;
231         hnputs(np->icmpid, 0);
232         hnputs(np->seq, 0);
233         memset(np->cksum, 0, sizeof(np->cksum));
234         hnputs(np->cksum, ptclcsum(nbp, ICMP_IPSIZE, blocklen(nbp) - ICMP_IPSIZE));
235         ipoput4(f, nbp, 0, MAXTTL, DFLTTOS, NULL);
236
237 }
238
239 static void
240 icmpunreachable(struct Fs *f, struct block *bp, int code, int seq)
241 {
242         struct block    *nbp;
243         Icmp    *p, *np;
244         int     i;
245         uint8_t addr[IPaddrlen];
246
247         p = (Icmp *)bp->rp;
248
249         /* only do this for unicast sources and destinations */
250         v4tov6(addr, p->dst);
251         i = ipforme(f, addr);
252         if((i&Runi) == 0)
253                 return;
254         v4tov6(addr, p->src);
255         i = ipforme(f, addr);
256         if(i != 0 && (i&Runi) == 0)
257                 return;
258
259         netlog(f, Logicmp, "sending icmpnoconv -> %V\n", p->src);
260         nbp = allocb(ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8);
261         nbp->wp += ICMP_IPSIZE + ICMP_HDRSIZE + ICMP_IPSIZE + 8;
262         np = (Icmp *)nbp->rp;
263         np->vihl = IP_VER4;
264         memmove(np->dst, p->src, sizeof(np->dst));
265         memmove(np->src, p->dst, sizeof(np->src));
266         memmove(np->data, bp->rp, ICMP_IPSIZE + 8);
267         np->type = Unreachable;
268         np->code = code;
269         np->proto = IP_ICMPPROTO;
270         hnputs(np->icmpid, 0);
271         hnputs(np->seq, seq);
272         memset(np->cksum, 0, sizeof(np->cksum));
273         hnputs(np->cksum, ptclcsum(nbp, ICMP_IPSIZE, blocklen(nbp) - ICMP_IPSIZE));
274         ipoput4(f, nbp, 0, MAXTTL, DFLTTOS, NULL);
275 }
276
277 extern void
278 icmpnoconv(struct Fs *f, struct block *bp)
279 {
280         icmpunreachable(f, bp, 3, 0);
281 }
282
283 extern void
284 icmpcantfrag(struct Fs *f, struct block *bp, int mtu)
285 {
286         icmpunreachable(f, bp, 4, mtu);
287 }
288
289 static void
290 goticmpkt(struct Proto *icmp, struct block *bp)
291 {
292         struct conv     **c, *s;
293         Icmp    *p;
294         uint8_t dst[IPaddrlen];
295         uint16_t        recid;
296
297         p = (Icmp *) bp->rp;
298         v4tov6(dst, p->src);
299         recid = nhgets(p->icmpid);
300
301         for(c = icmp->conv; *c; c++) {
302                 s = *c;
303                 if(s->lport == recid)
304                 if(ipcmp(s->raddr, dst) == 0){
305                         bp = concatblock(bp);
306                         if(bp != NULL)
307                                 qpass(s->rq, bp);
308                         return;
309                 }
310         }
311         freeblist(bp);
312 }
313
314 static struct block *
315 mkechoreply(struct block *bp)
316 {
317         Icmp    *q;
318         uint8_t ip[4];
319
320         q = (Icmp *)bp->rp;
321         q->vihl = IP_VER4;
322         memmove(ip, q->src, sizeof(q->dst));
323         memmove(q->src, q->dst, sizeof(q->src));
324         memmove(q->dst, ip,  sizeof(q->dst));
325         q->type = EchoReply;
326         memset(q->cksum, 0, sizeof(q->cksum));
327         hnputs(q->cksum, ptclcsum(bp, ICMP_IPSIZE, blocklen(bp) - ICMP_IPSIZE));
328
329         return bp;
330 }
331
332 static char *unreachcode[] =
333 {
334 [0]     "net unreachable",
335 [1]     "host unreachable",
336 [2]     "protocol unreachable",
337 [3]     "port unreachable",
338 [4]     "fragmentation needed and DF set",
339 [5]     "source route failed",
340 };
341
342 static void
343 icmpiput(struct Proto *icmp, struct Ipifc*unused, struct block *bp)
344 {
345         int     n, iplen;
346         Icmp    *p;
347         struct block    *r;
348         struct Proto    *pr;
349         char    *msg;
350         char    m2[128];
351         Icmppriv *ipriv;
352
353         ipriv = icmp->priv;
354         
355         ipriv->stats[InMsgs]++;
356
357         p = (Icmp *)bp->rp;
358         netlog(icmp->f, Logicmp, "icmpiput %d %d\n", p->type, p->code);
359         n = blocklen(bp);
360         if(n < ICMP_IPSIZE+ICMP_HDRSIZE){
361                 ipriv->stats[InErrors]++;
362                 ipriv->stats[HlenErrs]++;
363                 netlog(icmp->f, Logicmp, "icmp hlen %d\n", n);
364                 goto raise;
365         }
366         iplen = nhgets(p->length);
367         if(iplen > n || (iplen % 1)){
368                 ipriv->stats[LenErrs]++;
369                 ipriv->stats[InErrors]++;
370                 netlog(icmp->f, Logicmp, "icmp length %d\n", iplen);
371                 goto raise;
372         }
373         if(ptclcsum(bp, ICMP_IPSIZE, iplen - ICMP_IPSIZE)){
374                 ipriv->stats[InErrors]++;
375                 ipriv->stats[CsumErrs]++;
376                 netlog(icmp->f, Logicmp, "icmp checksum error\n");
377                 goto raise;
378         }
379         if(p->type <= Maxtype)
380                 ipriv->in[p->type]++;
381
382         switch(p->type) {
383         case EchoRequest:
384                 if (iplen < n)
385                         bp = trimblock(bp, 0, iplen);
386                 r = mkechoreply(bp);
387                 ipriv->out[EchoReply]++;
388                 ipoput4(icmp->f, r, 0, MAXTTL, DFLTTOS, NULL);
389                 break;
390         case Unreachable:
391                 if(p->code > 5)
392                         msg = unreachcode[1];
393                 else
394                         msg = unreachcode[p->code];
395
396                 bp->rp += ICMP_IPSIZE+ICMP_HDRSIZE;
397                 if(blocklen(bp) < MinAdvise){
398                         ipriv->stats[LenErrs]++;
399                         goto raise;
400                 }
401                 p = (Icmp *)bp->rp;
402                 pr = Fsrcvpcolx(icmp->f, p->proto);
403                 if(pr != NULL && pr->advise != NULL) {
404                         (*pr->advise)(pr, bp, msg);
405                         return;
406                 }
407
408                 bp->rp -= ICMP_IPSIZE+ICMP_HDRSIZE;
409                 goticmpkt(icmp, bp);
410                 break;
411         case TimeExceed:
412                 if(p->code == 0){
413                         snprintf(m2, sizeof(m2), "ttl exceeded at %V", p->src);
414
415                         bp->rp += ICMP_IPSIZE+ICMP_HDRSIZE;
416                         if(blocklen(bp) < MinAdvise){
417                                 ipriv->stats[LenErrs]++;
418                                 goto raise;
419                         }
420                         p = (Icmp *)bp->rp;
421                         pr = Fsrcvpcolx(icmp->f, p->proto);
422                         if(pr != NULL && pr->advise != NULL) {
423                                 (*pr->advise)(pr, bp, m2);
424                                 return;
425                         }
426                         bp->rp -= ICMP_IPSIZE+ICMP_HDRSIZE;
427                 }
428
429                 goticmpkt(icmp, bp);
430                 break;
431         default:
432                 goticmpkt(icmp, bp);
433                 break;
434         }
435         return;
436
437 raise:
438         freeblist(bp);
439 }
440
441 void
442 icmpadvise(struct Proto *icmp, struct block *bp, char *msg)
443 {
444         struct conv     **c, *s;
445         Icmp    *p;
446         uint8_t dst[IPaddrlen];
447         uint16_t        recid;
448
449         p = (Icmp *) bp->rp;
450         v4tov6(dst, p->dst);
451         recid = nhgets(p->icmpid);
452
453         for(c = icmp->conv; *c; c++) {
454                 s = *c;
455                 if(s->lport == recid)
456                 if(ipcmp(s->raddr, dst) == 0){
457                         qhangup(s->rq, msg);
458                         qhangup(s->wq, msg);
459                         break;
460                 }
461         }
462         freeblist(bp);
463 }
464
465 int
466 icmpstats(struct Proto *icmp, char *buf, int len)
467 {
468         Icmppriv *priv;
469         char *p, *e;
470         int i;
471
472         priv = icmp->priv;
473         p = buf;
474         e = p+len;
475         for(i = 0; i < Nstats; i++)
476                 p = seprintf(p, e, "%s: %lud\n", statnames[i], priv->stats[i]);
477         for(i = 0; i <= Maxtype; i++){
478                 if(icmpnames[i])
479                         p = seprintf(p, e, "%s: %lud %lud\n", icmpnames[i], priv->in[i], priv->out[i]);
480                 else
481                         p = seprintf(p, e, "%d: %lud %lud\n", i, priv->in[i], priv->out[i]);
482         }
483         return p - buf;
484 }
485         
486 void
487 icmpinit(struct Fs *fs)
488 {
489         struct Proto *icmp;
490
491         icmp = kzmalloc(sizeof(struct Proto), 0);
492         icmp->priv = kzmalloc(sizeof(Icmppriv), 0);
493         icmp->name = "icmp";
494         icmp->connect = icmpconnect;
495         icmp->announce = icmpannounce;
496         icmp->state = icmpstate;
497         icmp->create = icmpcreate;
498         icmp->close = icmpclose;
499         icmp->rcv = icmpiput;
500         icmp->stats = icmpstats;
501         icmp->ctl = NULL;
502         icmp->advise = icmpadvise;
503         icmp->gc = NULL;
504         icmp->ipproto = IP_ICMPPROTO;
505         icmp->nc = 128;
506         icmp->ptclsize = 0;
507
508         Fsproto(fs, icmp);
509 }