Updated memory barrier stuff
[akaros.git] / user / parlib / mcs.c
1 #include <vcore.h>
2 #include <mcs.h>
3 #include <arch/atomic.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <uthread.h>
7
8 // MCS locks
9 void mcs_lock_init(struct mcs_lock *lock)
10 {
11         memset(lock,0,sizeof(mcs_lock_t));
12 }
13
14 static inline mcs_lock_qnode_t *mcs_qnode_swap(mcs_lock_qnode_t **addr,
15                                                mcs_lock_qnode_t *val)
16 {
17         return (mcs_lock_qnode_t*)atomic_swap_ptr((void**)addr, val);
18 }
19
20 void mcs_lock_lock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
21 {
22         qnode->next = 0;
23         cmb();  /* swap provides a CPU mb() */
24         mcs_lock_qnode_t *predecessor = mcs_qnode_swap(&lock->lock, qnode);
25         if (predecessor) {
26                 qnode->locked = 1;
27                 wmb();
28                 predecessor->next = qnode;
29                 /* no need for a wrmb(), since this will only get unlocked after they
30                  * read our previous write */
31                 while (qnode->locked)
32                         cpu_relax();
33         }
34         cmb();  /* just need a cmb, the swap handles the CPU wmb/wrmb() */
35 }
36
37 void mcs_lock_unlock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
38 {
39         if (qnode->next == 0) {
40                 cmb();  /* no need for CPU mbs, since there's an atomic_swap() */
41                 mcs_lock_qnode_t *old_tail = mcs_qnode_swap(&lock->lock,0);
42                 if (old_tail == qnode)
43                         return;
44                 mcs_lock_qnode_t *usurper = mcs_qnode_swap(&lock->lock,old_tail);
45                 while (qnode->next == 0)
46                         cpu_relax();
47                 if (usurper)
48                         usurper->next = qnode->next;
49                 else
50                         qnode->next->locked = 0;
51         } else {
52                 /* mb()s necessary since we didn't call an atomic_swap() */
53                 wmb();  /* need to make sure any previous writes don't pass unlocking */
54                 rwmb(); /* need to make sure any reads happen before the unlocking */
55                 qnode->next->locked = 0;
56         }
57 }
58
59 /* We don't bother saving the state, like we do with irqsave, since we can use
60  * whether or not we are in vcore context to determine that.  This means you
61  * shouldn't call this from those moments when you fake being in vcore context
62  * (when switching into the TLS, etc). */
63 void mcs_lock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
64 {
65         if (!in_vcore_context()) {
66                 if (current_uthread)
67                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
68                 cmb();  /* don't issue the flag write before the vcore_id() read */
69                 disable_notifs(vcore_id());
70                 cmb();  /* don't issue the flag write before the disable */
71                 if (current_uthread)
72                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
73         }
74         mcs_lock_lock(lock, qnode);
75 }
76
77 void mcs_unlock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
78 {
79         mcs_lock_unlock(lock, qnode);
80         if (!in_vcore_context() && in_multi_mode()) {
81                 if (current_uthread)
82                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
83                 cmb();  /* don't issue the flag write before the vcore_id() read */
84                 enable_notifs(vcore_id());
85                 cmb();  /* don't issue the flag write before the enable */
86                 if (current_uthread)
87                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
88         }
89 }
90
91 // MCS dissemination barrier!
92 int mcs_barrier_init(mcs_barrier_t* b, size_t np)
93 {
94         if(np > max_vcores())
95                 return -1;
96         b->allnodes = (mcs_dissem_flags_t*)malloc(np*sizeof(mcs_dissem_flags_t));
97         memset(b->allnodes,0,np*sizeof(mcs_dissem_flags_t));
98         b->nprocs = np;
99
100         b->logp = (np & (np-1)) != 0;
101         while(np >>= 1)
102                 b->logp++;
103
104         size_t i,k;
105         for(i = 0; i < b->nprocs; i++)
106         {
107                 b->allnodes[i].parity = 0;
108                 b->allnodes[i].sense = 1;
109
110                 for(k = 0; k < b->logp; k++)
111                 {
112                         size_t j = (i+(1<<k)) % b->nprocs;
113                         b->allnodes[i].partnerflags[0][k] = &b->allnodes[j].myflags[0][k];
114                         b->allnodes[i].partnerflags[1][k] = &b->allnodes[j].myflags[1][k];
115                 } 
116         }
117
118         return 0;
119 }
120
121 void mcs_barrier_wait(mcs_barrier_t* b, size_t pid)
122 {
123         mcs_dissem_flags_t* localflags = &b->allnodes[pid];
124         size_t i;
125         for(i = 0; i < b->logp; i++)
126         {
127                 *localflags->partnerflags[localflags->parity][i] = localflags->sense;
128                 while(localflags->myflags[localflags->parity][i] != localflags->sense);
129         }
130         if(localflags->parity)
131                 localflags->sense = 1-localflags->sense;
132         localflags->parity = 1-localflags->parity;
133 }
134