-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module TensorFlow.NN
    ( sigmoidCrossEntropyWithLogits
    ) where

import Prelude hiding           ( log
                                , exp
                                )
import TensorFlow.Build         ( MonadBuild
                                , withNameScope
                                )
import TensorFlow.GenOps.Core   ( greaterEqual
                                , select
                                , log
                                , exp
                                )
import TensorFlow.Tensor        ( Tensor(..)
                                , render
                                , Value
                                )
import TensorFlow.Types         ( TensorType(..)
                                , OneOf
                                )
import TensorFlow.Ops           ( zerosLike
                                , add
                                , mul
                                , neg
                                )

-- | Computes sigmoid cross entropy given `logits`.
--
-- Measures the probability error in discrete classification tasks in which each
-- class is independent and not mutually exclusive.  For instance, one could
-- perform multilabel classification where a picture can contain both an elephant
-- and a dog at the same time.
--
-- For brevity, let `x = logits`, `z = targets`.  The logistic loss is
--
--        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
--      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
--      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
--      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
--      = (1 - z) * x + log(1 + exp(-x))
--      = x - x * z + log(1 + exp(-x))
--
--  For x < 0, to avoid overflow in exp(-x), we reformulate the above
--
--        x - x * z + log(1 + exp(-x))
--      = log(exp(x)) - x * z + log(1 + exp(-x))
--      = - x * z + log(1 + exp(x))
--
--  Hence, to ensure stability and avoid overflow, the implementation uses this
--  equivalent formulation
--
--      max(x, 0) - x * z + log(1 + exp(-abs(x)))
--
--  `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
  :: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
     => Tensor Value a          -- ^ __logits__
     -> Tensor Value a          -- ^ __targets__
     -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do
    let zeros = zerosLike logits
        cond = logits `greaterEqual` zeros
        relu_logits = select cond logits zeros
        neg_abs_logits = select cond (neg logits) logits
    withNameScope "logistic_loss" $ do
        left <- render $ relu_logits - logits `mul` targets
        right <- render $ log (1 + exp neg_abs_logits)
        withNameScope "sigmoid_add" $ render $ left `add` right