mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 11:15:03 +01:00
2b5e41ffeb
* Enforce pedantic build mode in CI. * Our imports drifted really far from where they should be.
87 lines
3.4 KiB
Haskell
87 lines
3.4 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
module TensorFlow.NN
|
|
( sigmoidCrossEntropyWithLogits
|
|
) where
|
|
|
|
import Prelude hiding ( log
|
|
, exp
|
|
)
|
|
import TensorFlow.Build ( Build
|
|
, render
|
|
, withNameScope
|
|
)
|
|
import TensorFlow.GenOps.Core ( greaterEqual
|
|
, select
|
|
, log
|
|
, exp
|
|
)
|
|
import TensorFlow.Tensor ( Tensor(..)
|
|
, Value
|
|
)
|
|
import TensorFlow.Types ( TensorType(..)
|
|
, OneOf
|
|
)
|
|
import TensorFlow.Ops ( zerosLike
|
|
, add
|
|
)
|
|
|
|
-- | Computes sigmoid cross entropy given `logits`.
|
|
--
|
|
-- Measures the probability error in discrete classification tasks in which each
|
|
-- class is independent and not mutually exclusive. For instance, one could
|
|
-- perform multilabel classification where a picture can contain both an elephant
|
|
-- and a dog at the same time.
|
|
--
|
|
-- For brevity, let `x = logits`, `z = targets`. The logistic loss is
|
|
--
|
|
-- z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
|
-- = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
|
-- = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
|
|
-- = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
|
|
-- = (1 - z) * x + log(1 + exp(-x))
|
|
-- = x - x * z + log(1 + exp(-x))
|
|
--
|
|
-- For x < 0, to avoid overflow in exp(-x), we reformulate the above
|
|
--
|
|
-- x - x * z + log(1 + exp(-x))
|
|
-- = log(exp(x)) - x * z + log(1 + exp(-x))
|
|
-- = - x * z + log(1 + exp(x))
|
|
--
|
|
-- Hence, to ensure stability and avoid overflow, the implementation uses this
|
|
-- equivalent formulation
|
|
--
|
|
-- max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
|
--
|
|
-- `logits` and `targets` must have the same type and shape.
|
|
sigmoidCrossEntropyWithLogits
|
|
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
|
=> Tensor Value a -- ^ __logits__
|
|
-> Tensor Value a -- ^ __targets__
|
|
-> Build (Tensor Value a)
|
|
sigmoidCrossEntropyWithLogits logits targets = do
|
|
logits' <- render logits
|
|
targets' <- render targets
|
|
let zeros = zerosLike logits'
|
|
cond = logits' `greaterEqual` zeros
|
|
relu_logits = select cond logits' zeros
|
|
neg_abs_logits = select cond (-logits') logits'
|
|
withNameScope "logistic_loss" $ do
|
|
left <- render $ relu_logits - logits' * targets'
|
|
right <- render $ log (1 + exp neg_abs_logits)
|
|
withNameScope "sigmoid_add" $ render $ left `add` right
|