mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Starting NN library (#11)
* Starting NN library - Added "sigmoidCrossEntropyWithLogits" - Ported across a single test
This commit is contained in:
parent
03a3a6d086
commit
b2795d7518
5 changed files with 254 additions and 0 deletions
|
@ -10,6 +10,7 @@ packages:
|
|||
- tensorflow-mnist
|
||||
- tensorflow-mnist-input-data
|
||||
- tensorflow-queue
|
||||
- tensorflow-nn
|
||||
|
||||
extra-deps:
|
||||
# proto-lens is not yet in Stackage.
|
||||
|
|
3
tensorflow-nn/Setup.hs
Normal file
3
tensorflow-nn/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
|||
import Distribution.Simple
|
||||
|
||||
main = defaultMain
|
87
tensorflow-nn/src/TensorFlow/NN.hs
Normal file
87
tensorflow-nn/src/TensorFlow/NN.hs
Normal file
|
@ -0,0 +1,87 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module TensorFlow.NN
|
||||
( sigmoidCrossEntropyWithLogits
|
||||
) where
|
||||
|
||||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build(..)
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
import TensorFlow.GenOps.Core ( greaterEqual
|
||||
, select
|
||||
, log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Tensor ( Tensor(..)
|
||||
, Value(..)
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
, OneOf
|
||||
)
|
||||
import TensorFlow.Ops ( zerosLike
|
||||
, add
|
||||
)
|
||||
|
||||
-- | Computes sigmoid cross entropy given `logits`.
|
||||
--
|
||||
-- Measures the probability error in discrete classification tasks in which each
|
||||
-- class is independent and not mutually exclusive. For instance, one could
|
||||
-- perform multilabel classification where a picture can contain both an elephant
|
||||
-- and a dog at the same time.
|
||||
--
|
||||
-- For brevity, let `x = logits`, `z = targets`. The logistic loss is
|
||||
--
|
||||
-- z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
-- = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
||||
-- = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
|
||||
-- = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
|
||||
-- = (1 - z) * x + log(1 + exp(-x))
|
||||
-- = x - x * z + log(1 + exp(-x))
|
||||
--
|
||||
-- For x < 0, to avoid overflow in exp(-x), we reformulate the above
|
||||
--
|
||||
-- x - x * z + log(1 + exp(-x))
|
||||
-- = log(exp(x)) - x * z + log(1 + exp(-x))
|
||||
-- = - x * z + log(1 + exp(x))
|
||||
--
|
||||
-- Hence, to ensure stability and avoid overflow, the implementation uses this
|
||||
-- equivalent formulation
|
||||
--
|
||||
-- max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
--
|
||||
-- `logits` and `targets` must have the same type and shape.
|
||||
sigmoidCrossEntropyWithLogits
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
=> Tensor Value a -- ^ __logits__
|
||||
-> Tensor Value a -- ^ __targets__
|
||||
-> Build (Tensor Value a)
|
||||
sigmoidCrossEntropyWithLogits logits targets = do
|
||||
logits' <- render logits
|
||||
targets' <- render targets
|
||||
let zeros = zerosLike logits'
|
||||
cond = logits' `greaterEqual` zeros
|
||||
relu_logits = select cond logits' zeros
|
||||
neg_abs_logits = select cond (-logits') logits'
|
||||
withNameScope "logistic_loss" $ do
|
||||
left <- render $ relu_logits - logits' * targets'
|
||||
right <- render $ log (1 + exp neg_abs_logits)
|
||||
withNameScope "sigmoid_add" $ render $ left `add` right
|
44
tensorflow-nn/tensorflow-nn.cabal
Normal file
44
tensorflow-nn/tensorflow-nn.cabal
Normal file
|
@ -0,0 +1,44 @@
|
|||
name: tensorflow-nn
|
||||
version: 0.1.0.0
|
||||
synopsis: Friendly layer around TensorFlow bindings.
|
||||
description: Please see README.md
|
||||
homepage: https://github.com/tensorflow/haskell#readme
|
||||
license: Apache
|
||||
author: TensorFlow authors
|
||||
maintainer: tensorflow-haskell@googlegroups.com
|
||||
copyright: Google Inc.
|
||||
category: Machine Learning
|
||||
build-type: Simple
|
||||
cabal-version: >=1.22
|
||||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.NN
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, tensorflow-core-ops == 0.1.*
|
||||
, tensorflow == 0.1.*
|
||||
, tensorflow-ops == 0.1.*
|
||||
default-language: Haskell2010
|
||||
|
||||
|
||||
Test-Suite NNTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: NNTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, QuickCheck
|
||||
, base
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, tensorflow-nn
|
||||
, google-shim
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, vector
|
||||
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
119
tensorflow-nn/tests/NNTest.hs
Normal file
119
tensorflow-nn/tests/NNTest.hs
Normal file
|
@ -0,0 +1,119 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@?))
|
||||
import Test.HUnit.Lang (Assertion(..))
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.NN as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- | These tests are ported from:
|
||||
--
|
||||
-- <tensorflow>/tensorflow/python/ops/nn_xent_tests.py
|
||||
--
|
||||
-- This is the implementation we use to check the implementation we
|
||||
-- wrote in `TensorFlow.NN.sigmoidCrossEntropyWithLogits`.
|
||||
--
|
||||
sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
|
||||
sigmoidXentWithLogits logits' targets' =
|
||||
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
||||
eps = 0.0001
|
||||
pred = map (\p -> min (max p eps) (1 - eps)) sig
|
||||
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
||||
in zipWith xent pred targets'
|
||||
|
||||
|
||||
data Inputs = Inputs {
|
||||
logits :: [Float]
|
||||
, targets :: [Float]
|
||||
}
|
||||
|
||||
|
||||
defInputs :: Inputs
|
||||
defInputs = Inputs {
|
||||
logits = [-100, -2, -2, 0, 2, 2, 2, 100]
|
||||
, targets = [ 0, 0, 1, 0, 0, 1, 0.5, 1]
|
||||
}
|
||||
|
||||
|
||||
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||
++ "\ntolerance: " ++ show tol)
|
||||
where
|
||||
absDiff x y = abs (x - y)
|
||||
tol = 0.001 :: Float
|
||||
|
||||
|
||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||
let inputs = defInputs
|
||||
vLogits = TF.vector $ logits inputs
|
||||
vTargets = TF.vector $ targets inputs
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
|
||||
r <- run tfLoss
|
||||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
testLogisticOutputMultipleDim =
|
||||
testCase "testLogisticOutputMultipleDim" $ do
|
||||
let inputs = defInputs
|
||||
shape = [2, 2, 2]
|
||||
vLogits = TF.constant shape (logits inputs)
|
||||
vTargets = TF.constant shape (targets inputs)
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
|
||||
r <- run tfLoss
|
||||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||
vLogits = TF.vector $ logits inputs
|
||||
vTargets = TF.vector $ targets inputs
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
|
||||
r <- run $ do
|
||||
l <- tfLoss
|
||||
TF.gradients l [vLogits]
|
||||
|
||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||
|
||||
|
||||
run = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testLogisticOutput
|
||||
, testLogisticOutputMultipleDim
|
||||
, testGradientAtZero
|
||||
]
|
Loading…
Reference in a new issue