mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 11:29:47 +01:00
Support fetching storable vectors + use them in benchmark (#50)
In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
This commit is contained in:
parent
91f508eb5c
commit
f170df9d13
11 changed files with 154 additions and 114 deletions
|
@ -22,14 +22,9 @@ import Data.List (genericLength)
|
|||
import qualified Data.Text.IO as T
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
import TensorFlow.Examples.MNIST.InputData
|
||||
import TensorFlow.Examples.MNIST.Parse
|
||||
|
|
|
@ -48,10 +48,9 @@ import TensorFlow.Tensor
|
|||
, tensorFromName
|
||||
)
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Nodes (unScalar)
|
||||
import TensorFlow.Session
|
||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||
import TensorFlow.Types (TensorType(..), Shape(..))
|
||||
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?), Assertion)
|
||||
|
|
|
@ -42,7 +42,7 @@ import TensorFlow.Build
|
|||
, withNameScope
|
||||
)
|
||||
import TensorFlow.ControlFlow (named)
|
||||
import TensorFlow.Nodes (unScalar)
|
||||
import TensorFlow.Types (unScalar)
|
||||
import TensorFlow.Ops
|
||||
( add
|
||||
, assign
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.Int (Int32, Int64)
|
||||
|
@ -32,7 +33,7 @@ import qualified TensorFlow.Types as TF
|
|||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||
-- back.
|
||||
testDynamicPartitionStitchInverse :: forall a.
|
||||
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
||||
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||
let splitParts :: [TF.Tensor TF.Value a] =
|
||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
|
@ -129,8 +130,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
-- partitions, followed by embedding lookup.
|
||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||
=> LookupExample a -> Property
|
||||
testEmbeddingLookupUndoesSplit ::
|
||||
forall a. (TF.TensorDataType V.Vector a, Show a, Eq a)
|
||||
=> LookupExample a -> Property
|
||||
testEmbeddingLookupUndoesSplit
|
||||
(LookupExample numParts
|
||||
shape@(TF.Shape (firstDim : restDims))
|
||||
|
|
|
@ -5,7 +5,7 @@ import Control.Exception (evaluate)
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Criterion.Main (defaultMain, bgroup, bench)
|
||||
import Criterion.Types (Benchmarkable(..))
|
||||
import qualified Data.Vector as V
|
||||
import qualified Data.Vector.Storable as S
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
|
||||
|
@ -23,12 +23,12 @@ nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
|||
go m
|
||||
|
||||
-- | Benchmark feeding and fetching a vector.
|
||||
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
|
||||
feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float))
|
||||
feedFetchBenchmark = do
|
||||
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
||||
output <- TF.build (TF.render (TF.identity input))
|
||||
return $ \v -> do
|
||||
let shape = TF.Shape [fromIntegral (V.length v)]
|
||||
let shape = TF.Shape [fromIntegral (S.length v)]
|
||||
inputData = TF.encodeTensorData shape v
|
||||
feeds = [TF.feed input inputData]
|
||||
TF.runWithFeeds feeds output
|
||||
|
@ -36,8 +36,8 @@ feedFetchBenchmark = do
|
|||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ bgroup "feedFetch"
|
||||
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
|
||||
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
|
||||
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
||||
[ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0)
|
||||
, bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0)
|
||||
, bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0)
|
||||
]
|
||||
]
|
||||
|
|
|
@ -85,7 +85,7 @@ instance Arbitrary a => Arbitrary (TensorDataInputs a) where
|
|||
return $ TensorDataInputs sizes (V.fromList elems)
|
||||
|
||||
-- Test that a vector is unchanged after being encoded and decoded.
|
||||
encodeDecodeProp :: (TF.TensorType a, Eq a) => TensorDataInputs a -> Bool
|
||||
encodeDecodeProp :: (TF.TensorDataType V.Vector a, Eq a) => TensorDataInputs a -> Bool
|
||||
encodeDecodeProp (TensorDataInputs shape vec) =
|
||||
TF.decodeTensorData (TF.encodeTensorData (TF.Shape shape) vec) == vec
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ module Main where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Nodes (Scalar(..))
|
||||
import TensorFlow.Types (Scalar(..))
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Queue
|
||||
import TensorFlow.Session
|
||||
|
|
|
@ -36,7 +36,6 @@ module TensorFlow.Core
|
|||
, buildWithSummary
|
||||
-- ** Running graphs
|
||||
, Fetchable
|
||||
, Scalar(..)
|
||||
, Nodes
|
||||
, run
|
||||
, run_
|
||||
|
@ -64,8 +63,10 @@ module TensorFlow.Core
|
|||
, value
|
||||
, tensorFromName
|
||||
-- ** Element types
|
||||
, TensorType
|
||||
, TensorData
|
||||
, TensorType(decodeTensorData, encodeTensorData)
|
||||
, TensorDataType(decodeTensorData, encodeTensorData)
|
||||
, Scalar(..)
|
||||
, Shape(..)
|
||||
, OneOf
|
||||
, type (/=)
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
@ -24,12 +24,10 @@ import Control.Applicative (liftA2, liftA3)
|
|||
import Data.Map.Strict (Map)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Set (Set)
|
||||
import Data.String (IsString)
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((^.))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
|
@ -101,18 +99,12 @@ instance a ~ () => Fetchable ControlNode a where
|
|||
instance Nodes (Tensor v a) where
|
||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
fetchTensorList :: TensorType a => Tensor v a -> Build (Fetch (Shape, [a]))
|
||||
fetchTensorList t = fmap (fmap V.toList) <$> fetchTensorVector t
|
||||
|
||||
fetchTensorVector :: forall a v . TensorType a
|
||||
=> Tensor v a -> Build (Fetch (Shape, V.Vector a))
|
||||
=> Tensor v a -> Build (Fetch (TensorData a))
|
||||
fetchTensorVector (Tensor _ o) = do
|
||||
outputName <- renderOutput o
|
||||
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||
let tensorData = tensors Map.! outputName
|
||||
shape = Shape $ FFI.tensorDataDimensions tensorData
|
||||
vec = decodeTensorData $ TensorData tensorData
|
||||
|
||||
expectedType = tensorType (undefined :: a)
|
||||
actualType = FFI.tensorDataType tensorData
|
||||
badTypeError = error $ "Bad tensor type: expected "
|
||||
|
@ -121,21 +113,12 @@ fetchTensorVector (Tensor _ o) = do
|
|||
++ show actualType
|
||||
in if expectedType /= actualType
|
||||
then badTypeError
|
||||
else (shape, vec)
|
||||
else TensorData tensorData
|
||||
|
||||
-- The constraint "a ~ a'" means that the input/output of fetch can constrain
|
||||
-- the TensorType of each other.
|
||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (V.Vector a') where
|
||||
getFetch t = fmap snd <$> fetchTensorVector t
|
||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (TensorData a') where
|
||||
getFetch = fetchTensorVector
|
||||
|
||||
newtype Scalar a = Scalar {unScalar :: a}
|
||||
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
||||
RealFrac, IsString)
|
||||
|
||||
instance (TensorType a, a ~ a') => Fetchable (Tensor v a) (Scalar a') where
|
||||
getFetch t = fmap (Scalar . headFromSingleton . snd) <$> fetchTensorList t
|
||||
where
|
||||
headFromSingleton [x] = x
|
||||
headFromSingleton xs
|
||||
= error $ "Unable to extract singleton from tensor of length "
|
||||
++ show (length xs)
|
||||
instance (TensorType a, TensorDataType s a, a ~ a') => Fetchable (Tensor v a) (s a') where
|
||||
getFetch t = fmap decodeTensorData <$> fetchTensorVector t
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
@ -30,6 +31,8 @@
|
|||
module TensorFlow.Types
|
||||
( TensorType(..)
|
||||
, TensorData(..)
|
||||
, TensorDataType(..)
|
||||
, Scalar(..)
|
||||
, Shape(..)
|
||||
, protoShape
|
||||
, Attribute(..)
|
||||
|
@ -50,11 +53,13 @@ import Data.Complex (Complex)
|
|||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.String (IsString)
|
||||
import Data.Word (Word8, Word16, Word64)
|
||||
import Foreign.Storable (Storable)
|
||||
import GHC.Exts (Constraint, IsList(..))
|
||||
import Lens.Family2 (Lens', view, (&), (.~))
|
||||
import Lens.Family2.Unchecked (iso)
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.Attoparsec.ByteString as Atto
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
|
@ -95,66 +100,31 @@ import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
|||
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | Data about a tensor that is encoded for the TensorFlow APIs.
|
||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||
|
||||
-- | The class of scalar types supported by tensorflow.
|
||||
class TensorType a where
|
||||
tensorType :: a -> DataType
|
||||
tensorRefType :: a -> DataType
|
||||
tensorVal :: Lens' TensorProto [a]
|
||||
-- | Decode the bytes of a TensorData into a Vector.
|
||||
decodeTensorData :: TensorData a -> V.Vector a
|
||||
-- | Encode a Vector into a TensorData.
|
||||
--
|
||||
-- The values should be in row major order, e.g.,
|
||||
--
|
||||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
||||
|
||||
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
||||
-- Vector.Storable to encode/decode by type casting pointers.
|
||||
|
||||
-- TODO(fmayle): Assert that the data type matches the return type.
|
||||
simpleDecode :: Storable a => TensorData a -> V.Vector a
|
||||
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
||||
|
||||
simpleEncode :: forall a . (TensorType a, Storable a)
|
||||
=> Shape -> V.Vector a -> TensorData a
|
||||
simpleEncode (Shape xs)
|
||||
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
|
||||
where
|
||||
dt = tensorType (undefined :: a)
|
||||
|
||||
instance TensorType Float where
|
||||
tensorType _ = DT_FLOAT
|
||||
tensorRefType _ = DT_FLOAT_REF
|
||||
tensorVal = floatVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Double where
|
||||
tensorType _ = DT_DOUBLE
|
||||
tensorRefType _ = DT_DOUBLE_REF
|
||||
tensorVal = doubleVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int32 where
|
||||
tensorType _ = DT_INT32
|
||||
tensorRefType _ = DT_INT32_REF
|
||||
tensorVal = intVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int64 where
|
||||
tensorType _ = DT_INT64
|
||||
tensorRefType _ = DT_INT64_REF
|
||||
tensorVal = int64Val
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
integral :: Integral a => Lens' [Int32] [a]
|
||||
integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
||||
|
@ -163,40 +133,140 @@ instance TensorType Word8 where
|
|||
tensorType _ = DT_UINT8
|
||||
tensorRefType _ = DT_UINT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Word16 where
|
||||
tensorType _ = DT_UINT16
|
||||
tensorRefType _ = DT_UINT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int16 where
|
||||
tensorType _ = DT_INT16
|
||||
tensorRefType _ = DT_INT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int8 where
|
||||
tensorType _ = DT_INT8
|
||||
tensorRefType _ = DT_INT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType ByteString where
|
||||
tensorType _ = DT_STRING
|
||||
tensorRefType _ = DT_STRING_REF
|
||||
tensorVal = stringVal
|
||||
|
||||
instance TensorType Bool where
|
||||
tensorType _ = DT_BOOL
|
||||
tensorRefType _ = DT_BOOL_REF
|
||||
tensorVal = boolVal
|
||||
|
||||
instance TensorType (Complex Float) where
|
||||
tensorType _ = DT_COMPLEX64
|
||||
tensorRefType _ = DT_COMPLEX64
|
||||
tensorVal = error "TODO (Complex Float)"
|
||||
|
||||
instance TensorType (Complex Double) where
|
||||
tensorType _ = DT_COMPLEX128
|
||||
tensorRefType _ = DT_COMPLEX128
|
||||
tensorVal = error "TODO (Complex Double)"
|
||||
|
||||
|
||||
-- | Tensor data with the correct memory layout for tensorflow.
|
||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||
|
||||
-- | Types that can be converted to and from 'TensorData'.
|
||||
--
|
||||
-- 'S.Vector' is the most efficient to encode/decode for most element types.
|
||||
class TensorType a => TensorDataType s a where
|
||||
-- | Decode the bytes of a 'TensorData' into an 's'.
|
||||
decodeTensorData :: TensorData a -> s a
|
||||
-- | Encode an 's' into a 'TensorData'.
|
||||
--
|
||||
-- The values should be in row major order, e.g.,
|
||||
--
|
||||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
encodeTensorData :: Shape -> s a -> TensorData a
|
||||
|
||||
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
|
||||
-- can use Vector.Storable to encode/decode by type casting pointers.
|
||||
|
||||
-- TODO(fmayle): Assert that the data type matches the return type.
|
||||
simpleDecode :: Storable a => TensorData a -> S.Vector a
|
||||
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
||||
|
||||
simpleEncode :: forall a . (TensorType a, Storable a)
|
||||
=> Shape -> S.Vector a -> TensorData a
|
||||
simpleEncode (Shape xs) v =
|
||||
if product xs /= fromIntegral (S.length v)
|
||||
then error $ printf
|
||||
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
|
||||
(show xs) (product xs) (S.length v)
|
||||
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
|
||||
where
|
||||
dt = tensorType (undefined :: a)
|
||||
|
||||
instance TensorDataType S.Vector Float where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Double where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Int8 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Int16 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Int32 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Int64 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Word8 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType S.Vector Word16 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||
instance TensorDataType S.Vector Bool where
|
||||
decodeTensorData =
|
||||
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
||||
encodeTensorData (Shape xs) =
|
||||
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
||||
where
|
||||
fromBool x = if x then 1 else 0 :: Word8
|
||||
|
||||
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
|
||||
=> TensorDataType V.Vector a where
|
||||
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
|
||||
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||
-- table offsets for each element :: [Word64]
|
||||
-- at each element offset:
|
||||
-- string length :: VarInt64
|
||||
-- string data :: [Word8]
|
||||
-- TODO(fmayle): Benchmark these functions.
|
||||
decodeTensorData tensorData =
|
||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||
if expected /= count
|
||||
|
@ -241,32 +311,21 @@ instance TensorType ByteString where
|
|||
-- Convert to Vector Word8.
|
||||
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||
|
||||
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||
instance TensorType Bool where
|
||||
tensorType _ = DT_BOOL
|
||||
tensorRefType _ = DT_BOOL_REF
|
||||
tensorVal = boolVal
|
||||
decodeTensorData =
|
||||
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
||||
encodeTensorData (Shape xs) =
|
||||
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
||||
where
|
||||
fromBool x = if x then 1 else 0 :: Word8
|
||||
newtype Scalar a = Scalar {unScalar :: a}
|
||||
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
||||
RealFrac, IsString)
|
||||
|
||||
instance TensorType (Complex Float) where
|
||||
tensorType _ = DT_COMPLEX64
|
||||
tensorRefType _ = DT_COMPLEX64
|
||||
tensorVal = error "TODO (Complex Float)"
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
instance TensorDataType V.Vector a => TensorDataType Scalar a where
|
||||
decodeTensorData = Scalar . headFromSingleton . decodeTensorData
|
||||
encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
|
||||
|
||||
headFromSingleton :: V.Vector a -> a
|
||||
headFromSingleton x
|
||||
| V.length x == 1 = V.head x
|
||||
| otherwise = error $
|
||||
"Unable to extract singleton from tensor of length "
|
||||
++ show (V.length x)
|
||||
|
||||
instance TensorType (Complex Double) where
|
||||
tensorType _ = DT_COMPLEX128
|
||||
tensorRefType _ = DT_COMPLEX128
|
||||
tensorVal = error "TODO (Complex Double)"
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
-- | Shape (dimensions) of a tensor.
|
||||
newtype Shape = Shape [Int64] deriving Show
|
||||
|
|
Loading…
Reference in a new issue