Add make-tabulate* to factor out the common parts.

This commit is contained in:
Sergiu Ivanov 2022-03-06 21:45:21 +01:00
parent 929bf09299
commit 37dddb190f

View file

@ -10,7 +10,8 @@
(require "utils.rkt") (require "utils.rkt")
(module typed typed/racket (module typed typed/racket
(require "utils.rkt") (require "utils.rkt"
(for-syntax syntax/parse))
(provide (provide
tabulate* tabulate*/strict) tabulate* tabulate*/strict)
@ -18,12 +19,17 @@
(module+ test (module+ test
(require typed/rackunit)) (require typed/rackunit))
(define-syntax (make-tabulate* stx)
(syntax-parse stx
[(_ name:id row-op)
#'(define (name funcs doms)
(for/list ([xs (in-list (apply cartesian-product doms))])
(row-op xs (for/list ([f funcs]) : (Listof b)
(apply f xs)))))]))
(: tabulate* (All (b a ... ) (-> (Listof (-> a ... b)) (List (Listof a) ... a) (: tabulate* (All (b a ... ) (-> (Listof (-> a ... b)) (List (Listof a) ... a)
(Listof (Listof (U Any b)))))) (Listof (Listof (U Any b))))))
(define (tabulate* funcs doms) (make-tabulate* tabulate* append)
(for/list ([xs (in-list (apply cartesian-product doms))])
(append xs (for/list ([f funcs]) : (Listof b)
(apply f xs)))))
(module+ test (module+ test
(test-case "tabulate*" (test-case "tabulate*"
@ -37,10 +43,7 @@
(: tabulate*/strict (All (b a ...) (-> (Listof (-> a ... b)) (List (Listof a) ... a) (: tabulate*/strict (All (b a ...) (-> (Listof (-> a ... b)) (List (Listof a) ... a)
(Listof (List (List a ...) (Listof b)))))) (Listof (List (List a ...) (Listof b))))))
(define (tabulate*/strict funcs doms) (make-tabulate* tabulate*/strict list)
(for/list ([xs (in-list (apply cartesian-product doms))])
(list xs (for/list ([f funcs]) : (Listof b)
(apply f xs)))))
(module+ test (module+ test
(test-case "tabulate*/strict" (test-case "tabulate*/strict"