mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Create a separate TensorProtoLens class.
Only a handful of types had sensible tensorVal implementations. This is now evident in type signatures at the expense of them being more verbose.
This commit is contained in:
parent
29f11d351d
commit
e511f49828
7 changed files with 44 additions and 30 deletions
|
@ -35,6 +35,7 @@ import TensorFlow.Tensor ( Tensor(..)
|
|||
, Value(..)
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
, TensorProtoLens
|
||||
, OneOf
|
||||
)
|
||||
import TensorFlow.Ops ( zerosLike
|
||||
|
@ -70,7 +71,7 @@ import TensorFlow.Ops ( zerosLike
|
|||
--
|
||||
-- `logits` and `targets` must have the same type and shape.
|
||||
sigmoidCrossEntropyWithLogits
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
:: (OneOf '[Float, Double] a, TensorType a, TensorProtoLens a, Num a)
|
||||
=> Tensor Value a -- ^ __logits__
|
||||
-> Tensor Value a -- ^ __targets__
|
||||
-> Build (Tensor Value a)
|
||||
|
|
|
@ -27,7 +27,7 @@ import Data.List (genericLength)
|
|||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Ops () -- Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Looks up `ids` in a list of embedding tensors.
|
||||
|
@ -46,6 +46,7 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
|||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v .
|
||||
( TensorType a
|
||||
, TensorProtoLens b
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
)
|
||||
|
|
|
@ -94,13 +94,13 @@ import TensorFlow.Tensor
|
|||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (OneOf, TensorType, attrLens)
|
||||
import TensorFlow.Types (OneOf, TensorType, TensorProtoLens, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
(NodeDef, attr, input, op, name)
|
||||
|
||||
type GradientCompatible a =
|
||||
-- TODO(fmayle): MaxPoolGrad doesn't support Double for some reason.
|
||||
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a)
|
||||
(Num a, OneOf '[ Float, Complex Float, Complex Double ] a, TensorProtoLens a)
|
||||
|
||||
-- TODO(fmayle): Support control flow.
|
||||
-- TODO(fmayle): Support gate_gradients-like option to avoid race conditions.
|
||||
|
|
|
@ -138,6 +138,7 @@ import qualified Prelude (abs)
|
|||
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
|
||||
-- "1".
|
||||
instance ( TensorType a
|
||||
, TensorProtoLens a
|
||||
, Num a
|
||||
, v ~ Value
|
||||
, OneOf '[ Double, Float, Int32, Int64
|
||||
|
@ -191,7 +192,7 @@ initializedVariable initializer = do
|
|||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (TensorType a, Num a) =>
|
||||
:: (TensorType a, TensorProtoLens a, Num a) =>
|
||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
|
||||
|
@ -238,7 +239,8 @@ restore path x = do
|
|||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant :: forall a . (TensorType a, TensorProtoLens a)
|
||||
=> Shape -> [a] -> Tensor Value a
|
||||
constant (Shape shape') values
|
||||
| invalidLength = error invalidLengthMsg
|
||||
| otherwise = buildOp $ opDef "Const"
|
||||
|
@ -258,11 +260,11 @@ constant (Shape shape') values
|
|||
& tensorVal .~ values
|
||||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector :: (TensorType a, TensorProtoLens a) => [a] -> Tensor Value a
|
||||
vector xs = constant [fromIntegral $ length xs] xs
|
||||
|
||||
-- | Create a constant scalar.
|
||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||
scalar :: (TensorType a, TensorProtoLens a) => a -> Tensor Value a
|
||||
scalar x = constant [] [x]
|
||||
|
||||
-- Random tensor from the unit normal distribution with bounded values.
|
||||
|
@ -273,7 +275,8 @@ truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
|||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros :: (Num a, TensorType a, TensorProtoLens a)
|
||||
=> Shape -> Tensor Value a
|
||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||
|
||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||
|
|
|
@ -31,8 +31,9 @@ import qualified TensorFlow.Types as TF
|
|||
|
||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||
-- back.
|
||||
testDynamicPartitionStitchInverse :: forall a.
|
||||
(TF.TensorType a, Show a, Eq a) => StitchExample a -> Property
|
||||
testDynamicPartitionStitchInverse ::
|
||||
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
|
||||
=> StitchExample a -> Property
|
||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||
let splitParts :: [TF.Tensor TF.Value a] =
|
||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||
|
|
|
@ -36,8 +36,9 @@ import qualified TensorFlow.Types as TF
|
|||
|
||||
-- Verifies that direct gather is the same as dynamic split into
|
||||
-- partitions, followed by embedding lookup.
|
||||
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
|
||||
=> LookupExample a -> Property
|
||||
testEmbeddingLookupUndoesSplit ::
|
||||
forall a. (TF.TensorType a, TF.TensorProtoLens a, Show a, Eq a)
|
||||
=> LookupExample a -> Property
|
||||
testEmbeddingLookupUndoesSplit
|
||||
(LookupExample numParts
|
||||
shape@(TF.Shape (firstDim : restDims))
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
module TensorFlow.Types
|
||||
( TensorType(..)
|
||||
, TensorProtoLens(..)
|
||||
, TensorData(..)
|
||||
, Shape(..)
|
||||
, protoShape
|
||||
|
@ -76,13 +77,12 @@ import Proto.Tensorflow.Core.Framework.AttrValue
|
|||
)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||
( TensorProto(..)
|
||||
, floatVal
|
||||
, boolVal
|
||||
, doubleVal
|
||||
, floatVal
|
||||
, int64Val
|
||||
, intVal
|
||||
, stringVal
|
||||
, int64Val
|
||||
, stringVal
|
||||
, boolVal
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
( TensorShapeProto(..)
|
||||
|
@ -101,7 +101,6 @@ newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
|||
class TensorType a where
|
||||
tensorType :: a -> DataType
|
||||
tensorRefType :: a -> DataType
|
||||
tensorVal :: Lens' TensorProto [a]
|
||||
-- | Decode the bytes of a TensorData into a Vector.
|
||||
decodeTensorData :: TensorData a -> V.Vector a
|
||||
-- | Encode a Vector into a TensorData.
|
||||
|
@ -113,6 +112,25 @@ class TensorType a where
|
|||
-- ...
|
||||
encodeTensorData :: Shape -> V.Vector a -> TensorData a
|
||||
|
||||
-- | Class of types that can be used for constructing constant tensors.
|
||||
class TensorProtoLens a where
|
||||
tensorVal :: Lens' TensorProto [a]
|
||||
|
||||
instance TensorProtoLens Float where
|
||||
tensorVal = floatVal
|
||||
|
||||
instance TensorProtoLens Double where
|
||||
tensorVal = doubleVal
|
||||
|
||||
instance TensorProtoLens Int32 where
|
||||
tensorVal = intVal
|
||||
|
||||
instance TensorProtoLens Int64 where
|
||||
tensorVal = int64Val
|
||||
|
||||
instance TensorProtoLens ByteString where
|
||||
tensorVal = stringVal
|
||||
|
||||
-- All types, besides ByteString, are encoded as simple arrays and we can use
|
||||
-- Vector.Storable to encode/decode by type casting pointers.
|
||||
|
||||
|
@ -130,28 +148,24 @@ simpleEncode (Shape xs)
|
|||
instance TensorType Float where
|
||||
tensorType _ = DT_FLOAT
|
||||
tensorRefType _ = DT_FLOAT_REF
|
||||
tensorVal = floatVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Double where
|
||||
tensorType _ = DT_DOUBLE
|
||||
tensorRefType _ = DT_DOUBLE_REF
|
||||
tensorVal = doubleVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int32 where
|
||||
tensorType _ = DT_INT32
|
||||
tensorRefType _ = DT_INT32_REF
|
||||
tensorVal = intVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int64 where
|
||||
tensorType _ = DT_INT64
|
||||
tensorRefType _ = DT_INT64_REF
|
||||
tensorVal = int64Val
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
|
@ -161,40 +175,36 @@ integral = iso (fmap fromIntegral) (fmap fromIntegral)
|
|||
instance TensorType Word8 where
|
||||
tensorType _ = DT_UINT8
|
||||
tensorRefType _ = DT_UINT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Word16 where
|
||||
tensorType _ = DT_UINT16
|
||||
tensorRefType _ = DT_UINT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int16 where
|
||||
tensorType _ = DT_INT16
|
||||
tensorRefType _ = DT_INT16_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType Int8 where
|
||||
tensorType _ = DT_INT8
|
||||
tensorRefType _ = DT_INT8_REF
|
||||
tensorVal = intVal . integral
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType ByteString where
|
||||
tensorType _ = DT_STRING
|
||||
tensorRefType _ = DT_STRING_REF
|
||||
tensorVal = stringVal
|
||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||
-- table offsets for each element :: [Word64]
|
||||
-- at each element offset:
|
||||
-- string length :: VarInt64
|
||||
-- string data :: [Word8]
|
||||
-- C++ counterparts of these are TF_Tensor_{En,De}codeStrings.
|
||||
-- TODO(fmayle): Benchmark these functions.
|
||||
decodeTensorData tensorData =
|
||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||
|
@ -244,21 +254,18 @@ instance TensorType ByteString where
|
|||
instance TensorType Bool where
|
||||
tensorType _ = DT_BOOL
|
||||
tensorRefType _ = DT_BOOL_REF
|
||||
tensorVal = boolVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorType (Complex Float) where
|
||||
tensorType _ = DT_COMPLEX64
|
||||
tensorRefType _ = DT_COMPLEX64
|
||||
tensorVal = error "TODO (Complex Float)"
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
|
||||
instance TensorType (Complex Double) where
|
||||
tensorType _ = DT_COMPLEX128
|
||||
tensorRefType _ = DT_COMPLEX128
|
||||
tensorVal = error "TODO (Complex Double)"
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
|
|
Loading…
Reference in a new issue