Reorganized user-memory checking/copying functions
[akaros.git] / kern / src / umem.c
1 /* Copyright (c) 2009, 2010 The Regents of the University of California
2  * Barret Rhoden <brho@cs.berkeley.edu>
3  * Andrew Waterman <waterman@cs.berkeley.edu>
4  * See LICENSE for details.
5  *
6  * Functions for working with userspace's address space.  The user_mem ones need
7  * to involve some form of pinning (TODO), and that global static needs to go. */
8
9 #include <ros/common.h>
10 #include <umem.h>
11 #include <process.h>
12 #include <error.h>
13 #include <kmalloc.h>
14 #include <assert.h>
15 #include <pmap.h>
16
17 /**
18  * @brief Global variable used to store erroneous virtual addresses as the
19  *        result of a failed user_mem_check().
20  *
21  * zra: What if two checks fail at the same time? Maybe this should be per-cpu?
22  *
23  */
24 static void *DANGEROUS RACY user_mem_check_addr;
25
26 /**
27  * @brief Check that an environment is allowed to access the range of memory
28  * [va, va+len) with permissions 'perm | PTE_P'.
29  *
30  * Normally 'perm' will contain PTE_U at least, but this is not required.  The
31  * function get_va_perms only checks for PTE_U, PTE_W, and PTE_P.  It won't
32  * check for things like PTE_PS, PTE_A, etc.
33  * 'va' and 'len' need not be page-aligned;
34  *
35  * A user program can access a virtual address if:
36  *     -# the address is below ULIM
37  *     -# the page table gives it permission.  
38  *
39  * If there is an error, 'user_mem_check_addr' is set to the first
40  * erroneous virtual address.
41  *
42  * @param p    the process associated with the user program trying to access
43  *             the virtual address range
44  * @param va   the first virtual address in the range
45  * @param len  the length of the virtual address range
46  * @param perm the permissions the user is trying to access the virtual address 
47  *             range with
48  *
49  * @return VA a pointer of type COUNT(len) to the address range
50  * @return NULL trying to access this range of virtual addresses is not allowed
51  */
52 void *user_mem_check(struct proc *p, const void *DANGEROUS va, size_t len,
53                      int perm)
54 {
55         if (len == 0) {
56                 warn("Called user_mem_check with a len of 0. Don't do that. Returning NULL");
57                 return NULL;
58         }
59         
60         // TODO - will need to sort this out wrt page faulting / PTE_P
61         // also could be issues with sleeping and waking up to find pages
62         // are unmapped, though i think the lab ignores this since the 
63         // kernel is uninterruptible
64         void *DANGEROUS start, *DANGEROUS end;
65         size_t num_pages, i;
66         int page_perms = 0;
67
68         perm |= PTE_P;
69         start = ROUNDDOWN((void*DANGEROUS)va, PGSIZE);
70         end = ROUNDUP((void*DANGEROUS)va + len, PGSIZE);
71         if (start >= end) {
72                 warn("Blimey!  Wrap around in VM range calculation!");  
73                 return NULL;
74         }
75         num_pages = LA2PPN(end - start);
76         for (i = 0; i < num_pages; i++, start += PGSIZE) {
77                 page_perms = get_va_perms(p->env_pgdir, start);
78                 // ensures the bits we want on are turned on.  if not, error out
79                 if ((page_perms & perm) != perm) {
80                         if (i == 0)
81                                 user_mem_check_addr = (void*DANGEROUS)va;
82                         else
83                                 user_mem_check_addr = start;
84                         return NULL;
85                 }
86         }
87         // this should never be needed, since the perms should catch it
88         if ((uintptr_t)end > ULIM) {
89                 warn ("I suck - Bug in user permission mappings!");
90                 return NULL;
91         }
92         return (void *COUNT(len))TC(va);
93 }
94
95 /**
96  * @brief Checks that process 'p' is allowed to access the range
97  * of memory [va, va+len) with permissions 'perm | PTE_U'. Destroy 
98  * process 'p' if the assertion fails.
99  *
100  * This function is identical to user_mem_assert() except that it has a side
101  * affect of destroying the process 'p' if the memory check fails.
102  *
103  * @param p    the process associated with the user program trying to access
104  *             the virtual address range
105  * @param va   the first virtual address in the range
106  * @param len  the length of the virtual address range
107  * @param perm the permissions the user is trying to access the virtual address 
108  *             range with
109  *
110  * @return VA a pointer of type COUNT(len) to the address range
111  * @return NULL trying to access this range of virtual addresses is not allowed
112  *              process 'p' is destroyed
113  */
114 void *user_mem_assert(struct proc *p, const void *DANGEROUS va, size_t len,
115                        int perm)
116 {
117         if (len == 0) {
118                 warn("Called user_mem_assert with a len of 0. Don't do that. Returning NULL");
119                 return NULL;
120         }
121         
122         void *COUNT(len) res = user_mem_check(p, va, len, perm | PTE_USER_RO);
123         if (!res) {
124                 cprintf("[%08x] user_mem_check assertion failure for "
125                         "va %08x\n", p->pid, user_mem_check_addr);
126                 proc_destroy(p);        // may not return
127         return NULL;
128         }
129     return res;
130 }
131
132 /**
133  * @brief Copies data from a user buffer to a kernel buffer.
134  * 
135  * @param p    the process associated with the user program
136  *             from which the buffer is being copied
137  * @param dest the destination address of the kernel buffer
138  * @param va   the address of the userspace buffer from which we are copying
139  * @param len  the length of the userspace buffer
140  *
141  * @return ESUCCESS on success
142  * @return -EFAULT  the page assocaited with 'va' is not present, the user 
143  *                  lacks the proper permissions, or there was an invalid 'va'
144  */
145 int memcpy_from_user(struct proc *p, void *dest, const void *DANGEROUS va,
146                      size_t len)
147 {
148         const void *DANGEROUS start, *DANGEROUS end;
149         size_t num_pages, i;
150         pte_t *pte;
151         uintptr_t perm = PTE_P | PTE_USER_RO;
152         size_t bytes_copied = 0;
153
154         static_assert(ULIM % PGSIZE == 0 && ULIM != 0); // prevent wrap-around
155
156         start = ROUNDDOWN(va, PGSIZE);
157         end = ROUNDUP(va + len, PGSIZE);
158
159         if (start >= (void*SNT)ULIM || end > (void*SNT)ULIM)
160                 return -EFAULT;
161
162         num_pages = LA2PPN(end - start);
163         for (i = 0; i < num_pages; i++) {
164                 pte = pgdir_walk(p->env_pgdir, start + i * PGSIZE, 0);
165                 if (!pte)
166                         return -EFAULT;
167                 if ((*pte & PTE_P) && (*pte & PTE_USER_RO) != PTE_USER_RO)
168                         return -EFAULT;
169                 if (!(*pte & PTE_P))
170                         if (handle_page_fault(p, (uintptr_t)start + i * PGSIZE, PROT_READ))
171                                 return -EFAULT;
172
173                 void *kpage = KADDR(PTE_ADDR(*pte));
174                 const void *src_start = i > 0 ? kpage : kpage + (va - start);
175                 void *dst_start = dest + bytes_copied;
176                 size_t copy_len = PGSIZE;
177                 if (i == 0)
178                         copy_len -= va - start;
179                 if (i == num_pages-1)
180                         copy_len -= end - (va + len);
181
182                 memcpy(dst_start, src_start, copy_len);
183                 bytes_copied += copy_len;
184         }
185         assert(bytes_copied == len);
186         return 0;
187 }
188
189 /**
190  * @brief Copies data to a user buffer from a kernel buffer.
191  * 
192  * @param p    the process associated with the user program
193  *             to which the buffer is being copied
194  * @param dest the destination address of the user buffer
195  * @param va   the address of the kernel buffer from which we are copying
196  * @param len  the length of the user buffer
197  *
198  * @return ESUCCESS on success
199  * @return -EFAULT  the page assocaited with 'va' is not present, the user 
200  *                  lacks the proper permissions, or there was an invalid 'va'
201  */
202 int memcpy_to_user(struct proc *p, void *va, const void *src, size_t len)
203 {
204         const void *DANGEROUS start, *DANGEROUS end;
205         size_t num_pages, i;
206         pte_t *pte;
207         uintptr_t perm = PTE_P | PTE_USER_RW;
208         size_t bytes_copied = 0;
209
210         static_assert(ULIM % PGSIZE == 0 && ULIM != 0); // prevent wrap-around
211
212         start = ROUNDDOWN(va, PGSIZE);
213         end = ROUNDUP(va + len, PGSIZE);
214
215         if (start >= (void*SNT)ULIM || end > (void*SNT)ULIM)
216                 return -EFAULT;
217
218         num_pages = LA2PPN(end - start);
219         for (i = 0; i < num_pages; i++) {
220                 pte = pgdir_walk(p->env_pgdir, start + i * PGSIZE, 0);
221                 if (!pte)
222                         return -EFAULT;
223                 if ((*pte & PTE_P) && (*pte & PTE_USER_RW) != PTE_USER_RW)
224                         return -EFAULT;
225                 if (!(*pte & PTE_P))
226                         if (handle_page_fault(p, (uintptr_t)start + i * PGSIZE, PROT_WRITE))
227                                 return -EFAULT;
228                 void *kpage = KADDR(PTE_ADDR(*pte));
229                 void *dst_start = i > 0 ? kpage : kpage + (va - start);
230                 const void *src_start = src + bytes_copied;
231                 size_t copy_len = PGSIZE;
232                 if (i == 0)
233                         copy_len -= va - start;
234                 if (i == num_pages - 1)
235                         copy_len -= end - (va + len);
236                 memcpy(dst_start, src_start, copy_len);
237                 bytes_copied += copy_len;
238         }
239         assert(bytes_copied == len);
240         return 0;
241 }
242
243 int memcpy_to_user_errno(struct proc *p, void *dst, const void *src, int len)
244 {
245         if (memcpy_to_user(p, dst, src, len)) {
246                 set_errno(current_tf, EINVAL);
247                 return -1;
248         }
249         return 0;
250 }
251
252 /* Creates a buffer (kmalloc) and safely copies into it from va.  Can return an
253  * error code.  Check its response with IS_ERR().  Must be paired with
254  * user_memdup_free() if this succeeded. */
255 void *user_memdup(struct proc *p, const void *va, int len)
256 {
257         void* kva = NULL;
258         if (len < 0 || (kva = kmalloc(len, 0)) == NULL)
259                 return ERR_PTR(-ENOMEM);
260         if (memcpy_from_user(p, kva, va, len)) {
261                 kfree(kva);
262                 return ERR_PTR(-EINVAL);
263         }
264         return kva;
265 }
266
267 void *user_memdup_errno(struct proc *p, const void *va, int len)
268 {
269         void *kva = user_memdup(p, va, len);
270         if (IS_ERR(kva)) {
271                 set_errno(current_tf, -PTR_ERR(kva));
272                 return NULL;
273         }
274         return kva;
275 }
276
277 void user_memdup_free(struct proc *p, void *va)
278 {
279         kfree(va);
280 }
281
282 /* Same as memdup, but just does strings.  still needs memdup_freed */
283 char *user_strdup(struct proc *p, const char *va0, int max)
284 {
285         max++;
286         char* kbuf = (char*)kmalloc(PGSIZE, 0);
287         if (kbuf == NULL)
288                 return ERR_PTR(-ENOMEM);
289         int pos = 0, len = 0;
290         const char* va = va0;
291         while (max > 0 && len == 0) {
292                 int thislen = MIN(PGSIZE - (uintptr_t)va % PGSIZE, max);
293                 if (memcpy_from_user(p, kbuf, va, thislen)) {
294                         kfree(kbuf);
295                         return ERR_PTR(-EINVAL);
296                 }
297                 const char *nullterm = memchr(kbuf, 0, thislen);
298                 if (nullterm)
299                         len = pos + (nullterm - kbuf) + 1;
300                 pos += thislen;
301                 va += thislen;
302                 max -= thislen;
303         }
304         kfree(kbuf);
305         return len ? user_memdup(p, va0, len) : ERR_PTR(-EINVAL);
306 }
307
308 char *user_strdup_errno(struct proc *p, const char *va, int max)
309 {
310         void *kva = user_strdup(p, va, max);
311         if (IS_ERR(kva)) {
312                 set_errno(current_tf, -PTR_ERR(kva));
313                 return NULL;
314         }
315         return kva;
316 }
317
318 void *kmalloc_errno(int len)
319 {
320         void *kva = NULL;
321         if (len < 0 || (kva = kmalloc(len, 0)) == NULL)
322                 set_errno(current_tf, ENOMEM);
323         return kva;
324 }