diff --git a/tensorflow-nn/src/TensorFlow/NN.hs b/tensorflow-nn/src/TensorFlow/NN.hs index 5cb5fda..c7d7c28 100644 --- a/tensorflow-nn/src/TensorFlow/NN.hs +++ b/tensorflow-nn/src/TensorFlow/NN.hs @@ -35,6 +35,7 @@ import TensorFlow.Tensor ( Tensor(..) , Value(..) ) import TensorFlow.Types ( TensorType(..) + , TensorProtoLens , OneOf ) import TensorFlow.Ops ( zerosLike @@ -70,7 +71,7 @@ import TensorFlow.Ops ( zerosLike -- -- `logits` and `targets` must have the same type and shape. sigmoidCrossEntropyWithLogits - :: (OneOf '[Float, Double] a, TensorType a, Num a) + :: (OneOf '[Float, Double] a, TensorType a, TensorProtoLens a, Num a) => Tensor Value a -- ^ __logits__ -> Tensor Value a -- ^ __targets__ -> Build (Tensor Value a) diff --git a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs index 9eb396b..24ceb91 100644 --- a/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs +++ b/tensorflow-ops/src/TensorFlow/EmbeddingOps.hs @@ -27,7 +27,7 @@ import Data.List (genericLength) import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Ops () -- Num instance for Tensor import TensorFlow.Tensor (Tensor, Value) -import TensorFlow.Types (OneOf, TensorType) +import TensorFlow.Types (OneOf, TensorType, TensorProtoLens) import qualified TensorFlow.GenOps.Core as CoreOps -- | Looks up `ids` in a list of embedding tensors. @@ -46,6 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. embeddingLookup :: forall a b v . ( TensorType a + , TensorProtoLens b , OneOf '[Int64, Int32] b , Num b ) diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index f863e36..f629510 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -94,13 +94,13 @@ import TensorFlow.Tensor , tensorOutput , tensorAttr ) -import TensorFlow.Types (OneOf, TensorType, attrLens) +import TensorFlow.Types (OneOf, TensorType, TensorProtoLens, attrLens) import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef, attr, input, op, name) type GradientCompatible a = -- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason. - (Num a, OneOf '[ Float, Complex Float, Complex Double ] a) + (Num a, OneOf '[ Float, Complex Float, Complex Double ] a, TensorProtoLens a) -- TODO(fmayle): Support control flow. -- TODO(fmayle): Support gate_gradients-like option to avoid race conditions. diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 3fff01d..af69c88 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -138,6 +138,7 @@ import qualified Prelude (abs) -- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression -- "1". instance ( TensorType a + , TensorProtoLens a , Num a , v ~ Value , OneOf '[ Double, Float, Int32, Int64 @@ -191,7 +192,7 @@ initializedVariable initializer = do -- | Creates a zero-initialized variable with the given shape. zeroInitializedVariable - :: (TensorType a, Num a) => + :: (TensorType a, TensorProtoLens a, Num a) => TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a) zeroInitializedVariable = initializedVariable . zeros @@ -238,7 +239,8 @@ restore path x = do -- element 0: index (0, ..., 0) -- element 1: index (0, ..., 1) -- ... -constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a +constant :: forall a . (TensorType a, TensorProtoLens a) + => Shape -> [a] -> Tensor Value a constant (Shape shape') values | invalidLength = error invalidLengthMsg | otherwise = buildOp $ opDef "Const" @@ -258,11 +260,11 @@ constant (Shape shape') values & tensorVal .~ values -- | Create a constant vector. -vector :: TensorType a => [a] -> Tensor Value a +vector :: (TensorType a, TensorProtoLens a) => [a] -> Tensor Value a vector xs = constant [fromIntegral $ length xs] xs -- | Create a constant scalar. -scalar :: forall a . TensorType a => a -> Tensor Value a +scalar :: (TensorType a, TensorProtoLens a) => a -> Tensor Value a scalar x = constant [] [x] -- Random tensor from the unit normal distribution with bounded values. @@ -273,7 +275,8 @@ truncatedNormal = buildOp $ opDef "TruncatedNormal" & opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "T" .~ tensorType (undefined :: Int64) -zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a +zeros :: (Num a, TensorType a, TensorProtoLens a) + => Shape -> Tensor Value a zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32 diff --git a/tensorflow-ops/tests/DataFlowOpsTest.hs b/tensorflow-ops/tests/DataFlowOpsTest.hs index cd362c9..cb4f612 100644 --- a/tensorflow-ops/tests/DataFlowOpsTest.hs +++ b/tensorflow-ops/tests/DataFlowOpsTest.hs @@ -31,8 +31,9 @@ import qualified TensorFlow.Types as TF -- DynamicSplit is undone with DynamicStitch to get the original input -- back. -testDynamicPartitionStitchInverse :: forall a. - (TF.TensorType a, Show a, Eq a) => StitchExample a -> Property +testDynamicPartitionStitchInverse :: + forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a) + => StitchExample a -> Property testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = let splitParts :: [TF.Tensor TF.Value a] = CoreOps.dynamicPartition numParts (TF.vector values) partTensor diff --git a/tensorflow-ops/tests/EmbeddingOpsTest.hs b/tensorflow-ops/tests/EmbeddingOpsTest.hs index 0a6b97d..4a7dc94 100644 --- a/tensorflow-ops/tests/EmbeddingOpsTest.hs +++ b/tensorflow-ops/tests/EmbeddingOpsTest.hs @@ -36,8 +36,9 @@ import qualified TensorFlow.Types as TF -- Verifies that direct gather is the same as dynamic split into -- partitions, followed by embedding lookup. -testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a) - => LookupExample a -> Property +testEmbeddingLookupUndoesSplit :: + forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a) + => LookupExample a -> Property testEmbeddingLookupUndoesSplit (LookupExample numParts shape@(TF.Shape (firstDim : restDims)) diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 3d47f39..604e06d 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -28,6 +28,7 @@ module TensorFlow.Types ( TensorType(..) + , TensorProtoLens(..) , TensorData(..) , Shape(..) , protoShape @@ -76,13 +77,12 @@ import Proto.Tensorflow.Core.Framework.AttrValue ) import Proto.Tensorflow.Core.Framework.Tensor as Tensor ( TensorProto(..) - , floatVal + , boolVal , doubleVal + , floatVal + , int64Val , intVal , stringVal - , int64Val - , stringVal - , boolVal ) import Proto.Tensorflow.Core.Framework.TensorShape ( TensorShapeProto(..) @@ -101,7 +101,6 @@ newtype TensorData a = TensorData { unTensorData :: FFI.TensorData } class TensorType a where tensorType :: a -> DataType tensorRefType :: a -> DataType - tensorVal :: Lens' TensorProto [a] -- | Decode the bytes of a TensorData into a Vector. decodeTensorData :: TensorData a -> V.Vector a -- | Encode a Vector into a TensorData. @@ -113,6 +112,25 @@ class TensorType a where -- ... encodeTensorData :: Shape -> V.Vector a -> TensorData a +-- | Class of types that can be used for constructing constant tensors. +class TensorProtoLens a where + tensorVal :: Lens' TensorProto [a] + +instance TensorProtoLens Float where + tensorVal = floatVal + +instance TensorProtoLens Double where + tensorVal = doubleVal + +instance TensorProtoLens Int32 where + tensorVal = intVal + +instance TensorProtoLens Int64 where + tensorVal = int64Val + +instance TensorProtoLens ByteString where + tensorVal = stringVal + -- All types, besides ByteString, are encoded as simple arrays and we can use -- Vector.Storable to encode/decode by type casting pointers. @@ -130,28 +148,24 @@ simpleEncode (Shape xs) instance TensorType Float where tensorType _ = DT_FLOAT tensorRefType _ = DT_FLOAT_REF - tensorVal = floatVal decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Double where tensorType _ = DT_DOUBLE tensorRefType _ = DT_DOUBLE_REF - tensorVal = doubleVal decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Int32 where tensorType _ = DT_INT32 tensorRefType _ = DT_INT32_REF - tensorVal = intVal decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Int64 where tensorType _ = DT_INT64 tensorRefType _ = DT_INT64_REF - tensorVal = int64Val decodeTensorData = simpleDecode encodeTensorData = simpleEncode @@ -161,40 +175,36 @@ integral = iso (fmap fromIntegral) (fmap fromIntegral) instance TensorType Word8 where tensorType _ = DT_UINT8 tensorRefType _ = DT_UINT8_REF - tensorVal = intVal . integral decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Word16 where tensorType _ = DT_UINT16 tensorRefType _ = DT_UINT16_REF - tensorVal = intVal . integral decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Int16 where tensorType _ = DT_INT16 tensorRefType _ = DT_INT16_REF - tensorVal = intVal . integral decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType Int8 where tensorType _ = DT_INT8 tensorRefType _ = DT_INT8_REF - tensorVal = intVal . integral decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType ByteString where tensorType _ = DT_STRING tensorRefType _ = DT_STRING_REF - tensorVal = stringVal -- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- table offsets for each element :: [Word64] -- at each element offset: -- string length :: VarInt64 -- string data :: [Word8] + -- C++ counterparts of these are TF_Tensor_{En,De}codeStrings. -- TODO(fmayle): Benchmark these functions. decodeTensorData tensorData = either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ @@ -244,21 +254,18 @@ instance TensorType ByteString where instance TensorType Bool where tensorType _ = DT_BOOL tensorRefType _ = DT_BOOL_REF - tensorVal = boolVal decodeTensorData = simpleDecode encodeTensorData = simpleEncode instance TensorType (Complex Float) where tensorType _ = DT_COMPLEX64 tensorRefType _ = DT_COMPLEX64 - tensorVal = error "TODO (Complex Float)" decodeTensorData = error "TODO (Complex Float)" encodeTensorData = error "TODO (Complex Float)" instance TensorType (Complex Double) where tensorType _ = DT_COMPLEX128 tensorRefType _ = DT_COMPLEX128 - tensorVal = error "TODO (Complex Double)" decodeTensorData = error "TODO (Complex Double)" encodeTensorData = error "TODO (Complex Double)"