1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Support fetching storable vectors + use them in benchmark (#50)

In addition, you can now fetch TensorData directly. This might be useful in
scenarios where you feed the result of a computation back in, like RNN.

Before:

benchmarking feedFetch/4 byte
time                 83.31 μs   (81.88 μs .. 84.75 μs)
                     0.997 R²   (0.994 R² .. 0.998 R²)
mean                 87.32 μs   (86.06 μs .. 88.83 μs)
std dev              4.580 μs   (3.698 μs .. 5.567 μs)
variance introduced by outliers: 55% (severely inflated)

benchmarking feedFetch/4 KiB
time                 114.9 μs   (111.5 μs .. 118.2 μs)
                     0.996 R²   (0.994 R² .. 0.998 R²)
mean                 117.3 μs   (116.2 μs .. 118.6 μs)
std dev              3.877 μs   (3.058 μs .. 5.565 μs)
variance introduced by outliers: 31% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 109.0 ms   (107.9 ms .. 110.7 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 108.6 ms   (108.2 ms .. 109.2 ms)
std dev              740.2 μs   (353.2 μs .. 1.186 ms)

After:

benchmarking feedFetch/4 byte
time                 82.92 μs   (80.55 μs .. 85.24 μs)
                     0.996 R²   (0.993 R² .. 0.998 R²)
mean                 83.58 μs   (82.34 μs .. 84.89 μs)
std dev              4.327 μs   (3.664 μs .. 5.375 μs)
variance introduced by outliers: 54% (severely inflated)

benchmarking feedFetch/4 KiB
time                 85.69 μs   (83.81 μs .. 87.30 μs)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 86.99 μs   (86.11 μs .. 88.15 μs)
std dev              3.608 μs   (2.854 μs .. 5.273 μs)
variance introduced by outliers: 43% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 1.582 ms   (1.509 ms .. 1.677 ms)
                     0.970 R²   (0.936 R² .. 0.993 R²)
mean                 1.645 ms   (1.554 ms .. 1.981 ms)
std dev              490.6 μs   (138.9 μs .. 1.067 ms)
variance introduced by outliers: 97% (severely inflated)
This commit is contained in:
fkm3 2016-12-14 18:53:06 -08:00 committed by Judah Jacobson
parent 91f508eb5c
commit f170df9d13
11 changed files with 154 additions and 114 deletions

View file

@ -22,14 +22,9 @@ import Data.List (genericLength)
import qualified Data.Text.IO as T import qualified Data.Text.IO as T
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Build as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.InputData
import TensorFlow.Examples.MNIST.Parse import TensorFlow.Examples.MNIST.Parse

View file

@ -48,10 +48,9 @@ import TensorFlow.Tensor
, tensorFromName , tensorFromName
) )
import TensorFlow.Ops import TensorFlow.Ops
import TensorFlow.Nodes (unScalar)
import TensorFlow.Session import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd) (runSession, run, run_, runWithFeeds, build, buildAnd)
import TensorFlow.Types (TensorType(..), Shape(..)) import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), Assertion) import Test.HUnit ((@=?), Assertion)

View file

@ -42,7 +42,7 @@ import TensorFlow.Build
, withNameScope , withNameScope
) )
import TensorFlow.ControlFlow (named) import TensorFlow.ControlFlow (named)
import TensorFlow.Nodes (unScalar) import TensorFlow.Types (unScalar)
import TensorFlow.Ops import TensorFlow.Ops
( add ( add
, assign , assign

View file

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
@ -32,7 +33,7 @@ import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input -- DynamicSplit is undone with DynamicStitch to get the original input
-- back. -- back.
testDynamicPartitionStitchInverse :: forall a. testDynamicPartitionStitchInverse :: forall a.
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property (TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] = let splitParts :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor CoreOps.dynamicPartition numParts (TF.vector values) partTensor

View file

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -129,8 +130,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
-- Verifies that direct gather is the same as dynamic split into -- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup. -- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) testEmbeddingLookupUndoesSplit ::
=> LookupExample a -> Property forall a. (TF.TensorDataType V.Vector a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit testEmbeddingLookupUndoesSplit
(LookupExample numParts (LookupExample numParts
shape@(TF.Shape (firstDim : restDims)) shape@(TF.Shape (firstDim : restDims))

View file

@ -5,7 +5,7 @@ import Control.Exception (evaluate)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Criterion.Main (defaultMain, bgroup, bench) import Criterion.Main (defaultMain, bgroup, bench)
import Criterion.Types (Benchmarkable(..)) import Criterion.Types (Benchmarkable(..))
import qualified Data.Vector as V import qualified Data.Vector.Storable as S
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
@ -23,12 +23,12 @@ nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
go m go m
-- | Benchmark feeding and fetching a vector. -- | Benchmark feeding and fetching a vector.
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float)) feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float))
feedFetchBenchmark = do feedFetchBenchmark = do
input <- TF.build (TF.placeholder (TF.Shape [-1])) input <- TF.build (TF.placeholder (TF.Shape [-1]))
output <- TF.build (TF.render (TF.identity input)) output <- TF.build (TF.render (TF.identity input))
return $ \v -> do return $ \v -> do
let shape = TF.Shape [fromIntegral (V.length v)] let shape = TF.Shape [fromIntegral (S.length v)]
inputData = TF.encodeTensorData shape v inputData = TF.encodeTensorData shape v
feeds = [TF.feed input inputData] feeds = [TF.feed input inputData]
TF.runWithFeeds feeds output TF.runWithFeeds feeds output
@ -36,8 +36,8 @@ feedFetchBenchmark = do
main :: IO () main :: IO ()
main = defaultMain main = defaultMain
[ bgroup "feedFetch" [ bgroup "feedFetch"
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0) [ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0)
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0) , bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0)
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0) , bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0)
] ]
] ]

View file

@ -85,7 +85,7 @@ instance Arbitrary a => Arbitrary (TensorDataInputs a) where
return $ TensorDataInputs sizes (V.fromList elems) return $ TensorDataInputs sizes (V.fromList elems)
-- Test that a vector is unchanged after being encoded and decoded. -- Test that a vector is unchanged after being encoded and decoded.
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
encodeDecodeProp (TensorDataInputs shape vec) = encodeDecodeProp (TensorDataInputs shape vec) =
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec

View file

@ -20,7 +20,7 @@ module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64) import Data.Int (Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import TensorFlow.Nodes (Scalar(..)) import TensorFlow.Types (Scalar(..))
import TensorFlow.Ops (scalar) import TensorFlow.Ops (scalar)
import TensorFlow.Queue import TensorFlow.Queue
import TensorFlow.Session import TensorFlow.Session

View file

@ -36,7 +36,6 @@ module TensorFlow.Core
, buildWithSummary , buildWithSummary
-- ** Running graphs -- ** Running graphs
, Fetchable , Fetchable
, Scalar(..)
, Nodes , Nodes
, run , run
, run_ , run_
@ -64,8 +63,10 @@ module TensorFlow.Core
, value , value
, tensorFromName , tensorFromName
-- ** Element types -- ** Element types
, TensorType
, TensorData , TensorData
, TensorType(decodeTensorData, encodeTensorData) , TensorDataType(decodeTensorData, encodeTensorData)
, Scalar(..)
, Shape(..) , Shape(..)
, OneOf , OneOf
, type (/=) , type (/=)

View file

@ -12,8 +12,8 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -24,12 +24,10 @@ import Control.Applicative (liftA2, liftA3)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.Set (Set) import Data.Set (Set)
import Data.String (IsString)
import Data.Text (Text) import Data.Text (Text)
import Lens.Family2 ((^.)) import Lens.Family2 ((^.))
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import qualified Data.Set as Set import qualified Data.Set as Set
import qualified Data.Vector as V
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Output import TensorFlow.Output
@ -101,18 +99,12 @@ instance a ~ () => Fetchable ControlNode a where
instance Nodes (Tensor v a) where instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
fetchTensorVector :: forall a v . TensorType a fetchTensorVector :: forall a v . TensorType a
=> Tensor v a -> Build (Fetch (Shape, V.Vector a)) => Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor _ o) = do fetchTensorVector (Tensor _ o) = do
outputName <- renderOutput o outputName <- renderOutput o
return $ Fetch (Set.singleton outputName) $ \tensors -> return $ Fetch (Set.singleton outputName) $ \tensors ->
let tensorData = tensors Map.! outputName let tensorData = tensors Map.! outputName
shape = Shape $ FFI.tensorDataDimensions tensorData
vec = decodeTensorData $ TensorData tensorData
expectedType = tensorType (undefined :: a) expectedType = tensorType (undefined :: a)
actualType = FFI.tensorDataType tensorData actualType = FFI.tensorDataType tensorData
badTypeError = error $ "Bad tensor type: expected " badTypeError = error $ "Bad tensor type: expected "
@ -121,21 +113,12 @@ fetchTensorVector (Tensor _ o) = do
++ show actualType ++ show actualType
in if expectedType /= actualType in if expectedType /= actualType
then badTypeError then badTypeError
else (shape, vec) else TensorData tensorData
-- The constraint "a ~ a'" means that the input/output of fetch can constrain -- The constraint "a ~ a'" means that the input/output of fetch can constrain
-- the TensorType of each other. -- the TensorType of each other.
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
getFetch t = fmap snd <$> fetchTensorVector t getFetch = fetchTensorVector
newtype Scalar a = Scalar {unScalar :: a} instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat, getFetch t = fmap decodeTensorData <$> fetchTensorVector t
RealFrac, IsString)
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
where
headFromSingleton [x] = x
headFromSingleton xs
= error $ "Unable to extract singleton from tensor of length "
++ show (length xs)

View file

@ -16,6 +16,7 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
@ -30,6 +31,8 @@
module TensorFlow.Types module TensorFlow.Types
( TensorType(..) ( TensorType(..)
, TensorData(..) , TensorData(..)
, TensorDataType(..)
, Scalar(..)
, Shape(..) , Shape(..)
, protoShape , protoShape
, Attribute(..) , Attribute(..)
@ -50,11 +53,13 @@ import Data.Complex (Complex)
import Data.Default (def) import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64) import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64) import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable) import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..)) import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~)) import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2.Unchecked (iso) import Lens.Family2.Unchecked (iso)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString as B import qualified Data.ByteString as B
@ -95,66 +100,31 @@ import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt) import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
-- | Data about a tensor that is encoded for the TensorFlow APIs.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | The class of scalar types supported by tensorflow. -- | The class of scalar types supported by tensorflow.
class TensorType a where class TensorType a where
tensorType :: a -> DataType tensorType :: a -> DataType
tensorRefType :: a -> DataType tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a] tensorVal :: Lens' TensorProto [a]
-- | Decode the bytes of a TensorData into a Vector.
decodeTensorData :: TensorData a -> V.Vector a
-- | Encode a Vector into a TensorData.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> V.Vector a -> TensorData a
-- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> V.Vector a
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> V.Vector a -> TensorData a
simpleEncode (Shape xs)
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
where
dt = tensorType (undefined :: a)
instance TensorType Float where instance TensorType Float where
tensorType _ = DT_FLOAT tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal tensorVal = floatVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Double where instance TensorType Double where
tensorType _ = DT_DOUBLE tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal tensorVal = doubleVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int32 where instance TensorType Int32 where
tensorType _ = DT_INT32 tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF tensorRefType _ = DT_INT32_REF
tensorVal = intVal tensorVal = intVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int64 where instance TensorType Int64 where
tensorType _ = DT_INT64 tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF tensorRefType _ = DT_INT64_REF
tensorVal = int64Val tensorVal = int64Val
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
integral :: Integral a => Lens' [Int32] [a] integral :: Integral a => Lens' [Int32] [a]
integral = iso (fmap fromIntegral) (fmap fromIntegral) integral = iso (fmap fromIntegral) (fmap fromIntegral)
@ -163,40 +133,140 @@ instance TensorType Word8 where
tensorType _ = DT_UINT8 tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Word16 where instance TensorType Word16 where
tensorType _ = DT_UINT16 tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int16 where instance TensorType Int16 where
tensorType _ = DT_INT16 tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int8 where instance TensorType Int8 where
tensorType _ = DT_INT8 tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType ByteString where instance TensorType ByteString where
tensorType _ = DT_STRING tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF tensorRefType _ = DT_STRING_REF
tensorVal = stringVal tensorVal = stringVal
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | Types that can be converted to and from 'TensorData'.
--
-- 'S.Vector' is the most efficient to encode/decode for most element types.
class TensorType a => TensorDataType s a where
-- | Decode the bytes of a 'TensorData' into an 's'.
decodeTensorData :: TensorData a -> s a
-- | Encode an 's' into a 'TensorData'.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> s a -> TensorData a
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
-- can use Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> S.Vector a -> TensorData a
simpleEncode (Shape xs) v =
if product xs /= fromIntegral (S.length v)
then error $ printf
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
(show xs) (product xs) (S.length v)
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
where
dt = tensorType (undefined :: a)
instance TensorDataType S.Vector Float where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Double where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int32 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int64 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
=> TensorDataType V.Vector a where
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64] -- table offsets for each element :: [Word64]
-- at each element offset: -- at each element offset:
-- string length :: VarInt64 -- string length :: VarInt64
-- string data :: [Word8] -- string data :: [Word8]
-- TODO(fmayle): Benchmark these functions.
decodeTensorData tensorData = decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count if expected /= count
@ -241,32 +311,21 @@ instance TensorType ByteString where
-- Convert to Vector Word8. -- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes newtype Scalar a = Scalar {unScalar :: a}
-- encoding more expensive. It may make sense to define a custom boolean type. deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
instance TensorType Bool where RealFrac, IsString)
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance TensorType (Complex Float) where instance TensorDataType V.Vector a => TensorDataType Scalar a where
tensorType _ = DT_COMPLEX64 decodeTensorData = Scalar . headFromSingleton . decodeTensorData
tensorRefType _ = DT_COMPLEX64 encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)" headFromSingleton :: V.Vector a -> a
encodeTensorData = error "TODO (Complex Float)" headFromSingleton x
| V.length x == 1 = V.head x
| otherwise = error $
"Unable to extract singleton from tensor of length "
++ show (V.length x)
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
-- | Shape (dimensions) of a tensor. -- | Shape (dimensions) of a tensor.
newtype Shape = Shape [Int64] deriving Show newtype Shape = Shape [Int64] deriving Show