mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Starting NN library (#11)
* Starting NN library - Added "sigmoidCrossEntropyWithLogits" - Ported across a single test
This commit is contained in:
parent
03a3a6d086
commit
b2795d7518
5 changed files with 254 additions and 0 deletions
|
@ -10,6 +10,7 @@ packages:
|
||||||
- tensorflow-mnist
|
- tensorflow-mnist
|
||||||
- tensorflow-mnist-input-data
|
- tensorflow-mnist-input-data
|
||||||
- tensorflow-queue
|
- tensorflow-queue
|
||||||
|
- tensorflow-nn
|
||||||
|
|
||||||
extra-deps:
|
extra-deps:
|
||||||
# proto-lens is not yet in Stackage.
|
# proto-lens is not yet in Stackage.
|
||||||
|
|
3
tensorflow-nn/Setup.hs
Normal file
3
tensorflow-nn/Setup.hs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import Distribution.Simple
|
||||||
|
|
||||||
|
main = defaultMain
|
87
tensorflow-nn/src/TensorFlow/NN.hs
Normal file
87
tensorflow-nn/src/TensorFlow/NN.hs
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module TensorFlow.NN
|
||||||
|
( sigmoidCrossEntropyWithLogits
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Prelude hiding ( log
|
||||||
|
, exp
|
||||||
|
)
|
||||||
|
import TensorFlow.Build ( Build(..)
|
||||||
|
, render
|
||||||
|
, withNameScope
|
||||||
|
)
|
||||||
|
import TensorFlow.GenOps.Core ( greaterEqual
|
||||||
|
, select
|
||||||
|
, log
|
||||||
|
, exp
|
||||||
|
)
|
||||||
|
import TensorFlow.Tensor ( Tensor(..)
|
||||||
|
, Value(..)
|
||||||
|
)
|
||||||
|
import TensorFlow.Types ( TensorType(..)
|
||||||
|
, OneOf
|
||||||
|
)
|
||||||
|
import TensorFlow.Ops ( zerosLike
|
||||||
|
, add
|
||||||
|
)
|
||||||
|
|
||||||
|
-- | Computes sigmoid cross entropy given `logits`.
|
||||||
|
--
|
||||||
|
-- Measures the probability error in discrete classification tasks in which each
|
||||||
|
-- class is independent and not mutually exclusive. For instance, one could
|
||||||
|
-- perform multilabel classification where a picture can contain both an elephant
|
||||||
|
-- and a dog at the same time.
|
||||||
|
--
|
||||||
|
-- For brevity, let `x = logits`, `z = targets`. The logistic loss is
|
||||||
|
--
|
||||||
|
-- z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
|
-- = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
|
||||||
|
-- = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
|
||||||
|
-- = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
|
||||||
|
-- = (1 - z) * x + log(1 + exp(-x))
|
||||||
|
-- = x - x * z + log(1 + exp(-x))
|
||||||
|
--
|
||||||
|
-- For x < 0, to avoid overflow in exp(-x), we reformulate the above
|
||||||
|
--
|
||||||
|
-- x - x * z + log(1 + exp(-x))
|
||||||
|
-- = log(exp(x)) - x * z + log(1 + exp(-x))
|
||||||
|
-- = - x * z + log(1 + exp(x))
|
||||||
|
--
|
||||||
|
-- Hence, to ensure stability and avoid overflow, the implementation uses this
|
||||||
|
-- equivalent formulation
|
||||||
|
--
|
||||||
|
-- max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||||
|
--
|
||||||
|
-- `logits` and `targets` must have the same type and shape.
|
||||||
|
sigmoidCrossEntropyWithLogits
|
||||||
|
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||||
|
=> Tensor Value a -- ^ __logits__
|
||||||
|
-> Tensor Value a -- ^ __targets__
|
||||||
|
-> Build (Tensor Value a)
|
||||||
|
sigmoidCrossEntropyWithLogits logits targets = do
|
||||||
|
logits' <- render logits
|
||||||
|
targets' <- render targets
|
||||||
|
let zeros = zerosLike logits'
|
||||||
|
cond = logits' `greaterEqual` zeros
|
||||||
|
relu_logits = select cond logits' zeros
|
||||||
|
neg_abs_logits = select cond (-logits') logits'
|
||||||
|
withNameScope "logistic_loss" $ do
|
||||||
|
left <- render $ relu_logits - logits' * targets'
|
||||||
|
right <- render $ log (1 + exp neg_abs_logits)
|
||||||
|
withNameScope "sigmoid_add" $ render $ left `add` right
|
44
tensorflow-nn/tensorflow-nn.cabal
Normal file
44
tensorflow-nn/tensorflow-nn.cabal
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
name: tensorflow-nn
|
||||||
|
version: 0.1.0.0
|
||||||
|
synopsis: Friendly layer around TensorFlow bindings.
|
||||||
|
description: Please see README.md
|
||||||
|
homepage: https://github.com/tensorflow/haskell#readme
|
||||||
|
license: Apache
|
||||||
|
author: TensorFlow authors
|
||||||
|
maintainer: tensorflow-haskell@googlegroups.com
|
||||||
|
copyright: Google Inc.
|
||||||
|
category: Machine Learning
|
||||||
|
build-type: Simple
|
||||||
|
cabal-version: >=1.22
|
||||||
|
|
||||||
|
library
|
||||||
|
hs-source-dirs: src
|
||||||
|
exposed-modules: TensorFlow.NN
|
||||||
|
build-depends: base >= 4.7 && < 5
|
||||||
|
, tensorflow-core-ops == 0.1.*
|
||||||
|
, tensorflow == 0.1.*
|
||||||
|
, tensorflow-ops == 0.1.*
|
||||||
|
default-language: Haskell2010
|
||||||
|
|
||||||
|
|
||||||
|
Test-Suite NNTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: NNTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, QuickCheck
|
||||||
|
, base
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-nn
|
||||||
|
, google-shim
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, test-framework-quickcheck2
|
||||||
|
, vector
|
||||||
|
|
||||||
|
|
||||||
|
source-repository head
|
||||||
|
type: git
|
||||||
|
location: https://github.com/tensorflow/haskell
|
119
tensorflow-nn/tests/NNTest.hs
Normal file
119
tensorflow-nn/tests/NNTest.hs
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@?))
|
||||||
|
import Test.HUnit.Lang (Assertion(..))
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.Build as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.NN as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
|
-- | These tests are ported from:
|
||||||
|
--
|
||||||
|
-- <tensorflow>/tensorflow/python/ops/nn_xent_tests.py
|
||||||
|
--
|
||||||
|
-- This is the implementation we use to check the implementation we
|
||||||
|
-- wrote in `TensorFlow.NN.sigmoidCrossEntropyWithLogits`.
|
||||||
|
--
|
||||||
|
sigmoidXentWithLogits :: Floating a => Ord a => [a] -> [a] -> [a]
|
||||||
|
sigmoidXentWithLogits logits' targets' =
|
||||||
|
let sig = map (\x -> 1 / (1 + exp (-x))) logits'
|
||||||
|
eps = 0.0001
|
||||||
|
pred = map (\p -> min (max p eps) (1 - eps)) sig
|
||||||
|
xent y z = (-z) * (log y) - (1 - z) * log (1 - y)
|
||||||
|
in zipWith xent pred targets'
|
||||||
|
|
||||||
|
|
||||||
|
data Inputs = Inputs {
|
||||||
|
logits :: [Float]
|
||||||
|
, targets :: [Float]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
defInputs :: Inputs
|
||||||
|
defInputs = Inputs {
|
||||||
|
logits = [-100, -2, -2, 0, 2, 2, 2, 100]
|
||||||
|
, targets = [ 0, 0, 1, 0, 0, 1, 0.5, 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
assertAllClose :: V.Vector Float -> V.Vector Float -> Assertion
|
||||||
|
assertAllClose xs ys = all (<= tol) (V.zipWith absDiff xs ys) @?
|
||||||
|
("Difference > tolerance: \nxs: " ++ show xs ++ "\nys: " ++ show ys
|
||||||
|
++ "\ntolerance: " ++ show tol)
|
||||||
|
where
|
||||||
|
absDiff x y = abs (x - y)
|
||||||
|
tol = 0.001 :: Float
|
||||||
|
|
||||||
|
|
||||||
|
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||||
|
let inputs = defInputs
|
||||||
|
vLogits = TF.vector $ logits inputs
|
||||||
|
vTargets = TF.vector $ targets inputs
|
||||||
|
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||||
|
|
||||||
|
r <- run tfLoss
|
||||||
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
|
testLogisticOutputMultipleDim =
|
||||||
|
testCase "testLogisticOutputMultipleDim" $ do
|
||||||
|
let inputs = defInputs
|
||||||
|
shape = [2, 2, 2]
|
||||||
|
vLogits = TF.constant shape (logits inputs)
|
||||||
|
vTargets = TF.constant shape (targets inputs)
|
||||||
|
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||||
|
|
||||||
|
r <- run tfLoss
|
||||||
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
|
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||||
|
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||||
|
vLogits = TF.vector $ logits inputs
|
||||||
|
vTargets = TF.vector $ targets inputs
|
||||||
|
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
|
||||||
|
r <- run $ do
|
||||||
|
l <- tfLoss
|
||||||
|
TF.gradients l [vLogits]
|
||||||
|
|
||||||
|
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||||
|
|
||||||
|
|
||||||
|
run = TF.runSession . TF.buildAnd TF.run
|
||||||
|
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testLogisticOutput
|
||||||
|
, testLogisticOutputMultipleDim
|
||||||
|
, testGradientAtZero
|
||||||
|
]
|
Loading…
Reference in a new issue