mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Fixes.
This commit is contained in:
parent
eb9b91c28c
commit
9c0f1d2f88
|
@ -110,8 +110,8 @@ docOpList flags opList =
|
|||
]
|
||||
where moduleName =
|
||||
Text.pack (prefix flags) <> "." <> camelCase
|
||||
-- Discards the optional trailing _op_lib
|
||||
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
|
||||
-- Discards the optional trailing _ops_op_lib
|
||||
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
|
||||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||
renderOpAndExtras o = renderOp (parseOp o) </> extras o
|
||||
|
@ -136,26 +136,28 @@ renderTFName = strictText . unTFName . tfName
|
|||
renderQuotedTFName = dquotes . renderTFName
|
||||
|
||||
|
||||
-- | Generate the source code for a singel op.
|
||||
-- | Generate the source code for a single op.
|
||||
-- For example:
|
||||
--
|
||||
-- -- | {haddock comment}
|
||||
-- foo :: {type sig}
|
||||
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
|
||||
renderOp :: ParsedOp -> Doc
|
||||
renderOp d = stack $
|
||||
renderOp pOp = stack $
|
||||
[ haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig d)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard (inferredListSizeAttrs d)
|
||||
, n <+> "::" <+> hang 0 (typeSig pOp)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
<+> "=" </> -- args are indented
|
||||
-- the body needs to be indented wrt the name
|
||||
indent indentation (functionBody d)
|
||||
] ++ whereClause (inferredListSizeAttrs d)
|
||||
indent indentation (functionBody pOp)
|
||||
] ++ whereClause listSizeAttrs
|
||||
where
|
||||
n = renderHaskellName $ parsedOpName d
|
||||
args = sep $ map (renderHaskellName . attrName) (explicitInputAttrs d)
|
||||
++ map (renderHaskellName . parsedArgName) (parsedInputs d)
|
||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary d) (parsedOpDescription d)
|
||||
n = renderHaskellName $ parsedOpName pOp
|
||||
listSizeAttrs = inferredListSizeAttrs pOp
|
||||
args = sep $ map renderHaskellName
|
||||
$ map attrName (explicitInputAttrs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp)
|
||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||
|
||||
-- | A check that all lists of the given size have the given length.
|
||||
-- For example:
|
||||
|
@ -172,6 +174,7 @@ funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
|
|||
]
|
||||
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
|
||||
"length" <+> renderHaskellName x
|
||||
|
||||
-- | Define the implicit list length attributes.
|
||||
-- For example:
|
||||
-- where
|
||||
|
@ -187,7 +190,7 @@ whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr
|
|||
<> ") :: Int64"
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody o = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
where
|
||||
buildFunction
|
||||
|
@ -197,24 +200,24 @@ functionBody o = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
map renderHaskellName outputListsSizes)
|
||||
outputListsSizes = [ a
|
||||
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||
<- parsedOutputs o]
|
||||
<- parsedOutputs pOp]
|
||||
buildOpParts =
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName o) :
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||
-- Renders tensor arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+>
|
||||
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
||||
| a <- inferredTypeAttrs o, let n = attrName a
|
||||
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders mandatory attributes as function parameters.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- explicitInputAttrs o, let n = attrName a
|
||||
| a <- explicitInputAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders sizes of tensor list types having number_attr.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- inferredListSizeAttrs o, let n = attrName a
|
||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||
]
|
||||
|
||||
tensorArgs = map (renderHaskellName . parsedArgName) $ parsedInputs o
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
|
||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||
-- debugging.
|
||||
|
@ -233,19 +236,19 @@ extras d = enclose "{-\n" "\n-}" $
|
|||
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||
-- "Tensor t2 v2" is an output.
|
||||
typeSig :: ParsedOp -> Doc
|
||||
typeSig o = constraints
|
||||
<+/> signatureFold (map attrInput (explicitInputAttrs o)
|
||||
++ map tensorArgAndComment (parsedInputs o)
|
||||
typeSig pOp = constraints
|
||||
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
++ map tensorArgAndComment (parsedInputs pOp)
|
||||
++ [outputs])
|
||||
where
|
||||
constraints
|
||||
| null (inferredTypeAttrs o) = empty
|
||||
| null (inferredTypeAttrs pOp) = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs o ++ parsedOutputs o,
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
ArgTensorEither v <- [parsedArgKind k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs o]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||
$ inferredTypeAttrs o
|
||||
$ inferredTypeAttrs pOp
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||
|
@ -260,7 +263,7 @@ typeSig o = constraints
|
|||
AttrTensor -> "TensorProto"
|
||||
|
||||
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
||||
outputs = case parsedOutputs o of
|
||||
outputs = case parsedOutputs pOp of
|
||||
[] -> "ControlNode"
|
||||
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
||||
[a] -> tensorArg a <+> "-- ^" <+> argComment a
|
||||
|
@ -286,8 +289,7 @@ tensorArg p = case parsedArgCase p of
|
|||
in v <+> a
|
||||
|
||||
attrComment :: Attr a -> Doc
|
||||
attrComment a
|
||||
= argComment' (attrName a) (attrDescription a)
|
||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||
|
||||
argComment :: ParsedArg -> Doc
|
||||
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
-- generated code.
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
module TensorFlow.OpGen.ParsedOp
|
||||
( ParsedOp(..)
|
||||
, Name(..)
|
||||
|
@ -119,10 +120,11 @@ data ArgType
|
|||
= ArgTypeFixed DataType -- ^ A fixed type.
|
||||
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
|
||||
|
||||
-- The kind of an op input or output (not including the argument type `a`).
|
||||
data ArgKind
|
||||
= ArgTensorRef -- Tensor Ref a
|
||||
| ArgTensorValue -- Tensor Value a
|
||||
| ArgTensorEither Text -- Tensor v a
|
||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||
| ArgResource -- Resource a
|
||||
|
||||
|
||||
|
@ -183,7 +185,6 @@ forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
|||
|
||||
camelCase :: Text -> Text
|
||||
camelCase s = Text.concat $ map upCase
|
||||
$ filter (/= "ops")
|
||||
$ Text.splitOn "_" s
|
||||
|
||||
-- | Upper-case the given text.
|
||||
|
@ -196,25 +197,22 @@ parseOp o = ParsedOp
|
|||
{ parsedOpName = makeName $ o ^. name
|
||||
, parsedOpSummary = o ^. summary
|
||||
, parsedOpDescription = o ^. description
|
||||
, parsedInputs = inputs
|
||||
, parsedOutputs = outputs
|
||||
, explicitInputAttrs = explicitInputs
|
||||
, inferredTypeAttrs = inferredTypes
|
||||
, inferredListSizeAttrs = inferredListSizes
|
||||
, ..
|
||||
}
|
||||
where
|
||||
inputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
||||
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
||||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
outputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
explicitInputs = sortBy (comparing (tfName . attrName))
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ o ^. attr
|
||||
inferredTypes = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
inferredListSizes = mapMaybeAttrs (getInferredListSizeAttr inputs) $ o ^. attr
|
||||
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
$ o ^. attr
|
||||
implicitAttrs = Set.fromList $ map tfName $
|
||||
map attrName inferredTypes
|
||||
++ map attrName inferredListSizes
|
||||
map attrName inferredTypeAttrs
|
||||
++ map attrName inferredListSizeAttrs
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
|
@ -252,13 +250,13 @@ getInferredListSizeAttr inputs a
|
|||
|
||||
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
||||
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
||||
mapMaybeAttrs f = mapMaybe $ \a -> case f a of
|
||||
Nothing -> Nothing
|
||||
Just x -> Just Attr
|
||||
{ attrName = makeName (a ^. name)
|
||||
, attrDescription = a ^. description
|
||||
, attrInfo = x
|
||||
}
|
||||
mapMaybeAttrs f = mapMaybe $ \a -> do
|
||||
x <- f a
|
||||
Just Attr
|
||||
{ attrName = makeName (a ^. name)
|
||||
, attrDescription = a ^. description
|
||||
, attrInfo = x
|
||||
}
|
||||
|
||||
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||
parseArg a tKind = ParsedArg
|
||||
|
|
Loading…
Reference in New Issue
Block a user