dds/tbn.rkt
Sergiu Ivanov 738ad858ae Add sbn?.
2023-05-22 16:27:35 +02:00

1054 lines
42 KiB
Racket
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#lang racket
(require (except-in "utils.rkt" lists-transpose) (submod "utils.rkt" untyped)
"functions.rkt" "networks.rkt"
graph racket/random racket/hash)
(module typed typed/racket
(require (except-in "utils.rkt" lists-transpose)
"utils.rkt" "functions.rkt" "networks.rkt"
typed/graph typed/racket/random)
(require/typed racket/hash
[hash-intersect
(->* ((HashTable Variable Real))
(#:combine (-> Real Real Real))
#:rest (HashTable Variable Real)
(HashTable Variable Real))])
(module+ test
(require typed/rackunit))
(provide
apply-tbf-to-state
(struct-out tbf/state) TBF/State tbf/state-w tbf/state-θ make-tbf/state
sbf/state? apply-tbf/state
lists+vars->tbfs/state lists+headers->tbfs/state lists->tbfs/state
lists+vars->sbfs/state lists+headers->sbfs/state lists->sbfs/state
read-org-tbfs/state read-org-tbfs/state+headers
tbfs/state->lists tbfs/state->lists+headers
tabulate-tbfs/state tabulate-tbfs/state+headers
tabulate-tbf/state tabulate-tbf/state+headers
group-truth-table-by-nai
TBN sbn?
)
(: apply-tbf-to-state (-> TBF (State (U Zero One)) (U Zero One)))
(define (apply-tbf-to-state tbf st)
(apply-tbf tbf (list->vector
(hash-map st (λ (_ [val : (U Zero One)]) val) #t))))
(module+ test
(test-case "apply-tbf-to-state"
(define st (hash 'x1 0 'x2 1))
(define f (tbf #(1 1) 1))
(check-equal? (apply-tbf-to-state f st) 0)))
(struct tbf/state ([weights : (VariableMapping Real)]
[threshold : Real])
#:transparent
#:type-name TBF/State)
(define tbf/state-w tbf/state-weights)
(define tbf/state-θ tbf/state-threshold)
(: make-tbf/state (-> (Listof (Pairof Variable Real)) Real TBF/State))
(define (make-tbf/state pairs threshold)
(tbf/state (make-immutable-hash pairs) threshold))
(module+ test
(test-case "tbf/state"
(define f (make-tbf/state '((x1 . 1) (x2 . 1)) 1))
(check-equal? (tbf/state-w f) #hash((x1 . 1) (x2 . 1)))
(check-equal? (tbf/state-θ f) 1)))
(: sbf/state? (-> TBF/State Boolean))
(define (sbf/state? tbfs) (zero? (tbf/state-θ tbfs)))
(module+ test
(test-case "sbf/state?"
(check-true (sbf/state? (tbf/state (hash 'a -1 'b 1) 0)))
(check-false (sbf/state? (tbf/state (hash 'a -1 'b 1) 1)))))
(: apply-tbf/state (-> TBF/State (State (U Zero One)) (U Zero One)))
(define (apply-tbf/state tbfs st)
(any->01
(> (apply + (hash-values (hash-intersect (tbf/state-w tbfs) st #:combine *)))
(tbf/state-θ tbfs))))
(module+ test
(test-case "apply-tbf/state"
(define st1 (hash 'a 1 'b 0 'c 1))
(define st2 (hash 'a 1 'b 1 'c 0))
(define tbf (make-tbf/state '((a . 2) (b . -2)) 1))
(check-equal? (apply-tbf/state tbf st1) 1)
(check-equal? (apply-tbf/state tbf st2) 0)))
(: lists+vars->tbfs/state (-> (Listof Variable) (Listof (Listof Real))
(Listof TBF/State)))
(define (lists+vars->tbfs/state vars lsts)
(for/list ([lst (in-list lsts)])
(define-values (ws θ) (split-at-right lst 1))
(make-tbf/state (for/list ([x (in-list vars)]
[w (in-list ws)])
(cons x w))
(car θ))))
(module+ test
(test-case "lists+vars->tbfs/state"
(check-equal? (lists+vars->tbfs/state '(x y) '((1 2 3) (1 1 2)))
(list (tbf/state '#hash((x . 1) (y . 2)) 3)
(tbf/state '#hash((x . 1) (y . 1)) 2)))))
(: lists+headers->tbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real)))
(Listof TBF/State)))
(define (lists+headers->tbfs/state lsts+headers)
(lists+vars->tbfs/state (drop-right (car lsts+headers) 1)
(cdr lsts+headers)))
(module+ test
(test-case "lists+headers->tbfs/state"
(check-equal? (lists+headers->tbfs/state '((x y f) (1 2 3) (1 1 2)))
(list (tbf/state '#hash((x . 1) (y . 2)) 3)
(tbf/state '#hash((x . 1) (y . 1)) 2)))))
(: lists->tbfs/state (-> (Listof (Listof Real)) (Listof TBF/State)))
(define (lists->tbfs/state lsts)
(lists+vars->tbfs/state
(for/list ([i (in-range (length (car lsts)))])
(string->symbol (format "x~a" i)))
lsts))
(module+ test
(test-case "lists->tbfs/state"
(check-equal? (lists->tbfs/state '((1 2 3) (1 1 2)))
(list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))))
(: lists+vars->sbfs/state (-> (Listof Variable) (Listof (Listof Real))
(Listof TBF/State)))
(define (lists+vars->sbfs/state vars lsts)
(for/list ([lst (in-list lsts)])
(make-tbf/state (map (inst cons Variable Real) vars lst) 0)))
(module+ test
(test-case "lists+vars->sbfs/state"
(check-equal? (lists+vars->sbfs/state '(x y) '((1 2) (1 1)))
(list (tbf/state '#hash((x . 1) (y . 2)) 0)
(tbf/state '#hash((x . 1) (y . 1)) 0)))))
(: lists+headers->sbfs/state (-> (Pairof (Listof Variable) (Listof (Listof Real)))
(Listof TBF/State)))
(define (lists+headers->sbfs/state lsts)
(lists+vars->sbfs/state (car lsts) (cdr lsts)))
(module+ test
(test-case "lists+headers->sbfs/state"
(check-equal? (lists+headers->sbfs/state '((x y) (1 2) (1 1)))
(list (tbf/state '#hash((x . 1) (y . 2)) 0)
(tbf/state '#hash((x . 1) (y . 1)) 0)))))
(: lists->sbfs/state (-> (Listof (Listof Real)) (Listof TBF/State)))
(define (lists->sbfs/state lsts)
(lists+vars->sbfs/state
(for/list ([i (in-range (length (car lsts)))])
(string->symbol (format "x~a" i)))
lsts))
(module+ test
(test-case "lists->sbfs/state"
(check-equal? (lists->sbfs/state '((1 2) (1 1)))
(list
(tbf/state '#hash((x0 . 1) (x1 . 2)) 0)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 0)))))
(: read-org-tbfs/state (-> String (Listof TBF/State)))
(define (read-org-tbfs/state str)
(lists->tbfs/state
(assert-type (read-org-sexp str)
(Listof (Listof Real)))))
(module+ test
(test-case "read-org-tbfs/state"
(check-equal? (read-org-tbfs/state "((1 2 3) (1 1 2))")
(list (tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))))
(: read-org-tbfs/state+headers (-> String (Listof TBF/State)))
(define (read-org-tbfs/state+headers str)
(lists+headers->tbfs/state
(assert-type (read-org-sexp str)
(Pairof (Listof Variable) (Listof (Listof Real))))))
(module+ test
(test-case "read-org-tbfs/state+headers"
(check-equal? (read-org-tbfs/state+headers "((a b f) (1 2 3) (1 1 2))")
(list (tbf/state '#hash((a . 1) (b . 2)) 3)
(tbf/state '#hash((a . 1) (b . 1)) 2)))))
(: tbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real))))
(define (tbfs/state->lists tbfs)
(for/list ([tbf (in-list tbfs)])
(append (hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t)
(list (tbf/state-θ tbf)))))
(module+ test
(test-case "tbfs/state->lists"
(check-equal?
(tbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 3)
(tbf/state (hash 'a -2 'b 1) 1)))
'((1 2 3) (-2 1 1)))))
(: tbfs/state->lists+headers (-> (Listof TBF/State)
(Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tbfs/state->lists+headers tbfs)
(cons (append (hash-map (tbf/state-w (car tbfs))
(λ ([x : Symbol] _) x) #t)
'(θ))
(tbfs/state->lists tbfs)))
(module+ test
(test-case "tbfs/state->list+headers"
(check-equal?
(tbfs/state->lists+headers (list (tbf/state (hash 'a 1 'b 2) 3)
(tbf/state (hash 'a -2 'b 1) 1)))
'((a b θ)
(1 2 3)
(-2 1 1)))))
(: sbfs/state->lists (-> (Listof TBF/State) (Listof (Listof Real))))
(define (sbfs/state->lists tbfs)
(for/list ([tbf (in-list tbfs)])
(append (hash-map (tbf/state-w tbf) (λ (_ [w : Real]) w) #t)
(list (tbf/state-θ tbf)))))
(module+ test
(test-case "tbfs/state->lists"
(check-equal?
(tbfs/state->lists (list (tbf/state (hash 'a 1 'b 2) 3)
(tbf/state (hash 'a -2 'b 1) 1)))
'((1 2 3) (-2 1 1)))))
(: tabulate-tbfs/state (-> (Listof TBF/State) (Listof (Listof Real))))
(define (tabulate-tbfs/state tbfs)
(define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t))
(tabulate-state* (map (curry apply-tbf/state) tbfs)
(make-same-domains vars '(0 1))))
(module+ test
(test-case "tabulate-tbfs/state"
(check-equal? (tabulate-tbfs/state
(list (tbf/state (hash 'a 1 'b 2) 2)
(tbf/state (hash 'a -2 'b 2) 1)))
'((0 0 0 0)
(0 1 0 1)
(1 0 0 0)
(1 1 1 0)))))
(: tabulate-tbfs/state+headers (-> (Listof TBF/State) (Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tabulate-tbfs/state+headers tbfs)
(define vars (hash-map (tbf/state-w (car tbfs)) (λ ([x : Variable] _) x) #t))
(tabulate-state*+headers (map (curry apply-tbf/state) tbfs)
(make-same-domains vars '(0 1))))
(module+ test
(test-case "tabulate-tbfs/state+headers"
(check-equal? (tabulate-tbfs/state+headers
(list (tbf/state (hash 'a 1 'b 2) 2)
(tbf/state (hash 'a -2 'b 2) 1)))
'((a b f1 f2)
(0 0 0 0)
(0 1 0 1)
(1 0 0 0)
(1 1 1 0)))))
(: tabulate-tbf/state (-> TBF/State (Listof (Listof Real))))
(define (tabulate-tbf/state tbf)
(tabulate-tbfs/state (list tbf)))
(module+ test
(test-case "tabulate-tbf/state"
(check-equal? (tabulate-tbf/state (tbf/state (hash 'a 1 'b 2) 2))
'((0 0 0)
(0 1 0)
(1 0 0)
(1 1 1)))))
(: tabulate-tbf/state+headers (-> TBF/State (Pairof (Listof Variable)
(Listof (Listof Real)))))
(define (tabulate-tbf/state+headers tbf)
(tabulate-tbfs/state+headers (list tbf)))
(module+ test
(test-case "tabulate-tbf/state+headers"
(check-equal? (tabulate-tbf/state+headers (tbf/state (hash 'a 1 'b 2) 2))
'((a b f1)
(0 0 0)
(0 1 0)
(1 0 0)
(1 1 1)))))
(: group-truth-table-by-nai (-> (Listof (Listof Integer))
(Listof (Listof (Listof Integer)))))
(define (group-truth-table-by-nai tt)
(: sum (-> (Listof Integer) Integer))
(define (sum xs) (foldl + 0 xs))
(group-by (λ ([row : (Listof Integer)])
(drop-right row 1))
tt
(λ ([in1 : (Listof Integer)] [in2 : (Listof Integer)])
(= (sum in1) (sum in2)))))
(module+ test
(test-case "group-truth-table-by-nai"
(check-equal? (group-truth-table-by-nai '((0 0 0 1)
(0 0 1 1)
(0 1 0 0)
(0 1 1 1)
(1 0 0 0)
(1 0 1 0)
(1 1 0 1)
(1 1 1 0)))
'(((0 0 0 1))
((0 0 1 1) (0 1 0 0) (1 0 0 0))
((0 1 1 1) (1 0 1 0) (1 1 0 1))
((1 1 1 0))))))
(define-type TBN (HashTable Variable TBF/State))
(: sbn? (-> TBN Boolean))
(define (sbn? tbn) (andmap sbf/state? (hash-values tbn)))
(module+ test
(test-case "sbn?"
(define f1 (tbf/state (hash 'a -1 'b 1) 0))
(define f2 (tbf/state (hash 'a -1 'b 1) 1))
(check-true (sbn? (hash 'a f1 'b f1)))
(check-false (sbn? (hash 'a f1 'b f2))))
)
)
(module+ test
(require rackunit))
;;; ===================
;;; TBF/TBN and SBF/SBN
;;; ===================
;;; Applies a TBF to a state.
;;;
;;; The values of the variables of the state are ordered by hash-map
;;; and fed to the TBF in order. The number of the inputs of the TBF
;;; must match the number of variables in the state.
(define (apply-tbf-to-state tbf st)
(apply-tbf tbf (list->vector (hash-map st (λ (_ val) val)))))
(module+ test
(test-case "apply-tbf-to-state"
(define st (hash 'x1 0 'x2 1))
(define f (tbf #(1 1) 1))
(check-equal? (apply-tbf-to-state f st) 0)))
;;; A state TBF is a TBF with named inputs. A state TBF can be
;;; applied to states in an unambiguous ways.
(struct tbf/state (weights threshold) #:transparent)
;;; Shortcuts for acessing fields of a state/tbf.
(define tbf/state-w tbf/state-weights)
(define tbf/state-θ tbf/state-threshold)
;;; Makes a state/tbf from a list of pairs of names of variables and
;;; weights, as well as a threshold.
(define (make-tbf/state pairs threshold)
(tbf/state (make-immutable-hash pairs) threshold))
(module+ test
(test-case "tbf/state"
(define f (make-tbf/state '((x1 . 1) (x2 . 1)) 1))
(check-equal? (tbf/state-w f) #hash((x1 . 1) (x2 . 1)))
(check-equal? (tbf/state-θ f) 1)))
;;; A sign Boolean function (SBF) is a TBF whose threshold is 0.
(define sbf/state? (and/c tbf/state? (λ (tbf) (zero? (tbf/state-θ tbf)))))
(module+ test
(test-case "sbf/state?"
(check-true (sbf/state? (tbf/state #hash((a . -1) (b . 1)) 0)))))
;;; Makes a state/tbf which is an SBF from a list of pairs of names of
;;; variables and weights.
(define (make-sbf/state pairs)
(make-tbf/state pairs 0))
(module+ test
(test-case "make-sbf/state"
(check-equal? (make-sbf/state '((a . -1) (b . 1)))
(make-tbf/state '((a . -1) (b . 1)) 0))))
;;; Applies a state TBF to its inputs.
;;;
;;; Applying a TBF consists in multiplying the weights by the
;;; corresponding inputs and comparing the sum of the products to the
;;; threshold.
;;;
;;; This function is similar to apply-tbf, but applies a state TBF (a
;;; TBF with explicitly named inputs) to a state whose values are 0
;;; and 1.
(define (apply-tbf/state tbf st)
(any->01 (> (foldl + 0 (hash-values
(hash-intersect (tbf/state-w tbf)
st
#:combine *)))
(tbf/state-θ tbf))))
(module+ test
(test-case "apply-tbf/state"
(define st1 (hash 'a 1 'b 0 'c 1))
(define st2 (hash 'a 1 'b 1 'c 0))
(define tbf (make-tbf/state '((a . 2) (b . -2)) 1))
(check-equal? (apply-tbf/state tbf st1) 1)
(check-equal? (apply-tbf/state tbf st2) 0)))
;;; Reads a list of tbf/state from a list of list of numbers.
;;;
;;; The last element of each list is taken to be the threshold of the
;;; TBFs, and the rest of the elements are taken to be the weights.
;;;
;;; If headers is #t, the names of the variables to appear as the
;;; inputs of the TBF are taken from the first list. The last element
;;; of this list is discarded.
;;;
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
(define (lists->tbfs/state lsts #:headers [headers #t])
(define-values (var-names rows)
(if headers
(values (car lsts) (cdr lsts))
(values (for/list ([i (in-range (length (car lsts)))])
(string->symbol (format "x~a" i)))
lsts)))
(for/list ([lst (in-list rows)])
(define-values (ws θ) (split-at-right lst 1))
(make-tbf/state (for/list ([x (in-list var-names)]
[w (in-list ws)])
(cons x w))
(car θ))))
(module+ test
(test-case "lists->tbfs/state"
(define tbfs '((1 2 3) (1 1 2)))
(check-equal? (lists->tbfs/state tbfs #:headers #f)
(list
(tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))
(check-equal? (lists->tbfs/state (cons '(a b f) tbfs))
(list
(tbf/state '#hash((a . 1) (b . 2)) 3)
(tbf/state '#hash((a . 1) (b . 1)) 2)))))
;;; Like lists->tbfs/state, but does not expect thresholds in the
;;; input.
;;;
;;; Every lists in the list contains the weights of the SBF. If
;;; headers is #t, the names of the variables to appear as the inputs
;;; of the TBF are taken from the first list.
;;;
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
(define (lists->sbfs/state lsts #:headers [headers #t])
(define rows (if headers (cdr lsts) lsts))
(define rows-θ (for/list ([lst (in-list rows)]) (append lst '(0))))
(lists->tbfs/state (if headers (cons (car lsts) rows-θ) rows-θ)
#:headers headers))
(module+ test
(test-case "lists->sbfs/state"
(define tbfs '((1 2) (1 -1)))
(check-equal? (lists->sbfs/state tbfs #:headers #f)
(list
(tbf/state '#hash((x0 . 1) (x1 . 2)) 0)
(tbf/state '#hash((x0 . 1) (x1 . -1)) 0)))
(check-equal? (lists->sbfs/state (cons '(a b) tbfs) #:headers #t)
(list
(tbf/state '#hash((a . 1) (b . 2)) 0)
(tbf/state '#hash((a . 1) (b . -1)) 0)))))
;;; Reads a list of tbf/state from an Org-mode string containing a
;;; sexp, containing a list of lists of numbers. As in
;;; lists->tbfs/state, the last element of each list is taken to be
;;; the threshold of the TBFs, and the rest of the elements are taken
;;; to be the weights.
;;;
;;; If headers is #t, the names of the variables to appear as the
;;; inputs of the TBF are taken from the first list. The last element
;;; of this list is discarded.
;;;
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
(define (read-org-tbfs/state str #:headers [headers #t])
(lists->tbfs/state (read-org-sexp str) #:headers headers))
(module+ test
(test-case "read-org-tbfs/state"
(check-equal? (read-org-tbfs/state "((a b f) (1 2 3) (1 1 2))")
(list
(tbf/state '#hash((a . 1) (b . 2)) 3)
(tbf/state '#hash((a . 1) (b . 1)) 2)))
(check-equal? (read-org-tbfs/state "((1 2 3) (1 1 2))" #:headers #f)
(list
(tbf/state '#hash((x0 . 1) (x1 . 2)) 3)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 2)))))
;;; Like read-org-tbfs/state, but reads a list of SBFs. Therefore,
;;; the lists of numbers in the sexp are taken to be the weights of
;;; the SBFs.
;;;
;;; If headers is #t, the names of the variables to appear as the
;;; inputs of the TBF are taken from the first list. If headers is
;;; #f, the names of the variables are generated as xi, where i is the
;;; index of the variable.
(define (read-org-sbfs/state str #:headers [headers #t])
(lists->sbfs/state (read-org-sexp str) #:headers headers))
(module+ test
(test-case "read-org-sbfs/state"
(check-equal? (read-org-sbfs/state "((a b) (-1 2) (1 1))")
(list
(tbf/state '#hash((a . -1) (b . 2)) 0)
(tbf/state '#hash((a . 1) (b . 1)) 0)))
(check-equal? (read-org-sbfs/state "((-1 2) (1 1))" #:headers #f)
(list
(tbf/state '#hash((x0 . -1) (x1 . 2)) 0)
(tbf/state '#hash((x0 . 1) (x1 . 1)) 0)))))
;;; Given a list of tbf/state, produces a sexp that Org-mode can
;;; interpret as a table.
;;;
;;; All tbf/state in the list must have the same inputs. The function
;;; does not check this property.
;;;
;;; If #:headers is #f, does not print the names of the inputs of the
;;; TBFs. If #:headers is #t, the output starts by a list giving the
;;; names of the variables, as well as the symbol 'θ to represent the
;;; column giving the thresholds of the TBF.
(define (print-org-tbfs/state tbfs #:headers [headers #t])
(define table (for/list ([tbf (in-list tbfs)])
(append (hash-map (tbf/state-w tbf) (λ (_ w) w) #t)
(list (tbf/state-θ tbf)))))
(if headers
(cons (append (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t) '(θ))
table)
table))
(module+ test
(test-case "print-org-tbfs/state"
(define tbfs (list (make-tbf/state '((a . 1) (b . 2)) 3)
(make-tbf/state '((a . -2) (b . 1)) 1)))
(check-equal? (print-org-tbfs/state tbfs)
'((a b θ) (1 2 3) (-2 1 1)))))
;;; Like print-org-tbfs/state, but expects a list of SBFs. The
;;; thresholds are therefore not included in the output.
;;;
;;; All sbf/state in the list must have the same inputs. The function
;;; does not check this property.
;;;
;;; If #:headers is #f, does not print the names of the inputs of the
;;; TBFs. If #:headers is #t, the output starts by a list giving the
;;; names of the variables.
(define (print-org-sbfs/state sbfs #:headers [headers #t])
(define table (for/list ([sbf (in-list sbfs)])
(hash-map (tbf/state-w sbf) (λ (_ w) w) #t)))
(if headers
(cons (hash-map (tbf/state-w (car sbfs)) (λ (x _) x) #t)
table)
table))
(module+ test
(define sbfs (list (make-sbf/state '((a . 1) (b . 2)))
(make-sbf/state '((a . -2) (b . 1)))))
(check-equal? (print-org-sbfs/state sbfs)
'((a b) (1 2) (-2 1)))
(check-equal? (print-org-sbfs/state sbfs #:headers #f)
'((1 2) (-2 1))))
;;; Tabulates a list of tbf/state.
;;;
;;; As in the case of tbf-tabulate*, the result is a list of lists
;;; giving the truth tables of the given TBFs. The first elements of
;;; each row give the values of the inputs, while the last elements
;;; give the values of each function corresponding to the input.
;;;
;;; All the TBFs must have exactly the same inputs. This function
;;; does not check this property.
;;;
;;; If #:headers is #t, the output starts by a list giving the names
;;; of the variables, and then the symbols 'fi, where i is the number
;;; of the TBF in the list.
(define (tbf/state-tabulate* tbfs #:headers [headers #t])
(define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t))
(tabulate-state* (map (curry apply-tbf/state) tbfs)
(make-same-domains vars '(0 1))
#:headers headers))
(module+ test
(test-case "tbf/state-tabulate*"
(define tbfs (list (make-tbf/state '((a . 1) (b . 2)) 1)
(make-tbf/state '((a . -2) (b . 3)) 1)))
(check-equal? (tbf/state-tabulate* tbfs)
'((a b f1 f2)
(0 0 0 0)
(0 1 1 1)
(1 0 0 0)
(1 1 1 0)))))
;;; Like tbf/state-tabulate*, but only tabulates a single TBF.
(define (tbf/state-tabulate tbf #:headers [headers #t])
(tbf/state-tabulate* (list tbf) #:headers headers))
(module+ test
(test-case "tbf/state-tabulate"
(define tbf (make-tbf/state '((a . -2) (b . 3)) 1))
(check-equal? (tbf/state-tabulate tbf)
'((a b f1)
(0 0 0)
(0 1 1)
(1 0 0)
(1 1 0)))))
;;; Given a truth table of a Boolean function, groups the lines by the
;;; "number of activated inputs"—the number of inputs which are 1 in
;;; the input vector.
;;;
;;; The truth table must not include the header line.
(define (group-truth-table-by-nai tt)
(define sum (((curry foldl) +) 0))
(group-by (λ (row) (drop-right row 1))
tt
(λ (in1 in2) (= (sum in1) (sum in2)))))
(module+ test
(test-case "group-truth-table-by-nai"
(check-equal? (group-truth-table-by-nai '((0 0 0 1)
(0 0 1 1)
(0 1 0 0)
(0 1 1 1)
(1 0 0 0)
(1 0 1 0)
(1 1 0 1)
(1 1 1 0)))
'(((0 0 0 1))
((0 0 1 1) (0 1 0 0) (1 0 0 0))
((0 1 1 1) (1 0 1 0) (1 1 0 1))
((1 1 1 0))))))
;;; A TBN is a network form mapping variables to tbf/state.
;;;
;;; The tbf/state must only reference variables appearing in the
;;; network. This contract does not check this condition.
(define tbn? (hash/c symbol? tbf/state?))
;;; Builds a TBN from a list of pairs (variable, tbf/state).
(define make-tbn make-immutable-hash)
(module+ test
(test-case "make-tbn"
(define tbf-not (make-tbf/state '((a . -1)) -1))
(define tbf-id (make-sbf/state '((a . 1))))
(check-equal? (make-tbn `((a . ,tbf-not) (b . ,tbf-id)))
(hash 'a (tbf/state '#hash((a . -1)) -1)
'b (tbf/state '#hash((a . 1)) 0)))))
;;; A SBN is a network form mapping variables to sbf/state.
;;;
;;; The tbf/state must only reference variables appearing in the
;;; network. This contract does not check this condition.
(define sbn? (hash/c symbol? sbf/state?))
;;; Builds an SBN from a list of pairs (variable, sbf/state).
(define make-sbn make-immutable-hash)
(module+ test
(test-case "make-sbn"
(define sbf1 (make-sbf/state '((a . -1))))
(define sbf2 (make-sbf/state '((a . 1))))
(check-equal? (make-sbn `((a . ,sbf1) (b . ,sbf2)))
(hash 'a (tbf/state '#hash((a . -1)) 0)
'b (tbf/state '#hash((a . 1)) 0)))))
;;; Constructs a network from a network form defining a TBN.
(define (tbn->network tbn)
(make-01-network (for/hash ([(var tbf) (in-hash tbn)])
(values var ((curry apply-tbf/state) tbf)))))
(module+ test
(test-case "tbn->network"
(define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1))))
(b . ,(make-tbf/state '((a . -1)) -1)))))
(define n (tbn->network tbn))
(define s1 (hash 'a 0 'b 0))
(check-equal? (update n s1 '(a b))
(hash 'a 0 'b 1))
(check-equal? (network-domains n) #hash((a . (0 1)) (b . (0 1))))
(define sbn (make-sbn `((a . ,(make-sbf/state '((b . -1))))
(b . ,(make-sbf/state '((a . 1)))))))
(define sn (tbn->network sbn))
(define s2 (hash 'a 1 'b 1))
(check-equal? (update sn s2 '(a b))
(hash 'a 0 'b 1))
(check-equal? (network-domains sn) #hash((a . (0 1)) (b . (0 1))))))
;;; A helper function for read-org-tbn and read-org-sbn. It reads a
;;; TBN from an Org-mode sexp containing a list of lists of numbers.
;;; As in lists->tbfs/state, the last element of each list is taken to
;;; be the threshold of the TBFs, and the rest of the elements are
;;; taken to be the weights.
;;;
;;; As in read-org-tbfs/state, if headers is #t, the names of the
;;; variables to appear as the inputs of the TBF are taken from the
;;; first list. The last element of this list is discarded.
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
;;;
;;; If func-names is #t, the first element in every row except the
;;; first one, are taken to be the name of the variable to which the
;;; TBF should be associated. If func-names is #f, the functions are
;;; assigned to variables in alphabetical order.
;;;
;;; func-names cannot be #t if headers is #f. The function does not
;;; check this condition.
(define (parse-org-tbn sexp
#:headers [headers #t]
#:func-names [func-names #t])
(cond
[(eq? func-names #t)
(define-values (vars rows) (multi-split-at sexp 1))
(define tbfs (lists->tbfs/state rows #:headers headers))
(for/hash ([tbf (in-list tbfs)] [var (in-list (cdr vars))])
(values (car var) tbf))]
[else
(define tbfs (lists->tbfs/state sexp #:headers headers))
(define vars (hash-map (tbf/state-w (car tbfs)) (λ (x _) x) #t))
(for/hash ([tbf (in-list tbfs)] [var (in-list vars)])
(values var tbf))]))
;;; Reads a TBN from an Org-mode string containing a sexp, containing
;;; a list of lists of numbers. As in lists->tbfs/state, the last
;;; element of each list is taken to be the threshold of the TBFs, and
;;; the rest of the elements are taken to be the weights.
;;;
;;; As in read-org-tbfs/state, if headers is #t, the names of the
;;; variables to appear as the inputs of the TBF are taken from the
;;; first list. The last element of this list is discarded.
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
;;;
;;; If func-names is #t, the first element in every row except the
;;; first one, are taken to be the name of the variable to which the
;;; TBF should be associated. If func-names is #f, the functions are
;;; assigned to variables in alphabetical order.
;;;
;;; func-names cannot be #t if headers is #f. The function does not
;;; check this condition.
(define (read-org-tbn str
#:headers [headers #t]
#:func-names [func-names #t])
(parse-org-tbn (read-org-sexp str)
#:headers headers
#:func-names func-names))
(module+ test
(test-case "read-org-tbn, parse-org-tbn"
(check-equal? (read-org-tbn "((\"-\" \"x\" \"y\" \"θ\") (\"y\" -1 0 -1) (\"x\" 0 -1 -1))")
(hash
'x
(tbf/state '#hash((x . 0) (y . -1)) -1)
'y
(tbf/state '#hash((x . -1) (y . 0)) -1)))
(check-equal? (read-org-tbn "((\"x\" \"y\" \"θ\") (-1 0 -1) (0 -1 -1))" #:func-names #f)
(hash
'x
(tbf/state '#hash((x . -1) (y . 0)) -1)
'y
(tbf/state '#hash((x . 0) (y . -1)) -1)))
(check-equal? (read-org-tbn "((-1 0 -1) (0 -1 -1))" #:headers #f #:func-names #f)
(hash
'x0
(tbf/state '#hash((x0 . -1) (x1 . 0)) -1)
'x1
(tbf/state '#hash((x0 . 0) (x1 . -1)) -1)))))
;;; Like read-org-tbn, but reads an SBN from an Org-mode string
;;; containing a sexp, containing a list of lists of numbers.
;;;
;;; As in read-org-sbfs/state, if headers is #t, the names of the
;;; variables to appear as the inputs of the SBF are taken from the
;;; first list. The last element of this list is discarded.
;;; If headers is #f, the names of the variables are generated as xi,
;;; where i is the index of the variable.
;;;
;;; If func-names is #t, the first element in every row except the
;;; first one, are taken to be the name of the variable to which the
;;; TBF should be associated. If func-names is #f, the functions are
;;; assigned to variables in alphabetical order.
;;;
;;; func-names cannot be #t if headers is #f. The function does not
;;; check this condition.
(define (read-org-sbn str
#:headers [headers #t]
#:func-names [func-names #t])
(define sexp (read-org-sexp str))
;; Inject the 0 thresholds into the rows of the sexp we have just read.
(define (inject-0 rows) (for/list ([row (in-list rows)]) (append row '(0))))
(define sexp-ready (if headers
(cons (car sexp) (inject-0 (cdr sexp)))
(inject-0 sexp)))
(parse-org-tbn sexp-ready #:headers headers #:func-names func-names))
(module+ test
(test-case "read-org-sbn, parse-org-tbn"
(check-equal? (read-org-sbn "((\"-\" \"x\" \"y\") (\"y\" -1 0) (\"x\" 0 -1))")
(hash
'x
(tbf/state '#hash((x . 0) (y . -1)) 0)
'y
(tbf/state '#hash((x . -1) (y . 0)) 0)))
(check-equal? (read-org-sbn "((\"x\" \"y\") (-1 0) (0 -1))" #:func-names #f)
(hash
'x
(tbf/state '#hash((x . -1) (y . 0)) 0)
'y
(tbf/state '#hash((x . 0) (y . -1)) 0)))
(check-equal? (read-org-sbn "((-1 0) (0 -1))" #:headers #f #:func-names #f)
(hash
'x0
(tbf/state '#hash((x0 . -1) (x1 . 0)) 0)
'x1
(tbf/state '#hash((x0 . 0) (x1 . -1)) 0)))))
;;; A shortcut for building the state graphs of TBN.
(define build-tbn-state-graph
(compose pretty-print-state-graph
build-full-state-graph
make-syn-dynamics
tbn->network))
;;; Checks whether a TBN is normalized: whether all of the functions
;;; have the same inputs, and whether these inputs are exactly the
;;; variables of the TBN.
(define (normalized-tbn? tbn)
(define tbn-vars (hash-keys tbn))
(for/and ([tbf (in-list (hash-values tbn))])
(set=? tbn-vars (hash-keys (tbf/state-w tbf)))))
(module+ test
(test-case "normalized-tbn?"
(check-false (normalized-tbn?
(make-tbn `((a . ,(make-sbf/state '((b . 1))))
(b . ,(make-tbf/state '((a . -1)) -1))))))
(check-true (normalized-tbn?
(make-tbn `((a . ,(make-sbf/state '((a . 1) (b . -1))))
(b . ,(make-tbf/state '((a . -1) (b . 1)) -1))))))))
;;; Normalizes a TBN.
;;;
;;; For every TBF, removes the inputs that are not in the variables of
;;; the TBN, and adds missing inputs with 0 weight.
(define (normalize-tbn tbn)
(define vars-0 (for/hash ([(x _) (in-hash tbn)]) (values x 0)))
(define (normalize-tbf tbf)
;; Only keep the inputs which are also the variables of tbn.
(define w-pruned (hash-intersect tbn (tbf/state-w tbf)
#:combine (λ (_ w) w)))
;; Put in the missing inputs with weight 0.
(define w-complete (hash-union vars-0 w-pruned #:combine (λ (_ w) w)))
(tbf/state w-complete (tbf/state-θ tbf)))
(for/hash ([(x tbf) (in-hash tbn)]) (values x (normalize-tbf tbf))))
(module+ test
(test-case "normalize-tbn"
(check-equal? (normalize-tbn
(hash 'a (make-sbf/state '((b . 1) (c . 3)))
'b (make-tbf/state '((a . -1)) -1)))
(hash
'a
(tbf/state '#hash((a . 0) (b . 1)) 0)
'b
(tbf/state '#hash((a . -1) (b . 0)) -1)))))
;;; Compacts (and denormalizes) a TBF by removing all inputs which
;;; are 0.
(define (compact-tbf tbf)
(tbf/state
(for/hash ([(k v) (in-hash (tbf/state-w tbf))]
#:unless (zero? v))
(values k v))
(tbf/state-θ tbf)))
(module+ test
(test-case "compact-tbf"
(check-equal? (compact-tbf (tbf/state (hash 'a 0 'b 1 'c 2 'd 0) 2))
(tbf/state '#hash((b . 1) (c . 2)) 2))))
;;; Compacts a TBN by removing all inputs which are 0 or which are not
;;; variables of the network.
(define (compact-tbn tbn)
(define (remove-0-non-var tbf)
(tbf/state
(for/hash ([(x w) (in-hash (tbf/state-w tbf))]
#:when (hash-has-key? tbn x)
#:unless (zero? w))
(values x w))
(tbf/state-θ tbf)))
(for/hash ([(x tbf) (in-hash tbn)])
(values x (remove-0-non-var tbf))))
(module+ test
(test-case "compact-tbn"
(check-equal?
(compact-tbn (hash 'a (tbf/state (hash 'a 0 'b 1 'c 3 'd 0) 0)
'b (tbf/state (hash 'a -1 'b 1) -1)))
(hash
'a
(tbf/state '#hash((b . 1)) 0)
'b
(tbf/state '#hash((a . -1) (b . 1)) -1)))))
;;; Given TBN, produces a sexp containing the description of the
;;; functions of the TBN that Org-mode can interpret as a table.
;;;
;;; Like print-org-tbfs/state, if #:headers is #f, does not print the
;;; names of the inputs of the TBFs. If #:headers is #t, the output
;;; starts by a list giving the names of the variables, as well as the
;;; symbol 'θ to represent the column giving the thresholds of the
;;; TBF.
;;;
;;; If #:func-names is #t, the first column of the table gives the
;;; variable which the corresponding TBF updates.
;;;
;;; If both #:func-names and #:headers are #t, the first cell of the
;;; first column contains the symbol '-.
(define (print-org-tbn tbn
#:headers [headers #t]
#:func-names [func-names #t])
(define ntbn (normalize-tbn tbn))
(define vars-tbfs (hash-map ntbn (λ (x tbf) (cons x tbf)) #t))
(define tbfs (map cdr vars-tbfs))
(define tbfs-table (print-org-tbfs/state tbfs #:headers headers))
(cond
[(eq? func-names #t)
(define vars (map car vars-tbfs))
(define col-1 (if headers (cons '- vars) vars))
(for/list ([var (in-list col-1)] [row (in-list tbfs-table)])
(cons var row))]
[else
tbfs-table]))
(module+ test
(test-case "print-org-tbn"
(define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1))))
(b . ,(make-tbf/state '((a . -1)) -1)))))
(check-equal? (print-org-tbn tbn)
'((- a b θ) (a 0 1 0) (b -1 0 -1)))))
;;; Given an SBN, produces a sexp containing the description of the
;;; functions of the SBN that Org-mode can interpret as a table.
;;; This function is therefore very similar to print-org-tbn.
;;;
;;; Like print-org-tbfs/state, if #:headers is #f, does not print the
;;; names of the inputs of the TBFs. If #:headers is #t, the output
;;; starts by a list giving the names of the variables.
;;;
;;; If #:func-names is #t, the first column of the table gives the
;;; variable which the corresponding TBF updates.
;;;
;;; If both #:func-names and #:headers are #t, the first cell of the
;;; first column contains the symbol '-.
(define (print-org-sbn sbn
#:headers [headers #t]
#:func-names [func-names #t])
(define tab (print-org-tbn sbn #:headers headers #:func-names func-names))
(define-values (tab-no-θ _) (multi-split-at
tab
(- (length (car tab)) 1)))
tab-no-θ)
(module+ test
(test-case "print-org-sbn"
(define sbn (hash
'a
(tbf/state (hash 'b 2) 0)
'b
(tbf/state (hash 'a 2) 0)))
(check-equal? (print-org-sbn sbn)
'((- a b) (a 0 2) (b 2 0)))))
;;; Given a TBN, constructs its interaction graph. The nodes of this
;;; graph are labeled with pairs (variable name . threshold), while
;;; the edges are labelled with the weights.
;;;
;;; If #:zero-edges is #t, the edges with zero weights will appear in
;;; the interaction graph.
(define (tbn-interaction-graph tbn
#:zero-edges [zero-edges #t])
(define ntbn (normalize-tbn tbn))
(define ig (weighted-graph/directed
(if zero-edges
(for*/list ([(tar tbf) (in-hash ntbn)]
[(src w) (in-hash (tbf/state-w tbf))])
(list w src tar))
(for*/list ([(tar tbf) (in-hash ntbn)]
[(src w) (in-hash (tbf/state-w tbf))]
#:unless (zero? w))
(list w src tar)))))
(update-graph ig #:v-func (λ (x) (cons x (tbf/state-θ (hash-ref ntbn x))))))
(module+ test
(test-case "tbn-interaction-graph"
(define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1))))
(b . ,(make-tbf/state '((a . -1)) -1)))))
(check-equal? (graphviz (tbn-interaction-graph tbn))
"digraph G {\n\tnode0 [label=\"'(b . -1)\\n\"];\n\tnode1 [label=\"'(a . 0)\\n\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n")
(check-equal? (graphviz (tbn-interaction-graph tbn #:zero-edges #f))
"digraph G {\n\tnode0 [label=\"'(b . -1)\\n\"];\n\tnode1 [label=\"'(a . 0)\\n\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n")))
;;; Pretty prints the node labels of the interaction graph of a TBN.
(define (pretty-print-tbn-interaction-graph ig)
(update-graph ig #:v-func (match-lambda
[(cons var weight) (~a var ":" weight)])))
(module+ test
(test-case "pretty-print-tbn-interaction-graph"
(define tbn (make-tbn `((a . ,(make-sbf/state '((b . 1))))
(b . ,(make-tbf/state '((a . -1)) -1)))))
(check-equal? (graphviz (pretty-print-tbn-interaction-graph (tbn-interaction-graph tbn)))
"digraph G {\n\tnode0 [label=\"b:-1\"];\n\tnode1 [label=\"a:0\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t\tnode0 -> node1 [label=\"1\"];\n\t\tnode1 -> node0 [label=\"-1\"];\n\t}\n}\n")))
;;; Given an SBN, constructs its interaction graph. As in
;;; tbn-interaction-graph, the nodes of this graph are labeled with
;;; the variable names, while the edges are labelled with the weights.
;;;
;;; If #:zero-edges is #t, the edges with zero weights will appear in
;;; the interaction graph.
(define (sbn-interaction-graph sbn
#:zero-edges [zero-edges #t])
(update-graph (tbn-interaction-graph sbn #:zero-edges zero-edges)
#:v-func (match-lambda
[(cons var _) var])))
(module+ test
(test-case "sbn-interaction-graph"
(define sbn (hash
'a
(tbf/state (hash 'b 2) 0)
'b
(tbf/state (hash 'a 2) 0)))
(check-equal? (graphviz (sbn-interaction-graph sbn))
"digraph G {\n\tnode0 [label=\"b\"];\n\tnode1 [label=\"a\"];\n\tsubgraph U {\n\t\tedge [dir=none];\n\t\tnode0 -> node1 [label=\"2\"];\n\t\tnode0 -> node0 [label=\"0\"];\n\t\tnode1 -> node1 [label=\"0\"];\n\t}\n\tsubgraph D {\n\t}\n}\n")))