diff --git a/tbn.rkt b/tbn.rkt new file mode 100644 index 0000000..0f3d852 --- /dev/null +++ b/tbn.rkt @@ -0,0 +1,724 @@ +#lang racket + +(require (except-in "utils.rkt" lists-transpose) (submod "utils.rkt" untyped) + "functions.rkt" "networks.rkt" + graph racket/random racket/hash) + +(module+ test + (require rackunit)) + +;;; =================== +;;; TBF/TBN and SBF/SBN +;;; =================== + +;;; Applies a TBF to a state. +;;; +;;; The values of the variables of the state are ordered by hash-map +;;; and fed to the TBF in order. The number of the inputs of the TBF +;;; must match the number of variables in the state. +(define (apply-tbf-to-state tbf st) + (apply-tbf tbf (list->vector (hash-map st (λ (_ val) val))))) + +(module+ test + (test-case "apply-tbf-to-state" + (define st (hash 'x1 0 'x2 1)) + (define f (tbf #(1 1) 1)) + (check-equal? (apply-tbf-to-state f st) 0))) + +;;; A state TBF is a TBF with named inputs. A state TBF can be +;;; applied to states in an unambiguous ways. +(struct tbf/state (weights threshold) #:transparent) + +;;; Shortcuts for acessing fields of a state/tbf. +(define tbf/state-w tbf/state-weights) +(define tbf/state-θ tbf/state-threshold) + +;;; Makes a state/tbf from a list of pairs of names of variables and +;;; weights, as well as a threshold. +(define (make-tbf/state pairs threshold) + (tbf/state (make-immutable-hash pairs) threshold)) + +(module+ test + (test-case "tbf/state" + (define f (make-tbf/state '((x1 . 1) (x2 . 1)) 1)) + (check-equal? (tbf/state-w f) #hash((x1 . 1) (x2 . 1))) + (check-equal? (tbf/state-θ f) 1))) + +;;; A sign Boolean function (SBF) is a TBF whose threshold is 0. +(define sbf/state? (and/c tbf/state? (λ (tbf) (zero? (tbf/state-θ tbf))))) + +(module+ test + (test-case "sbf/state?" + (check-true (sbf/state? (tbf/state #hash((a . -1) (b . 1)) 0))))) + +;;; Makes a state/tbf which is an SBF from a list of pairs of names of +;;; variables and weights. +(define (make-sbf/state pairs) + (make-tbf/state pairs 0)) + +(module+ test + (test-case "make-sbf/state" + (check-equal? (make-sbf/state '((a . -1) (b . 1))) + (make-tbf/state '((a . -1) (b . 1)) 0)))) + +;;; Applies a state TBF to its inputs. +;;; +;;; Applying a TBF consists in multiplying the weights by the +;;; corresponding inputs and comparing the sum of the products to the +;;; threshold. +;;; +;;; This function is similar to apply-tbf, but applies a state TBF (a +;;; TBF with explicitly named inputs) to a state whose values are 0 +;;; and 1. +(define (apply-tbf/state tbf st) + (any->01 (> (foldl + 0 (hash-values + (hash-intersect (tbf/state-w tbf) + st + #:combine *))) + (tbf/state-θ tbf)))) + +(module+ test + (test-case "apply-tbf/state" + (define st1 (hash 'a 1 'b 0 'c 1)) + (define st2 (hash 'a 1 'b 1 'c 0)) + (define tbf (make-tbf/state '((a . 2) (b . -2)) 1)) + (check-equal? (apply-tbf/state tbf st1) 1) + (check-equal? (apply-tbf/state tbf st2) 0))) + +;;; Reads a list of tbf/state from a list of list of numbers. +;;; +;;; The last element of each list is taken to be the threshold of the +;;; TBFs, and the rest of the elements are taken to be the weights. +;;; +;;; If headers is #t, the names of the variables to appear as the +;;; inputs of the TBF are taken from the first list. The last element +;;; of this list is discarded. +;;; +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +(define (lists->tbfs/state lsts #:headers [headers #t]) + (define-values (var-names rows) + (if headers + (values (car lsts) (cdr lsts)) + (values (for/list ([i (in-range (length (car lsts)))]) + (string->symbol (format "x~a" i))) + lsts))) + (for/list ([lst (in-list rows)]) + (define-values (ws θ) (split-at-right lst 1)) + (make-tbf/state (for/list ([x (in-list var-names)] + [w (in-list ws)]) + (cons x w)) + (car θ)))) + +(module+ test + (test-case "lists->tbfs/state" + (define tbfs '((1 2 3) (1 1 2))) + (check-equal? (lists->tbfs/state tbfs #:headers #f) + (list + (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) + (tbf/state '#hash((x0 . 1) (x1 . 1)) 2))) + (check-equal? (lists->tbfs/state (cons '(a b f) tbfs)) + (list + (tbf/state '#hash((a . 1) (b . 2)) 3) + (tbf/state '#hash((a . 1) (b . 1)) 2))))) + +;;; Like lists->tbfs/state, but does not expect thresholds in the +;;; input. +;;; +;;; Every lists in the list contains the weights of the SBF. If +;;; headers is #t, the names of the variables to appear as the inputs +;;; of the TBF are taken from the first list. +;;; +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +(define (lists->sbfs/state lsts #:headers [headers #t]) + (define rows (if headers (cdr lsts) lsts)) + (define rows-θ (for/list ([lst (in-list rows)]) (append lst '(0)))) + (lists->tbfs/state (if headers (cons (car lsts) rows-θ) rows-θ) + #:headers headers)) + +(module+ test + (test-case "lists->sbfs/state" + (define tbfs '((1 2) (1 -1))) + (check-equal? (lists->sbfs/state tbfs #:headers #f) + (list + (tbf/state '#hash((x0 . 1) (x1 . 2)) 0) + (tbf/state '#hash((x0 . 1) (x1 . -1)) 0))) + (check-equal? (lists->sbfs/state (cons '(a b) tbfs) #:headers #t) + (list + (tbf/state '#hash((a . 1) (b . 2)) 0) + (tbf/state '#hash((a . 1) (b . -1)) 0))))) + +;;; Reads a list of tbf/state from an Org-mode string containing a +;;; sexp, containing a list of lists of numbers. As in +;;; lists->tbfs/state, the last element of each list is taken to be +;;; the threshold of the TBFs, and the rest of the elements are taken +;;; to be the weights. +;;; +;;; If headers is #t, the names of the variables to appear as the +;;; inputs of the TBF are taken from the first list. The last element +;;; of this list is discarded. +;;; +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +(define (read-org-tbfs/state str #:headers [headers #t]) + (lists->tbfs/state (read-org-sexp str) #:headers headers)) + +(module+ test + (test-case "read-org-tbfs/state" + (check-equal? (read-org-tbfs/state "((a b f) (1 2 3) (1 1 2))") + (list + (tbf/state '#hash((a . 1) (b . 2)) 3) + (tbf/state '#hash((a . 1) (b . 1)) 2))) + (check-equal? (read-org-tbfs/state "((1 2 3) (1 1 2))" #:headers #f) + (list + (tbf/state '#hash((x0 . 1) (x1 . 2)) 3) + (tbf/state '#hash((x0 . 1) (x1 . 1)) 2))))) + +;;; Like read-org-tbfs/state, but reads a list of SBFs. Therefore, +;;; the lists of numbers in the sexp are taken to be the weights of +;;; the SBFs. +;;; +;;; If headers is #t, the names of the variables to appear as the +;;; inputs of the TBF are taken from the first list. If headers is +;;; #f, the names of the variables are generated as xi, where i is the +;;; index of the variable. +(define (read-org-sbfs/state str #:headers [headers #t]) + (lists->sbfs/state (read-org-sexp str) #:headers headers)) + +(module+ test + (test-case "read-org-sbfs/state" + (check-equal? (read-org-sbfs/state "((a b) (-1 2) (1 1))") + (list + (tbf/state '#hash((a . -1) (b . 2)) 0) + (tbf/state '#hash((a . 1) (b . 1)) 0))) + (check-equal? (read-org-sbfs/state "((-1 2) (1 1))" #:headers #f) + (list + (tbf/state '#hash((x0 . -1) (x1 . 2)) 0) + (tbf/state '#hash((x0 . 1) (x1 . 1)) 0))))) + +;;; Given a list of tbf/state, produces a sexp that Org-mode can +;;; interpret as a table. +;;; +;;; All tbf/state in the list must have the same inputs. The function +;;; does not check this property. +;;; +;;; If #:headers is #f, does not print the names of the inputs of the +;;; TBFs. If #:headers is #t, the output starts by a list giving the +;;; names of the variables, as well as the symbol 'θ to represent the +;;; column giving the thresholds of the TBF. +(define (print-org-tbfs/state tbfs #:headers [headers #t]) + (define table (for/list ([tbf (in-list tbfs)]) + (append (hash-map (tbf/state-w tbf) (λ (_ w) w) #t) + (list (tbf/state-θ tbf))))) + (if headers + (cons (append (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t) '(θ)) + table) + table)) + +(module+ test + (test-case "print-org-tbfs/state" + (define tbfs (list (make-tbf/state '((a . 1) (b . 2)) 3) + (make-tbf/state '((a . -2) (b . 1)) 1))) + (check-equal? (print-org-tbfs/state tbfs) + '((a b θ) (1 2 3) (-2 1 1))))) + +;;; Like print-org-tbfs/state, but expects a list of SBFs. The +;;; thresholds are therefore not included in the output. +;;; +;;; All sbf/state in the list must have the same inputs. The function +;;; does not check this property. +;;; +;;; If #:headers is #f, does not print the names of the inputs of the +;;; TBFs. If #:headers is #t, the output starts by a list giving the +;;; names of the variables. +(define (print-org-sbfs/state sbfs #:headers [headers #t]) + (define table (for/list ([sbf (in-list sbfs)]) + (hash-map (tbf/state-w sbf) (λ (_ w) w) #t))) + (if headers + (cons (hash-map (tbf/state-w (car sbfs)) (λ (x _) x) #t) + table) + table)) + +(module+ test + (define sbfs (list (make-sbf/state '((a . 1) (b . 2))) + (make-sbf/state '((a . -2) (b . 1))))) + (check-equal? (print-org-sbfs/state sbfs) + '((a b) (1 2) (-2 1))) + (check-equal? (print-org-sbfs/state sbfs #:headers #f) + '((1 2) (-2 1)))) + +;;; Tabulates a list of tbf/state. +;;; +;;; As in the case of tbf-tabulate*, the result is a list of lists +;;; giving the truth tables of the given TBFs. The first elements of +;;; each row give the values of the inputs, while the last elements +;;; give the values of each function corresponding to the input. +;;; +;;; All the TBFs must have exactly the same inputs. This function +;;; does not check this property. +;;; +;;; If #:headers is #t, the output starts by a list giving the names +;;; of the variables, and then the symbols 'fi, where i is the number +;;; of the TBF in the list. +(define (tbf/state-tabulate* tbfs #:headers [headers #t]) + (define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t)) + (tabulate-state* (map (curry apply-tbf/state) tbfs) + (make-same-domains vars '(0 1)) + #:headers headers)) + +(module+ test + (test-case "tbf/state-tabulate*" + (define tbfs (list (make-tbf/state '((a . 1) (b . 2)) 1) + (make-tbf/state '((a . -2) (b . 3)) 1))) + (check-equal? (tbf/state-tabulate* tbfs) + '((a b f1 f2) + (0 0 0 0) + (0 1 1 1) + (1 0 0 0) + (1 1 1 0))))) + +;;; Like tbf/state-tabulate*, but only tabulates a single TBF. +(define (tbf/state-tabulate tbf #:headers [headers #t]) + (tbf/state-tabulate* (list tbf) #:headers headers)) + +(module+ test + (test-case "tbf/state-tabulate" + (define tbf (make-tbf/state '((a . -2) (b . 3)) 1)) + (check-equal? (tbf/state-tabulate tbf) + '((a b f1) + (0 0 0) + (0 1 1) + (1 0 0) + (1 1 0))))) + +;;; Given a truth table of a Boolean function, groups the lines by the +;;; "number of activated inputs"—the number of inputs which are 1 in +;;; the input vector. +;;; +;;; The truth table must not include the header line. +(define (group-truth-table-by-nai tt) + (define sum (((curry foldl) +) 0)) + (group-by (λ (row) (drop-right row 1)) + tt + (λ (in1 in2) (= (sum in1) (sum in2))))) + +(module+ test + (test-case "group-truth-table-by-nai" + (check-equal? (group-truth-table-by-nai '((0 0 0 1) + (0 0 1 1) + (0 1 0 0) + (0 1 1 1) + (1 0 0 0) + (1 0 1 0) + (1 1 0 1) + (1 1 1 0))) + '(((0 0 0 1)) + ((0 0 1 1) (0 1 0 0) (1 0 0 0)) + ((0 1 1 1) (1 0 1 0) (1 1 0 1)) + ((1 1 1 0)))))) + +;;; A TBN is a network form mapping variables to tbf/state. +;;; +;;; The tbf/state must only reference variables appearing in the +;;; network. This contract does not check this condition. +(define tbn? (hash/c symbol? tbf/state?)) + +;;; Builds a TBN from a list of pairs (variable, tbf/state). +(define make-tbn make-immutable-hash) + +(module+ test + (test-case "make-tbn" + (define tbf-not (make-tbf/state '((a . -1)) -1)) + (define tbf-id (make-sbf/state '((a . 1)))) + (check-equal? (make-tbn `((a . ,tbf-not) (b . ,tbf-id))) + (hash 'a (tbf/state '#hash((a . -1)) -1) + 'b (tbf/state '#hash((a . 1)) 0))))) + +;;; A SBN is a network form mapping variables to sbf/state. +;;; +;;; The tbf/state must only reference variables appearing in the +;;; network. This contract does not check this condition. +(define sbn? (hash/c symbol? sbf/state?)) + +;;; Builds an SBN from a list of pairs (variable, sbf/state). +(define make-sbn make-immutable-hash) + +(module+ test + (test-case "make-sbn" + (define sbf1 (make-sbf/state '((a . -1)))) + (define sbf2 (make-sbf/state '((a . 1)))) + (check-equal? (make-sbn `((a . ,sbf1) (b . ,sbf2))) + (hash 'a (tbf/state '#hash((a . -1)) 0) + 'b (tbf/state '#hash((a . 1)) 0))))) + +;;; Constructs a network from a network form defining a TBN. +(define (tbn->network tbn) + (make-01-network (for/hash ([(var tbf) (in-hash tbn)]) + (values var ((curry apply-tbf/state) tbf))))) + +(module+ test + (test-case "tbn->network" + (define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1)))) + (b . ,(make-tbf/state '((a . -1)) -1))))) + (define n (tbn->network tbn)) + (define s1 (hash 'a 0 'b 0)) + (check-equal? (update n s1 '(a b)) + (hash 'a 0 'b 1)) + (check-equal? (network-domains n) #hash((a . (0 1)) (b . (0 1)))) + + (define sbn (make-sbn `((a . ,(make-sbf/state '((b . -1)))) + (b . ,(make-sbf/state '((a . 1))))))) + (define sn (tbn->network sbn)) + (define s2 (hash 'a 1 'b 1)) + (check-equal? (update sn s2 '(a b)) + (hash 'a 0 'b 1)) + (check-equal? (network-domains sn) #hash((a . (0 1)) (b . (0 1)))))) + +;;; A helper function for read-org-tbn and read-org-sbn. It reads a +;;; TBN from an Org-mode sexp containing a list of lists of numbers. +;;; As in lists->tbfs/state, the last element of each list is taken to +;;; be the threshold of the TBFs, and the rest of the elements are +;;; taken to be the weights. +;;; +;;; As in read-org-tbfs/state, if headers is #t, the names of the +;;; variables to appear as the inputs of the TBF are taken from the +;;; first list. The last element of this list is discarded. +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +;;; +;;; If func-names is #t, the first element in every row except the +;;; first one, are taken to be the name of the variable to which the +;;; TBF should be associated. If func-names is #f, the functions are +;;; assigned to variables in alphabetical order. +;;; +;;; func-names cannot be #t if headers is #f. The function does not +;;; check this condition. +(define (parse-org-tbn sexp + #:headers [headers #t] + #:func-names [func-names #t]) + (cond + [(eq? func-names #t) + (define-values (vars rows) (multi-split-at sexp 1)) + (define tbfs (lists->tbfs/state rows #:headers headers)) + (for/hash ([tbf (in-list tbfs)] [var (in-list (cdr vars))]) + (values (car var) tbf))] + [else + (define tbfs (lists->tbfs/state sexp #:headers headers)) + (define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t)) + (for/hash ([tbf (in-list tbfs)] [var (in-list vars)]) + (values var tbf))])) + +;;; Reads a TBN from an Org-mode string containing a sexp, containing +;;; a list of lists of numbers. As in lists->tbfs/state, the last +;;; element of each list is taken to be the threshold of the TBFs, and +;;; the rest of the elements are taken to be the weights. +;;; +;;; As in read-org-tbfs/state, if headers is #t, the names of the +;;; variables to appear as the inputs of the TBF are taken from the +;;; first list. The last element of this list is discarded. +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +;;; +;;; If func-names is #t, the first element in every row except the +;;; first one, are taken to be the name of the variable to which the +;;; TBF should be associated. If func-names is #f, the functions are +;;; assigned to variables in alphabetical order. +;;; +;;; func-names cannot be #t if headers is #f. The function does not +;;; check this condition. +(define (read-org-tbn str + #:headers [headers #t] + #:func-names [func-names #t]) + (parse-org-tbn (read-org-sexp str) + #:headers headers + #:func-names func-names)) + +(module+ test + (test-case "read-org-tbn, parse-org-tbn" + (check-equal? (read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))") + (hash + 'x + (tbf/state '#hash((x . 0) (y . -1)) -1) + 'y + (tbf/state '#hash((x . -1) (y . 0)) -1))) + (check-equal? (read-org-tbn "((\"x\" \"y\" \"θ\") (-1 0 -1) (0 -1 -1))" #:func-names #f) + (hash + 'x + (tbf/state '#hash((x . -1) (y . 0)) -1) + 'y + (tbf/state '#hash((x . 0) (y . -1)) -1))) + (check-equal? (read-org-tbn "((-1 0 -1) (0 -1 -1))" #:headers #f #:func-names #f) + (hash + 'x0 + (tbf/state '#hash((x0 . -1) (x1 . 0)) -1) + 'x1 + (tbf/state '#hash((x0 . 0) (x1 . -1)) -1))))) + +;;; Like read-org-tbn, but reads an SBN from an Org-mode string +;;; containing a sexp, containing a list of lists of numbers. +;;; +;;; As in read-org-sbfs/state, if headers is #t, the names of the +;;; variables to appear as the inputs of the SBF are taken from the +;;; first list. The last element of this list is discarded. +;;; If headers is #f, the names of the variables are generated as xi, +;;; where i is the index of the variable. +;;; +;;; If func-names is #t, the first element in every row except the +;;; first one, are taken to be the name of the variable to which the +;;; TBF should be associated. If func-names is #f, the functions are +;;; assigned to variables in alphabetical order. +;;; +;;; func-names cannot be #t if headers is #f. The function does not +;;; check this condition. +(define (read-org-sbn str + #:headers [headers #t] + #:func-names [func-names #t]) + (define sexp (read-org-sexp str)) + ;; Inject the 0 thresholds into the rows of the sexp we have just read. + (define (inject-0 rows) (for/list ([row (in-list rows)]) (append row '(0)))) + (define sexp-ready (if headers + (cons (car sexp) (inject-0 (cdr sexp))) + (inject-0 sexp))) + (parse-org-tbn sexp-ready #:headers headers #:func-names func-names)) + +(module+ test + (test-case "read-org-sbn, parse-org-tbn" + (check-equal? (read-org-sbn "((\"-\" \"x\" \"y\") (\"y\" -1 0) (\"x\" 0 -1))") + (hash + 'x + (tbf/state '#hash((x . 0) (y . -1)) 0) + 'y + (tbf/state '#hash((x . -1) (y . 0)) 0))) + (check-equal? (read-org-sbn "((\"x\" \"y\") (-1 0) (0 -1))" #:func-names #f) + (hash + 'x + (tbf/state '#hash((x . -1) (y . 0)) 0) + 'y + (tbf/state '#hash((x . 0) (y . -1)) 0))) + (check-equal? (read-org-sbn "((-1 0) (0 -1))" #:headers #f #:func-names #f) + (hash + 'x0 + (tbf/state '#hash((x0 . -1) (x1 . 0)) 0) + 'x1 + (tbf/state '#hash((x0 . 0) (x1 . -1)) 0))))) + +;;; A shortcut for building the state graphs of TBN. +(define build-tbn-state-graph + (compose pretty-print-state-graph + build-full-state-graph + make-syn-dynamics + tbn->network)) + +;;; Checks whether a TBN is normalized: whether all of the functions +;;; have the same inputs, and whether these inputs are exactly the +;;; variables of the TBN. +(define (normalized-tbn? tbn) + (define tbn-vars (hash-keys tbn)) + (for/and ([tbf (in-list (hash-values tbn))]) + (set=? tbn-vars (hash-keys (tbf/state-w tbf))))) + +(module+ test + (test-case "normalized-tbn?" + (check-false (normalized-tbn? + (make-tbn `((a . ,(make-sbf/state '((b . 1)))) + (b . ,(make-tbf/state '((a . -1)) -1)))))) + (check-true (normalized-tbn? + (make-tbn `((a . ,(make-sbf/state '((a . 1) (b . -1)))) + (b . ,(make-tbf/state '((a . -1) (b . 1)) -1)))))))) + +;;; Normalizes a TBN. +;;; +;;; For every TBF, removes the inputs that are not in the variables of +;;; the TBN, and adds missing inputs with 0 weight. +(define (normalize-tbn tbn) + (define vars-0 (for/hash ([(x _) (in-hash tbn)]) (values x 0))) + (define (normalize-tbf tbf) + ;; Only keep the inputs which are also the variables of tbn. + (define w-pruned (hash-intersect tbn (tbf/state-w tbf) + #:combine (λ (_ w) w))) + ;; Put in the missing inputs with weight 0. + (define w-complete (hash-union vars-0 w-pruned #:combine (λ (_ w) w))) + (tbf/state w-complete (tbf/state-θ tbf))) + (for/hash ([(x tbf) (in-hash tbn)]) (values x (normalize-tbf tbf)))) + +(module+ test + (test-case "normalize-tbn" + (check-equal? (normalize-tbn + (hash 'a (make-sbf/state '((b . 1) (c . 3))) + 'b (make-tbf/state '((a . -1)) -1))) + (hash + 'a + (tbf/state '#hash((a . 0) (b . 1)) 0) + 'b + (tbf/state '#hash((a . -1) (b . 0)) -1))))) + +;;; Compacts (and denormalizes) a TBF by removing all inputs which +;;; are 0. +(define (compact-tbf tbf) + (tbf/state + (for/hash ([(k v) (in-hash (tbf/state-w tbf))] + #:unless (zero? v)) + (values k v)) + (tbf/state-θ tbf))) + +(module+ test + (test-case "compact-tbf" + (check-equal? (compact-tbf (tbf/state (hash 'a 0 'b 1 'c 2 'd 0) 2)) + (tbf/state '#hash((b . 1) (c . 2)) 2)))) + +;;; Compacts a TBN by removing all inputs which are 0 or which are not +;;; variables of the network. +(define (compact-tbn tbn) + (define (remove-0-non-var tbf) + (tbf/state + (for/hash ([(x w) (in-hash (tbf/state-w tbf))] + #:when (hash-has-key? tbn x) + #:unless (zero? w)) + (values x w)) + (tbf/state-θ tbf))) + (for/hash ([(x tbf) (in-hash tbn)]) + (values x (remove-0-non-var tbf)))) + +(module+ test + (test-case "compact-tbn" + (check-equal? + (compact-tbn (hash 'a (tbf/state (hash 'a 0 'b 1 'c 3 'd 0) 0) + 'b (tbf/state (hash 'a -1 'b 1) -1))) + (hash + 'a + (tbf/state '#hash((b . 1)) 0) + 'b + (tbf/state '#hash((a . -1) (b . 1)) -1))))) + +;;; Given TBN, produces a sexp containing the description of the +;;; functions of the TBN that Org-mode can interpret as a table. +;;; +;;; Like print-org-tbfs/state, if #:headers is #f, does not print the +;;; names of the inputs of the TBFs. If #:headers is #t, the output +;;; starts by a list giving the names of the variables, as well as the +;;; symbol 'θ to represent the column giving the thresholds of the +;;; TBF. +;;; +;;; If #:func-names is #t, the first column of the table gives the +;;; variable which the corresponding TBF updates. +;;; +;;; If both #:func-names and #:headers are #t, the first cell of the +;;; first column contains the symbol '-. +(define (print-org-tbn tbn + #:headers [headers #t] + #:func-names [func-names #t]) + (define ntbn (normalize-tbn tbn)) + (define vars-tbfs (hash-map ntbn (λ (x tbf) (cons x tbf)) #t)) + (define tbfs (map cdr vars-tbfs)) + (define tbfs-table (print-org-tbfs/state tbfs #:headers headers)) + (cond + [(eq? func-names #t) + (define vars (map car vars-tbfs)) + (define col-1 (if headers (cons '- vars) vars)) + (for/list ([var (in-list col-1)] [row (in-list tbfs-table)]) + (cons var row))] + [else + tbfs-table])) + +(module+ test + (test-case "print-org-tbn" + (define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1)))) + (b . ,(make-tbf/state '((a . -1)) -1))))) + (check-equal? (print-org-tbn tbn) + '((- a b θ) (a 0 1 0) (b -1 0 -1))))) + +;;; Given an SBN, produces a sexp containing the description of the +;;; functions of the SBN that Org-mode can interpret as a table. +;;; This function is therefore very similar to print-org-tbn. +;;; +;;; Like print-org-tbfs/state, if #:headers is #f, does not print the +;;; names of the inputs of the TBFs. If #:headers is #t, the output +;;; starts by a list giving the names of the variables. +;;; +;;; If #:func-names is #t, the first column of the table gives the +;;; variable which the corresponding TBF updates. +;;; +;;; If both #:func-names and #:headers are #t, the first cell of the +;;; first column contains the symbol '-. +(define (print-org-sbn sbn + #:headers [headers #t] + #:func-names [func-names #t]) + (define tab (print-org-tbn sbn #:headers headers #:func-names func-names)) + (define-values (tab-no-θ _) (multi-split-at + tab + (- (length (car tab)) 1))) + tab-no-θ) + +(module+ test + (test-case "print-org-sbn" + (define sbn (hash + 'a + (tbf/state (hash 'b 2) 0) + 'b + (tbf/state (hash 'a 2) 0))) + (check-equal? (print-org-sbn sbn) + '((- a b) (a 0 2) (b 2 0))))) + +;;; Given a TBN, constructs its interaction graph. The nodes of this +;;; graph are labeled with pairs (variable name . threshold), while +;;; the edges are labelled with the weights. +;;; +;;; If #:zero-edges is #t, the edges with zero weights will appear in +;;; the interaction graph. +(define (tbn-interaction-graph tbn + #:zero-edges [zero-edges #t]) + (define ntbn (normalize-tbn tbn)) + (define ig (weighted-graph/directed + (if zero-edges + (for*/list ([(tar tbf) (in-hash ntbn)] + [(src w) (in-hash (tbf/state-w tbf))]) + (list w src tar)) + (for*/list ([(tar tbf) (in-hash ntbn)] + [(src w) (in-hash (tbf/state-w tbf))] + #:unless (zero? w)) + (list w src tar))))) + (update-graph ig #:v-func (λ (x) (cons x (tbf/state-θ (hash-ref ntbn x)))))) + +(module+ test + (test-case "tbn-interaction-graph" + (define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1)))) + (b . ,(make-tbf/state '((a . -1)) -1))))) + (check-equal? (graphviz (tbn-interaction-graph tbn)) + "digraph G {\n\tnode0 [label=\"'(b . -1)\\n\"];\n\tnode1 [label=\"'(a . 0)\\n\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n") + (check-equal? (graphviz (tbn-interaction-graph tbn #:zero-edges #f)) + "digraph G {\n\tnode0 [label=\"'(b . -1)\\n\"];\n\tnode1 [label=\"'(a . 0)\\n\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n"))) + +;;; Pretty prints the node labels of the interaction graph of a TBN. +(define (pretty-print-tbn-interaction-graph ig) + (update-graph ig #:v-func (match-lambda + [(cons var weight) (~a var ":" weight)]))) + +(module+ test + (test-case "pretty-print-tbn-interaction-graph" + (define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1)))) + (b . ,(make-tbf/state '((a . -1)) -1))))) + (check-equal? (graphviz (pretty-print-tbn-interaction-graph (tbn-interaction-graph tbn))) + "digraph G {\n\tnode0 [label=\"b:-1\"];\n\tnode1 [label=\"a:0\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n"))) + +;;; Given an SBN, constructs its interaction graph. As in +;;; tbn-interaction-graph, the nodes of this graph are labeled with +;;; the variable names, while the edges are labelled with the weights. +;;; +;;; If #:zero-edges is #t, the edges with zero weights will appear in +;;; the interaction graph. +(define (sbn-interaction-graph sbn + #:zero-edges [zero-edges #t]) + (update-graph (tbn-interaction-graph sbn #:zero-edges zero-edges) + #:v-func (match-lambda + [(cons var _) var]))) + +(module+ test + (test-case "sbn-interaction-graph" + (define sbn (hash + 'a + (tbf/state (hash 'b 2) 0) + 'b + (tbf/state (hash 'a 2) 0))) + (check-equal? (graphviz (sbn-interaction-graph sbn)) + "digraph G {\n\tnode0 [label=\"b\"];\n\tnode1 [label=\"a\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node1 [label=\"2\"];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t}\n}\n")))