2016-10-28 03:05:27 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
2016-11-18 19:42:02 +01:00
|
|
|
{-# LANGUAGE OverloadedLists #-}
|
2016-10-28 03:05:27 +02:00
|
|
|
|
|
|
|
module Main where
|
|
|
|
|
|
|
|
import Google.Test (googleTest)
|
2016-11-17 22:54:36 +01:00
|
|
|
import TensorFlow.Test (assertAllClose)
|
2016-11-18 19:42:02 +01:00
|
|
|
import Test.Framework (Test)
|
2016-10-28 03:05:27 +02:00
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
|
|
import qualified Data.Vector as V
|
|
|
|
import qualified TensorFlow.Gradient as TF
|
|
|
|
import qualified TensorFlow.NN as TF
|
|
|
|
import qualified TensorFlow.Ops as TF
|
2017-04-07 00:10:33 +02:00
|
|
|
import qualified TensorFlow.Core as TF
|
2016-10-28 03:05:27 +02:00
|
|
|
|
|
|
|
-- | These tests are ported from:
|
|
|
|
--
|
|
|
|
-- <tensorflow>/tensorflow/python/ops/nn_xent_tests.py
|
|
|
|
--
|
|
|
|
-- This is the implementation we use to check the implementation we
|
|
|
|
-- wrote in `TensorFlow.NN.sigmoidCrossEntropyWithLogits`.
|
|
|
|
--
|
|
|
|
sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
|
|
|
|
sigmoidXentWithLogits logits' targets' =
|
|
|
|
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
|
|
|
eps = 0.0001
|
2016-11-18 19:42:02 +01:00
|
|
|
predictions = map (\p -> min (max p eps) (1 - eps)) sig
|
2016-10-28 03:05:27 +02:00
|
|
|
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
2016-11-18 19:42:02 +01:00
|
|
|
in zipWith xent predictions targets'
|
2016-10-28 03:05:27 +02:00
|
|
|
|
|
|
|
|
|
|
|
data Inputs = Inputs {
|
|
|
|
logits :: [Float]
|
|
|
|
, targets :: [Float]
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
defInputs :: Inputs
|
|
|
|
defInputs = Inputs {
|
|
|
|
logits = [-100, -2, -2, 0, 2, 2, 2, 100]
|
|
|
|
, targets = [ 0, 0, 1, 0, 0, 1, 0.5, 1]
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
testLogisticOutput :: Test
|
2016-10-28 03:05:27 +02:00
|
|
|
testLogisticOutput = testCase "testLogisticOutput" $ do
|
|
|
|
let inputs = defInputs
|
2017-04-07 00:10:33 +02:00
|
|
|
r <- run $ do
|
|
|
|
vLogits <- TF.render $ TF.vector $ logits inputs
|
|
|
|
vTargets <- TF.render $ TF.vector $ targets inputs
|
|
|
|
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
|
|
|
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
2016-10-28 03:05:27 +02:00
|
|
|
assertAllClose r ourLoss
|
|
|
|
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
testLogisticOutputMultipleDim :: Test
|
2016-10-28 03:05:27 +02:00
|
|
|
testLogisticOutputMultipleDim =
|
|
|
|
testCase "testLogisticOutputMultipleDim" $ do
|
2016-10-29 03:08:32 +02:00
|
|
|
let inputs = defInputs
|
2016-10-28 03:05:27 +02:00
|
|
|
shape = [2, 2, 2]
|
2017-04-07 00:10:33 +02:00
|
|
|
r <- run $ do
|
|
|
|
vLogits <- TF.render $ TF.constant shape (logits inputs)
|
|
|
|
vTargets <- TF.render $ TF.constant shape (targets inputs)
|
|
|
|
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
|
|
|
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
2016-10-28 03:05:27 +02:00
|
|
|
assertAllClose r ourLoss
|
|
|
|
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
testGradientAtZero :: Test
|
2016-10-28 03:05:27 +02:00
|
|
|
testGradientAtZero = testCase "testGradientAtZero" $ do
|
|
|
|
r <- run $ do
|
2017-04-07 00:10:33 +02:00
|
|
|
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
|
|
|
vTargets <- TF.render $ TF.vector $ targets inputs
|
|
|
|
vLogits <- TF.render $ TF.vector $ logits inputs
|
|
|
|
let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
|
|
|
|
2016-10-28 03:05:27 +02:00
|
|
|
l <- tfLoss
|
|
|
|
TF.gradients l [vLogits]
|
|
|
|
|
|
|
|
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
|
|
|
|
2017-03-18 20:08:53 +01:00
|
|
|
run :: TF.Fetchable t a => TF.Session t -> IO a
|
|
|
|
run = TF.runSession . (>>= TF.run)
|
2016-10-28 03:05:27 +02:00
|
|
|
|
|
|
|
main :: IO ()
|
2016-10-29 03:08:32 +02:00
|
|
|
main = googleTest [ testGradientAtZero
|
|
|
|
, testLogisticOutput
|
2016-10-28 03:05:27 +02:00
|
|
|
, testLogisticOutputMultipleDim
|
|
|
|
]
|