Merge origin/netpush (networking code) (XCC)
[akaros.git] / kern / src / socket.c
index 1972e40..b9f4084 100644 (file)
@@ -13,7 +13,6 @@
 #include <kref.h>
 #include <syscall.h>
 #include <sys/uio.h>
-#include <mbuf.h>
 #include <ros/errno.h>
 #include <net.h>
 #include <net/udp.h>
@@ -82,8 +81,8 @@ struct socket* alloc_sock(int socket_family, int socket_type, int protocol){
        STAILQ_INIT(&(newsock->acceptq));
        pbuf_head_init(&newsock->recv_buff);
        pbuf_head_init(&newsock->send_buff);
-       init_sem(&newsock->sem, 0);
-       init_sem(&newsock->accept_sem, 0);
+       sem_init_irqsave(&newsock->sem, 0);
+       sem_init_irqsave(&newsock->accept_sem, 0);
        spinlock_init(&newsock->waiter_lock);
        LIST_INIT(&newsock->waiters);
        return newsock;
@@ -134,6 +133,7 @@ intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t
        struct sockaddr_in *in_addr = (struct sockaddr_in *)addr;
        uint16_t r_port;
        struct socket *accepted = NULL;
+       int8_t irq_state = 0;
        if (sock == NULL) {
                set_errno(EBADF);
                return -1;      
@@ -141,11 +141,13 @@ intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t
        if (sock->so_type == SOCK_DGRAM){
                return -1; // indicates false for connect
        } else if (sock->so_type == SOCK_STREAM) {
+               /* XXX these do the same thing, what is it you actually wanted to do?
+                * (Originally the first was sleep_on, and the second __down_sem */
                if (STAILQ_EMPTY(&(sock->acceptq))) {
                        // block on the acceptq
-                       sleep_on(&sock->accept_sem);
+                       sem_down_irqsave(&sock->accept_sem, &irq_state);
                } else {
-                       __down_sem(&sock->accept_sem, NULL);
+                       sem_down_irqsave(&sock->accept_sem, &irq_state);
                }
                spin_lock_irqsave(&sock->waiter_lock);
                accepted = STAILQ_FIRST(&(sock->acceptq));
@@ -164,14 +166,10 @@ intreg_t sys_accept(struct proc *p, int sockfd, struct sockaddr *addr, socklen_t
        return -1;
 }
 
-static void wrap_restart_kthread(struct trapframe *tf, uint32_t srcid,
-                                       long a0, long a1, long a2){
-       restart_kthread((struct kthread*) a0);
-}
-
 static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
        struct socket *sockold = (struct socket *) arg;
        struct socket *sock = alloc_sock(sockold->so_family, sockold->so_type, sockold->so_protocol);
+       int8_t irq_state = 0;
        
        sock->so_pcb = newpcb;
        newpcb->pcbsock = sock;
@@ -179,11 +177,7 @@ static error_t accept_callback(void *arg, struct tcp_pcb *newpcb, error_t err) {
        STAILQ_INSERT_TAIL(&sockold->acceptq, sock, next);
        // wake up any kthread who is potentially waiting
        spin_unlock_irqsave(&sockold->waiter_lock);
-       struct kthread *kthread = __up_sem(&(sock->accept_sem), FALSE);
-       if (kthread) {
-               send_kernel_message(core_id(), (amr_t)wrap_restart_kthread, (long)kthread, 0, 0,
-                                                                                                 KMSG_ROUTINE);
-       } 
+       sem_up_irqsave(&sock->accept_sem, &irq_state);
        return 0;
 }
 intreg_t sys_listen(struct proc *p, int sockfd, int backlog) {
@@ -233,10 +227,11 @@ intreg_t sys_connect(struct proc *p, int sock_fd, const struct sockaddr* addr, i
        return -1;
 }
 
-intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len, int flags) {
+intreg_t sys_send(struct proc *p, int sockfd, const void *buf, size_t len,
+                  int flags) {
        printk("sys_send called \n");
-       struct socket* sock = getsocket(p_proc, fd);
-       const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr;
+       struct socket* sock = getsocket(p, sockfd);
+       const struct sockaddr_in *in_addr = (const struct sockaddr_in *)buf;
        uint16_t r_port;
        if (sock == NULL) {
                set_errno(EBADF);
@@ -405,6 +400,7 @@ intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t
        struct socket* sock = getsocket(p, socket);     
        int copied = 0;
        int returnval = 0;
+       int8_t irq_state = 0;
        if (sock == NULL) {
                set_errno(EBADF);
                return -1;
@@ -415,12 +411,12 @@ intreg_t sys_recvfrom(struct proc *p, int socket, void *restrict buffer, size_t
                buf = detach_pbuf(ph);
                if (!buf){
                        // about to sleep
-                       sleep_on(&sock->sem);
+                       sem_down_irqsave(&sock->sem, &irq_state);
                        buf = detach_pbuf(ph);
                        // Someone woke me up, there should be data..
                        assert(buf);
                } else {
-                       __down_sem(&sock->sem, NULL);
+                       sem_down_irqsave(&sock->sem, &irq_state);
                }
                        copied = buf->len - sizeof(struct udp_hdr);
                        if (copied > length)
@@ -446,8 +442,9 @@ intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
                                fd_set *exceptfds, struct timeval *timeout){
        /* Create a semaphore */
        struct semaphore_entry read_sem; 
+       int8_t irq_state = 0;
 
-       init_sem(&(read_sem.sem), 0);
+       sem_init_irqsave(&(read_sem.sem), 0);
 
        /* insert into the sem list of a fd / socket */
        int low_fd = 0;
@@ -467,7 +464,7 @@ intreg_t sys_select(struct proc *p, int nfds, fd_set *readfds, fd_set *writefds,
                }
        }
        /* At this point wait on the semaphore */
-  sleep_on(&(read_sem.sem));
+       sem_down_irqsave(&read_sem.sem, &irq_state);
        /* someone woke me up, so walk through the list of descriptors and find one that is ready */
        /* remove itself from all the lists that it is waiting on */
        for (int i = low_fd; i<nfds; i++) {