mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add versions of each op that take optional params as an extra arg. (#84)
Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...` which lets you set optional attributes. `OpParams` is currently a type alias for `OpDef -> OpDef`. In the future we should consider more type safety, e.g., using type-level strings and OverloadedLabels for optional attributes. I used it to replace a few manual `buildOp`s in our code with the codegenerated ops, now that it's easier to set attributes. I also removed `tensorAttr` and `named` since it's now possible to set those op attributes directly. Although this clutters up the API a bit, I think it's simpler than using type classes to implement optional arguments (as in, for example, `Text.Printf`) -- especially in terms of type inference with the rest of the library.
This commit is contained in:
parent
2c5c879037
commit
c99a23b6a7
11 changed files with 190 additions and 152 deletions
|
@ -12,6 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
@ -172,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
|
||||||
renderOp :: ParsedOp -> Doc
|
renderOp :: ParsedOp -> Doc
|
||||||
renderOp pOp = stack $
|
renderOp pOp = stack $
|
||||||
[ haddocks
|
[ haddocks
|
||||||
, n <+> "::" <+> hang 0 (typeSig pOp)
|
-- Prevent unreasonably long compilation times on ghc-7.10, due
|
||||||
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
-- to stack calling "-dump-hi" which (unnecessarily) includes the
|
||||||
|
-- inlining information, and is large for ops with many arguments.
|
||||||
|
#if __GLASGOW_HASKELL__ < 800
|
||||||
|
, "{-# NOINLINE " <> n <> "#-}"
|
||||||
|
#endif
|
||||||
|
, n <+> "::" <+> hang 0 (typeSig empty pOp)
|
||||||
|
, n <+> "=" <+> n <> "' id"
|
||||||
|
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
|
||||||
|
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||||
<+> "=" </> -- args are indented
|
<+> "=" </> -- args are indented
|
||||||
-- the body needs to be indented wrt the name
|
-- the body needs to be indented wrt the name
|
||||||
indent indentation (functionBody pOp)
|
indent indentation (functionBody pOp)
|
||||||
] ++ whereClause listSizeAttrs
|
] ++ whereClause listSizeAttrs
|
||||||
where
|
where
|
||||||
n = renderHaskellName $ parsedOpName pOp
|
n = renderHaskellName $ parsedOpName pOp
|
||||||
|
n' = n <> "'"
|
||||||
listSizeAttrs = inferredListSizeAttrs pOp
|
listSizeAttrs = inferredListSizeAttrs pOp
|
||||||
args = sep $ map renderHaskellName
|
args = sep $ "op'options"
|
||||||
|
: (map renderHaskellName
|
||||||
$ map attrName (explicitInputAttrs pOp)
|
$ map attrName (explicitInputAttrs pOp)
|
||||||
++ map parsedArgName (parsedInputs pOp)
|
++ map parsedArgName (parsedInputs pOp))
|
||||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||||
|
|
||||||
-- | A check that all lists of the given size have the given length.
|
-- | A check that all lists of the given size have the given length.
|
||||||
|
@ -247,7 +258,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
|
||||||
-- Renders sizes of tensor list types having number_attr.
|
-- Renders sizes of tensor list types having number_attr.
|
||||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||||
]
|
] ++
|
||||||
|
["& op'options"]
|
||||||
|
|
||||||
|
|
||||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||||
inferredTypeExpr a
|
inferredTypeExpr a
|
||||||
|
@ -270,12 +283,12 @@ extras d = enclose "{-\n" "\n-}" $
|
||||||
-- | The type signature for an op.
|
-- | The type signature for an op.
|
||||||
-- Of the form:
|
-- Of the form:
|
||||||
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
||||||
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
|
-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||||
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||||
-- "Tensor t2 v2" is an output.
|
-- "Tensor t2 v2" is an output.
|
||||||
typeSig :: ParsedOp -> Doc
|
typeSig :: Doc -> ParsedOp -> Doc
|
||||||
typeSig pOp = constraints
|
typeSig pre pOp = constraints
|
||||||
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||||
++ map tensorArgAndComment (parsedInputs pOp)
|
++ map tensorArgAndComment (parsedInputs pOp)
|
||||||
++ [outputs])
|
++ [outputs])
|
||||||
where
|
where
|
||||||
|
|
|
@ -72,6 +72,7 @@ import TensorFlow.Ops
|
||||||
, expandDims
|
, expandDims
|
||||||
, fill
|
, fill
|
||||||
, matMul
|
, matMul
|
||||||
|
, matMul'
|
||||||
, reducedShape
|
, reducedShape
|
||||||
, reluGrad
|
, reluGrad
|
||||||
, reshape
|
, reshape
|
||||||
|
@ -95,7 +96,6 @@ import TensorFlow.Tensor
|
||||||
, TensorKind (ValueKind)
|
, TensorKind (ValueKind)
|
||||||
, Value
|
, Value
|
||||||
, tensorOutput
|
, tensorOutput
|
||||||
, tensorAttr
|
|
||||||
)
|
)
|
||||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
@ -532,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
let transposeA = lookupAttr nodeDef "transpose_a"
|
let transposeA = lookupAttr nodeDef "transpose_a"
|
||||||
transposeB = lookupAttr nodeDef "transpose_b"
|
transposeB = lookupAttr nodeDef "transpose_b"
|
||||||
transAttrs a b =
|
transAttrs a b =
|
||||||
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
|
||||||
in case (transposeA, transposeB) of
|
in case (transposeA, transposeB) of
|
||||||
(False, False) ->
|
(False, False) ->
|
||||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
[ Just $ matMul' (transAttrs False True) dz y
|
||||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
, Just $ matMul' (transAttrs True False) x dz]
|
||||||
(False, True) ->
|
(False, True) ->
|
||||||
[ Just $ dz `matMul` y
|
[ Just $ matMul dz y
|
||||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
, Just $ matMul' (transAttrs True False) x dz]
|
||||||
(True, False) ->
|
(True, False) ->
|
||||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
[ Just $ matMul' (transAttrs False True) dz y
|
||||||
, Just $ x `matMul` dz ]
|
, Just $ matMul x dz]
|
||||||
(True, True) ->
|
(True, True) ->
|
||||||
[ Just $ (dz `matMul` y) & transAttrs True True
|
[ Just $ matMul' (transAttrs True True) dz y
|
||||||
, Just $ (x `matMul` dz) & transAttrs True True ]
|
, Just $ matMul' (transAttrs True True) x dz]
|
||||||
|
|
||||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
[ Just $ CoreOps.transpose dz
|
[ Just $ CoreOps.transpose dz
|
||||||
|
@ -554,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
]
|
]
|
||||||
|
|
||||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
[ Just $ CoreOps.conv2DBackpropInput'
|
||||||
& tensorAttr "strides" .~ strides
|
((opAttr "strides" .~ strides)
|
||||||
& tensorAttr "padding" .~ padding
|
. (opAttr "padding" .~ padding)
|
||||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
& tensorAttr "data_format" .~ dataFormat
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
(shape x) y dz
|
||||||
& tensorAttr "strides" .~ strides
|
, Just $ CoreOps.conv2DBackpropFilter'
|
||||||
& tensorAttr "padding" .~ padding
|
((opAttr "strides" .~ strides)
|
||||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
. (opAttr "padding" .~ padding)
|
||||||
& tensorAttr "data_format" .~ dataFormat
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
x (shape y) dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -572,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
[ Just $ CoreOps.maxPoolGrad x output dz
|
[ Just $ CoreOps.maxPoolGrad'
|
||||||
& tensorAttr "ksize" .~ ksize
|
((opAttr "ksize" .~ ksize)
|
||||||
& tensorAttr "strides" .~ strides
|
. (opAttr "strides" .~ strides)
|
||||||
& tensorAttr "padding" .~ padding
|
. (opAttr "padding" .~ padding)
|
||||||
& tensorAttr "data_format" .~ dataFormat
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
|
x output dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
output :: Tensor Value a
|
output :: Tensor Value a
|
||||||
|
|
|
@ -58,56 +58,99 @@
|
||||||
|
|
||||||
module TensorFlow.Ops
|
module TensorFlow.Ops
|
||||||
( CoreOps.add
|
( CoreOps.add
|
||||||
|
, CoreOps.add'
|
||||||
, CoreOps.abs
|
, CoreOps.abs
|
||||||
|
, CoreOps.abs'
|
||||||
, CoreOps.addN
|
, CoreOps.addN
|
||||||
|
, CoreOps.addN'
|
||||||
, CoreOps.argMax
|
, CoreOps.argMax
|
||||||
|
, CoreOps.argMax'
|
||||||
, CoreOps.assign
|
, CoreOps.assign
|
||||||
|
, CoreOps.assign'
|
||||||
, CoreOps.broadcastGradientArgs
|
, CoreOps.broadcastGradientArgs
|
||||||
|
, CoreOps.broadcastGradientArgs'
|
||||||
, CoreOps.cast
|
, CoreOps.cast
|
||||||
|
, CoreOps.cast'
|
||||||
, CoreOps.concat
|
, CoreOps.concat
|
||||||
|
, CoreOps.concat'
|
||||||
, constant
|
, constant
|
||||||
|
, constant'
|
||||||
, CoreOps.equal
|
, CoreOps.equal
|
||||||
|
, CoreOps.equal'
|
||||||
, expandDims
|
, expandDims
|
||||||
|
, expandDims'
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
|
, initializedVariable'
|
||||||
, zeroInitializedVariable
|
, zeroInitializedVariable
|
||||||
|
, zeroInitializedVariable'
|
||||||
, CoreOps.fill
|
, CoreOps.fill
|
||||||
, CoreOps.oneHot
|
, CoreOps.fill'
|
||||||
|
, CoreOps.identity
|
||||||
|
, CoreOps.identity'
|
||||||
, CoreOps.matMul
|
, CoreOps.matMul
|
||||||
|
, CoreOps.matMul'
|
||||||
, matTranspose
|
, matTranspose
|
||||||
|
, matTranspose'
|
||||||
, CoreOps.mean
|
, CoreOps.mean
|
||||||
|
, CoreOps.mean'
|
||||||
, CoreOps.mul
|
, CoreOps.mul
|
||||||
|
, CoreOps.mul'
|
||||||
, CoreOps.neg
|
, CoreOps.neg
|
||||||
|
, CoreOps.neg'
|
||||||
|
, CoreOps.oneHot
|
||||||
|
, CoreOps.oneHot'
|
||||||
, CoreOps.pack
|
, CoreOps.pack
|
||||||
|
, CoreOps.pack'
|
||||||
, placeholder
|
, placeholder
|
||||||
|
, placeholder'
|
||||||
, CoreOps.range
|
, CoreOps.range
|
||||||
|
, CoreOps.range'
|
||||||
, reducedShape
|
, reducedShape
|
||||||
, CoreOps.relu
|
, CoreOps.relu
|
||||||
|
, CoreOps.relu'
|
||||||
, CoreOps.reluGrad
|
, CoreOps.reluGrad
|
||||||
|
, CoreOps.reluGrad'
|
||||||
, CoreOps.reshape
|
, CoreOps.reshape
|
||||||
|
, CoreOps.reshape'
|
||||||
, restore
|
, restore
|
||||||
, restoreFromName
|
, restoreFromName
|
||||||
, save
|
, save
|
||||||
, scalar
|
, scalar
|
||||||
|
, scalar'
|
||||||
, shape
|
, shape
|
||||||
|
, shape'
|
||||||
, CoreOps.sign
|
, CoreOps.sign
|
||||||
|
, CoreOps.sign'
|
||||||
, CoreOps.size
|
, CoreOps.size
|
||||||
|
, CoreOps.size'
|
||||||
, CoreOps.softmax
|
, CoreOps.softmax
|
||||||
|
, CoreOps.softmax'
|
||||||
, CoreOps.softmaxCrossEntropyWithLogits
|
, CoreOps.softmaxCrossEntropyWithLogits
|
||||||
|
, CoreOps.softmaxCrossEntropyWithLogits'
|
||||||
, CoreOps.sparseToDense
|
, CoreOps.sparseToDense
|
||||||
|
, CoreOps.sparseToDense'
|
||||||
, CoreOps.sub
|
, CoreOps.sub
|
||||||
|
, CoreOps.sub'
|
||||||
, CoreOps.sum
|
, CoreOps.sum
|
||||||
|
, CoreOps.sum'
|
||||||
, CoreOps.transpose
|
, CoreOps.transpose
|
||||||
|
, CoreOps.transpose'
|
||||||
, truncatedNormal
|
, truncatedNormal
|
||||||
|
, truncatedNormal'
|
||||||
, CoreOps.variable
|
, CoreOps.variable
|
||||||
|
, CoreOps.variable'
|
||||||
, vector
|
, vector
|
||||||
|
, vector'
|
||||||
, zeros
|
, zeros
|
||||||
, CoreOps.zerosLike
|
, CoreOps.zerosLike
|
||||||
|
, CoreOps.zerosLike'
|
||||||
, scalarize
|
, scalarize
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.Word (Word16)
|
||||||
import Prelude hiding (abs, sum, concat)
|
import Prelude hiding (abs, sum, concat)
|
||||||
import Data.ProtoLens (def)
|
import Data.ProtoLens (def)
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
|
@ -151,29 +194,31 @@ instance ( TensorType a
|
||||||
signum = CoreOps.sign
|
signum = CoreOps.sign
|
||||||
negate = CoreOps.neg
|
negate = CoreOps.neg
|
||||||
|
|
||||||
matTranspose :: forall a v . TensorType a
|
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
|
||||||
=> Tensor v a -> Tensor Value a
|
matTranspose = matTranspose' id
|
||||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
|
||||||
|
|
||||||
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
|
||||||
placeholder shape' =
|
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||||
build $ buildOp $ opDef "Placeholder"
|
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||||
& opAttr "shape" .~ shape'
|
placeholder = placeholder' id
|
||||||
|
|
||||||
|
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
|
||||||
|
placeholder' params pShape
|
||||||
|
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
|
initializedVariable :: (MonadBuild m, TensorType a)
|
||||||
=> Tensor Value a -> m (Tensor Ref a)
|
=> Tensor Value a -> m (Tensor Ref a)
|
||||||
initializedVariable initializer = do
|
initializedVariable = initializedVariable' id
|
||||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
|
||||||
(i :: Tensor Ref a) <-
|
initializedVariable' :: (MonadBuild m, TensorType a)
|
||||||
build $ buildOp (opDef "Assign"
|
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
|
||||||
& opAttr "T" .~ tensorType (undefined :: a)
|
initializedVariable' params initializer = do
|
||||||
& opAttr "use_locking" .~ True
|
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
||||||
& opAttr "validate_shape" .~ False
|
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
||||||
)
|
initializer
|
||||||
v initializer
|
|
||||||
addInitializer =<< group i
|
addInitializer =<< group i
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -181,7 +226,12 @@ initializedVariable initializer = do
|
||||||
zeroInitializedVariable
|
zeroInitializedVariable
|
||||||
:: (MonadBuild m, TensorType a, Num a) =>
|
:: (MonadBuild m, TensorType a, Num a) =>
|
||||||
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||||
zeroInitializedVariable = initializedVariable . zeros
|
zeroInitializedVariable = zeroInitializedVariable' id
|
||||||
|
|
||||||
|
zeroInitializedVariable'
|
||||||
|
:: (MonadBuild m, TensorType a, Num a) =>
|
||||||
|
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||||
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||||
|
|
||||||
-- TODO: Support heterogeneous list of tensors.
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
|
@ -227,23 +277,23 @@ restore path x = do
|
||||||
-- element 0: index (0, ..., 0)
|
-- element 0: index (0, ..., 0)
|
||||||
-- element 1: index (0, ..., 1)
|
-- element 1: index (0, ..., 1)
|
||||||
-- ...
|
-- ...
|
||||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
constant :: TensorType a => Shape -> [a] -> Tensor Value a
|
||||||
constant (Shape shape') values
|
constant = constant' id
|
||||||
|
|
||||||
|
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
|
||||||
|
constant' params (Shape cShape) values
|
||||||
| invalidLength = error invalidLengthMsg
|
| invalidLength = error invalidLengthMsg
|
||||||
| otherwise = buildOp $ opDef "Const"
|
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
||||||
& opAttr "value" .~ typedNode
|
|
||||||
& opAttr "dtype" .~ nodeType
|
|
||||||
where
|
where
|
||||||
invalidLength = product shape' /= fromIntegral (length values)
|
invalidLength = product cShape /= fromIntegral (length values)
|
||||||
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
||||||
(product shape')
|
(product cShape)
|
||||||
(length values)
|
(length values)
|
||||||
nodeType = tensorType (undefined :: a)
|
|
||||||
typedNode :: TensorProto
|
typedNode :: TensorProto
|
||||||
typedNode = def
|
typedNode = def
|
||||||
& dtype .~ nodeType
|
& dtype .~ tensorType (undefined :: a)
|
||||||
& tensorShape.TensorShape.dim .~
|
& tensorShape.TensorShape.dim .~
|
||||||
[def & TensorShape.size .~ x | x <- shape']
|
[def & TensorShape.size .~ x | x <- cShape]
|
||||||
& tensorVal .~ values
|
& tensorVal .~ values
|
||||||
|
|
||||||
-- | Reshape a N-D tensor down to a scalar.
|
-- | Reshape a N-D tensor down to a scalar.
|
||||||
|
@ -257,30 +307,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
|
|
||||||
-- | Create a constant vector.
|
-- | Create a constant vector.
|
||||||
vector :: TensorType a => [a] -> Tensor Value a
|
vector :: TensorType a => [a] -> Tensor Value a
|
||||||
vector xs = constant [fromIntegral $ length xs] xs
|
vector = vector' id
|
||||||
|
|
||||||
|
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
|
||||||
|
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
||||||
|
|
||||||
-- | Create a constant scalar.
|
-- | Create a constant scalar.
|
||||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
scalar :: TensorType a => a -> Tensor Value a
|
||||||
scalar x = constant [] [x]
|
scalar = scalar' id
|
||||||
|
|
||||||
-- Random tensor from the unit normal distribution with bounded values.
|
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
|
||||||
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
|
scalar' params x = constant' params [] [x]
|
||||||
|
|
||||||
|
-- | Random tensor from the unit normal distribution with bounded values.
|
||||||
|
--
|
||||||
|
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
|
||||||
|
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||||
=> Tensor v Int64 -- ^ Shape.
|
=> Tensor v Int64 -- ^ Shape.
|
||||||
-> m (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
truncatedNormal
|
truncatedNormal = CoreOps.truncatedNormal
|
||||||
= build . buildOp (opDef "TruncatedNormal"
|
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||||
& opAttr "T" .~ tensorType (undefined :: Int64))
|
=> OpParams -> Tensor v Int64 -- ^ Shape.
|
||||||
|
-> m (Tensor Value a)
|
||||||
|
truncatedNormal' = CoreOps.truncatedNormal'
|
||||||
|
|
||||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
||||||
|
|
||||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||||
shape = CoreOps.shape
|
shape = CoreOps.shape
|
||||||
|
|
||||||
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
|
||||||
|
shape' = CoreOps.shape'
|
||||||
|
|
||||||
|
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||||
expandDims = CoreOps.expandDims
|
expandDims = CoreOps.expandDims
|
||||||
|
|
||||||
|
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||||
|
expandDims' = CoreOps.expandDims'
|
||||||
|
|
||||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Lens.Family2 ((^.))
|
import Lens.Family2 ((^.), (.~))
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
( node )
|
( node )
|
||||||
|
@ -38,8 +38,8 @@ import TensorFlow.Build
|
||||||
, withDevice
|
, withDevice
|
||||||
, colocateWith
|
, colocateWith
|
||||||
, withNameScope
|
, withNameScope
|
||||||
|
, opName
|
||||||
)
|
)
|
||||||
import TensorFlow.ControlFlow (named)
|
|
||||||
import TensorFlow.Types (unScalar)
|
import TensorFlow.Types (unScalar)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
( add
|
( add
|
||||||
|
@ -47,6 +47,7 @@ import TensorFlow.Ops
|
||||||
, constant
|
, constant
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, variable
|
, variable
|
||||||
|
, variable'
|
||||||
)
|
)
|
||||||
import TensorFlow.Output (Device(..))
|
import TensorFlow.Output (Device(..))
|
||||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||||
|
@ -61,26 +62,16 @@ import Test.HUnit ((@=?))
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
-- | Test named behavior.
|
-- | Test 'opName' behavior.
|
||||||
testNamed :: Test
|
testOpName :: Test
|
||||||
testNamed = testCase "testNamed" $ do
|
testOpName = testCase "testOpName" $ do
|
||||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
let graph = variable' (opName .~ "foo") []
|
||||||
|
>>= render :: Build (Tensor Ref Float)
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
nodeDef = head $ asGraphDef graph ^. node
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
"RefIdentity" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
"foo" @=? (nodeDef ^. name)
|
"foo" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
-- | Test named deRef behavior.
|
|
||||||
testNamedDeRef :: Test
|
|
||||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
|
||||||
let graph = named "foo" <$> do
|
|
||||||
v :: Tensor Ref Float <- variable []
|
|
||||||
assign v 5
|
|
||||||
-- TODO: Implement TensorFlow get_variable and test it.
|
|
||||||
runSession $ do
|
|
||||||
out <- graph >>= run
|
|
||||||
liftIO $ 5 @=? (unScalar out :: Float)
|
|
||||||
|
|
||||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||||
-- been rendered.
|
-- been rendered.
|
||||||
testPureRender :: Test
|
testPureRender :: Test
|
||||||
|
@ -118,14 +109,15 @@ testNameScoped = testCase "testNameScoped" $ do
|
||||||
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
||||||
"Variable" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
|
|
||||||
-- | Test combined named and nameScoped behavior.
|
-- | Test combined opName and nameScoped behavior.
|
||||||
testNamedAndScoped :: Test
|
testNamedAndScoped :: Test
|
||||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||||
let graph :: Build (Tensor Ref Float)
|
let graph :: Build (Tensor Ref Float)
|
||||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
||||||
|
>>= render
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
nodeDef = head $ asGraphDef graph ^. node
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
"RefIdentity" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
"foo1/bar1" @=? (nodeDef ^. name)
|
"foo1/bar1" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||||
|
@ -174,8 +166,7 @@ main :: IO ()
|
||||||
main = googleTest [ testInitializedVariable
|
main = googleTest [ testInitializedVariable
|
||||||
, testInitializedVariableShape
|
, testInitializedVariableShape
|
||||||
, testDeviceColocation
|
, testDeviceColocation
|
||||||
, testNamed
|
, testOpName
|
||||||
, testNamedDeRef
|
|
||||||
, testNameScoped
|
, testNameScoped
|
||||||
, testNamedAndScoped
|
, testNamedAndScoped
|
||||||
, testPureRender
|
, testPureRender
|
||||||
|
|
|
@ -19,6 +19,7 @@ module Main where
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
import Lens.Family2 ((.~))
|
||||||
import System.IO.Temp (withSystemTempDirectory)
|
import System.IO.Temp (withSystemTempDirectory)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -27,7 +28,6 @@ import qualified Data.ByteString.Char8 as B8
|
||||||
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Build as TF
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
|
||||||
import qualified TensorFlow.Nodes as TF
|
import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
|
@ -56,7 +56,8 @@ testSaveRestore = testCase "testSaveRestore" $
|
||||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||||
var = TF.render =<<
|
var = TF.render =<<
|
||||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||||
|
(TF.Shape [])
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
v <- var
|
v <- var
|
||||||
TF.assign v 134 >>= TF.run_
|
TF.assign v 134 >>= TF.run_
|
||||||
|
|
|
@ -36,7 +36,6 @@ import qualified Data.ByteString as B
|
||||||
import qualified Data.ByteString.Char8 as B8
|
import qualified Data.ByteString.Char8 as B8
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
|
||||||
import qualified TensorFlow.GenOps.Core as TF (select)
|
import qualified TensorFlow.GenOps.Core as TF (select)
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
|
|
|
@ -23,6 +23,7 @@ module TensorFlow.BuildOp
|
||||||
, buildOp
|
, buildOp
|
||||||
, buildListOp
|
, buildListOp
|
||||||
, eqLengthGuard
|
, eqLengthGuard
|
||||||
|
, OpParams
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
|
|
||||||
|
@ -238,3 +239,7 @@ eqLengthGuard = all eachOk
|
||||||
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
||||||
error ("number_attr " ++ numberAttrName ++
|
error ("number_attr " ++ numberAttrName ++
|
||||||
" contains tensors with different length " ++ show pairs)
|
" contains tensors with different length " ++ show pairs)
|
||||||
|
|
||||||
|
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||||
|
-- TODO: be more type safe.
|
||||||
|
type OpParams = OpDef -> OpDef
|
||||||
|
|
|
@ -22,21 +22,15 @@ module TensorFlow.ControlFlow
|
||||||
withControlDependencies
|
withControlDependencies
|
||||||
, group
|
, group
|
||||||
-- * Operations
|
-- * Operations
|
||||||
, identity
|
|
||||||
, noOp
|
, noOp
|
||||||
, named
|
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import Data.Text (Text)
|
import Lens.Family2 ((&), (.~))
|
||||||
import Lens.Family2 ((&), (^.), (.~))
|
|
||||||
|
|
||||||
import TensorFlow.BuildOp
|
import TensorFlow.BuildOp
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Nodes
|
import TensorFlow.Nodes
|
||||||
import TensorFlow.Output
|
|
||||||
import TensorFlow.Tensor
|
|
||||||
import TensorFlow.Types
|
|
||||||
|
|
||||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||||
-- on the nodes in the first argument.
|
-- on the nodes in the first argument.
|
||||||
|
@ -57,31 +51,6 @@ group deps = do
|
||||||
-- TODO: slicker way
|
-- TODO: slicker way
|
||||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||||
|
|
||||||
|
|
||||||
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
|
||||||
identity :: TensorType a => Tensor v a -> Tensor v a
|
|
||||||
identity = namedIdentity implicitName
|
|
||||||
|
|
||||||
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
|
||||||
-- the input.
|
|
||||||
--
|
|
||||||
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
|
||||||
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
|
||||||
-- whether we can change that op.
|
|
||||||
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
|
||||||
named = namedIdentity . explicitName
|
|
||||||
|
|
||||||
-- | An internal version of "identity" that allows setting the name
|
|
||||||
-- of the output Tensor.
|
|
||||||
namedIdentity :: forall a v . TensorType a
|
|
||||||
=> PendingNodeName -> Tensor v a -> Tensor v a
|
|
||||||
namedIdentity n t = case t ^. tensorKind of
|
|
||||||
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
|
||||||
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
|
||||||
where
|
|
||||||
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
|
||||||
|
|
||||||
|
|
||||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||||
noOp :: ControlNode
|
noOp :: ControlNode
|
||||||
noOp = buildOp $ opDef "NoOp"
|
noOp = buildOp $ opDef "NoOp"
|
||||||
|
|
|
@ -50,14 +50,14 @@ module TensorFlow.Core
|
||||||
, render
|
, render
|
||||||
, asGraphDef
|
, asGraphDef
|
||||||
, addGraphDef
|
, addGraphDef
|
||||||
|
, opName
|
||||||
|
, opAttr
|
||||||
-- * Tensor
|
-- * Tensor
|
||||||
, ControlNode
|
, ControlNode
|
||||||
, Tensor
|
, Tensor
|
||||||
, Value
|
, Value
|
||||||
, Ref
|
, Ref
|
||||||
, TensorKind(..)
|
, TensorKind(..)
|
||||||
, tensorAttr
|
|
||||||
, value
|
, value
|
||||||
, tensorFromName
|
, tensorFromName
|
||||||
-- ** Element types
|
-- ** Element types
|
||||||
|
@ -74,12 +74,10 @@ module TensorFlow.Core
|
||||||
, Device(..)
|
, Device(..)
|
||||||
, withDevice
|
, withDevice
|
||||||
, withNameScope
|
, withNameScope
|
||||||
, named
|
|
||||||
-- ** Dependencies
|
-- ** Dependencies
|
||||||
, withControlDependencies
|
, withControlDependencies
|
||||||
, group
|
, group
|
||||||
-- ** Misc
|
-- ** Misc
|
||||||
, identity
|
|
||||||
, noOp
|
, noOp
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
|
|
@ -124,6 +124,9 @@ data OpDef = OpDef
|
||||||
data PendingNodeName = ExplicitName !Text | ImplicitName
|
data PendingNodeName = ExplicitName !Text | ImplicitName
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
instance IsString PendingNodeName where
|
||||||
|
fromString = ExplicitName . fromString
|
||||||
|
|
||||||
-- | The name of a node in the graph. This corresponds to the proto field
|
-- | The name of a node in the graph. This corresponds to the proto field
|
||||||
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
||||||
-- (if the node was implicitly named).
|
-- (if the node was implicitly named).
|
||||||
|
|
|
@ -27,13 +27,12 @@ module TensorFlow.Tensor where
|
||||||
|
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens', Traversal', (^.))
|
import Lens.Family2 (Lens', (^.))
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
|
||||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
import TensorFlow.Output (Output)
|
||||||
import TensorFlow.Types
|
import TensorFlow.Types
|
||||||
( TensorData(..)
|
( TensorData(..)
|
||||||
, Attribute
|
|
||||||
, ListOf(..)
|
, ListOf(..)
|
||||||
)
|
)
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
@ -62,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
||||||
tensorOutput :: Lens' (Tensor v a) Output
|
tensorOutput :: Lens' (Tensor v a) Output
|
||||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||||
|
|
||||||
-- TODO: Come up with a better API for handling attributes.
|
|
||||||
-- | Lens for the attributes of a tensor.
|
|
||||||
--
|
|
||||||
-- Only valid if the tensor has not yet been rendered. If the tensor has been
|
|
||||||
-- rendered, the traversal will be over nothing (nothing can be read or
|
|
||||||
-- written).
|
|
||||||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
|
||||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
|
||||||
|
|
||||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||||
-- Ref into Value. This behaves like a no-op.
|
-- Ref into Value. This behaves like a no-op.
|
||||||
value :: Tensor v a -> Tensor Value a
|
value :: Tensor v a -> Tensor Value a
|
||||||
|
|
Loading…
Reference in a new issue