1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-25 10:15:00 +01:00
tensorflow-haskell/tensorflow-nn/tests/NNTest.hs
Judah Jacobson 2c5c879037 Introduce a MonadBuild class, and remove buildAnd. (#83)
This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
2017-03-18 12:08:53 -07:00

106 lines
3.6 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
module Main where
import Google.Test (googleTest)
import TensorFlow.Test (assertAllClose)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
-- | These tests are ported from:
--
-- <tensorflow>/tensorflow/python/ops/nn_xent_tests.py
--
-- This is the implementation we use to check the implementation we
-- wrote in `TensorFlow.NN.sigmoidCrossEntropyWithLogits`.
--
sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
sigmoidXentWithLogits logits' targets' =
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
eps = 0.0001
predictions = map (\p -> min (max p eps) (1 - eps)) sig
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
in zipWith xent predictions targets'
data Inputs = Inputs {
logits :: [Float]
, targets :: [Float]
}
defInputs :: Inputs
defInputs = Inputs {
logits = [-100, -2, -2, 0, 2, 2, 2, 100]
, targets = [ 0, 0, 1, 0, 0, 1, 0.5, 1]
}
testLogisticOutput :: Test
testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs
vLogits = TF.vector $ logits inputs
vTargets = TF.vector $ targets inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
assertAllClose r ourLoss
testLogisticOutputMultipleDim :: Test
testLogisticOutputMultipleDim =
testCase "testLogisticOutputMultipleDim" $ do
let inputs = defInputs
shape = [2, 2, 2]
vLogits = TF.constant shape (logits inputs)
vTargets = TF.constant shape (targets inputs)
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
assertAllClose r ourLoss
testGradientAtZero :: Test
testGradientAtZero = testCase "testGradientAtZero" $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vLogits = TF.vector $ logits inputs
vTargets = TF.vector $ targets inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
r <- run $ do
l <- tfLoss
TF.gradients l [vLogits]
assertAllClose (head r) (V.fromList [0.5, -0.5])
run :: TF.Fetchable t a => TF.Session t -> IO a
run = TF.runSession . (>>= TF.run)
main :: IO ()
main = googleTest [ testGradientAtZero
, testLogisticOutput
, testLogisticOutputMultipleDim
]