parlib: Use 'const' in the set_dtls() interface
[akaros.git] / user / parlib / dtls.c
1 /* Copyright (c) 2012 The Regents of the University of California
2  * Kevin Klues <klueska@cs.berkeley.edu>
3  *
4  * See LICENSE for details. */
5
6 #include <parlib/assert.h>
7 #include <parlib/dtls.h>
8 #include <parlib/slab.h>
9 #include <parlib/spinlock.h>
10 #include <stddef.h>
11
12 /* The current dymamic tls implementation uses a locked linked list
13  * to find the key for a given thread. We should probably find a better way to
14  * do this based on a custom lock-free hash table or something. */
15 #include <parlib/spinlock.h>
16 #include <sys/queue.h>
17
18 /* Define some number of static keys, for which the memory containing the keys
19  * and the per-thread memory for the values associated with those keys is
20  * allocated statically. This is adapted from glibc's notion of the
21  * "specific_1stblock" field embedded directly into its pthread structure for
22  * pthread_get/specific() calls. */
23 #define NUM_STATIC_KEYS 32
24
25 /* The dynamic tls key structure */
26 struct dtls_key {
27         int id;
28         int ref_count;
29         bool valid;
30         void (*dtor)(void *);
31 };
32
33 /* The definition of a dtls_key list and its elements */
34 struct dtls_value {
35         TAILQ_ENTRY(dtls_value) link;
36         struct dtls_key *key;
37         const void *dtls;
38 };
39 TAILQ_HEAD(dtls_list, dtls_value);
40
41 /* A struct containing all of the per thread (i.e. vcore or uthread) data
42  * associated with dtls */
43 typedef struct dtls_data {
44         /* A per-thread list of dtls regions */
45         struct dtls_list list;
46         /* Memory to hold dtls values for the first NUM_STATIC_KEYS keys */
47         struct dtls_value early_values[NUM_STATIC_KEYS];
48 } dtls_data_t;
49
50 /* A slab of dtls keys (global to all threads) */
51 static struct kmem_cache *__dtls_keys_cache;
52
53 /* A slab of values for use when mapping a dtls_key to its per-thread value */
54 struct kmem_cache *__dtls_values_cache;
55
56 static __thread dtls_data_t __dtls_data;
57 static __thread bool __dtls_initialized;
58 static struct dtls_key static_dtls_keys[NUM_STATIC_KEYS];
59 static int num_dtls_keys;
60
61 /* Initialize the slab caches for allocating dtls keys and values. */
62 int dtls_cache_init(void)
63 {
64         /* Make sure this only runs once */
65         static bool initialized;
66
67         if (initialized)
68                 return 0;
69         initialized = true;
70
71         /* Initialize the global cache of dtls_keys */
72         __dtls_keys_cache =
73             kmem_cache_create("dtls_keys_cache", sizeof(struct dtls_key),
74                               __alignof__(struct dtls_key), 0, NULL, NULL);
75
76         /* Initialize the global cache of dtls_values */
77         __dtls_values_cache =
78             kmem_cache_create("dtls_values_cache", sizeof(struct dtls_value),
79                               __alignof__(struct dtls_value), 0, NULL, NULL);
80
81         return 0;
82 }
83
84 static dtls_key_t __allocate_dtls_key(void)
85 {
86         dtls_key_t key;
87         int keyid = __sync_fetch_and_add(&num_dtls_keys, 1);
88
89         if (keyid < NUM_STATIC_KEYS) {
90                 key = &static_dtls_keys[keyid];
91         } else {
92                 dtls_cache_init();
93                 key = kmem_cache_alloc(__dtls_keys_cache, 0);
94         }
95         assert(key);
96         key->id = keyid;
97         key->ref_count = 1;
98         return key;
99 }
100
101 static void __maybe_free_dtls_key(dtls_key_t key)
102 {
103         int ref_count = __sync_add_and_fetch(&key->ref_count, -1);
104
105         if (ref_count == 0 && key->id >= NUM_STATIC_KEYS)
106                 kmem_cache_free(__dtls_keys_cache, key);
107 }
108
109 static struct dtls_value *__allocate_dtls_value(struct dtls_data *dtls_data,
110                                                 struct dtls_key *key)
111 {
112         struct dtls_value *v;
113
114         if (key->id < NUM_STATIC_KEYS)
115                 v = &dtls_data->early_values[key->id];
116         else
117                 v = kmem_cache_alloc(__dtls_values_cache, 0);
118         assert(v);
119         return v;
120 }
121
122 static void __free_dtls_value(struct dtls_value *v)
123 {
124         if (v->key->id >= NUM_STATIC_KEYS)
125                 kmem_cache_free(__dtls_values_cache, v);
126 }
127
128 dtls_key_t dtls_key_create(dtls_dtor_t dtor)
129 {
130         dtls_key_t key = __allocate_dtls_key();
131
132         key->valid = true;
133         key->dtor = dtor;
134         return key;
135 }
136
137 void dtls_key_delete(dtls_key_t key)
138 {
139         assert(key);
140
141         key->valid = false;
142         __maybe_free_dtls_key(key);
143 }
144
145 static inline struct dtls_value *__get_dtls(dtls_data_t *dtls_data,
146                                             dtls_key_t key)
147 {
148         struct dtls_value *v;
149
150         assert(key);
151         if (key->id < NUM_STATIC_KEYS) {
152                 v = &dtls_data->early_values[key->id];
153                 if (v->key != NULL)
154                         return v;
155         } else {
156                 TAILQ_FOREACH(v, &dtls_data->list, link)
157                         if (v->key == key)
158                                 return v;
159         }
160         return NULL;
161 }
162
163 static inline void __set_dtls(dtls_data_t *dtls_data, dtls_key_t key,
164                               const void *dtls)
165 {
166         struct dtls_value *v;
167
168         assert(key);
169         v = __get_dtls(dtls_data, key);
170         if (!v) {
171                 v = __allocate_dtls_value(dtls_data, key);
172                 __sync_fetch_and_add(&key->ref_count, 1);
173                 v->key = key;
174                 TAILQ_INSERT_HEAD(&dtls_data->list, v, link);
175         }
176         v->dtls = dtls;
177 }
178
179 static inline void __destroy_dtls(dtls_data_t *dtls_data)
180 {
181         struct dtls_value *v, *n;
182         dtls_key_t key;
183         const void *dtls;
184
185         v = TAILQ_FIRST(&dtls_data->list);
186         while (v != NULL) {
187                 key = v->key;
188                 /* The dtor must be called outside of a spinlock so that it can call
189                  * code that may deschedule it for a while (i.e. a mutex). Probably a
190                  * good idea anyway since it can be arbitrarily long and is written by
191                  * the user. Note, there is a small race here on the valid field,
192                  * whereby we may run a destructor on an invalid key. At least the keys
193                  * memory wont be deleted though, as protected by the ref count. Any
194                  * reasonable usage of this interface should safeguard that a key is
195                  * never destroyed before all of the threads that use it have exited
196                  * anyway. */
197                 if (key->valid && key->dtor) {
198                         dtls = v->dtls;
199                         v->dtls = NULL;
200                         key->dtor((void*)dtls);
201                 }
202                 n = TAILQ_NEXT(v, link);
203                 TAILQ_REMOVE(&dtls_data->list, v, link);
204                 /* Free both the key (which is v->key) and v *after* removing v from the
205                  * list.  It's possible that free() will call back into the DTLS (e.g.
206                  * pthread_getspecific()), and v must be off the list by then.
207                  *
208                  * For a similar, hilarious bug in glibc, check out:
209                  * https://sourceware.org/bugzilla/show_bug.cgi?id=3317 */
210                 __maybe_free_dtls_key(key);
211                 __free_dtls_value(v);
212                 v = n;
213         }
214 }
215
216 void set_dtls(dtls_key_t key, const void *dtls)
217 {
218         bool initialized = true;
219         dtls_data_t *dtls_data = NULL;
220
221         if (!__dtls_initialized) {
222                 initialized = false;
223                 __dtls_initialized = true;
224         }
225         dtls_data = &__dtls_data;
226         if (!initialized)
227                 TAILQ_INIT(&dtls_data->list);
228         __set_dtls(dtls_data, key, dtls);
229 }
230
231 void *get_dtls(dtls_key_t key)
232 {
233         dtls_data_t *dtls_data = NULL;
234         struct dtls_value *v;
235
236         if (!__dtls_initialized)
237                 return NULL;
238         dtls_data = &__dtls_data;
239         v = __get_dtls(dtls_data, key);
240         return v ? (void*)v->dtls : NULL;
241 }
242
243 void destroy_dtls(void)
244 {
245         dtls_data_t *dtls_data = NULL;
246
247         if (!__dtls_initialized)
248                 return;
249         dtls_data = &__dtls_data;
250         __destroy_dtls(dtls_data);
251 }