From f170df9d13dacdfb8ef75ea25070adec91923fc9 Mon Sep 17 00:00:00 2001 From: fkm3 Date: Wed, 14 Dec 2016 18:53:06 -0800 Subject: [PATCH] Support fetching storable vectors + use them in benchmark (#50) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated) --- tensorflow-mnist/app/Main.hs | 7 +- tensorflow-mnist/tests/ParseTest.hs | 3 +- tensorflow-ops/tests/BuildTest.hs | 2 +- tensorflow-ops/tests/DataFlowOpsTest.hs | 3 +- tensorflow-ops/tests/EmbeddingOpsTest.hs | 6 +- tensorflow-ops/tests/FeedFetchBench.hs | 12 +- tensorflow-ops/tests/TypesTest.hs | 2 +- tensorflow-queue/tests/QueueTest.hs | 2 +- tensorflow/src/TensorFlow/Core.hs | 5 +- tensorflow/src/TensorFlow/Nodes.hs | 31 +--- tensorflow/src/TensorFlow/Types.hs | 195 +++++++++++++++-------- 11 files changed, 154 insertions(+), 114 deletions(-) diff --git a/tensorflow-mnist/app/Main.hs b/tensorflow-mnist/app/Main.hs index cdf3fe1..5475469 100644 --- a/tensorflow-mnist/app/Main.hs +++ b/tensorflow-mnist/app/Main.hs @@ -22,14 +22,9 @@ import Data.List (genericLength) import qualified Data.Text.IO as T import qualified Data.Vector as V -import qualified TensorFlow.Build as TF -import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.Core as TF import qualified TensorFlow.Gradient as TF -import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Ops as TF -import qualified TensorFlow.Session as TF -import qualified TensorFlow.Tensor as TF -import qualified TensorFlow.Types as TF import TensorFlow.Examples.MNIST.InputData import TensorFlow.Examples.MNIST.Parse diff --git a/tensorflow-mnist/tests/ParseTest.hs b/tensorflow-mnist/tests/ParseTest.hs index d7a428a..89b3d4a 100644 --- a/tensorflow-mnist/tests/ParseTest.hs +++ b/tensorflow-mnist/tests/ParseTest.hs @@ -48,10 +48,9 @@ import TensorFlow.Tensor , tensorFromName ) import TensorFlow.Ops -import TensorFlow.Nodes (unScalar) import TensorFlow.Session (runSession, run, run_, runWithFeeds, build, buildAnd) -import TensorFlow.Types (TensorType(..), Shape(..)) +import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) import Test.HUnit ((@=?), Assertion) diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index 2689dd6..d8bf859 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -42,7 +42,7 @@ import TensorFlow.Build , withNameScope ) import TensorFlow.ControlFlow (named) -import TensorFlow.Nodes (unScalar) +import TensorFlow.Types (unScalar) import TensorFlow.Ops ( add , assign diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs index cd362c9..789df51 100644 --- a/tensorflow-ops/tests/DataFlowOpsTest.hs +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -12,6 +12,7 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} import Data.Int (Int32, Int64) @@ -32,7 +33,7 @@ import qualified TensorFlow.Types as TF -- DynamicSplit is undone with DynamicStitch to get the original input -- back. testDynamicPartitionStitchInverse :: forall a. - (TF.TensorType a, Show a, Eq a) => StitchExample a -> Property + (TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = let splitParts :: [TF.Tensor TF.Value a] = CoreOps.dynamicPartition numParts (TF.vector values) partTensor diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 5bd3247..45e5647 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -12,6 +12,7 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -129,8 +130,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do -- Verifies that direct gather is the same as dynamic split into -- partitions, followed by embedding lookup. -testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) - => LookupExample a -> Property +testEmbeddingLookupUndoesSplit :: + forall a. (TF.TensorDataType V.Vector a, Show a, Eq a) + => LookupExample a -> Property testEmbeddingLookupUndoesSplit (LookupExample numParts shape@(TF.Shape (firstDim : restDims)) diff --git a/tensorflow-ops/tests/FeedFetchBench.hs b/tensorflow-ops/tests/FeedFetchBench.hs index c7877b3..04465ef 100644 --- a/tensorflow-ops/tests/FeedFetchBench.hs +++ b/tensorflow-ops/tests/FeedFetchBench.hs @@ -5,7 +5,7 @@ import Control.Exception (evaluate) import Control.Monad.IO.Class (liftIO) import Criterion.Main (defaultMain, bgroup, bench) import Criterion.Types (Benchmarkable(..)) -import qualified Data.Vector as V +import qualified Data.Vector.Storable as S import qualified TensorFlow.Core as TF import qualified TensorFlow.Ops as TF @@ -23,12 +23,12 @@ nfSession init x = Benchmarkable $ \m -> TF.runSession $ do go m -- | Benchmark feeding and fetching a vector. -feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float)) +feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float)) feedFetchBenchmark = do input <- TF.build (TF.placeholder (TF.Shape [-1])) output <- TF.build (TF.render (TF.identity input)) return $ \v -> do - let shape = TF.Shape [fromIntegral (V.length v)] + let shape = TF.Shape [fromIntegral (S.length v)] inputData = TF.encodeTensorData shape v feeds = [TF.feed input inputData] TF.runWithFeeds feeds output @@ -36,8 +36,8 @@ feedFetchBenchmark = do main :: IO () main = defaultMain [ bgroup "feedFetch" - [ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0) - , bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0) - , bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0) + [ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0) + , bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0) + , bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0) ] ] diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs index b0c8579..2610231 100644 --- a/tensorflow-ops/tests/TypesTest.hs +++ b/tensorflow-ops/tests/TypesTest.hs @@ -85,7 +85,7 @@ instance Arbitrary a => Arbitrary (TensorDataInputs a) where return $ TensorDataInputs sizes (V.fromList elems) -- Test that a vector is unchanged after being encoded and decoded. -encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool +encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool encodeDecodeProp (TensorDataInputs shape vec) = TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec diff --git a/tensorflow-queue/tests/QueueTest.hs b/tensorflow-queue/tests/QueueTest.hs index b92c5f3..f3b38eb 100644 --- a/tensorflow-queue/tests/QueueTest.hs +++ b/tensorflow-queue/tests/QueueTest.hs @@ -20,7 +20,7 @@ module Main where import Control.Monad.IO.Class (liftIO) import Data.Int (Int64) import Google.Test (googleTest) -import TensorFlow.Nodes (Scalar(..)) +import TensorFlow.Types (Scalar(..)) import TensorFlow.Ops (scalar) import TensorFlow.Queue import TensorFlow.Session diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 0fc0590..3938e89 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -36,7 +36,6 @@ module TensorFlow.Core , buildWithSummary -- ** Running graphs , Fetchable - , Scalar(..) , Nodes , run , run_ @@ -64,8 +63,10 @@ module TensorFlow.Core , value , tensorFromName -- ** Element types + , TensorType , TensorData - , TensorType(decodeTensorData, encodeTensorData) + , TensorDataType(decodeTensorData, encodeTensorData) + , Scalar(..) , Shape(..) , OneOf , type (/=) diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs index 730c9e5..5e8c62d 100644 --- a/tensorflow/src/TensorFlow/Nodes.hs +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -12,8 +12,8 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -24,12 +24,10 @@ import Control.Applicative (liftA2, liftA3) import Data.Map.Strict (Map) import Data.Monoid ((<>)) import Data.Set (Set) -import Data.String (IsString) import Data.Text (Text) import Lens.Family2 ((^.)) import qualified Data.Map.Strict as Map import qualified Data.Set as Set -import qualified Data.Vector as V import TensorFlow.Build import TensorFlow.Output @@ -101,18 +99,12 @@ instance a ~ () => Fetchable ControlNode a where instance Nodes (Tensor v a) where getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) -fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a])) -fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t - fetchTensorVector :: forall a v . TensorType a - => Tensor v a -> Build (Fetch (Shape, V.Vector a)) + => Tensor v a -> Build (Fetch (TensorData a)) fetchTensorVector (Tensor _ o) = do outputName <- renderOutput o return $ Fetch (Set.singleton outputName) $ \tensors -> let tensorData = tensors Map.! outputName - shape = Shape $ FFI.tensorDataDimensions tensorData - vec = decodeTensorData $ TensorData tensorData - expectedType = tensorType (undefined :: a) actualType = FFI.tensorDataType tensorData badTypeError = error $ "Bad tensor type: expected " @@ -121,21 +113,12 @@ fetchTensorVector (Tensor _ o) = do ++ show actualType in if expectedType /= actualType then badTypeError - else (shape, vec) + else TensorData tensorData -- The constraint "a ~ a'" means that the input/output of fetch can constrain -- the TensorType of each other. -instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where - getFetch t = fmap snd <$> fetchTensorVector t +instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where + getFetch = fetchTensorVector -newtype Scalar a = Scalar {unScalar :: a} - deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat, - RealFrac, IsString) - -instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where - getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t - where - headFromSingleton [x] = x - headFromSingleton xs - = error $ "Unable to extract singleton from tensor of length " - ++ show (length xs) +instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where + getFetch t = fmap decodeTensorData <$> fetchTensorVector t diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 497942b..3ed9cec 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -16,6 +16,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} @@ -30,6 +31,8 @@ module TensorFlow.Types ( TensorType(..) , TensorData(..) + , TensorDataType(..) + , Scalar(..) , Shape(..) , protoShape , Attribute(..) @@ -50,11 +53,13 @@ import Data.Complex (Complex) import Data.Default (def) import Data.Int (Int8, Int16, Int32, Int64) import Data.Monoid ((<>)) +import Data.String (IsString) import Data.Word (Word8, Word16, Word64) import Foreign.Storable (Storable) import GHC.Exts (Constraint, IsList(..)) import Lens.Family2 (Lens', view, (&), (.~)) import Lens.Family2.Unchecked (iso) +import Text.Printf (printf) import qualified Data.Attoparsec.ByteString as Atto import Data.ByteString (ByteString) import qualified Data.ByteString as B @@ -95,66 +100,31 @@ import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import TensorFlow.Internal.VarInt (getVarInt, putVarInt) import qualified TensorFlow.Internal.FFI as FFI --- | Data about a tensor that is encoded for the TensorFlow APIs. -newtype TensorData a = TensorData { unTensorData :: FFI.TensorData } - -- | The class of scalar types supported by tensorflow. class TensorType a where tensorType :: a -> DataType tensorRefType :: a -> DataType tensorVal :: Lens' TensorProto [a] - -- | Decode the bytes of a TensorData into a Vector. - decodeTensorData :: TensorData a -> V.Vector a - -- | Encode a Vector into a TensorData. - -- - -- The values should be in row major order, e.g., - -- - -- element 0: index (0, ..., 0) - -- element 1: index (0, ..., 1) - -- ... - encodeTensorData :: Shape -> V.Vector a -> TensorData a - --- All types, besides ByteString, are encoded as simple arrays and we can use --- Vector.Storable to encode/decode by type casting pointers. - --- TODO(fmayle): Assert that the data type matches the return type. -simpleDecode :: Storable a => TensorData a -> V.Vector a -simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData - -simpleEncode :: forall a . (TensorType a, Storable a) - => Shape -> V.Vector a -> TensorData a -simpleEncode (Shape xs) - = TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert - where - dt = tensorType (undefined :: a) instance TensorType Float where tensorType _ = DT_FLOAT tensorRefType _ = DT_FLOAT_REF tensorVal = floatVal - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Double where tensorType _ = DT_DOUBLE tensorRefType _ = DT_DOUBLE_REF tensorVal = doubleVal - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Int32 where tensorType _ = DT_INT32 tensorRefType _ = DT_INT32_REF tensorVal = intVal - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Int64 where tensorType _ = DT_INT64 tensorRefType _ = DT_INT64_REF tensorVal = int64Val - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode integral :: Integral a => Lens' [Int32] [a] integral = iso (fmap fromIntegral) (fmap fromIntegral) @@ -163,40 +133,140 @@ instance TensorType Word8 where tensorType _ = DT_UINT8 tensorRefType _ = DT_UINT8_REF tensorVal = intVal . integral - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Word16 where tensorType _ = DT_UINT16 tensorRefType _ = DT_UINT16_REF tensorVal = intVal . integral - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Int16 where tensorType _ = DT_INT16 tensorRefType _ = DT_INT16_REF tensorVal = intVal . integral - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType Int8 where tensorType _ = DT_INT8 tensorRefType _ = DT_INT8_REF tensorVal = intVal . integral - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode instance TensorType ByteString where tensorType _ = DT_STRING tensorRefType _ = DT_STRING_REF tensorVal = stringVal + +instance TensorType Bool where + tensorType _ = DT_BOOL + tensorRefType _ = DT_BOOL_REF + tensorVal = boolVal + +instance TensorType (Complex Float) where + tensorType _ = DT_COMPLEX64 + tensorRefType _ = DT_COMPLEX64 + tensorVal = error "TODO (Complex Float)" + +instance TensorType (Complex Double) where + tensorType _ = DT_COMPLEX128 + tensorRefType _ = DT_COMPLEX128 + tensorVal = error "TODO (Complex Double)" + + +-- | Tensor data with the correct memory layout for tensorflow. +newtype TensorData a = TensorData { unTensorData :: FFI.TensorData } + +-- | Types that can be converted to and from 'TensorData'. +-- +-- 'S.Vector' is the most efficient to encode/decode for most element types. +class TensorType a => TensorDataType s a where + -- | Decode the bytes of a 'TensorData' into an 's'. + decodeTensorData :: TensorData a -> s a + -- | Encode an 's' into a 'TensorData'. + -- + -- The values should be in row major order, e.g., + -- + -- element 0: index (0, ..., 0) + -- element 1: index (0, ..., 1) + -- ... + encodeTensorData :: Shape -> s a -> TensorData a + +-- All types, besides ByteString and Bool, are encoded as simple arrays and we +-- can use Vector.Storable to encode/decode by type casting pointers. + +-- TODO(fmayle): Assert that the data type matches the return type. +simpleDecode :: Storable a => TensorData a -> S.Vector a +simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData + +simpleEncode :: forall a . (TensorType a, Storable a) + => Shape -> S.Vector a -> TensorData a +simpleEncode (Shape xs) v = + if product xs /= fromIntegral (S.length v) + then error $ printf + "simpleEncode: bad vector length for shape %v: expected=%d got=%d" + (show xs) (product xs) (S.length v) + else TensorData (FFI.TensorData xs dt (S.unsafeCast v)) + where + dt = tensorType (undefined :: a) + +instance TensorDataType S.Vector Float where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Double where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Int8 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Int16 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Int32 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Int64 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Word8 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +instance TensorDataType S.Vector Word16 where + decodeTensorData = simpleDecode + encodeTensorData = simpleEncode + +-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes +-- encoding more expensive. It may make sense to define a custom boolean type. +instance TensorDataType S.Vector Bool where + decodeTensorData = + S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData + encodeTensorData (Shape xs) = + TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert + where + fromBool x = if x then 1 else 0 :: Word8 + +instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a) + => TensorDataType V.Vector a where + decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData + encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a) + +instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where + decodeTensorData = error "TODO (Complex Float)" + encodeTensorData = error "TODO (Complex Float)" + +instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where + decodeTensorData = error "TODO (Complex Double)" + encodeTensorData = error "TODO (Complex Double)" + +instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where -- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- table offsets for each element :: [Word64] -- at each element offset: -- string length :: VarInt64 -- string data :: [Word8] - -- TODO(fmayle): Benchmark these functions. decodeTensorData tensorData = either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ if expected /= count @@ -241,32 +311,21 @@ instance TensorType ByteString where -- Convert to Vector Word8. byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes --- TODO: Haskell and tensorflow use different byte sizes for bools, which makes --- encoding more expensive. It may make sense to define a custom boolean type. -instance TensorType Bool where - tensorType _ = DT_BOOL - tensorRefType _ = DT_BOOL_REF - tensorVal = boolVal - decodeTensorData = - S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData - encodeTensorData (Shape xs) = - TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert - where - fromBool x = if x then 1 else 0 :: Word8 +newtype Scalar a = Scalar {unScalar :: a} + deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat, + RealFrac, IsString) -instance TensorType (Complex Float) where - tensorType _ = DT_COMPLEX64 - tensorRefType _ = DT_COMPLEX64 - tensorVal = error "TODO (Complex Float)" - decodeTensorData = error "TODO (Complex Float)" - encodeTensorData = error "TODO (Complex Float)" +instance TensorDataType V.Vector a => TensorDataType Scalar a where + decodeTensorData = Scalar . headFromSingleton . decodeTensorData + encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y]) + +headFromSingleton :: V.Vector a -> a +headFromSingleton x + | V.length x == 1 = V.head x + | otherwise = error $ + "Unable to extract singleton from tensor of length " + ++ show (V.length x) -instance TensorType (Complex Double) where - tensorType _ = DT_COMPLEX128 - tensorRefType _ = DT_COMPLEX128 - tensorVal = error "TODO (Complex Double)" - decodeTensorData = error "TODO (Complex Double)" - encodeTensorData = error "TODO (Complex Double)" -- | Shape (dimensions) of a tensor. newtype Shape = Shape [Int64] deriving Show