1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00
This commit is contained in:
Judah Jacobson 2016-11-19 16:49:35 -08:00
parent eb9b91c28c
commit 9c0f1d2f88
2 changed files with 50 additions and 50 deletions

View File

@ -110,8 +110,8 @@ docOpList flags opList =
]
where moduleName =
Text.pack (prefix flags) <> "." <> camelCase
-- Discards the optional trailing _op_lib
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
-- Discards the optional trailing _ops_op_lib
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
renderOpAndExtras o = renderOp (parseOp o) </> extras o
@ -136,26 +136,28 @@ renderTFName = strictText . unTFName . tfName
renderQuotedTFName = dquotes . renderTFName
-- | Generate the source code for a singel op.
-- | Generate the source code for a single op.
-- For example:
--
-- -- | {haddock comment}
-- foo :: {type sig}
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
renderOp :: ParsedOp -> Doc
renderOp d = stack $
renderOp pOp = stack $
[ haddocks
, n <+> "::" <+> hang 0 (typeSig d)
, n <+> hang 0 args <+> "|" <+> funcGuard (inferredListSizeAttrs d)
, n <+> "::" <+> hang 0 (typeSig pOp)
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation (functionBody d)
] ++ whereClause (inferredListSizeAttrs d)
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where
n = renderHaskellName $ parsedOpName d
args = sep $ map (renderHaskellName . attrName) (explicitInputAttrs d)
++ map (renderHaskellName . parsedArgName) (parsedInputs d)
haddocks = "-- |" <+> multilineComment (parsedOpSummary d) (parsedOpDescription d)
n = renderHaskellName $ parsedOpName pOp
listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp)
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length.
-- For example:
@ -172,6 +174,7 @@ funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
]
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
"length" <+> renderHaskellName x
-- | Define the implicit list length attributes.
-- For example:
-- where
@ -187,7 +190,7 @@ whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr
<> ") :: Int64"
functionBody :: ParsedOp -> Doc
functionBody o = buildFunction <+> parens (hang 0 (stack buildOpParts))
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
buildFunction
@ -197,24 +200,24 @@ functionBody o = buildFunction <+> parens (hang 0 (stack buildOpParts))
map renderHaskellName outputListsSizes)
outputListsSizes = [ a
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs o]
<- parsedOutputs pOp]
buildOpParts =
"opDef" <+> renderQuotedTFName (parsedOpName o) :
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments.
[ "& opAttr" <+> renderQuotedTFName n <+>
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
| a <- inferredTypeAttrs o, let n = attrName a
| a <- inferredTypeAttrs pOp, let n = attrName a
] ++
-- Renders mandatory attributes as function parameters.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- explicitInputAttrs o, let n = attrName a
| a <- explicitInputAttrs pOp, let n = attrName a
] ++
-- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs o, let n = attrName a
| a <- inferredListSizeAttrs pOp, let n = attrName a
]
tensorArgs = map (renderHaskellName . parsedArgName) $ parsedInputs o
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
-- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging.
@ -233,19 +236,19 @@ extras d = enclose "{-\n" "\n-}" $
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc
typeSig o = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs o)
++ map tensorArgAndComment (parsedInputs o)
typeSig pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where
constraints
| null (inferredTypeAttrs o) = empty
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs o ++ parsedOutputs o,
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
ArgTensorEither v <- [parsedArgKind k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs o]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint
$ inferredTypeAttrs o
$ inferredTypeAttrs pOp
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
@ -260,7 +263,7 @@ typeSig o = constraints
AttrTensor -> "TensorProto"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs o of
outputs = case parsedOutputs pOp of
[] -> "ControlNode"
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
[a] -> tensorArg a <+> "-- ^" <+> argComment a
@ -286,8 +289,7 @@ tensorArg p = case parsedArgCase p of
in v <+> a
attrComment :: Attr a -> Doc
attrComment a
= argComment' (attrName a) (attrDescription a)
attrComment a = argComment' (attrName a) (attrDescription a)
argComment :: ParsedArg -> Doc
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)

View File

@ -3,6 +3,7 @@
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
( ParsedOp(..)
, Name(..)
@ -119,10 +120,11 @@ data ArgType
= ArgTypeFixed DataType -- ^ A fixed type.
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a
@ -183,7 +185,6 @@ forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
-- | Upper-case the given text.
@ -196,25 +197,22 @@ parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, parsedInputs = inputs
, parsedOutputs = outputs
, explicitInputAttrs = explicitInputs
, inferredTypeAttrs = inferredTypes
, inferredListSizeAttrs = inferredListSizes
, ..
}
where
inputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
outputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputs = sortBy (comparing (tfName . attrName))
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
inferredTypes = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
inferredListSizes = mapMaybeAttrs (getInferredListSizeAttr inputs) $ o ^. attr
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypes
++ map attrName inferredListSizes
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
@ -252,13 +250,13 @@ getInferredListSizeAttr inputs a
-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> case f a of
Nothing -> Nothing
Just x -> Just Attr
{ attrName = makeName (a ^. name)
, attrDescription = a ^. description
, attrInfo = x
}
mapMaybeAttrs f = mapMaybe $ \a -> do
x <- f a
Just Attr
{ attrName = makeName (a ^. name)
, attrDescription = a ^. description
, attrInfo = x
}
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg