1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 18:49:46 +01:00

Fix Ref and Build semantics for generated code. (#37)

Also:
- Make TensorFlow.Ops.{variable,assign} be the Core generated versions.
- Make ops take "Shape" as mandatory input.
This commit is contained in:
Judah Jacobson 2016-11-21 10:19:15 -08:00 committed by Greg Steuck
parent a277c7ddb3
commit cec666e135
3 changed files with 53 additions and 23 deletions

View file

@ -16,7 +16,32 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions.
{- | Rendering of TensorFlow operations as Haskell functions.
The basic type signature generated for each op is:
> {constraints} => {mandatory attrs} -> {input tensors} -> {output tensors}
where:
* @{mandatory attrs}@ is of the form @A_1 -> ... -> A_N@, where each @A@ is an
op attribute that doesn't have a default and can't be inferred from other
inputs.
* @{constraints}@ restrict the type parameters of the input and output tensors
(for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle a@ (or a list of one
of those types), and @a@ is either a concrete type or a (constrained) type
variable.
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
it is either @ControlNode@ or @Build ControlNode@.)
-}
module TensorFlow.OpGen
( OpGenFlags(..)
@ -259,15 +284,18 @@ typeSig pOp = constraints
AttrFloat -> "Float"
AttrBool -> "Bool"
AttrType -> "DataType"
AttrShape -> "TensorShapeProto"
AttrShape -> "Shape"
AttrTensor -> "TensorProto"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of
[] -> "ControlNode"
[] -> wrapOutput "ControlNode"
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
[a] -> tensorArg a <+> "-- ^" <+> argComment a
as -> tuple (map tensorArg as) <+/> resultComment as
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o
| otherwise = o
-- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"

View file

@ -41,6 +41,8 @@ import Proto.Tensorflow.Core.Framework.OpDef
, description
, name
, inputArg
, isRef
, isStateful
, outputArg
, summary
, typeListAttr
@ -67,6 +69,10 @@ data ParsedOp = ParsedOp
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
-- Associated with the list of tensors whose size it describes.
, parsedOpIsMonadic :: Bool
-- ^ Whether this op is stateful or takes a stateful input. Such ops
-- should not be CSE'd and must be monadic in our API (i.e., return a
-- Build action).
}
data Name = Name
@ -127,6 +133,10 @@ data ArgKind
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a
isRefKind :: ArgKind -> Bool
isRefKind ArgTensorRef = True
isRefKind ArgResource = True
isRefKind _ = False
makeName :: Text -> Name
makeName n = Name
@ -197,6 +207,8 @@ parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefKind . parsedArgKind) parsedInputs
, ..
}
where
@ -218,11 +230,13 @@ parseOp o = ParsedOp
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
@ -230,7 +244,7 @@ getExplicitInputAttr implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
| otherwise = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]

View file

@ -60,7 +60,7 @@ module TensorFlow.Ops
, CoreOps.abs
, CoreOps.addN
, CoreOps.argMax
, assign
, CoreOps.assign
, CoreOps.broadcastGradientArgs
, CoreOps.cast
, CoreOps.concat
@ -97,7 +97,7 @@ module TensorFlow.Ops
, CoreOps.sum
, CoreOps.transpose
, truncatedNormal
, variable
, CoreOps.variable
, vector
, zeros
, CoreOps.zerosLike
@ -154,31 +154,18 @@ matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
-- | Create a new, uninitialized stateful Tensor of the given shape.
variable :: forall a . TensorType a => Shape -> Build (Tensor Ref a)
variable shape' = buildOp $ opDef "Variable"
& opAttr "shape" .~ shape'
& opAttr "dtype" .~ tensorType (undefined :: a)
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
-- Assign returns the input ref.
assign :: forall a v . TensorType a
=> Tensor Ref a -> Tensor v a -> Build (Tensor Ref a)
assign = buildOp $ opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable initializer = do
v <- variable [] -- The shape is not known initially.
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
@ -220,7 +207,8 @@ restoreFromName :: forall a . TensorType a
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
group =<< CoreOps.assign x
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a