Merge origin/netpush (networking code) (XCC)
[akaros.git] / kern / src / socket.c
1 /*
2  * Copyright (c) 2011 The Regents of the University of California
3  * David Zhu <yuzhu@cs.berkeley.edu>
4  * See LICENSE for details.
5  * 
6  * Socket layer on top of TCP abstraction. Similar to the BSD implementation.
7  *
8  */
9 #include <ros/common.h>
10 #include <socket.h>
11 #include <vfs.h>
12 #include <time.h>
13 #include <kref.h>
14 #include <syscall.h>
15 #include <sys/uio.h>
16 #include <ros/errno.h>
17 #include <net.h>
18 #include <net/udp.h>
19 #include <net/tcp.h>
20 #include <net/pbuf.h>
21 #include <net/tcp_impl.h>
22 #include <umem.h>
23 #include <kthread.h>
24 #include <bitmask.h>
25 #include <debug.h>
26 /*
27  *TODO: Figure out which socket.h is used where
28  *There are several socket.h in kern, and a couple more in glibc. Perhaps the glibc ones
29  *should grab from here..
30  */
31
32 struct kmem_cache *sock_kcache;
33 struct kmem_cache *mbuf_kcache;
34 struct kmem_cache *udp_pcb_kcache;
35 struct kmem_cache *tcp_pcb_kcache;
36 struct kmem_cache *tcp_pcb_listen_kcache;
37 struct kmem_cache *tcp_segment_kcache;
38
39 // file ops needed to support read/write on socket fd
40 static struct file_operations socket_op = {
41         0,
42         0,//soo_read,
43         0,//soo_write,
44         0,
45         0,
46         0,
47         0,
48         0,
49         0,
50         0,//soo_poll,
51         0,
52         0,
53         0, // sendpage might apply here
54         0,
55 };
56 static struct socket* getsocket(struct proc *p, int fd){
57         /* look up fd -> file */
58         struct file *so_file = get_file_from_fd(&(p->open_files), fd);
59
60         /* get socket and verify its type */
61         if (so_file == NULL){
62                 printd("getsocket() fd -> null file: fd %d\n", fd);
63                 return NULL;
64         }
65         if (so_file->f_op != &socket_op) {
66                 set_errno(ENOTSOCK);
67                 printd("fd %d maps to non-socket file\n");
68                 return NULL;
69         } else
70                 return (struct socket*) so_file->f_privdata;
71 }
72
73 struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
74         struct socket *newsock = kmem_cache_alloc(sock_kcache, 0);
75         assert(newsock);
76
77         newsock->so_family = socket_family;
78         newsock->so_type = socket_type;
79         newsock->so_protocol = protocol;
80         newsock->so_state = SS_ISDISCONNECTED;
81         STAILQ_INIT(&(newsock->acceptq));
82         pbuf_head_init(&newsock->recv_buff);
83         pbuf_head_init(&newsock->send_buff);
84         sem_init_irqsave(&newsock->sem, 0);
85         sem_init_irqsave(&newsock->accept_sem, 0);
86         spinlock_init(&newsock->waiter_lock);
87         LIST_INIT(&newsock->waiters);
88         return newsock;
89
90 }
91 // TODO: refactor vfs so we can allocate fd and do the basic initialization
92 struct file *alloc_socket_file(struct socket* sock) {
93         struct file *file = alloc_file();
94         if (file == NULL) return 0;
95
96         // Linux fakes a dentry and an inode for socks, see socket.c : sock_alloc_file
97         file->f_dentry = NULL; // This might break things?
98         file->f_vfsmnt = 0;
99         file->f_flags = 0;
100
101         file->f_mode = S_IRUSR | S_IWUSR; // both read and write for socket files
102
103         file->f_pos = 0;
104         file->f_uid = 0;
105         file->f_gid = 0;
106         file->f_error = 0;
107
108         file->f_op = &socket_op;
109         file->f_privdata = sock;
110         file->f_mapping = 0;
111         return file;
112 }
113
114 void socket_init(){
115         
116         /* allocate buf for socket */
117         sock_kcache = kmem_cache_create("socket", sizeof(struct socket),
118                                                                         __alignof__(struct socket), 0, 0, 0);
119         udp_pcb_kcache = kmem_cache_create("udppcb", sizeof(struct udp_pcb),
120                                                                         __alignof__(struct udp_pcb), 0, 0, 0);
121         tcp_pcb_kcache = kmem_cache_create("tcppcb", sizeof(struct tcp_pcb),
122                                                                         __alignof__(struct tcp_pcb), 0, 0, 0);
123         tcp_pcb_listen_kcache = kmem_cache_create("tcppcblisten", sizeof(struct tcp_pcb_listen),
124                                                                         __alignof__(struct tcp_pcb_listen), 0, 0, 0);
125         tcp_segment_kcache = kmem_cache_create("tcpsegment", sizeof(struct tcp_seg),
126                                                                         __alignof__(struct tcp_seg), 0, 0, 0);
127         pbuf_init();
128
129 }
130 intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
131         printk ("sysaccept called\n");
132         struct socket* sock = getsocket(p, sockfd);
133         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
134         uint16_t r_port;
135         struct socket *accepted = NULL;
136         int8_t irq_state = 0;
137         if (sock == NULL) {
138                 set_errno(EBADF);
139                 return -1;      
140         }
141         if (sock->so_type == SOCK_DGRAM){
142                 return -1; // indicates false for connect
143         } else if (sock->so_type == SOCK_STREAM) {
144                 /* XXX these do the same thing, what is it you actually wanted to do?
145                  * (Originally the first was sleep_on, and the second __down_sem */
146                 if (STAILQ_EMPTY(&(sock->acceptq))) {
147                         // block on the acceptq
148                         sem_down_irqsave(&sock->accept_sem, &irq_state);
149                 } else {
150                         sem_down_irqsave(&sock->accept_sem, &irq_state);
151                 }
152                 spin_lock_irqsave(&sock->waiter_lock);
153                 accepted = STAILQ_FIRST(&(sock->acceptq));
154                 STAILQ_REMOVE_HEAD((&(sock->acceptq)), next);
155                 spin_unlock_irqsave(&sock->waiter_lock);
156                 if (accepted == NULL) return -1;
157                 struct file *file = alloc_socket_file(accepted);
158                 if (file == NULL) return -1;
159                 int fd = insert_file(&p->open_files, file, 0);
160                 if (fd < 0) {
161                         warn("File insertion for socket open failed");
162                         return -1;
163                 }
164                 kref_put(&file->f_kref);
165         }
166         return -1;
167 }
168
169 static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
170         struct socket *sockold = (struct socket *) arg;
171         struct socket *sock = alloc_sock(sockold->so_family, sockold->so_type, sockold->so_protocol);
172         int8_t irq_state = 0;
173         
174         sock->so_pcb = newpcb;
175         newpcb->pcbsock = sock;
176         spin_lock_irqsave(&sockold->waiter_lock);
177         STAILQ_INSERT_TAIL(&sockold->acceptq, sock, next);
178         // wake up any kthread who is potentially waiting
179         spin_unlock_irqsave(&sockold->waiter_lock);
180         sem_up_irqsave(&sock->accept_sem, &irq_state);
181         return 0;
182 }
183 intreg_t sys_listen(struct proc *p, int sockfd, int backlog) {
184         struct socket* sock = getsocket(p, sockfd);
185         if (sock == NULL) {
186                 set_errno(EBADF);
187                 return -1;      
188         }
189         if (sock->so_type == SOCK_DGRAM){
190                 return -1; // indicates false for connect
191         } else if (sock->so_type == SOCK_STREAM) {
192                 // check if the socket is in WAIT state
193                 struct tcp_pcb *tpcb = (struct tcp_pcb*)sock->so_pcb;
194                 struct tcp_pcb* lpcb = tcp_listen_with_backlog(tpcb, backlog);
195                 if (lpcb == NULL) {
196                         return -1;
197                 }
198                 sock->so_pcb = lpcb;
199
200                 // register callback for new connection
201                 tcp_arg(lpcb, sock);                                                  
202                 tcp_accept(lpcb, accept_callback); 
203
204                 return 0;
205
206
207                 // XXX: add backlog later
208         }
209         return -1;
210 }
211 intreg_t sys_connect(struct proc *p, int sock_fd, const struct sockaddr* addr, int addrlen) {
212         printk("sys_connect called \n");
213         struct socket* sock = getsocket(p, sock_fd);
214         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
215         uint16_t r_port;
216         if (sock == NULL) {
217                 set_errno(EBADF);
218                 return -1;      
219         }
220         if (sock->so_type == SOCK_DGRAM){
221                 return -1; // indicates false for connect
222         } else if (sock->so_type == SOCK_STREAM) {
223                 error_t err = tcp_connect((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port, NULL);
224                 return err;
225         }
226
227         return -1;
228 }
229
230 intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len,
231                   int flags) {
232         printk("sys_send called \n");
233         struct socket* sock = getsocket(p, sockfd);
234         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)buf;
235         uint16_t r_port;
236         if (sock == NULL) {
237                 set_errno(EBADF);
238                 return -1;      
239         }
240         return len;
241
242 }
243 intreg_t sys_recv(struct proc *p, int sockfd, void *buf, size_t len, int flags) {
244         printk("sys_recv called \n");
245         // return actual length filled
246         return len;
247 }
248
249 intreg_t sys_bind(struct proc* p_proc, int fd, const struct sockaddr *addr, socklen_t addrlen) {
250         struct socket* sock = getsocket(p_proc, fd);
251         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
252         uint16_t r_port;
253         if (sock == NULL) {
254                 set_errno(EBADF);
255                 return -1;      
256         }
257         if (sock->so_type == SOCK_DGRAM){
258                 return udp_bind((struct udp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
259         } else if (sock->so_type == SOCK_STREAM) {
260                 return tcp_bind((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
261         } else {
262                 printk("SOCK type not supported in bind operation \n");
263                 return -1;
264         }
265         return 0;
266 }
267  
268 intreg_t sys_socket(struct proc *p, int socket_family, int socket_type, int protocol){
269         //check validity of params
270         if (socket_family != AF_INET && socket_type != SOCK_DGRAM)
271                 return 0;
272         struct socket *sock = alloc_sock(socket_family, socket_type, protocol);
273         if (socket_type == SOCK_DGRAM){
274                 /* udp socket */
275                 sock->so_pcb = udp_new();
276                 /* back link */
277                 ((struct udp_pcb*) (sock->so_pcb))->pcbsock = sock;
278         } else if (socket_type == SOCK_STREAM) {
279                 /* tcp socket */
280                 sock->so_pcb = tcp_new();
281                 ((struct tcp_pcb*) (sock->so_pcb))->pcbsock = sock;
282         }
283         struct file *file = alloc_socket_file(sock);
284         
285         if (file == NULL) return -1;
286         int fd = insert_file(&p->open_files, file, 0);
287         if (fd < 0) {
288                 warn("File insertion for socket open failed");
289                 return -1;
290         }
291         kref_put(&file->f_kref);
292         printk("Socket open, res = %d\n", fd);
293         return fd;
294 }
295
296 intreg_t send_iov(struct socket* sock, struct iovec* iov, int flags){
297         // COPY_COUNT: for each iov, copy into mbuf, and send
298         // should not copy here, copy in the protocol..
299         // should be esomething like this sock->so_proto->pr_send(sock, iov, flags);
300         // make it datagram specific for now...
301         send_datagram(sock, iov, flags);
302         // finally time to check for validity of UA, in the protocol send
303         return 0;       
304 }
305
306 /*TODO: iov support currently broken */
307 int send_datagram(struct socket* sock, struct iovec* iov, int flags){
308         // is this a connection oriented protocol? 
309         struct pbuf *prev = NULL;
310         struct pbuf *curr = NULL;
311         if (sock->so_type == SOCK_STREAM){
312                 set_errno(ENOTCONN);
313                 return -1;
314         }
315         
316         // possible sock locks needed
317         if ((sock->so_state & SS_ISCONNECTED) == 0){
318                 set_errno(EINVAL);
319                 return -1;
320         }
321     // pbuf_ref needs to map in the user ref
322         for (int i = 0; i< sizeof(iov) / sizeof (struct iovec); i++){
323                 prev = curr;
324                 curr = pbuf_alloc(PBUF_TRANSPORT, iov[i].iov_len, PBUF_REF);
325                 if (prev!=NULL) pbuf_chain(prev, curr);
326         }
327         // struct pbuf* pb = pbuf_alloc(PBUF_TRANSPORT, PBUF_REF);
328         udp_send(sock->so_pcb, prev);
329         return 0;
330         
331 }
332
333 /* sys_sendto can send SOCK_DGRAM and eventually SOCK_STREAM 
334  * SOCK_DGRAM uses PBUF_REF since UDP does not need to wait for ack
335  * SOCK_STREAM uses PBUF_
336  *
337  */
338 intreg_t sys_sendto(struct proc *p_proc, int fd, const void *buffer, size_t length, 
339                         int flags, const struct sockaddr *dest_addr, socklen_t dest_len){
340         // look up the socket
341         struct socket* sock = getsocket(p_proc, fd);
342         int error;
343         struct sockaddr_in *in_addr;
344         uint16_t r_port;
345         if (sock == NULL) {
346                 set_errno(EBADF);
347                 return -1;      
348         }
349         if (sock->so_type == SOCK_DGRAM){
350                 in_addr = (struct sockaddr_in *)dest_addr;
351                 struct pbuf* buf = pbuf_alloc(PBUF_TRANSPORT, length, PBUF_REF);
352                 if (buf != NULL)
353                         buf->payload = (void*)buffer;
354                 else 
355                         warn("pbuf alloc failed \n");
356                 // potentially unsafe cast to udp_pcb 
357                 return udp_sendto((struct udp_pcb*) sock->so_pcb, buf, &in_addr->sin_addr, in_addr->sin_port);
358         }
359
360         return -1;
361   //TODO: support for sendmsg and iovectors? Let's get the basics working first!
362         #if 0 
363         // use iovector to handle sendmsg calls too, and potentially scatter-gather
364         struct msghdr msg;
365         struct iovec iov;
366         struct uio auio;
367         
368         // checking for permission only when you are sending it
369         // potential bug TOCTOU, especially with async calls
370                 
371     msg.msg_name = dest_addr;
372     msg.msg_namelen = dest_len;
373     msg.msg_iov = &iov;
374     msg.msg_iovlen = 1;
375     msg.msg_control = 0;
376     
377         iov.iov_base = buffer;
378     iov.iov_len = length;
379         
380
381         // this is why we need another function to populate auio
382
383         auio.uio_iov = iov;
384         auio.uio_iovcnt = 1;
385         auio.uio_offset = 0;
386         auio.uio_resid = 0;
387         auio.uio_rw = UIO_WRITE;
388         auio.uio_proc = p;
389
390         // consider changing to send_uaio, since we care about progress.
391     error = send_iov(soc, iov, flags);
392         #endif
393 }
394
395 /* UDP and TCP has different waiting semantics
396  * UDP requires any packet to be available. 
397  * TCP requires accumulation of certain size? 
398  */
399 intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t length, int flags, struct sockaddr *restrict address, socklen_t *restrict address_len){
400         struct socket* sock = getsocket(p, socket);     
401         int copied = 0;
402         int returnval = 0;
403         int8_t irq_state = 0;
404         if (sock == NULL) {
405                 set_errno(EBADF);
406                 return -1;
407         }
408         if (sock->so_type == SOCK_DGRAM){
409                 struct pbuf_head *ph = &(sock->recv_buff);
410                 struct pbuf* buf = NULL;
411                 buf = detach_pbuf(ph);
412                 if (!buf){
413                         // about to sleep
414                         sem_down_irqsave(&sock->sem, &irq_state);
415                         buf = detach_pbuf(ph);
416                         // Someone woke me up, there should be data..
417                         assert(buf);
418                 } else {
419                         sem_down_irqsave(&sock->sem, &irq_state);
420                 }
421                         copied = buf->len - sizeof(struct udp_hdr);
422                         if (copied > length)
423                                 copied = length;
424                         pbuf_header(buf, -UDP_HDR_SZ);
425                         // copy it to user space
426                         returnval = memcpy_to_user_errno(p, buffer, buf->payload, copied);
427                 }
428         if (returnval < 0) 
429                 return -1;
430         else
431                 return copied;
432 }
433
434 static int selscan(int maxfdp1, fd_set *readset_in, fd_set *writeset_in, fd_set *exceptset_in,
435              fd_set *readset_out, fd_set *writeset_out, fd_set *exceptset_out){
436         return 0;
437 }
438
439 /* TODO: Start respecting the time out value */ 
440 /* TODO: start respecting writefds and exceptfds */
441 intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
442                                 fd_set *exceptfds, struct timeval *timeout){
443         /* Create a semaphore */
444         struct semaphore_entry read_sem; 
445         int8_t irq_state = 0;
446
447         sem_init_irqsave(&(read_sem.sem), 0);
448
449         /* insert into the sem list of a fd / socket */
450         int low_fd = 0;
451         for (int i = low_fd; i< nfds; i++) {
452                 if(FD_ISSET(i, readfds)){
453                   struct socket* sock = getsocket(p, i);
454                         /* if the fd is not open or if the file descriptor is not a socket 
455                          * go to the next in the fd set 
456                          */
457                         if (sock == NULL) continue;
458                         /* for each file that is open, insert this semaphore to be woken up when there
459                         * is data available to be read
460                         */
461                         spin_lock(&sock->waiter_lock);
462                         LIST_INSERT_HEAD(&sock->waiters, &read_sem, link);
463                         spin_unlock(&sock->waiter_lock);
464                 }
465         }
466         /* At this point wait on the semaphore */
467         sem_down_irqsave(&read_sem.sem, &irq_state);
468         /* someone woke me up, so walk through the list of descriptors and find one that is ready */
469         /* remove itself from all the lists that it is waiting on */
470         for (int i = low_fd; i<nfds; i++) {
471                 if (FD_ISSET(i, readfds)){
472                         struct socket* sock = getsocket(p,i);
473                         if (sock == NULL) continue;
474                         spin_lock(&sock->waiter_lock);
475                         LIST_REMOVE(&read_sem, link);
476                         spin_unlock(&sock->waiter_lock);
477                 }
478         }
479         fd_set readout, writeout, exceptout;
480         FD_ZERO(&readout);
481         FD_ZERO(&writeout);
482         FD_ZERO(&exceptout);
483         for (int i = low_fd; i< nfds; i ++){
484                 if (readfds && FD_ISSET(i, readfds)){
485                   struct socket* sock = getsocket(p, i);
486                         if ((sock->recv_buff).qlen > 0){
487                                 FD_SET(i, &readout);
488                         }
489                         /* if the socket is ready, then we can return it */
490                 }
491         }
492         if (readfds)
493                 memcpy(readfds, &readout, sizeof(*readfds));
494         if (writefds)
495                 memcpy(writefds, &writeout, sizeof(*writefds));
496         if (exceptfds)
497                 memcpy(readfds, &readout, sizeof(*readfds));
498
499         /* Sleep on that semaphore */
500         /* Somehow get these file descriptors to wake me up when there is new data */
501         return 0;
502 }