Adding select support for basic socket udp receive.
[akaros.git] / kern / src / net / udp.c
1 /**
2  * Contains shamelessly stolen code from BSD & lwip, both have
3  * BSD-style licenses
4  *
5  */
6 #include <ros/common.h>
7 #include <string.h>
8 #include <kmalloc.h>
9 #include <socket.h>
10 #include <net.h>
11 #include <sys/queue.h>
12 #include <atomic.h>
13
14 #include <bits/netinet.h>
15 #include <net/ip.h>
16 #include <net/udp.h>
17 #include <slab.h>
18 #include <socket.h>
19
20 struct udp_pcb *udp_pcbs;
21 uint16_t udp_port_num = SOCKET_PORT_START;
22
23 struct udp_pcb* udp_new(void){
24         struct udp_pcb *pcb = kmem_cache_alloc(udp_pcb_kcache, 0);
25     // if pcb is only tracking ttl, then no need!
26         if (pcb!= NULL){
27                 pcb->ttl = UDP_TTL;
28         memset(pcb, 0, sizeof(struct udp_pcb));
29         }
30         return pcb;
31 }
32
33 int udp_send(struct udp_pcb *pcb, struct pbuf *p)
34 {
35   /* send to the packet using remote ip and port stored in the pcb */
36   // rip and rport should be in socket not pcb?
37   return udp_sendto(pcb, p, &pcb->remote_ip, pcb->remote_port);
38 }
39
40 typedef unsigned char u16;
41 typedef unsigned long u32;
42
43 u16 udp_sum_calc(u16 len_udp, u16 src_addr[],u16 dest_addr[],  int padding, u16 buff[])
44 {
45 u16 prot_udp=17;
46 u16 padd=0;
47 u16 word16;
48 u32 sum;
49 int i;
50         
51         // Find out if the length of data is even or odd number. If odd,
52         // add a padding byte = 0 at the end of packet
53         if ((padding&1)==1){
54                 padd=1;
55                 buff[len_udp]=0;
56         }
57         
58         //initialize sum to zero
59         sum=0;
60         
61         // make 16 bit words out of every two adjacent 8 bit words and 
62         // calculate the sum of all 16 vit words
63         for (i=0;i<len_udp+padd;i=i+2){
64                 word16 =((buff[i]<<8)&0xFF00)+(buff[i+1]&0xFF);
65                 sum = sum + (unsigned long)word16;
66         }       
67         // add the UDP pseudo header which contains the IP source and destinationn addresses
68         for (i=0;i<4;i=i+2){
69                 word16 =((src_addr[i]<<8)&0xFF00)+(src_addr[i+1]&0xFF);
70                 sum=sum+word16; 
71         }
72         for (i=0;i<4;i=i+2){
73                 word16 =((dest_addr[i]<<8)&0xFF00)+(dest_addr[i+1]&0xFF);
74                 sum=sum+word16;         
75         }
76         // the protocol number and the length of the UDP packet
77         sum = sum + prot_udp + len_udp;
78
79         // keep only the last 16 bits of the 32 bit calculated sum and add the carries
80         while (sum>>16)
81                 sum = (sum & 0xFFFF)+(sum >> 16);
82                 
83         // Take the one's complement of sum
84         sum = ~sum;
85
86 return ((u16) sum);
87 }
88
89 int udp_sendto(struct udp_pcb *pcb, struct pbuf *p,
90                     struct in_addr *dst_ip, uint16_t dst_port){
91     // we now have one netif to send to, otherwise we need to route
92     // ip_route();
93     struct udp_hdr *udphdr;
94     struct pbuf *q;
95                 printd("udp_sendto ip %x, port %d\n", dst_ip->s_addr, dst_port); 
96     // broadcast?
97     if (pcb->local_port == 0) {
98                                 /* if the PCB not bound to a port, bind it and give local ip */
99         if (udp_bind(pcb, &pcb->local_ip, pcb->local_port)!=0)
100                                         warn("udp binding failed \n");
101     }
102     if (pbuf_header(p, UDP_HLEN)){ // we could probably save this check for block.
103         // CHECK: should allocate enough for the other headers too
104         q = pbuf_alloc(PBUF_IP, UDP_HLEN, PBUF_RAM);
105         if (q == NULL)
106            panic("out of memory");
107         // if the original packet is not empty
108         if (p->tot_len !=0) {
109             pbuf_chain(q,p);
110             // check if it is chained properly ..
111         }
112     } else {
113                                 /* Successfully padded the header*/
114                                 q = p;
115     }
116
117     udphdr = (struct udp_hdr *) q->payload;
118                 printd("src port %d, dst port %d \n, length %d ", pcb->local_port, ntohs(dst_port), q->tot_len);
119     udphdr->src_port = htons(pcb->local_port);
120     udphdr->dst_port = (dst_port);
121     udphdr->length = htons(q->tot_len); 
122                 udphdr->checksum = 0; // just to be sure.
123                 // printd("checksum inet_chksum %x \n", udphdr->checksum);
124                 printd("params src addr %x, dst addr %x, length %x \n", global_ip.s_addr, (dst_ip->s_addr), 
125                                           q->tot_len);
126
127     udphdr->checksum = inet_chksum_pseudo(q, htonl(global_ip.s_addr), dst_ip->s_addr,
128                                                                                          IPPROTO_UDP, q->tot_len);
129                 printd ("method ours %x\n", udphdr->checksum);
130                 // 0x0000; //either use brho's checksum or use cards' capabilities
131                 // ip_output(q, src_ip, dst_ip, pcb->ttl, pcb->tos, IP_PROTO_UDP);
132                 ip_output(q, &global_ip, dst_ip, IPPROTO_UDP);
133     return 0;
134 }
135 /* TODO: use the real queues we have implemented... */
136 int udp_bind(struct udp_pcb *pcb, struct in_addr *ip, uint16_t port){ 
137     int rebind = pcb->local_port;
138     struct udp_pcb *ipcb;
139                 assert(pcb);
140                 /* trying to assign port */
141     if (port != 0)
142         pcb->local_port = port;
143
144     /* no lock needed since we are just traversing/reading */
145     /* Check for double bind and rebind of the same pcb */
146     for (ipcb = udp_pcbs; ipcb != NULL; ipcb = ipcb->next) {
147         /* is this UDP PCB already on active list? */
148         if (pcb == ipcb) {
149             rebind = 1; //already on the list
150         } else if (ipcb->local_port == port){
151             warn("someone else is using the port %d\n" , port); 
152             return -1;
153         }
154     }
155     /* accept data for all interfaces */
156     if (ip == NULL || (ip->s_addr == INADDR_ANY.s_addr))
157                 /* true right now */
158         pcb->local_ip = INADDR_ANY;
159     /* assign a port */
160     if (port == 0) {
161         port = SOCKET_PORT_START; 
162         ipcb = udp_pcbs;
163         while ((ipcb != NULL) && (port != SOCKET_PORT_END)) {
164             if (ipcb->local_port == port) {
165                 /* port is already used by another udp_pcb */
166                 port++;
167                 /* restart scanning all udp pcbs */
168                 ipcb = udp_pcbs;
169             } else {
170                 /* go on with next udp pcb */
171                 ipcb = ipcb->next;
172             }
173         }
174         if (ipcb != NULL){
175             warn("No more udp ports available!");
176         }
177     }
178     if (rebind == 0) {
179         /* place the PCB on the active list if not already there */
180                                 pcb->next = udp_pcbs;
181                                 udp_pcbs = pcb;
182     }
183                 printk("local port bound to 0x%x \n", port);
184     pcb->local_port = port;
185     return 0;
186 }
187
188 /* port are in host order, ips are in network order */
189 /* Think: a pcb is here, if someone is waiting for a connection or the udp conn
190  * has been established */
191 static struct udp_pcb* find_pcb(struct udp_pcb* list, uint16_t src_port, uint16_t dst_port,
192                                                                 uint16_t srcip, uint16_t dstip) {
193         struct udp_pcb* uncon_pcb = NULL;
194         struct udp_pcb* pcb = NULL;
195         uint8_t local_match = 0;
196
197         for (pcb = list; pcb != NULL; pcb = pcb->next) {
198                 local_match = 0;
199                 if ((pcb->local_port == dst_port) 
200                         && (pcb->local_ip.s_addr == dstip 
201                         || ip_addr_isany(&pcb->local_ip))){
202                                 local_match = 1;
203         if ((uncon_pcb == NULL) && 
204             ((pcb->flags & UDP_FLAGS_CONNECTED) == 0)) {
205           /* the first unconnected matching PCB */
206           uncon_pcb = pcb;
207         }
208                 }
209
210                 if (local_match && (pcb->remote_port == src_port) &&
211                                 (ip_addr_isany(&pcb->remote_ip) ||
212                                  pcb->remote_ip.s_addr == srcip))
213                         /* perfect match */
214                         return pcb;
215         }
216         return uncon_pcb;
217 }
218
219 #if 0 // not working yet
220 // need to have pbuf queue support
221 int udp_attach(struct pbuf *p, struct sock *socket) {
222         // pretend the attaching of packet is succesful
223         /*
224         recv_q->last->next = p;
225         recv_q->last=p->last
226         */ 
227 }
228
229 #endif 
230
231 /** Process an incoming UDP datagram. 
232  * Given an incoming UDP datagram, this function finds the right PCB
233  * which links to the right socket buffer, and attaches the datagram
234  * to the right socket. 
235  * If no appropriate PCB is found, the pbuf is freed.
236  */ 
237
238 /** TODO: think about combining udp_input and ip_input together */
239 // TODO: figure out if we even need a PCB? or just socket buff. 
240 // TODO: test out looking up pcbs.. since matching function may fail
241
242 void wrap_restart_kthread(struct trapframe *tf, uint32_t srcid,
243                                         long a0, long a1, long a2){
244         restart_kthread((struct kthread*) a0);
245 }
246
247 int udp_input(struct pbuf *p){
248         struct udp_hdr *udphdr;
249
250         struct udp_pcb *pcb, uncon_pcb;
251         struct ip_hdr *iphdr;
252         uint16_t src, dst;
253         bool local_match = 0;
254         iphdr = (struct ip_hdr *)p->payload;
255         /* Move the header to where the udp header is */
256         if (pbuf_header(p, - PBUF_IP_HLEN)){
257                 warn("udp_input: Did not find a matching PCB for a udp packet\n");
258                 pbuf_free(p);
259                 return -1;
260         }
261         printk("start of udp %p\n", p->payload);
262         udphdr = (struct udp_hdr *)p->payload;
263         /* convert the src port and dst port to host order */
264         src = ntohs(udphdr->src_port);
265         dst = ntohs(udphdr->dst_port);
266         pcb = find_pcb(udp_pcbs, src, dst, iphdr->src_addr, iphdr->dst_addr);
267         /* TODO: Possibly adjust the pcb to the head of the queue? */
268         /* TODO: Linux uses a set of hashtables to lookup PCBs 
269          * Look at __udp4_lib_lookup function in Linux kernel 2.6.21.1
270          */
271         /* Anything that is not directed at this pcb should have been dropped */
272         if (pcb == NULL){
273                 warn("udp_input: Did not find a matching PCB for a udp packet\n");
274                 pbuf_free(p);
275                 return -1;
276         }
277
278         /* checksum check */
279   if (udphdr->checksum != 0) {
280     if (inet_chksum_pseudo(p, (iphdr->src_addr), (iphdr->dst_addr), 
281                                  IPPROTO_UDP, p->tot_len) != 0){
282                         warn("udp_input: UPD datagram discarded due to failed chksum!");
283                         pbuf_free(p);
284                         return -1;
285     }
286         }
287   /* ignore SO_REUSE */
288         if (pcb != NULL && pcb->pcbsock != NULL){
289                 /* For each in the pbuf chain, disconnect from the chain and add it to the
290                  * recv_buff of the correct socket 
291                  */ 
292                 struct socket *sock = pcb->pcbsock;
293                 attach_pbuf(p, &sock->recv_buff);
294                 struct kthread *kthread;
295                 /* First notify any blocking recv calls,
296                  * then notify anyone who might be waiting in a select
297                  */ 
298                 // multiple people might be waiting on the socket here..
299                 kthread = __up_sem(&(sock->sem), FALSE);
300                 if (kthread) {
301                          send_kernel_message(core_id(), (amr_t)wrap_restart_kthread, (long)kthread, 0, 0,
302                                                                                                   KMSG_ROUTINE);
303                 } else {
304                         // wake up all waiters
305                         struct semaphore_entry *sentry, *sentry_tmp;
306                         spin_lock(&sock->waiter_lock);
307                   LIST_FOREACH_SAFE(sentry, &(sock->waiters), link, sentry_tmp){
308                                 //should only wake up one waiter
309                                 kthread = __up_sem(&sentry->sem, true);
310                                 if (kthread){
311                                 send_kernel_message(core_id(), (amr_t)wrap_restart_kthread, (long)kthread, 0, 0,
312                                                                                                   KMSG_ROUTINE);
313                                 }
314                                 LIST_REMOVE(sentry, link);
315                                 /* do not need to free since all the sentry are stack-based vars */
316                         }
317                         spin_unlock(&sock->waiter_lock);
318                 }
319                 // the attaching of pbuf should have increfed pbuf ref, so free is simply a decref
320                 pbuf_free(p);
321         }
322         return 0;
323 }