2016-10-24 21:26:42 +02:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
2017-03-23 06:47:42 +01:00
|
|
|
{-# LANGUAGE OverloadedLists #-}
|
2016-10-24 21:26:42 +02:00
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
|
|
module Main where
|
|
|
|
|
|
|
|
import Control.Monad.IO.Class (liftIO)
|
|
|
|
import Data.Int (Int32, Int64)
|
|
|
|
import Google.Test (googleTest)
|
2017-03-21 02:16:38 +01:00
|
|
|
import Lens.Family2 ((.~))
|
2016-10-24 21:26:42 +02:00
|
|
|
import System.IO.Temp (withSystemTempDirectory)
|
2016-11-18 19:42:02 +01:00
|
|
|
import Test.Framework (Test)
|
2016-10-24 21:26:42 +02:00
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
|
|
import Test.HUnit ((@=?))
|
|
|
|
import qualified Data.ByteString.Char8 as B8
|
|
|
|
|
|
|
|
import qualified Data.Vector as V
|
|
|
|
import qualified TensorFlow.Build as TF
|
|
|
|
import qualified TensorFlow.Nodes as TF
|
|
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
import qualified TensorFlow.Session as TF
|
|
|
|
import qualified TensorFlow.Tensor as TF
|
2016-11-18 19:42:02 +01:00
|
|
|
import qualified TensorFlow.Types as TF
|
2016-10-24 21:26:42 +02:00
|
|
|
|
|
|
|
-- | Test that one can easily determine number of elements in the tensor.
|
2016-11-18 19:42:02 +01:00
|
|
|
testSize :: Test
|
2016-10-24 21:26:42 +02:00
|
|
|
testSize = testCase "testSize" $ do
|
2016-11-18 19:42:02 +01:00
|
|
|
x <- eval $ TF.size (TF.constant (TF.Shape [2, 3]) [0..5 :: Float])
|
2016-10-24 21:26:42 +02:00
|
|
|
TF.Scalar (2 * 3 :: Int32) @=? x
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
eval :: TF.Fetchable t a => t -> IO a
|
2017-03-18 20:08:53 +01:00
|
|
|
eval = TF.runSession . TF.run
|
2016-10-24 21:26:42 +02:00
|
|
|
|
|
|
|
-- | Confirms that the original example from Python code works.
|
2016-11-18 19:42:02 +01:00
|
|
|
testReducedShape :: Test
|
2016-10-24 21:26:42 +02:00
|
|
|
testReducedShape = testCase "testReducedShape" $ do
|
|
|
|
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
|
|
|
|
(TF.vector [1, 2 :: Int32])
|
|
|
|
V.fromList [2, 1, 1, 7 :: Int32] @=? x
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
testSaveRestore :: Test
|
2016-10-24 21:26:42 +02:00
|
|
|
testSaveRestore = testCase "testSaveRestore" $
|
|
|
|
withSystemTempDirectory "" $ \dirPath -> do
|
|
|
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
2017-03-18 20:08:53 +01:00
|
|
|
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
2017-04-07 00:10:33 +02:00
|
|
|
var = TF.zeroInitializedVariable' (TF.opName .~ "a")
|
2017-03-21 02:16:38 +01:00
|
|
|
(TF.Shape [])
|
2016-10-24 21:26:42 +02:00
|
|
|
TF.runSession $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
v <- var
|
|
|
|
TF.assign v 134 >>= TF.run_
|
|
|
|
TF.save path [v] >>= TF.run_
|
2016-10-24 21:26:42 +02:00
|
|
|
result <- TF.runSession $ do
|
2017-03-18 20:08:53 +01:00
|
|
|
v <- var
|
|
|
|
TF.restore path v >>= TF.run_
|
2016-10-24 21:26:42 +02:00
|
|
|
TF.run v
|
|
|
|
liftIO $ TF.Scalar 134 @=? result
|
|
|
|
|
2017-03-23 06:47:42 +01:00
|
|
|
-- | Test that 'placeholder' is not CSE'd.
|
|
|
|
testPlaceholderCse :: Test
|
|
|
|
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
|
|
|
p1 <- TF.placeholder []
|
|
|
|
p2 <- TF.placeholder []
|
|
|
|
let enc :: Float -> TF.TensorData Float
|
|
|
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
2017-04-07 00:10:33 +02:00
|
|
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
|
|
|
$ p1 `TF.add` p2
|
2017-03-23 06:47:42 +01:00
|
|
|
liftIO $ result @=? TF.Scalar 5
|
|
|
|
|
2017-03-23 20:58:40 +01:00
|
|
|
-- | Test that regular tensors can also be used for feeds, as long as they each
|
|
|
|
-- have a different name.
|
|
|
|
testScalarFeedCse :: Test
|
|
|
|
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
|
|
|
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
|
|
|
|
-- The second op is identical to the first other than its name; make sure
|
|
|
|
-- we don't aggressively CSE them together and prevent feeding them
|
|
|
|
-- separately.
|
|
|
|
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
|
|
|
let enc :: Float -> TF.TensorData Float
|
|
|
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
2017-04-07 00:10:33 +02:00
|
|
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
|
|
|
$ p1 `TF.add` p2
|
2017-03-23 20:58:40 +01:00
|
|
|
liftIO $ result @=? TF.Scalar 5
|
2016-10-24 21:26:42 +02:00
|
|
|
|
2017-05-09 02:45:56 +02:00
|
|
|
-- | See https://github.com/tensorflow/haskell/issues/92.
|
|
|
|
-- Even though we're not explicitly evaluating `f0` until the end,
|
|
|
|
-- it should hold the earlier value of the variable.
|
|
|
|
testRereadRef :: Test
|
|
|
|
testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
|
|
|
|
w <- TF.initializedVariable 0
|
|
|
|
f0 <- TF.run w
|
|
|
|
TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float))
|
|
|
|
f1 <- TF.run w
|
|
|
|
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
|
|
|
|
|
2016-10-24 21:26:42 +02:00
|
|
|
main :: IO ()
|
|
|
|
main = googleTest [ testSaveRestore
|
|
|
|
, testSize
|
|
|
|
, testReducedShape
|
2017-03-23 06:47:42 +01:00
|
|
|
, testPlaceholderCse
|
2017-03-23 20:58:40 +01:00
|
|
|
, testScalarFeedCse
|
2017-05-09 02:45:56 +02:00
|
|
|
, testRereadRef
|
2016-10-24 21:26:42 +02:00
|
|
|
]
|