{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.NN
( sigmoidCrossEntropyWithLogits
) where
import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( MonadBuild
, withNameScope
)
import TensorFlow.GenOps.Core ( greaterEqual
, select
, log
, exp
)
import TensorFlow.Tensor ( Tensor(..)
, render
, Value
)
import TensorFlow.Types ( TensorType(..)
, OneOf
)
import TensorFlow.Ops ( zerosLike
, add
, mul
, neg
)
sigmoidCrossEntropyWithLogits
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
=> Tensor Value a
-> Tensor Value a
-> m (Tensor Value a)
sigmoidCrossEntropyWithLogits :: Tensor Value a -> Tensor Value a -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits :: Tensor Value a
logits targets :: Tensor Value a
targets = do
let zeros :: Tensor Build a
zeros = Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
zerosLike Tensor Value a
logits
cond :: Tensor Build Bool
cond = Tensor Value a
logits Tensor Value a -> Tensor Build a -> Tensor Build Bool
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8, Double,
Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build Bool
`greaterEqual` Tensor Build a
zeros
relu_logits :: Tensor Build a
relu_logits = Tensor Build Bool
-> Tensor Value a -> Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
select Tensor Build Bool
cond Tensor Value a
logits Tensor Build a
zeros
neg_abs_logits :: Tensor Build a
neg_abs_logits = Tensor Build Bool
-> Tensor Build a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *) t.
TensorType t =>
Tensor v'1 Bool -> Tensor v'2 t -> Tensor v'3 t -> Tensor Build t
select Tensor Build Bool
cond (Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Double, Float]
t =>
Tensor v'1 t -> Tensor Build t
neg Tensor Value a
logits) Tensor Value a
logits
Text -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
withNameScope "logistic_loss" (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ do
Tensor Value a
left <- Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a
relu_logits Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
- Tensor Value a
logits Tensor Value a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`mul` Tensor Value a
targets
Tensor Value a
right <- Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
log (1 Tensor Build a -> Tensor Build a -> Tensor Build a
forall a. Num a => a -> a -> a
+ Tensor Build a -> Tensor Build a
forall (v'1 :: * -> *) t.
OneOf '[Complex Double, Complex Float, Word16, Double, Float] t =>
Tensor v'1 t -> Tensor Build t
exp Tensor Build a
neg_abs_logits)
Text -> m (Tensor Value a) -> m (Tensor Value a)
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
withNameScope "sigmoid_add" (m (Tensor Value a) -> m (Tensor Value a))
-> m (Tensor Value a) -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Build a -> m (Tensor Value a)
forall (m :: * -> *) a.
MonadBuild m =>
Tensor Build a -> m (Tensor Value a)
render (Tensor Build a -> m (Tensor Value a))
-> Tensor Build a -> m (Tensor Value a)
forall a b. (a -> b) -> a -> b
$ Tensor Value a
left Tensor Value a -> Tensor Value a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, ByteString, Int16, Int32, Int64,
Int8, Word16, Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`add` Tensor Value a
right