1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-15 09:08:33 +02:00

Fix Ref and Build semantics for generated code. (#37)

Also:
- Make TensorFlow.Ops.{variable,assign} be the Core generated versions.
- Make ops take "Shape" as mandatory input.
This commit is contained in:
Judah Jacobson 2016-11-21 10:19:15 -08:00 committed by Greg Steuck
parent a277c7ddb3
commit cec666e135
3 changed files with 53 additions and 23 deletions

View File

@ -16,7 +16,32 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions. {- | Rendering of TensorFlow operations as Haskell functions.
The basic type signature generated for each op is:
> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors}
where:
* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an
op attribute that doesn't have a default and can't be inferred from other
inputs.
* @{constraints}@ restrict the type parameters of the input and output tensors
(for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
of those types), and @a@ is either a concrete type or a (constrained) type
variable.
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
it is either @ControlNode@ or @Build ControlNode@.)
-}
module TensorFlow.OpGen module TensorFlow.OpGen
( OpGenFlags(..) ( OpGenFlags(..)
@ -259,15 +284,18 @@ typeSig pOp = constraints
AttrFloat -> "Float" AttrFloat -> "Float"
AttrBool -> "Bool" AttrBool -> "Bool"
AttrType -> "DataType" AttrType -> "DataType"
AttrShape -> "TensorShapeProto" AttrShape -> "Shape"
AttrTensor -> "TensorProto" AttrTensor -> "TensorProto"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t) tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of outputs = case parsedOutputs pOp of
[] -> "ControlNode" [] -> wrapOutput "ControlNode"
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a` -- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
[a] -> tensorArg a <+> "-- ^" <+> argComment a [a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> tuple (map tensorArg as) <+/> resultComment as as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o
| otherwise = o
-- | Render an op input or output. -- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"

View File

@ -41,6 +41,8 @@ import Proto.Tensorflow.Core.Framework.OpDef
, description , description
, name , name
, inputArg , inputArg
, isRef
, isStateful
, outputArg , outputArg
, summary , summary
, typeListAttr , typeListAttr
@ -67,6 +69,10 @@ data ParsedOp = ParsedOp
-- Attributes which are list sizes (ints) that are inferred automatically -- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors. -- from one or more of the input tensors.
-- Associated with the list of tensors whose size it describes. -- Associated with the list of tensors whose size it describes.
, parsedOpIsMonadic :: Bool
-- ^ Whether this op is stateful or takes a stateful input. Such ops
-- should not be CSE'd and must be monadic in our API (i.e., return a
-- Build action).
} }
data Name = Name data Name = Name
@ -127,6 +133,10 @@ data ArgKind
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v` | ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a | ArgResource -- Resource a
isRefKind :: ArgKind -> Bool
isRefKind ArgTensorRef = True
isRefKind ArgResource = True
isRefKind _ = False
makeName :: Text -> Name makeName :: Text -> Name
makeName n = Name makeName n = Name
@ -197,6 +207,8 @@ parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name { parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary , parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description , parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefKind . parsedArgKind) parsedInputs
, .. , ..
} }
where where
@ -218,11 +230,13 @@ parseOp o = ParsedOp
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource | a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v | otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource | a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue | otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
@ -230,7 +244,7 @@ getExplicitInputAttr implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs | TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing , a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type') , t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
| otherwise = Nothing | otherwise = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]

View File

@ -60,7 +60,7 @@ module TensorFlow.Ops
, CoreOps.abs , CoreOps.abs
, CoreOps.addN , CoreOps.addN
, CoreOps.argMax , CoreOps.argMax
, assign , CoreOps.assign
, CoreOps.broadcastGradientArgs , CoreOps.broadcastGradientArgs
, CoreOps.cast , CoreOps.cast
, CoreOps.concat , CoreOps.concat
@ -97,7 +97,7 @@ module TensorFlow.Ops
, CoreOps.sum , CoreOps.sum
, CoreOps.transpose , CoreOps.transpose
, truncatedNormal , truncatedNormal
, variable , CoreOps.variable
, vector , vector
, zeros , zeros
, CoreOps.zerosLike , CoreOps.zerosLike
@ -154,31 +154,18 @@ matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a => Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
-- | Create a new, uninitialized stateful Tensor of the given shape.
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
variable shape' = buildOp $ opDef "Variable"
& opAttr "shape" .~ shape'
& opAttr "dtype" .~ tensorType (undefined :: a)
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a) placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder shape' = placeholder shape' =
buildOp $ opDef "Placeholder" buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape' & opAttr "shape" .~ shape'
-- Assign returns the input ref.
assign :: forall a v . TensorType a
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
assign = buildOp $ opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a) => Tensor Value a -> Build (Tensor Ref a)
initializedVariable initializer = do initializedVariable initializer = do
v <- variable [] -- The shape is not known initially. v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <- (i :: Tensor Ref a) <-
buildOp (opDef "Assign" buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a) & opAttr "T" .~ tensorType (undefined :: a)
@ -220,7 +207,8 @@ restoreFromName :: forall a . TensorType a
restoreFromName path name x = do restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore" let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a) & opAttr "dt" .~ tensorType (undefined :: a)
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a) group =<< CoreOps.assign x
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a restore :: forall a . TensorType a