dds/tbn.rkt

716 lines
28 KiB
Racket

#lang typed/racket
(require (except-in "utils.rkt" lists-transpose)
"utils.rkt" "functions.rkt" "networks.rkt"
typed/graph typed/racket/random)
(require/typed racket/hash
[hash-intersect
(->* ((HashTable Variable Real))
(#:combine (-> Real Real Real))
#:rest (HashTable Variable Real)
(HashTable Variable Real))]
[(hash-intersect hash-intersect/tbn-tbf)
(->* ((HashTable Variable TBF/State))
(#:combine (-> TBF/State Real Real))
#:rest (HashTable Variable Real)
(HashTable Variable Real))]
[hash-union
(->* ((HashTable Variable Real))
(#:combine (-> Real Real Real))
#:rest (HashTable Variable Real)
(HashTable Variable Real))])
(module+ test
(require typed/rackunit))
(provide
apply-tbf-to-state
(struct-out tbf/state) TBF/State tbf/state-w tbf/state-θ make-tbf/state
sbf/state? apply-tbf/state compact-tbf
lists+vars->tbfs/state lists+headers->tbfs/state lists->tbfs/state
lists+vars->sbfs/state lists+headers->sbfs/state lists->sbfs/state
read-org-tbfs/state read-org-tbfs/state+headers
tbfs/state->lists tbfs/state->lists+headers lists->tbfs/state/opt-headers
sbfs/state->lists sbfs/state->lists+headers
tabulate-tbfs/state tabulate-tbfs/state+headers
tabulate-tbf/state tabulate-tbf/state+headers
group-truth-table-by-nai
TBN sbn? tbn->network
build-tbn-state-graph normalized-tbn? normalize-tbn compact-tbn
tbn-interaction-graph pretty-print-tbn-interaction-graph
sbn-interaction-graph
parse-org-tbn read-org-tbn read-org-sbn tbn->lists sbn->lists
)
(: apply-tbf-to-state (-> TBF (State (U Zero One)) (U Zero One)))
(define (apply-tbf-to-state tbf st)
(apply-tbf tbf (list->vector
(hash-map st (λ (_ [val : (U Zero One)]) val) #t))))
(module+ test
(test-case "apply-tbf-to-state"
(define st (hash 'x1 0 'x2 1))
(define f (tbf #(1 1) 1))
(check-equal? (apply-tbf-to-state f st) 0)))
(struct tbf/state ([weights : (VariableMapping Real)]
[threshold : Real])
#:transparent
#:type-name TBF/State)
(define tbf/state-w tbf/state-weights)
(define tbf/state-θ tbf/state-threshold)
(: make-tbf/state (-> (Listof (Pairof Variable Real)) Real TBF/State))
(define (make-tbf/state pairs threshold)
(tbf/state (make-immutable-hash pairs) threshold))
(module+ test
(test-case "tbf/state"
(define f (make-tbf/state '((x1 . 1) (x2 . 1)) 1))
(check-equal? (tbf/state-w f) #hash((x1 . 1) (x2 . 1)))
(check-equal? (tbf/state-θ f) 1)))
(: sbf/state? (-> TBF/State Boolean))
(define (sbf/state? tbfs) (zero? (tbf/state-θ tbfs)))
(module+ test
(test-case "sbf/state?"
(check-true (sbf/state? (tbf/state (hash 'a -1 'b 1) 0)))
(check-false (sbf/state? (tbf/state (hash 'a -1 'b 1) 1)))))
(: apply-tbf/state (-> TBF/State (State (U Zero One)) (U Zero One)))
(define (apply-tbf/state tbfs st)
(any->01
(> (apply + (hash-values (hash-intersect (tbf/state-w tbfs) st #:combine *)))
(tbf/state-θ tbfs))))
(module+ test
(test-case "apply-tbf/state"
(define st1 (hash 'a 1 'b 0 'c 1))
(define st2 (hash 'a 1 'b 1 'c 0))
(define tbf (make-tbf/state '((a . 2) (b . -2)) 1))
(check-equal? (apply-tbf/state tbf st1) 1)
(check-equal? (apply-tbf/state tbf st2) 0)))
(: compact-tbf (-> TBF/State TBF/State))
(define (compact-tbf tbf)
(tbf/state
(for/hash : (VariableMapping Real)
([(k v) (in-hash (tbf/state-w tbf))]
#:unless (zero? v))
(values k v))
(tbf/state-θ tbf)))
(module+ test
(test-case "compact-tbf"
(check-equal? (compact-tbf (tbf/state (hash 'a 0 'b 1 'c 2 'd 0) 2))
(tbf/state '#hash((b . 1) (c . 2)) 2))))
(: lists+vars->tbfs/state (-> (Listof Variable) (Listof (Listof Real))
(Listof TBF/State)))
(define (lists+vars->tbfs/state vars lsts)
(for/list ([lst (in-list lsts)])
(define-values (ws θ) (split-at-right lst 1))
(make-tbf/state (for/list ([x (in-list vars)]
[w (in-list ws)])
(cons x w))
(car θ))))
(module+ test
(test-case "lists+vars->tbfs/state"
(check-equal? (lists+vars->tbfs/state '(x y) '((1 2 3) (1 1 2)))
(list (tbf/state '#hash((x . 1) (y . 2)) 3)
(tbf/state '#hash((x . 1) (y . 1)) 2)))))
(: lists+headers->tbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real)))
(Listof TBF/State)))
(define (lists+headers->tbfs/state lsts+headers)
(lists+vars->tbfs/state (drop-right (car lsts+headers) 1)
(cdr lsts+headers)))
(module+ test
(test-case "lists+headers->tbfs/state"
(check-equal? (lists+headers->tbfs/state '((x y f) (1 2 3) (1 1 2)))
(list (tbf/state '#hash((x . 1) (y . 2)) 3)
(tbf/state '#hash((x . 1) (y . 1)) 2)))))
(: lists->tbfs/state (-> (Listof (Listof Real)) (Listof TBF/State)))
(define (lists->tbfs/state lsts)
(lists+vars->tbfs/state
(for/list ([i (in-range (length (car lsts)))])
(string->symbol (format "x~a" i)))
lsts))
(module+ test
(test-case "lists->tbfs/state"
(check-equal? (lists->tbfs/state '((1 2 3) (1 1 2)))
(list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))))
(: lists->tbfs/state/opt-headers (-> (Listof (Listof (U Variable Real)))
#:headers Boolean
(Listof TBF/State)))
(define (lists->tbfs/state/opt-headers lsts #:headers hdr)
(if hdr
(lists+headers->tbfs/state
(assert-type lsts (Pairof (Listof Variable) (Listof (Listof Real)))))
(lists->tbfs/state
(assert-type lsts (Listof (Listof Real))))))
(module+ test
(test-case "lists->tbfs/state/opt-headers"
(check-equal?
(lists->tbfs/state/opt-headers '((1 2 3) (1 1 2)) #:headers #f)
(list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))
(check-equal?
(lists->tbfs/state/opt-headers '((x y f) (1 2 3) (1 1 2)) #:headers #t)
(list (tbf/state '#hash((x . 1) (y . 2)) 3)
(tbf/state '#hash((x . 1) (y . 1)) 2)))))
(: lists+vars->sbfs/state (-> (Listof Variable) (Listof (Listof Real))
(Listof TBF/State)))
(define (lists+vars->sbfs/state vars lsts)
(for/list ([lst (in-list lsts)])
(make-tbf/state (map (inst cons Variable Real) vars lst) 0)))
(module+ test
(test-case "lists+vars->sbfs/state"
(check-equal? (lists+vars->sbfs/state '(x y) '((1 2) (1 1)))
(list (tbf/state '#hash((x . 1) (y . 2)) 0)
(tbf/state '#hash((x . 1) (y . 1)) 0)))))
(: lists+headers->sbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real)))
(Listof TBF/State)))
(define (lists+headers->sbfs/state lsts)
(lists+vars->sbfs/state (car lsts) (cdr lsts)))
(module+ test
(test-case "lists+headers->sbfs/state"
(check-equal? (lists+headers->sbfs/state '((x y) (1 2) (1 1)))
(list (tbf/state '#hash((x . 1) (y . 2)) 0)
(tbf/state '#hash((x . 1) (y . 1)) 0)))))
(: lists->sbfs/state (-> (Listof (Listof Real)) (Listof TBF/State)))
(define (lists->sbfs/state lsts)
(lists+vars->sbfs/state
(for/list ([i (in-range (length (car lsts)))])
(string->symbol (format "x~a" i)))
lsts))
(module+ test
(test-case "lists->sbfs/state"
(check-equal? (lists->sbfs/state '((1 2) (1 1)))
(list
(tbf/state '#hash((x0 . 1) (x1 . 2)) 0)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 0)))))
(: read-org-tbfs/state (-> String (Listof TBF/State)))
(define (read-org-tbfs/state str)
(lists->tbfs/state
(assert-type (read-org-sexp str)
(Listof (Listof Real)))))
(module+ test
(test-case "read-org-tbfs/state"
(check-equal? (read-org-tbfs/state "((1 2 3) (1 1 2))")
(list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))))
(: read-org-tbfs/state+headers (-> String (Listof TBF/State)))
(define (read-org-tbfs/state+headers str)
(lists+headers->tbfs/state
(assert-type (read-org-sexp str)
(Pairof (Listof Variable) (Listof (Listof Real))))))
(module+ test
(test-case "read-org-tbfs/state+headers"
(check-equal? (read-org-tbfs/state+headers "((a b f) (1 2 3) (1 1 2))")
(list (tbf/state '#hash((a . 1) (b . 2)) 3)
(tbf/state '#hash((a . 1) (b . 1)) 2)))))
(: tbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real))))
(define (tbfs/state->lists tbfs)
(for/list ([tbf (in-list tbfs)])
(append (hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t)
(list (tbf/state-θ tbf)))))
(module+ test
(test-case "tbfs/state->lists"
(check-equal?
(tbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 3)
(tbf/state (hash 'a -2 'b 1) 1)))
'((1 2 3) (-2 1 1)))))
(: tbfs/state->lists+headers (-> (Listof TBF/State)
(Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tbfs/state->lists+headers tbfs)
(cons (append (hash-map (tbf/state-w (car tbfs))
(λ ([x : Symbol] _) x) #t)
'(θ))
(tbfs/state->lists tbfs)))
(module+ test
(test-case "tbfs/state->list+headers"
(check-equal?
(tbfs/state->lists+headers (list (tbf/state (hash 'a 1 'b 2) 3)
(tbf/state (hash 'a -2 'b 1) 1)))
'((a b θ)
(1 2 3)
(-2 1 1)))))
(: sbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real))))
(define (sbfs/state->lists tbfs)
(for/list ([tbf (in-list tbfs)])
(hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t)))
(module+ test
(test-case "sbfs/state->lists"
(check-equal?
(sbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 0)
(tbf/state (hash 'a -2 'b 1) 0)))
'((1 2) (-2 1)))))
(: sbfs/state->lists+headers (-> (Listof TBF/State)
(Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (sbfs/state->lists+headers tbfs)
(cons (hash-map (tbf/state-w (car tbfs))
(λ ([x : Symbol] _) x) #t)
(sbfs/state->lists tbfs)))
(module+ test
(test-case "sbfs/state->list+headers"
(check-equal?
(sbfs/state->lists+headers (list (tbf/state (hash 'a 1 'b 2) 0)
(tbf/state (hash 'a -2 'b 1) 0)))
'((a b)
(1 2)
(-2 1)))))
(: tabulate-tbfs/state (-> (Listof TBF/State) (Listof (Listof Real))))
(define (tabulate-tbfs/state tbfs)
(define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t))
(tabulate-state* (map (curry apply-tbf/state) tbfs)
(make-same-domains vars '(0 1))))
(module+ test
(test-case "tabulate-tbfs/state"
(check-equal? (tabulate-tbfs/state
(list (tbf/state (hash 'a 1 'b 2) 2)
(tbf/state (hash 'a -2 'b 2) 1)))
'((0 0 0 0)
(0 1 0 1)
(1 0 0 0)
(1 1 1 0)))))
(: tabulate-tbfs/state+headers (-> (Listof TBF/State) (Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tabulate-tbfs/state+headers tbfs)
(define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t))
(tabulate-state*+headers (map (curry apply-tbf/state) tbfs)
(make-same-domains vars '(0 1))))
(module+ test
(test-case "tabulate-tbfs/state+headers"
(check-equal? (tabulate-tbfs/state+headers
(list (tbf/state (hash 'a 1 'b 2) 2)
(tbf/state (hash 'a -2 'b 2) 1)))
'((a b f1 f2)
(0 0 0 0)
(0 1 0 1)
(1 0 0 0)
(1 1 1 0)))))
(: tabulate-tbf/state (-> TBF/State (Listof (Listof Real))))
(define (tabulate-tbf/state tbf)
(tabulate-tbfs/state (list tbf)))
(module+ test
(test-case "tabulate-tbf/state"
(check-equal? (tabulate-tbf/state (tbf/state (hash 'a 1 'b 2) 2))
'((0 0 0)
(0 1 0)
(1 0 0)
(1 1 1)))))
(: tabulate-tbf/state+headers (-> TBF/State (Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tabulate-tbf/state+headers tbf)
(tabulate-tbfs/state+headers (list tbf)))
(module+ test
(test-case "tabulate-tbf/state+headers"
(check-equal? (tabulate-tbf/state+headers (tbf/state (hash 'a 1 'b 2) 2))
'((a b f1)
(0 0 0)
(0 1 0)
(1 0 0)
(1 1 1)))))
(: group-truth-table-by-nai (-> (Listof (Listof Integer))
(Listof (Listof (Listof Integer)))))
(define (group-truth-table-by-nai tt)
(: sum (-> (Listof Integer) Integer))
(define (sum xs) (foldl + 0 xs))
(group-by (λ ([row : (Listof Integer)])
(drop-right row 1))
tt
(λ ([in1 : (Listof Integer)] [in2 : (Listof Integer)])
(= (sum in1) (sum in2)))))
(module+ test
(test-case "group-truth-table-by-nai"
(check-equal? (group-truth-table-by-nai '((0 0 0 1)
(0 0 1 1)
(0 1 0 0)
(0 1 1 1)
(1 0 0 0)
(1 0 1 0)
(1 1 0 1)
(1 1 1 0)))
'(((0 0 0 1))
((0 0 1 1) (0 1 0 0) (1 0 0 0))
((0 1 1 1) (1 0 1 0) (1 1 0 1))
((1 1 1 0))))))
(define-type TBN (HashTable Variable TBF/State))
(: sbn? (-> TBN Boolean))
(define (sbn? tbn) (andmap sbf/state? (hash-values tbn)))
(module+ test
(test-case "sbn?"
(define f1 (tbf/state (hash 'a -1 'b 1) 0))
(define f2 (tbf/state (hash 'a -1 'b 1) 1))
(check-true (sbn? (hash 'a f1 'b f1)))
(check-false (sbn? (hash 'a f1 'b f2))))
)
(: tbn->network (-> TBN (Network (U Zero One))))
(define (tbn->network tbn)
(make-01-network
(for/hash : (VariableMapping (UpdateFunction (U Zero One)))
([(x tbfx) (in-hash tbn)])
(values x (λ ([s : (State (U Zero One))])
(apply-tbf/state tbfx s))))))
(module+ test
(test-case "tbn->network"
(define tbn-form (hash 'a (tbf/state (hash 'a -1 'b 1) 0)
'b (tbf/state (hash 'a -1 'b 1) 1)))
(define tbn (tbn->network tbn-form))
(define s (hash 'a 0 'b 1))
(check-equal? (update tbn s '(a b))
#hash((a . 1) (b . 0)))
(check-equal? (network-domains tbn)
#hash((a . (0 1)) (b . (0 1))))))
(: build-tbn-state-graph (-> TBN Graph))
(define (build-tbn-state-graph tbn)
(pretty-print-state-graph
((inst build-full-state-graph (U Zero One))
((inst make-syn-dynamics (U Zero One))
(tbn->network tbn)))))
(module+ test
(test-case "build-tbn-state-graph"
(check-equal? (graphviz
(build-tbn-state-graph
(hash 'a (tbf/state (hash 'a -1 'b 1) 0)
'b (tbf/state (hash 'a -1 'b 1) 1))))
"digraph G {\n\tnode0 [label=\"a:0 b:0\"];\n\tnode1 [label=\"a:1 b:1\"];\n\tnode2 [label=\"a:0 b:1\"];\n\tnode3 [label=\"a:1 b:0\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [];\n\t}\n\tsubgraph D {\n\t\tnode1 -> node0 [];\n\t\tnode2 -> node3 [];\n\t\tnode3 -> node0 [];\n\t}\n}\n")))
(: normalized-tbn? (-> TBN Boolean))
(define (normalized-tbn? tbn)
(define tbn-vars (hash-keys tbn))
(for/and ([tbf (in-list (hash-values tbn))])
(set=? tbn-vars (hash-keys (tbf/state-w tbf)))))
(module+ test
(test-case "normalized-tbn?"
(check-true (normalized-tbn?
(hash 'x (tbf/state (hash 'x 0 'y -1) -1)
'y (tbf/state (hash 'x -1 'y 0) -1))))
(check-false (normalized-tbn?
(hash 'x (tbf/state (hash 'x 0 ) -1)
'y (tbf/state (hash 'y 0) -1))))))
(: normalize-tbn (-> TBN TBN))
(define (normalize-tbn tbn)
(define vars-0 (for/hash : (VariableMapping Real)
([(x _) (in-hash tbn)])
(values x 0)))
(: normalize-tbf (-> TBF/State TBF/State))
(define (normalize-tbf tbf)
;; Only keep the inputs which are also the variables of tbn.
(define w-pruned (hash-intersect/tbn-tbf
tbn
(tbf/state-w tbf)
#:combine (λ (_ w) w)))
;; Put in the missing inputs with weight 0.
(define w-complete
(assert-type (hash-union vars-0 w-pruned #:combine (λ (_ w) w))
(VariableMapping Real)))
(tbf/state w-complete (tbf/state-θ tbf)))
(for/hash : TBN ([(x tbf) (in-hash tbn)])
(values x (normalize-tbf tbf))))
(module+ test
(test-case "normalize-tbn"
(check-equal? (normalize-tbn (hash 'x (tbf/state (hash 'x 2) -1)
'y (tbf/state (hash 'y 3) 1)))
(hash 'x (tbf/state (hash 'x 2 'y 0) -1)
'y (tbf/state (hash 'x 0 'y 3) 1)))))
(: compact-tbn (-> TBN TBN))
(define (compact-tbn tbn)
(: remove-0-non-var (-> TBF/State TBF/State))
(define (remove-0-non-var tbf)
(tbf/state (for/hash : (VariableMapping Real)
([(x w) (in-hash (tbf/state-w tbf))]
#:when (hash-has-key? tbn x)
#:unless (zero? w))
(values x w))
(tbf/state-θ tbf)))
(for/hash : TBN ([(x tbf) (in-hash tbn)])
(values x (remove-0-non-var tbf))))
(module+ test
(test-case "compact-tbn"
(check-equal?
(compact-tbn (hash 'a (tbf/state (hash 'a 0 'b 1 'c 3 'd 0) 0)
'b (tbf/state (hash 'a -1 'b 1) -1)))
(hash
'a
(tbf/state '#hash((b . 1)) 0)
'b
(tbf/state '#hash((a . -1) (b . 1)) -1)))))
(: tbn-interaction-graph (->* (TBN) (#:zero-edges Boolean) Graph))
(define (tbn-interaction-graph tbn #:zero-edges [zero-edges #t])
(define ntbn (normalize-tbn tbn))
(define ig (weighted-graph/directed
(if zero-edges
(for*/list : (Listof (List Real Variable Variable))
([(tar tbf) (in-hash ntbn)]
[(src w) (in-hash (tbf/state-w tbf))])
(list w src tar))
(for*/list : (Listof (List Real Variable Variable))
([(tar tbf) (in-hash ntbn)]
[(src w) (in-hash (tbf/state-w tbf))]
#:unless (zero? w))
(list w src tar)))))
(update-graph
ig #:v-func (λ (x) (cons x (tbf/state-θ (hash-ref ntbn (assert-type x Variable)))))))
(module+ test
(test-case "tbn-interaction-graph"
(define tbn (hash 'a (tbf/state (hash 'b 1) 0)
'b (tbf/state (hash 'a -1) -1)))
(check-equal? (graphviz (tbn-interaction-graph tbn))
"digraph G {\n\tnode0 [label=\"'(b . -1)\"];\n\tnode1 [label=\"'(a . 0)\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n")
(check-equal? (graphviz (tbn-interaction-graph tbn #:zero-edges #f))
"digraph G {\n\tnode0 [label=\"'(b . -1)\"];\n\tnode1 [label=\"'(a . 0)\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n")))
(: pretty-print-tbn-interaction-graph (-> Graph Graph))
(define (pretty-print-tbn-interaction-graph ig)
(update-graph ig #:v-func (match-lambda
[(cons var weight) (~a var ":" weight)])))
(module+ test
(test-case "pretty-print-tbn-interaction-graph"
(define tbn (hash 'a (tbf/state (hash 'b 1) 0)
'b (tbf/state (hash 'a -1) -1)))
(check-equal? (graphviz (pretty-print-tbn-interaction-graph
(tbn-interaction-graph tbn)))
"digraph G {\n\tnode0 [label=\"a:0\"];\n\tnode1 [label=\"b:-1\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"-1\"];\n\t\tnode1 -> node0 [label=\"1\"];\n\t}\n}\n")))
(: sbn-interaction-graph (-> TBN Graph))
(define (sbn-interaction-graph sbn #:zero-edges [zero-edges #t])
(update-graph (tbn-interaction-graph sbn #:zero-edges zero-edges)
#:v-func (match-lambda
[(cons var _) var])))
(module+ test
(test-case "sbn-interaction-graph"
(define sbn (hash 'a (tbf/state (hash 'b 2) 0)
'b (tbf/state (hash 'a 2) 0)))
(check-equal? (graphviz (sbn-interaction-graph sbn))
"digraph G {\n\tnode0 [label=\"a\"];\n\tnode1 [label=\"b\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node1 [label=\"2\"];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t}\n}\n")))
(: parse-org-tbn (->* ((Listof (Listof (U Symbol Real))))
(#:headers Boolean
#:func-names Boolean)
TBN))
(define (parse-org-tbn tab #:headers [headers #t] #:func-names [func-names #t])
(cond [func-names
(define-values (vars rows) (multi-split-at tab 1))
(define tbfs (lists->tbfs/state/opt-headers rows #:headers headers))
(for/hash : TBN
([tbf (in-list tbfs)]
[var (in-list (cdr vars))])
(values (assert-type (car var) Variable) tbf))]
[else
(define tbfs (lists->tbfs/state/opt-headers tab #:headers headers))
(define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t))
(for/hash : TBN ([tbf (in-list tbfs)] [var (in-list vars)])
(values (assert-type var Variable) tbf))]))
(module+ test
(test-case "parse-org-tbn"
(check-equal?
(parse-org-tbn '((1 2 3) (3 2 1)) #:headers #f #:func-names #f)
(hash 'x0 (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
'x1 (tbf/state '#hash((x0 . 3) (x1 . 2)) 1)))
(check-equal?
(parse-org-tbn '((a b θ) (1 2 3) (3 2 1)) #:headers #t #:func-names #f)
(hash
'a
(tbf/state '#hash((a . 1) (b . 2)) 3)
'b
(tbf/state '#hash((a . 3) (b . 2)) 1)))
(check-equal?
(parse-org-tbn '((dummy a b θ) (b 3 2 1) (a 1 2 3))
#:headers #t
#:func-names #t)
(hash 'a (tbf/state '#hash((a . 1) (b . 2)) 3)
'b (tbf/state '#hash((a . 3) (b . 2)) 1)))))
(: read-org-tbn (->* (String) (#:headers Boolean #:func-names Boolean) TBN))
(define (read-org-tbn str
#:headers [headers #t]
#:func-names [func-names #t])
(parse-org-tbn (assert-type (read-org-sexp str)
(Listof (Listof (U Symbol Real))))
#:headers headers
#:func-names func-names))
(module+ test
(test-case "read-org-tbn"
(check-equal?
(read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))")
(hash
'x
(tbf/state '#hash((x . 0) (y . -1)) -1)
'y
(tbf/state '#hash((x . -1) (y . 0)) -1)))
(check-equal?
(read-org-tbn "((\"x\" \"y\" \"θ\") (-1 0 -1) (0 -1 -1))" #:func-names #f)
(hash
'x
(tbf/state '#hash((x . -1) (y . 0)) -1)
'y
(tbf/state '#hash((x . 0) (y . -1)) -1)))
(check-equal?
(read-org-tbn "((-1 0 -1) (0 -1 -1))" #:headers #f #:func-names #f)
(hash
'x0
(tbf/state '#hash((x0 . -1) (x1 . 0)) -1)
'x1
(tbf/state '#hash((x0 . 0) (x1 . -1)) -1)))))
(: read-org-sbn (->* (String) (#:headers Boolean #:func-names Boolean) TBN))
(define (read-org-sbn str
#:headers [headers #t]
#:func-names [func-names #t])
(define sexp (assert-type (read-org-sexp str)
(Listof (Listof (U Symbol Real)))))
;; Inject the 0 thresholds into the rows of the sexp we have just read.
(: inject-0 (-> (Listof (Listof (U Symbol Real))) (Listof (Listof (U Symbol Real)))))
(define (inject-0 rows)
(for/list : (Listof (Listof (U Symbol Real)))
([row (in-list rows)]) (append row '(0))))
(define sexp-ready (if headers
(cons (append (car sexp) '(θ)) (inject-0 (cdr sexp)))
(inject-0 sexp)))
(parse-org-tbn sexp-ready #:headers headers #:func-names func-names))
(module+ test
(test-case "read-org-sbn"
(check-equal? (read-org-sbn "((\"-\" \"x\" \"y\") (\"y\" -1 0) (\"x\" 0 -1))")
(hash
'x
(tbf/state '#hash((x . 0) (y . -1)) 0)
'y
(tbf/state '#hash((x . -1) (y . 0)) 0)))
(check-equal? (read-org-sbn "((\"x\" \"y\") (-1 0) (0 -1))" #:func-names #f)
(hash
'x
(tbf/state '#hash((x . -1) (y . 0)) 0)
'y
(tbf/state '#hash((x . 0) (y . -1)) 0)))
(check-equal? (read-org-sbn "((-1 0) (0 -1))" #:headers #f #:func-names #f)
(hash
'x0
(tbf/state '#hash((x0 . -1) (x1 . 0)) 0)
'x1
(tbf/state '#hash((x0 . 0) (x1 . -1)) 0)))))
(: tbn->lists (->* (TBN) (#:headers Boolean
#:func-names Boolean)
(Listof (Listof (U Symbol Real)))))
(define (tbn->lists tbn
#:headers [headers #t]
#:func-names [func-names #t])
(define ntbn (normalize-tbn tbn))
(define vars-tbfs (hash-map ntbn (λ ([x : Variable] [tbf : TBF/State])
(cons x tbf)) #t))
(define tbfs (map (inst cdr Variable TBF/State) vars-tbfs))
(define tbfs-table ((if headers
tbfs/state->lists+headers
tbfs/state->lists) tbfs))
(cond
[(eq? func-names #t)
(define vars (map (inst car Variable TBF/State) vars-tbfs))
(define col-1 (if headers (cons '- vars) vars))
(for/list ([var (in-list col-1)] [row (in-list tbfs-table)])
(cons var row))]
[else
tbfs-table]))
(module+ test
(test-case "tbn->lists"
(define tbn (hash 'a (tbf/state (hash 'b 1) 0)
'b (tbf/state (hash 'a -1) -1)))
(check-equal? (tbn->lists tbn)
'((- a b θ) (a 0 1 0) (b -1 0 -1)))
(check-equal? (tbn->lists tbn #:headers #f)
'((a 0 1 0) (b -1 0 -1)))
(check-equal? (tbn->lists tbn #:func-names #f)
'((a b θ) (0 1 0) (-1 0 -1)))
(check-equal? (tbn->lists tbn #:headers #f #:func-names #f)
'((0 1 0) (-1 0 -1)))))
(: sbn->lists (->* (TBN) (#:headers Boolean
#:func-names Boolean)
(Listof (Listof (U Symbol Real)))))
(define (sbn->lists sbn
#:headers [headers #t]
#:func-names [func-names #t])
(define tab (tbn->lists sbn #:headers headers #:func-names func-names))
(define-values (tab-no-θ _)
(multi-split-at tab (sub1 (length (car tab)))))
tab-no-θ)
(module+ test
(test-case "sbn->lists"
(define sbn (hash 'a (tbf/state (hash 'b 2) 0)
'b (tbf/state (hash 'a 2) 0)))
(check-equal? (sbn->lists sbn)
'((- a b) (a 0 2) (b 2 0)))
(check-equal? (sbn->lists sbn #:headers #f)
'((a 0 2) (b 2 0)))
(check-equal? (sbn->lists sbn #:func-names #f)
'((a b) (0 2) (2 0)))
(check-equal? (sbn->lists sbn #:headers #f #:func-names #f)
'((0 2) (2 0)))))