Basic socket stubs and functionalities.
[akaros.git] / kern / src / socket.c
1 /*
2  * Copyright (c) 2011 The Regents of the University of California
3  * David Zhu <yuzhu@cs.berkeley.edu>
4  * See LICENSE for details.
5  * 
6  * Socket layer on top of TCP abstraction. Similar to the BSD implementation.
7  *
8  */
9 #include <ros/common.h>
10 #include <socket.h>
11 #include <vfs.h>
12 #include <time.h>
13 #include <kref.h>
14 #include <syscall.h>
15 #include <sys/uio.h>
16 #include <mbuf.h>
17 #include <ros/errno.h>
18 #include <net.h>
19 #include <net/udp.h>
20 #include <net/tcp.h>
21 #include <net/pbuf.h>
22 #include <net/tcp_impl.h>
23 #include <umem.h>
24 #include <kthread.h>
25 #include <bitmask.h>
26 #include <debug.h>
27 /*
28  *TODO: Figure out which socket.h is used where
29  *There are several socket.h in kern, and a couple more in glibc. Perhaps the glibc ones
30  *should grab from here..
31  */
32
33 struct kmem_cache *sock_kcache;
34 struct kmem_cache *mbuf_kcache;
35 struct kmem_cache *udp_pcb_kcache;
36 struct kmem_cache *tcp_pcb_kcache;
37 struct kmem_cache *tcp_pcb_listen_kcache;
38 struct kmem_cache *tcp_segment_kcache;
39
40 // file ops needed to support read/write on socket fd
41 static struct file_operations socket_op = {
42         0,
43         0,//soo_read,
44         0,//soo_write,
45         0,
46         0,
47         0,
48         0,
49         0,
50         0,
51         0,//soo_poll,
52         0,
53         0,
54         0, // sendpage might apply here
55         0,
56 };
57 static struct socket* getsocket(struct proc *p, int fd){
58         /* look up fd -> file */
59         struct file *so_file = get_file_from_fd(&(p->open_files), fd);
60
61         /* get socket and verify its type */
62         if (so_file == NULL){
63                 printd("getsocket() fd -> null file: fd %d\n", fd);
64                 return NULL;
65         }
66         if (so_file->f_op != &socket_op) {
67                 set_errno(ENOTSOCK);
68                 printd("fd %d maps to non-socket file\n");
69                 return NULL;
70         } else
71                 return (struct socket*) so_file->f_privdata;
72 }
73
74 struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
75         struct socket *newsock = kmem_cache_alloc(sock_kcache, 0);
76         assert(newsock);
77
78         newsock->so_family = socket_family;
79         newsock->so_type = socket_type;
80         newsock->so_protocol = protocol;
81         newsock->so_state = SS_ISDISCONNECTED;
82         pbuf_head_init(&newsock->recv_buff);
83         pbuf_head_init(&newsock->send_buff);
84         init_sem(&newsock->sem, 0);
85         spinlock_init(&newsock->waiter_lock);
86         LIST_INIT(&newsock->waiters);
87         return newsock;
88
89 }
90 // TODO: refactor vfs so we can allocate fd and do the basic initialization
91 struct file *alloc_socket_file(struct socket* sock) {
92         struct file *file = alloc_file();
93         if (file == NULL) return 0;
94
95         // Linux fakes a dentry and an inode for socks, see socket.c : sock_alloc_file
96         file->f_dentry = NULL; // This might break things?
97         file->f_vfsmnt = 0;
98         file->f_flags = 0;
99
100         file->f_mode = S_IRUSR | S_IWUSR; // both read and write for socket files
101
102         file->f_pos = 0;
103         file->f_uid = 0;
104         file->f_gid = 0;
105         file->f_error = 0;
106
107         file->f_op = &socket_op;
108         file->f_privdata = sock;
109         file->f_mapping = 0;
110         return file;
111 }
112
113 void socket_init(){
114         
115         /* allocate buf for socket */
116         sock_kcache = kmem_cache_create("socket", sizeof(struct socket),
117                                                                         __alignof__(struct socket), 0, 0, 0);
118         udp_pcb_kcache = kmem_cache_create("udppcb", sizeof(struct udp_pcb),
119                                                                         __alignof__(struct udp_pcb), 0, 0, 0);
120         tcp_pcb_kcache = kmem_cache_create("tcppcb", sizeof(struct tcp_pcb),
121                                                                         __alignof__(struct tcp_pcb), 0, 0, 0);
122         tcp_pcb_listen_kcache = kmem_cache_create("tcppcblisten", sizeof(struct tcp_pcb_listen),
123                                                                         __alignof__(struct tcp_pcb_listen), 0, 0, 0);
124         tcp_segment_kcache = kmem_cache_create("tcpsegment", sizeof(struct tcp_seg),
125                                                                         __alignof__(struct tcp_seg), 0, 0, 0);
126         pbuf_init();
127
128 }
129 intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
130         printk ("sysaccept called\n");
131         struct socket* sock = getsocket(p, sockfd);
132         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
133         uint16_t r_port;
134         if (sock == NULL) {
135                 set_errno(EBADF);
136                 return -1;      
137         }
138         if (sock->so_type == SOCK_DGRAM){
139                 return -1; // indicates false for connect
140         } else if (sock->so_type == SOCK_STREAM) {
141                 // check if the socket is in WAIT state
142                 // pulling new socket from the queue
143         }
144         return -1;
145 }
146 static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
147         struct socket *sockold = (struct socket *) arg;
148         struct socket *sock = alloc_sock(sockold->so_family, sockold->so_type, sockold->so_protocol);
149         struct file *file = alloc_socket_file(sock);
150         
151         sock->so_pcb = newpcb;
152         newpcb->pcbsock = sock;
153         if (file == NULL) return -1;
154         kref_put(&file->f_kref);
155         return 0;
156 }
157 intreg_t sys_listen(struct proc *p, int sockfd, int backlog) {
158         struct socket* sock = getsocket(p, sockfd);
159         if (sock == NULL) {
160                 set_errno(EBADF);
161                 return -1;      
162         }
163         if (sock->so_type == SOCK_DGRAM){
164                 return -1; // indicates false for connect
165         } else if (sock->so_type == SOCK_STREAM) {
166                 // check if the socket is in WAIT state
167                 struct tcp_pcb *tpcb = (struct tcp_pcb*)sock->so_pcb;
168                 struct tcp_pcb* lpcb = tcp_listen_with_backlog(tpcb, backlog);
169                 if (lpcb == NULL) {
170                         return -1;
171                 }
172                 sock->so_pcb = lpcb;
173
174                 // register callback for new connection
175                 tcp_arg(lpcb, sock);                                                  
176                 tcp_accept(lpcb, accept_callback); 
177
178                 return 0;
179
180
181                 // XXX: add backlog later
182         }
183         return -1;
184 }
185 intreg_t sys_connect(struct proc *p, int sock_fd, const struct sockaddr* addr, int addrlen) {
186         printk("sys_connect called \n");
187         struct socket* sock = getsocket(p, sock_fd);
188         struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
189         uint16_t r_port;
190         if (sock == NULL) {
191                 set_errno(EBADF);
192                 return -1;      
193         }
194         if (sock->so_type == SOCK_DGRAM){
195                 return -1; // indicates false for connect
196         } else if (sock->so_type == SOCK_STREAM) {
197                 error_t err = tcp_connect((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port, NULL);
198                 return err;
199         }
200
201         return -1;
202 }
203
204 intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len, int flags) {
205         printk("sys_send called \n");
206         return len; // indicates success, by returning length
207
208 }
209 intreg_t sys_recv(struct proc *p, int sockfd, void *buf, size_t len, int flags) {
210         printk("sys_recv called \n");
211         // return actual length filled
212         return len;
213 }
214
215 intreg_t sys_bind(struct proc* p_proc, int fd, const struct sockaddr *addr, socklen_t addrlen) {
216         struct socket* sock = getsocket(p_proc, fd);
217         const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
218         uint16_t r_port;
219         if (sock == NULL) {
220                 set_errno(EBADF);
221                 return -1;      
222         }
223         if (sock->so_type == SOCK_DGRAM){
224                 return udp_bind((struct udp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
225         } else if (sock->so_type == SOCK_STREAM) {
226                 return tcp_bind((struct tcp_pcb*)sock->so_pcb, & (in_addr->sin_addr), in_addr->sin_port);
227         } else {
228                 printk("SOCK type not supported in bind operation \n");
229                 return -1;
230         }
231         return 0;
232 }
233  
234 intreg_t sys_socket(struct proc *p, int socket_family, int socket_type, int protocol){
235         //check validity of params
236         if (socket_family != AF_INET && socket_type != SOCK_DGRAM)
237                 return 0;
238         struct socket *sock = alloc_sock(socket_family, socket_type, protocol);
239         if (socket_type == SOCK_DGRAM){
240                 /* udp socket */
241                 sock->so_pcb = udp_new();
242                 /* back link */
243                 ((struct udp_pcb*) (sock->so_pcb))->pcbsock = sock;
244         } else if (socket_type == SOCK_STREAM) {
245                 /* tcp socket */
246                 sock->so_pcb = tcp_new();
247                 ((struct tcp_pcb*) (sock->so_pcb))->pcbsock = sock;
248         }
249         struct file *file = alloc_socket_file(sock);
250         
251         if (file == NULL) return -1;
252         int fd = insert_file(&p->open_files, file, 0);
253         if (fd < 0) {
254                 warn("File insertion for socket open failed");
255                 return -1;
256         }
257         kref_put(&file->f_kref);
258         printk("Socket open, res = %d\n", fd);
259         return fd;
260 }
261
262 intreg_t send_iov(struct socket* sock, struct iovec* iov, int flags){
263         // COPY_COUNT: for each iov, copy into mbuf, and send
264         // should not copy here, copy in the protocol..
265         // should be esomething like this sock->so_proto->pr_send(sock, iov, flags);
266         // make it datagram specific for now...
267         send_datagram(sock, iov, flags);
268         // finally time to check for validity of UA, in the protocol send
269         return 0;       
270 }
271
272 /*TODO: iov support currently broken */
273 int send_datagram(struct socket* sock, struct iovec* iov, int flags){
274         // is this a connection oriented protocol? 
275         struct pbuf *prev = NULL;
276         struct pbuf *curr = NULL;
277         if (sock->so_type == SOCK_STREAM){
278                 set_errno(ENOTCONN);
279                 return -1;
280         }
281         
282         // possible sock locks needed
283         if ((sock->so_state & SS_ISCONNECTED) == 0){
284                 set_errno(EINVAL);
285                 return -1;
286         }
287     // pbuf_ref needs to map in the user ref
288         for (int i = 0; i< sizeof(iov) / sizeof (struct iovec); i++){
289                 prev = curr;
290                 curr = pbuf_alloc(PBUF_TRANSPORT, iov[i].iov_len, PBUF_REF);
291                 if (prev!=NULL) pbuf_chain(prev, curr);
292         }
293         // struct pbuf* pb = pbuf_alloc(PBUF_TRANSPORT, PBUF_REF);
294         udp_send(sock->so_pcb, prev);
295         return 0;
296         
297 }
298
299 /* sys_sendto can send SOCK_DGRAM and eventually SOCK_STREAM 
300  * SOCK_DGRAM uses PBUF_REF since UDP does not need to wait for ack
301  * SOCK_STREAM uses PBUF_
302  *
303  */
304 intreg_t sys_sendto(struct proc *p_proc, int fd, const void *buffer, size_t length, 
305                         int flags, const struct sockaddr *dest_addr, socklen_t dest_len){
306         // look up the socket
307         struct socket* sock = getsocket(p_proc, fd);
308         int error;
309         struct sockaddr_in *in_addr;
310         uint16_t r_port;
311         if (sock == NULL) {
312                 set_errno(EBADF);
313                 return -1;      
314         }
315         if (sock->so_type == SOCK_DGRAM){
316                 in_addr = (struct sockaddr_in *)dest_addr;
317                 struct pbuf* buf = pbuf_alloc(PBUF_TRANSPORT, length, PBUF_REF);
318                 if (buf != NULL)
319                         buf->payload = (void*)buffer;
320                 else 
321                         warn("pbuf alloc failed \n");
322                 // potentially unsafe cast to udp_pcb 
323                 return udp_sendto((struct udp_pcb*) sock->so_pcb, buf, &in_addr->sin_addr, in_addr->sin_port);
324         }
325
326         return -1;
327   //TODO: support for sendmsg and iovectors? Let's get the basics working first!
328         #if 0 
329         // use iovector to handle sendmsg calls too, and potentially scatter-gather
330         struct msghdr msg;
331         struct iovec iov;
332         struct uio auio;
333         
334         // checking for permission only when you are sending it
335         // potential bug TOCTOU, especially with async calls
336                 
337     msg.msg_name = dest_addr;
338     msg.msg_namelen = dest_len;
339     msg.msg_iov = &iov;
340     msg.msg_iovlen = 1;
341     msg.msg_control = 0;
342     
343         iov.iov_base = buffer;
344     iov.iov_len = length;
345         
346
347         // this is why we need another function to populate auio
348
349         auio.uio_iov = iov;
350         auio.uio_iovcnt = 1;
351         auio.uio_offset = 0;
352         auio.uio_resid = 0;
353         auio.uio_rw = UIO_WRITE;
354         auio.uio_proc = p;
355
356         // consider changing to send_uaio, since we care about progress.
357     error = send_iov(soc, iov, flags);
358         #endif
359 }
360
361 /* UDP and TCP has different waiting semantics
362  * UDP requires any packet to be available. 
363  * TCP requires accumulation of certain size? 
364  */
365 intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t length, int flags, struct sockaddr *restrict address, socklen_t *restrict address_len){
366         struct socket* sock = getsocket(p, socket);     
367         int copied = 0;
368         int returnval = 0;
369         if (sock == NULL) {
370                 set_errno(EBADF);
371                 return -1;
372         }
373         if (sock->so_type == SOCK_DGRAM){
374                 struct pbuf_head *ph = &(sock->recv_buff);
375                 struct pbuf* buf = NULL;
376                 buf = detach_pbuf(ph);
377                 if (!buf){
378                         // about to sleep
379                         sleep_on(&sock->sem);
380                         buf = detach_pbuf(ph);
381                         // Someone woke me up, there should be data..
382                         assert(buf);
383                 } else {
384                         __down_sem(&sock->sem, NULL);
385                 }
386                         copied = buf->len - sizeof(struct udp_hdr);
387                         if (copied > length)
388                                 copied = length;
389                         pbuf_header(buf, -UDP_HDR_SZ);
390                         // copy it to user space
391                         returnval = memcpy_to_user_errno(p, buffer, buf->payload, copied);
392                 }
393         if (returnval < 0) 
394                 return -1;
395         else
396                 return copied;
397 }
398
399 static int selscan(int maxfdp1, fd_set *readset_in, fd_set *writeset_in, fd_set *exceptset_in,
400              fd_set *readset_out, fd_set *writeset_out, fd_set *exceptset_out){
401         return 0;
402 }
403
404 /* TODO: Start respecting the time out value */ 
405 /* TODO: start respecting writefds and exceptfds */
406 intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
407                                 fd_set *exceptfds, struct timeval *timeout){
408         /* Create a semaphore */
409         struct semaphore_entry read_sem; 
410
411         init_sem(&(read_sem.sem), 0);
412
413         /* insert into the sem list of a fd / socket */
414         int low_fd = 0;
415         for (int i = low_fd; i< nfds; i++) {
416                 if(FD_ISSET(i, readfds)){
417                   struct socket* sock = getsocket(p, i);
418                         /* if the fd is not open or if the file descriptor is not a socket 
419                          * go to the next in the fd set 
420                          */
421                         if (sock == NULL) continue;
422                         /* for each file that is open, insert this semaphore to be woken up when there
423                         * is data available to be read
424                         */
425                         spin_lock(&sock->waiter_lock);
426                         LIST_INSERT_HEAD(&sock->waiters, &read_sem, link);
427                         spin_unlock(&sock->waiter_lock);
428                 }
429         }
430         /* At this point wait on the semaphore */
431   sleep_on(&(read_sem.sem));
432         /* someone woke me up, so walk through the list of descriptors and find one that is ready */
433         /* remove itself from all the lists that it is waiting on */
434         for (int i = low_fd; i<nfds; i++) {
435                 if (FD_ISSET(i, readfds)){
436                         struct socket* sock = getsocket(p,i);
437                         if (sock == NULL) continue;
438                         spin_lock(&sock->waiter_lock);
439                         LIST_REMOVE(&read_sem, link);
440                         spin_unlock(&sock->waiter_lock);
441                 }
442         }
443         fd_set readout, writeout, exceptout;
444         FD_ZERO(&readout);
445         FD_ZERO(&writeout);
446         FD_ZERO(&exceptout);
447         for (int i = low_fd; i< nfds; i ++){
448                 if (readfds && FD_ISSET(i, readfds)){
449                   struct socket* sock = getsocket(p, i);
450                         if ((sock->recv_buff).qlen > 0){
451                                 FD_SET(i, &readout);
452                         }
453                         /* if the socket is ready, then we can return it */
454                 }
455         }
456         if (readfds)
457                 memcpy(readfds, &readout, sizeof(*readfds));
458         if (writefds)
459                 memcpy(writefds, &writeout, sizeof(*writefds));
460         if (exceptfds)
461                 memcpy(readfds, &readout, sizeof(*readfds));
462
463         /* Sleep on that semaphore */
464         /* Somehow get these file descriptors to wake me up when there is new data */
465         return 0;
466 }