MCS locks properly disable thread migration
[akaros.git] / user / parlib / mcs.c
1 #include <vcore.h>
2 #include <mcs.h>
3 #include <arch/atomic.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <uthread.h>
7
8 // MCS locks
9 void mcs_lock_init(struct mcs_lock *lock)
10 {
11         memset(lock,0,sizeof(mcs_lock_t));
12 }
13
14 static inline mcs_lock_qnode_t *mcs_qnode_swap(mcs_lock_qnode_t **addr,
15                                                mcs_lock_qnode_t *val)
16 {
17         return (mcs_lock_qnode_t*)atomic_swap_ptr((void**)addr, val);
18 }
19
20 void mcs_lock_lock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
21 {
22         qnode->next = 0;
23         mcs_lock_qnode_t* predecessor = mcs_qnode_swap(&lock->lock,qnode);
24         if(predecessor)
25         {
26                 qnode->locked = 1;
27                 predecessor->next = qnode;
28                 while(qnode->locked)
29                         cpu_relax();
30         }
31 }
32
33 void mcs_lock_unlock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
34 {
35         if(qnode->next == 0)
36         {
37                 mcs_lock_qnode_t* old_tail = mcs_qnode_swap(&lock->lock,0);
38                 if(old_tail == qnode)
39                         return;
40
41                 mcs_lock_qnode_t* usurper = mcs_qnode_swap(&lock->lock,old_tail);
42                 while(qnode->next == 0);
43                 if(usurper)
44                         usurper->next = qnode->next;
45                 else
46                         qnode->next->locked = 0;
47         }
48         else
49                 qnode->next->locked = 0;
50 }
51
52 /* We don't bother saving the state, like we do with irqsave, since we can use
53  * whether or not we are in vcore context to determine that.  This means you
54  * shouldn't call this from those moments when you fake being in vcore context
55  * (when switching into the TLS, etc). */
56 void mcs_lock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
57 {
58         if (!in_vcore_context()) {
59                 if (current_uthread)
60                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
61                 wmb();
62                 disable_notifs(vcore_id());
63                 wmb();
64                 if (current_uthread)
65                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
66         }
67         mcs_lock_lock(lock, qnode);
68 }
69
70 void mcs_unlock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
71 {
72         mcs_lock_unlock(lock, qnode);
73         if (!in_vcore_context() && in_multi_mode()) {
74                 if (current_uthread)
75                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
76                 wmb();
77                 enable_notifs(vcore_id());
78                 wmb();
79                 if (current_uthread)
80                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
81         }
82 }
83
84 // MCS dissemination barrier!
85 int mcs_barrier_init(mcs_barrier_t* b, size_t np)
86 {
87         if(np > max_vcores())
88                 return -1;
89         b->allnodes = (mcs_dissem_flags_t*)malloc(np*sizeof(mcs_dissem_flags_t));
90         memset(b->allnodes,0,np*sizeof(mcs_dissem_flags_t));
91         b->nprocs = np;
92
93         b->logp = (np & (np-1)) != 0;
94         while(np >>= 1)
95                 b->logp++;
96
97         size_t i,k;
98         for(i = 0; i < b->nprocs; i++)
99         {
100                 b->allnodes[i].parity = 0;
101                 b->allnodes[i].sense = 1;
102
103                 for(k = 0; k < b->logp; k++)
104                 {
105                         size_t j = (i+(1<<k)) % b->nprocs;
106                         b->allnodes[i].partnerflags[0][k] = &b->allnodes[j].myflags[0][k];
107                         b->allnodes[i].partnerflags[1][k] = &b->allnodes[j].myflags[1][k];
108                 } 
109         }
110
111         return 0;
112 }
113
114 void mcs_barrier_wait(mcs_barrier_t* b, size_t pid)
115 {
116         mcs_dissem_flags_t* localflags = &b->allnodes[pid];
117         size_t i;
118         for(i = 0; i < b->logp; i++)
119         {
120                 *localflags->partnerflags[localflags->parity][i] = localflags->sense;
121                 while(localflags->myflags[localflags->parity][i] != localflags->sense);
122         }
123         if(localflags->parity)
124                 localflags->sense = 1-localflags->sense;
125         localflags->parity = 1-localflags->parity;
126 }
127