Refactored icmpkick6
[akaros.git] / kern / src / net / compress.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 typedef struct Iphdr Iphdr;
17 typedef struct Tcphdr Tcphdr;
18 typedef struct Ilhdr Ilhdr;
19 typedef struct Hdr Hdr;
20 typedef struct Tcpc Tcpc;
21
22 struct Iphdr {
23         uint8_t vihl;                           /* Version and header length */
24         uint8_t tos;                            /* Type of service */
25         uint8_t length[2];                      /* packet length */
26         uint8_t id[2];                          /* Identification */
27         uint8_t frag[2];                        /* Fragment information */
28         uint8_t ttl;                            /* Time to live */
29         uint8_t proto;                          /* Protocol */
30         uint8_t cksum[2];                       /* Header checksum */
31         uint32_t src;                           /* Ip source (byte ordering unimportant) */
32         uint32_t dst;                           /* Ip destination (byte ordering unimportant) */
33 };
34
35 struct Tcphdr {
36         uint32_t ports;                         /* defined as a uint32_t to make comparisons easier */
37         uint8_t seq[4];
38         uint8_t ack[4];
39         uint8_t flag[2];
40         uint8_t win[2];
41         uint8_t cksum[2];
42         uint8_t urg[2];
43 };
44
45 struct Ilhdr {
46         uint8_t sum[2];                         /* Checksum including header */
47         uint8_t len[2];                         /* Packet length */
48         uint8_t type;                           /* Packet type */
49         uint8_t spec;                           /* Special */
50         uint8_t src[2];                         /* Src port */
51         uint8_t dst[2];                         /* Dst port */
52         uint8_t id[4];                          /* Sequence id */
53         uint8_t ack[4];                         /* Acked sequence */
54 };
55
56 enum {
57         URG = 0x20,                                     /* Data marked urgent */
58         ACK = 0x10,     /* Aknowledge is valid */
59         PSH = 0x08,     /* Whole data pipe is pushed */
60         RST = 0x04,     /* Reset connection */
61         SYN = 0x02,     /* Pkt. is synchronise */
62         FIN = 0x01,     /* Start close down */
63
64         IP_DF = 0x4000, /* Don't fragment */
65
66         IP_TCPPROTO = 6,
67         IP_ILPROTO = 40,
68         IL_IPHDR = 20,
69 };
70
71 struct Hdr {
72         uint8_t buf[128];
73         Iphdr *ip;
74         Tcphdr *tcp;
75         int len;
76 };
77
78 struct Tcpc {
79         uint8_t lastrecv;
80         uint8_t lastxmit;
81         uint8_t basexmit;
82         uint8_t err;
83         uint8_t compressid;
84         Hdr t[MAX_STATES];
85         Hdr r[MAX_STATES];
86 };
87
88 enum {                                                  /* flag bits for what changed in a packet */
89         NEW_U = (1 << 0),                       /* tcp only */
90         NEW_W = (1 << 1),       /* tcp only */
91         NEW_A = (1 << 2),       /* il tcp */
92         NEW_S = (1 << 3),       /* tcp only */
93         NEW_P = (1 << 4),       /* tcp only */
94         NEW_I = (1 << 5),       /* il tcp */
95         NEW_C = (1 << 6),       /* il tcp */
96         NEW_T = (1 << 7),       /* il only */
97         TCP_PUSH_BIT = 0x10,
98 };
99
100 /* reserved, special-case values of above for tcp */
101 #define SPECIAL_I (NEW_S|NEW_W|NEW_U)   /* echoed interactive traffic */
102 #define SPECIAL_D (NEW_S|NEW_A|NEW_W|NEW_U)     /* unidirectional data */
103 #define SPECIALS_MASK (NEW_S|NEW_A|NEW_W|NEW_U)
104
105 int encode(void *p, uint32_t n)
106 {
107         uint8_t *cp;
108
109         cp = p;
110         if (n >= 256 || n == 0) {
111                 *cp++ = 0;
112                 cp[0] = n >> 8;
113                 cp[1] = n;
114                 return 3;
115         } else
116                 *cp = n;
117         return 1;
118 }
119
120 #define DECODEL(f) { \
121         if (*cp == 0) {\
122                 hnputl(f, nhgetl(f) + ((cp[1] << 8) | cp[2])); \
123                 cp += 3; \
124         } else { \
125                 hnputl(f, nhgetl(f) + (uint32_t)*cp++); \
126         } \
127 }
128 #define DECODES(f) { \
129         if (*cp == 0) {\
130                 hnputs(f, nhgets(f) + ((cp[1] << 8) | cp[2])); \
131                 cp += 3; \
132         } else { \
133                 hnputs(f, nhgets(f) + (uint32_t)*cp++); \
134         } \
135 }
136
137 uint16_t tcpcompress(Tcpc * comp, struct block * b, struct Fs *)
138 {
139         Iphdr *ip;                                      /* current packet */
140         Tcphdr *tcp;                            /* current pkt */
141         uint32_t iplen, tcplen, hlen;   /* header length in bytes */
142         uint32_t deltaS, deltaA;        /* general purpose temporaries */
143         uint32_t changes;                       /* change mask */
144         uint8_t new_seq[16];            /* changes from last to current */
145         uint8_t *cp;
146         Hdr *h;                                         /* last packet */
147         int i, j;
148
149         /*
150          * Bail if this is not a compressible TCP/IP packet
151          */
152         ip = (Iphdr *) b->rp;
153         iplen = (ip->vihl & 0xf) << 2;
154         tcp = (Tcphdr *) (b->rp + iplen);
155         tcplen = (tcp->flag[0] & 0xf0) >> 2;
156         hlen = iplen + tcplen;
157         if ((tcp->flag[1] & (SYN | FIN | RST | ACK)) != ACK)
158                 return Pip;     /* connection control */
159
160         /*
161          * Packet is compressible, look for a connection
162          */
163         changes = 0;
164         cp = new_seq;
165         j = comp->lastxmit;
166         h = &comp->t[j];
167         if (ip->src != h->ip->src || ip->dst != h->ip->dst
168                 || tcp->ports != h->tcp->ports) {
169                 for (i = 0; i < MAX_STATES; ++i) {
170                         j = (comp->basexmit + i) % MAX_STATES;
171                         h = &comp->t[j];
172                         if (ip->src == h->ip->src && ip->dst == h->ip->dst
173                                 && tcp->ports == h->tcp->ports)
174                                 goto found;
175                 }
176
177                 /* no connection, reuse the oldest */
178                 if (i == MAX_STATES) {
179                         j = comp->basexmit;
180                         j = (j + MAX_STATES - 1) % MAX_STATES;
181                         comp->basexmit = j;
182                         h = &comp->t[j];
183                         goto raise;
184                 }
185         }
186 found:
187
188         /*
189          * Make sure that only what we expect to change changed. 
190          */
191         if (ip->vihl != h->ip->vihl || ip->tos != h->ip->tos ||
192                 ip->ttl != h->ip->ttl || ip->proto != h->ip->proto)
193                 goto raise;     /* headers changed */
194         if (iplen != sizeof(Iphdr)
195                 && memcmp(ip + 1, h->ip + 1, iplen - sizeof(Iphdr)))
196                 goto raise;     /* ip options changed */
197         if (tcplen != sizeof(Tcphdr)
198                 && memcmp(tcp + 1, h->tcp + 1, tcplen - sizeof(Tcphdr)))
199                 goto raise;     /* tcp options changed */
200
201         if (tcp->flag[1] & URG) {
202                 cp += encode(cp, nhgets(tcp->urg));
203                 changes |= NEW_U;
204         } else if (memcmp(tcp->urg, h->tcp->urg, sizeof(tcp->urg)) != 0)
205                 goto raise;
206         if (deltaS = nhgets(tcp->win) - nhgets(h->tcp->win)) {
207                 cp += encode(cp, deltaS);
208                 changes |= NEW_W;
209         }
210         if (deltaA = nhgetl(tcp->ack) - nhgetl(h->tcp->ack)) {
211                 if (deltaA > 0xffff)
212                         goto raise;
213                 cp += encode(cp, deltaA);
214                 changes |= NEW_A;
215         }
216         if (deltaS = nhgetl(tcp->seq) - nhgetl(h->tcp->seq)) {
217                 if (deltaS > 0xffff)
218                         goto raise;
219                 cp += encode(cp, deltaS);
220                 changes |= NEW_S;
221         }
222
223         /*
224          * Look for the special-case encodings.
225          */
226         switch (changes) {
227                 case 0:
228                         /*
229                          * Nothing changed. If this packet contains data and the last
230                          * one didn't, this is probably a data packet following an
231                          * ack (normal on an interactive connection) and we send it
232                          * compressed. Otherwise it's probably a retransmit,
233                          * retransmitted ack or window probe.  Send it uncompressed
234                          * in case the other side missed the compressed version.
235                          */
236                         if (nhgets(ip->length) == nhgets(h->ip->length) ||
237                                 nhgets(h->ip->length) != hlen)
238                                 goto raise;
239                         break;
240                 case SPECIAL_I:
241                 case SPECIAL_D:
242                         /*
243                          * Actual changes match one of our special case encodings --
244                          * send packet uncompressed.
245                          */
246                         goto raise;
247                 case NEW_S | NEW_A:
248                         if (deltaS == deltaA && deltaS == nhgets(h->ip->length) - hlen) {
249                                 /* special case for echoed terminal traffic */
250                                 changes = SPECIAL_I;
251                                 cp = new_seq;
252                         }
253                         break;
254                 case NEW_S:
255                         if (deltaS == nhgets(h->ip->length) - hlen) {
256                                 /* special case for data xfer */
257                                 changes = SPECIAL_D;
258                                 cp = new_seq;
259                         }
260                         break;
261         }
262         deltaS = nhgets(ip->id) - nhgets(h->ip->id);
263         if (deltaS != 1) {
264                 cp += encode(cp, deltaS);
265                 changes |= NEW_I;
266         }
267         if (tcp->flag[1] & PSH)
268                 changes |= TCP_PUSH_BIT;
269         /*
270          * Grab the cksum before we overwrite it below. Then update our
271          * state with this packet's header.
272          */
273         deltaA = nhgets(tcp->cksum);
274         memmove(h->buf, b->rp, hlen);
275         h->len = hlen;
276         h->tcp = (Tcphdr *) (h->buf + iplen);
277
278         /*
279          * We want to use the original packet as our compressed packet. (cp -
280          * new_seq) is the number of bytes we need for compressed sequence
281          * numbers. In addition we need one byte for the change mask, one
282          * for the connection id and two for the tcp checksum. So, (cp -
283          * new_seq) + 4 bytes of header are needed. hlen is how many bytes
284          * of the original packet to toss so subtract the two to get the new
285          * packet size. The temporaries are gross -egs.
286          */
287         deltaS = cp - new_seq;
288         cp = b->rp;
289         if (comp->lastxmit != j || comp->compressid == 0) {
290                 comp->lastxmit = j;
291                 hlen -= deltaS + 4;
292                 cp += hlen;
293                 *cp++ = (changes | NEW_C);
294                 *cp++ = j;
295         } else {
296                 hlen -= deltaS + 3;
297                 cp += hlen;
298                 *cp++ = changes;
299         }
300         b->rp += hlen;
301         hnputs(cp, deltaA);
302         cp += 2;
303         memmove(cp, new_seq, deltaS);
304         return Pvjctcp;
305
306 raise:
307         /*
308          * Update connection state & send uncompressed packet
309          */
310         memmove(h->buf, b->rp, hlen);
311         h->tcp = (Tcphdr *) (h->buf + iplen);
312         h->len = hlen;
313         h->ip->proto = j;
314         comp->lastxmit = j;
315         return Pvjutcp;
316 }
317
318 struct block *tcpuncompress(Tcpc * comp, struct block *b, uint16_t type,
319                                                         struct Fs *f)
320 {
321         uint8_t *cp, changes;
322         int i;
323         int iplen, len;
324         Iphdr *ip;
325         Tcphdr *tcp;
326         Hdr *h;
327
328         if (type == Pvjutcp) {
329                 /*
330                  *  Locate the saved state for this connection. If the state
331                  *  index is legal, clear the 'discard' flag.
332                  */
333                 ip = (Iphdr *) b->rp;
334                 if (ip->proto >= MAX_STATES)
335                         goto raise;
336                 iplen = (ip->vihl & 0xf) << 2;
337                 tcp = (Tcphdr *) (b->rp + iplen);
338                 comp->lastrecv = ip->proto;
339                 len = iplen + ((tcp->flag[0] & 0xf0) >> 2);
340                 comp->err = 0;
341                 netlog(f, Logcompress, "uncompressed %d\n", comp->lastrecv);
342                 /*
343                  * Restore the IP protocol field then save a copy of this
344                  * packet header. The checksum is zeroed in the copy so we
345                  * don't have to zero it each time we process a compressed
346                  * packet.
347                  */
348                 ip->proto = IP_TCPPROTO;
349                 h = &comp->r[comp->lastrecv];
350                 memmove(h->buf, b->rp, len);
351                 h->tcp = (Tcphdr *) (h->buf + iplen);
352                 h->len = len;
353                 h->ip->cksum[0] = h->ip->cksum[1] = 0;
354                 return b;
355         }
356
357         cp = b->rp;
358         changes = *cp++;
359         if (changes & NEW_C) {
360                 /*
361                  * Make sure the state index is in range, then grab the
362                  * state. If we have a good state index, clear the 'discard'
363                  * flag.
364                  */
365                 if (*cp >= MAX_STATES)
366                         goto raise;
367                 comp->err = 0;
368                 comp->lastrecv = *cp++;
369                 netlog(f, Logcompress, "newc %d\n", comp->lastrecv);
370         } else {
371                 /*
372                  * This packet has no state index. If we've had a
373                  * line error since the last time we got an explicit state
374                  * index, we have to toss the packet.
375                  */
376                 if (comp->err != 0) {
377                         freeblist(b);
378                         return NULL;
379                 }
380                 netlog(f, Logcompress, "oldc %d\n", comp->lastrecv);
381         }
382
383         /*
384          * Find the state then fill in the TCP checksum and PUSH bit.
385          */
386         h = &comp->r[comp->lastrecv];
387         ip = h->ip;
388         tcp = h->tcp;
389         len = h->len;
390         memmove(tcp->cksum, cp, sizeof tcp->cksum);
391         cp += 2;
392         if (changes & TCP_PUSH_BIT)
393                 tcp->flag[1] |= PSH;
394         else
395                 tcp->flag[1] &= ~PSH;
396         /*
397          * Fix up the state's ack, seq, urg and win fields based on the
398          * changemask.
399          */
400         switch (changes & SPECIALS_MASK) {
401                 case SPECIAL_I:
402                         i = nhgets(ip->length) - len;
403                         hnputl(tcp->ack, nhgetl(tcp->ack) + i);
404                         hnputl(tcp->seq, nhgetl(tcp->seq) + i);
405                         break;
406
407                 case SPECIAL_D:
408                         hnputl(tcp->seq, nhgetl(tcp->seq) + nhgets(ip->length) - len);
409                         break;
410
411                 default:
412                         if (changes & NEW_U) {
413                                 tcp->flag[1] |= URG;
414                                 if (*cp == 0) {
415                                         hnputs(tcp->urg, nhgets(cp + 1));
416                                         cp += 3;
417                                 } else {
418                                         hnputs(tcp->urg, *cp++);
419                                 }
420                         } else {
421                                 tcp->flag[1] &= ~URG;
422                         }
423                         if (changes & NEW_W)
424                                 DECODES(tcp->win)
425                                         if (changes & NEW_A)
426                                         DECODEL(tcp->ack)
427                                                 if (changes & NEW_S)
428                                                 DECODEL(tcp->seq)
429                                                         break;
430         }
431
432         /* Update the IP ID */
433         if (changes & NEW_I)
434                 DECODES(ip->id)
435                         else
436                 hnputs(ip->id, nhgets(ip->id) + 1);
437
438         /*
439          *  At this po int unused_int, cp points to the first byte of data in the packet.
440          *  Back up cp by the TCP/IP header length to make room for the
441          *  reconstructed header.
442          *  We assume the packet we were handed has enough space to prepend
443          *  up to 128 bytes of header.
444          */
445         b->rp = cp;
446         if (b->rp - b->base < len) {
447                 b = padblock(b, len);
448                 b = pullupblock(b, blocklen(b));
449         } else
450                 b->rp -= len;
451         hnputs(ip->length, BLEN(b));
452         memmove(b->rp, ip, len);
453
454         /* recompute the ip header checksum */
455         ip = (Iphdr *) b->rp;
456         hnputs(ip->cksum, ipcsum(b->rp));
457         return b;
458
459 raise:
460         netlog(f, Logcompress, "Bad Packet!\n");
461         comp->err = 1;
462         freeblist(b);
463         return NULL;
464 }
465
466 Tcpc *compress_init(Tcpc * c)
467 {
468         int i;
469         Hdr *h;
470
471         if (c == NULL) {
472                 c = kzmalloc(sizeof(Tcpc), 0);
473                 if (c == NULL)
474                         return NULL;
475         }
476         memset(c, 0, sizeof(*c));
477         for (i = 0; i < MAX_STATES; i++) {
478                 h = &c->t[i];
479                 h->ip = (Iphdr *) h->buf;
480                 h->tcp = (Tcphdr *) (h->buf + 10);
481                 h->len = 20;
482                 h = &c->r[i];
483                 h->ip = (Iphdr *) h->buf;
484                 h->tcp = (Tcphdr *) (h->buf + 10);
485                 h->len = 20;
486         }
487
488         return c;
489 }
490
491 uint16_t compress(Tcpc * tcp, struct block * b, struct Fs * f)
492 {
493         Iphdr *ip;
494
495         /*
496          * Bail if this is not a compressible IP packet
497          */
498         ip = (Iphdr *) b->rp;
499         if ((nhgets(ip->frag) & 0x3fff) != 0)
500                 return Pip;
501
502         switch (ip->proto) {
503                 case IP_TCPPROTO:
504                         return tcpcompress(tcp, b, f);
505                 default:
506                         return Pip;
507         }
508 }
509
510 int compress_negotiate(Tcpc * tcp, uint8_t * data)
511 {
512         if (data[0] != MAX_STATES - 1)
513                 return -1;
514         tcp->compressid = data[1];
515         return 0;
516 }