summary refs log tree commit diff
diff options
context:
space:
mode:
authorLeah Neukirchen <leah@vuxu.org>2022-11-27 16:27:10 +0100
committerLeah Neukirchen <leah@vuxu.org>2022-11-27 16:27:10 +0100
commitc53cde4625aba15a5f6f660e11fe91c61f5b6e24 (patch)
tree68a6d0a8bbc015ef097b5854c38bc0581810da47
parentc5d3b30866dff0bbe3a2891ea4fcc3cd2d075d74 (diff)
downloadmew-c53cde4625aba15a5f6f660e11fe91c61f5b6e24.tar.gz
mew-c53cde4625aba15a5f6f660e11fe91c61f5b6e24.tar.xz
mew-c53cde4625aba15a5f6f660e11fe91c61f5b6e24.zip
extend sample to sample multiple elements without replacement
-rw-r--r--mew.scm58
-rw-r--r--mew.svnwiki8
-rw-r--r--tests/test.mew6
3 files changed, 60 insertions, 12 deletions
diff --git a/mew.scm b/mew.scm
index 22a8752..8507634 100644
--- a/mew.scm
+++ b/mew.scm
@@ -371,6 +371,13 @@
         (and (hash-table? o)
              (zero? (hash-table-size o)))))
 
+  (define (empty o)
+    (cond ((list? o) '())
+          ((string? o) "")
+          ((vector? o) #())
+          ((hash-table? o) (tbl))
+          (else "no empty defined")))
+
   (define (len o)
     (cond ((list? o) (length o))
           ((string? o) (string-length o))
@@ -426,17 +433,46 @@
             (vector-set! res i (vector-ref res j)))
           (vector-set! res j (vector-ref v i))))))
 
-  (define (sample o)
-    (if (hash-table? o)
-      (esc ret
-        (let ((n (rand (hash-table-size o)))
-              (i 0))
-          (hash-table-for-each o
-                               (lambda (k v)
-                                 (if (= i n)
-                                   (ret (cons k v))
-                                   (set! i (inc i)))))))
-      (get o (rand (len o)))))
+  (define sample
+    (let ((gen-get (lambda (o)
+                     (if (hash-table? o)
+                       (lambda (n)
+                         (esc ret
+                           (let ((i 0))
+                             (hash-table-for-each o
+                                                  (lambda (k v)
+                                                    (if (= i n)
+                                                      (ret (cons k v))
+                                                      (set! i (inc i))))))))
+                       (lambda (n)
+                         (get o n))))))
+      (case-lambda
+        ((o)
+         ((gen-get o) (rand (len o))))
+        ((o k)
+         (if (or (<= k 0) (< (len o) k))
+           #()
+           ;; Algorithm L with additional shuffle at the end.
+           ;; https://dl.acm.org/doi/pdf/10.1145/198429.198435
+           (let ((geto (gen-get o))
+                 (r (make-vector k))
+                 (w (exp (/ (log (rand)) k)))
+                 (n (len o))
+                 (i 0))
+             (while (< i k)
+               (vector-set! r i (geto i))
+               (set! i (inc i)))
+             (while (< i n)
+               (set! i (+ i 1 (inexact->exact (floor (/ (log (rand))
+                                                        (log (- 1 w)))))))
+               (when (< i n)
+                 (vector-set! r (rand k) (geto i))
+                 (set! w (* w (exp (/ (log (rand)) k))))))
+             (shuffle! r)
+             (if (vector? o)
+               r
+               (into (empty o) r)))))
+         )))
 
   (define range
     (case-lambda
diff --git a/mew.svnwiki b/mew.svnwiki
index 7d5dce7..f27a8bd 100644
--- a/mew.svnwiki
+++ b/mew.svnwiki
@@ -536,6 +536,14 @@ Shuffles the vector {{<vector>}} randomly in-place using a Fisher-Yates shuffle.
 Returns a random element of the list/vector/string {{<obj>}}.
 Returns a random key/value pair of the hash-table {{<obj>}}.
 
+<procedure>(sample <obj> <N>)</procedure>
+
+Returns a random list/vector/string consisting of {{<N>}} elements of
+the list/vector/string {{<obj>}}, without replacement.
+
+Returns a random hash-table consisting of {{<N>}} key/value pairs of
+the hash-table {{<obj>}}, without replacement.
+
 
 == Special syntax
 
diff --git a/tests/test.mew b/tests/test.mew
index 50cb6bd..e3a527a 100644
--- a/tests/test.mew
+++ b/tests/test.mew
@@ -44,7 +44,11 @@
 (test-group "sample"
   (test #t ((one-of 1 2 3) (sample '(1 2 3))))
   (test #t ((one-of 1 2 3) (sample #(1 2 3))))
-  (test #t ((one-of '(1 . 2) '(3 . 4)) (sample (tbl 1 2 3 4)))))
+  (test #t ((one-of '(1 . 2) '(3 . 4)) (sample (tbl 1 2 3 4))))
+  (test #t (string? (sample "foobar" 3)))
+  (test 3 (len (sample "foobar" 3)))
+  (test "ooo" (sample "ooooo" 3))
+  (test #(1 2 3) (sort (sample #(1 2 3) 3) <)))
 
 (test-group "range"
   (test '(1 2 3) (into '() (range 1 4)))