Getting more of the IP stack to compile
[akaros.git] / kern / src / net / compress.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 typedef struct Iphdr    Iphdr;
17 typedef struct Tcphdr   Tcphdr;
18 typedef struct Ilhdr    Ilhdr;
19 typedef struct Hdr      Hdr;
20 typedef struct Tcpc     Tcpc;
21
22 struct Iphdr
23 {
24         uint8_t vihl;           /* Version and header length */
25         uint8_t tos;            /* Type of service */
26         uint8_t length[2];      /* packet length */
27         uint8_t id[2];          /* Identification */
28         uint8_t frag[2];        /* Fragment information */
29         uint8_t ttl;            /* Time to live */
30         uint8_t proto;          /* Protocol */
31         uint8_t cksum[2];       /* Header checksum */
32         uint32_t        src;            /* Ip source (byte ordering unimportant) */
33         uint32_t        dst;            /* Ip destination (byte ordering unimportant) */
34 };
35
36 struct Tcphdr
37 {
38         uint32_t        ports;          /* defined as a uint32_t to make comparisons easier */
39         uint8_t seq[4];
40         uint8_t ack[4];
41         uint8_t flag[2];
42         uint8_t win[2];
43         uint8_t cksum[2];
44         uint8_t urg[2];
45 };
46
47 struct Ilhdr
48 {
49         uint8_t sum[2]; /* Checksum including header */
50         uint8_t len[2]; /* Packet length */
51         uint8_t type;           /* Packet type */
52         uint8_t spec;           /* Special */
53         uint8_t src[2]; /* Src port */
54         uint8_t dst[2]; /* Dst port */
55         uint8_t id[4];  /* Sequence id */
56         uint8_t ack[4]; /* Acked sequence */
57 };
58
59 enum
60 {
61         URG             = 0x20,         /* Data marked urgent */
62         ACK             = 0x10,         /* Aknowledge is valid */
63         PSH             = 0x08,         /* Whole data pipe is pushed */
64         RST             = 0x04,         /* Reset connection */
65         SYN             = 0x02,         /* Pkt. is synchronise */
66         FIN             = 0x01,         /* Start close down */
67
68         IP_DF           = 0x4000,       /* Don't fragment */
69
70         IP_TCPPROTO     = 6,
71         IP_ILPROTO      = 40,
72         IL_IPHDR        = 20,
73 };
74
75 struct Hdr
76 {
77         uint8_t buf[128];
78         Iphdr   *ip;
79         Tcphdr  *tcp;
80         int     len;
81 };
82
83 struct Tcpc
84 {
85         uint8_t lastrecv;
86         uint8_t lastxmit;
87         uint8_t basexmit;
88         uint8_t err;
89         uint8_t compressid;
90         Hdr     t[MAX_STATES];
91         Hdr     r[MAX_STATES];
92 };
93
94 enum
95 {       /* flag bits for what changed in a packet */
96         NEW_U=(1<<0),   /* tcp only */
97         NEW_W=(1<<1),   /* tcp only */
98         NEW_A=(1<<2),   /* il tcp */
99         NEW_S=(1<<3),   /* tcp only */
100         NEW_P=(1<<4),   /* tcp only */
101         NEW_I=(1<<5),   /* il tcp */
102         NEW_C=(1<<6),   /* il tcp */
103         NEW_T=(1<<7),   /* il only */
104         TCP_PUSH_BIT    = 0x10,
105 };
106
107 /* reserved, special-case values of above for tcp */
108 #define SPECIAL_I (NEW_S|NEW_W|NEW_U)           /* echoed interactive traffic */
109 #define SPECIAL_D (NEW_S|NEW_A|NEW_W|NEW_U)     /* unidirectional data */
110 #define SPECIALS_MASK (NEW_S|NEW_A|NEW_W|NEW_U)
111
112 int
113 encode(void *p, uint32_t n)
114 {
115         uint8_t *cp;
116
117         cp = p;
118         if(n >= 256 || n == 0) {
119                 *cp++ = 0;
120                 cp[0] = n >> 8;
121                 cp[1] = n;
122                 return 3;
123         } else 
124                 *cp = n;
125         return 1;
126 }
127
128 #define DECODEL(f) { \
129         if (*cp == 0) {\
130                 hnputl(f, nhgetl(f) + ((cp[1] << 8) | cp[2])); \
131                 cp += 3; \
132         } else { \
133                 hnputl(f, nhgetl(f) + (uint32_t)*cp++); \
134         } \
135 }
136 #define DECODES(f) { \
137         if (*cp == 0) {\
138                 hnputs(f, nhgets(f) + ((cp[1] << 8) | cp[2])); \
139                 cp += 3; \
140         } else { \
141                 hnputs(f, nhgets(f) + (uint32_t)*cp++); \
142         } \
143 }
144
145 uint16_t
146 tcpcompress(Tcpc *comp, struct block *b, struct Fs *)
147 {
148         Iphdr   *ip;            /* current packet */
149         Tcphdr  *tcp;           /* current pkt */
150         uint32_t        iplen, tcplen, hlen;    /* header length in bytes */
151         uint32_t        deltaS, deltaA; /* general purpose temporaries */
152         uint32_t        changes;        /* change mask */
153         uint8_t new_seq[16];    /* changes from last to current */
154         uint8_t *cp;
155         Hdr     *h;             /* last packet */
156         int     i, j;
157
158         /*
159          * Bail if this is not a compressible TCP/IP packet
160          */
161         ip = (Iphdr*)b->rp;
162         iplen = (ip->vihl & 0xf) << 2;
163         tcp = (Tcphdr*)(b->rp + iplen);
164         tcplen = (tcp->flag[0] & 0xf0) >> 2;
165         hlen = iplen + tcplen;
166         if((tcp->flag[1] & (SYN|FIN|RST|ACK)) != ACK)
167                 return Pip;     /* connection control */
168
169         /*
170          * Packet is compressible, look for a connection
171          */
172         changes = 0;
173         cp = new_seq;
174         j = comp->lastxmit;
175         h = &comp->t[j];
176         if(ip->src != h->ip->src || ip->dst != h->ip->dst
177         || tcp->ports != h->tcp->ports) {
178                 for(i = 0; i < MAX_STATES; ++i) {
179                         j = (comp->basexmit + i) % MAX_STATES;
180                         h = &comp->t[j];
181                         if(ip->src == h->ip->src && ip->dst == h->ip->dst
182                         && tcp->ports == h->tcp->ports)
183                                 goto found;
184                 }
185
186                 /* no connection, reuse the oldest */
187                 if(i == MAX_STATES) {
188                         j = comp->basexmit;
189                         j = (j + MAX_STATES - 1) % MAX_STATES;
190                         comp->basexmit = j;
191                         h = &comp->t[j];
192                         goto raise;
193                 }
194         }
195 found:
196
197         /*
198          * Make sure that only what we expect to change changed. 
199          */
200         if(ip->vihl  != h->ip->vihl || ip->tos   != h->ip->tos ||
201            ip->ttl   != h->ip->ttl  || ip->proto != h->ip->proto)
202                 goto raise;     /* headers changed */
203         if(iplen != sizeof(Iphdr) && memcmp(ip+1, h->ip+1, iplen - sizeof(Iphdr)))
204                 goto raise;     /* ip options changed */
205         if(tcplen != sizeof(Tcphdr) && memcmp(tcp+1, h->tcp+1, tcplen - sizeof(Tcphdr)))
206                 goto raise;     /* tcp options changed */
207
208         if(tcp->flag[1] & URG) {
209                 cp += encode(cp, nhgets(tcp->urg));
210                 changes |= NEW_U;
211         } else if(memcmp(tcp->urg, h->tcp->urg, sizeof(tcp->urg)) != 0)
212                 goto raise;
213         if(deltaS = nhgets(tcp->win) - nhgets(h->tcp->win)) {
214                 cp += encode(cp, deltaS);
215                 changes |= NEW_W;
216         }
217         if(deltaA = nhgetl(tcp->ack) - nhgetl(h->tcp->ack)) {
218                 if(deltaA > 0xffff)
219                         goto raise;
220                 cp += encode(cp, deltaA);
221                 changes |= NEW_A;
222         }
223         if(deltaS = nhgetl(tcp->seq) - nhgetl(h->tcp->seq)) {
224                 if (deltaS > 0xffff)
225                         goto raise;
226                 cp += encode(cp, deltaS);
227                 changes |= NEW_S;
228         }
229
230         /*
231          * Look for the special-case encodings.
232          */
233         switch(changes) {
234         case 0:
235                 /*
236                  * Nothing changed. If this packet contains data and the last
237                  * one didn't, this is probably a data packet following an
238                  * ack (normal on an interactive connection) and we send it
239                  * compressed. Otherwise it's probably a retransmit,
240                  * retransmitted ack or window probe.  Send it uncompressed
241                  * in case the other side missed the compressed version.
242                  */
243                 if(nhgets(ip->length) == nhgets(h->ip->length) ||
244                    nhgets(h->ip->length) != hlen)
245                         goto raise;
246                 break;
247         case SPECIAL_I:
248         case SPECIAL_D:
249                 /*
250                  * Actual changes match one of our special case encodings --
251                  * send packet uncompressed.
252                  */
253                 goto raise;
254         case NEW_S | NEW_A:
255                 if (deltaS == deltaA &&
256                         deltaS == nhgets(h->ip->length) - hlen) {
257                         /* special case for echoed terminal traffic */
258                         changes = SPECIAL_I;
259                         cp = new_seq;
260                 }
261                 break;
262         case NEW_S:
263                 if (deltaS == nhgets(h->ip->length) - hlen) {
264                         /* special case for data xfer */
265                         changes = SPECIAL_D;
266                         cp = new_seq;
267                 }
268                 break;
269         }
270         deltaS = nhgets(ip->id) - nhgets(h->ip->id);
271         if(deltaS != 1) {
272                 cp += encode(cp, deltaS);
273                 changes |= NEW_I;
274         }
275         if (tcp->flag[1] & PSH)
276                 changes |= TCP_PUSH_BIT;
277         /*
278          * Grab the cksum before we overwrite it below. Then update our
279          * state with this packet's header.
280          */
281         deltaA = nhgets(tcp->cksum);
282         memmove(h->buf, b->rp, hlen);
283         h->len = hlen;
284         h->tcp = (Tcphdr*)(h->buf + iplen);
285
286         /*
287          * We want to use the original packet as our compressed packet. (cp -
288          * new_seq) is the number of bytes we need for compressed sequence
289          * numbers. In addition we need one byte for the change mask, one
290          * for the connection id and two for the tcp checksum. So, (cp -
291          * new_seq) + 4 bytes of header are needed. hlen is how many bytes
292          * of the original packet to toss so subtract the two to get the new
293          * packet size. The temporaries are gross -egs.
294          */
295         deltaS = cp - new_seq;
296         cp = b->rp;
297         if(comp->lastxmit != j || comp->compressid == 0) {
298                 comp->lastxmit = j;
299                 hlen -= deltaS + 4;
300                 cp += hlen;
301                 *cp++ = (changes | NEW_C);
302                 *cp++ = j;
303         } else {
304                 hlen -= deltaS + 3;
305                 cp += hlen;
306                 *cp++ = changes;
307         }
308         b->rp += hlen;
309         hnputs(cp, deltaA);
310         cp += 2;
311         memmove(cp, new_seq, deltaS);
312         return Pvjctcp;
313
314 raise:
315         /*
316          * Update connection state & send uncompressed packet
317          */
318         memmove(h->buf, b->rp, hlen);
319         h->tcp = (Tcphdr*)(h->buf + iplen);
320         h->len = hlen;
321         h->ip->proto = j;
322         comp->lastxmit = j;
323         return Pvjutcp;
324 }
325
326 struct block*
327 tcpuncompress(Tcpc *comp, struct block *b, uint16_t type, struct Fs *f)
328 {
329         uint8_t *cp, changes;
330         int     i;
331         int     iplen, len;
332         Iphdr   *ip;
333         Tcphdr  *tcp;
334         Hdr     *h;
335
336         if(type == Pvjutcp) {
337                 /*
338                  *  Locate the saved state for this connection. If the state
339                  *  index is legal, clear the 'discard' flag.
340                  */
341                 ip = (Iphdr*)b->rp;
342                 if(ip->proto >= MAX_STATES)
343                         goto raise;
344                 iplen = (ip->vihl & 0xf) << 2;
345                 tcp = (Tcphdr*)(b->rp + iplen);
346                 comp->lastrecv = ip->proto;
347                 len = iplen + ((tcp->flag[0] & 0xf0) >> 2);
348                 comp->err = 0;
349 netlog(f, Logcompress, "uncompressed %d\n", comp->lastrecv);
350                 /*
351                  * Restore the IP protocol field then save a copy of this
352                  * packet header. The checksum is zeroed in the copy so we
353                  * don't have to zero it each time we process a compressed
354                  * packet.
355                  */
356                 ip->proto = IP_TCPPROTO;
357                 h = &comp->r[comp->lastrecv];
358                 memmove(h->buf, b->rp, len);
359                 h->tcp = (Tcphdr*)(h->buf + iplen);
360                 h->len = len;
361                 h->ip->cksum[0] = h->ip->cksum[1] = 0;
362                 return b;
363         }
364
365         cp = b->rp;
366         changes = *cp++;
367         if(changes & NEW_C) {
368                 /*
369                  * Make sure the state index is in range, then grab the
370                  * state. If we have a good state index, clear the 'discard'
371                  * flag.
372                  */
373                 if(*cp >= MAX_STATES)
374                         goto raise;
375                 comp->err = 0;
376                 comp->lastrecv = *cp++;
377 netlog(f, Logcompress, "newc %d\n", comp->lastrecv);
378         } else {
379                 /*
380                  * This packet has no state index. If we've had a
381                  * line error since the last time we got an explicit state
382                  * index, we have to toss the packet.
383                  */
384                 if(comp->err != 0){
385                         freeblist(b);
386                         return NULL;
387                 }
388 netlog(f, Logcompress, "oldc %d\n", comp->lastrecv);
389         }
390
391         /*
392          * Find the state then fill in the TCP checksum and PUSH bit.
393          */
394         h = &comp->r[comp->lastrecv];
395         ip = h->ip;
396         tcp = h->tcp;
397         len = h->len;
398         memmove(tcp->cksum, cp, sizeof tcp->cksum);
399         cp += 2;
400         if(changes & TCP_PUSH_BIT)
401                 tcp->flag[1] |= PSH;
402         else
403                 tcp->flag[1] &= ~PSH;
404         /*
405          * Fix up the state's ack, seq, urg and win fields based on the
406          * changemask.
407          */
408         switch (changes & SPECIALS_MASK) {
409         case SPECIAL_I:
410                 i = nhgets(ip->length) - len;
411                 hnputl(tcp->ack, nhgetl(tcp->ack) + i);
412                 hnputl(tcp->seq, nhgetl(tcp->seq) + i);
413                 break;
414
415         case SPECIAL_D:
416                 hnputl(tcp->seq, nhgetl(tcp->seq) + nhgets(ip->length) - len);
417                 break;
418
419         default:
420                 if(changes & NEW_U) {
421                         tcp->flag[1] |= URG;
422                         if(*cp == 0){
423                                 hnputs(tcp->urg, nhgets(cp+1));
424                                 cp += 3;
425                         }else
426                                 hnputs(tcp->urg, *cp++);
427                 } else
428                         tcp->flag[1] &= ~URG;
429                 if(changes & NEW_W)
430                         DECODES(tcp->win)
431                 if(changes & NEW_A)
432                         DECODEL(tcp->ack)
433                 if(changes & NEW_S)
434                         DECODEL(tcp->seq)
435                 break;
436         }
437
438         /* Update the IP ID */
439         if(changes & NEW_I)
440                 DECODES(ip->id)
441         else
442                 hnputs(ip->id, nhgets(ip->id) + 1);
443
444         /*
445          *  At this po int unused_int, cp points to the first byte of data in the packet.
446          *  Back up cp by the TCP/IP header length to make room for the
447          *  reconstructed header.
448          *  We assume the packet we were handed has enough space to prepend
449          *  up to 128 bytes of header.
450          */
451         b->rp = cp;
452         if(b->rp - b->base < len){
453                 b = padblock(b, len);
454                 b = pullupblock(b, blocklen(b));
455         } else
456                 b->rp -= len;
457         hnputs(ip->length, BLEN(b));
458         memmove(b->rp, ip, len);
459         
460         /* recompute the ip header checksum */
461         ip = (Iphdr*)b->rp;
462         hnputs(ip->cksum, ipcsum(b->rp));
463         return b;
464
465 raise:
466         netlog(f, Logcompress, "Bad Packet!\n");
467         comp->err = 1;
468         freeblist(b);
469         return NULL;
470 }
471
472 Tcpc*
473 compress_init(Tcpc *c)
474 {
475         int i;
476         Hdr *h;
477
478         if(c == NULL){
479                 c = kzmalloc(sizeof(Tcpc), 0);
480                 if(c == NULL)
481                         return NULL;
482         }
483         memset(c, 0, sizeof(*c));
484         for(i = 0; i < MAX_STATES; i++){
485                 h = &c->t[i];
486                 h->ip = (Iphdr*)h->buf;
487                 h->tcp = (Tcphdr*)(h->buf + 10);
488                 h->len = 20;
489                 h = &c->r[i];
490                 h->ip = (Iphdr*)h->buf;
491                 h->tcp = (Tcphdr*)(h->buf + 10);
492                 h->len = 20;
493         }
494
495         return c;
496 }
497
498 uint16_t
499 compress(Tcpc *tcp, struct block *b, struct Fs *f)
500 {
501         Iphdr           *ip;
502
503         /*
504          * Bail if this is not a compressible IP packet
505          */
506         ip = (Iphdr*)b->rp;
507         if((nhgets(ip->frag) & 0x3fff) != 0)
508                 return Pip;
509
510         switch(ip->proto) {
511         case IP_TCPPROTO:
512                 return tcpcompress(tcp, b, f);
513         default:
514                 return Pip;
515         }
516 }
517
518 int
519 compress_negotiate(Tcpc *tcp, uint8_t *data)
520 {
521         if(data[0] != MAX_STATES - 1)
522                 return -1;
523         tcp->compressid = data[1];
524         return 0;
525 }