diff --git a/scribblings/tbn.scrbl b/scribblings/tbn.scrbl index 88780be..4dc464b 100644 --- a/scribblings/tbn.scrbl +++ b/scribblings/tbn.scrbl @@ -91,3 +91,24 @@ threshold is 0. (sbf/state? (tbf/state (hash 'a -1 'b 1) 0)) (sbf/state? (tbf/state (hash 'a -1 'b 1) 1)) ]} + +@defproc[(apply-tbf/state [tbfs TBF/State] + [st (State (U Zero One))]) + (U Zero One)]{ + +Applies a @racket[TBF/State] to its inputs given by the state +@racket[st]. + +Applying a TBF consists in multiplying the weights by the +corresponding inputs and comparing the sum of the products to +the threshold. + +This function is similar to @racket[apply-tbf], but because it applies +a @racket[TBF/State] to a @racket[(State (U Zero One))], it avoids +potential mismatches between weights and the corresponding +input values. + +@ex[ +(apply-tbf/state (tbf/state (hash 'a 2 'b -2) 1) + (hash 'a 1 'b 0 'c 1)) +]} diff --git a/tbn.rkt b/tbn.rkt index 4973b12..e3ac60d 100644 --- a/tbn.rkt +++ b/tbn.rkt @@ -9,6 +9,13 @@ "utils.rkt" "functions.rkt" "networks.rkt" typed/graph typed/racket/random) + (require/typed racket/hash + [hash-intersect + (->* ((HashTable Variable Real)) + (#:combine (-> Real Real Real)) + #:rest (HashTable Variable Real) + (HashTable Variable Real))]) + (module+ test (require typed/rackunit)) @@ -16,7 +23,7 @@ apply-tbf-to-state (struct-out tbf/state) TBF/State tbf/state-w tbf/state-θ make-tbf/state - sbf/state? + sbf/state? apply-tbf/state ) (: apply-tbf-to-state (-> TBF (State (U Zero One)) (U Zero One))) @@ -54,6 +61,20 @@ (test-case "sbf/state?" (check-true (sbf/state? (tbf/state (hash 'a -1 'b 1) 0))) (check-false (sbf/state? (tbf/state (hash 'a -1 'b 1) 1))))) + + (: apply-tbf/state (-> TBF/State (State (U Zero One)) (U Zero One))) + (define (apply-tbf/state tbfs st) + (any->01 + (> (apply + (hash-values (hash-intersect (tbf/state-w tbfs) st #:combine *))) + (tbf/state-θ tbfs)))) + + (module+ test + (test-case "apply-tbf/state" + (define st1 (hash 'a 1 'b 0 'c 1)) + (define st2 (hash 'a 1 'b 1 'c 0)) + (define tbf (make-tbf/state '((a . 2) (b . -2)) 1)) + (check-equal? (apply-tbf/state tbf st1) 1) + (check-equal? (apply-tbf/state tbf st2) 0))) ) (module+ test