Create a separate TensorProtoLens class.

Only a handful of types had sensible tensorVal implementations.
This is now evident in type signatures at the expense of them
being more verbose.
This commit is contained in:
Greg Steuck 2016-11-02 16:16:16 -07:00
parent 29f11d351d
commit e511f49828
7 changed files with 44 additions and 30 deletions

View File

@ -35,6 +35,7 @@ import TensorFlow.Tensor ( Tensor(..)
, Value(..) , Value(..)
) )
import TensorFlow.Types ( TensorType(..) import TensorFlow.Types ( TensorType(..)
, TensorProtoLens
, OneOf , OneOf
) )
import TensorFlow.Ops ( zerosLike import TensorFlow.Ops ( zerosLike
@ -70,7 +71,7 @@ import TensorFlow.Ops ( zerosLike
-- --
-- `logits` and `targets` must have the same type and shape. -- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a) :: (OneOf '[Float, Double] a, TensorType a, TensorProtoLens a, Num a)
=> Tensor Value a -- ^ __logits__ => Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__ -> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a) -> Build (Tensor Value a)

View File

@ -27,7 +27,7 @@ import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Ops () -- Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType, TensorProtoLens)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
-- | Looks up `ids` in a list of embedding tensors. -- | Looks up `ids` in a list of embedding tensors.
@ -46,6 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v . embeddingLookup :: forall a b v .
( TensorType a ( TensorType a
, TensorProtoLens b
, OneOf '[Int64, Int32] b , OneOf '[Int64, Int32] b
, Num b , Num b
) )

View File

@ -94,13 +94,13 @@ import TensorFlow.Tensor
, tensorOutput , tensorOutput
, tensorAttr , tensorAttr
) )
import TensorFlow.Types (OneOf, TensorType, attrLens) import TensorFlow.Types (OneOf, TensorType, TensorProtoLens, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name) (NodeDef, attr, input, op, name)
type GradientCompatible a = type GradientCompatible a =
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason. -- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a) (Num a, OneOf '[ Float, Complex Float, Complex Double ] a, TensorProtoLens a)
-- TODO(fmayle): Support control flow. -- TODO(fmayle): Support control flow.
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions. -- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.

View File

@ -138,6 +138,7 @@ import qualified Prelude (abs)
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression -- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
-- "1". -- "1".
instance ( TensorType a instance ( TensorType a
, TensorProtoLens a
, Num a , Num a
, v ~ Value , v ~ Value
, OneOf '[ Double, Float, Int32, Int64 , OneOf '[ Double, Float, Int32, Int64
@ -191,7 +192,7 @@ initializedVariable initializer = do
-- | Creates a zero-initialized variable with the given shape. -- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable zeroInitializedVariable
:: (TensorType a, Num a) => :: (TensorType a, TensorProtoLens a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a) TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros zeroInitializedVariable = initializedVariable . zeros
@ -238,7 +239,8 @@ restore path x = do
-- element 0: index (0, ..., 0) -- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1) -- element 1: index (0, ..., 1)
-- ... -- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a constant :: forall a . (TensorType a, TensorProtoLens a)
=> Shape -> [a] -> Tensor Value a
constant (Shape shape') values constant (Shape shape') values
| invalidLength = error invalidLengthMsg | invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const" | otherwise = buildOp $ opDef "Const"
@ -258,11 +260,11 @@ constant (Shape shape') values
& tensorVal .~ values & tensorVal .~ values
-- | Create a constant vector. -- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a vector :: (TensorType a, TensorProtoLens a) => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs vector xs = constant [fromIntegral $ length xs] xs
-- | Create a constant scalar. -- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a scalar :: (TensorType a, TensorProtoLens a) => a -> Tensor Value a
scalar x = constant [] [x] scalar x = constant [] [x]
-- Random tensor from the unit normal distribution with bounded values. -- Random tensor from the unit normal distribution with bounded values.
@ -273,7 +275,8 @@ truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64) & opAttr "T" .~ tensorType (undefined :: Int64)
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a zeros :: (Num a, TensorType a, TensorProtoLens a)
=> Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32 shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32

View File

@ -31,8 +31,9 @@ import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input -- DynamicSplit is undone with DynamicStitch to get the original input
-- back. -- back.
testDynamicPartitionStitchInverse :: forall a. testDynamicPartitionStitchInverse ::
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
=> StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] = let splitParts :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor CoreOps.dynamicPartition numParts (TF.vector values) partTensor

View File

@ -36,8 +36,9 @@ import qualified TensorFlow.Types as TF
-- Verifies that direct gather is the same as dynamic split into -- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup. -- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) testEmbeddingLookupUndoesSplit ::
=> LookupExample a -> Property forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit testEmbeddingLookupUndoesSplit
(LookupExample numParts (LookupExample numParts
shape@(TF.Shape (firstDim : restDims)) shape@(TF.Shape (firstDim : restDims))

View File

@ -28,6 +28,7 @@
module TensorFlow.Types module TensorFlow.Types
( TensorType(..) ( TensorType(..)
, TensorProtoLens(..)
, TensorData(..) , TensorData(..)
, Shape(..) , Shape(..)
, protoShape , protoShape
@ -76,13 +77,12 @@ import Proto.Tensorflow.Core.Framework.AttrValue
) )
import Proto.Tensorflow.Core.Framework.Tensor as Tensor import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..) ( TensorProto(..)
, floatVal , boolVal
, doubleVal , doubleVal
, floatVal
, int64Val
, intVal , intVal
, stringVal , stringVal
, int64Val
, stringVal
, boolVal
) )
import Proto.Tensorflow.Core.Framework.TensorShape import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..) ( TensorShapeProto(..)
@ -101,7 +101,6 @@ newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
class TensorType a where class TensorType a where
tensorType :: a -> DataType tensorType :: a -> DataType
tensorRefType :: a -> DataType tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
-- | Decode the bytes of a TensorData into a Vector. -- | Decode the bytes of a TensorData into a Vector.
decodeTensorData :: TensorData a -> V.Vector a decodeTensorData :: TensorData a -> V.Vector a
-- | Encode a Vector into a TensorData. -- | Encode a Vector into a TensorData.
@ -113,6 +112,25 @@ class TensorType a where
-- ... -- ...
encodeTensorData :: Shape -> V.Vector a -> TensorData a encodeTensorData :: Shape -> V.Vector a -> TensorData a
-- | Class of types that can be used for constructing constant tensors.
class TensorProtoLens a where
tensorVal :: Lens' TensorProto [a]
instance TensorProtoLens Float where
tensorVal = floatVal
instance TensorProtoLens Double where
tensorVal = doubleVal
instance TensorProtoLens Int32 where
tensorVal = intVal
instance TensorProtoLens Int64 where
tensorVal = int64Val
instance TensorProtoLens ByteString where
tensorVal = stringVal
-- All types, besides ByteString, are encoded as simple arrays and we can use -- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers. -- Vector.Storable to encode/decode by type casting pointers.
@ -130,28 +148,24 @@ simpleEncode (Shape xs)
instance TensorType Float where instance TensorType Float where
tensorType _ = DT_FLOAT tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Double where instance TensorType Double where
tensorType _ = DT_DOUBLE tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Int32 where instance TensorType Int32 where
tensorType _ = DT_INT32 tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF tensorRefType _ = DT_INT32_REF
tensorVal = intVal
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Int64 where instance TensorType Int64 where
tensorType _ = DT_INT64 tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
@ -161,40 +175,36 @@ integral = iso (fmap fromIntegral) (fmap fromIntegral)
instance TensorType Word8 where instance TensorType Word8 where
tensorType _ = DT_UINT8 tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Word16 where instance TensorType Word16 where
tensorType _ = DT_UINT16 tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Int16 where instance TensorType Int16 where
tensorType _ = DT_INT16 tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType Int8 where instance TensorType Int8 where
tensorType _ = DT_INT8 tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType ByteString where instance TensorType ByteString where
tensorType _ = DT_STRING tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64] -- table offsets for each element :: [Word64]
-- at each element offset: -- at each element offset:
-- string length :: VarInt64 -- string length :: VarInt64
-- string data :: [Word8] -- string data :: [Word8]
-- C++ counterparts of these are TF_Tensor_{En,De}codeStrings.
-- TODO(fmayle): Benchmark these functions. -- TODO(fmayle): Benchmark these functions.
decodeTensorData tensorData = decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
@ -244,21 +254,18 @@ instance TensorType ByteString where
instance TensorType Bool where instance TensorType Bool where
tensorType _ = DT_BOOL tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData = simpleDecode decodeTensorData = simpleDecode
encodeTensorData = simpleEncode encodeTensorData = simpleEncode
instance TensorType (Complex Float) where instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64 tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64 tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)" decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)" encodeTensorData = error "TODO (Complex Float)"
instance TensorType (Complex Double) where instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128 tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128 tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)" decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)" encodeTensorData = error "TODO (Complex Double)"