9b25f27bfa6e7ba87bf4f04aacb8074e41fe6656
[akaros.git] / user / parlib / mcs.c
1 #include <vcore.h>
2 #include <mcs.h>
3 #include <arch/atomic.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <uthread.h>
7
8 // MCS locks
9 void mcs_lock_init(struct mcs_lock *lock)
10 {
11         memset(lock,0,sizeof(mcs_lock_t));
12 }
13
14 static inline mcs_lock_qnode_t *mcs_qnode_swap(mcs_lock_qnode_t **addr,
15                                                mcs_lock_qnode_t *val)
16 {
17         return (mcs_lock_qnode_t*)atomic_swap_ptr((void**)addr, val);
18 }
19
20 void mcs_lock_lock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
21 {
22         qnode->next = 0;
23         mcs_lock_qnode_t *predecessor = mcs_qnode_swap(&lock->lock, qnode);
24         if (predecessor) {
25                 qnode->locked = 1;
26                 wmb();
27                 predecessor->next = qnode;
28                 /* no need for a wrmb(), since this will only get unlocked after they
29                  * read our previous write */
30                 while (qnode->locked)
31                         cpu_relax();
32         }
33         cmb();  /* just need a cmb, the swap handles the CPU wmb/wrmb() */
34 }
35
36 void mcs_lock_unlock(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
37 {
38         if (qnode->next == 0) {
39                 cmb();  /* no need for CPU mbs, since there's an atomic_swap() */
40                 mcs_lock_qnode_t *old_tail = mcs_qnode_swap(&lock->lock,0);
41                 if (old_tail == qnode)
42                         return;
43                 mcs_lock_qnode_t *usurper = mcs_qnode_swap(&lock->lock,old_tail);
44                 while (qnode->next == 0)
45                         cpu_relax();
46                 if (usurper)
47                         usurper->next = qnode->next;
48                 else
49                         qnode->next->locked = 0;
50         } else {
51                 /* mb()s necessary since we didn't call an atomic_swap() */
52                 wmb();  /* need to make sure any previous writes don't pass unlocking */
53                 rwmb(); /* need to make sure any reads happen before the unlocking */
54                 qnode->next->locked = 0;
55         }
56 }
57
58 /* We don't bother saving the state, like we do with irqsave, since we can use
59  * whether or not we are in vcore context to determine that.  This means you
60  * shouldn't call this from those moments when you fake being in vcore context
61  * (when switching into the TLS, etc). */
62 void mcs_lock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
63 {
64         if (!in_vcore_context()) {
65                 if (current_uthread)
66                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
67                 cmb();  /* don't issue the flag write before the vcore_id() read */
68                 disable_notifs(vcore_id());
69                 cmb();  /* don't issue the flag write before the disable */
70                 if (current_uthread)
71                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
72         }
73         mcs_lock_lock(lock, qnode);
74 }
75
76 void mcs_unlock_notifsafe(struct mcs_lock *lock, struct mcs_lock_qnode *qnode)
77 {
78         mcs_lock_unlock(lock, qnode);
79         if (!in_vcore_context() && in_multi_mode()) {
80                 if (current_uthread)
81                         current_uthread->flags |= UTHREAD_DONT_MIGRATE;
82                 cmb();  /* don't issue the flag write before the vcore_id() read */
83                 enable_notifs(vcore_id());
84                 cmb();  /* don't issue the flag write before the enable */
85                 if (current_uthread)
86                         current_uthread->flags &= ~UTHREAD_DONT_MIGRATE;
87         }
88 }
89
90 // MCS dissemination barrier!
91 int mcs_barrier_init(mcs_barrier_t* b, size_t np)
92 {
93         if(np > max_vcores())
94                 return -1;
95         b->allnodes = (mcs_dissem_flags_t*)malloc(np*sizeof(mcs_dissem_flags_t));
96         memset(b->allnodes,0,np*sizeof(mcs_dissem_flags_t));
97         b->nprocs = np;
98
99         b->logp = (np & (np-1)) != 0;
100         while(np >>= 1)
101                 b->logp++;
102
103         size_t i,k;
104         for(i = 0; i < b->nprocs; i++)
105         {
106                 b->allnodes[i].parity = 0;
107                 b->allnodes[i].sense = 1;
108
109                 for(k = 0; k < b->logp; k++)
110                 {
111                         size_t j = (i+(1<<k)) % b->nprocs;
112                         b->allnodes[i].partnerflags[0][k] = &b->allnodes[j].myflags[0][k];
113                         b->allnodes[i].partnerflags[1][k] = &b->allnodes[j].myflags[1][k];
114                 } 
115         }
116
117         return 0;
118 }
119
120 void mcs_barrier_wait(mcs_barrier_t* b, size_t pid)
121 {
122         mcs_dissem_flags_t* localflags = &b->allnodes[pid];
123         size_t i;
124         for(i = 0; i < b->logp; i++)
125         {
126                 *localflags->partnerflags[localflags->parity][i] = localflags->sense;
127                 while(localflags->myflags[localflags->parity][i] != localflags->sense);
128         }
129         if(localflags->parity)
130                 localflags->sense = 1-localflags->sense;
131         localflags->parity = 1-localflags->parity;
132 }
133