Fix mount chan type error
[akaros.git] / kern / drivers / dev / mnt.c
1 // INFERNO
2 #include <vfs.h>
3 #include <kfs.h>
4 #include <slab.h>
5 #include <kmalloc.h>
6 #include <kref.h>
7 #include <string.h>
8 #include <stdio.h>
9 #include <assert.h>
10 #include <error.h>
11 #include <cpio.h>
12 #include <pmap.h>
13 #include <smp.h>
14 #include <ip.h>
15
16 /*
17  * References are managed as follows:
18  * The channel to the server - a network connection or pipe - has one
19  * reference for every Chan open on the server.  The server channel has
20  * c->mux set to the Mnt used for muxing control to that server.  Mnts
21  * have no reference count; they go away when c goes away.
22  * Each channel derived from the mount point has mchan set to c,
23  * and increfs/decrefs mchan to manage references on the server
24  * connection.
25  */
26
27 #define MAXRPC (IOHDRSZ+8192)
28
29 struct mntrpc
30 {
31         struct chan*    c;              /* Channel for whom we are working */
32         struct mntrpc*  list;           /* Free/pending list */
33         struct fcall    request;        /* Outgoing file system protocol message */
34         struct fcall    reply;          /* Incoming reply */
35         struct mnt*     m;              /* Mount device during rpc */
36         struct rendez   r;              /* Place to hang out */
37         uint8_t*        rpc;            /* I/O Data buffer */
38         unsigned int            rpclen; /* len of buffer */
39         struct block    *b;             /* reply blocks */
40         char    done;           /* Rpc completed */
41         uint64_t        stime;          /* start time for mnt statistics */
42         uint32_t        reqlen;         /* request length for mnt statistics */
43         uint32_t        replen;         /* reply length for mnt statistics */
44         struct mntrpc*  flushed;        /* message this one flushes */
45 };
46
47 enum
48 {
49         TAGSHIFT = 5,                   /* uint32_t has to be 32 bits */
50         TAGMASK = (1<<TAGSHIFT)-1,
51         NMASK = (64*1024)>>TAGSHIFT,
52 };
53
54 struct Mntalloc
55 {
56         spinlock_t      l;
57         struct mnt*     list;           /* Mount devices in use */
58         struct mnt*     mntfree;        /* Free list */
59         struct mntrpc*  rpcfree;
60         int     nrpcfree;
61         int     nrpcused;
62         uint32_t        id;
63         uint32_t        tagmask[NMASK];
64 }mntalloc;
65
66 void    mattach(struct mnt*, struct chan*, char *unused_char_p_t);
67 struct mnt*     mntchk(struct chan*);
68 void    mntdirfix( uint8_t *unused_uint8_p_t, struct chan*);
69 struct mntrpc*  mntflushalloc(struct mntrpc*, uint32_t);
70 void    mntflushfree(struct mnt*, struct mntrpc*);
71 void    mntfree(struct mntrpc*);
72 void    mntgate(struct mnt*);
73 void    mntpntfree(struct mnt*);
74 void    mntqrm(struct mnt*, struct mntrpc*);
75 struct mntrpc*  mntralloc(struct chan*, uint32_t);
76 long    mntrdwr( int unused_int, struct chan*, void*, long, int64_t);
77 int     mntrpcread(struct mnt*, struct mntrpc*);
78 void    mountio(struct mnt*, struct mntrpc*);
79 void    mountmux(struct mnt*, struct mntrpc*);
80 void    mountrpc(struct mnt*, struct mntrpc*);
81 int     rpcattn(void*);
82 struct chan*    mntchan(void);
83
84 char    Esbadstat[] = "invalid directory entry received from server";
85 char Enoversion[] = "version not established for mount channel";
86
87
88 void (*mntstats)( int unused_int, struct chan*, uint64_t, uint32_t);
89
90 static void
91 mntinit(void)
92 {
93         mntalloc.id = 1;
94         mntalloc.tagmask[0] = 1;                        /* don't allow 0 as a tag */
95         mntalloc.tagmask[NMASK-1] = 0x80000000UL;       /* don't allow NOTAG */
96         //fmtinstall('F', fcallfmt);
97 /*      fmtinstall('D', dirfmt); */
98 /*      fmtinstall('M', dirmodefmt);  */
99
100         cinit();
101 }
102
103 /*
104  * Version is not multiplexed: message sent only once per connection.
105  */
106 long
107 mntversion(struct chan *c, char *version, int msize, int returnlen)
108 {
109         ERRSTACK(2);
110         struct fcall f;
111         uint8_t *msg;
112         struct mnt *m;
113         char *v;
114         long k, l;
115         uint64_t oo;
116         char buf[128];
117
118         qlock(&c->umqlock);     /* make sure no one else does this until we've established ourselves */
119         if(waserror()){
120                 qunlock(&c->umqlock);
121                 nexterror();
122         }
123
124         /* defaults */
125         if(msize == 0)
126                 msize = MAXRPC;
127         if(msize > c->iounit && c->iounit != 0)
128                 msize = c->iounit;
129         v = version;
130         if(v == NULL || v[0] == '\0')
131                 v = VERSION9P;
132
133         /* validity */
134         if(msize < 0)
135                 error("bad iounit in version call");
136         if(strncmp(v, VERSION9P, strlen(VERSION9P)) != 0)
137                 error("bad 9P version specification");
138
139         m = c->mux;
140
141         if(m != NULL){
142                 qunlock(&c->umqlock);
143                 poperror();
144
145                 strncpy(buf, m->version, sizeof buf);
146                 k = strlen(buf);
147                 if(strncmp(buf, v, k) != 0){
148                         snprintf(buf, sizeof buf, "incompatible 9P versions %s %s", m->version, v);
149                         error(buf);
150                 }
151                 if(returnlen > 0){
152                         if(returnlen < k)
153                                 error(Eshort);
154                         memmove(version, buf, k);
155                 }
156                 return k;
157         }
158
159         f.type = Tversion;
160         f.tag = NOTAG;
161         f.msize = msize;
162         f.version = v;
163         msg = kzmalloc(8192 + IOHDRSZ, 0);
164         if(msg == NULL)
165                 exhausted("version memory");
166         if(waserror()){
167                 kfree(msg);
168                 nexterror();
169         }
170         k = convS2M(&f, msg, 8192+IOHDRSZ);
171         if(k == 0)
172                 error("bad fversion conversion on send");
173
174         spin_lock(&c->lock);
175         oo = c->offset;
176         c->offset += k;
177         spin_unlock(&c->lock);
178
179         l = devtab[c->type].write(c, msg, k, oo);
180
181         if(l < k){
182                 spin_lock(&c->lock);
183                 c->offset -= k - l;
184                 spin_unlock(&c->lock);
185                 error("short write in fversion");
186         }
187
188         /* message sent; receive and decode reply */
189         k = devtab[c->type].read(c, msg, 8192+IOHDRSZ, c->offset);
190         if(k <= 0)
191                 error("EOF receiving fversion reply");
192
193         spin_lock(&c->lock);
194         c->offset += k;
195         spin_unlock(&c->lock);
196
197         l = convM2S(msg, k, &f);
198         if(l != k)
199                 error("bad fversion conversion on reply");
200         if(f.type != Rversion){
201                 if(f.type == Rerror)
202                         error(f.ename);
203                 error("unexpected reply type in fversion");
204         }
205         if(f.msize > msize)
206                 error("server tries to increase msize in fversion");
207         if(f.msize<256 || f.msize>1024*1024)
208                 error("nonsense value of msize in fversion");
209         if(strncmp(f.version, v, strlen(f.version)) != 0)
210                 error("bad 9P version returned from server");
211
212         /* now build Mnt associated with this connection */
213         spin_lock(&mntalloc.l);
214         m = mntalloc.mntfree;
215         if(m != 0)
216                 mntalloc.mntfree = m->list;
217         else {
218                 m = kzmalloc(sizeof(struct mnt), 0);
219                 if(m == 0) {
220                         spin_unlock(&mntalloc.l);
221                         exhausted("mount devices");
222                 }
223         }
224         m->list = mntalloc.list;
225         mntalloc.list = m;
226         m->version = NULL;
227         kstrdup(&m->version, f.version);
228         m->id = mntalloc.id++;
229         m->q = qopen(10*MAXRPC, 0, NULL, NULL);
230         m->msize = f.msize;
231         spin_unlock(&mntalloc.l);
232
233         poperror();     /* msg */
234         kfree(msg);
235
236         spin_lock(&m->lock);
237         m->queue = 0;
238         m->rip = 0;
239
240         c->flag |= CMSG;
241         c->mux = m;
242         m->c = c;
243         spin_unlock(&m->lock);
244
245         poperror();     /* c */
246         qunlock(&c->umqlock);
247         k = strlen(f.version);
248         if(returnlen > 0){
249                 if(returnlen < k)
250                         error(Eshort);
251                 memmove(version, f.version, k);
252         }
253
254         return k;
255 }
256
257 struct chan*
258 mntauth(struct chan *c, char *spec)
259 {
260         ERRSTACK(2);
261         struct mnt *m;
262         struct mntrpc *r;
263
264         m = c->mux;
265
266         if(m == NULL){
267                 mntversion(c, VERSION9P, MAXRPC, 0);
268                 m = c->mux;
269                 if(m == NULL)
270                         error(Enoversion);
271         }
272
273         c = mntchan();
274         if(waserror()) {
275                 /* Close must not be called since it will
276                  * call mnt recursively
277                  */
278                 chanfree(c);
279                 nexterror();
280         }
281
282         r = mntralloc(0, m->msize);
283
284         if(waserror()) {
285                 mntfree(r);
286                 nexterror();
287         }
288
289         r->request.type = Tauth;
290         r->request.afid = c->fid;
291         r->request.uname = current->user;
292         r->request.aname = spec;
293         mountrpc(m, r);
294
295         c->qid = r->reply.aqid;
296         c->mchan = m->c;
297         kref_get(&m->c->ref, 1);
298         c->mqid = c->qid;
299         c->mode = ORDWR;
300
301         poperror();     /* r */
302         mntfree(r);
303
304         poperror();     /* c */
305
306         return c;
307
308 }
309
310 static struct chan*
311 mntattach(char *muxattach)
312 {
313         ERRSTACK(2);
314         struct mnt *m;
315         struct chan *c;
316         struct mntrpc *r;
317         struct bogus{
318                 struct chan     *chan;
319                 struct chan     *authchan;
320                 char    *spec;
321                 int     flags;
322         }bogus;
323
324         bogus = *((struct bogus *)muxattach);
325         c = bogus.chan;
326
327         m = c->mux;
328
329         if(m == NULL){
330                 mntversion(c, NULL, 0, 0);
331                 m = c->mux;
332                 if(m == NULL)
333                         error(Enoversion);
334         }
335
336         c = mntchan();
337         if(waserror()) {
338                 /* Close must not be called since it will
339                  * call mnt recursively
340                  */
341                 chanfree(c);
342                 nexterror();
343         }
344
345         r = mntralloc(0, m->msize);
346
347         if(waserror()) {
348                 mntfree(r);
349                 nexterror();
350         }
351
352         r->request.type = Tattach;
353         r->request.fid = c->fid;
354         if(bogus.authchan == NULL)
355                 r->request.afid = NOFID;
356         else
357                 r->request.afid = bogus.authchan->fid;
358         r->request.uname = current->user;
359         r->request.aname = bogus.spec;
360         mountrpc(m, r);
361
362         c->qid = r->reply.qid;
363         c->mchan = m->c;
364         kref_get(&m->c->ref, 1);
365         c->mqid = c->qid;
366
367         poperror();     /* r */
368         mntfree(r);
369
370         poperror();     /* c */
371
372         if(bogus.flags&MCACHE)
373                 c->flag |= CCACHE;
374         return c;
375 }
376
377 struct chan*
378 mntchan(void)
379 {
380         struct chan *c;
381
382         c = devattach('M', 0);
383         spin_lock(&mntalloc.l);
384         c->dev = mntalloc.id++;
385         spin_unlock(&mntalloc.l);
386
387         if(c->mchan)
388                 panic("mntchan non-zero %p", c->mchan);
389         return c;
390 }
391
392 static struct walkqid*
393 mntwalk(struct chan *c, struct chan *nc, char **name, int nname)
394 {
395         ERRSTACK(2);
396         volatile int alloc;
397         int i;
398         struct mnt *m;
399         struct mntrpc *r;
400         struct walkqid *wq;
401
402         if(nc != NULL)
403                 printd("mntwalk: nc != NULL\n");
404         if(nname > MAXWELEM)
405                 error("devmnt: too many name elements");
406         alloc = 0;
407         wq = kzmalloc(sizeof(struct walkqid) + nname * sizeof(struct qid),
408                       KMALLOC_WAIT);
409         if(waserror()){
410                 if(alloc && wq->clone!=NULL)
411                         cclose(wq->clone);
412                 kfree(wq);
413                 poperror();
414                 return NULL;
415         }
416
417         alloc = 0;
418         m = mntchk(c);
419         r = mntralloc(c, m->msize);
420         if(nc == NULL){
421                 nc = devclone(c);
422                 /* Until the other side accepts this fid, we can't mntclose it.
423                  * Therefore set type to -1 for now.  inferno was setting this to 0,
424                  * assuming it was devroot.  lining up with chanrelease and newchan */
425                 nc->type = -1;
426                 alloc = 1;
427         }
428         wq->clone = nc;
429
430         if(waserror()) {
431                 mntfree(r);
432                 nexterror();
433         }
434         r->request.type = Twalk;
435         r->request.fid = c->fid;
436         r->request.newfid = nc->fid;
437         r->request.nwname = nname;
438         memmove(r->request.wname, name, nname*sizeof( char *));
439
440         mountrpc(m, r);
441
442         if(r->reply.nwqid > nname)
443                 error("too many QIDs returned by walk");
444         if(r->reply.nwqid < nname){
445                 if(alloc)
446                         cclose(nc);
447                 wq->clone = NULL;
448                 if(r->reply.nwqid == 0){
449                         kfree(wq);
450                         wq = NULL;
451                         goto Return;
452                 }
453         }
454
455         /* move new fid onto mnt device and update its qid */
456         if(wq->clone != NULL){
457                 if(wq->clone != c){
458                         wq->clone->type = c->type;
459                         wq->clone->mchan = c->mchan;
460                         kref_get(&c->mchan->ref, 1);
461                 }
462                 if(r->reply.nwqid > 0)
463                         wq->clone->qid = r->reply.wqid[r->reply.nwqid-1];
464         }
465         wq->nqid = r->reply.nwqid;
466         for(i=0; i<wq->nqid; i++)
467                 wq->qid[i] = r->reply.wqid[i];
468
469     Return:
470         poperror();
471         mntfree(r);
472         poperror();
473         return wq;
474 }
475
476 static int
477 mntstat(struct chan *c, uint8_t *dp, int n)
478 {
479         ERRSTACK(1);
480         struct mnt *m;
481         struct mntrpc *r;
482
483         if(n < BIT16SZ)
484                 error(Eshortstat);
485         m = mntchk(c);
486         r = mntralloc(c, m->msize);
487         if(waserror()) {
488                 mntfree(r);
489                 nexterror();
490         }
491         r->request.type = Tstat;
492         r->request.fid = c->fid;
493         mountrpc(m, r);
494
495         if(r->reply.nstat > n){
496                 /* doesn't fit; just patch the count and return */
497                 PBIT16(( uint8_t *)dp, r->reply.nstat);
498                 n = BIT16SZ;
499         }else{
500                 n = r->reply.nstat;
501                 memmove(dp, r->reply.stat, n);
502                 validstat(dp, n);
503                 mntdirfix(dp, c);
504         }
505         poperror();
506         mntfree(r);
507         return n;
508 }
509
510 static struct chan*
511 mntopencreate(int type, struct chan *c, char *name, int omode, uint32_t perm)
512 {
513         ERRSTACK(1);
514         struct mnt *m;
515         struct mntrpc *r;
516
517         m = mntchk(c);
518         r = mntralloc(c, m->msize);
519         if(waserror()) {
520                 mntfree(r);
521                 nexterror();
522         }
523         r->request.type = type;
524         r->request.fid = c->fid;
525         r->request.mode = omode;
526         if(type == Tcreate){
527                 r->request.perm = perm;
528                 r->request.name = name;
529         }
530         mountrpc(m, r);
531
532         c->qid = r->reply.qid;
533         c->offset = 0;
534         c->mode = openmode(omode);
535         c->iounit = r->reply.iounit;
536         if(c->iounit == 0 || c->iounit > m->msize-IOHDRSZ)
537                 c->iounit = m->msize-IOHDRSZ;
538         c->flag |= COPEN;
539         poperror();
540         mntfree(r);
541
542         if(c->flag & CCACHE)
543                 copen(c);
544
545         return c;
546 }
547
548 static struct chan*
549 mntopen(struct chan *c, int omode)
550 {
551         return mntopencreate(Topen, c, NULL, omode, 0);
552 }
553
554 static void
555 mntcreate(struct chan *c, char *name, int omode, uint32_t perm)
556 {
557         mntopencreate(Tcreate, c, name, omode, perm);
558 }
559
560 static void
561 mntclunk(struct chan *c, int t)
562 {
563         ERRSTACK(1);
564         struct mnt *m;
565         struct mntrpc *r;
566
567         m = mntchk(c);
568         r = mntralloc(c, m->msize);
569         if(waserror()){
570                 mntfree(r);
571                 nexterror();
572         }
573
574         r->request.type = t;
575         r->request.fid = c->fid;
576         mountrpc(m, r);
577         mntfree(r);
578         poperror();
579 }
580
581 void
582 muxclose(struct mnt *m)
583 {
584         struct mntrpc *q, *r;
585
586         for(q = m->queue; q; q = r) {
587                 r = q->list;
588                 mntfree(q);
589         }
590         m->id = 0;
591         kfree(m->version);
592         m->version = NULL;
593         mntpntfree(m);
594 }
595
596 void
597 mntpntfree(struct mnt *m)
598 {
599         struct mnt *f, **l;
600         struct queue *q;
601
602         spin_lock(&mntalloc.l);
603         l = &mntalloc.list;
604         for(f = *l; f; f = f->list) {
605                 if(f == m) {
606                         *l = m->list;
607                         break;
608                 }
609                 l = &f->list;
610         }
611         m->list = mntalloc.mntfree;
612         mntalloc.mntfree = m;
613         q = m->q;
614         spin_unlock(&mntalloc.l);
615
616         qfree(q);
617 }
618
619 static void
620 mntclose(struct chan *c)
621 {
622         mntclunk(c, Tclunk);
623 }
624
625 static void
626 mntremove(struct chan *c)
627 {
628         mntclunk(c, Tremove);
629 }
630
631 static int
632 mntwstat(struct chan *c, uint8_t *dp, int n)
633 {
634         ERRSTACK(1);
635         struct mnt *m;
636         struct mntrpc *r;
637
638         m = mntchk(c);
639         r = mntralloc(c, m->msize);
640         if(waserror()) {
641                 mntfree(r);
642                 nexterror();
643         }
644         r->request.type = Twstat;
645         r->request.fid = c->fid;
646         r->request.nstat = n;
647         r->request.stat = dp;
648         mountrpc(m, r);
649         poperror();
650         mntfree(r);
651         return n;
652 }
653
654 static long
655 mntread(struct chan *c, void *buf, long n, int64_t off)
656 {
657         uint8_t *p, *e;
658         int nc, cache, isdir, dirlen;
659
660         isdir = 0;
661         cache = c->flag & CCACHE;
662         if(c->qid.type & QTDIR) {
663                 cache = 0;
664                 isdir = 1;
665         }
666
667         p = buf;
668         if(cache) {
669                 nc = cread(c, buf, n, off);
670                 if(nc > 0) {
671                         n -= nc;
672                         if(n == 0)
673                                 return nc;
674                         p += nc;
675                         off += nc;
676                 }
677                 n = mntrdwr(Tread, c, p, n, off);
678                 cupdate(c, p, n, off);
679                 return n + nc;
680         }
681
682         n = mntrdwr(Tread, c, buf, n, off);
683         if(isdir) {
684                 for(e = &p[n]; p+BIT16SZ < e; p += dirlen){
685                         dirlen = BIT16SZ+GBIT16(p);
686                         if(p+dirlen > e)
687                                 break;
688                         validstat(p, dirlen);
689                         mntdirfix(p, c);
690                 }
691                 if(p != e)
692                         error(Esbadstat);
693         }
694         return n;
695 }
696
697 static long
698 mntwrite(struct chan *c, void *buf, long n, int64_t off)
699 {
700         return mntrdwr(Twrite, c, buf, n, off);
701 }
702
703 long
704 mntrdwr(int type, struct chan *c, void *buf, long n, int64_t off)
705 {
706         ERRSTACK(1);
707         struct mnt *m;
708         struct mntrpc *r;       /* TO DO: volatile struct { Mntrpc *r; } r; */
709         char *uba;
710         int cache;
711         uint32_t cnt, nr, nreq;
712
713         m = mntchk(c);
714         uba = buf;
715         cnt = 0;
716         cache = c->flag & CCACHE;
717         if(c->qid.type & QTDIR)
718                 cache = 0;
719         for(;;) {
720                 r = mntralloc(c, m->msize);
721                 if(waserror()) {
722                         mntfree(r);
723                         nexterror();
724                 }
725                 r->request.type = type;
726                 r->request.fid = c->fid;
727                 r->request.offset = off;
728                 r->request.data = uba;
729                 nr = n;
730                 if(nr > m->msize-IOHDRSZ)
731                         nr = m->msize-IOHDRSZ;
732                 r->request.count = nr;
733                 mountrpc(m, r);
734                 nreq = r->request.count;
735                 nr = r->reply.count;
736                 if(nr > nreq)
737                         nr = nreq;
738
739                 if(type == Tread)
740                         r->b = bl2mem(( uint8_t *)uba, r->b, nr);
741                 else if(cache)
742                         cwrite(c, ( uint8_t *)uba, nr, off);
743
744                 poperror();
745                 mntfree(r);
746                 off += nr;
747                 uba += nr;
748                 cnt += nr;
749                 n -= nr;
750                 if(nr != nreq || n == 0 /*|| current->killed*/)
751                         break;
752         }
753         return cnt;
754 }
755
756 void
757 mountrpc(struct mnt *m, struct mntrpc *r)
758 {
759         char *sn, *cn;
760         int t;
761
762         r->reply.tag = 0;
763         r->reply.type = Tmax;   /* can't ever be a valid message type */
764
765         mountio(m, r);
766
767         t = r->reply.type;
768         switch(t) {
769         case Rerror:
770                 error(r->reply.ename);
771         case Rflush:
772                 error(Eintr);
773         default:
774                 if(t == r->request.type+1)
775                         break;
776                 sn = "?";
777                 if(m->c->name != NULL)
778                         sn = m->c->name->s;
779                 cn = "?";
780                 if(r->c != NULL && r->c->name != NULL)
781                         cn = r->c->name->s;
782                 printd("mnt: proc %s %lu: mismatch from %s %s rep 0x%p tag %d fid %d T%d R%d rp %d\n",
783                        "current->text", "current->pid", sn, cn,
784                         r, r->request.tag, r->request.fid, r->request.type,
785                         r->reply.type, r->reply.tag);
786                 error(Emountrpc);
787         }
788 }
789
790 void
791 mountio(struct mnt *m, struct mntrpc *r)
792 {
793         ERRSTACK(1);
794         int n;
795
796         while(waserror()) {
797                 if(m->rip == current)
798                         mntgate(m);
799                 if(strcmp(current_errstr(), Eintr) != 0){
800                         mntflushfree(m, r);
801                         nexterror();
802                 }
803                 r = mntflushalloc(r, m->msize);
804                 /* need one for every waserror call (so this plus one outside) */
805                 poperror();
806         }
807
808         spin_lock(&m->lock);
809         r->m = m;
810         r->list = m->queue;
811         m->queue = r;
812         spin_unlock(&m->lock);
813
814         /* Transmit a file system rpc */
815         if(m->msize == 0)
816                 panic("msize");
817         n = convS2M(&r->request, r->rpc, m->msize);
818         if(n < 0)
819                 panic("bad message type in mountio");
820         if(devtab[m->c->type].write(m->c, r->rpc, n, 0) != n)
821                 error(Emountrpc);
822 /*      r->stime = fastticks(NULL); */
823         r->reqlen = n;
824
825         /* Gate readers onto the mount point one at a time */
826         for(;;) {
827                 spin_lock(&m->lock);
828                 if(m->rip == 0)
829                         break;
830                 spin_unlock(&m->lock);
831                 rendez_sleep(&r->r, rpcattn, r);
832                 if(r->done){
833                         poperror();
834                         mntflushfree(m, r);
835                         return;
836                 }
837         }
838         m->rip = current;
839         spin_unlock(&m->lock);
840         while(r->done == 0) {
841                 if(mntrpcread(m, r) < 0)
842                         error(Emountrpc);
843                 mountmux(m, r);
844         }
845         mntgate(m);
846         poperror();
847         mntflushfree(m, r);
848 }
849
850 static int
851 doread(struct mnt *m, int len)
852 {
853         struct block *b;
854
855         while(qlen(m->q) < len){
856                 b = devtab[m->c->type].bread(m->c, m->msize, 0);
857                 if(b == NULL)
858                         return -1;
859                 if(blocklen(b) == 0){
860                         freeblist(b);
861                         return -1;
862                 }
863                 qaddlist(m->q, b);
864         }
865         return 0;
866 }
867
868 int
869 mntrpcread(struct mnt *m, struct mntrpc *r)
870 {
871         int i, t, len, hlen;
872         struct block *b, **l, *nb;
873
874         r->reply.type = 0;
875         r->reply.tag = 0;
876
877         /* read at least length, type, and tag and pullup to a single block */
878         if(doread(m, BIT32SZ+BIT8SZ+BIT16SZ) < 0)
879                 return -1;
880         nb = pullupqueue(m->q, BIT32SZ+BIT8SZ+BIT16SZ);
881
882         /* read in the rest of the message, avoid ridiculous (for now) message sizes */
883         len = GBIT32(nb->rp);
884         if(len > m->msize){
885                 qdiscard(m->q, qlen(m->q));
886                 return -1;
887         }
888         if(doread(m, len) < 0)
889                 return -1;
890
891         /* pullup the header (i.e. everything except data) */
892         t = nb->rp[BIT32SZ];
893         switch(t){
894         case Rread:
895                 hlen = BIT32SZ+BIT8SZ+BIT16SZ+BIT32SZ;
896                 break;
897         default:
898                 hlen = len;
899                 break;
900         }
901         nb = pullupqueue(m->q, hlen);
902
903         if(convM2S(nb->rp, len, &r->reply) <= 0){
904                 /* bad message, dump it */
905                 printd("mntrpcread: convM2S failed\n");
906                 qdiscard(m->q, len);
907                 return -1;
908         }
909
910         /* hang the data off of the fcall struct */
911         l = &r->b;
912         *l = NULL;
913         do {
914                 b = qremove(m->q);
915                 if(hlen > 0){
916                         b->rp += hlen;
917                         len -= hlen;
918                         hlen = 0;
919                 }
920                 i = BLEN(b);
921                 if(i <= len){
922                         len -= i;
923                         *l = b;
924                         l = &(b->next);
925                 } else {
926                         /* split block and put unused bit back */
927                         nb = allocb(i-len);
928                         memmove(nb->wp, b->rp+len, i-len);
929                         b->wp = b->rp+len;
930                         nb->wp += i-len;
931                         qputback(m->q, nb);
932                         *l = b;
933                         return 0;
934                 }
935         }while(len > 0);
936
937         return 0;
938 }
939
940 void
941 mntgate(struct mnt *m)
942 {
943         struct mntrpc *q;
944
945         spin_lock(&m->lock);
946         m->rip = 0;
947         for(q = m->queue; q; q = q->list) {
948                 if(q->done == 0)
949                         if (rendez_wakeup(&q->r))
950                                 break;
951         }
952         spin_unlock(&m->lock);
953 }
954
955 void
956 mountmux(struct mnt *m, struct mntrpc *r)
957 {
958         struct mntrpc **l, *q;
959
960         spin_lock(&m->lock);
961         l = &m->queue;
962         for(q = *l; q; q = q->list) {
963                 /* look for a reply to a message */
964                 if(q->request.tag == r->reply.tag) {
965                         *l = q->list;
966                         if(q != r) {
967                                 /*
968                                  * Completed someone else.
969                                  * Trade pointers to receive buffer.
970                                  */
971                                 q->reply = r->reply;
972                                 q->b = r->b;
973                                 r->b = NULL;
974                         }
975                         q->done = 1;
976                         spin_unlock(&m->lock);
977                         if(mntstats != NULL)
978                                 (*mntstats)(q->request.type,
979                                         m->c, q->stime,
980                                         q->reqlen + r->replen);
981                         if(q != r)
982                                 rendez_wakeup(&q->r);
983                         return;
984                 }
985                 l = &q->list;
986         }
987         spin_unlock(&m->lock);
988         if(r->reply.type == Rerror){
989                 printd("unexpected reply tag %u; type %d (error %q)\n", r->reply.tag, r->reply.type, r->reply.ename);
990         }else{
991                 printd("unexpected reply tag %u; type %d\n", r->reply.tag, r->reply.type);
992         }
993 }
994
995 /*
996  * Create a new flush request and chain the previous
997  * requests from it
998  */
999 struct mntrpc*
1000 mntflushalloc(struct mntrpc *r, uint32_t iounit)
1001 {
1002         struct mntrpc *fr;
1003
1004         fr = mntralloc(0, iounit);
1005
1006         fr->request.type = Tflush;
1007         if(r->request.type == Tflush)
1008                 fr->request.oldtag = r->request.oldtag;
1009         else
1010                 fr->request.oldtag = r->request.tag;
1011         fr->flushed = r;
1012
1013         return fr;
1014 }
1015
1016 /*
1017  *  Free a chain of flushes.  Remove each unanswered
1018  *  flush and the original message from the unanswered
1019  *  request queue.  Mark the original message as done
1020  *  and if it hasn't been answered set the reply to to
1021  *  Rflush.
1022  */
1023 void
1024 mntflushfree(struct mnt *m, struct mntrpc *r)
1025 {
1026         struct mntrpc *fr;
1027
1028         while(r){
1029                 fr = r->flushed;
1030                 if(!r->done){
1031                         r->reply.type = Rflush;
1032                         mntqrm(m, r);
1033                 }
1034                 if(fr)
1035                         mntfree(r);
1036                 r = fr;
1037         }
1038 }
1039
1040 static int
1041 alloctag(void)
1042 {
1043         int i, j;
1044         uint32_t v;
1045
1046         for(i = 0; i < NMASK; i++){
1047                 v = mntalloc.tagmask[i];
1048                 if(v == ~0UL)
1049                         continue;
1050                 for(j = 0; j < 1<<TAGSHIFT; j++)
1051                         if((v & (1<<j)) == 0){
1052                                 mntalloc.tagmask[i] |= 1<<j;
1053                                 return (i<<TAGSHIFT) + j;
1054                         }
1055         }
1056         /* panic("no devmnt tags left"); */
1057         return NOTAG;
1058 }
1059
1060 static void
1061 freetag(int t)
1062 {
1063         mntalloc.tagmask[t>>TAGSHIFT] &= ~(1<<(t&TAGMASK));
1064 }
1065
1066 struct mntrpc*
1067 mntralloc(struct chan *c, uint32_t msize)
1068 {
1069         struct mntrpc *new;
1070
1071         spin_lock(&mntalloc.l);
1072         new = mntalloc.rpcfree;
1073         if(new == NULL){
1074                 new = kzmalloc(sizeof(struct mntrpc), 0);
1075                 if(new == NULL) {
1076                         spin_unlock(&mntalloc.l);
1077                         exhausted("mount rpc header");
1078                 }
1079                 /*
1080                  * The header is split from the data buffer as
1081                  * mountmux may swap the buffer with another header.
1082                  */
1083                 new->rpc = kzmalloc(msize, KMALLOC_WAIT);
1084                 if(new->rpc == NULL){
1085                         kfree(new);
1086                         spin_unlock(&mntalloc.l);
1087                         exhausted("mount rpc buffer");
1088                 }
1089                 new->rpclen = msize;
1090                 new->request.tag = alloctag();
1091                 if(new->request.tag == NOTAG){
1092                         kfree(new);
1093                         spin_unlock(&mntalloc.l);
1094                         exhausted("rpc tags");
1095                 }
1096         }
1097         else {
1098                 mntalloc.rpcfree = new->list;
1099                 mntalloc.nrpcfree--;
1100                 if(new->rpclen < msize){
1101                         kfree(new->rpc);
1102                         new->rpc = kzmalloc(msize, KMALLOC_WAIT);
1103                         if(new->rpc == NULL){
1104                                 kfree(new);
1105                                 mntalloc.nrpcused--;
1106                                 spin_unlock(&mntalloc.l);
1107                                 exhausted("mount rpc buffer");
1108                         }
1109                         new->rpclen = msize;
1110                 }
1111         }
1112         mntalloc.nrpcused++;
1113         spin_unlock(&mntalloc.l);
1114         new->c = c;
1115         new->done = 0;
1116         new->flushed = NULL;
1117         new->b = NULL;
1118         return new;
1119 }
1120
1121 void
1122 mntfree(struct mntrpc *r)
1123 {
1124         if(r->b != NULL)
1125                 freeblist(r->b);
1126         spin_lock(&mntalloc.l);
1127         if(mntalloc.nrpcfree >= 10){
1128                 kfree(r->rpc);
1129                 freetag(r->request.tag);
1130                 kfree(r);
1131         }
1132         else{
1133                 r->list = mntalloc.rpcfree;
1134                 mntalloc.rpcfree = r;
1135                 mntalloc.nrpcfree++;
1136         }
1137         mntalloc.nrpcused--;
1138         spin_unlock(&mntalloc.l);
1139 }
1140
1141 void
1142 mntqrm(struct mnt *m, struct mntrpc *r)
1143 {
1144         struct mntrpc **l, *f;
1145
1146         spin_lock(&m->lock);
1147         r->done = 1;
1148
1149         l = &m->queue;
1150         for(f = *l; f; f = f->list) {
1151                 if(f == r) {
1152                         *l = r->list;
1153                         break;
1154                 }
1155                 l = &f->list;
1156         }
1157         spin_unlock(&m->lock);
1158 }
1159
1160 struct mnt*
1161 mntchk(struct chan *c)
1162 {
1163         struct mnt *m;
1164
1165         /* This routine is mostly vestiges of prior lives; now it's just sanity checking */
1166
1167         if(c->mchan == NULL)
1168                 panic("mntchk 1: NULL mchan c %s\n", /*c2name(c)*/"channame?");
1169
1170         m = c->mchan->mux;
1171
1172         if(m == NULL)
1173                 printd("mntchk 2: NULL mux c %s c->mchan %s \n", c2name(c), c2name(c->mchan));
1174
1175         /*
1176          * Was it closed and reused (was error(Eshutdown); now, it can't happen)
1177          */
1178         if(m->id == 0 || m->id >= c->dev)
1179                 panic("mntchk 3: can't happen");
1180
1181         return m;
1182 }
1183
1184 /*
1185  * Rewrite channel type and dev for in-flight data to
1186  * reflect local values.  These entries are known to be
1187  * the first two in the Dir encoding after the count.
1188  */
1189 void
1190 mntdirfix(uint8_t *dirbuf, struct chan *c)
1191 {
1192         unsigned int r;
1193
1194         r = devtab[c->type].dc;
1195         dirbuf += BIT16SZ;      /* skip count */
1196         PBIT16(dirbuf, r);
1197         dirbuf += BIT16SZ;
1198         PBIT32(dirbuf, c->dev);
1199 }
1200
1201 int
1202 rpcattn(void *v)
1203 {
1204         struct mntrpc *r;
1205
1206         r = v;
1207         return r->done || r->m->rip == 0;
1208 }
1209
1210 struct dev mntdevtab __devtab = {
1211         'M',
1212         "mnt",
1213
1214         devreset,
1215         mntinit,
1216         devshutdown,
1217         mntattach,
1218         mntwalk,
1219         mntstat,
1220         mntopen,
1221         mntcreate,
1222         mntclose,
1223         mntread,
1224         devbread,
1225         mntwrite,
1226         devbwrite,
1227         mntremove,
1228         mntwstat,
1229 };