tensorflow-haskell/tensorflow-ops/tests/OpsTest.hs

76 lines
2.7 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Google.Test (googleTest)
import System.IO.Temp (withSystemTempDirectory)
import Test.Framework (Test)
2016-10-24 21:26:42 +02:00
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
2016-11-09 01:55:51 +01:00
import qualified TensorFlow.GenOps.Core as CoreOps
2016-10-24 21:26:42 +02:00
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
2016-11-09 01:55:51 +01:00
import qualified TensorFlow.Output as TF
2016-10-24 21:26:42 +02:00
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Types as TF
2016-10-24 21:26:42 +02:00
-- | Test that one can easily determine number of elements in the tensor.
testSize :: Test
2016-10-24 21:26:42 +02:00
testSize = testCase "testSize" $ do
x <- eval $ TF.size (TF.constant (TF.Shape [2, 3]) [0..5 :: Float])
2016-10-24 21:26:42 +02:00
TF.Scalar (2 * 3 :: Int32) @=? x
eval :: TF.Fetchable t a => t -> IO a
2016-10-24 21:26:42 +02:00
eval = TF.runSession . TF.buildAnd TF.run . return
-- | Confirms that the original example from Python code works.
testReducedShape :: Test
2016-10-24 21:26:42 +02:00
testReducedShape = testCase "testReducedShape" $ do
x <- eval $ TF.reducedShape (TF.vector [2, 3, 5, 7 :: Int64])
(TF.vector [1, 2 :: Int32])
V.fromList [2, 1, 1, 7 :: Int32] @=? x
testSaveRestore :: Test
2016-10-24 21:26:42 +02:00
testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint"
2016-11-09 01:55:51 +01:00
var :: TF.Build (TF.ResourceHandle Float)
var = TF.zeroInitializedVariable (TF.Shape [])
2016-10-24 21:26:42 +02:00
TF.runSession $ do
v <- TF.build var
2016-11-09 01:55:51 +01:00
TF.buildAnd TF.run_ $ TF.group $ CoreOps.assignVariableOp v 134
2016-10-24 21:26:42 +02:00
TF.buildAnd TF.run_ $ TF.save path [v]
result <- TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.restore path v
2016-11-09 01:55:51 +01:00
TF.run (CoreOps.readVariableOp v)
2016-10-24 21:26:42 +02:00
liftIO $ TF.Scalar 134 @=? result
main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
]