about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--malloc/malloc.c24
-rw-r--r--sysdeps/aarch64/morello/libc-cap.h66
-rw-r--r--sysdeps/generic/libc-cap.h1
3 files changed, 54 insertions, 37 deletions
diff --git a/malloc/malloc.c b/malloc/malloc.c
index ededc5cfe2..3f4d4a2356 100644
--- a/malloc/malloc.c
+++ b/malloc/malloc.c
@@ -571,12 +571,12 @@ cap_narrow (void *p, size_t n)
 /* Used in realloc if p is already narrowed or NULL.
    Must match a previous cap_reserve call.  */
 static __always_inline bool
-cap_narrow_check (void *p, void *oldp)
+cap_narrow_check (void *p, void *oldp, void *narrow_oldp)
 {
   if (cap_narrowing_enabled)
     {
       if (p == NULL)
-	(void) __libc_cap_narrow (oldp, 0);
+	__libc_cap_put_back (oldp, narrow_oldp);
       else
 	__libc_cap_unreserve ();
     }
@@ -586,12 +586,12 @@ cap_narrow_check (void *p, void *oldp)
 /* Used in realloc if p is new allocation or NULL but not yet narrowed.
    Must match a previous cap_reserve call.  */
 static __always_inline void *
-cap_narrow_try (void *p, size_t n, void *oldp)
+cap_narrow_try (void *p, size_t n, void *oldp, void *narrow_oldp)
 {
   if (cap_narrowing_enabled)
     {
       if (p == NULL)
-	(void) __libc_cap_narrow (oldp, 0);
+	__libc_cap_put_back (oldp, narrow_oldp);
       else
 	p = __libc_cap_narrow (p, n);
     }
@@ -3588,8 +3588,9 @@ __libc_free (void *mem)
   if (mem == 0)                              /* free(0) has no effect */
     return;
 
+  void *orig_mem = mem;
   mem = cap_widen (mem);
-  cap_drop (mem);
+  cap_drop (orig_mem);
 
   /* Quickly check that the freed pointer matches the tag for the memory.
      This gives a useful double-free detection.  */
@@ -3652,6 +3653,7 @@ __libc_realloc (void *oldmem, size_t bytes)
   if (oldmem == 0)
     return __libc_malloc (bytes);
 
+  void *orig_oldmem = oldmem;
   oldmem = cap_widen (oldmem);
 
   /* Perform a quick check to ensure that the pointer's tag matches the
@@ -3692,7 +3694,7 @@ __libc_realloc (void *oldmem, size_t bytes)
   /* Every return path below should unreserve using the cap_narrow* apis.  */
   if (!cap_reserve ())
     return NULL;
-  cap_drop (oldmem);
+  cap_drop (orig_oldmem);
 
   if (chunk_is_mmapped (oldp))
     {
@@ -3717,7 +3719,7 @@ __libc_realloc (void *oldmem, size_t bytes)
 	     caller for doing this, so we might want to
 	     reconsider.  */
 	  newmem = tag_new_usable (newmem);
-	  newmem = cap_narrow_try (newmem, bytes, oldmem);
+	  newmem = cap_narrow_try (newmem, bytes, oldmem, orig_oldmem);
 	  return newmem;
 	}
 #endif
@@ -3742,7 +3744,7 @@ __libc_realloc (void *oldmem, size_t bytes)
       else
 #endif
       newmem = __libc_malloc (bytes);
-      if (!cap_narrow_check (newmem, oldmem))
+      if (!cap_narrow_check (newmem, oldmem, orig_oldmem))
         return 0;              /* propagate failure */
 
 #ifdef __CHERI_PURE_CAPABILITY__
@@ -3760,7 +3762,7 @@ __libc_realloc (void *oldmem, size_t bytes)
     {
       /* Use memalign, copy, free.  */
       void *newmem = _mid_memalign (align, bytes, 0);
-      if (!cap_narrow_check (newmem, oldmem))
+      if (!cap_narrow_check (newmem, oldmem, orig_oldmem))
 	return newmem;
       size_t sz = memsize (oldp);
       memcpy (newmem, oldmem, sz < bytes ? sz : bytes);
@@ -3774,7 +3776,7 @@ __libc_realloc (void *oldmem, size_t bytes)
       newp = _int_realloc (ar_ptr, oldp, oldsize, nb);
       assert (!newp || chunk_is_mmapped (mem2chunk (newp)) ||
 	      ar_ptr == arena_for_chunk (mem2chunk (newp)));
-      return cap_narrow_try (newp, bytes, oldmem);
+      return cap_narrow_try (newp, bytes, oldmem, orig_oldmem);
     }
 
   __libc_lock_lock (ar_ptr->mutex);
@@ -3790,7 +3792,7 @@ __libc_realloc (void *oldmem, size_t bytes)
       /* Try harder to allocate memory in other arenas.  */
       LIBC_PROBE (memory_realloc_retry, 2, bytes, oldmem);
       newp = __libc_malloc (bytes);
-      if (!cap_narrow_check (newp, oldmem))
+      if (!cap_narrow_check (newp, oldmem, orig_oldmem))
 	return NULL;
       size_t sz = memsize (oldp);
       memcpy (newp, oldmem, sz);
diff --git a/sysdeps/aarch64/morello/libc-cap.h b/sysdeps/aarch64/morello/libc-cap.h
index 9e6f66aa22..84c20e9df8 100644
--- a/sysdeps/aarch64/morello/libc-cap.h
+++ b/sysdeps/aarch64/morello/libc-cap.h
@@ -30,8 +30,7 @@
 
 struct htentry
 {
-  uint64_t key;
-  uint64_t unused;
+  void *key;
   void *value;
 };
 
@@ -48,19 +47,19 @@ struct ht
 static inline bool
 htentry_isempty (struct htentry *e)
 {
-  return e->key == 0;
+  return (uint64_t) e->key == 0;
 }
 
 static inline bool
 htentry_isdeleted (struct htentry *e)
 {
-  return e->key == -1;
+  return (uint64_t) e->key == -1;
 }
 
 static inline bool
 htentry_isused (struct htentry *e)
 {
-  return e->key != 0 && e->key != -1;
+  return !htentry_isempty (e) && !htentry_isdeleted (e);
 }
 
 static inline uint64_t
@@ -154,9 +153,10 @@ ht_resize (struct ht *ht)
     {
       if (htentry_isused (e))
 	{
-	  uint64_t hash = ht_key_hash (e->key);
+	  uint64_t k = (uint64_t) e->key;
+	  uint64_t hash = ht_key_hash (k);
 	  used--;
-	  *ht_lookup (ht, e->key, hash) = *e;
+	  *ht_lookup (ht, k, hash) = *e;
 	}
     }
   ht_tab_free (oldtab, oldlen);
@@ -191,48 +191,61 @@ ht_unreserve (struct ht *ht)
 }
 
 static bool
-ht_add (struct ht *ht, uint64_t key, void *value)
+ht_add (struct ht *ht, void *key, void *value)
 {
+  uint64_t k = (uint64_t) key;
+  uint64_t hash = ht_key_hash (k);
+  assert (k != 0 && k != -1);
+
   __libc_lock_lock (ht->mutex);
   assert (ht->reserve > 0);
   ht->reserve--;
-  uint64_t hash = ht_key_hash (key);
-  struct htentry *e = ht_lookup (ht, key, hash);
+  struct htentry *e = ht_lookup (ht, k, hash);
   bool r = false;
   if (!htentry_isused (e))
     {
       if (htentry_isempty (e))
         ht->fill++;
       ht->used++;
-      e->key = key;
       r = true;
     }
+  e->key = key;
   e->value = value;
   __libc_lock_unlock (ht->mutex);
   return r;
 }
 
 static bool
-ht_del (struct ht *ht, uint64_t key)
+ht_del (struct ht *ht, void *key)
 {
+  uint64_t k = (uint64_t) key;
+  uint64_t hash = ht_key_hash (k);
+  assert (k != 0 && k != -1);
+
   __libc_lock_lock (ht->mutex);
-  struct htentry *e = ht_lookup (ht, key, ht_key_hash (key));
+  struct htentry *e = ht_lookup (ht, k, hash);
   bool r = htentry_isused (e);
   if (r)
     {
+      r = __builtin_cheri_equal_exact(e->key, key);
       ht->used--;
-      e->key = -1;
+      e->key = (void *) -1;
+      e->value = NULL;
     }
   __libc_lock_unlock (ht->mutex);
   return r;
 }
 
 static void *
-ht_get (struct ht *ht, uint64_t key)
+ht_get (struct ht *ht, void *key)
 {
+  uint64_t k = (uint64_t) key;
+  uint64_t hash = ht_key_hash (k);
+  assert (k != 0 && k != -1);
+
   __libc_lock_lock (ht->mutex);
-  struct htentry *e = ht_lookup (ht, key, ht_key_hash (key));
-  void *v = htentry_isused (e) ? e->value : NULL;
+  struct htentry *e = ht_lookup (ht, k, hash);
+  void *v = __builtin_cheri_equal_exact(e->key, key) ? e->value : NULL;
   __libc_lock_unlock (ht->mutex);
   return v;
 }
@@ -317,10 +330,9 @@ __libc_cap_align (size_t n)
 static __always_inline void *
 __libc_cap_narrow (void *p, size_t n)
 {
-  assert (p != NULL);
-  uint64_t key = (uint64_t)(uintptr_t) p;
-  assert (ht_add (&__libc_cap_ht, key, p));
   void *narrow = __builtin_cheri_bounds_set_exact (p, n);
+  assert (__builtin_cheri_tag_get (narrow));
+  assert (ht_add (&__libc_cap_ht, narrow, p));
   return narrow;
 }
 
@@ -329,9 +341,7 @@ __libc_cap_narrow (void *p, size_t n)
 static __always_inline void *
 __libc_cap_widen (void *p)
 {
-  assert (__builtin_cheri_tag_get (p) && __builtin_cheri_offset_get (p) == 0);
-  uint64_t key = (uint64_t)(uintptr_t) p;
-  void *cap = ht_get (&__libc_cap_ht, key);
+  void *cap = ht_get (&__libc_cap_ht, p);
   assert (cap == p);
   return cap;
 }
@@ -351,9 +361,13 @@ __libc_cap_unreserve (void)
 static __always_inline void
 __libc_cap_drop (void *p)
 {
-  assert (p != NULL);
-  uint64_t key = (uint64_t)(uintptr_t) p;
-  assert (ht_del (&__libc_cap_ht, key));
+  assert (ht_del (&__libc_cap_ht, p));
+}
+
+static __always_inline void
+__libc_cap_put_back (void *p, void *narrow)
+{
+  assert (ht_add (&__libc_cap_ht, narrow, p));
 }
 
 #endif
diff --git a/sysdeps/generic/libc-cap.h b/sysdeps/generic/libc-cap.h
index 9d93d61c9e..4a385d823b 100644
--- a/sysdeps/generic/libc-cap.h
+++ b/sysdeps/generic/libc-cap.h
@@ -39,5 +39,6 @@ void __libc_cap_link_error (void);
 #define __libc_cap_reserve(p) __libc_cap_fail (bool)
 #define __libc_cap_unreserve(p) __libc_cap_fail (void)
 #define __libc_cap_drop(p) __libc_cap_fail (void)
+#define __libc_cap_put_back(p, q) __libc_cap_fail (void)
 
 #endif