big speedups for memcpy/memset
[akaros.git] / kern / src / string.c
index 8cde851..4e3db1e 100644 (file)
@@ -4,6 +4,7 @@
 #pragma nosharc
 #endif
 
+#include <stdio.h>
 #include <string.h>
 #include <ros/memlayout.h>
 #include <assert.h>
@@ -129,87 +130,62 @@ strfind(const char *s, char c)
        return (char *) s;
 }
 
-// n must be a multiple of 16 and v must be uint32_t-aligned
+// memset aligned words.
 static inline void *
-memset16(uint32_t *COUNT(n/sizeof(uint32_t)) _v, uint32_t c, size_t n)
+memsetw(long* _v, long c, size_t n)
 {
-       uint32_t *start, *end;
-       uint32_t *BND(_v, end) v;
+       long *start, *end, *v;
 
        start = _v;
-       end = _v + n/sizeof(uint32_t);
+       end = _v + n/sizeof(long);
        v = _v;
-       c = c | c<<8 | c<<16 | c<<24;
-
-       if(n >= 64 && ((uintptr_t)v) % 8 == 0)
-       {
-               uint64_t* v64 = (uint64_t*)v;
-               uint64_t c64 = c | ((uint64_t)c)<<32;
-               while(v64 < (uint64_t*)end-7)
-               {
-                       v64[3] = v64[2] = v64[1] = v64[0] = c64;
-                       v64[7] = v64[6] = v64[5] = v64[4] = c64;
-                       v64 += 8;
-               }
-               v = (uint32_t*)v64;
-       }
-
-       while(v < end)
+       c = (char)c;
+       c = c | c<<8;
+       c = c | c<<16;
+       #if NUM_ADDR_BITS == 64
+       c = c | c<<32;
+       #elif NUM_ADDR_BITS != 32
+       # error
+       #endif
+
+       while(v < end - (8-1))
        {
                v[3] = v[2] = v[1] = v[0] = c;
                v += 4;
+               v[3] = v[2] = v[1] = v[0] = c;
+               v += 4;
        }
 
-       return start;
-}
-
-// n must be a multiple of 16 and v must be 4-byte aligned.
-// as allowed by ISO, behavior undefined if dst/src overlap
-static inline void *
-memcpy16(uint32_t *COUNT(n/sizeof(uint32_t)) _dst,
-         const uint32_t *COUNT(n/sizeof(uint32_t)) _src, size_t n)
-{
-       uint32_t *dststart, *SNT dstend, *SNT srcend;
-       uint32_t *BND(_dst,dstend) dst;
-       const uint32_t *BND(_src,srcend) src;
-
-       dststart = _dst;
-       dstend = (uint32_t *SNT)(_dst + n/sizeof(uint32_t));
-       srcend = (uint32_t *SNT)(_src + n/sizeof(uint32_t));
-       dst = _dst;
-       src = _src;
-
-       while(dst < dstend && src < srcend)
-       {
-               dst[0] = src[0];
-               dst[1] = src[1];
-               dst[2] = src[2];
-               dst[3] = src[3];
-
-               src += 4;
-               dst += 4;
-       }
+       while(v < end)
+         *v++ = c;
 
-       return dststart;
+       return start;
 }
 
-void *
-pagecopy(void* d, void* s)
-{
-       static_assert(PGSIZE % 64 == 0);
-       for(int i = 0; i < PGSIZE; i += 64)
-       {
-               *((uint64_t*)(d+i+0)) = *((uint64_t*)(s+i+0));
-               *((uint64_t*)(d+i+8)) = *((uint64_t*)(s+i+8));
-               *((uint64_t*)(d+i+16)) = *((uint64_t*)(s+i+16));
-               *((uint64_t*)(d+i+24)) = *((uint64_t*)(s+i+24));
-               *((uint64_t*)(d+i+32)) = *((uint64_t*)(s+i+32));
-               *((uint64_t*)(d+i+40)) = *((uint64_t*)(s+i+40));
-               *((uint64_t*)(d+i+48)) = *((uint64_t*)(s+i+48));
-               *((uint64_t*)(d+i+56)) = *((uint64_t*)(s+i+56));
-       }
-       return d;
-}
+// copy aligned words.
+// unroll 9 ways to get multiple misses in flight
+#define memcpyw(type, _dst, _src, n) \
+  do { \
+       type* restrict src = (type*)(_src); \
+       type* restrict dst = (type*)(_dst); \
+       type* srcend = src + (n)/sizeof(type); \
+       type* dstend = dst + (n)/sizeof(type); \
+       while (dst < dstend - (9-1)) { \
+               dst[0] = src[0]; \
+               dst[1] = src[1]; \
+               dst[2] = src[2]; \
+               dst[3] = src[3]; \
+               dst[4] = src[4]; \
+               dst[5] = src[5]; \
+               dst[6] = src[6]; \
+               dst[7] = src[7]; \
+               dst[8] = src[8]; \
+               src += 9; \
+               dst += 9; \
+       } \
+       while(dst < dstend) \
+         *dst++ = *src++; \
+  } while(0)
 
 void *
 memset(void *COUNT(_n) v, int c, size_t _n)
@@ -222,16 +198,16 @@ memset(void *COUNT(_n) v, int c, size_t _n)
 
        p = v;
 
-    while (n > 0 && ((uintptr_t)p & 7))
+    while (n > 0 && ((uintptr_t)p & (sizeof(long)-1)))
        {
                *p++ = c;
                n--;
        }
 
-       if(n >= 16 && ((uintptr_t)p & 3) == 0)
+       if (n >= sizeof(long))
        {
-               n0 = (n/16)*16;
-               memset16((uint32_t*COUNT(n0/sizeof(uint32_t)))p,c,n0);
+               n0 = n / sizeof(long) * sizeof(long);
+               memsetw((long*)p, c, n0);
                n -= n0;
                p += n0;
        }
@@ -246,26 +222,37 @@ memset(void *COUNT(_n) v, int c, size_t _n)
 }
 
 void *
-(DMEMCPY(1,2,3) memcpy)(void *COUNT(_n) dst, const void *COUNT(_n) src, size_t _n)
+memcpy(void* dst, const void* src, size_t _n)
 {
-       const char *BND(src,src+_n) s;
-       char *BND(dst,dst+_n) d;
-       size_t n0;
+       const char* s;
+       char* d;
+       size_t n0 = 0;
        size_t n = _n;
+       int align = sizeof(long)-1;
 
        s = src;
        d = dst;
 
-       if(n >= 16 && ((uintptr_t)src  & 3) == 0 && ((uintptr_t)dst & 3) == 0)
+       if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(long)-1)) == 0)
        {
-               n0 = (n/16)*16;
-               memcpy16((uint32_t*COUNT(n0/sizeof(uint32_t)))dst,
-                 (const uint32_t*COUNT(n0/sizeof(uint32_t)))src,n0);
-               n -= n0;
-               s += n0;
-               d += n0;
+               n0 = n / sizeof(long) * sizeof(long);
+               memcpyw(long, d, s, n0);
+       }
+       else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(int)-1)) == 0)
+       {
+               n0 = n / sizeof(int) * sizeof(int);
+               memcpyw(int, d, s, n0);
+       }
+       else if ((((uintptr_t)s | (uintptr_t)d) & (sizeof(short)-1)) == 0)
+       {
+               n0 = n / sizeof(short) * sizeof(short);
+               memcpyw(short, d, s, n0);
        }
 
+       n -= n0;
+       s += n0;
+       d += n0;
+
        while (n-- > 0)
                *d++ = *s++;