about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/thread/sem_post.c9
-rw-r--r--src/thread/sem_timedwait.c36
-rw-r--r--src/thread/sem_trywait.c5
3 files changed, 23 insertions, 27 deletions
diff --git a/src/thread/sem_post.c b/src/thread/sem_post.c
index 8f4700c3..148ab780 100644
--- a/src/thread/sem_post.c
+++ b/src/thread/sem_post.c
@@ -3,8 +3,11 @@
 
 int sem_post(sem_t *sem)
 {
-	a_inc(sem->__val);
-	if (sem->__val[1])
-		__wake(sem->__val, 1, 0);
+	int val, waiters;
+	do {
+		val = sem->__val[0];
+		waiters = sem->__val[1];
+	} while (a_cas(sem->__val, val, val+1+(val<0)) != val);
+	if (val<0 || waiters) __wake(sem->__val, 1, 0);
 	return 0;
 }
diff --git a/src/thread/sem_timedwait.c b/src/thread/sem_timedwait.c
index db05b417..1d4b3e2c 100644
--- a/src/thread/sem_timedwait.c
+++ b/src/thread/sem_timedwait.c
@@ -8,31 +8,21 @@ static void cleanup(void *p)
 
 int sem_timedwait(sem_t *sem, const struct timespec *at)
 {
-	int r;
-
-	if (a_fetch_add(sem->__val, -1) > 0) return 0;
-	a_inc(sem->__val);
-
-	if (at && at->tv_nsec >= 1000000000UL) {
-		errno = EINVAL;
-		return -1;
-	}
-
-	a_inc(sem->__val+1);
-	pthread_cleanup_push(cleanup, sem->__val+1)
-
-	for (;;) {
-		r = 0;
-		if (!sem_trywait(sem)) break;
-		r = __timedwait_cp(sem->__val, 0, CLOCK_REALTIME, at, 0);
+	while (sem_trywait(sem)) {
+		int r;
+		if (at && at->tv_nsec >= 1000000000UL) {
+			errno = EINVAL;
+			return -1;
+		}
+		a_inc(sem->__val+1);
+		a_cas(sem->__val, 0, -1);
+		pthread_cleanup_push(cleanup, sem->__val+1);
+		r = __timedwait_cp(sem->__val, -1, CLOCK_REALTIME, at, 0);
+		pthread_cleanup_pop(1);
 		if (r) {
 			errno = r;
-			r = -1;
-			break;
+			return -1;
 		}
 	}
-
-	pthread_cleanup_pop(1);
-
-	return r;
+	return 0;
 }
diff --git a/src/thread/sem_trywait.c b/src/thread/sem_trywait.c
index dd8f57e3..55d90075 100644
--- a/src/thread/sem_trywait.c
+++ b/src/thread/sem_trywait.c
@@ -4,7 +4,10 @@
 int sem_trywait(sem_t *sem)
 {
 	int val = sem->__val[0];
-	if (val>0 && a_cas(sem->__val, val, val-1)==val) return 0;
+	if (val>0) {
+		int new = val-1-(val==1 && sem->__val[1]);
+		if (a_cas(sem->__val, val, new)==val) return 0;
+	}
 	errno = EAGAIN;
 	return -1;
 }