More simple fixes.
[akaros.git] / kern / src / socket.c
1 /*
2  * Copyright (c) 2011 The Regents of the University of California
3  * David Zhu <yuzhu@cs.berkeley.edu>
4  * See LICENSE for details.
5  * 
6  * Socket layer on top of TCP abstraction. Similar to the BSD implementation.
7  *
8  */
9 #include <ros/common.h>
10 #include <socket.h>
11 #include <vfs.h>
12 #include <time.h>
13 #include <kref.h>
14 #include <syscall.h>
15 #include <sys/uio.h>
16 #include <mbuf.h>
17 #include <ros/errno.h>
18 #include <net.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <net/pbuf.h>
22 #include <net/tcp_impl.h>
23 #include <umem.h>
24 #include <kthread.h>
25 #include <bitmask.h>
26 #include <debug.h>
27 /*
28  *TODO: Figure out which socket.h is used where
29  *There are several socket.h in kern, and a couple more in glibc. Perhaps the glibc ones
30  *should grab from here..
31  */
32
33 struct kmem_cache *sock_kcache;
34 struct kmem_cache *mbuf_kcache;
35 struct kmem_cache *udp_pcb_kcache;
36 struct kmem_cache *tcp_pcb_kcache;
37 struct kmem_cache *tcp_pcb_listen_kcache;
38 struct kmem_cache *tcp_segment_kcache;
39
40 // file ops needed to support read/write on socket fd
41 static struct file_operations socket_op = {
42         0,
43         0,//soo_read,
44         0,//soo_write,
45         0,
46         0,
47         0,
48         0,
49         0,
50         0,
51         0,//soo_poll,
52         0,
53         0,
54         0, // sendpage might apply here
55         0,
56 };
57 static struct socket* getsocket(struct proc *p, int fd){
58         /* look up fd -> file */
59         struct file *so_file = get_file_from_fd(&(p->open_files), fd);
60
61         /* get socket and verify its type */
62         if (so_file == NULL){
63                 printd("getsocket() fd -> null file: fd %d\n", fd);
64                 return NULL;
65         }
66         if (so_file->f_op != &socket_op) {
67                 set_errno(ENOTSOCK);
68                 printd("fd %d maps to non-socket file\n");
69                 return NULL;
70         } else
71                 return (struct socket*) so_file->f_privdata;
72 }
73
74 struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
75         struct socket *newsock = kmem_cache_alloc(sock_kcache, 0);
76         assert(newsock);
77
78         newsock->so_family = socket_family;
79         newsock->so_type = socket_type;
80         newsock->so_protocol = protocol;
81         newsock->so_state = SS_ISDISCONNECTED;
82         STAILQ_INIT(&(newsock->acceptq));
83         pbuf_head_init(&newsock->recv_buff);
84         pbuf_head_init(&newsock->send_buff);
85         init_sem(&newsock->sem, 0);
86         init_sem(&newsock->accept_sem, 0);
87         spinlock_init(&newsock->waiter_lock);
88         LIST_INIT(&newsock->waiters);
89         return newsock;
90
91 }
92 // TODO: refactor vfs so we can allocate fd and do the basic initialization
93 struct file *alloc_socket_file(struct socket* sock) {
94         struct file *file = alloc_file();
95         if (file == NULL) return 0;
96
97         // Linux fakes a dentry and an inode for socks, see socket.c : sock_alloc_file
98         file->f_dentry = NULL; // This might break things?
99         file->f_vfsmnt = 0;
100         file->f_flags = 0;
101
102         file->f_mode = S_IRUSR | S_IWUSR; // both read and write for socket files
103
104         file->f_pos = 0;
105         file->f_uid = 0;
106         file->f_gid = 0;
107         file->f_error = 0;
108
109         file->f_op = &socket_op;
110         file->f_privdata = sock;
111         file->f_mapping = 0;
112         return file;
113 }
114
115 void socket_init(){
116         
117         /* allocate buf for socket */
118         sock_kcache = kmem_cache_create("socket", sizeof(struct socket),
119                                                                         __alignof__(struct socket), 0, 0, 0);
120         udp_pcb_kcache = kmem_cache_create("udppcb", sizeof(struct udp_pcb),
121                                                                         __alignof__(struct udp_pcb), 0, 0, 0);
122         tcp_pcb_kcache = kmem_cache_create("tcppcb", sizeof(struct tcp_pcb),
123                                                                         __alignof__(struct tcp_pcb), 0, 0, 0);
124         tcp_pcb_listen_kcache = kmem_cache_create("tcppcblisten", sizeof(struct tcp_pcb_listen),
125                                                                         __alignof__(struct tcp_pcb_listen), 0, 0, 0);
126         tcp_segment_kcache = kmem_cache_create("tcpsegment", sizeof(struct tcp_seg),
127                                                                         __alignof__(struct tcp_seg), 0, 0, 0);
128         pbuf_init();
129
130 }
131 intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
132         printk ("sysaccept called\n");
133         struct socket* sock = getsocket(p, sockfd);
134         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
135         uint16_t r_port;
136         struct socket *accepted = NULL;
137         if (sock == NULL) {
138                 set_errno(EBADF);
139                 return -1;      
140         }
141         if (sock->so_type == SOCK_DGRAM){
142                 return -1; // indicates false for connect
143         } else if (sock->so_type == SOCK_STREAM) {
144                 if (STAILQ_EMPTY(&(sock->acceptq))) {
145                         // block on the acceptq
146                         sleep_on(&sock->accept_sem);
147                 } else {
148                         __down_sem(&sock->accept_sem, NULL);
149                 }
150                 spin_lock_irqsave(&sock->waiter_lock);
151                 accepted = STAILQ_FIRST(&(sock->acceptq));
152                 STAILQ_REMOVE_HEAD((&(sock->acceptq)), next);
153                 spin_unlock_irqsave(&sock->waiter_lock);
154                 if (accepted == NULL) return -1;
155                 struct file *file = alloc_socket_file(accepted);
156                 if (file == NULL) return -1;
157                 int fd = insert_file(&p->open_files, file, 0);
158                 if (fd < 0) {
159                         warn("File insertion for socket open failed");
160                         return -1;
161                 }
162                 kref_put(&file->f_kref);
163         }
164         return -1;
165 }
166
167 static void wrap_restart_kthread(struct trapframe *tf, uint32_t srcid,
168                                         long a0, long a1, long a2){
169         restart_kthread((struct kthread*) a0);
170 }
171
172 static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
173         struct socket *sockold = (struct socket *) arg;
174         struct socket *sock = alloc_sock(sockold->so_family, sockold->so_type, sockold->so_protocol);
175         
176         sock->so_pcb = newpcb;
177         newpcb->pcbsock = sock;
178         spin_lock_irqsave(&sockold->waiter_lock);
179         STAILQ_INSERT_TAIL(&sockold->acceptq, sock, next);
180         // wake up any kthread who is potentially waiting
181         spin_unlock_irqsave(&sockold->waiter_lock);
182         struct kthread *kthread = __up_sem(&(sock->accept_sem), FALSE);
183         if (kthread) {
184                 send_kernel_message(core_id(), (amr_t)wrap_restart_kthread, (long)kthread, 0, 0,
185                                                                                                   KMSG_ROUTINE);
186         } 
187         return 0;
188 }
189 intreg_t sys_listen(struct proc *p, int sockfd, int backlog) {
190         struct socket* sock = getsocket(p, sockfd);
191         if (sock == NULL) {
192                 set_errno(EBADF);
193                 return -1;      
194         }
195         if (sock->so_type == SOCK_DGRAM){
196                 return -1; // indicates false for connect
197         } else if (sock->so_type == SOCK_STREAM) {
198                 // check if the socket is in WAIT state
199                 struct tcp_pcb *tpcb = (struct tcp_pcb*)sock->so_pcb;
200                 struct tcp_pcb* lpcb = tcp_listen_with_backlog(tpcb, backlog);
201                 if (lpcb == NULL) {
202                         return -1;
203                 }
204                 sock->so_pcb = lpcb;
205
206                 // register callback for new connection
207                 tcp_arg(lpcb, sock);                                                  
208                 tcp_accept(lpcb, accept_callback); 
209
210                 return 0;
211
212
213                 // XXX: add backlog later
214         }
215         return -1;
216 }
217 intreg_t sys_connect(struct proc *p, int sock_fd, const struct sockaddr* addr, int addrlen) {
218         printk("sys_connect called \n");
219         struct socket* sock = getsocket(p, sock_fd);
220         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
221         uint16_t r_port;
222         if (sock == NULL) {
223                 set_errno(EBADF);
224                 return -1;      
225         }
226         if (sock->so_type == SOCK_DGRAM){
227                 return -1; // indicates false for connect
228         } else if (sock->so_type == SOCK_STREAM) {
229                 error_t err = tcp_connect((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port, NULL);
230                 return err;
231         }
232
233         return -1;
234 }
235
236 intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len, int flags) {
237         printk("sys_send called \n");
238         struct socket* sock = getsocket(p_proc, fd);
239         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
240         uint16_t r_port;
241         if (sock == NULL) {
242                 set_errno(EBADF);
243                 return -1;      
244         }
245         return len;
246
247 }
248 intreg_t sys_recv(struct proc *p, int sockfd, void *buf, size_t len, int flags) {
249         printk("sys_recv called \n");
250         // return actual length filled
251         return len;
252 }
253
254 intreg_t sys_bind(struct proc* p_proc, int fd, const struct sockaddr *addr, socklen_t addrlen) {
255         struct socket* sock = getsocket(p_proc, fd);
256         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
257         uint16_t r_port;
258         if (sock == NULL) {
259                 set_errno(EBADF);
260                 return -1;      
261         }
262         if (sock->so_type == SOCK_DGRAM){
263                 return udp_bind((struct udp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
264         } else if (sock->so_type == SOCK_STREAM) {
265                 return tcp_bind((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
266         } else {
267                 printk("SOCK type not supported in bind operation \n");
268                 return -1;
269         }
270         return 0;
271 }
272  
273 intreg_t sys_socket(struct proc *p, int socket_family, int socket_type, int protocol){
274         //check validity of params
275         if (socket_family != AF_INET && socket_type != SOCK_DGRAM)
276                 return 0;
277         struct socket *sock = alloc_sock(socket_family, socket_type, protocol);
278         if (socket_type == SOCK_DGRAM){
279                 /* udp socket */
280                 sock->so_pcb = udp_new();
281                 /* back link */
282                 ((struct udp_pcb*) (sock->so_pcb))->pcbsock = sock;
283         } else if (socket_type == SOCK_STREAM) {
284                 /* tcp socket */
285                 sock->so_pcb = tcp_new();
286                 ((struct tcp_pcb*) (sock->so_pcb))->pcbsock = sock;
287         }
288         struct file *file = alloc_socket_file(sock);
289         
290         if (file == NULL) return -1;
291         int fd = insert_file(&p->open_files, file, 0);
292         if (fd < 0) {
293                 warn("File insertion for socket open failed");
294                 return -1;
295         }
296         kref_put(&file->f_kref);
297         printk("Socket open, res = %d\n", fd);
298         return fd;
299 }
300
301 intreg_t send_iov(struct socket* sock, struct iovec* iov, int flags){
302         // COPY_COUNT: for each iov, copy into mbuf, and send
303         // should not copy here, copy in the protocol..
304         // should be esomething like this sock->so_proto->pr_send(sock, iov, flags);
305         // make it datagram specific for now...
306         send_datagram(sock, iov, flags);
307         // finally time to check for validity of UA, in the protocol send
308         return 0;       
309 }
310
311 /*TODO: iov support currently broken */
312 int send_datagram(struct socket* sock, struct iovec* iov, int flags){
313         // is this a connection oriented protocol? 
314         struct pbuf *prev = NULL;
315         struct pbuf *curr = NULL;
316         if (sock->so_type == SOCK_STREAM){
317                 set_errno(ENOTCONN);
318                 return -1;
319         }
320         
321         // possible sock locks needed
322         if ((sock->so_state & SS_ISCONNECTED) == 0){
323                 set_errno(EINVAL);
324                 return -1;
325         }
326     // pbuf_ref needs to map in the user ref
327         for (int i = 0; i< sizeof(iov) / sizeof (struct iovec); i++){
328                 prev = curr;
329                 curr = pbuf_alloc(PBUF_TRANSPORT, iov[i].iov_len, PBUF_REF);
330                 if (prev!=NULL) pbuf_chain(prev, curr);
331         }
332         // struct pbuf* pb = pbuf_alloc(PBUF_TRANSPORT, PBUF_REF);
333         udp_send(sock->so_pcb, prev);
334         return 0;
335         
336 }
337
338 /* sys_sendto can send SOCK_DGRAM and eventually SOCK_STREAM 
339  * SOCK_DGRAM uses PBUF_REF since UDP does not need to wait for ack
340  * SOCK_STREAM uses PBUF_
341  *
342  */
343 intreg_t sys_sendto(struct proc *p_proc, int fd, const void *buffer, size_t length, 
344                         int flags, const struct sockaddr *dest_addr, socklen_t dest_len){
345         // look up the socket
346         struct socket* sock = getsocket(p_proc, fd);
347         int error;
348         struct sockaddr_in *in_addr;
349         uint16_t r_port;
350         if (sock == NULL) {
351                 set_errno(EBADF);
352                 return -1;      
353         }
354         if (sock->so_type == SOCK_DGRAM){
355                 in_addr = (struct sockaddr_in *)dest_addr;
356                 struct pbuf* buf = pbuf_alloc(PBUF_TRANSPORT, length, PBUF_REF);
357                 if (buf != NULL)
358                         buf->payload = (void*)buffer;
359                 else 
360                         warn("pbuf alloc failed \n");
361                 // potentially unsafe cast to udp_pcb 
362                 return udp_sendto((struct udp_pcb*) sock->so_pcb, buf, &in_addr->sin_addr, in_addr->sin_port);
363         }
364
365         return -1;
366   //TODO: support for sendmsg and iovectors? Let's get the basics working first!
367         #if 0 
368         // use iovector to handle sendmsg calls too, and potentially scatter-gather
369         struct msghdr msg;
370         struct iovec iov;
371         struct uio auio;
372         
373         // checking for permission only when you are sending it
374         // potential bug TOCTOU, especially with async calls
375                 
376     msg.msg_name = dest_addr;
377     msg.msg_namelen = dest_len;
378     msg.msg_iov = &iov;
379     msg.msg_iovlen = 1;
380     msg.msg_control = 0;
381     
382         iov.iov_base = buffer;
383     iov.iov_len = length;
384         
385
386         // this is why we need another function to populate auio
387
388         auio.uio_iov = iov;
389         auio.uio_iovcnt = 1;
390         auio.uio_offset = 0;
391         auio.uio_resid = 0;
392         auio.uio_rw = UIO_WRITE;
393         auio.uio_proc = p;
394
395         // consider changing to send_uaio, since we care about progress.
396     error = send_iov(soc, iov, flags);
397         #endif
398 }
399
400 /* UDP and TCP has different waiting semantics
401  * UDP requires any packet to be available. 
402  * TCP requires accumulation of certain size? 
403  */
404 intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t length, int flags, struct sockaddr *restrict address, socklen_t *restrict address_len){
405         struct socket* sock = getsocket(p, socket);     
406         int copied = 0;
407         int returnval = 0;
408         if (sock == NULL) {
409                 set_errno(EBADF);
410                 return -1;
411         }
412         if (sock->so_type == SOCK_DGRAM){
413                 struct pbuf_head *ph = &(sock->recv_buff);
414                 struct pbuf* buf = NULL;
415                 buf = detach_pbuf(ph);
416                 if (!buf){
417                         // about to sleep
418                         sleep_on(&sock->sem);
419                         buf = detach_pbuf(ph);
420                         // Someone woke me up, there should be data..
421                         assert(buf);
422                 } else {
423                         __down_sem(&sock->sem, NULL);
424                 }
425                         copied = buf->len - sizeof(struct udp_hdr);
426                         if (copied > length)
427                                 copied = length;
428                         pbuf_header(buf, -UDP_HDR_SZ);
429                         // copy it to user space
430                         returnval = memcpy_to_user_errno(p, buffer, buf->payload, copied);
431                 }
432         if (returnval < 0) 
433                 return -1;
434         else
435                 return copied;
436 }
437
438 static int selscan(int maxfdp1, fd_set *readset_in, fd_set *writeset_in, fd_set *exceptset_in,
439              fd_set *readset_out, fd_set *writeset_out, fd_set *exceptset_out){
440         return 0;
441 }
442
443 /* TODO: Start respecting the time out value */ 
444 /* TODO: start respecting writefds and exceptfds */
445 intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
446                                 fd_set *exceptfds, struct timeval *timeout){
447         /* Create a semaphore */
448         struct semaphore_entry read_sem; 
449
450         init_sem(&(read_sem.sem), 0);
451
452         /* insert into the sem list of a fd / socket */
453         int low_fd = 0;
454         for (int i = low_fd; i< nfds; i++) {
455                 if(FD_ISSET(i, readfds)){
456                   struct socket* sock = getsocket(p, i);
457                         /* if the fd is not open or if the file descriptor is not a socket 
458                          * go to the next in the fd set 
459                          */
460                         if (sock == NULL) continue;
461                         /* for each file that is open, insert this semaphore to be woken up when there
462                         * is data available to be read
463                         */
464                         spin_lock(&sock->waiter_lock);
465                         LIST_INSERT_HEAD(&sock->waiters, &read_sem, link);
466                         spin_unlock(&sock->waiter_lock);
467                 }
468         }
469         /* At this point wait on the semaphore */
470   sleep_on(&(read_sem.sem));
471         /* someone woke me up, so walk through the list of descriptors and find one that is ready */
472         /* remove itself from all the lists that it is waiting on */
473         for (int i = low_fd; i<nfds; i++) {
474                 if (FD_ISSET(i, readfds)){
475                         struct socket* sock = getsocket(p,i);
476                         if (sock == NULL) continue;
477                         spin_lock(&sock->waiter_lock);
478                         LIST_REMOVE(&read_sem, link);
479                         spin_unlock(&sock->waiter_lock);
480                 }
481         }
482         fd_set readout, writeout, exceptout;
483         FD_ZERO(&readout);
484         FD_ZERO(&writeout);
485         FD_ZERO(&exceptout);
486         for (int i = low_fd; i< nfds; i ++){
487                 if (readfds && FD_ISSET(i, readfds)){
488                   struct socket* sock = getsocket(p, i);
489                         if ((sock->recv_buff).qlen > 0){
490                                 FD_SET(i, &readout);
491                         }
492                         /* if the socket is ready, then we can return it */
493                 }
494         }
495         if (readfds)
496                 memcpy(readfds, &readout, sizeof(*readfds));
497         if (writefds)
498                 memcpy(writefds, &writeout, sizeof(*writefds));
499         if (exceptfds)
500                 memcpy(readfds, &readout, sizeof(*readfds));
501
502         /* Sleep on that semaphore */
503         /* Somehow get these file descriptors to wake me up when there is new data */
504         return 0;
505 }