mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-25 19:19:45 +01:00
2c5c879037
This change adds a class that both `Build` and `Session` are instances of: class MonadBuild m where build :: Build a -> m a All stateful ops (generated and manually written) now have a signature that returns an instance of `MonadBuild` (rather than just `Build`). For example: assign_ :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v t -> m (Tensor Ref t) This lets us remove a bunch of spurious calls to `build` in user code. It also lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run` (or `run =<< foo`, which is sometimes nicer when foo is a complicated expression). I went ahead and deleted `buildAnd` altogether since it seems to lead to confusion; in particular a few tests had `buildAnd run . pure` which is actually equivalent to just `run`.
67 lines
2.7 KiB
Haskell
67 lines
2.7 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
import Data.Int (Int32, Int64)
|
|
import Data.List (genericLength)
|
|
import Google.Test (googleTest)
|
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
|
import Test.HUnit ((@=?))
|
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
|
|
|
import qualified Data.Vector as V
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
import qualified TensorFlow.Ops as TF
|
|
import qualified TensorFlow.Session as TF
|
|
import qualified TensorFlow.Tensor as TF
|
|
import qualified TensorFlow.Types as TF
|
|
|
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
|
-- back.
|
|
testDynamicPartitionStitchInverse :: forall a.
|
|
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
|
let splitParts :: [TF.Tensor TF.Value a] =
|
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
|
partTensor = TF.vector partitions
|
|
restitchIndices = CoreOps.dynamicPartition numParts
|
|
(TF.vector [0..genericLength values-1])
|
|
partTensor
|
|
-- drop (numParts - 2) from both args to expose b/27343984
|
|
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
|
in monadicIO $ run $ do
|
|
fromIntegral numParts @=? length splitParts
|
|
valuesOut <- TF.runSession $ TF.run restitch
|
|
V.fromList values @=? valuesOut
|
|
|
|
data StitchExample a = StitchExample Int64 [a] [Int32]
|
|
deriving Show
|
|
|
|
instance Arbitrary a => Arbitrary (StitchExample a) where
|
|
arbitrary = do
|
|
-- Limits the size of the vector.
|
|
size <- choose (1, 100)
|
|
values <- vectorOf size arbitrary
|
|
numParts <- choose (2, 15)
|
|
partitions <- vectorOf size (choose (0, fromIntegral numParts - 1))
|
|
return $ StitchExample numParts values partitions
|
|
|
|
main :: IO ()
|
|
main = googleTest
|
|
[ testProperty "DynamicPartitionStitchInverse"
|
|
(testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property)
|
|
]
|