mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
Fix Ref and Build semantics for generated code. (#37)
Also: - Make TensorFlow.Ops.{variable,assign} be the Core generated versions. - Make ops take "Shape" as mandatory input.
This commit is contained in:
parent
a277c7ddb3
commit
cec666e135
3 changed files with 53 additions and 23 deletions
|
@ -16,7 +16,32 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
-- | Rendering of TensorFlow operations as Haskell functions.
|
||||
{- | Rendering of TensorFlow operations as Haskell functions.
|
||||
|
||||
The basic type signature generated for each op is:
|
||||
|
||||
> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors}
|
||||
|
||||
where:
|
||||
|
||||
* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an
|
||||
op attribute that doesn't have a default and can't be inferred from other
|
||||
inputs.
|
||||
|
||||
* @{constraints}@ restrict the type parameters of the input and output tensors
|
||||
(for example: 'TensorType' or 'OneOf').
|
||||
|
||||
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
|
||||
of those types), and @a@ is either a concrete type or a (constrained) type
|
||||
variable.
|
||||
|
||||
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
|
||||
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
|
||||
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
|
||||
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
|
||||
it is either @ControlNode@ or @Build ControlNode@.)
|
||||
-}
|
||||
|
||||
module TensorFlow.OpGen
|
||||
( OpGenFlags(..)
|
||||
|
@ -259,15 +284,18 @@ typeSig pOp = constraints
|
|||
AttrFloat -> "Float"
|
||||
AttrBool -> "Bool"
|
||||
AttrType -> "DataType"
|
||||
AttrShape -> "TensorShapeProto"
|
||||
AttrShape -> "Shape"
|
||||
AttrTensor -> "TensorProto"
|
||||
|
||||
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
||||
outputs = case parsedOutputs pOp of
|
||||
[] -> "ControlNode"
|
||||
[] -> wrapOutput "ControlNode"
|
||||
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
||||
[a] -> tensorArg a <+> "-- ^" <+> argComment a
|
||||
as -> tuple (map tensorArg as) <+/> resultComment as
|
||||
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
||||
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
||||
wrapOutput o
|
||||
| parsedOpIsMonadic pOp = "Build" <+> parens o
|
||||
| otherwise = o
|
||||
|
||||
-- | Render an op input or output.
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
||||
|
|
|
@ -41,6 +41,8 @@ import Proto.Tensorflow.Core.Framework.OpDef
|
|||
, description
|
||||
, name
|
||||
, inputArg
|
||||
, isRef
|
||||
, isStateful
|
||||
, outputArg
|
||||
, summary
|
||||
, typeListAttr
|
||||
|
@ -67,6 +69,10 @@ data ParsedOp = ParsedOp
|
|||
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||
-- from one or more of the input tensors.
|
||||
-- Associated with the list of tensors whose size it describes.
|
||||
, parsedOpIsMonadic :: Bool
|
||||
-- ^ Whether this op is stateful or takes a stateful input. Such ops
|
||||
-- should not be CSE'd and must be monadic in our API (i.e., return a
|
||||
-- Build action).
|
||||
}
|
||||
|
||||
data Name = Name
|
||||
|
@ -127,6 +133,10 @@ data ArgKind
|
|||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||
| ArgResource -- Resource a
|
||||
|
||||
isRefKind :: ArgKind -> Bool
|
||||
isRefKind ArgTensorRef = True
|
||||
isRefKind ArgResource = True
|
||||
isRefKind _ = False
|
||||
|
||||
makeName :: Text -> Name
|
||||
makeName n = Name
|
||||
|
@ -197,6 +207,8 @@ parseOp o = ParsedOp
|
|||
{ parsedOpName = makeName $ o ^. name
|
||||
, parsedOpSummary = o ^. summary
|
||||
, parsedOpDescription = o ^. description
|
||||
, parsedOpIsMonadic = o ^. isStateful
|
||||
|| any (isRefKind . parsedArgKind) parsedInputs
|
||||
, ..
|
||||
}
|
||||
where
|
||||
|
@ -218,11 +230,13 @@ parseOp o = ParsedOp
|
|||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
inputTensorKind a v
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorEither v
|
||||
|
||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||
outputTensorKind a
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorValue
|
||||
|
||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
|
@ -230,7 +244,7 @@ getExplicitInputAttr implicitAttrs a
|
|||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||
, a ^. maybe'defaultValue == Nothing
|
||||
, t <- parseAttrType (a ^. type')
|
||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
|
||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||
|
|
|
@ -60,7 +60,7 @@ module TensorFlow.Ops
|
|||
, CoreOps.abs
|
||||
, CoreOps.addN
|
||||
, CoreOps.argMax
|
||||
, assign
|
||||
, CoreOps.assign
|
||||
, CoreOps.broadcastGradientArgs
|
||||
, CoreOps.cast
|
||||
, CoreOps.concat
|
||||
|
@ -97,7 +97,7 @@ module TensorFlow.Ops
|
|||
, CoreOps.sum
|
||||
, CoreOps.transpose
|
||||
, truncatedNormal
|
||||
, variable
|
||||
, CoreOps.variable
|
||||
, vector
|
||||
, zeros
|
||||
, CoreOps.zerosLike
|
||||
|
@ -154,31 +154,18 @@ matTranspose :: forall a v . TensorType a
|
|||
=> Tensor v a -> Tensor Value a
|
||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||
|
||||
-- | Create a new, uninitialized stateful Tensor of the given shape.
|
||||
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
|
||||
variable shape' = buildOp $ opDef "Variable"
|
||||
& opAttr "shape" .~ shape'
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
|
||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||
placeholder shape' =
|
||||
buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
|
||||
-- Assign returns the input ref.
|
||||
assign :: forall a v . TensorType a
|
||||
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
|
||||
assign = buildOp $ opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
=> Tensor Value a -> Build (Tensor Ref a)
|
||||
initializedVariable initializer = do
|
||||
v <- variable [] -- The shape is not known initially.
|
||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
|
@ -220,7 +207,8 @@ restoreFromName :: forall a . TensorType a
|
|||
restoreFromName path name x = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
group =<< CoreOps.assign x
|
||||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a . TensorType a
|
||||
|
|
Loading…
Reference in a new issue