net: Fix double-free snoop bug
[akaros.git] / kern / src / net / tcp.c
index 6dea5ce..d544fd9 100644 (file)
@@ -463,16 +463,10 @@ void tcpsetstate(struct conv *s, uint8_t newstate)
                Fsconnected(s, NULL);
 }
 
-static char *tcpconnect(struct conv *c, char **argv, int argc)
+static void tcpconnect(struct conv *c, char **argv, int argc)
 {
-       char *e;
-
-       e = Fsstdconnect(c, argv, argc);
-       if (e != NULL)
-               return e;
+       Fsstdconnect(c, argv, argc);
        tcpstart(c, TCP_CONNECT);
-
-       return NULL;
 }
 
 static int tcpstate(struct conv *c, char *state, int n)
@@ -500,17 +494,40 @@ static int tcpinuse(struct conv *c)
        return s->state != Closed;
 }
 
-static char *tcpannounce(struct conv *c, char **argv, int argc)
+static void tcpannounce(struct conv *c, char **argv, int argc)
 {
-       char *e;
-
-       e = Fsstdannounce(c, argv, argc);
-       if (e != NULL)
-               return e;
+       Fsstdannounce(c, argv, argc);
        tcpstart(c, TCP_LISTEN);
        Fsconnected(c, NULL);
+}
+
+static void tcpbypass(struct conv *cv, char **argv, int argc)
+{
+       struct tcppriv *tpriv = cv->p->priv;
 
-       return NULL;
+       Fsstdbypass(cv, argv, argc);
+       iphtadd(&tpriv->ht, cv);
+}
+
+static void tcpshutdown(struct conv *c, int how)
+{
+       Tcpctl *tcb = (Tcpctl*)c->ptcl;
+
+       /* Do nothing for the read side */
+       if (how == SHUT_RD)
+               return;
+       /* Sends a FIN.  If we're in another state (like Listen), we'll run into
+        * issues, since we'll never send the FIN.  We'll be shutdown on our end,
+        * but we'll never tell the distant end.  Might just be an app issue. */
+       switch (tcb->state) {
+       case Syn_received:
+       case Established:
+               tcb->flgcnt++;
+               tcb->snd.nxt++;
+               tcpsetstate(c, Finwait1);
+               tcpoutput(c);
+               break;
+       }
 }
 
 /*
@@ -564,11 +581,11 @@ void tcpkick(void *x)
 
        tcb = (Tcpctl *) s->ptcl;
 
+       qlock(&s->qlock);
        if (waserror()) {
                qunlock(&s->qlock);
                nexterror();
        }
-       qlock(&s->qlock);
 
        switch (tcb->state) {
                case Syn_sent:
@@ -613,11 +630,11 @@ void tcpacktimer(void *v)
        s = v;
        tcb = (Tcpctl *) s->ptcl;
 
+       qlock(&s->qlock);
        if (waserror()) {
                qunlock(&s->qlock);
                nexterror();
        }
-       qlock(&s->qlock);
        if (tcb->state != Closed) {
                tcb->flags |= FORCE;
                tcprcvwin(s);
@@ -629,7 +646,7 @@ void tcpacktimer(void *v)
 
 static void tcpcreate(struct conv *c)
 {
-       c->rq = qopen(QMAX, Qcoalesce, tcpacktimer, c);
+       c->rq = qopen(QMAX, Qcoalesce, 0, 0);
        c->wq = qopen(8 * QMAX, Qkick, tcpkick, c);
 }
 
@@ -890,7 +907,7 @@ void tcpstart(struct conv *s, int mode)
        Tcpctl *tcb;
        struct tcppriv *tpriv;
        /* tcpackproc needs to free this if it ever exits */
-       char *kpname = kmalloc(KNAMELEN, KMALLOC_WAIT);
+       char *kpname = kmalloc(KNAMELEN, MEM_WAIT);
 
        tpriv = s->p->priv;
 
@@ -975,7 +992,8 @@ struct block *htontcp6(Tcp * tcph, struct block *data, Tcp6hdr * ph,
                        return NULL;
        } else {
                dlen = 0;
-               data = allocb(hdrlen + TCP6_PKT + 64);  /* the 64 pad is to meet mintu's */
+               /* the 64 pad is to meet mintu's */
+               data = block_alloc(hdrlen + TCP6_PKT + 64, MEM_WAIT);
                if (data == NULL)
                        return NULL;
                data->wp += hdrlen + TCP6_PKT;
@@ -1058,7 +1076,8 @@ struct block *htontcp4(Tcp * tcph, struct block *data, Tcp4hdr * ph,
                        return NULL;
        } else {
                dlen = 0;
-               data = allocb(hdrlen + TCP4_PKT + 64);  /* the 64 pad is to meet mintu's */
+               /* the 64 pad is to meet mintu's */
+               data = block_alloc(hdrlen + TCP4_PKT + 64, MEM_WAIT);
                if (data == NULL)
                        return NULL;
                data->wp += hdrlen + TCP4_PKT;
@@ -1235,7 +1254,7 @@ int ntohtcp4(Tcp * tcph, struct block **bpp)
  */
 void tcpsndsyn(struct conv *s, Tcpctl * tcb)
 {
-       tcb->iss = (nrand(1 << 16) << 16) | nrand(1 << 16);
+       urandom_read(&tcb->iss, sizeof(tcb->iss));
        tcb->rttseq = tcb->iss;
        tcb->snd.wl2 = tcb->iss;
        tcb->snd.una = tcb->iss;
@@ -1337,18 +1356,14 @@ sndrst(struct Proto *tcp, uint8_t * source, uint8_t * dest,
  *  send a reset to the remote side and close the conversation
  *  called with s qlocked
  */
-char *tcphangup(struct conv *s)
+static void tcphangup(struct conv *s)
 {
-       ERRSTACK(2);
+       ERRSTACK(1);
        Tcp seg;
        Tcpctl *tcb;
        struct block *hbp;
 
        tcb = (Tcpctl *) s->ptcl;
-       if (waserror()) {
-               poperror();
-               return commonerror();
-       }
        if (ipcmp(s->raddr, IPnoaddr)) {
                /* discard error style, poperror regardless */
                if (!waserror()) {
@@ -1378,8 +1393,6 @@ char *tcphangup(struct conv *s)
                poperror();
        }
        localclose(s, NULL);
-       poperror();
-       return NULL;
 }
 
 /*
@@ -1508,7 +1521,7 @@ limbo(struct conv *s, uint8_t * source, uint8_t * dest, Tcp * seg, int version)
                lp->mss = seg->mss;
                lp->rcvscale = seg->ws;
                lp->irs = seg->seq;
-               lp->iss = (nrand(1 << 16) << 16) | nrand(1 << 16);
+               urandom_read(&lp->iss, sizeof(lp->iss));
        }
 
        if (sndsynack(s->p, lp) < 0) {
@@ -1983,6 +1996,12 @@ void tcpiput(struct Proto *tcp, struct Ipifc *unused, struct block *bp)
                        return;
                }
 
+               s = iphtlook(&tpriv->ht, source, seg.source, dest, seg.dest);
+               if (s && s->state == Bypass) {
+                       bypass_or_drop(s, bp);
+                       return;
+               }
+
                /* trim the packet to the size claimed by the datagram */
                length -= hdrlen + TCP4_PKT;
                bp = trimblock(bp, hdrlen + TCP4_PKT, length);
@@ -2024,6 +2043,12 @@ void tcpiput(struct Proto *tcp, struct Ipifc *unused, struct block *bp)
                        return;
                }
 
+               s = iphtlook(&tpriv->ht, source, seg.source, dest, seg.dest);
+               if (s && s->state == Bypass) {
+                       bypass_or_drop(s, bp);
+                       return;
+               }
+
                /* trim the packet to the size claimed by the datagram */
                length -= hdrlen;
                bp = trimblock(bp, hdrlen + TCP6_PKT, length);
@@ -2035,20 +2060,19 @@ void tcpiput(struct Proto *tcp, struct Ipifc *unused, struct block *bp)
                }
        }
 
-       /* lock protocol while searching for a conversation */
-       qlock(&tcp->qlock);
-
-       /* Look for a matching conversation */
-       s = iphtlook(&tpriv->ht, source, seg.source, dest, seg.dest);
+       /* s, the conv matching the n-tuple, was set above */
        if (s == NULL) {
                netlog(f, Logtcp, "iphtlook failed\n");
 reset:
-               qunlock(&tcp->qlock);
                sndrst(tcp, source, dest, length, &seg, version, "no conversation");
                freeblist(bp);
                return;
        }
 
+       /* lock protocol for unstate Plan 9 invariants.  funcs like limbo or
+        * incoming might rely on it. */
+       qlock(&tcp->qlock);
+
        /* if it's a listener, look for the right flags and get a new conv */
        tcb = (Tcpctl *) s->ptcl;
        if (tcb->state == Listen) {
@@ -2072,8 +2096,10 @@ reset:
                 *  return it in state Syn_received
                 */
                s = tcpincoming(s, &seg, source, dest, version);
-               if (s == NULL)
+               if (s == NULL) {
+                       qunlock(&tcp->qlock);
                        goto reset;
+               }
        }
 
        /* The rest of the input state machine is run with the control block
@@ -2109,7 +2135,7 @@ reset:
                        }
                        if (seg.flags & RST) {
                                if (seg.flags & ACK)
-                                       localclose(s, errno_to_string(ECONNREFUSED));
+                                       localclose(s, "connection refused");
                                goto raise;
                        }
 
@@ -2214,7 +2240,7 @@ reset:
                                                 s->raddr, s->rport, s->laddr, s->lport, tcb->rcv.nxt,
                                                 seg.seq);
                        }
-                       localclose(s, errno_to_string(ECONNREFUSED));
+                       localclose(s, "connection refused");
                        goto raise;
                }
 
@@ -2680,7 +2706,7 @@ void tcpsendka(struct conv *s)
        seg.mss = 0;
        seg.ws = 0;
        if (tcpporthogdefense)
-               seg.seq = tcb->snd.una - (1 << 30) - nrand(1 << 20);
+               urandom_read(&seg.seq, sizeof(seg.seq));
        else
                seg.seq = tcb->snd.una - 1;
        seg.ack = tcb->rcv.nxt;
@@ -2689,7 +2715,7 @@ void tcpsendka(struct conv *s)
        if (tcb->state == Finwait2) {
                seg.flags |= FIN;
        } else {
-               dbp = allocb(1);
+               dbp = block_alloc(1, MEM_WAIT);
                dbp->wp++;
        }
 
@@ -2736,14 +2762,14 @@ void tcpkeepalive(void *v)
 
        s = v;
        tcb = (Tcpctl *) s->ptcl;
+       qlock(&s->qlock);
        if (waserror()) {
                qunlock(&s->qlock);
                nexterror();
        }
-       qlock(&s->qlock);
        if (tcb->state != Closed) {
                if (--(tcb->kacounter) <= 0) {
-                       localclose(s, errno_to_string(ETIMEDOUT));
+                       localclose(s, "connection timed out");
                } else {
                        tcpsendka(s);
                        tcpgo(s->p->priv, &tcb->katimer);
@@ -2756,14 +2782,14 @@ void tcpkeepalive(void *v)
 /*
  *  start keepalive timer
  */
-char *tcpstartka(struct conv *s, char **f, int n)
+static void tcpstartka(struct conv *s, char **f, int n)
 {
        Tcpctl *tcb;
        int x;
 
        tcb = (Tcpctl *) s->ptcl;
        if (tcb->state != Established)
-               return "connection must be in Establised state";
+               error(ENOTCONN, "connection must be in Establised state");
        if (n > 1) {
                x = atoi(f[1]);
                if (x >= MSPTICK)
@@ -2771,21 +2797,17 @@ char *tcpstartka(struct conv *s, char **f, int n)
        }
        tcpsetkacounter(tcb);
        tcpgo(s->p->priv, &tcb->katimer);
-
-       return NULL;
 }
 
 /*
  *  turn checksums on/off
  */
-char *tcpsetchecksum(struct conv *s, char **f, int unused)
+static void tcpsetchecksum(struct conv *s, char **f, int unused)
 {
        Tcpctl *tcb;
 
        tcb = (Tcpctl *) s->ptcl;
        tcb->nochecksum = !atoi(f[1]);
-
-       return NULL;
 }
 
 void tcprxmit(struct conv *s)
@@ -2822,11 +2844,11 @@ void tcptimeout(void *arg)
        tpriv = s->p->priv;
        tcb = (Tcpctl *) s->ptcl;
 
+       qlock(&s->qlock);
        if (waserror()) {
                qunlock(&s->qlock);
                nexterror();
        }
-       qlock(&s->qlock);
        switch (tcb->state) {
                default:
                        tcb->backoff++;
@@ -2836,7 +2858,7 @@ void tcptimeout(void *arg)
                                maxback = MAXBACKMS;
                        tcb->backedoff += tcb->timer.start * MSPTICK;
                        if (tcb->backedoff >= maxback) {
-                               localclose(s, errno_to_string(ETIMEDOUT));
+                               localclose(s, "connection timed out");
                                break;
                        }
                        netlog(s->p->f, Logtcprxmt, "timeout rexmit 0x%lx %llu/%llu\n",
@@ -3061,7 +3083,6 @@ void tcpadvise(struct Proto *tcp, struct block *bp, char *msg)
        }
 
        /* Look for a connection */
-       qlock(&tcp->qlock);
        for (p = tcp->conv; *p; p++) {
                s = *p;
                tcb = (Tcpctl *) s->ptcl;
@@ -3071,7 +3092,6 @@ void tcpadvise(struct Proto *tcp, struct block *bp, char *msg)
                                        if (ipcmp(s->raddr, dest) == 0)
                                                if (ipcmp(s->laddr, source) == 0) {
                                                        qlock(&s->qlock);
-                                                       qunlock(&tcp->qlock);
                                                        switch (tcb->state) {
                                                                case Syn_sent:
                                                                        localclose(s, msg);
@@ -3082,33 +3102,32 @@ void tcpadvise(struct Proto *tcp, struct block *bp, char *msg)
                                                        return;
                                                }
        }
-       qunlock(&tcp->qlock);
        freeblist(bp);
 }
 
-static char *tcpporthogdefensectl(char *val)
+static void tcpporthogdefensectl(char *val)
 {
        if (strcmp(val, "on") == 0)
                tcpporthogdefense = 1;
        else if (strcmp(val, "off") == 0)
                tcpporthogdefense = 0;
        else
-               return "unknown value for tcpporthogdefense";
-       return NULL;
+               error(EINVAL, "unknown value for tcpporthogdefense");
 }
 
 /* called with c qlocked */
-char *tcpctl(struct conv *c, char **f, int n)
+static void tcpctl(struct conv *c, char **f, int n)
 {
        if (n == 1 && strcmp(f[0], "hangup") == 0)
-               return tcphangup(c);
-       if (n >= 1 && strcmp(f[0], "keepalive") == 0)
-               return tcpstartka(c, f, n);
-       if (n >= 1 && strcmp(f[0], "checksum") == 0)
-               return tcpsetchecksum(c, f, n);
-       if (n >= 1 && strcmp(f[0], "tcpporthogdefense") == 0)
-               return tcpporthogdefensectl(f[1]);
-       return "unknown control request";
+               tcphangup(c);
+       else if (n >= 1 && strcmp(f[0], "keepalive") == 0)
+               tcpstartka(c, f, n);
+       else if (n >= 1 && strcmp(f[0], "checksum") == 0)
+               tcpsetchecksum(c, f, n);
+       else if (n >= 1 && strcmp(f[0], "tcpporthogdefense") == 0)
+               tcpporthogdefensectl(f[1]);
+       else
+               error(EINVAL, "unknown command to %s", __func__);
 }
 
 int tcpstats(struct Proto *tcp, char *buf, int len)
@@ -3196,17 +3215,19 @@ void tcpinit(struct Fs *fs)
        tcp->name = "tcp";
        tcp->connect = tcpconnect;
        tcp->announce = tcpannounce;
+       tcp->bypass = tcpbypass;
        tcp->ctl = tcpctl;
        tcp->state = tcpstate;
        tcp->create = tcpcreate;
        tcp->close = tcpclose;
+       tcp->shutdown = tcpshutdown;
        tcp->rcv = tcpiput;
        tcp->advise = tcpadvise;
        tcp->stats = tcpstats;
        tcp->inuse = tcpinuse;
        tcp->gc = tcpgc;
        tcp->ipproto = IP_TCPPROTO;
-       tcp->nc = scalednconv();
+       tcp->nc = 4096;
        tcp->ptclsize = sizeof(Tcpctl);
        tpriv->stats[MaxConn] = tcp->nc;