Implement TSO
[akaros.git] / kern / src / net / tcp.c
index 5f8c942..ec405c4 100644 (file)
@@ -80,6 +80,7 @@ enum {
        RETRAN = 4,
        ACTIVE = 8,
        SYNACK = 16,
+       TSO = 32,
 
        LOGAGAIN = 3,
        LOGDGAIN = 2,
@@ -602,7 +603,7 @@ void tcpacktimer(void *v)
 static void tcpcreate(struct conv *c)
 {
        c->rq = qopen(QMAX, Qcoalesce, tcpacktimer, c);
-       c->wq = qopen((3 * QMAX) / 2, Qkick, tcpkick, c);
+       c->wq = qopen(8 * QMAX, Qkick, tcpkick, c);
 }
 
 static void timerstate(struct tcppriv *priv, Tcptimer * t, int newstate)
@@ -736,17 +737,20 @@ void localclose(struct conv *s, char *reason)
 
        if (tcb->state == Syn_sent)
                Fsconnected(s, reason);
-       if (s->state == Announced)
-               rendez_wakeup(&s->listenr);
 
        qhangup(s->rq, reason);
        qhangup(s->wq, reason);
 
        tcpsetstate(s, Closed);
+
+       /* listener will check the rq state */
+       if (s->state == Announced)
+               rendez_wakeup(&s->listenr);
 }
 
 /* mtu (- TCP + IP hdr len) of 1st hop */
-int tcpmtu(struct Proto *tcp, uint8_t * addr, int version, int *scale)
+int tcpmtu(struct Proto *tcp, uint8_t * addr, int version, int *scale,
+          uint8_t *flags)
 {
        struct Ipifc *ifc;
        int mtu;
@@ -765,6 +769,8 @@ int tcpmtu(struct Proto *tcp, uint8_t * addr, int version, int *scale)
                                mtu = ifc->maxtu - ifc->m->hsize - (TCP6_PKT + TCP6_HDRSIZE);
                        break;
        }
+       *flags &= ~TSO;
+
        if (ifc != NULL) {
                if (ifc->mbps > 100)
                        *scale = HaveWS | 3;
@@ -772,6 +778,8 @@ int tcpmtu(struct Proto *tcp, uint8_t * addr, int version, int *scale)
                        *scale = HaveWS | 1;
                else
                        *scale = HaveWS | 0;
+               if (ifc->feat & NETF_TSO)
+                       *flags |= TSO;
        } else
                *scale = HaveWS | 0;
 
@@ -854,14 +862,15 @@ void tcpstart(struct conv *s, int mode)
 {
        Tcpctl *tcb;
        struct tcppriv *tpriv;
-       char kpname[KNAMELEN];
+       /* tcpackproc needs to free this if it ever exits */
+       char *kpname = kmalloc(KNAMELEN, KMALLOC_WAIT);
 
        tpriv = s->p->priv;
 
        if (tpriv->ackprocstarted == 0) {
                qlock(&tpriv->apl);
                if (tpriv->ackprocstarted == 0) {
-                       snprintf(kpname, sizeof(kpname), "#I%dtcpack", s->p->f->dev);
+                       snprintf(kpname, KNAMELEN, "#I%dtcpack", s->p->f->dev);
                        ktask(kpname, tcpackproc, s->p);
                        tpriv->ackprocstarted = 1;
                }
@@ -1060,8 +1069,11 @@ struct block *htontcp4(Tcp * tcph, struct block *data, Tcp4hdr * ph,
        if (tcb != NULL && tcb->nochecksum) {
                h->tcpcksum[0] = h->tcpcksum[1] = 0;
        } else {
-               csum = ptclcsum(data, TCP4_IPLEN, hdrlen + dlen + TCP4_PHDRSIZE);
+               csum = ~ptclcsum(data, TCP4_IPLEN, TCP4_PHDRSIZE);
                hnputs(h->tcpcksum, csum);
+               data->checksum_start = TCP4_IPLEN + TCP4_PHDRSIZE;
+               data->checksum_offset = ph->tcpcksum - ph->tcpsport;
+               data->flag |= Btcpck;
        }
 
        return data;
@@ -1207,7 +1219,8 @@ void tcpsndsyn(struct conv *s, Tcpctl * tcb)
        tcb->sndsyntime = NOW;
 
        /* set desired mss and scale */
-       tcb->mss = tcpmtu(s->p, s->laddr, s->ipversion, &tcb->scale);
+       tcb->mss = tcpmtu(s->p, s->laddr, s->ipversion, &tcb->scale,
+                         &tcb->flags);
 }
 
 void
@@ -1309,7 +1322,7 @@ char *tcphangup(struct conv *s)
                poperror();
                return commonerror();
        }
-       if (s->raddr != 0) {
+       if (ipcmp(s->raddr, IPnoaddr)) {
                /* discard error style, poperror regardless */
                if (!waserror()) {
                        seg.flags = RST | ACK;
@@ -1352,6 +1365,7 @@ int sndsynack(struct Proto *tcp, Limbo * lp)
        Tcp6hdr ph6;
        Tcp seg;
        int scale;
+       uint8_t flag = 0;
 
        /* make pseudo header */
        switch (lp->version) {
@@ -1383,7 +1397,7 @@ int sndsynack(struct Proto *tcp, Limbo * lp)
        seg.ack = lp->irs + 1;
        seg.flags = SYN | ACK;
        seg.urg = 0;
-       seg.mss = tcpmtu(tcp, lp->laddr, lp->version, &scale);
+       seg.mss = tcpmtu(tcp, lp->laddr, lp->version, &scale, &flag);
        seg.wnd = QMAX;
 
        /* if the other side set scale, we should too */
@@ -2439,8 +2453,33 @@ void tcpoutput(struct conv *s)
                                   tcb->snd.wnd, tcb->cwind);
                if (usable < ssize)
                        ssize = usable;
-               if (tcb->mss < ssize)
-                       ssize = tcb->mss;
+               if (ssize > tcb->mss) {
+                       if ((tcb->flags & TSO) == 0) {
+                               ssize = tcb->mss;
+                       } else {
+                               int segs, window;
+
+                               /*  Don't send too much.  32K is arbitrary..
+                                */
+                               if (ssize > 32 * 1024)
+                                       ssize = 32 * 1024;
+
+                               /* Clamp xmit to an integral MSS to
+                                * avoid ragged tail segments causing
+                                * poor link utilization.  Also
+                                * account for each segment sent in
+                                * msg heuristic, and round up to the
+                                * next multiple of 4, to ensure we
+                                * still yeild.
+                                */
+                               segs = ssize / tcb->mss;
+                               ssize = segs * tcb->mss;
+                               msgs += segs;
+                               if (segs > 3)
+                                       msgs = (msgs + 4) & ~3;
+                       }
+               }
+
                dsize = ssize;
                seg.urg = 0;
 
@@ -2496,6 +2535,10 @@ void tcpoutput(struct conv *s)
                                seg.flags |= FIN;
                                dsize--;
                        }
+                       if (BLEN(bp) > tcb->mss) {
+                               bp->flag |= Btso;
+                               bp->mss = tcb->mss;
+                       }
                }
 
                if (sent + dsize == sndcnt)
@@ -3051,7 +3094,7 @@ int tcpstats(struct Proto *tcp, char *buf, int len)
        p = buf;
        e = p + len;
        for (i = 0; i < Nstats; i++)
-               p = seprintf(p, e, "%s: %lu\n", statnames[i], priv->stats[i]);
+               p = seprintf(p, e, "%s: %u\n", statnames[i], priv->stats[i]);
        return p - buf;
 }