Support fetching storable vectors + use them in benchmark (#50)

In addition, you can now fetch TensorData directly. This might be useful in
scenarios where you feed the result of a computation back in, like RNN.

Before:

benchmarking feedFetch/4 byte
time                 83.31 μs   (81.88 μs .. 84.75 μs)
                     0.997 R²   (0.994 R² .. 0.998 R²)
mean                 87.32 μs   (86.06 μs .. 88.83 μs)
std dev              4.580 μs   (3.698 μs .. 5.567 μs)
variance introduced by outliers: 55% (severely inflated)

benchmarking feedFetch/4 KiB
time                 114.9 μs   (111.5 μs .. 118.2 μs)
                     0.996 R²   (0.994 R² .. 0.998 R²)
mean                 117.3 μs   (116.2 μs .. 118.6 μs)
std dev              3.877 μs   (3.058 μs .. 5.565 μs)
variance introduced by outliers: 31% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 109.0 ms   (107.9 ms .. 110.7 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 108.6 ms   (108.2 ms .. 109.2 ms)
std dev              740.2 μs   (353.2 μs .. 1.186 ms)

After:

benchmarking feedFetch/4 byte
time                 82.92 μs   (80.55 μs .. 85.24 μs)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 83.58 μs   (82.34 μs .. 84.89 μs)
std dev              4.327 μs   (3.664 μs .. 5.375 μs)
variance introduced by outliers: 54% (severely inflated)

benchmarking feedFetch/4 KiB
time                 85.69 μs   (83.81 μs .. 87.30 μs)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 86.99 μs   (86.11 μs .. 88.15 μs)
std dev              3.608 μs   (2.854 μs .. 5.273 μs)
variance introduced by outliers: 43% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 1.582 ms   (1.509 ms .. 1.677 ms)
                     0.970 R²   (0.936 R² .. 0.993 R²)
mean                 1.645 ms   (1.554 ms .. 1.981 ms)
std dev              490.6 μs   (138.9 μs .. 1.067 ms)
variance introduced by outliers: 97% (severely inflated)
This commit is contained in:
fkm3 2016-12-14 18:53:06 -08:00 committed by Judah Jacobson
parent 91f508eb5c
commit f170df9d13
11 changed files with 154 additions and 114 deletions

View File

@ -22,14 +22,9 @@ import Data.List (genericLength)
import qualified Data.Text.IO as T
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse

View File

@ -48,10 +48,9 @@ import TensorFlow.Tensor
, tensorFromName
)
import TensorFlow.Ops
import TensorFlow.Nodes (unScalar)
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd)
import TensorFlow.Types (TensorType(..), Shape(..))
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), Assertion)

View File

@ -42,7 +42,7 @@ import TensorFlow.Build
, withNameScope
)
import TensorFlow.ControlFlow (named)
import TensorFlow.Nodes (unScalar)
import TensorFlow.Types (unScalar)
import TensorFlow.Ops
( add
, assign

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32, Int64)
@ -32,7 +33,7 @@ import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input
-- back.
testDynamicPartitionStitchInverse :: forall a.
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -129,8 +130,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit ::
forall a. (TF.TensorDataType V.Vector a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit
(LookupExample numParts
shape@(TF.Shape (firstDim : restDims))

View File

@ -5,7 +5,7 @@ import Control.Exception (evaluate)
import Control.Monad.IO.Class (liftIO)
import Criterion.Main (defaultMain, bgroup, bench)
import Criterion.Types (Benchmarkable(..))
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF
@ -23,12 +23,12 @@ nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
go m
-- | Benchmark feeding and fetching a vector.
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float))
feedFetchBenchmark = do
input <- TF.build (TF.placeholder (TF.Shape [-1]))
output <- TF.build (TF.render (TF.identity input))
return $ \v -> do
let shape = TF.Shape [fromIntegral (V.length v)]
let shape = TF.Shape [fromIntegral (S.length v)]
inputData = TF.encodeTensorData shape v
feeds = [TF.feed input inputData]
TF.runWithFeeds feeds output
@ -36,8 +36,8 @@ feedFetchBenchmark = do
main :: IO ()
main = defaultMain
[ bgroup "feedFetch"
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
[ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0)
, bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0)
, bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0)
]
]

View File

@ -85,7 +85,7 @@ instance Arbitrary a => Arbitrary (TensorDataInputs a) where
return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded.
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec

View File

@ -20,7 +20,7 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import TensorFlow.Nodes (Scalar(..))
import TensorFlow.Types (Scalar(..))
import TensorFlow.Ops (scalar)
import TensorFlow.Queue
import TensorFlow.Session

View File

@ -36,7 +36,6 @@ module TensorFlow.Core
, buildWithSummary
-- ** Running graphs
, Fetchable
, Scalar(..)
, Nodes
, run
, run_
@ -64,8 +63,10 @@ module TensorFlow.Core
, value
, tensorFromName
-- ** Element types
, TensorType
, TensorData
, TensorType(decodeTensorData, encodeTensorData)
, TensorDataType(decodeTensorData, encodeTensorData)
, Scalar(..)
, Shape(..)
, OneOf
, type (/=)

View File

@ -12,8 +12,8 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -24,12 +24,10 @@ import Control.Applicative (liftA2, liftA3)
import Data.Map.Strict (Map)
import Data.Monoid ((<>))
import Data.Set (Set)
import Data.String (IsString)
import Data.Text (Text)
import Lens.Family2 ((^.))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Vector as V
import TensorFlow.Build
import TensorFlow.Output
@ -101,18 +99,12 @@ instance a ~ () => Fetchable ControlNode a where
instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
fetchTensorVector :: forall a v . TensorType a
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
=> Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor _ o) = do
outputName <- renderOutput o
return $ Fetch (Set.singleton outputName) $ \tensors ->
let tensorData = tensors Map.! outputName
shape = Shape $ FFI.tensorDataDimensions tensorData
vec = decodeTensorData $ TensorData tensorData
expectedType = tensorType (undefined :: a)
actualType = FFI.tensorDataType tensorData
badTypeError = error $ "Bad tensor type: expected "
@ -121,21 +113,12 @@ fetchTensorVector (Tensor _ o) = do
++ show actualType
in if expectedType /= actualType
then badTypeError
else (shape, vec)
else TensorData tensorData
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
-- the TensorType of each other.
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
getFetch t = fmap snd <$> fetchTensorVector t
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
getFetch = fetchTensorVector
newtype Scalar a = Scalar {unScalar :: a}
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
RealFrac, IsString)
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
where
headFromSingleton [x] = x
headFromSingleton xs
= error $ "Unable to extract singleton from tensor of length "
++ show (length xs)
instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
getFetch t = fmap decodeTensorData <$> fetchTensorVector t

View File

@ -16,6 +16,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
@ -30,6 +31,8 @@
module TensorFlow.Types
( TensorType(..)
, TensorData(..)
, TensorDataType(..)
, Scalar(..)
, Shape(..)
, protoShape
, Attribute(..)
@ -50,11 +53,13 @@ import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2.Unchecked (iso)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
@ -95,66 +100,31 @@ import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI
-- | Data about a tensor that is encoded for the TensorFlow APIs.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | The class of scalar types supported by tensorflow.
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
-- | Decode the bytes of a TensorData into a Vector.
decodeTensorData :: TensorData a -> V.Vector a
-- | Encode a Vector into a TensorData.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> V.Vector a -> TensorData a
-- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> V.Vector a
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> V.Vector a -> TensorData a
simpleEncode (Shape xs)
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
where
dt = tensorType (undefined :: a)
instance TensorType Float where
tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Double where
tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int32 where
tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF
tensorVal = intVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int64 where
tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
integral :: Integral a => Lens' [Int32] [a]
integral = iso (fmap fromIntegral) (fmap fromIntegral)
@ -163,40 +133,140 @@ instance TensorType Word8 where
tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Word16 where
tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int8 where
tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType ByteString where
tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | Types that can be converted to and from 'TensorData'.
--
-- 'S.Vector' is the most efficient to encode/decode for most element types.
class TensorType a => TensorDataType s a where
-- | Decode the bytes of a 'TensorData' into an 's'.
decodeTensorData :: TensorData a -> s a
-- | Encode an 's' into a 'TensorData'.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> s a -> TensorData a
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
-- can use Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> S.Vector a -> TensorData a
simpleEncode (Shape xs) v =
if product xs /= fromIntegral (S.length v)
then error $ printf
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
(show xs) (product xs) (S.length v)
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
where
dt = tensorType (undefined :: a)
instance TensorDataType S.Vector Float where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Double where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int32 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int64 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
=> TensorDataType V.Vector a where
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset:
-- string length :: VarInt64
-- string data :: [Word8]
-- TODO(fmayle): Benchmark these functions.
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count
@ -241,32 +311,21 @@ instance TensorType ByteString where
-- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
newtype Scalar a = Scalar {unScalar :: a}
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
RealFrac, IsString)
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance TensorDataType V.Vector a => TensorDataType Scalar a where
decodeTensorData = Scalar . headFromSingleton . decodeTensorData
encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
headFromSingleton :: V.Vector a -> a
headFromSingleton x
| V.length x == 1 = V.head x
| otherwise = error $
"Unable to extract singleton from tensor of length "
++ show (V.length x)
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
-- | Shape (dimensions) of a tensor.
newtype Shape = Shape [Int64] deriving Show