mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 14:59:44 +01:00
f170df9d13
In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
67 lines
2.7 KiB
Haskell
67 lines
2.7 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
import Data.Int (Int32, Int64)
|
|
import Data.List (genericLength)
|
|
import Google.Test (googleTest)
|
|
import Test.Framework.Providers.QuickCheck2 (testProperty)
|
|
import Test.HUnit ((@=?))
|
|
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
|
|
import Test.QuickCheck.Monadic (monadicIO, run)
|
|
|
|
import qualified Data.Vector as V
|
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
|
import qualified TensorFlow.Ops as TF
|
|
import qualified TensorFlow.Session as TF
|
|
import qualified TensorFlow.Tensor as TF
|
|
import qualified TensorFlow.Types as TF
|
|
|
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
|
-- back.
|
|
testDynamicPartitionStitchInverse :: forall a.
|
|
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
|
let splitParts :: [TF.Tensor TF.Value a] =
|
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
|
partTensor = TF.vector partitions
|
|
restitchIndices = CoreOps.dynamicPartition numParts
|
|
(TF.vector [0..genericLength values-1])
|
|
partTensor
|
|
-- drop (numParts - 2) from both args to expose b/27343984
|
|
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
|
in monadicIO $ run $ do
|
|
fromIntegral numParts @=? length splitParts
|
|
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
|
V.fromList values @=? valuesOut
|
|
|
|
data StitchExample a = StitchExample Int64 [a] [Int32]
|
|
deriving Show
|
|
|
|
instance Arbitrary a => Arbitrary (StitchExample a) where
|
|
arbitrary = do
|
|
-- Limits the size of the vector.
|
|
size <- choose (1, 100)
|
|
values <- vectorOf size arbitrary
|
|
numParts <- choose (2, 15)
|
|
partitions <- vectorOf size (choose (0, fromIntegral numParts - 1))
|
|
return $ StitchExample numParts values partitions
|
|
|
|
main :: IO ()
|
|
main = googleTest
|
|
[ testProperty "DynamicPartitionStitchInverse"
|
|
(testDynamicPartitionStitchInverse :: StitchExample Int64 -> Property)
|
|
]
|