mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Support fetching storable vectors + use them in benchmark (#50)
In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
This commit is contained in:
parent
91f508eb5c
commit
f170df9d13
11 changed files with 154 additions and 114 deletions
|
@ -22,14 +22,9 @@ import Data.List (genericLength)
|
||||||
import qualified Data.Text.IO as T
|
import qualified Data.Text.IO as T
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Nodes as TF
|
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
|
||||||
|
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
|
|
@ -48,10 +48,9 @@ import TensorFlow.Tensor
|
||||||
, tensorFromName
|
, tensorFromName
|
||||||
)
|
)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
import TensorFlow.Nodes (unScalar)
|
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||||
import TensorFlow.Types (TensorType(..), Shape(..))
|
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?), Assertion)
|
import Test.HUnit ((@=?), Assertion)
|
||||||
|
|
|
@ -42,7 +42,7 @@ import TensorFlow.Build
|
||||||
, withNameScope
|
, withNameScope
|
||||||
)
|
)
|
||||||
import TensorFlow.ControlFlow (named)
|
import TensorFlow.ControlFlow (named)
|
||||||
import TensorFlow.Nodes (unScalar)
|
import TensorFlow.Types (unScalar)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
( add
|
( add
|
||||||
, assign
|
, assign
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
|
@ -32,7 +33,7 @@ import qualified TensorFlow.Types as TF
|
||||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||||
-- back.
|
-- back.
|
||||||
testDynamicPartitionStitchInverse :: forall a.
|
testDynamicPartitionStitchInverse :: forall a.
|
||||||
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
||||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||||
let splitParts :: [TF.Tensor TF.Value a] =
|
let splitParts :: [TF.Tensor TF.Value a] =
|
||||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
@ -129,7 +130,8 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
|
|
||||||
-- Verifies that direct gather is the same as dynamic split into
|
-- Verifies that direct gather is the same as dynamic split into
|
||||||
-- partitions, followed by embedding lookup.
|
-- partitions, followed by embedding lookup.
|
||||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
testEmbeddingLookupUndoesSplit ::
|
||||||
|
forall a. (TF.TensorDataType V.Vector a, Show a, Eq a)
|
||||||
=> LookupExample a -> Property
|
=> LookupExample a -> Property
|
||||||
testEmbeddingLookupUndoesSplit
|
testEmbeddingLookupUndoesSplit
|
||||||
(LookupExample numParts
|
(LookupExample numParts
|
||||||
|
|
|
@ -5,7 +5,7 @@ import Control.Exception (evaluate)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Criterion.Main (defaultMain, bgroup, bench)
|
import Criterion.Main (defaultMain, bgroup, bench)
|
||||||
import Criterion.Types (Benchmarkable(..))
|
import Criterion.Types (Benchmarkable(..))
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector.Storable as S
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
|
|
||||||
|
@ -23,12 +23,12 @@ nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
||||||
go m
|
go m
|
||||||
|
|
||||||
-- | Benchmark feeding and fetching a vector.
|
-- | Benchmark feeding and fetching a vector.
|
||||||
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
|
feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float))
|
||||||
feedFetchBenchmark = do
|
feedFetchBenchmark = do
|
||||||
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
||||||
output <- TF.build (TF.render (TF.identity input))
|
output <- TF.build (TF.render (TF.identity input))
|
||||||
return $ \v -> do
|
return $ \v -> do
|
||||||
let shape = TF.Shape [fromIntegral (V.length v)]
|
let shape = TF.Shape [fromIntegral (S.length v)]
|
||||||
inputData = TF.encodeTensorData shape v
|
inputData = TF.encodeTensorData shape v
|
||||||
feeds = [TF.feed input inputData]
|
feeds = [TF.feed input inputData]
|
||||||
TF.runWithFeeds feeds output
|
TF.runWithFeeds feeds output
|
||||||
|
@ -36,8 +36,8 @@ feedFetchBenchmark = do
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ bgroup "feedFetch"
|
[ bgroup "feedFetch"
|
||||||
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
|
[ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0)
|
||||||
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
|
, bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0)
|
||||||
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
, bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0)
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
|
@ -85,7 +85,7 @@ instance Arbitrary a => Arbitrary (TensorDataInputs a) where
|
||||||
return $ TensorDataInputs sizes (V.fromList elems)
|
return $ TensorDataInputs sizes (V.fromList elems)
|
||||||
|
|
||||||
-- Test that a vector is unchanged after being encoded and decoded.
|
-- Test that a vector is unchanged after being encoded and decoded.
|
||||||
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
|
||||||
encodeDecodeProp (TensorDataInputs shape vec) =
|
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||||
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ module Main where
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.Nodes (Scalar(..))
|
import TensorFlow.Types (Scalar(..))
|
||||||
import TensorFlow.Ops (scalar)
|
import TensorFlow.Ops (scalar)
|
||||||
import TensorFlow.Queue
|
import TensorFlow.Queue
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
|
|
|
@ -36,7 +36,6 @@ module TensorFlow.Core
|
||||||
, buildWithSummary
|
, buildWithSummary
|
||||||
-- ** Running graphs
|
-- ** Running graphs
|
||||||
, Fetchable
|
, Fetchable
|
||||||
, Scalar(..)
|
|
||||||
, Nodes
|
, Nodes
|
||||||
, run
|
, run
|
||||||
, run_
|
, run_
|
||||||
|
@ -64,8 +63,10 @@ module TensorFlow.Core
|
||||||
, value
|
, value
|
||||||
, tensorFromName
|
, tensorFromName
|
||||||
-- ** Element types
|
-- ** Element types
|
||||||
|
, TensorType
|
||||||
, TensorData
|
, TensorData
|
||||||
, TensorType(decodeTensorData, encodeTensorData)
|
, TensorDataType(decodeTensorData, encodeTensorData)
|
||||||
|
, Scalar(..)
|
||||||
, Shape(..)
|
, Shape(..)
|
||||||
, OneOf
|
, OneOf
|
||||||
, type (/=)
|
, type (/=)
|
||||||
|
|
|
@ -12,8 +12,8 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
@ -24,12 +24,10 @@ import Control.Applicative (liftA2, liftA3)
|
||||||
import Data.Map.Strict (Map)
|
import Data.Map.Strict (Map)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.String (IsString)
|
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import Lens.Family2 ((^.))
|
import Lens.Family2 ((^.))
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import qualified Data.Vector as V
|
|
||||||
|
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Output
|
import TensorFlow.Output
|
||||||
|
@ -101,18 +99,12 @@ instance a ~ () => Fetchable ControlNode a where
|
||||||
instance Nodes (Tensor v a) where
|
instance Nodes (Tensor v a) where
|
||||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||||
|
|
||||||
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
|
|
||||||
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
|
|
||||||
|
|
||||||
fetchTensorVector :: forall a v . TensorType a
|
fetchTensorVector :: forall a v . TensorType a
|
||||||
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
|
=> Tensor v a -> Build (Fetch (TensorData a))
|
||||||
fetchTensorVector (Tensor _ o) = do
|
fetchTensorVector (Tensor _ o) = do
|
||||||
outputName <- renderOutput o
|
outputName <- renderOutput o
|
||||||
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||||
let tensorData = tensors Map.! outputName
|
let tensorData = tensors Map.! outputName
|
||||||
shape = Shape $ FFI.tensorDataDimensions tensorData
|
|
||||||
vec = decodeTensorData $ TensorData tensorData
|
|
||||||
|
|
||||||
expectedType = tensorType (undefined :: a)
|
expectedType = tensorType (undefined :: a)
|
||||||
actualType = FFI.tensorDataType tensorData
|
actualType = FFI.tensorDataType tensorData
|
||||||
badTypeError = error $ "Bad tensor type: expected "
|
badTypeError = error $ "Bad tensor type: expected "
|
||||||
|
@ -121,21 +113,12 @@ fetchTensorVector (Tensor _ o) = do
|
||||||
++ show actualType
|
++ show actualType
|
||||||
in if expectedType /= actualType
|
in if expectedType /= actualType
|
||||||
then badTypeError
|
then badTypeError
|
||||||
else (shape, vec)
|
else TensorData tensorData
|
||||||
|
|
||||||
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
|
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
|
||||||
-- the TensorType of each other.
|
-- the TensorType of each other.
|
||||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
|
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
|
||||||
getFetch t = fmap snd <$> fetchTensorVector t
|
getFetch = fetchTensorVector
|
||||||
|
|
||||||
newtype Scalar a = Scalar {unScalar :: a}
|
instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
|
||||||
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
getFetch t = fmap decodeTensorData <$> fetchTensorVector t
|
||||||
RealFrac, IsString)
|
|
||||||
|
|
||||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
|
|
||||||
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
|
|
||||||
where
|
|
||||||
headFromSingleton [x] = x
|
|
||||||
headFromSingleton xs
|
|
||||||
= error $ "Unable to extract singleton from tensor of length "
|
|
||||||
++ show (length xs)
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
@ -30,6 +31,8 @@
|
||||||
module TensorFlow.Types
|
module TensorFlow.Types
|
||||||
( TensorType(..)
|
( TensorType(..)
|
||||||
, TensorData(..)
|
, TensorData(..)
|
||||||
|
, TensorDataType(..)
|
||||||
|
, Scalar(..)
|
||||||
, Shape(..)
|
, Shape(..)
|
||||||
, protoShape
|
, protoShape
|
||||||
, Attribute(..)
|
, Attribute(..)
|
||||||
|
@ -50,11 +53,13 @@ import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Int (Int8, Int16, Int32, Int64)
|
import Data.Int (Int8, Int16, Int32, Int64)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
|
import Data.String (IsString)
|
||||||
import Data.Word (Word8, Word16, Word64)
|
import Data.Word (Word8, Word16, Word64)
|
||||||
import Foreign.Storable (Storable)
|
import Foreign.Storable (Storable)
|
||||||
import GHC.Exts (Constraint, IsList(..))
|
import GHC.Exts (Constraint, IsList(..))
|
||||||
import Lens.Family2 (Lens', view, (&), (.~))
|
import Lens.Family2 (Lens', view, (&), (.~))
|
||||||
import Lens.Family2.Unchecked (iso)
|
import Lens.Family2.Unchecked (iso)
|
||||||
|
import Text.Printf (printf)
|
||||||
import qualified Data.Attoparsec.ByteString as Atto
|
import qualified Data.Attoparsec.ByteString as Atto
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
|
@ -95,66 +100,31 @@ import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
-- | Data about a tensor that is encoded for the TensorFlow APIs.
|
|
||||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
|
||||||
|
|
||||||
-- | The class of scalar types supported by tensorflow.
|
-- | The class of scalar types supported by tensorflow.
|
||||||
class TensorType a where
|
class TensorType a where
|
||||||
tensorType :: a -> DataType
|
tensorType :: a -> DataType
|
||||||
tensorRefType :: a -> DataType
|
tensorRefType :: a -> DataType
|
||||||
tensorVal :: Lens' TensorProto [a]
|
tensorVal :: Lens' TensorProto [a]
|
||||||
-- | Decode the bytes of a TensorData into a Vector.
|
|
||||||
decodeTensorData :: TensorData a -> V.Vector a
|
|
||||||
-- | Encode a Vector into a TensorData.
|
|
||||||
--
|
|
||||||
-- The values should be in row major order, e.g.,
|
|
||||||
--
|
|
||||||
-- element 0: index (0, ..., 0)
|
|
||||||
-- element 1: index (0, ..., 1)
|
|
||||||
-- ...
|
|
||||||
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
|
||||||
|
|
||||||
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
|
||||||
-- Vector.Storable to encode/decode by type casting pointers.
|
|
||||||
|
|
||||||
-- TODO(fmayle): Assert that the data type matches the return type.
|
|
||||||
simpleDecode :: Storable a => TensorData a -> V.Vector a
|
|
||||||
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
|
||||||
|
|
||||||
simpleEncode :: forall a . (TensorType a, Storable a)
|
|
||||||
=> Shape -> V.Vector a -> TensorData a
|
|
||||||
simpleEncode (Shape xs)
|
|
||||||
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
|
|
||||||
where
|
|
||||||
dt = tensorType (undefined :: a)
|
|
||||||
|
|
||||||
instance TensorType Float where
|
instance TensorType Float where
|
||||||
tensorType _ = DT_FLOAT
|
tensorType _ = DT_FLOAT
|
||||||
tensorRefType _ = DT_FLOAT_REF
|
tensorRefType _ = DT_FLOAT_REF
|
||||||
tensorVal = floatVal
|
tensorVal = floatVal
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Double where
|
instance TensorType Double where
|
||||||
tensorType _ = DT_DOUBLE
|
tensorType _ = DT_DOUBLE
|
||||||
tensorRefType _ = DT_DOUBLE_REF
|
tensorRefType _ = DT_DOUBLE_REF
|
||||||
tensorVal = doubleVal
|
tensorVal = doubleVal
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Int32 where
|
instance TensorType Int32 where
|
||||||
tensorType _ = DT_INT32
|
tensorType _ = DT_INT32
|
||||||
tensorRefType _ = DT_INT32_REF
|
tensorRefType _ = DT_INT32_REF
|
||||||
tensorVal = intVal
|
tensorVal = intVal
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Int64 where
|
instance TensorType Int64 where
|
||||||
tensorType _ = DT_INT64
|
tensorType _ = DT_INT64
|
||||||
tensorRefType _ = DT_INT64_REF
|
tensorRefType _ = DT_INT64_REF
|
||||||
tensorVal = int64Val
|
tensorVal = int64Val
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
integral :: Integral a => Lens' [Int32] [a]
|
integral :: Integral a => Lens' [Int32] [a]
|
||||||
integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
||||||
|
@ -163,40 +133,140 @@ instance TensorType Word8 where
|
||||||
tensorType _ = DT_UINT8
|
tensorType _ = DT_UINT8
|
||||||
tensorRefType _ = DT_UINT8_REF
|
tensorRefType _ = DT_UINT8_REF
|
||||||
tensorVal = intVal . integral
|
tensorVal = intVal . integral
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Word16 where
|
instance TensorType Word16 where
|
||||||
tensorType _ = DT_UINT16
|
tensorType _ = DT_UINT16
|
||||||
tensorRefType _ = DT_UINT16_REF
|
tensorRefType _ = DT_UINT16_REF
|
||||||
tensorVal = intVal . integral
|
tensorVal = intVal . integral
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Int16 where
|
instance TensorType Int16 where
|
||||||
tensorType _ = DT_INT16
|
tensorType _ = DT_INT16
|
||||||
tensorRefType _ = DT_INT16_REF
|
tensorRefType _ = DT_INT16_REF
|
||||||
tensorVal = intVal . integral
|
tensorVal = intVal . integral
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType Int8 where
|
instance TensorType Int8 where
|
||||||
tensorType _ = DT_INT8
|
tensorType _ = DT_INT8
|
||||||
tensorRefType _ = DT_INT8_REF
|
tensorRefType _ = DT_INT8_REF
|
||||||
tensorVal = intVal . integral
|
tensorVal = intVal . integral
|
||||||
decodeTensorData = simpleDecode
|
|
||||||
encodeTensorData = simpleEncode
|
|
||||||
|
|
||||||
instance TensorType ByteString where
|
instance TensorType ByteString where
|
||||||
tensorType _ = DT_STRING
|
tensorType _ = DT_STRING
|
||||||
tensorRefType _ = DT_STRING_REF
|
tensorRefType _ = DT_STRING_REF
|
||||||
tensorVal = stringVal
|
tensorVal = stringVal
|
||||||
|
|
||||||
|
instance TensorType Bool where
|
||||||
|
tensorType _ = DT_BOOL
|
||||||
|
tensorRefType _ = DT_BOOL_REF
|
||||||
|
tensorVal = boolVal
|
||||||
|
|
||||||
|
instance TensorType (Complex Float) where
|
||||||
|
tensorType _ = DT_COMPLEX64
|
||||||
|
tensorRefType _ = DT_COMPLEX64
|
||||||
|
tensorVal = error "TODO (Complex Float)"
|
||||||
|
|
||||||
|
instance TensorType (Complex Double) where
|
||||||
|
tensorType _ = DT_COMPLEX128
|
||||||
|
tensorRefType _ = DT_COMPLEX128
|
||||||
|
tensorVal = error "TODO (Complex Double)"
|
||||||
|
|
||||||
|
|
||||||
|
-- | Tensor data with the correct memory layout for tensorflow.
|
||||||
|
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||||
|
|
||||||
|
-- | Types that can be converted to and from 'TensorData'.
|
||||||
|
--
|
||||||
|
-- 'S.Vector' is the most efficient to encode/decode for most element types.
|
||||||
|
class TensorType a => TensorDataType s a where
|
||||||
|
-- | Decode the bytes of a 'TensorData' into an 's'.
|
||||||
|
decodeTensorData :: TensorData a -> s a
|
||||||
|
-- | Encode an 's' into a 'TensorData'.
|
||||||
|
--
|
||||||
|
-- The values should be in row major order, e.g.,
|
||||||
|
--
|
||||||
|
-- element 0: index (0, ..., 0)
|
||||||
|
-- element 1: index (0, ..., 1)
|
||||||
|
-- ...
|
||||||
|
encodeTensorData :: Shape -> s a -> TensorData a
|
||||||
|
|
||||||
|
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
|
||||||
|
-- can use Vector.Storable to encode/decode by type casting pointers.
|
||||||
|
|
||||||
|
-- TODO(fmayle): Assert that the data type matches the return type.
|
||||||
|
simpleDecode :: Storable a => TensorData a -> S.Vector a
|
||||||
|
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
||||||
|
|
||||||
|
simpleEncode :: forall a . (TensorType a, Storable a)
|
||||||
|
=> Shape -> S.Vector a -> TensorData a
|
||||||
|
simpleEncode (Shape xs) v =
|
||||||
|
if product xs /= fromIntegral (S.length v)
|
||||||
|
then error $ printf
|
||||||
|
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
|
||||||
|
(show xs) (product xs) (S.length v)
|
||||||
|
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
|
||||||
|
where
|
||||||
|
dt = tensorType (undefined :: a)
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Float where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Double where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Int8 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Int16 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Int32 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Int64 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Word8 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
instance TensorDataType S.Vector Word16 where
|
||||||
|
decodeTensorData = simpleDecode
|
||||||
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||||
|
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||||
|
instance TensorDataType S.Vector Bool where
|
||||||
|
decodeTensorData =
|
||||||
|
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
||||||
|
encodeTensorData (Shape xs) =
|
||||||
|
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
||||||
|
where
|
||||||
|
fromBool x = if x then 1 else 0 :: Word8
|
||||||
|
|
||||||
|
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
|
||||||
|
=> TensorDataType V.Vector a where
|
||||||
|
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
|
||||||
|
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
|
||||||
|
|
||||||
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
|
||||||
|
decodeTensorData = error "TODO (Complex Float)"
|
||||||
|
encodeTensorData = error "TODO (Complex Float)"
|
||||||
|
|
||||||
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
||||||
|
decodeTensorData = error "TODO (Complex Double)"
|
||||||
|
encodeTensorData = error "TODO (Complex Double)"
|
||||||
|
|
||||||
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
||||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||||
-- table offsets for each element :: [Word64]
|
-- table offsets for each element :: [Word64]
|
||||||
-- at each element offset:
|
-- at each element offset:
|
||||||
-- string length :: VarInt64
|
-- string length :: VarInt64
|
||||||
-- string data :: [Word8]
|
-- string data :: [Word8]
|
||||||
-- TODO(fmayle): Benchmark these functions.
|
|
||||||
decodeTensorData tensorData =
|
decodeTensorData tensorData =
|
||||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||||
if expected /= count
|
if expected /= count
|
||||||
|
@ -241,32 +311,21 @@ instance TensorType ByteString where
|
||||||
-- Convert to Vector Word8.
|
-- Convert to Vector Word8.
|
||||||
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||||
|
|
||||||
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
newtype Scalar a = Scalar {unScalar :: a}
|
||||||
-- encoding more expensive. It may make sense to define a custom boolean type.
|
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
||||||
instance TensorType Bool where
|
RealFrac, IsString)
|
||||||
tensorType _ = DT_BOOL
|
|
||||||
tensorRefType _ = DT_BOOL_REF
|
|
||||||
tensorVal = boolVal
|
|
||||||
decodeTensorData =
|
|
||||||
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
|
||||||
encodeTensorData (Shape xs) =
|
|
||||||
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
|
||||||
where
|
|
||||||
fromBool x = if x then 1 else 0 :: Word8
|
|
||||||
|
|
||||||
instance TensorType (Complex Float) where
|
instance TensorDataType V.Vector a => TensorDataType Scalar a where
|
||||||
tensorType _ = DT_COMPLEX64
|
decodeTensorData = Scalar . headFromSingleton . decodeTensorData
|
||||||
tensorRefType _ = DT_COMPLEX64
|
encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
|
||||||
tensorVal = error "TODO (Complex Float)"
|
|
||||||
decodeTensorData = error "TODO (Complex Float)"
|
headFromSingleton :: V.Vector a -> a
|
||||||
encodeTensorData = error "TODO (Complex Float)"
|
headFromSingleton x
|
||||||
|
| V.length x == 1 = V.head x
|
||||||
|
| otherwise = error $
|
||||||
|
"Unable to extract singleton from tensor of length "
|
||||||
|
++ show (V.length x)
|
||||||
|
|
||||||
instance TensorType (Complex Double) where
|
|
||||||
tensorType _ = DT_COMPLEX128
|
|
||||||
tensorRefType _ = DT_COMPLEX128
|
|
||||||
tensorVal = error "TODO (Complex Double)"
|
|
||||||
decodeTensorData = error "TODO (Complex Double)"
|
|
||||||
encodeTensorData = error "TODO (Complex Double)"
|
|
||||||
|
|
||||||
-- | Shape (dimensions) of a tensor.
|
-- | Shape (dimensions) of a tensor.
|
||||||
newtype Shape = Shape [Int64] deriving Show
|
newtype Shape = Shape [Int64] deriving Show
|
||||||
|
|
Loading…
Reference in a new issue