Create a separate TensorProtoLens class.

Only a handful of types had sensible tensorVal implementations.
This is now evident in type signatures at the expense of them
being more verbose.
This commit is contained in:
Greg Steuck 2016-11-02 16:16:16 -07:00
parent 29f11d351d
commit e511f49828
7 changed files with 44 additions and 30 deletions

View File

@ -35,6 +35,7 @@ import TensorFlow.Tensor ( Tensor(..)
, Value(..)
)
import TensorFlow.Types ( TensorType(..)
, TensorProtoLens
, OneOf
)
import TensorFlow.Ops ( zerosLike
@ -70,7 +71,7 @@ import TensorFlow.Ops ( zerosLike
--
-- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a)
:: (OneOf '[Float, Double] a, TensorType a, TensorProtoLens a, Num a)
=> Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a)

View File

@ -27,7 +27,7 @@ import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens)
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Looks up `ids` in a list of embedding tensors.
@ -46,6 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v .
( TensorType a
, TensorProtoLens b
, OneOf '[Int64, Int32] b
, Num b
)

View File

@ -94,13 +94,13 @@ import TensorFlow.Tensor
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (OneOf, TensorType, attrLens)
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
(NodeDef, attr, input, op, name)
type GradientCompatible a =
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a, TensorProtoLens a)
-- TODO(fmayle): Support control flow.
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.

View File

@ -138,6 +138,7 @@ import qualified Prelude (abs)
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
-- "1".
instance ( TensorType a
, TensorProtoLens a
, Num a
, v ~ Value
, OneOf '[ Double, Float, Int32, Int64
@ -191,7 +192,7 @@ initializedVariable initializer = do
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (TensorType a, Num a) =>
:: (TensorType a, TensorProtoLens a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
@ -238,7 +239,8 @@ restore path x = do
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
constant :: forall a . (TensorType a, TensorProtoLens a)
=> Shape -> [a] -> Tensor Value a
constant (Shape shape') values
| invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const"
@ -258,11 +260,11 @@ constant (Shape shape') values
& tensorVal .~ values
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector :: (TensorType a, TensorProtoLens a) => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a
scalar :: (TensorType a, TensorProtoLens a) => a -> Tensor Value a
scalar x = constant [] [x]
-- Random tensor from the unit normal distribution with bounded values.
@ -273,7 +275,8 @@ truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros :: (Num a, TensorType a, TensorProtoLens a)
=> Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32

View File

@ -31,8 +31,9 @@ import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input
-- back.
testDynamicPartitionStitchInverse :: forall a.
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse ::
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
=> StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor

View File

@ -36,8 +36,9 @@ import qualified TensorFlow.Types as TF
-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit ::
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
=> LookupExample a -> Property
testEmbeddingLookupUndoesSplit
(LookupExample numParts
shape@(TF.Shape (firstDim : restDims))

View File

@ -28,6 +28,7 @@
module TensorFlow.Types
( TensorType(..)
, TensorProtoLens(..)
, TensorData(..)
, Shape(..)
, protoShape
@ -76,13 +77,12 @@ import Proto.Tensorflow.Core.Framework.AttrValue
)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..)
, floatVal
, boolVal
, doubleVal
, floatVal
, int64Val
, intVal
, stringVal
, int64Val
, stringVal
, boolVal
)
import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
@ -101,7 +101,6 @@ newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
-- | Decode the bytes of a TensorData into a Vector.
decodeTensorData :: TensorData a -> V.Vector a
-- | Encode a Vector into a TensorData.
@ -113,6 +112,25 @@ class TensorType a where
-- ...
encodeTensorData :: Shape -> V.Vector a -> TensorData a
-- | Class of types that can be used for constructing constant tensors.
class TensorProtoLens a where
tensorVal :: Lens' TensorProto [a]
instance TensorProtoLens Float where
tensorVal = floatVal
instance TensorProtoLens Double where
tensorVal = doubleVal
instance TensorProtoLens Int32 where
tensorVal = intVal
instance TensorProtoLens Int64 where
tensorVal = int64Val
instance TensorProtoLens ByteString where
tensorVal = stringVal
-- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers.
@ -130,28 +148,24 @@ simpleEncode (Shape xs)
instance TensorType Float where
tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Double where
tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int32 where
tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF
tensorVal = intVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int64 where
tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
@ -161,40 +175,36 @@ integral = iso (fmap fromIntegral) (fmap fromIntegral)
instance TensorType Word8 where
tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Word16 where
tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int8 where
tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType ByteString where
tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset:
-- string length :: VarInt64
-- string data :: [Word8]
-- C++ counterparts of these are TF_Tensor_{En,De}codeStrings.
-- TODO(fmayle): Benchmark these functions.
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
@ -244,21 +254,18 @@ instance TensorType ByteString where
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"