mirror of
https://github.com/tensorflow/haskell.git
synced 2025-02-17 05:25:05 +01:00
Fix Ref and Build semantics for generated code. (#37)
Also: - Make TensorFlow.Ops.{variable,assign} be the Core generated versions. - Make ops take "Shape" as mandatory input.
This commit is contained in:
parent
a277c7ddb3
commit
cec666e135
3 changed files with 53 additions and 23 deletions
|
@ -16,7 +16,32 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
-- | Rendering of TensorFlow operations as Haskell functions.
|
{- | Rendering of TensorFlow operations as Haskell functions.
|
||||||
|
|
||||||
|
The basic type signature generated for each op is:
|
||||||
|
|
||||||
|
> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors}
|
||||||
|
|
||||||
|
where:
|
||||||
|
|
||||||
|
* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an
|
||||||
|
op attribute that doesn't have a default and can't be inferred from other
|
||||||
|
inputs.
|
||||||
|
|
||||||
|
* @{constraints}@ restrict the type parameters of the input and output tensors
|
||||||
|
(for example: 'TensorType' or 'OneOf').
|
||||||
|
|
||||||
|
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||||
|
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
|
||||||
|
of those types), and @a@ is either a concrete type or a (constrained) type
|
||||||
|
variable.
|
||||||
|
|
||||||
|
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
|
||||||
|
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
|
||||||
|
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
|
||||||
|
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
|
||||||
|
it is either @ControlNode@ or @Build ControlNode@.)
|
||||||
|
-}
|
||||||
|
|
||||||
module TensorFlow.OpGen
|
module TensorFlow.OpGen
|
||||||
( OpGenFlags(..)
|
( OpGenFlags(..)
|
||||||
|
@ -259,15 +284,18 @@ typeSig pOp = constraints
|
||||||
AttrFloat -> "Float"
|
AttrFloat -> "Float"
|
||||||
AttrBool -> "Bool"
|
AttrBool -> "Bool"
|
||||||
AttrType -> "DataType"
|
AttrType -> "DataType"
|
||||||
AttrShape -> "TensorShapeProto"
|
AttrShape -> "Shape"
|
||||||
AttrTensor -> "TensorProto"
|
AttrTensor -> "TensorProto"
|
||||||
|
|
||||||
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
||||||
outputs = case parsedOutputs pOp of
|
outputs = case parsedOutputs pOp of
|
||||||
[] -> "ControlNode"
|
[] -> wrapOutput "ControlNode"
|
||||||
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
||||||
[a] -> tensorArg a <+> "-- ^" <+> argComment a
|
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
||||||
as -> tuple (map tensorArg as) <+/> resultComment as
|
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
||||||
|
wrapOutput o
|
||||||
|
| parsedOpIsMonadic pOp = "Build" <+> parens o
|
||||||
|
| otherwise = o
|
||||||
|
|
||||||
-- | Render an op input or output.
|
-- | Render an op input or output.
|
||||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
||||||
|
|
|
@ -41,6 +41,8 @@ import Proto.Tensorflow.Core.Framework.OpDef
|
||||||
, description
|
, description
|
||||||
, name
|
, name
|
||||||
, inputArg
|
, inputArg
|
||||||
|
, isRef
|
||||||
|
, isStateful
|
||||||
, outputArg
|
, outputArg
|
||||||
, summary
|
, summary
|
||||||
, typeListAttr
|
, typeListAttr
|
||||||
|
@ -67,6 +69,10 @@ data ParsedOp = ParsedOp
|
||||||
-- Attributes which are list sizes (ints) that are inferred automatically
|
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||||
-- from one or more of the input tensors.
|
-- from one or more of the input tensors.
|
||||||
-- Associated with the list of tensors whose size it describes.
|
-- Associated with the list of tensors whose size it describes.
|
||||||
|
, parsedOpIsMonadic :: Bool
|
||||||
|
-- ^ Whether this op is stateful or takes a stateful input. Such ops
|
||||||
|
-- should not be CSE'd and must be monadic in our API (i.e., return a
|
||||||
|
-- Build action).
|
||||||
}
|
}
|
||||||
|
|
||||||
data Name = Name
|
data Name = Name
|
||||||
|
@ -127,6 +133,10 @@ data ArgKind
|
||||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||||
| ArgResource -- Resource a
|
| ArgResource -- Resource a
|
||||||
|
|
||||||
|
isRefKind :: ArgKind -> Bool
|
||||||
|
isRefKind ArgTensorRef = True
|
||||||
|
isRefKind ArgResource = True
|
||||||
|
isRefKind _ = False
|
||||||
|
|
||||||
makeName :: Text -> Name
|
makeName :: Text -> Name
|
||||||
makeName n = Name
|
makeName n = Name
|
||||||
|
@ -197,6 +207,8 @@ parseOp o = ParsedOp
|
||||||
{ parsedOpName = makeName $ o ^. name
|
{ parsedOpName = makeName $ o ^. name
|
||||||
, parsedOpSummary = o ^. summary
|
, parsedOpSummary = o ^. summary
|
||||||
, parsedOpDescription = o ^. description
|
, parsedOpDescription = o ^. description
|
||||||
|
, parsedOpIsMonadic = o ^. isStateful
|
||||||
|
|| any (isRefKind . parsedArgKind) parsedInputs
|
||||||
, ..
|
, ..
|
||||||
}
|
}
|
||||||
where
|
where
|
||||||
|
@ -218,11 +230,13 @@ parseOp o = ParsedOp
|
||||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||||
inputTensorKind a v
|
inputTensorKind a v
|
||||||
| a ^. type' == DT_RESOURCE = ArgResource
|
| a ^. type' == DT_RESOURCE = ArgResource
|
||||||
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorEither v
|
| otherwise = ArgTensorEither v
|
||||||
|
|
||||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||||
outputTensorKind a
|
outputTensorKind a
|
||||||
| a ^. type' == DT_RESOURCE = ArgResource
|
| a ^. type' == DT_RESOURCE = ArgResource
|
||||||
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorValue
|
| otherwise = ArgTensorValue
|
||||||
|
|
||||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||||
|
@ -230,7 +244,7 @@ getExplicitInputAttr implicitAttrs a
|
||||||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||||
, a ^. maybe'defaultValue == Nothing
|
, a ^. maybe'defaultValue == Nothing
|
||||||
, t <- parseAttrType (a ^. type')
|
, t <- parseAttrType (a ^. type')
|
||||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
|
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||||
|
|
|
@ -60,7 +60,7 @@ module TensorFlow.Ops
|
||||||
, CoreOps.abs
|
, CoreOps.abs
|
||||||
, CoreOps.addN
|
, CoreOps.addN
|
||||||
, CoreOps.argMax
|
, CoreOps.argMax
|
||||||
, assign
|
, CoreOps.assign
|
||||||
, CoreOps.broadcastGradientArgs
|
, CoreOps.broadcastGradientArgs
|
||||||
, CoreOps.cast
|
, CoreOps.cast
|
||||||
, CoreOps.concat
|
, CoreOps.concat
|
||||||
|
@ -97,7 +97,7 @@ module TensorFlow.Ops
|
||||||
, CoreOps.sum
|
, CoreOps.sum
|
||||||
, CoreOps.transpose
|
, CoreOps.transpose
|
||||||
, truncatedNormal
|
, truncatedNormal
|
||||||
, variable
|
, CoreOps.variable
|
||||||
, vector
|
, vector
|
||||||
, zeros
|
, zeros
|
||||||
, CoreOps.zerosLike
|
, CoreOps.zerosLike
|
||||||
|
@ -154,31 +154,18 @@ matTranspose :: forall a v . TensorType a
|
||||||
=> Tensor v a -> Tensor Value a
|
=> Tensor v a -> Tensor Value a
|
||||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||||
|
|
||||||
-- | Create a new, uninitialized stateful Tensor of the given shape.
|
|
||||||
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
|
|
||||||
variable shape' = buildOp $ opDef "Variable"
|
|
||||||
& opAttr "shape" .~ shape'
|
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
|
||||||
|
|
||||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||||
placeholder shape' =
|
placeholder shape' =
|
||||||
buildOp $ opDef "Placeholder"
|
buildOp $ opDef "Placeholder"
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
& opAttr "shape" .~ shape'
|
& opAttr "shape" .~ shape'
|
||||||
|
|
||||||
-- Assign returns the input ref.
|
|
||||||
assign :: forall a v . TensorType a
|
|
||||||
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
|
|
||||||
assign = buildOp $ opDef "Assign"
|
|
||||||
& opAttr "T" .~ tensorType (undefined :: a)
|
|
||||||
& opAttr "use_locking" .~ True
|
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: forall a . TensorType a
|
initializedVariable :: forall a . TensorType a
|
||||||
=> Tensor Value a -> Build (Tensor Ref a)
|
=> Tensor Value a -> Build (Tensor Ref a)
|
||||||
initializedVariable initializer = do
|
initializedVariable initializer = do
|
||||||
v <- variable [] -- The shape is not known initially.
|
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||||
(i :: Tensor Ref a) <-
|
(i :: Tensor Ref a) <-
|
||||||
buildOp (opDef "Assign"
|
buildOp (opDef "Assign"
|
||||||
& opAttr "T" .~ tensorType (undefined :: a)
|
& opAttr "T" .~ tensorType (undefined :: a)
|
||||||
|
@ -220,7 +207,8 @@ restoreFromName :: forall a . TensorType a
|
||||||
restoreFromName path name x = do
|
restoreFromName path name x = do
|
||||||
let restoreOp = buildOp $ opDef "Restore"
|
let restoreOp = buildOp $ opDef "Restore"
|
||||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||||
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
group =<< CoreOps.assign x
|
||||||
|
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
restore :: forall a . TensorType a
|
restore :: forall a . TensorType a
|
||||||
|
|
Loading…
Add table
Reference in a new issue