Preemptively fixes walk malloc errors
[akaros.git] / kern / drivers / dev / mnt.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 /*
17  * References are managed as follows:
18  * The channel to the server - a network connection or pipe - has one
19  * reference for every Chan open on the server.  The server channel has
20  * c->mux set to the Mnt used for muxing control to that server.  Mnts
21  * have no reference count; they go away when c goes away.
22  * Each channel derived from the mount point has mchan set to c,
23  * and increfs/decrefs mchan to manage references on the server
24  * connection.
25  */
26
27 #define MAXRPC (IOHDRSZ+8192)
28
29 struct mntrpc
30 {
31         struct chan*    c;              /* Channel for whom we are working */
32         struct mntrpc*  list;           /* Free/pending list */
33         struct fcall    request;        /* Outgoing file system protocol message */
34         struct fcall    reply;          /* Incoming reply */
35         struct mnt*     m;              /* Mount device during rpc */
36         struct rendez   r;              /* Place to hang out */
37         uint8_t*        rpc;            /* I/O Data buffer */
38         unsigned int            rpclen; /* len of buffer */
39         struct block    *b;             /* reply blocks */
40         char    done;           /* Rpc completed */
41         uint64_t        stime;          /* start time for mnt statistics */
42         uint32_t        reqlen;         /* request length for mnt statistics */
43         uint32_t        replen;         /* reply length for mnt statistics */
44         struct mntrpc*  flushed;        /* message this one flushes */
45 };
46
47 enum
48 {
49         TAGSHIFT = 5,                   /* uint32_t has to be 32 bits */
50         TAGMASK = (1<<TAGSHIFT)-1,
51         NMASK = (64*1024)>>TAGSHIFT,
52 };
53
54 struct Mntalloc
55 {
56         spinlock_t      l;
57         struct mnt*     list;           /* Mount devices in use */
58         struct mnt*     mntfree;        /* Free list */
59         struct mntrpc*  rpcfree;
60         int     nrpcfree;
61         int     nrpcused;
62         uint32_t        id;
63         uint32_t        tagmask[NMASK];
64 }mntalloc;
65
66 void    mattach(struct mnt*, struct chan*, char *unused_char_p_t);
67 struct mnt*     mntchk(struct chan*);
68 void    mntdirfix( uint8_t *unused_uint8_p_t, struct chan*);
69 struct mntrpc*  mntflushalloc(struct mntrpc*, uint32_t);
70 void    mntflushfree(struct mnt*, struct mntrpc*);
71 void    mntfree(struct mntrpc*);
72 void    mntgate(struct mnt*);
73 void    mntpntfree(struct mnt*);
74 void    mntqrm(struct mnt*, struct mntrpc*);
75 struct mntrpc*  mntralloc(struct chan*, uint32_t);
76 long    mntrdwr( int unused_int, struct chan*, void*, long, int64_t);
77 int     mntrpcread(struct mnt*, struct mntrpc*);
78 void    mountio(struct mnt*, struct mntrpc*);
79 void    mountmux(struct mnt*, struct mntrpc*);
80 void    mountrpc(struct mnt*, struct mntrpc*);
81 int     rpcattn(void*);
82 struct chan*    mntchan(void);
83
84 char    Esbadstat[] = "invalid directory entry received from server";
85 char Enoversion[] = "version not established for mount channel";
86
87
88 void (*mntstats)( int unused_int, struct chan*, uint64_t, uint32_t);
89
90 static void
91 mntinit(void)
92 {
93         mntalloc.id = 1;
94         mntalloc.tagmask[0] = 1;                        /* don't allow 0 as a tag */
95         mntalloc.tagmask[NMASK-1] = 0x80000000UL;       /* don't allow NOTAG */
96         //fmtinstall('F', fcallfmt);
97 /*      fmtinstall('D', dirfmt); */
98 /*      fmtinstall('M', dirmodefmt);  */
99
100         cinit();
101 }
102
103 /*
104  * Version is not multiplexed: message sent only once per connection.
105  */
106 long
107 mntversion(struct chan *c, char *version, int msize, int returnlen)
108 {
109         ERRSTACK(4);
110         struct fcall f;
111         uint8_t *msg;
112         struct mnt *m;
113         char *v;
114         long k, l;
115         uint64_t oo;
116         char buf[128];
117
118         qlock(&c->umqlock);     /* make sure no one else does this until we've established ourselves */
119         if(waserror()){
120                 qunlock(&c->umqlock);
121                 nexterror();
122         }
123
124         /* defaults */
125         if(msize == 0)
126                 msize = MAXRPC;
127         if(msize > c->iounit && c->iounit != 0)
128                 msize = c->iounit;
129         v = version;
130         if(v == NULL || v[0] == '\0')
131                 v = VERSION9P;
132
133         /* validity */
134         if(msize < 0)
135                 error("bad iounit in version call");
136         if(strncmp(v, VERSION9P, strlen(VERSION9P)) != 0)
137                 error("bad 9P version specification");
138
139         m = c->mux;
140
141         if(m != NULL){
142                 qunlock(&c->umqlock);
143                 poperror();
144
145                 strncpy(buf, m->version, sizeof buf);
146                 k = strlen(buf);
147                 if(strncmp(buf, v, k) != 0){
148                         snprintf(buf, sizeof buf, "incompatible 9P versions %s %s", m->version, v);
149                         error(buf);
150                 }
151                 if(returnlen > 0){
152                         if(returnlen < k)
153                                 error(Eshort);
154                         memmove(version, buf, k);
155                 }
156                 return k;
157         }
158
159         f.type = Tversion;
160         f.tag = NOTAG;
161         f.msize = msize;
162         f.version = v;
163         msg = kzmalloc(8192 + IOHDRSZ, 0);
164         if(msg == NULL)
165                 exhausted("version memory");
166         if(waserror()){
167                 kfree(msg);
168                 nexterror();
169         }
170         k = convS2M(&f, msg, 8192+IOHDRSZ);
171         if(k == 0)
172                 error("bad fversion conversion on send");
173
174         spin_lock(&c->lock);
175         oo = c->offset;
176         c->offset += k;
177         spin_unlock(&c->lock);
178
179         l = devtab[c->type]->write(c, msg, k, oo);
180
181         if(l < k){
182                 spin_lock(&c->lock);
183                 c->offset -= k - l;
184                 spin_unlock(&c->lock);
185                 error("short write in fversion");
186         }
187
188         /* message sent; receive and decode reply */
189         k = devtab[c->type]->read(c, msg, 8192+IOHDRSZ, c->offset);
190         if(k <= 0)
191                 error("EOF receiving fversion reply");
192
193         spin_lock(&c->lock);
194         c->offset += k;
195         spin_unlock(&c->lock);
196
197         l = convM2S(msg, k, &f);
198         if(l != k)
199                 error("bad fversion conversion on reply");
200         if(f.type != Rversion){
201                 if(f.type == Rerror)
202                         error(f.ename);
203                 error("unexpected reply type in fversion");
204         }
205         if(f.msize > msize)
206                 error("server tries to increase msize in fversion");
207         if(f.msize<256 || f.msize>1024*1024)
208                 error("nonsense value of msize in fversion");
209         if(strncmp(f.version, v, strlen(f.version)) != 0)
210                 error("bad 9P version returned from server");
211
212         /* now build Mnt associated with this connection */
213         spin_lock(&mntalloc.l);
214         m = mntalloc.mntfree;
215         if(m != 0)
216                 mntalloc.mntfree = m->list;
217         else {
218                 m = kzmalloc(sizeof(struct mnt), 0);
219                 if(m == 0) {
220                         spin_unlock(&mntalloc.l);
221                         exhausted("mount devices");
222                 }
223         }
224         m->list = mntalloc.list;
225         mntalloc.list = m;
226         m->version = NULL;
227         kstrdup(&m->version, f.version);
228         m->id = mntalloc.id++;
229         m->q = qopen(10*MAXRPC, 0, NULL, NULL);
230         m->msize = f.msize;
231         spin_unlock(&mntalloc.l);
232
233         poperror();     /* msg */
234         kfree(msg);
235
236         spin_lock(&m->lock);
237         m->queue = 0;
238         m->rip = 0;
239
240         c->flag |= CMSG;
241         c->mux = m;
242         m->c = c;
243         spin_unlock(&m->lock);
244
245         poperror();     /* c */
246         qunlock(&c->umqlock);
247         k = strlen(f.version);
248         if(returnlen > 0){
249                 if(returnlen < k)
250                         error(Eshort);
251                 memmove(version, f.version, k);
252         }
253
254         return k;
255 }
256
257 struct chan*
258 mntauth(struct chan *c, char *spec)
259 {
260         ERRSTACK(2);
261         struct mnt *m;
262         struct mntrpc *r;
263
264         m = c->mux;
265
266         if(m == NULL){
267                 mntversion(c, VERSION9P, MAXRPC, 0);
268                 m = c->mux;
269                 if(m == NULL)
270                         error(Enoversion);
271         }
272
273         c = mntchan();
274         if(waserror()) {
275                 /* Close must not be called since it will
276                  * call mnt recursively
277                  */
278                 chanfree(c);
279                 nexterror();
280         }
281
282         r = mntralloc(0, m->msize);
283
284         if(waserror()) {
285                 mntfree(r);
286                 nexterror();
287         }
288
289         r->request.type = Tauth;
290         r->request.afid = c->fid;
291         r->request.uname = current->user;
292         r->request.aname = spec;
293         mountrpc(m, r);
294
295         c->qid = r->reply.aqid;
296         c->mchan = m->c;
297         kref_get(&m->c->ref, 1);
298         c->mqid = c->qid;
299         c->mode = ORDWR;
300
301         poperror();     /* r */
302         mntfree(r);
303
304         poperror();     /* c */
305
306         return c;
307
308 }
309
310 static struct chan*
311 mntattach(char *muxattach)
312 {
313         ERRSTACK(2);
314         struct mnt *m;
315         struct chan *c;
316         struct mntrpc *r;
317         struct bogus{
318                 struct chan     *chan;
319                 struct chan     *authchan;
320                 char    *spec;
321                 int     flags;
322         }bogus;
323
324         bogus = *((struct bogus *)muxattach);
325         c = bogus.chan;
326
327         m = c->mux;
328
329         if(m == NULL){
330                 mntversion(c, NULL, 0, 0);
331                 m = c->mux;
332                 if(m == NULL)
333                         error(Enoversion);
334         }
335
336         c = mntchan();
337         if(waserror()) {
338                 /* Close must not be called since it will
339                  * call mnt recursively
340                  */
341                 chanfree(c);
342                 nexterror();
343         }
344
345         r = mntralloc(0, m->msize);
346
347         if(waserror()) {
348                 mntfree(r);
349                 nexterror();
350         }
351
352         r->request.type = Tattach;
353         r->request.fid = c->fid;
354         if(bogus.authchan == NULL)
355                 r->request.afid = NOFID;
356         else
357                 r->request.afid = bogus.authchan->fid;
358         r->request.uname = current->user;
359         r->request.aname = bogus.spec;
360         mountrpc(m, r);
361
362         c->qid = r->reply.qid;
363         c->mchan = m->c;
364         kref_get(&m->c->ref, 1);
365         c->mqid = c->qid;
366
367         poperror();     /* r */
368         mntfree(r);
369
370         poperror();     /* c */
371
372         if(bogus.flags&MCACHE)
373                 c->flag |= CCACHE;
374         return c;
375 }
376
377 struct chan*
378 mntchan(void)
379 {
380         struct chan *c;
381
382         c = devattach('M', 0);
383         spin_lock(&mntalloc.l);
384         c->dev = mntalloc.id++;
385         spin_unlock(&mntalloc.l);
386
387         if(c->mchan)
388                 panic("mntchan non-zero %p", c->mchan);
389         return c;
390 }
391
392 static struct walkqid*
393 mntwalk(struct chan *c, struct chan *nc, char **name, int nname)
394 {
395         ERRSTACK(2);
396         volatile int alloc;
397         int i;
398         struct mnt *m;
399         struct mntrpc *r;
400         struct walkqid *wq;
401
402         if(nc != NULL)
403                 printd("mntwalk: nc != NULL\n");
404         if(nname > MAXWELEM)
405                 error("devmnt: too many name elements");
406         alloc = 0;
407         wq = kzmalloc(sizeof(struct walkqid) + nname * sizeof(struct qid),
408                       KMALLOC_WAIT);
409         if(waserror()){
410                 if(alloc && wq->clone!=NULL)
411                         cclose(wq->clone);
412                 kfree(wq);
413                 return NULL;
414         }
415
416         alloc = 0;
417         m = mntchk(c);
418         r = mntralloc(c, m->msize);
419         if(nc == NULL){
420                 nc = devclone(c);
421                 /*
422                  * Until the other side accepts this fid, we can't mntclose it.
423                  * Therefore set type to 0 for now; rootclose is known to be safe.
424                  */
425                 nc->type = 0;
426                 alloc = 1;
427         }
428         wq->clone = nc;
429
430         if(waserror()) {
431                 mntfree(r);
432                 nexterror();
433         }
434         r->request.type = Twalk;
435         r->request.fid = c->fid;
436         r->request.newfid = nc->fid;
437         r->request.nwname = nname;
438         memmove(r->request.wname, name, nname*sizeof( char *));
439
440         mountrpc(m, r);
441
442         if(r->reply.nwqid > nname)
443                 error("too many QIDs returned by walk");
444         if(r->reply.nwqid < nname){
445                 if(alloc)
446                         cclose(nc);
447                 wq->clone = NULL;
448                 if(r->reply.nwqid == 0){
449                         kfree(wq);
450                         wq = NULL;
451                         goto Return;
452                 }
453         }
454
455         /* move new fid onto mnt device and update its qid */
456         if(wq->clone != NULL){
457                 if(wq->clone != c){
458                         wq->clone->type = c->type;
459                         wq->clone->mchan = c->mchan;
460                         kref_get(&c->mchan->ref, 1);
461                 }
462                 if(r->reply.nwqid > 0)
463                         wq->clone->qid = r->reply.wqid[r->reply.nwqid-1];
464         }
465         wq->nqid = r->reply.nwqid;
466         for(i=0; i<wq->nqid; i++)
467                 wq->qid[i] = r->reply.wqid[i];
468
469     Return:
470         poperror();
471         mntfree(r);
472         poperror();
473         return wq;
474 }
475
476 static int
477 mntstat(struct chan *c, uint8_t *dp, int n)
478 {
479         ERRSTACK(2);
480         struct mnt *m;
481         struct mntrpc *r;
482
483         if(n < BIT16SZ)
484                 error(Eshortstat);
485         m = mntchk(c);
486         r = mntralloc(c, m->msize);
487         if(waserror()) {
488                 mntfree(r);
489                 nexterror();
490         }
491         r->request.type = Tstat;
492         r->request.fid = c->fid;
493         mountrpc(m, r);
494
495         if(r->reply.nstat > n){
496                 /* doesn't fit; just patch the count and return */
497                 PBIT16(( uint8_t *)dp, r->reply.nstat);
498                 n = BIT16SZ;
499         }else{
500                 n = r->reply.nstat;
501                 memmove(dp, r->reply.stat, n);
502                 validstat(dp, n);
503                 mntdirfix(dp, c);
504         }
505         poperror();
506         mntfree(r);
507         return n;
508 }
509
510 static struct chan*
511 mntopencreate(int type, struct chan *c, char *name, int omode, uint32_t perm)
512 {
513         ERRSTACK(2);
514         struct mnt *m;
515         struct mntrpc *r;
516
517         m = mntchk(c);
518         r = mntralloc(c, m->msize);
519         if(waserror()) {
520                 mntfree(r);
521                 nexterror();
522         }
523         r->request.type = type;
524         r->request.fid = c->fid;
525         r->request.mode = omode;
526         if(type == Tcreate){
527                 r->request.perm = perm;
528                 r->request.name = name;
529         }
530         mountrpc(m, r);
531
532         c->qid = r->reply.qid;
533         c->offset = 0;
534         c->mode = openmode(omode);
535         c->iounit = r->reply.iounit;
536         if(c->iounit == 0 || c->iounit > m->msize-IOHDRSZ)
537                 c->iounit = m->msize-IOHDRSZ;
538         c->flag |= COPEN;
539         poperror();
540         mntfree(r);
541
542         if(c->flag & CCACHE)
543                 copen(c);
544
545         return c;
546 }
547
548 static struct chan*
549 mntopen(struct chan *c, int omode)
550 {
551         return mntopencreate(Topen, c, NULL, omode, 0);
552 }
553
554 static void
555 mntcreate(struct chan *c, char *name, int omode, uint32_t perm)
556 {
557         mntopencreate(Tcreate, c, name, omode, perm);
558 }
559
560 static void
561 mntclunk(struct chan *c, int t)
562 {
563         ERRSTACK(2);
564         struct mnt *m;
565         struct mntrpc *r;
566
567         m = mntchk(c);
568         r = mntralloc(c, m->msize);
569         if(waserror()){
570                 mntfree(r);
571                 nexterror();
572         }
573
574         r->request.type = t;
575         r->request.fid = c->fid;
576         mountrpc(m, r);
577         mntfree(r);
578         poperror();
579 }
580
581 void
582 muxclose(struct mnt *m)
583 {
584         struct mntrpc *q, *r;
585
586         for(q = m->queue; q; q = r) {
587                 r = q->list;
588                 mntfree(q);
589         }
590         m->id = 0;
591         kfree(m->version);
592         m->version = NULL;
593         mntpntfree(m);
594 }
595
596 void
597 mntpntfree(struct mnt *m)
598 {
599         struct mnt *f, **l;
600         struct queue *q;
601
602         spin_lock(&mntalloc.l);
603         l = &mntalloc.list;
604         for(f = *l; f; f = f->list) {
605                 if(f == m) {
606                         *l = m->list;
607                         break;
608                 }
609                 l = &f->list;
610         }
611         m->list = mntalloc.mntfree;
612         mntalloc.mntfree = m;
613         q = m->q;
614         spin_unlock(&mntalloc.l);
615
616         qfree(q);
617 }
618
619 static void
620 mntclose(struct chan *c)
621 {
622         mntclunk(c, Tclunk);
623 }
624
625 static void
626 mntremove(struct chan *c)
627 {
628         mntclunk(c, Tremove);
629 }
630
631 static int
632 mntwstat(struct chan *c, uint8_t *dp, int n)
633 {
634         ERRSTACK(2);
635         struct mnt *m;
636         struct mntrpc *r;
637
638         m = mntchk(c);
639         r = mntralloc(c, m->msize);
640         if(waserror()) {
641                 mntfree(r);
642                 nexterror();
643         }
644         r->request.type = Twstat;
645         r->request.fid = c->fid;
646         r->request.nstat = n;
647         r->request.stat = dp;
648         mountrpc(m, r);
649         poperror();
650         mntfree(r);
651         return n;
652 }
653
654 static long
655 mntread(struct chan *c, void *buf, long n, int64_t off)
656 {
657         uint8_t *p, *e;
658         int nc, cache, isdir, dirlen;
659
660         isdir = 0;
661         cache = c->flag & CCACHE;
662         if(c->qid.type & QTDIR) {
663                 cache = 0;
664                 isdir = 1;
665         }
666
667         p = buf;
668         if(cache) {
669                 nc = cread(c, buf, n, off);
670                 if(nc > 0) {
671                         n -= nc;
672                         if(n == 0)
673                                 return nc;
674                         p += nc;
675                         off += nc;
676                 }
677                 n = mntrdwr(Tread, c, p, n, off);
678                 cupdate(c, p, n, off);
679                 return n + nc;
680         }
681
682         n = mntrdwr(Tread, c, buf, n, off);
683         if(isdir) {
684                 for(e = &p[n]; p+BIT16SZ < e; p += dirlen){
685                         dirlen = BIT16SZ+GBIT16(p);
686                         if(p+dirlen > e)
687                                 break;
688                         validstat(p, dirlen);
689                         mntdirfix(p, c);
690                 }
691                 if(p != e)
692                         error(Esbadstat);
693         }
694         return n;
695 }
696
697 static long
698 mntwrite(struct chan *c, void *buf, long n, int64_t off)
699 {
700         return mntrdwr(Twrite, c, buf, n, off);
701 }
702
703 long
704 mntrdwr(int type, struct chan *c, void *buf, long n, int64_t off)
705 {
706         ERRSTACK(2);
707         struct mnt *m;
708         struct mntrpc *r;       /* TO DO: volatile struct { Mntrpc *r; } r; */
709         char *uba;
710         int cache;
711         uint32_t cnt, nr, nreq;
712
713         m = mntchk(c);
714         uba = buf;
715         cnt = 0;
716         cache = c->flag & CCACHE;
717         if(c->qid.type & QTDIR)
718                 cache = 0;
719         for(;;) {
720                 r = mntralloc(c, m->msize);
721                 if(waserror()) {
722                         mntfree(r);
723                         nexterror();
724                 }
725                 r->request.type = type;
726                 r->request.fid = c->fid;
727                 r->request.offset = off;
728                 r->request.data = uba;
729                 nr = n;
730                 if(nr > m->msize-IOHDRSZ)
731                         nr = m->msize-IOHDRSZ;
732                 r->request.count = nr;
733                 mountrpc(m, r);
734                 nreq = r->request.count;
735                 nr = r->reply.count;
736                 if(nr > nreq)
737                         nr = nreq;
738
739                 if(type == Tread)
740                         r->b = bl2mem(( uint8_t *)uba, r->b, nr);
741                 else if(cache)
742                         cwrite(c, ( uint8_t *)uba, nr, off);
743
744                 poperror();
745                 mntfree(r);
746                 off += nr;
747                 uba += nr;
748                 cnt += nr;
749                 n -= nr;
750                 if(nr != nreq || n == 0 /*|| current->killed*/)
751                         break;
752         }
753         return cnt;
754 }
755
756 void
757 mountrpc(struct mnt *m, struct mntrpc *r)
758 {
759         char *sn, *cn;
760         int t;
761
762         r->reply.tag = 0;
763         r->reply.type = Tmax;   /* can't ever be a valid message type */
764
765         mountio(m, r);
766
767         t = r->reply.type;
768         switch(t) {
769         case Rerror:
770                 error(r->reply.ename);
771         case Rflush:
772                 error(Eintr);
773         default:
774                 if(t == r->request.type+1)
775                         break;
776                 sn = "?";
777                 if(m->c->name != NULL)
778                         sn = m->c->name->s;
779                 cn = "?";
780                 if(r->c != NULL && r->c->name != NULL)
781                         cn = r->c->name->s;
782                 printd("mnt: proc %s %lud: mismatch from %s %s rep 0x%p tag %d fid %d T%d R%d rp %d\n",
783                        "current->text", "current->pid", sn, cn,
784                         r, r->request.tag, r->request.fid, r->request.type,
785                         r->reply.type, r->reply.tag);
786                 error(Emountrpc);
787         }
788 }
789
790 void
791 mountio(struct mnt *m, struct mntrpc *r)
792 {
793         ERRSTACK(4);
794         int n;
795
796         while(waserror()) {
797                 if(m->rip == current)
798                         mntgate(m);
799                 if(strcmp(current_errstr(), Eintr) != 0){
800                         mntflushfree(m, r);
801                         nexterror();
802                 }
803                 r = mntflushalloc(r, m->msize);
804         }
805
806         spin_lock(&m->lock);
807         r->m = m;
808         r->list = m->queue;
809         m->queue = r;
810         spin_unlock(&m->lock);
811
812         /* Transmit a file system rpc */
813         if(m->msize == 0)
814                 panic("msize");
815         n = convS2M(&r->request, r->rpc, m->msize);
816         if(n < 0)
817                 panic("bad message type in mountio");
818         if(devtab[m->c->type]->write(m->c, r->rpc, n, 0) != n)
819                 error(Emountrpc);
820 /*      r->stime = fastticks(NULL); */
821         r->reqlen = n;
822
823         /* Gate readers onto the mount point one at a time */
824         for(;;) {
825                 spin_lock(&m->lock);
826                 if(m->rip == 0)
827                         break;
828                 spin_unlock(&m->lock);
829                 rendez_sleep(&r->r, rpcattn, r);
830                 if(r->done){
831                         poperror();
832                         mntflushfree(m, r);
833                         return;
834                 }
835         }
836         m->rip = current;
837         spin_unlock(&m->lock);
838         while(r->done == 0) {
839                 if(mntrpcread(m, r) < 0)
840                         error(Emountrpc);
841                 mountmux(m, r);
842         }
843         mntgate(m);
844         poperror();
845         mntflushfree(m, r);
846 }
847
848 static int
849 doread(struct mnt *m, int len)
850 {
851         struct block *b;
852
853         while(qlen(m->q) < len){
854                 b = devtab[m->c->type]->bread(m->c, m->msize, 0);
855                 if(b == NULL)
856                         return -1;
857                 if(blocklen(b) == 0){
858                         freeblist(b);
859                         return -1;
860                 }
861                 qaddlist(m->q, b);
862         }
863         return 0;
864 }
865
866 int
867 mntrpcread(struct mnt *m, struct mntrpc *r)
868 {
869         int i, t, len, hlen;
870         struct block *b, **l, *nb;
871
872         r->reply.type = 0;
873         r->reply.tag = 0;
874
875         /* read at least length, type, and tag and pullup to a single block */
876         if(doread(m, BIT32SZ+BIT8SZ+BIT16SZ) < 0)
877                 return -1;
878         nb = pullupqueue(m->q, BIT32SZ+BIT8SZ+BIT16SZ);
879
880         /* read in the rest of the message, avoid ridiculous (for now) message sizes */
881         len = GBIT32(nb->rp);
882         if(len > m->msize){
883                 qdiscard(m->q, qlen(m->q));
884                 return -1;
885         }
886         if(doread(m, len) < 0)
887                 return -1;
888
889         /* pullup the header (i.e. everything except data) */
890         t = nb->rp[BIT32SZ];
891         switch(t){
892         case Rread:
893                 hlen = BIT32SZ+BIT8SZ+BIT16SZ+BIT32SZ;
894                 break;
895         default:
896                 hlen = len;
897                 break;
898         }
899         nb = pullupqueue(m->q, hlen);
900
901         if(convM2S(nb->rp, len, &r->reply) <= 0){
902                 /* bad message, dump it */
903                 printd("mntrpcread: convM2S failed\n");
904                 qdiscard(m->q, len);
905                 return -1;
906         }
907
908         /* hang the data off of the fcall struct */
909         l = &r->b;
910         *l = NULL;
911         do {
912                 b = qremove(m->q);
913                 if(hlen > 0){
914                         b->rp += hlen;
915                         len -= hlen;
916                         hlen = 0;
917                 }
918                 i = BLEN(b);
919                 if(i <= len){
920                         len -= i;
921                         *l = b;
922                         l = &(b->next);
923                 } else {
924                         /* split block and put unused bit back */
925                         nb = allocb(i-len);
926                         memmove(nb->wp, b->rp+len, i-len);
927                         b->wp = b->rp+len;
928                         nb->wp += i-len;
929                         qputback(m->q, nb);
930                         *l = b;
931                         return 0;
932                 }
933         }while(len > 0);
934
935         return 0;
936 }
937
938 void
939 mntgate(struct mnt *m)
940 {
941         struct mntrpc *q;
942
943         spin_lock(&m->lock);
944         m->rip = 0;
945         for(q = m->queue; q; q = q->list) {
946                 if(q->done == 0)
947                         if (rendez_wakeup(&q->r))
948                                 break;
949         }
950         spin_unlock(&m->lock);
951 }
952
953 void
954 mountmux(struct mnt *m, struct mntrpc *r)
955 {
956         struct mntrpc **l, *q;
957
958         spin_lock(&m->lock);
959         l = &m->queue;
960         for(q = *l; q; q = q->list) {
961                 /* look for a reply to a message */
962                 if(q->request.tag == r->reply.tag) {
963                         *l = q->list;
964                         if(q != r) {
965                                 /*
966                                  * Completed someone else.
967                                  * Trade pointers to receive buffer.
968                                  */
969                                 q->reply = r->reply;
970                                 q->b = r->b;
971                                 r->b = NULL;
972                         }
973                         q->done = 1;
974                         spin_unlock(&m->lock);
975                         if(mntstats != NULL)
976                                 (*mntstats)(q->request.type,
977                                         m->c, q->stime,
978                                         q->reqlen + r->replen);
979                         if(q != r)
980                                 rendez_wakeup(&q->r);
981                         return;
982                 }
983                 l = &q->list;
984         }
985         spin_unlock(&m->lock);
986         if(r->reply.type == Rerror){
987                 printd("unexpected reply tag %ud; type %d (error %q)\n", r->reply.tag, r->reply.type, r->reply.ename);
988         }else{
989                 printd("unexpected reply tag %ud; type %d\n", r->reply.tag, r->reply.type);
990         }
991 }
992
993 /*
994  * Create a new flush request and chain the previous
995  * requests from it
996  */
997 struct mntrpc*
998 mntflushalloc(struct mntrpc *r, uint32_t iounit)
999 {
1000         struct mntrpc *fr;
1001
1002         fr = mntralloc(0, iounit);
1003
1004         fr->request.type = Tflush;
1005         if(r->request.type == Tflush)
1006                 fr->request.oldtag = r->request.oldtag;
1007         else
1008                 fr->request.oldtag = r->request.tag;
1009         fr->flushed = r;
1010
1011         return fr;
1012 }
1013
1014 /*
1015  *  Free a chain of flushes.  Remove each unanswered
1016  *  flush and the original message from the unanswered
1017  *  request queue.  Mark the original message as done
1018  *  and if it hasn't been answered set the reply to to
1019  *  Rflush.
1020  */
1021 void
1022 mntflushfree(struct mnt *m, struct mntrpc *r)
1023 {
1024         struct mntrpc *fr;
1025
1026         while(r){
1027                 fr = r->flushed;
1028                 if(!r->done){
1029                         r->reply.type = Rflush;
1030                         mntqrm(m, r);
1031                 }
1032                 if(fr)
1033                         mntfree(r);
1034                 r = fr;
1035         }
1036 }
1037
1038 static int
1039 alloctag(void)
1040 {
1041         int i, j;
1042         uint32_t v;
1043
1044         for(i = 0; i < NMASK; i++){
1045                 v = mntalloc.tagmask[i];
1046                 if(v == ~0UL)
1047                         continue;
1048                 for(j = 0; j < 1<<TAGSHIFT; j++)
1049                         if((v & (1<<j)) == 0){
1050                                 mntalloc.tagmask[i] |= 1<<j;
1051                                 return (i<<TAGSHIFT) + j;
1052                         }
1053         }
1054         /* panic("no devmnt tags left"); */
1055         return NOTAG;
1056 }
1057
1058 static void
1059 freetag(int t)
1060 {
1061         mntalloc.tagmask[t>>TAGSHIFT] &= ~(1<<(t&TAGMASK));
1062 }
1063
1064 struct mntrpc*
1065 mntralloc(struct chan *c, uint32_t msize)
1066 {
1067         struct mntrpc *new;
1068
1069         spin_lock(&mntalloc.l);
1070         new = mntalloc.rpcfree;
1071         if(new == NULL){
1072                 new = kzmalloc(sizeof(struct mntrpc), 0);
1073                 if(new == NULL) {
1074                         spin_unlock(&mntalloc.l);
1075                         exhausted("mount rpc header");
1076                 }
1077                 /*
1078                  * The header is split from the data buffer as
1079                  * mountmux may swap the buffer with another header.
1080                  */
1081                 new->rpc = kzmalloc(msize, KMALLOC_WAIT);
1082                 if(new->rpc == NULL){
1083                         kfree(new);
1084                         spin_unlock(&mntalloc.l);
1085                         exhausted("mount rpc buffer");
1086                 }
1087                 new->rpclen = msize;
1088                 new->request.tag = alloctag();
1089                 if(new->request.tag == NOTAG){
1090                         kfree(new);
1091                         spin_unlock(&mntalloc.l);
1092                         exhausted("rpc tags");
1093                 }
1094         }
1095         else {
1096                 mntalloc.rpcfree = new->list;
1097                 mntalloc.nrpcfree--;
1098                 if(new->rpclen < msize){
1099                         kfree(new->rpc);
1100                         new->rpc = kzmalloc(msize, KMALLOC_WAIT);
1101                         if(new->rpc == NULL){
1102                                 kfree(new);
1103                                 mntalloc.nrpcused--;
1104                                 spin_unlock(&mntalloc.l);
1105                                 exhausted("mount rpc buffer");
1106                         }
1107                         new->rpclen = msize;
1108                 }
1109         }
1110         mntalloc.nrpcused++;
1111         spin_unlock(&mntalloc.l);
1112         new->c = c;
1113         new->done = 0;
1114         new->flushed = NULL;
1115         new->b = NULL;
1116         return new;
1117 }
1118
1119 void
1120 mntfree(struct mntrpc *r)
1121 {
1122         if(r->b != NULL)
1123                 freeblist(r->b);
1124         spin_lock(&mntalloc.l);
1125         if(mntalloc.nrpcfree >= 10){
1126                 kfree(r->rpc);
1127                 freetag(r->request.tag);
1128                 kfree(r);
1129         }
1130         else{
1131                 r->list = mntalloc.rpcfree;
1132                 mntalloc.rpcfree = r;
1133                 mntalloc.nrpcfree++;
1134         }
1135         mntalloc.nrpcused--;
1136         spin_unlock(&mntalloc.l);
1137 }
1138
1139 void
1140 mntqrm(struct mnt *m, struct mntrpc *r)
1141 {
1142         struct mntrpc **l, *f;
1143
1144         spin_lock(&m->lock);
1145         r->done = 1;
1146
1147         l = &m->queue;
1148         for(f = *l; f; f = f->list) {
1149                 if(f == r) {
1150                         *l = r->list;
1151                         break;
1152                 }
1153                 l = &f->list;
1154         }
1155         spin_unlock(&m->lock);
1156 }
1157
1158 struct mnt*
1159 mntchk(struct chan *c)
1160 {
1161         struct mnt *m;
1162
1163         /* This routine is mostly vestiges of prior lives; now it's just sanity checking */
1164
1165         if(c->mchan == NULL)
1166                 panic("mntchk 1: NULL mchan c %s\n", /*c2name(c)*/"channame?");
1167
1168         m = c->mchan->mux;
1169
1170         if(m == NULL)
1171                 printd("mntchk 2: NULL mux c %s c->mchan %s \n", c2name(c), c2name(c->mchan));
1172
1173         /*
1174          * Was it closed and reused (was error(Eshutdown); now, it can't happen)
1175          */
1176         if(m->id == 0 || m->id >= c->dev)
1177                 panic("mntchk 3: can't happen");
1178
1179         return m;
1180 }
1181
1182 /*
1183  * Rewrite channel type and dev for in-flight data to
1184  * reflect local values.  These entries are known to be
1185  * the first two in the Dir encoding after the count.
1186  */
1187 void
1188 mntdirfix(uint8_t *dirbuf, struct chan *c)
1189 {
1190         unsigned int r;
1191
1192         r = devtab[c->type]->dc;
1193         dirbuf += BIT16SZ;      /* skip count */
1194         PBIT16(dirbuf, r);
1195         dirbuf += BIT16SZ;
1196         PBIT32(dirbuf, c->dev);
1197 }
1198
1199 int
1200 rpcattn(void *v)
1201 {
1202         struct mntrpc *r;
1203
1204         r = v;
1205         return r->done || r->m->rip == 0;
1206 }
1207
1208 struct dev mntdevtab = {
1209         'M',
1210         "mnt",
1211
1212         devreset,
1213         mntinit,
1214         devshutdown,
1215         mntattach,
1216         mntwalk,
1217         mntstat,
1218         mntopen,
1219         mntcreate,
1220         mntclose,
1221         mntread,
1222         devbread,
1223         mntwrite,
1224         devbwrite,
1225         mntremove,
1226         mntwstat,
1227 };