diff --git a/networks.rkt b/networks.rkt index ab54b9a..5bafb46 100644 --- a/networks.rkt +++ b/networks.rkt @@ -108,6 +108,9 @@ [make-tbn (-> (listof (cons/c variable? tbf/state?)) tbn?)] [tbn->network (-> tbn? network?)] [make-sbn (-> (listof (cons/c variable? tbf/state?)) sbn?)] + [parse-org-tbn (->* ((listof any/c)) + (#:headers boolean? #:func-names boolean?) + tbn?)] [read-org-tbn (->* (string?) (#:headers boolean? #:func-names boolean?) tbn?)]) @@ -1333,6 +1336,40 @@ (check-equal? (update sn s2 '(a b)) (make-state '((a . 0) (b . 1)))))) +;;; A helper function for read-org-tbn and read-org-sbn. It reads a +;;; TBN from an Org-mode sexp containing a list of lists of numbers. +;;; As in lists->tbfs/state, the last element of each list is taken to +;;; be the threshold of the TBFs, and the rest of the elements are +;;; taken to be the weights. +;;; +;;; As in read-org-tbfs/state, if headers is #t, the names of the +;;; variables to appear as the inputs of the TBF are taken from the +;;; first list. The last element of this list is discarded. +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +;;; +;;; If func-names is #t, the first element in every row except the +;;; first one, are taken to be the name of the variable to which the +;;; TBF should be associated. If func-names is #f, the functions are +;;; assigned to variables in alphabetical order. +;;; +;;; func-names cannot be #t if headers is #f. The function does not +;;; check this condition. +(define (parse-org-tbn sexp + #:headers [headers #t] + #:func-names [func-names #t]) + (cond + [(eq? func-names #t) + (define-values (vars rows) (multi-split-at sexp 1)) + (define tbfs (lists->tbfs/state rows #:headers headers)) + (for/hash ([tbf (in-list tbfs)] [var (in-list (cdr vars))]) + (values (car var) tbf))] + [else + (define tbfs (lists->tbfs/state sexp #:headers headers)) + (define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t)) + (for/hash ([tbf (in-list tbfs)] [var (in-list vars)]) + (values var tbf))])) + ;;; Reads a TBN from an Org-mode string containing a sexp, containing ;;; a list of lists of numbers. As in lists->tbfs/state, the last ;;; element of each list is taken to be the threshold of the TBFs, and @@ -1354,19 +1391,30 @@ (define (read-org-tbn str #:headers [headers #t] #:func-names [func-names #t]) - (define sexp (read-org-sexp str)) - (cond - [(eq? func-names #t) - (define-values (vars rows) (multi-split-at sexp 1)) - (define tbfs (lists->tbfs/state rows #:headers headers)) - (for/hash ([tbf (in-list tbfs)] [var (in-list (cdr vars))]) - (values (car var) tbf))] - [else - (define tbfs (lists->tbfs/state sexp #:headers headers)) - (define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t)) - (for/hash ([tbf (in-list tbfs)] [var (in-list vars)]) - (values var tbf))])) + (parse-org-tbn (read-org-sexp str) + #:headers headers + #:func-names func-names)) +(module+ test + (test-case "read-org-tbn, parse-org-tbn" + (check-equal? (read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))") + (hash + 'x + (tbf/state '#hash((x . 0) (y . -1)) -1) + 'y + (tbf/state '#hash((x . -1) (y . 0)) -1))) + (check-equal? (read-org-tbn "((\"x\" \"y\" \"θ\") (-1 0 -1) (0 -1 -1))" #:func-names #f) + (hash + 'x + (tbf/state '#hash((x . -1) (y . 0)) -1) + 'y + (tbf/state '#hash((x . 0) (y . -1)) -1))) + (check-equal? (read-org-tbn "((-1 0 -1) (0 -1 -1))" #:headers #f #:func-names #f) + (hash + 'x0 + (tbf/state '#hash((x0 . -1) (x1 . 0)) -1) + 'x1 + (tbf/state '#hash((x0 . 0) (x1 . -1)) -1))))) (module+ test (test-case "read-org-tbn" (check-equal? (read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))")