diff --git a/scribblings/tbn.scrbl b/scribblings/tbn.scrbl index cfc3eb8..9bf3145 100644 --- a/scribblings/tbn.scrbl +++ b/scribblings/tbn.scrbl @@ -468,6 +468,17 @@ the variables of @racket[tbn]. 'y (tbf/state (hash 'y 0) -1))) ]} +@defproc[(normalize-tbn (tbn TBF)) TBN]{ + +Normalizes @racket[tbn]: for every @racket[TBF/State], removes the +inputs that are not in the variables of @racket[tbn], and adds missing +inputs with 0 weight. + +@ex[ +(normalize-tbn (hash 'x (tbf/state (hash 'x 2) -1) + 'y (tbf/state (hash 'y 3) 1))) +]} + @section{Miscellaneous utilities} @defproc[(group-truth-table-by-nai [tt (Listof (Listof Integer))]) diff --git a/tbn.rkt b/tbn.rkt index 53d3a91..f524cc1 100644 --- a/tbn.rkt +++ b/tbn.rkt @@ -11,6 +11,16 @@ (require/typed racket/hash [hash-intersect + (->* ((HashTable Variable Real)) + (#:combine (-> Real Real Real)) + #:rest (HashTable Variable Real) + (HashTable Variable Real))] + [(hash-intersect hash-intersect/tbn-tbf) + (->* ((HashTable Variable TBF/State)) + (#:combine (-> TBF/State Real Real)) + #:rest (HashTable Variable Real) + (HashTable Variable Real))] + [hash-union (->* ((HashTable Variable Real)) (#:combine (-> Real Real Real)) #:rest (HashTable Variable Real) @@ -38,7 +48,7 @@ TBN sbn? tbn->network parse-org-tbn read-org-tbn read-org-sbn - build-tbn-state-graph normalized-tbn? + build-tbn-state-graph normalized-tbn? normalize-tbn ) (: apply-tbf-to-state (-> TBF (State (U Zero One)) (U Zero One))) @@ -529,6 +539,33 @@ (check-false (normalized-tbn? (hash 'x (tbf/state (hash 'x 0 ) -1) 'y (tbf/state (hash 'y 0) -1)))))) + + (: normalize-tbn (-> TBN TBN)) + (define (normalize-tbn tbn) + (define vars-0 (for/hash : (VariableMapping Real) + ([(x _) (in-hash tbn)]) + (values x 0))) + (: normalize-tbf (-> TBF/State TBF/State)) + (define (normalize-tbf tbf) + ;; Only keep the inputs which are also the variables of tbn. + (define w-pruned (hash-intersect/tbn-tbf + tbn + (tbf/state-w tbf) + #:combine (λ (_ w) w))) + ;; Put in the missing inputs with weight 0. + (define w-complete + (assert-type (hash-union vars-0 w-pruned #:combine (λ (_ w) w)) + (VariableMapping Real))) + (tbf/state w-complete (tbf/state-θ tbf))) + (for/hash : TBN ([(x tbf) (in-hash tbn)]) + (values x (normalize-tbf tbf)))) + + (module+ test + (test-case "normalize-tbn" + (check-equal? (normalize-tbn (hash 'x (tbf/state (hash 'x 2) -1) + 'y (tbf/state (hash 'y 3) 1))) + (hash 'x (tbf/state (hash 'x 2 'y 0) -1) + 'y (tbf/state (hash 'x 0 'y 3) 1))))) ) (module+ test