accept function and wake up logic
[akaros.git] / kern / src / socket.c
1 /*
2  * Copyright (c) 2011 The Regents of the University of California
3  * David Zhu <yuzhu@cs.berkeley.edu>
4  * See LICENSE for details.
5  * 
6  * Socket layer on top of TCP abstraction. Similar to the BSD implementation.
7  *
8  */
9 #include <ros/common.h>
10 #include <socket.h>
11 #include <vfs.h>
12 #include <time.h>
13 #include <kref.h>
14 #include <syscall.h>
15 #include <sys/uio.h>
16 #include <mbuf.h>
17 #include <ros/errno.h>
18 #include <net.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <net/pbuf.h>
22 #include <net/tcp_impl.h>
23 #include <umem.h>
24 #include <kthread.h>
25 #include <bitmask.h>
26 #include <debug.h>
27 /*
28  *TODO: Figure out which socket.h is used where
29  *There are several socket.h in kern, and a couple more in glibc. Perhaps the glibc ones
30  *should grab from here..
31  */
32
33 struct kmem_cache *sock_kcache;
34 struct kmem_cache *mbuf_kcache;
35 struct kmem_cache *udp_pcb_kcache;
36 struct kmem_cache *tcp_pcb_kcache;
37 struct kmem_cache *tcp_pcb_listen_kcache;
38 struct kmem_cache *tcp_segment_kcache;
39
40 // file ops needed to support read/write on socket fd
41 static struct file_operations socket_op = {
42         0,
43         0,//soo_read,
44         0,//soo_write,
45         0,
46         0,
47         0,
48         0,
49         0,
50         0,
51         0,//soo_poll,
52         0,
53         0,
54         0, // sendpage might apply here
55         0,
56 };
57 static struct socket* getsocket(struct proc *p, int fd){
58         /* look up fd -> file */
59         struct file *so_file = get_file_from_fd(&(p->open_files), fd);
60
61         /* get socket and verify its type */
62         if (so_file == NULL){
63                 printd("getsocket() fd -> null file: fd %d\n", fd);
64                 return NULL;
65         }
66         if (so_file->f_op != &socket_op) {
67                 set_errno(ENOTSOCK);
68                 printd("fd %d maps to non-socket file\n");
69                 return NULL;
70         } else
71                 return (struct socket*) so_file->f_privdata;
72 }
73
74 struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
75         struct socket *newsock = kmem_cache_alloc(sock_kcache, 0);
76         assert(newsock);
77
78         newsock->so_family = socket_family;
79         newsock->so_type = socket_type;
80         newsock->so_protocol = protocol;
81         newsock->so_state = SS_ISDISCONNECTED;
82         STAILQ_INIT(&(newsock->acceptq));
83         pbuf_head_init(&newsock->recv_buff);
84         pbuf_head_init(&newsock->send_buff);
85         init_sem(&newsock->sem, 0);
86         init_sem(&newsock->accept_sem, 0);
87         spinlock_init(&newsock->waiter_lock);
88         LIST_INIT(&newsock->waiters);
89         return newsock;
90
91 }
92 // TODO: refactor vfs so we can allocate fd and do the basic initialization
93 struct file *alloc_socket_file(struct socket* sock) {
94         struct file *file = alloc_file();
95         if (file == NULL) return 0;
96
97         // Linux fakes a dentry and an inode for socks, see socket.c : sock_alloc_file
98         file->f_dentry = NULL; // This might break things?
99         file->f_vfsmnt = 0;
100         file->f_flags = 0;
101
102         file->f_mode = S_IRUSR | S_IWUSR; // both read and write for socket files
103
104         file->f_pos = 0;
105         file->f_uid = 0;
106         file->f_gid = 0;
107         file->f_error = 0;
108
109         file->f_op = &socket_op;
110         file->f_privdata = sock;
111         file->f_mapping = 0;
112         return file;
113 }
114
115 void socket_init(){
116         
117         /* allocate buf for socket */
118         sock_kcache = kmem_cache_create("socket", sizeof(struct socket),
119                                                                         __alignof__(struct socket), 0, 0, 0);
120         udp_pcb_kcache = kmem_cache_create("udppcb", sizeof(struct udp_pcb),
121                                                                         __alignof__(struct udp_pcb), 0, 0, 0);
122         tcp_pcb_kcache = kmem_cache_create("tcppcb", sizeof(struct tcp_pcb),
123                                                                         __alignof__(struct tcp_pcb), 0, 0, 0);
124         tcp_pcb_listen_kcache = kmem_cache_create("tcppcblisten", sizeof(struct tcp_pcb_listen),
125                                                                         __alignof__(struct tcp_pcb_listen), 0, 0, 0);
126         tcp_segment_kcache = kmem_cache_create("tcpsegment", sizeof(struct tcp_seg),
127                                                                         __alignof__(struct tcp_seg), 0, 0, 0);
128         pbuf_init();
129
130 }
131 intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
132         printk ("sysaccept called\n");
133         struct socket* sock = getsocket(p, sockfd);
134         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
135         uint16_t r_port;
136         struct socket *accepted = NULL;
137         if (sock == NULL) {
138                 set_errno(EBADF);
139                 return -1;      
140         }
141         if (sock->so_type == SOCK_DGRAM){
142                 return -1; // indicates false for connect
143         } else if (sock->so_type == SOCK_STREAM) {
144                 if (STAILQ_EMPTY(&(sock->acceptq))) {
145                         // block on the acceptq
146                         sleep_on(&sock->accept_sem);
147                 } else {
148                         __down_sem(&sock->accept_sem, NULL);
149                 }
150                 spin_lock_irqsave(&sock->waiter_lock);
151                 accepted = STAILQ_FIRST(&(sock->acceptq));
152                 STAILQ_REMOVE_HEAD((&(sock->acceptq)), next);
153                 spin_unlock_irqsave(&sock->waiter_lock);
154                 if (accepted == NULL) return -1;
155                 struct file *file = alloc_socket_file(accepted);
156                 if (file == NULL) return -1;
157                 int fd = insert_file(&p->open_files, file, 0);
158                 if (fd < 0) {
159                         warn("File insertion for socket open failed");
160                         return -1;
161                 }
162                 kref_put(&file->f_kref);
163         }
164         return -1;
165 }
166
167 static void wrap_restart_kthread(struct trapframe *tf, uint32_t srcid,
168                                         long a0, long a1, long a2){
169         restart_kthread((struct kthread*) a0);
170 }
171
172 static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
173         struct socket *sockold = (struct socket *) arg;
174         struct socket *sock = alloc_sock(sockold->so_family, sockold->so_type, sockold->so_protocol);
175         
176         sock->so_pcb = newpcb;
177         newpcb->pcbsock = sock;
178         spin_lock_irqsave(&sockold->waiter_lock);
179         STAILQ_INSERT_TAIL(&sockold->acceptq, sock, next);
180         // wake up any kthread who is potentially waiting
181         spin_unlock_irqsave(&sockold->waiter_lock);
182         struct kthread *kthread = __up_sem(&(sock->accept_sem), FALSE);
183         if (kthread) {
184                 send_kernel_message(core_id(), (amr_t)wrap_restart_kthread, (long)kthread, 0, 0,
185                                                                                                   KMSG_ROUTINE);
186         } 
187         return 0;
188 }
189 intreg_t sys_listen(struct proc *p, int sockfd, int backlog) {
190         struct socket* sock = getsocket(p, sockfd);
191         if (sock == NULL) {
192                 set_errno(EBADF);
193                 return -1;      
194         }
195         if (sock->so_type == SOCK_DGRAM){
196                 return -1; // indicates false for connect
197         } else if (sock->so_type == SOCK_STREAM) {
198                 // check if the socket is in WAIT state
199                 struct tcp_pcb *tpcb = (struct tcp_pcb*)sock->so_pcb;
200                 struct tcp_pcb* lpcb = tcp_listen_with_backlog(tpcb, backlog);
201                 if (lpcb == NULL) {
202                         return -1;
203                 }
204                 sock->so_pcb = lpcb;
205
206                 // register callback for new connection
207                 tcp_arg(lpcb, sock);                                                  
208                 tcp_accept(lpcb, accept_callback); 
209
210                 return 0;
211
212
213                 // XXX: add backlog later
214         }
215         return -1;
216 }
217 intreg_t sys_connect(struct proc *p, int sock_fd, const struct sockaddr* addr, int addrlen) {
218         printk("sys_connect called \n");
219         struct socket* sock = getsocket(p, sock_fd);
220         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
221         uint16_t r_port;
222         if (sock == NULL) {
223                 set_errno(EBADF);
224                 return -1;      
225         }
226         if (sock->so_type == SOCK_DGRAM){
227                 return -1; // indicates false for connect
228         } else if (sock->so_type == SOCK_STREAM) {
229                 error_t err = tcp_connect((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port, NULL);
230                 return err;
231         }
232
233         return -1;
234 }
235
236 intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len, int flags) {
237         printk("sys_send called \n");
238         return len; // indicates success, by returning length
239
240 }
241 intreg_t sys_recv(struct proc *p, int sockfd, void *buf, size_t len, int flags) {
242         printk("sys_recv called \n");
243         // return actual length filled
244         return len;
245 }
246
247 intreg_t sys_bind(struct proc* p_proc, int fd, const struct sockaddr *addr, socklen_t addrlen) {
248         struct socket* sock = getsocket(p_proc, fd);
249         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
250         uint16_t r_port;
251         if (sock == NULL) {
252                 set_errno(EBADF);
253                 return -1;      
254         }
255         if (sock->so_type == SOCK_DGRAM){
256                 return udp_bind((struct udp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
257         } else if (sock->so_type == SOCK_STREAM) {
258                 return tcp_bind((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
259         } else {
260                 printk("SOCK type not supported in bind operation \n");
261                 return -1;
262         }
263         return 0;
264 }
265  
266 intreg_t sys_socket(struct proc *p, int socket_family, int socket_type, int protocol){
267         //check validity of params
268         if (socket_family != AF_INET && socket_type != SOCK_DGRAM)
269                 return 0;
270         struct socket *sock = alloc_sock(socket_family, socket_type, protocol);
271         if (socket_type == SOCK_DGRAM){
272                 /* udp socket */
273                 sock->so_pcb = udp_new();
274                 /* back link */
275                 ((struct udp_pcb*) (sock->so_pcb))->pcbsock = sock;
276         } else if (socket_type == SOCK_STREAM) {
277                 /* tcp socket */
278                 sock->so_pcb = tcp_new();
279                 ((struct tcp_pcb*) (sock->so_pcb))->pcbsock = sock;
280         }
281         struct file *file = alloc_socket_file(sock);
282         
283         if (file == NULL) return -1;
284         int fd = insert_file(&p->open_files, file, 0);
285         if (fd < 0) {
286                 warn("File insertion for socket open failed");
287                 return -1;
288         }
289         kref_put(&file->f_kref);
290         printk("Socket open, res = %d\n", fd);
291         return fd;
292 }
293
294 intreg_t send_iov(struct socket* sock, struct iovec* iov, int flags){
295         // COPY_COUNT: for each iov, copy into mbuf, and send
296         // should not copy here, copy in the protocol..
297         // should be esomething like this sock->so_proto->pr_send(sock, iov, flags);
298         // make it datagram specific for now...
299         send_datagram(sock, iov, flags);
300         // finally time to check for validity of UA, in the protocol send
301         return 0;       
302 }
303
304 /*TODO: iov support currently broken */
305 int send_datagram(struct socket* sock, struct iovec* iov, int flags){
306         // is this a connection oriented protocol? 
307         struct pbuf *prev = NULL;
308         struct pbuf *curr = NULL;
309         if (sock->so_type == SOCK_STREAM){
310                 set_errno(ENOTCONN);
311                 return -1;
312         }
313         
314         // possible sock locks needed
315         if ((sock->so_state & SS_ISCONNECTED) == 0){
316                 set_errno(EINVAL);
317                 return -1;
318         }
319     // pbuf_ref needs to map in the user ref
320         for (int i = 0; i< sizeof(iov) / sizeof (struct iovec); i++){
321                 prev = curr;
322                 curr = pbuf_alloc(PBUF_TRANSPORT, iov[i].iov_len, PBUF_REF);
323                 if (prev!=NULL) pbuf_chain(prev, curr);
324         }
325         // struct pbuf* pb = pbuf_alloc(PBUF_TRANSPORT, PBUF_REF);
326         udp_send(sock->so_pcb, prev);
327         return 0;
328         
329 }
330
331 /* sys_sendto can send SOCK_DGRAM and eventually SOCK_STREAM 
332  * SOCK_DGRAM uses PBUF_REF since UDP does not need to wait for ack
333  * SOCK_STREAM uses PBUF_
334  *
335  */
336 intreg_t sys_sendto(struct proc *p_proc, int fd, const void *buffer, size_t length, 
337                         int flags, const struct sockaddr *dest_addr, socklen_t dest_len){
338         // look up the socket
339         struct socket* sock = getsocket(p_proc, fd);
340         int error;
341         struct sockaddr_in *in_addr;
342         uint16_t r_port;
343         if (sock == NULL) {
344                 set_errno(EBADF);
345                 return -1;      
346         }
347         if (sock->so_type == SOCK_DGRAM){
348                 in_addr = (struct sockaddr_in *)dest_addr;
349                 struct pbuf* buf = pbuf_alloc(PBUF_TRANSPORT, length, PBUF_REF);
350                 if (buf != NULL)
351                         buf->payload = (void*)buffer;
352                 else 
353                         warn("pbuf alloc failed \n");
354                 // potentially unsafe cast to udp_pcb 
355                 return udp_sendto((struct udp_pcb*) sock->so_pcb, buf, &in_addr->sin_addr, in_addr->sin_port);
356         }
357
358         return -1;
359   //TODO: support for sendmsg and iovectors? Let's get the basics working first!
360         #if 0 
361         // use iovector to handle sendmsg calls too, and potentially scatter-gather
362         struct msghdr msg;
363         struct iovec iov;
364         struct uio auio;
365         
366         // checking for permission only when you are sending it
367         // potential bug TOCTOU, especially with async calls
368                 
369     msg.msg_name = dest_addr;
370     msg.msg_namelen = dest_len;
371     msg.msg_iov = &iov;
372     msg.msg_iovlen = 1;
373     msg.msg_control = 0;
374     
375         iov.iov_base = buffer;
376     iov.iov_len = length;
377         
378
379         // this is why we need another function to populate auio
380
381         auio.uio_iov = iov;
382         auio.uio_iovcnt = 1;
383         auio.uio_offset = 0;
384         auio.uio_resid = 0;
385         auio.uio_rw = UIO_WRITE;
386         auio.uio_proc = p;
387
388         // consider changing to send_uaio, since we care about progress.
389     error = send_iov(soc, iov, flags);
390         #endif
391 }
392
393 /* UDP and TCP has different waiting semantics
394  * UDP requires any packet to be available. 
395  * TCP requires accumulation of certain size? 
396  */
397 intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t length, int flags, struct sockaddr *restrict address, socklen_t *restrict address_len){
398         struct socket* sock = getsocket(p, socket);     
399         int copied = 0;
400         int returnval = 0;
401         if (sock == NULL) {
402                 set_errno(EBADF);
403                 return -1;
404         }
405         if (sock->so_type == SOCK_DGRAM){
406                 struct pbuf_head *ph = &(sock->recv_buff);
407                 struct pbuf* buf = NULL;
408                 buf = detach_pbuf(ph);
409                 if (!buf){
410                         // about to sleep
411                         sleep_on(&sock->sem);
412                         buf = detach_pbuf(ph);
413                         // Someone woke me up, there should be data..
414                         assert(buf);
415                 } else {
416                         __down_sem(&sock->sem, NULL);
417                 }
418                         copied = buf->len - sizeof(struct udp_hdr);
419                         if (copied > length)
420                                 copied = length;
421                         pbuf_header(buf, -UDP_HDR_SZ);
422                         // copy it to user space
423                         returnval = memcpy_to_user_errno(p, buffer, buf->payload, copied);
424                 }
425         if (returnval < 0) 
426                 return -1;
427         else
428                 return copied;
429 }
430
431 static int selscan(int maxfdp1, fd_set *readset_in, fd_set *writeset_in, fd_set *exceptset_in,
432              fd_set *readset_out, fd_set *writeset_out, fd_set *exceptset_out){
433         return 0;
434 }
435
436 /* TODO: Start respecting the time out value */ 
437 /* TODO: start respecting writefds and exceptfds */
438 intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
439                                 fd_set *exceptfds, struct timeval *timeout){
440         /* Create a semaphore */
441         struct semaphore_entry read_sem; 
442
443         init_sem(&(read_sem.sem), 0);
444
445         /* insert into the sem list of a fd / socket */
446         int low_fd = 0;
447         for (int i = low_fd; i< nfds; i++) {
448                 if(FD_ISSET(i, readfds)){
449                   struct socket* sock = getsocket(p, i);
450                         /* if the fd is not open or if the file descriptor is not a socket 
451                          * go to the next in the fd set 
452                          */
453                         if (sock == NULL) continue;
454                         /* for each file that is open, insert this semaphore to be woken up when there
455                         * is data available to be read
456                         */
457                         spin_lock(&sock->waiter_lock);
458                         LIST_INSERT_HEAD(&sock->waiters, &read_sem, link);
459                         spin_unlock(&sock->waiter_lock);
460                 }
461         }
462         /* At this point wait on the semaphore */
463   sleep_on(&(read_sem.sem));
464         /* someone woke me up, so walk through the list of descriptors and find one that is ready */
465         /* remove itself from all the lists that it is waiting on */
466         for (int i = low_fd; i<nfds; i++) {
467                 if (FD_ISSET(i, readfds)){
468                         struct socket* sock = getsocket(p,i);
469                         if (sock == NULL) continue;
470                         spin_lock(&sock->waiter_lock);
471                         LIST_REMOVE(&read_sem, link);
472                         spin_unlock(&sock->waiter_lock);
473                 }
474         }
475         fd_set readout, writeout, exceptout;
476         FD_ZERO(&readout);
477         FD_ZERO(&writeout);
478         FD_ZERO(&exceptout);
479         for (int i = low_fd; i< nfds; i ++){
480                 if (readfds && FD_ISSET(i, readfds)){
481                   struct socket* sock = getsocket(p, i);
482                         if ((sock->recv_buff).qlen > 0){
483                                 FD_SET(i, &readout);
484                         }
485                         /* if the socket is ready, then we can return it */
486                 }
487         }
488         if (readfds)
489                 memcpy(readfds, &readout, sizeof(*readfds));
490         if (writefds)
491                 memcpy(writefds, &writeout, sizeof(*writefds));
492         if (exceptfds)
493                 memcpy(readfds, &readout, sizeof(*readfds));
494
495         /* Sleep on that semaphore */
496         /* Somehow get these file descriptors to wake me up when there is new data */
497         return 0;
498 }