9240abdc1f025e1a934a7f26465d2579afd483bc
[akaros.git] / kern / src / socket.c
1 /*
2  * Copyright (c) 2011 The Regents of the University of California
3  * David Zhu <yuzhu@cs.berkeley.edu>
4  * See LICENSE for details.
5  * 
6  * Socket layer on top of TCP abstraction. Similar to the BSD implementation.
7  *
8  */
9 #include <ros/common.h>
10 #include <socket.h>
11 #include <vfs.h>
12 #include <kref.h>
13 #include <syscall.h>
14 #include <sys/uio.h>
15 #include <mbuf.h>
16 #include <ros/errno.h>
17 #include <net.h>
18 #include <net/udp.h>
19 #include <net/pbuf.h>
20 #include <umem.h>
21 /*
22  *TODO: Figure out which socket.h is used where
23  *There are several socket.h in kern, and a couple more in glibc. Perhaps the glibc ones
24  *should grab from here..
25  */
26
27 struct kmem_cache *sock_kcache;
28 struct kmem_cache *mbuf_kcache;
29 struct kmem_cache *udp_pcb_kcache;
30 // file ops needed to support read/write on socket fd
31 static struct file_operations socket_op = {
32         0,
33         0,//soo_read,
34         0,//soo_write,
35         0,
36         0,
37         0,
38         0,
39         0,
40         0,
41         0,//soo_poll,
42         0,
43         0,
44         0, // sendpage might apply here
45         0,
46 };
47 static struct socket* getsocket(struct proc *p, int fd){
48         /* look up fd -> file */
49         struct file *so_file = get_file_from_fd(&(p->open_files), fd);
50
51         /* get socket and verify its type */
52         if (so_file == NULL){
53                 printd("getsocket() fd -> null file: fd %d\n", fd);
54                 return NULL;
55         }
56         if (so_file->f_op != &socket_op) {
57                 set_errno(ENOTSOCK);
58                 printd("fd %d maps to non-socket file\n");
59                 return NULL;
60         } else
61                 return (struct socket*) so_file->f_privdata;
62 }
63
64 struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
65         struct socket *newsock = kmem_cache_alloc(sock_kcache, 0);
66         assert(newsock);
67
68         newsock->so_family = socket_family;
69         newsock->so_type = socket_type;
70         newsock->so_protocol = protocol;
71         newsock->so_state = SS_ISDISCONNECTED;
72         pbuf_head_init(&newsock->recv_buff);
73         pbuf_head_init(&newsock->send_buff);
74         init_sem(&newsock->sem, 0);
75         if (socket_type == SOCK_DGRAM){
76                 newsock->so_pcb = udp_new();
77                 /* back link */
78                 ((struct udp_pcb*) (newsock->so_pcb))->pcbsock = newsock;
79         }
80         return newsock;
81
82 }
83 // TODO: refactor vfs so we can allocate fd and do the basic initialization
84 struct file *alloc_socket_file(struct socket* sock) {
85         struct file *file = alloc_file();
86         if (file == NULL) return 0;
87
88         // Linux fakes a dentry and an inode for socks, see socket.c : sock_alloc_file
89         file->f_dentry = NULL; // This might break things?
90         file->f_vfsmnt = 0;
91         file->f_flags = 0;
92
93         file->f_mode = S_IRUSR | S_IWUSR; // both read and write for socket files
94
95         file->f_pos = 0;
96         file->f_uid = 0;
97         file->f_gid = 0;
98         file->f_error = 0;
99
100         file->f_op = &socket_op;
101         file->f_privdata = sock;
102         file->f_mapping = 0;
103         return file;
104 }
105
106 void socket_init(){
107         
108         /* allocate buf for socket */
109         sock_kcache = kmem_cache_create("socket", sizeof(struct socket),
110                                                                         __alignof__(struct socket), 0, 0, 0);
111         udp_pcb_kcache = kmem_cache_create("udppcb", sizeof(struct udp_pcb), 
112                                                                         __alignof__(struct udp_pcb), 0, 0, 0);
113
114         pbuf_init();
115
116 }
117
118 intreg_t sys_socket(struct proc *p, int socket_family, int socket_type, int protocol){
119         //check validity of params
120         if (socket_family !=AF_INET && socket_type!=SOCK_DGRAM)
121                 return 0;
122         struct socket *sock = alloc_sock(socket_family, socket_type, protocol);
123         struct file *file = alloc_socket_file(sock);
124         
125         if (file == NULL) return -1;
126         int fd = insert_file(&p->open_files, file, 0);
127         if (fd < 0) {
128                 warn("File insertion for socket open failed");
129                 return -1;
130         }
131         kref_put(&file->f_kref);
132         printk("Socket open, res = %d\n", fd);
133         return fd;
134 }
135 intreg_t send_iov(struct socket* sock, struct iovec* iov, int flags){
136         
137         // COPY_COUNT: for each iov, copy into mbuf, and send
138         // should not copy here, copy in the protocol..
139         // should be esomething like this sock->so_proto->pr_send(sock, iov, flags);
140         // make it datagram specific for now...
141         send_datagram(sock, iov, flags);
142         // finally time to check for validity of UA, in the protocol send
143         return 0;       
144 }
145 /*TODO: iov support currently broken */
146 int send_datagram(struct socket* sock, struct iovec* iov, int flags){
147         // is this a connection oriented protocol? 
148         struct pbuf *prev = NULL;
149         struct pbuf *curr = NULL;
150         if (sock->so_type == SOCK_STREAM){
151                 set_errno(ENOTCONN);
152                 return -1;
153         }
154         
155         // possible sock locks needed
156         if ((sock->so_state & SS_ISCONNECTED) == 0){
157                 set_errno(EINVAL);
158                 return -1;
159         }
160     // pbuf_ref needs to map in the user ref
161         for (int i = 0; i< sizeof(iov) / sizeof (struct iovec); i++){
162                 prev = curr;
163                 curr = pbuf_alloc(PBUF_TRANSPORT, iov[i].iov_len, PBUF_REF);
164                 if (prev!=NULL) pbuf_chain(prev, curr);
165         }
166         // struct pbuf* pb = pbuf_alloc(PBUF_TRANSPORT, PBUF_REF);
167         udp_send(sock->so_pcb, prev);
168         return 0;
169         
170 }
171
172 /* sys_sendto can send SOCK_DGRAM and eventually SOCK_STREAM 
173  * SOCK_DGRAM uses PBUF_REF since UDP does not need to wait for ack
174  * SOCK_STREAM uses PBUF_
175  *
176  */
177 intreg_t sys_sendto(struct proc *p_proc, int fd, const void *buffer, size_t length, 
178                         int flags, const struct sockaddr *dest_addr, socklen_t dest_len){
179         // look up the socket
180         struct socket* sock = getsocket(p_proc, fd);
181         int error;
182         struct sockaddr_in *in_addr;
183         uint16_t r_port;
184         if (sock == NULL) {
185                 set_errno(EBADF);
186                 return -1;      
187         }
188         if (sock->so_type == SOCK_DGRAM){
189                 in_addr = (struct sockaddr_in *)dest_addr;
190                 struct pbuf* buf = pbuf_alloc(PBUF_TRANSPORT, length, PBUF_REF);
191                 if (buf != NULL)
192                         buf->payload = (void*)buffer;
193                 else 
194                         warn("pbuf alloc failed \n");
195                 // potentially unsafe cast to udp_pcb 
196                 return udp_sendto((struct udp_pcb*) sock->so_pcb, buf, &in_addr->sin_addr, in_addr->sin_port);
197         }
198
199         return -1;
200   //TODO: support for sendmsg and iovectors? Let's get the basis working first!
201         #if 0 
202         // use iovector to handle sendmsg calls too, and potentially scatter-gather
203         struct msghdr msg;
204         struct iovec iov;
205         struct uio auio;
206         
207         // checking for permission only when you are sending it
208         // potential bug TOCTOU, especially with async calls
209                 
210     msg.msg_name = dest_addr;
211     msg.msg_namelen = dest_len;
212     msg.msg_iov = &iov;
213     msg.msg_iovlen = 1;
214     msg.msg_control = 0;
215     
216         iov.iov_base = buffer;
217     iov.iov_len = length;
218         
219
220         // this is why we need another function to populate auio
221
222         auio.uio_iov = iov;
223         auio.uio_iovcnt = 1;
224         auio.uio_offset = 0;
225         auio.uio_resid = 0;
226         auio.uio_rw = UIO_WRITE;
227         auio.uio_proc = p;
228
229         // consider changing to send_uaio, since we care about progress.
230     error = send_iov(soc, iov, flags);
231         #endif
232 }
233
234 /* UDP and TCP has different waiting semantics
235  * UDP requires any packet to be available. 
236  * TCP requires accumulation of certain size? 
237  */
238 intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t length, int flags, struct sockaddr *restrict address, socklen_t *restrict address_len){
239         struct socket* sock = getsocket(p, socket);     
240         int copied = 0;
241         int returnval = 0;
242         if (sock == NULL) {
243                 set_errno(EBADF);
244                 return -1;
245         }
246         if (sock->so_type == SOCK_DGRAM){
247                 struct pbuf_head *ph = &(sock->recv_buff);
248                 struct pbuf* buf = NULL;
249                 buf = detach_pbuf(ph);
250                 if (!buf){
251                         // about to sleep
252                         sleep_on(&sock->sem);
253                         buf = detach_pbuf(ph);
254                         // Someone woke me up, there should be data..
255                         assert(buf);
256                 } else {
257                         __down_sem(&sock->sem, NULL);
258                 }
259                         copied = buf->len - sizeof(struct udp_hdr);
260                         if (copied > length)
261                                 copied = length;
262                         pbuf_header(buf, -UDP_HDR_SZ);
263                         // copy it to user space
264                         returnval = memcpy_to_user_errno(p, buffer, buf->payload, copied);
265                 }
266         if (returnval < 0) 
267                 return -1;
268         else
269                 return copied;
270 }