DONT_MIGRATE set more carefully
[akaros.git] / user / parlib / mcs.c
1 #include <vcore.h>
2 #include <mcs.h>
3 #include <arch/atomic.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <uthread.h>
7
8 // MCS locks
9 void mcs_lock_init(struct mcs_lock *lock)
10 {
11         memset(lock,0,sizeof(mcs_lock_t));
12 }
13
14 static inline mcs_lock_qnode_t *mcs_qnode_swap(mcs_lock_qnode_t **addr,
15                                                mcs_lock_qnode_t *val)
16 {
17         return (mcs_lock_qnode_t*)atomic_swap_ptr((void**)addr, val);
18 }
19
20 void mcs_lock_lock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
21 {
22         qnode->next = 0;
23         cmb();  /* swap provides a CPU mb() */
24         mcs_lock_qnode_t *predecessor = mcs_qnode_swap(&lock->lock, qnode);
25         if (predecessor) {
26                 qnode->locked = 1;
27                 wmb();
28                 predecessor->next = qnode;
29                 /* no need for a wrmb(), since this will only get unlocked after they
30                  * read our previous write */
31                 while (qnode->locked)
32                         cpu_relax();
33         }
34         cmb();  /* just need a cmb, the swap handles the CPU wmb/wrmb() */
35 }
36
37 void mcs_lock_unlock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
38 {
39         if (qnode->next == 0) {
40                 cmb();  /* no need for CPU mbs, since there's an atomic_swap() */
41                 mcs_lock_qnode_t *old_tail = mcs_qnode_swap(&lock->lock,0);
42                 if (old_tail == qnode)
43                         return;
44                 mcs_lock_qnode_t *usurper = mcs_qnode_swap(&lock->lock,old_tail);
45                 while (qnode->next == 0)
46                         cpu_relax();
47                 if (usurper)
48                         usurper->next = qnode->next;
49                 else
50                         qnode->next->locked = 0;
51         } else {
52                 /* mb()s necessary since we didn't call an atomic_swap() */
53                 wmb();  /* need to make sure any previous writes don't pass unlocking */
54                 rwmb(); /* need to make sure any reads happen before the unlocking */
55                 qnode->next->locked = 0;
56         }
57 }
58
59 /* We don't bother saving the state, like we do with irqsave, since we can use
60  * whether or not we are in vcore context to determine that.  This means you
61  * shouldn't call this from those moments when you fake being in vcore context
62  * (when switching into the TLS, etc). */
63 void mcs_lock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
64 {
65         if (!in_vcore_context() && in_multi_mode()) {
66                 if (current_uthread)
67                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
68                 cmb();  /* don't issue the flag write before the vcore_id() read */
69                 disable_notifs(vcore_id());
70         }
71         mcs_lock_lock(lock, qnode);
72 }
73
74 /* Note we turn off the DONT_MIGRATE flag before enabling notifs.  This is fine,
75  * since we wouldn't receive any notifs that could lead to us migrating after we
76  * set DONT_MIGRATE but before enable_notifs().  We need it to be in this order,
77  * since we need to check messages after ~DONT_MIGRATE. */
78 void mcs_unlock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
79 {
80         mcs_lock_unlock(lock, qnode);
81         if (!in_vcore_context() && in_multi_mode()) {
82                 if (current_uthread)
83                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
84                 cmb();  /* don't enable before ~DONT_MIGRATE */
85                 enable_notifs(vcore_id());
86         }
87 }
88
89 // MCS dissemination barrier!
90 int mcs_barrier_init(mcs_barrier_t* b, size_t np)
91 {
92         if(np > max_vcores())
93                 return -1;
94         b->allnodes = (mcs_dissem_flags_t*)malloc(np*sizeof(mcs_dissem_flags_t));
95         memset(b->allnodes,0,np*sizeof(mcs_dissem_flags_t));
96         b->nprocs = np;
97
98         b->logp = (np & (np-1)) != 0;
99         while(np >>= 1)
100                 b->logp++;
101
102         size_t i,k;
103         for(i = 0; i < b->nprocs; i++)
104         {
105                 b->allnodes[i].parity = 0;
106                 b->allnodes[i].sense = 1;
107
108                 for(k = 0; k < b->logp; k++)
109                 {
110                         size_t j = (i+(1<<k)) % b->nprocs;
111                         b->allnodes[i].partnerflags[0][k] = &b->allnodes[j].myflags[0][k];
112                         b->allnodes[i].partnerflags[1][k] = &b->allnodes[j].myflags[1][k];
113                 } 
114         }
115
116         return 0;
117 }
118
119 void mcs_barrier_wait(mcs_barrier_t* b, size_t pid)
120 {
121         mcs_dissem_flags_t* localflags = &b->allnodes[pid];
122         size_t i;
123         for(i = 0; i < b->logp; i++)
124         {
125                 *localflags->partnerflags[localflags->parity][i] = localflags->sense;
126                 while(localflags->myflags[localflags->parity][i] != localflags->sense);
127         }
128         if(localflags->parity)
129                 localflags->sense = 1-localflags->sense;
130         localflags->parity = 1-localflags->parity;
131 }
132