#lang typed/racket (require (except-in "utils.rkt" lists-transpose) "utils.rkt" "functions.rkt" "networks.rkt" typed/graph typed/racket/random) (require/typed racket/hash [hash-intersect (->* ((HashTable Variable Real)) (#:combine (-> Real Real Real)) #:rest (HashTable Variable Real) (HashTable Variable Real))] [(hash-intersect hash-intersect/tbn-tbf) (->* ((HashTable Variable TBF/State)) (#:combine (-> TBF/State Real Real)) #:rest (HashTable Variable Real) (HashTable Variable Real))] [hash-union (->* ((HashTable Variable Real)) (#:combine (-> Real Real Real)) #:rest (HashTable Variable Real) (HashTable Variable Real))]) (module+ test (require typed/rackunit)) (provide apply-tbf-to-state (struct-out tbf/state) TBF/State tbf/state-w tbf/state-θ make-tbf/state sbf/state? apply-tbf/state compact-tbf lists+vars->tbfs/state lists+headers->tbfs/state lists->tbfs/state lists+vars->sbfs/state lists+headers->sbfs/state lists->sbfs/state read-org-tbfs/state read-org-tbfs/state+headers tbfs/state->lists tbfs/state->lists+headers lists->tbfs/state/opt-headers sbfs/state->lists sbfs/state->lists+headers tabulate-tbfs/state tabulate-tbfs/state+headers tabulate-tbf/state tabulate-tbf/state+headers group-truth-table-by-nai TBN sbn? tbn->network build-tbn-state-graph normalized-tbn? normalize-tbn compact-tbn tbn-interaction-graph pretty-print-tbn-interaction-graph sbn-interaction-graph parse-org-tbn read-org-tbn read-org-sbn tbn->lists sbn->lists ) (: apply-tbf-to-state (-> TBF (State (U Zero One)) (U Zero One))) (define (apply-tbf-to-state tbf st) (apply-tbf tbf (list->vector (hash-map st (λ (_ [val : (U Zero One)]) val) #t)))) (module+ test (test-case "apply-tbf-to-state" (define st (hash 'x1 0 'x2 1)) (define f (tbf #(1 1) 1)) (check-equal? (apply-tbf-to-state f st) 0))) (struct tbf/state ([weights : (VariableMapping Real)] [threshold : Real]) #:transparent #:type-name TBF/State) (define tbf/state-w tbf/state-weights) (define tbf/state-θ tbf/state-threshold) (: make-tbf/state (-> (Listof (Pairof Variable Real)) Real TBF/State)) (define (make-tbf/state pairs threshold) (tbf/state (make-immutable-hash pairs) threshold)) (module+ test (test-case "tbf/state" (define f (make-tbf/state '((x1 . 1) (x2 . 1)) 1)) (check-equal? (tbf/state-w f) #hash((x1 . 1) (x2 . 1))) (check-equal? (tbf/state-θ f) 1))) (: sbf/state? (-> TBF/State Boolean)) (define (sbf/state? tbfs) (zero? (tbf/state-θ tbfs))) (module+ test (test-case "sbf/state?" (check-true (sbf/state? (tbf/state (hash 'a -1 'b 1) 0))) (check-false (sbf/state? (tbf/state (hash 'a -1 'b 1) 1))))) (: apply-tbf/state (-> TBF/State (State (U Zero One)) (U Zero One))) (define (apply-tbf/state tbfs st) (any->01 (> (apply + (hash-values (hash-intersect (tbf/state-w tbfs) st #:combine *))) (tbf/state-θ tbfs)))) (module+ test (test-case "apply-tbf/state" (define st1 (hash 'a 1 'b 0 'c 1)) (define st2 (hash 'a 1 'b 1 'c 0)) (define tbf (make-tbf/state '((a . 2) (b . -2)) 1)) (check-equal? (apply-tbf/state tbf st1) 1) (check-equal? (apply-tbf/state tbf st2) 0))) (: compact-tbf (-> TBF/State TBF/State)) (define (compact-tbf tbf) (tbf/state (for/hash : (VariableMapping Real) ([(k v) (in-hash (tbf/state-w tbf))] #:unless (zero? v)) (values k v)) (tbf/state-θ tbf))) (module+ test (test-case "compact-tbf" (check-equal? (compact-tbf (tbf/state (hash 'a 0 'b 1 'c 2 'd 0) 2)) (tbf/state '#hash((b . 1) (c . 2)) 2)))) (: lists+vars->tbfs/state (-> (Listof Variable) (Listof (Listof Real)) (Listof TBF/State))) (define (lists+vars->tbfs/state vars lsts) (for/list ([lst (in-list lsts)]) (define-values (ws θ) (split-at-right lst 1)) (make-tbf/state (for/list ([x (in-list vars)] [w (in-list ws)]) (cons x w)) (car θ)))) (module+ test (test-case "lists+vars->tbfs/state" (check-equal? (lists+vars->tbfs/state '(x y) '((1 2 3) (1 1 2))) (list (tbf/state '#hash((x . 1) (y . 2)) 3) (tbf/state '#hash((x . 1) (y . 1)) 2))))) (: lists+headers->tbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real))) (Listof TBF/State))) (define (lists+headers->tbfs/state lsts+headers) (lists+vars->tbfs/state (drop-right (car lsts+headers) 1) (cdr lsts+headers))) (module+ test (test-case "lists+headers->tbfs/state" (check-equal? (lists+headers->tbfs/state '((x y f) (1 2 3) (1 1 2))) (list (tbf/state '#hash((x . 1) (y . 2)) 3) (tbf/state '#hash((x . 1) (y . 1)) 2))))) (: lists->tbfs/state (-> (Listof (Listof Real)) (Listof TBF/State))) (define (lists->tbfs/state lsts) (lists+vars->tbfs/state (for/list ([i (in-range (length (car lsts)))]) (string->symbol (format "x~a" i))) lsts)) (module+ test (test-case "lists->tbfs/state" (check-equal? (lists->tbfs/state '((1 2 3) (1 1 2))) (list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) (tbf/state '#hash((x0 . 1) (x1 . 1)) 2))))) (: lists->tbfs/state/opt-headers (-> (Listof (Listof (U Variable Real))) #:headers Boolean (Listof TBF/State))) (define (lists->tbfs/state/opt-headers lsts #:headers hdr) (if hdr (lists+headers->tbfs/state (assert-type lsts (Pairof (Listof Variable) (Listof (Listof Real))))) (lists->tbfs/state (assert-type lsts (Listof (Listof Real)))))) (module+ test (test-case "lists->tbfs/state/opt-headers" (check-equal? (lists->tbfs/state/opt-headers '((1 2 3) (1 1 2)) #:headers #f) (list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) (tbf/state '#hash((x0 . 1) (x1 . 1)) 2))) (check-equal? (lists->tbfs/state/opt-headers '((x y f) (1 2 3) (1 1 2)) #:headers #t) (list (tbf/state '#hash((x . 1) (y . 2)) 3) (tbf/state '#hash((x . 1) (y . 1)) 2))))) (: lists+vars->sbfs/state (-> (Listof Variable) (Listof (Listof Real)) (Listof TBF/State))) (define (lists+vars->sbfs/state vars lsts) (for/list ([lst (in-list lsts)]) (make-tbf/state (map (inst cons Variable Real) vars lst) 0))) (module+ test (test-case "lists+vars->sbfs/state" (check-equal? (lists+vars->sbfs/state '(x y) '((1 2) (1 1))) (list (tbf/state '#hash((x . 1) (y . 2)) 0) (tbf/state '#hash((x . 1) (y . 1)) 0))))) (: lists+headers->sbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real))) (Listof TBF/State))) (define (lists+headers->sbfs/state lsts) (lists+vars->sbfs/state (car lsts) (cdr lsts))) (module+ test (test-case "lists+headers->sbfs/state" (check-equal? (lists+headers->sbfs/state '((x y) (1 2) (1 1))) (list (tbf/state '#hash((x . 1) (y . 2)) 0) (tbf/state '#hash((x . 1) (y . 1)) 0))))) (: lists->sbfs/state (-> (Listof (Listof Real)) (Listof TBF/State))) (define (lists->sbfs/state lsts) (lists+vars->sbfs/state (for/list ([i (in-range (length (car lsts)))]) (string->symbol (format "x~a" i))) lsts)) (module+ test (test-case "lists->sbfs/state" (check-equal? (lists->sbfs/state '((1 2) (1 1))) (list (tbf/state '#hash((x0 . 1) (x1 . 2)) 0) (tbf/state '#hash((x0 . 1) (x1 . 1)) 0))))) (: read-org-tbfs/state (-> String (Listof TBF/State))) (define (read-org-tbfs/state str) (lists->tbfs/state (assert-type (read-org-sexp str) (Listof (Listof Real))))) (module+ test (test-case "read-org-tbfs/state" (check-equal? (read-org-tbfs/state "((1 2 3) (1 1 2))") (list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) (tbf/state '#hash((x0 . 1) (x1 . 1)) 2))))) (: read-org-tbfs/state+headers (-> String (Listof TBF/State))) (define (read-org-tbfs/state+headers str) (lists+headers->tbfs/state (assert-type (read-org-sexp str) (Pairof (Listof Variable) (Listof (Listof Real)))))) (module+ test (test-case "read-org-tbfs/state+headers" (check-equal? (read-org-tbfs/state+headers "((a b f) (1 2 3) (1 1 2))") (list (tbf/state '#hash((a . 1) (b . 2)) 3) (tbf/state '#hash((a . 1) (b . 1)) 2))))) (: tbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real)))) (define (tbfs/state->lists tbfs) (for/list ([tbf (in-list tbfs)]) (append (hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t) (list (tbf/state-θ tbf))))) (module+ test (test-case "tbfs/state->lists" (check-equal? (tbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 3) (tbf/state (hash 'a -2 'b 1) 1))) '((1 2 3) (-2 1 1))))) (: tbfs/state->lists+headers (-> (Listof TBF/State) (Pairof (Listof Variable) (Listof (Listof Real))))) (define (tbfs/state->lists+headers tbfs) (cons (append (hash-map (tbf/state-w (car tbfs)) (λ ([x : Symbol] _) x) #t) '(θ)) (tbfs/state->lists tbfs))) (module+ test (test-case "tbfs/state->list+headers" (check-equal? (tbfs/state->lists+headers (list (tbf/state (hash 'a 1 'b 2) 3) (tbf/state (hash 'a -2 'b 1) 1))) '((a b θ) (1 2 3) (-2 1 1))))) (: sbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real)))) (define (sbfs/state->lists tbfs) (for/list ([tbf (in-list tbfs)]) (hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t))) (module+ test (test-case "sbfs/state->lists" (check-equal? (sbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 0) (tbf/state (hash 'a -2 'b 1) 0))) '((1 2) (-2 1))))) (: sbfs/state->lists+headers (-> (Listof TBF/State) (Pairof (Listof Variable) (Listof (Listof Real))))) (define (sbfs/state->lists+headers tbfs) (cons (hash-map (tbf/state-w (car tbfs)) (λ ([x : Symbol] _) x) #t) (sbfs/state->lists tbfs))) (module+ test (test-case "sbfs/state->list+headers" (check-equal? (sbfs/state->lists+headers (list (tbf/state (hash 'a 1 'b 2) 0) (tbf/state (hash 'a -2 'b 1) 0))) '((a b) (1 2) (-2 1))))) (: tabulate-tbfs/state (-> (Listof TBF/State) (Listof (Listof Real)))) (define (tabulate-tbfs/state tbfs) (define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t)) (tabulate-state* (map (curry apply-tbf/state) tbfs) (make-same-domains vars '(0 1)))) (module+ test (test-case "tabulate-tbfs/state" (check-equal? (tabulate-tbfs/state (list (tbf/state (hash 'a 1 'b 2) 2) (tbf/state (hash 'a -2 'b 2) 1))) '((0 0 0 0) (0 1 0 1) (1 0 0 0) (1 1 1 0))))) (: tabulate-tbfs/state+headers (-> (Listof TBF/State) (Pairof (Listof Variable) (Listof (Listof Real))))) (define (tabulate-tbfs/state+headers tbfs) (define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t)) (tabulate-state*+headers (map (curry apply-tbf/state) tbfs) (make-same-domains vars '(0 1)))) (module+ test (test-case "tabulate-tbfs/state+headers" (check-equal? (tabulate-tbfs/state+headers (list (tbf/state (hash 'a 1 'b 2) 2) (tbf/state (hash 'a -2 'b 2) 1))) '((a b f1 f2) (0 0 0 0) (0 1 0 1) (1 0 0 0) (1 1 1 0))))) (: tabulate-tbf/state (-> TBF/State (Listof (Listof Real)))) (define (tabulate-tbf/state tbf) (tabulate-tbfs/state (list tbf))) (module+ test (test-case "tabulate-tbf/state" (check-equal? (tabulate-tbf/state (tbf/state (hash 'a 1 'b 2) 2)) '((0 0 0) (0 1 0) (1 0 0) (1 1 1))))) (: tabulate-tbf/state+headers (-> TBF/State (Pairof (Listof Variable) (Listof (Listof Real))))) (define (tabulate-tbf/state+headers tbf) (tabulate-tbfs/state+headers (list tbf))) (module+ test (test-case "tabulate-tbf/state+headers" (check-equal? (tabulate-tbf/state+headers (tbf/state (hash 'a 1 'b 2) 2)) '((a b f1) (0 0 0) (0 1 0) (1 0 0) (1 1 1))))) (: group-truth-table-by-nai (-> (Listof (Listof Integer)) (Listof (Listof (Listof Integer))))) (define (group-truth-table-by-nai tt) (: sum (-> (Listof Integer) Integer)) (define (sum xs) (foldl + 0 xs)) (group-by (λ ([row : (Listof Integer)]) (drop-right row 1)) tt (λ ([in1 : (Listof Integer)] [in2 : (Listof Integer)]) (= (sum in1) (sum in2))))) (module+ test (test-case "group-truth-table-by-nai" (check-equal? (group-truth-table-by-nai '((0 0 0 1) (0 0 1 1) (0 1 0 0) (0 1 1 1) (1 0 0 0) (1 0 1 0) (1 1 0 1) (1 1 1 0))) '(((0 0 0 1)) ((0 0 1 1) (0 1 0 0) (1 0 0 0)) ((0 1 1 1) (1 0 1 0) (1 1 0 1)) ((1 1 1 0)))))) (define-type TBN (HashTable Variable TBF/State)) (: sbn? (-> TBN Boolean)) (define (sbn? tbn) (andmap sbf/state? (hash-values tbn))) (module+ test (test-case "sbn?" (define f1 (tbf/state (hash 'a -1 'b 1) 0)) (define f2 (tbf/state (hash 'a -1 'b 1) 1)) (check-true (sbn? (hash 'a f1 'b f1))) (check-false (sbn? (hash 'a f1 'b f2)))) ) (: tbn->network (-> TBN (Network (U Zero One)))) (define (tbn->network tbn) (make-01-network (for/hash : (VariableMapping (UpdateFunction (U Zero One))) ([(x tbfx) (in-hash tbn)]) (values x (λ ([s : (State (U Zero One))]) (apply-tbf/state tbfx s)))))) (module+ test (test-case "tbn->network" (define tbn-form (hash 'a (tbf/state (hash 'a -1 'b 1) 0) 'b (tbf/state (hash 'a -1 'b 1) 1))) (define tbn (tbn->network tbn-form)) (define s (hash 'a 0 'b 1)) (check-equal? (update tbn s '(a b)) #hash((a . 1) (b . 0))) (check-equal? (network-domains tbn) #hash((a . (0 1)) (b . (0 1)))))) (: build-tbn-state-graph (-> TBN Graph)) (define (build-tbn-state-graph tbn) (pretty-print-state-graph ((inst build-full-state-graph (U Zero One)) ((inst make-syn-dynamics (U Zero One)) (tbn->network tbn))))) (module+ test (test-case "build-tbn-state-graph" (check-equal? (graphviz (build-tbn-state-graph (hash 'a (tbf/state (hash 'a -1 'b 1) 0) 'b (tbf/state (hash 'a -1 'b 1) 1)))) "digraph G {\n\tnode0 [label=\"a:0 b:0\"];\n\tnode1 [label=\"a:1 b:1\"];\n\tnode2 [label=\"a:0 b:1\"];\n\tnode3 [label=\"a:1 b:0\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [];\n\t}\n\tsubgraph D {\n\t\tnode1 -> node0 [];\n\t\tnode2 -> node3 [];\n\t\tnode3 -> node0 [];\n\t}\n}\n"))) (: normalized-tbn? (-> TBN Boolean)) (define (normalized-tbn? tbn) (define tbn-vars (hash-keys tbn)) (for/and ([tbf (in-list (hash-values tbn))]) (set=? tbn-vars (hash-keys (tbf/state-w tbf))))) (module+ test (test-case "normalized-tbn?" (check-true (normalized-tbn? (hash 'x (tbf/state (hash 'x 0 'y -1) -1) 'y (tbf/state (hash 'x -1 'y 0) -1)))) (check-false (normalized-tbn? (hash 'x (tbf/state (hash 'x 0 ) -1) 'y (tbf/state (hash 'y 0) -1)))))) (: normalize-tbn (-> TBN TBN)) (define (normalize-tbn tbn) (define vars-0 (for/hash : (VariableMapping Real) ([(x _) (in-hash tbn)]) (values x 0))) (: normalize-tbf (-> TBF/State TBF/State)) (define (normalize-tbf tbf) ;; Only keep the inputs which are also the variables of tbn. (define w-pruned (hash-intersect/tbn-tbf tbn (tbf/state-w tbf) #:combine (λ (_ w) w))) ;; Put in the missing inputs with weight 0. (define w-complete (assert-type (hash-union vars-0 w-pruned #:combine (λ (_ w) w)) (VariableMapping Real))) (tbf/state w-complete (tbf/state-θ tbf))) (for/hash : TBN ([(x tbf) (in-hash tbn)]) (values x (normalize-tbf tbf)))) (module+ test (test-case "normalize-tbn" (check-equal? (normalize-tbn (hash 'x (tbf/state (hash 'x 2) -1) 'y (tbf/state (hash 'y 3) 1))) (hash 'x (tbf/state (hash 'x 2 'y 0) -1) 'y (tbf/state (hash 'x 0 'y 3) 1))))) (: compact-tbn (-> TBN TBN)) (define (compact-tbn tbn) (: remove-0-non-var (-> TBF/State TBF/State)) (define (remove-0-non-var tbf) (tbf/state (for/hash : (VariableMapping Real) ([(x w) (in-hash (tbf/state-w tbf))] #:when (hash-has-key? tbn x) #:unless (zero? w)) (values x w)) (tbf/state-θ tbf))) (for/hash : TBN ([(x tbf) (in-hash tbn)]) (values x (remove-0-non-var tbf)))) (module+ test (test-case "compact-tbn" (check-equal? (compact-tbn (hash 'a (tbf/state (hash 'a 0 'b 1 'c 3 'd 0) 0) 'b (tbf/state (hash 'a -1 'b 1) -1))) (hash 'a (tbf/state '#hash((b . 1)) 0) 'b (tbf/state '#hash((a . -1) (b . 1)) -1))))) (: tbn-interaction-graph (->* (TBN) (#:zero-edges Boolean) Graph)) (define (tbn-interaction-graph tbn #:zero-edges [zero-edges #t]) (define ntbn (normalize-tbn tbn)) (define ig (weighted-graph/directed (if zero-edges (for*/list : (Listof (List Real Variable Variable)) ([(tar tbf) (in-hash ntbn)] [(src w) (in-hash (tbf/state-w tbf))]) (list w src tar)) (for*/list : (Listof (List Real Variable Variable)) ([(tar tbf) (in-hash ntbn)] [(src w) (in-hash (tbf/state-w tbf))] #:unless (zero? w)) (list w src tar))))) (update-graph ig #:v-func (λ (x) (cons x (tbf/state-θ (hash-ref ntbn (assert-type x Variable))))))) (module+ test (test-case "tbn-interaction-graph" (define tbn (hash 'a (tbf/state (hash 'b 1) 0) 'b (tbf/state (hash 'a -1) -1))) (check-equal? (graphviz (tbn-interaction-graph tbn)) "digraph G {\n\tnode0 [label=\"'(b . -1)\"];\n\tnode1 [label=\"'(a . 0)\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n") (check-equal? (graphviz (tbn-interaction-graph tbn #:zero-edges #f)) "digraph G {\n\tnode0 [label=\"'(b . -1)\"];\n\tnode1 [label=\"'(a . 0)\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n"))) (: pretty-print-tbn-interaction-graph (-> Graph Graph)) (define (pretty-print-tbn-interaction-graph ig) (update-graph ig #:v-func (match-lambda [(cons var weight) (~a var ":" weight)]))) (module+ test (test-case "pretty-print-tbn-interaction-graph" (define tbn (hash 'a (tbf/state (hash 'b 1) 0) 'b (tbf/state (hash 'a -1) -1))) (check-equal? (graphviz (pretty-print-tbn-interaction-graph (tbn-interaction-graph tbn))) "digraph G {\n\tnode0 [label=\"a:0\"];\n\tnode1 [label=\"b:-1\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"-1\"];\n\t\tnode1 -> node0 [label=\"1\"];\n\t}\n}\n"))) (: sbn-interaction-graph (-> TBN Graph)) (define (sbn-interaction-graph sbn #:zero-edges [zero-edges #t]) (update-graph (tbn-interaction-graph sbn #:zero-edges zero-edges) #:v-func (match-lambda [(cons var _) var]))) (module+ test (test-case "sbn-interaction-graph" (define sbn (hash 'a (tbf/state (hash 'b 2) 0) 'b (tbf/state (hash 'a 2) 0))) (check-equal? (graphviz (sbn-interaction-graph sbn)) "digraph G {\n\tnode0 [label=\"a\"];\n\tnode1 [label=\"b\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node1 [label=\"2\"];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t}\n}\n"))) (: parse-org-tbn (->* ((Listof (Listof (U Symbol Real)))) (#:headers Boolean #:func-names Boolean) TBN)) (define (parse-org-tbn tab #:headers [headers #t] #:func-names [func-names #t]) (cond [func-names (define-values (vars rows) (multi-split-at tab 1)) (define tbfs (lists->tbfs/state/opt-headers rows #:headers headers)) (for/hash : TBN ([tbf (in-list tbfs)] [var (in-list (cdr vars))]) (values (assert-type (car var) Variable) tbf))] [else (define tbfs (lists->tbfs/state/opt-headers tab #:headers headers)) (define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t)) (for/hash : TBN ([tbf (in-list tbfs)] [var (in-list vars)]) (values (assert-type var Variable) tbf))])) (module+ test (test-case "parse-org-tbn" (check-equal? (parse-org-tbn '((1 2 3) (3 2 1)) #:headers #f #:func-names #f) (hash 'x0 (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) 'x1 (tbf/state '#hash((x0 . 3) (x1 . 2)) 1))) (check-equal? (parse-org-tbn '((a b θ) (1 2 3) (3 2 1)) #:headers #t #:func-names #f) (hash 'a (tbf/state '#hash((a . 1) (b . 2)) 3) 'b (tbf/state '#hash((a . 3) (b . 2)) 1))) (check-equal? (parse-org-tbn '((dummy a b θ) (b 3 2 1) (a 1 2 3)) #:headers #t #:func-names #t) (hash 'a (tbf/state '#hash((a . 1) (b . 2)) 3) 'b (tbf/state '#hash((a . 3) (b . 2)) 1))))) (: read-org-tbn (->* (String) (#:headers Boolean #:func-names Boolean) TBN)) (define (read-org-tbn str #:headers [headers #t] #:func-names [func-names #t]) (parse-org-tbn (assert-type (read-org-sexp str) (Listof (Listof (U Symbol Real)))) #:headers headers #:func-names func-names)) (module+ test (test-case "read-org-tbn" (check-equal? (read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))") (hash 'x (tbf/state '#hash((x . 0) (y . -1)) -1) 'y (tbf/state '#hash((x . -1) (y . 0)) -1))) (check-equal? (read-org-tbn "((\"x\" \"y\" \"θ\") (-1 0 -1) (0 -1 -1))" #:func-names #f) (hash 'x (tbf/state '#hash((x . -1) (y . 0)) -1) 'y (tbf/state '#hash((x . 0) (y . -1)) -1))) (check-equal? (read-org-tbn "((-1 0 -1) (0 -1 -1))" #:headers #f #:func-names #f) (hash 'x0 (tbf/state '#hash((x0 . -1) (x1 . 0)) -1) 'x1 (tbf/state '#hash((x0 . 0) (x1 . -1)) -1))))) (: read-org-sbn (->* (String) (#:headers Boolean #:func-names Boolean) TBN)) (define (read-org-sbn str #:headers [headers #t] #:func-names [func-names #t]) (define sexp (assert-type (read-org-sexp str) (Listof (Listof (U Symbol Real))))) ;; Inject the 0 thresholds into the rows of the sexp we have just read. (: inject-0 (-> (Listof (Listof (U Symbol Real))) (Listof (Listof (U Symbol Real))))) (define (inject-0 rows) (for/list : (Listof (Listof (U Symbol Real))) ([row (in-list rows)]) (append row '(0)))) (define sexp-ready (if headers (cons (append (car sexp) '(θ)) (inject-0 (cdr sexp))) (inject-0 sexp))) (parse-org-tbn sexp-ready #:headers headers #:func-names func-names)) (module+ test (test-case "read-org-sbn" (check-equal? (read-org-sbn "((\"-\" \"x\" \"y\") (\"y\" -1 0) (\"x\" 0 -1))") (hash 'x (tbf/state '#hash((x . 0) (y . -1)) 0) 'y (tbf/state '#hash((x . -1) (y . 0)) 0))) (check-equal? (read-org-sbn "((\"x\" \"y\") (-1 0) (0 -1))" #:func-names #f) (hash 'x (tbf/state '#hash((x . -1) (y . 0)) 0) 'y (tbf/state '#hash((x . 0) (y . -1)) 0))) (check-equal? (read-org-sbn "((-1 0) (0 -1))" #:headers #f #:func-names #f) (hash 'x0 (tbf/state '#hash((x0 . -1) (x1 . 0)) 0) 'x1 (tbf/state '#hash((x0 . 0) (x1 . -1)) 0))))) (: tbn->lists (->* (TBN) (#:headers Boolean #:func-names Boolean) (Listof (Listof (U Symbol Real))))) (define (tbn->lists tbn #:headers [headers #t] #:func-names [func-names #t]) (define ntbn (normalize-tbn tbn)) (define vars-tbfs (hash-map ntbn (λ ([x : Variable] [tbf : TBF/State]) (cons x tbf)) #t)) (define tbfs (map (inst cdr Variable TBF/State) vars-tbfs)) (define tbfs-table ((if headers tbfs/state->lists+headers tbfs/state->lists) tbfs)) (cond [(eq? func-names #t) (define vars (map (inst car Variable TBF/State) vars-tbfs)) (define col-1 (if headers (cons '- vars) vars)) (for/list ([var (in-list col-1)] [row (in-list tbfs-table)]) (cons var row))] [else tbfs-table])) (module+ test (test-case "tbn->lists" (define tbn (hash 'a (tbf/state (hash 'b 1) 0) 'b (tbf/state (hash 'a -1) -1))) (check-equal? (tbn->lists tbn) '((- a b θ) (a 0 1 0) (b -1 0 -1))) (check-equal? (tbn->lists tbn #:headers #f) '((a 0 1 0) (b -1 0 -1))) (check-equal? (tbn->lists tbn #:func-names #f) '((a b θ) (0 1 0) (-1 0 -1))) (check-equal? (tbn->lists tbn #:headers #f #:func-names #f) '((0 1 0) (-1 0 -1))))) (: sbn->lists (->* (TBN) (#:headers Boolean #:func-names Boolean) (Listof (Listof (U Symbol Real))))) (define (sbn->lists sbn #:headers [headers #t] #:func-names [func-names #t]) (define tab (tbn->lists sbn #:headers headers #:func-names func-names)) (define-values (tab-no-θ _) (multi-split-at tab (sub1 (length (car tab))))) tab-no-θ) (module+ test (test-case "sbn->lists" (define sbn (hash 'a (tbf/state (hash 'b 2) 0) 'b (tbf/state (hash 'a 2) 0))) (check-equal? (sbn->lists sbn) '((- a b) (a 0 2) (b 2 0))) (check-equal? (sbn->lists sbn #:headers #f) '((a 0 2) (b 2 0))) (check-equal? (sbn->lists sbn #:func-names #f) '((a b) (0 2) (2 0))) (check-equal? (sbn->lists sbn #:headers #f #:func-names #f) '((0 2) (2 0)))))