mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
imports and haddock
This commit is contained in:
parent
3b450787ae
commit
3b63a065a6
|
@ -19,17 +19,61 @@ module TensorFlow.NN
|
|||
( sigmoidCrossEntropyWithLogits
|
||||
) where
|
||||
|
||||
import Prelude hiding (log, exp)
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.GenOps.Core (greaterEqual, select, log, exp)
|
||||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build(..)
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
import TensorFlow.GenOps.Core ( greaterEqual
|
||||
, select
|
||||
, log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Tensor ( Tensor(..)
|
||||
, Value(..)
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
, OneOf
|
||||
)
|
||||
import TensorFlow.Ops ( zerosLike
|
||||
, add
|
||||
)
|
||||
|
||||
-- | Computes sigmoid cross entropy given `logits`.
|
||||
--
|
||||
-- Measures the probability error in discrete classification tasks in which each
|
||||
-- class is independent and not mutually exclusive. For instance, one could
|
||||
-- perform multilabel classification where a picture can contain both an elephant
|
||||
-- and a dog at the same time.
|
||||
--
|
||||
-- For brevity, let `x = logits`, `z = targets`. The logistic loss is
|
||||
--
|
||||
-- z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
-- = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
||||
-- = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
|
||||
-- = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
|
||||
-- = (1 - z) * x + log(1 + exp(-x))
|
||||
-- = x - x * z + log(1 + exp(-x))
|
||||
--
|
||||
-- For x < 0, to avoid overflow in exp(-x), we reformulate the above
|
||||
--
|
||||
-- x - x * z + log(1 + exp(-x))
|
||||
-- = log(exp(x)) - x * z + log(1 + exp(-x))
|
||||
-- = - x * z + log(1 + exp(x))
|
||||
--
|
||||
-- Hence, to ensure stability and avoid overflow, the implementation uses this
|
||||
-- equivalent formulation
|
||||
--
|
||||
-- max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
--
|
||||
-- `logits` and `targets` must have the same type and shape.
|
||||
sigmoidCrossEntropyWithLogits
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a) =>
|
||||
Tensor Value a
|
||||
-> Tensor Value a -> Build (Tensor Value a)
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
=> Tensor Value a -- ^ __logits__
|
||||
-> Tensor Value a -- ^ __targets__
|
||||
-> Build (Tensor Value a)
|
||||
sigmoidCrossEntropyWithLogits logits targets = do
|
||||
let zeros = zerosLike logits
|
||||
cond = logits `greaterEqual` zeros
|
||||
|
|
Loading…
Reference in New Issue
Block a user