mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Create a separate TensorProtoLens class.
Only a handful of types had sensible tensorVal implementations. This is now evident in type signatures at the expense of them being more verbose.
This commit is contained in:
parent
29f11d351d
commit
e511f49828
7 changed files with 44 additions and 30 deletions
|
@ -35,6 +35,7 @@ import TensorFlow.Tensor ( Tensor(..)
|
||||||
, Value(..)
|
, Value(..)
|
||||||
)
|
)
|
||||||
import TensorFlow.Types ( TensorType(..)
|
import TensorFlow.Types ( TensorType(..)
|
||||||
|
, TensorProtoLens
|
||||||
, OneOf
|
, OneOf
|
||||||
)
|
)
|
||||||
import TensorFlow.Ops ( zerosLike
|
import TensorFlow.Ops ( zerosLike
|
||||||
|
@ -70,7 +71,7 @@ import TensorFlow.Ops ( zerosLike
|
||||||
--
|
--
|
||||||
-- `logits` and `targets` must have the same type and shape.
|
-- `logits` and `targets` must have the same type and shape.
|
||||||
sigmoidCrossEntropyWithLogits
|
sigmoidCrossEntropyWithLogits
|
||||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
:: (OneOf '[Float, Double] a, TensorType a, TensorProtoLens a, Num a)
|
||||||
=> Tensor Value a -- ^ __logits__
|
=> Tensor Value a -- ^ __logits__
|
||||||
-> Tensor Value a -- ^ __targets__
|
-> Tensor Value a -- ^ __targets__
|
||||||
-> Build (Tensor Value a)
|
-> Build (Tensor Value a)
|
||||||
|
|
|
@ -27,7 +27,7 @@ import Data.List (genericLength)
|
||||||
import TensorFlow.Build (Build, colocateWith, render)
|
import TensorFlow.Build (Build, colocateWith, render)
|
||||||
import TensorFlow.Ops () -- Num instance for Tensor
|
import TensorFlow.Ops () -- Num instance for Tensor
|
||||||
import TensorFlow.Tensor (Tensor, Value)
|
import TensorFlow.Tensor (Tensor, Value)
|
||||||
import TensorFlow.Types (OneOf, TensorType)
|
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
-- | Looks up `ids` in a list of embedding tensors.
|
-- | Looks up `ids` in a list of embedding tensors.
|
||||||
|
@ -46,6 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup :: forall a b v .
|
embeddingLookup :: forall a b v .
|
||||||
( TensorType a
|
( TensorType a
|
||||||
|
, TensorProtoLens b
|
||||||
, OneOf '[Int64, Int32] b
|
, OneOf '[Int64, Int32] b
|
||||||
, Num b
|
, Num b
|
||||||
)
|
)
|
||||||
|
|
|
@ -94,13 +94,13 @@ import TensorFlow.Tensor
|
||||||
, tensorOutput
|
, tensorOutput
|
||||||
, tensorAttr
|
, tensorAttr
|
||||||
)
|
)
|
||||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens, attrLens)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
(NodeDef, attr, input, op, name)
|
(NodeDef, attr, input, op, name)
|
||||||
|
|
||||||
type GradientCompatible a =
|
type GradientCompatible a =
|
||||||
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
||||||
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
|
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a, TensorProtoLens a)
|
||||||
|
|
||||||
-- TODO(fmayle): Support control flow.
|
-- TODO(fmayle): Support control flow.
|
||||||
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
||||||
|
|
|
@ -138,6 +138,7 @@ import qualified Prelude (abs)
|
||||||
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
|
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
|
||||||
-- "1".
|
-- "1".
|
||||||
instance ( TensorType a
|
instance ( TensorType a
|
||||||
|
, TensorProtoLens a
|
||||||
, Num a
|
, Num a
|
||||||
, v ~ Value
|
, v ~ Value
|
||||||
, OneOf '[ Double, Float, Int32, Int64
|
, OneOf '[ Double, Float, Int32, Int64
|
||||||
|
@ -191,7 +192,7 @@ initializedVariable initializer = do
|
||||||
|
|
||||||
-- | Creates a zero-initialized variable with the given shape.
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
zeroInitializedVariable
|
zeroInitializedVariable
|
||||||
:: (TensorType a, Num a) =>
|
:: (TensorType a, TensorProtoLens a, Num a) =>
|
||||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||||
zeroInitializedVariable = initializedVariable . zeros
|
zeroInitializedVariable = initializedVariable . zeros
|
||||||
|
|
||||||
|
@ -238,7 +239,8 @@ restore path x = do
|
||||||
-- element 0: index (0, ..., 0)
|
-- element 0: index (0, ..., 0)
|
||||||
-- element 1: index (0, ..., 1)
|
-- element 1: index (0, ..., 1)
|
||||||
-- ...
|
-- ...
|
||||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
constant :: forall a . (TensorType a, TensorProtoLens a)
|
||||||
|
=> Shape -> [a] -> Tensor Value a
|
||||||
constant (Shape shape') values
|
constant (Shape shape') values
|
||||||
| invalidLength = error invalidLengthMsg
|
| invalidLength = error invalidLengthMsg
|
||||||
| otherwise = buildOp $ opDef "Const"
|
| otherwise = buildOp $ opDef "Const"
|
||||||
|
@ -258,11 +260,11 @@ constant (Shape shape') values
|
||||||
& tensorVal .~ values
|
& tensorVal .~ values
|
||||||
|
|
||||||
-- | Create a constant vector.
|
-- | Create a constant vector.
|
||||||
vector :: TensorType a => [a] -> Tensor Value a
|
vector :: (TensorType a, TensorProtoLens a) => [a] -> Tensor Value a
|
||||||
vector xs = constant [fromIntegral $ length xs] xs
|
vector xs = constant [fromIntegral $ length xs] xs
|
||||||
|
|
||||||
-- | Create a constant scalar.
|
-- | Create a constant scalar.
|
||||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
scalar :: (TensorType a, TensorProtoLens a) => a -> Tensor Value a
|
||||||
scalar x = constant [] [x]
|
scalar x = constant [] [x]
|
||||||
|
|
||||||
-- Random tensor from the unit normal distribution with bounded values.
|
-- Random tensor from the unit normal distribution with bounded values.
|
||||||
|
@ -273,7 +275,8 @@ truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||||
|
|
||||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
zeros :: (Num a, TensorType a, TensorProtoLens a)
|
||||||
|
=> Shape -> Tensor Value a
|
||||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||||
|
|
||||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||||
|
|
|
@ -31,8 +31,9 @@ import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||||
-- back.
|
-- back.
|
||||||
testDynamicPartitionStitchInverse :: forall a.
|
testDynamicPartitionStitchInverse ::
|
||||||
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
|
||||||
|
=> StitchExample a -> Property
|
||||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||||
let splitParts :: [TF.Tensor TF.Value a] =
|
let splitParts :: [TF.Tensor TF.Value a] =
|
||||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||||
|
|
|
@ -36,7 +36,8 @@ import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
-- Verifies that direct gather is the same as dynamic split into
|
-- Verifies that direct gather is the same as dynamic split into
|
||||||
-- partitions, followed by embedding lookup.
|
-- partitions, followed by embedding lookup.
|
||||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
testEmbeddingLookupUndoesSplit ::
|
||||||
|
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
|
||||||
=> LookupExample a -> Property
|
=> LookupExample a -> Property
|
||||||
testEmbeddingLookupUndoesSplit
|
testEmbeddingLookupUndoesSplit
|
||||||
(LookupExample numParts
|
(LookupExample numParts
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
|
|
||||||
module TensorFlow.Types
|
module TensorFlow.Types
|
||||||
( TensorType(..)
|
( TensorType(..)
|
||||||
|
, TensorProtoLens(..)
|
||||||
, TensorData(..)
|
, TensorData(..)
|
||||||
, Shape(..)
|
, Shape(..)
|
||||||
, protoShape
|
, protoShape
|
||||||
|
@ -76,13 +77,12 @@ import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||||
( TensorProto(..)
|
( TensorProto(..)
|
||||||
, floatVal
|
, boolVal
|
||||||
, doubleVal
|
, doubleVal
|
||||||
|
, floatVal
|
||||||
|
, int64Val
|
||||||
, intVal
|
, intVal
|
||||||
, stringVal
|
, stringVal
|
||||||
, int64Val
|
|
||||||
, stringVal
|
|
||||||
, boolVal
|
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
( TensorShapeProto(..)
|
( TensorShapeProto(..)
|
||||||
|
@ -101,7 +101,6 @@ newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||||
class TensorType a where
|
class TensorType a where
|
||||||
tensorType :: a -> DataType
|
tensorType :: a -> DataType
|
||||||
tensorRefType :: a -> DataType
|
tensorRefType :: a -> DataType
|
||||||
tensorVal :: Lens' TensorProto [a]
|
|
||||||
-- | Decode the bytes of a TensorData into a Vector.
|
-- | Decode the bytes of a TensorData into a Vector.
|
||||||
decodeTensorData :: TensorData a -> V.Vector a
|
decodeTensorData :: TensorData a -> V.Vector a
|
||||||
-- | Encode a Vector into a TensorData.
|
-- | Encode a Vector into a TensorData.
|
||||||
|
@ -113,6 +112,25 @@ class TensorType a where
|
||||||
-- ...
|
-- ...
|
||||||
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
||||||
|
|
||||||
|
-- | Class of types that can be used for constructing constant tensors.
|
||||||
|
class TensorProtoLens a where
|
||||||
|
tensorVal :: Lens' TensorProto [a]
|
||||||
|
|
||||||
|
instance TensorProtoLens Float where
|
||||||
|
tensorVal = floatVal
|
||||||
|
|
||||||
|
instance TensorProtoLens Double where
|
||||||
|
tensorVal = doubleVal
|
||||||
|
|
||||||
|
instance TensorProtoLens Int32 where
|
||||||
|
tensorVal = intVal
|
||||||
|
|
||||||
|
instance TensorProtoLens Int64 where
|
||||||
|
tensorVal = int64Val
|
||||||
|
|
||||||
|
instance TensorProtoLens ByteString where
|
||||||
|
tensorVal = stringVal
|
||||||
|
|
||||||
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
||||||
-- Vector.Storable to encode/decode by type casting pointers.
|
-- Vector.Storable to encode/decode by type casting pointers.
|
||||||
|
|
||||||
|
@ -130,28 +148,24 @@ simpleEncode (Shape xs)
|
||||||
instance TensorType Float where
|
instance TensorType Float where
|
||||||
tensorType _ = DT_FLOAT
|
tensorType _ = DT_FLOAT
|
||||||
tensorRefType _ = DT_FLOAT_REF
|
tensorRefType _ = DT_FLOAT_REF
|
||||||
tensorVal = floatVal
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Double where
|
instance TensorType Double where
|
||||||
tensorType _ = DT_DOUBLE
|
tensorType _ = DT_DOUBLE
|
||||||
tensorRefType _ = DT_DOUBLE_REF
|
tensorRefType _ = DT_DOUBLE_REF
|
||||||
tensorVal = doubleVal
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Int32 where
|
instance TensorType Int32 where
|
||||||
tensorType _ = DT_INT32
|
tensorType _ = DT_INT32
|
||||||
tensorRefType _ = DT_INT32_REF
|
tensorRefType _ = DT_INT32_REF
|
||||||
tensorVal = intVal
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Int64 where
|
instance TensorType Int64 where
|
||||||
tensorType _ = DT_INT64
|
tensorType _ = DT_INT64
|
||||||
tensorRefType _ = DT_INT64_REF
|
tensorRefType _ = DT_INT64_REF
|
||||||
tensorVal = int64Val
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
|
@ -161,40 +175,36 @@ integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
||||||
instance TensorType Word8 where
|
instance TensorType Word8 where
|
||||||
tensorType _ = DT_UINT8
|
tensorType _ = DT_UINT8
|
||||||
tensorRefType _ = DT_UINT8_REF
|
tensorRefType _ = DT_UINT8_REF
|
||||||
tensorVal = intVal . integral
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Word16 where
|
instance TensorType Word16 where
|
||||||
tensorType _ = DT_UINT16
|
tensorType _ = DT_UINT16
|
||||||
tensorRefType _ = DT_UINT16_REF
|
tensorRefType _ = DT_UINT16_REF
|
||||||
tensorVal = intVal . integral
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Int16 where
|
instance TensorType Int16 where
|
||||||
tensorType _ = DT_INT16
|
tensorType _ = DT_INT16
|
||||||
tensorRefType _ = DT_INT16_REF
|
tensorRefType _ = DT_INT16_REF
|
||||||
tensorVal = intVal . integral
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType Int8 where
|
instance TensorType Int8 where
|
||||||
tensorType _ = DT_INT8
|
tensorType _ = DT_INT8
|
||||||
tensorRefType _ = DT_INT8_REF
|
tensorRefType _ = DT_INT8_REF
|
||||||
tensorVal = intVal . integral
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType ByteString where
|
instance TensorType ByteString where
|
||||||
tensorType _ = DT_STRING
|
tensorType _ = DT_STRING
|
||||||
tensorRefType _ = DT_STRING_REF
|
tensorRefType _ = DT_STRING_REF
|
||||||
tensorVal = stringVal
|
|
||||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||||
-- table offsets for each element :: [Word64]
|
-- table offsets for each element :: [Word64]
|
||||||
-- at each element offset:
|
-- at each element offset:
|
||||||
-- string length :: VarInt64
|
-- string length :: VarInt64
|
||||||
-- string data :: [Word8]
|
-- string data :: [Word8]
|
||||||
|
-- C++ counterparts of these are TF_Tensor_{En,De}codeStrings.
|
||||||
-- TODO(fmayle): Benchmark these functions.
|
-- TODO(fmayle): Benchmark these functions.
|
||||||
decodeTensorData tensorData =
|
decodeTensorData tensorData =
|
||||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||||
|
@ -244,21 +254,18 @@ instance TensorType ByteString where
|
||||||
instance TensorType Bool where
|
instance TensorType Bool where
|
||||||
tensorType _ = DT_BOOL
|
tensorType _ = DT_BOOL
|
||||||
tensorRefType _ = DT_BOOL_REF
|
tensorRefType _ = DT_BOOL_REF
|
||||||
tensorVal = boolVal
|
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData = simpleDecode
|
||||||
encodeTensorData = simpleEncode
|
encodeTensorData = simpleEncode
|
||||||
|
|
||||||
instance TensorType (Complex Float) where
|
instance TensorType (Complex Float) where
|
||||||
tensorType _ = DT_COMPLEX64
|
tensorType _ = DT_COMPLEX64
|
||||||
tensorRefType _ = DT_COMPLEX64
|
tensorRefType _ = DT_COMPLEX64
|
||||||
tensorVal = error "TODO (Complex Float)"
|
|
||||||
decodeTensorData = error "TODO (Complex Float)"
|
decodeTensorData = error "TODO (Complex Float)"
|
||||||
encodeTensorData = error "TODO (Complex Float)"
|
encodeTensorData = error "TODO (Complex Float)"
|
||||||
|
|
||||||
instance TensorType (Complex Double) where
|
instance TensorType (Complex Double) where
|
||||||
tensorType _ = DT_COMPLEX128
|
tensorType _ = DT_COMPLEX128
|
||||||
tensorRefType _ = DT_COMPLEX128
|
tensorRefType _ = DT_COMPLEX128
|
||||||
tensorVal = error "TODO (Complex Double)"
|
|
||||||
decodeTensorData = error "TODO (Complex Double)"
|
decodeTensorData = error "TODO (Complex Double)"
|
||||||
encodeTensorData = error "TODO (Complex Double)"
|
encodeTensorData = error "TODO (Complex Double)"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue