about summary refs log tree commit diff
path: root/src/libunixonacid/unixmessage_receive.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libunixonacid/unixmessage_receive.c')
-rw-r--r--src/libunixonacid/unixmessage_receive.c56
1 files changed, 39 insertions, 17 deletions
diff --git a/src/libunixonacid/unixmessage_receive.c b/src/libunixonacid/unixmessage_receive.c
index 21491fa..5fa16c4 100644
--- a/src/libunixonacid/unixmessage_receive.c
+++ b/src/libunixonacid/unixmessage_receive.c
@@ -5,7 +5,8 @@
 #include <errno.h>
 #include <sys/socket.h>
 #include <sys/uio.h>
-#include <skalibs/uint.h>
+#include <skalibs/uint16.h>
+#include <skalibs/uint32.h>
 #include <skalibs/cbuffer.h>
 #include <skalibs/djbunix.h>
 #include <skalibs/error.h>
@@ -45,12 +46,12 @@ static int unixmessage_receiver_fill (unixmessage_receiver_t *b)
     .msg_iov = iov,
     .msg_iovlen = 2,
     .msg_flags = 0,
-    .msg_control = ancilbuf,
-    .msg_controllen = sizeof(ancilbuf)
+    .msg_control = b->fds_ok & 1 ? ancilbuf : 0,
+    .msg_controllen = b->fds_ok & 1 ? sizeof(ancilbuf) : 0
   } ;
   unsigned int auxlen ;
   int r = -1 ;
-  if (cbuffer_isfull(&b->mainb) || cbuffer_isfull(&b->auxb))
+  if (cbuffer_isfull(&b->mainb) || ((b->fds_ok & 1) && cbuffer_isfull(&b->auxb)))
     return (errno = ENOBUFS, -1) ;
   {
     siovec_t v[2] ;
@@ -62,6 +63,7 @@ static int unixmessage_receiver_fill (unixmessage_receiver_t *b)
     r = recvmsg(b->fd, &msghdr, awesomeflags) ;
     if (!r || (r < 0 && errno != EINTR)) return r ;
   }
+  if (b->fds_ok & 1)
   {
     struct cmsghdr *c = CMSG_FIRSTHDR(&msghdr) ;
     if (c)
@@ -69,16 +71,32 @@ static int unixmessage_receiver_fill (unixmessage_receiver_t *b)
       if (c->cmsg_level != SOL_SOCKET
        || c->cmsg_type != SCM_RIGHTS) return (errno = EPROTO, -1) ;
       auxlen = (unsigned int)(c->cmsg_len - (CMSG_DATA(c) - (unsigned char *)c)) ;
+      if (auxlen && !(b->fds_ok & 2))
+      {
+        register unsigned int i = auxlen/sizeof(int) ;
+        while (i--) fd_close(((int *)CMSG_DATA(c))[i]) ;
+        return (errno = EPROTO, -1) ;
+      }
 #ifndef SKALIBS_HASCMSGCLOEXEC
       {
         register unsigned int i = 0 ;
         for (; i < auxlen/sizeof(int) ; i++)
-          if (coe(((int *)CMSG_DATA(c))[i]) < 0) return -1 ;
+          if (coe(((int *)CMSG_DATA(c))[i]) < 0)
+          {
+            int e = errno ;
+            i++ ;
+            while (i--) fd_close(((int *)CMSG_DATA(c))[i]) ;
+            errno = e ;
+            return -1 ;
+          }
       }
 #endif
-      if (msghdr.msg_flags & MSG_CTRUNC) return (errno = EPROTO, -1) ;
-      if (cbuffer_put(&b->auxb, (char *)CMSG_DATA(c), auxlen) < auxlen)
+      if ((msghdr.msg_flags & MSG_CTRUNC) || cbuffer_put(&b->auxb, (char *)CMSG_DATA(c), auxlen) < auxlen)
+      {
+        register unsigned int i = auxlen/sizeof(int) ;
+        while (i--) fd_close(((int *)CMSG_DATA(c))[i]) ;
         return (errno = ENOBUFS, -1) ;
+      }
     }
   }
   cbuffer_WSEEK(&b->mainb, r) ;
@@ -89,21 +107,25 @@ int unixmessage_receive (unixmessage_receiver_t *b, unixmessage_t *m)
 {
   if (b->maindata.len == b->mainlen && b->auxdata.len == b->auxlen)
   {
-    char pack[sizeof(unsigned int) << 1] ;
-    if (cbuffer_len(&b->mainb) < sizeof(unsigned int) << 1)
+    char pack[6] ;
+    if (cbuffer_len(&b->mainb) < 6)
     {
       register int r = sanitize_read(unixmessage_receiver_fill(b)) ;
       if (r <= 0) return r ;
-      if (cbuffer_len(&b->mainb) < sizeof(unsigned int) << 1)
-        return (errno = EWOULDBLOCK, 0) ;
+      if (cbuffer_len(&b->mainb) < 6) return (errno = EWOULDBLOCK, 0) ;
     }
-    cbuffer_get(&b->mainb, pack, sizeof(unsigned int) << 1) ;
-    uint_unpack_big(pack, &b->mainlen) ;
-    uint_unpack_big(pack + sizeof(unsigned int), &b->auxlen) ;
+    cbuffer_get(&b->mainb, pack, 6) ;
+    uint32_unpack_big(pack, &b->mainlen) ;
+    if (b->fds_ok & 1) uint16_unpack_big(pack + 4, &b->auxlen) ;
+    else b->auxlen = 0 ;
     b->auxlen *= sizeof(int) ;
-    if (!stralloc_ready(&b->maindata, b->mainlen)) return -1 ;
+    if (b->mainlen > UNIXMESSAGE_MAXSIZE
+     || b->auxlen > ((b->fds_ok & 2) ? UNIXMESSAGE_MAXFDS * sizeof(int) : 0))
+      return (errno = EPROTO, -1) ;
+    if (!stralloc_ready(&b->maindata, b->mainlen)
+     || !stralloc_ready(&b->auxdata, b->auxlen))
+      return -1 ;
     b->maindata.len = 0 ;
-    if (!stralloc_ready(&b->auxdata, b->auxlen)) return -1 ;
     b->auxdata.len = 0 ;
   }
 
@@ -124,6 +146,6 @@ int unixmessage_receive (unixmessage_receiver_t *b, unixmessage_t *m)
   m->s = b->maindata.s ;
   m->len = b->maindata.len ;
   m->fds = (int *)b->auxdata.s ;
-  m->nfds = b->auxlen / sizeof(int) ;
+  m->nfds = b->auxdata.len / sizeof(int) ;
   return 1 ;
 }