mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Refactor OpGen. (#36)
Also fixes op lists when the same attribute specifies the length of both an input and an output. I added a test of "shapeN" which previously failed with the following error: ERROR: Ran out of counts in toResult. Likely misuse of buildListOp.
This commit is contained in:
parent
2b5e41ffeb
commit
a277c7ddb3
5 changed files with 502 additions and 413 deletions
|
@ -13,6 +13,7 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
-- | Rendering of TensorFlow operations as Haskell functions.
|
-- | Rendering of TensorFlow operations as Haskell functions.
|
||||||
|
@ -23,46 +24,25 @@ module TensorFlow.OpGen
|
||||||
, flagParser)
|
, flagParser)
|
||||||
where
|
where
|
||||||
|
|
||||||
import Prelude hiding (head, tail)
|
|
||||||
|
|
||||||
import Control.Applicative ((<**>))
|
|
||||||
import Control.Monad (guard)
|
|
||||||
import Data.Char (toLower, toUpper)
|
|
||||||
import Data.Foldable (toList)
|
import Data.Foldable (toList)
|
||||||
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.ProtoLens (def, showMessage)
|
import Data.ProtoLens (def, showMessage)
|
||||||
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
import Data.List.NonEmpty (NonEmpty)
|
||||||
import qualified Data.List.NonEmpty as NE
|
import qualified Data.List.NonEmpty as NE
|
||||||
import Lens.Family2 ((^.), (.~), (&), view)
|
import Lens.Family2 ((^.), (.~), (&), view)
|
||||||
import Options.Applicative (Parser, help, long, strOption, value)
|
import Options.Applicative (Parser, help, long, strOption, value)
|
||||||
import Proto.Tensorflow.Core.Framework.OpDef
|
import Proto.Tensorflow.Core.Framework.OpDef
|
||||||
( OpList
|
( OpList
|
||||||
, OpDef
|
, OpDef
|
||||||
, OpDef'ArgDef
|
|
||||||
, attr
|
, attr
|
||||||
, description
|
|
||||||
, inputArg
|
, inputArg
|
||||||
, name
|
, name
|
||||||
, numberAttr
|
|
||||||
, op
|
, op
|
||||||
, outputArg
|
, outputArg
|
||||||
, summary
|
|
||||||
, type'
|
|
||||||
, typeAttr
|
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
import System.FilePath (takeBaseName)
|
import System.FilePath (takeBaseName)
|
||||||
import TensorFlow.OpGen.AttrVal
|
import TensorFlow.OpGen.ParsedOp
|
||||||
(AttrDef
|
|
||||||
, AttrCase(..)
|
|
||||||
, AttrTemplate(..)
|
|
||||||
, Template
|
|
||||||
, attrDef
|
|
||||||
, attrOriginal
|
|
||||||
, attrTemplate
|
|
||||||
, templateDefault
|
|
||||||
, templateRestrictions
|
|
||||||
)
|
|
||||||
import Text.PrettyPrint.Mainland
|
import Text.PrettyPrint.Mainland
|
||||||
( Doc
|
( Doc
|
||||||
, (<>)
|
, (<>)
|
||||||
|
@ -79,18 +59,14 @@ import Text.PrettyPrint.Mainland
|
||||||
, folddoc
|
, folddoc
|
||||||
, hang
|
, hang
|
||||||
, indent
|
, indent
|
||||||
, int
|
|
||||||
, parens
|
, parens
|
||||||
, sep
|
, sep
|
||||||
, stack
|
, stack
|
||||||
, strictText
|
, strictText
|
||||||
, tuple
|
, tuple
|
||||||
)
|
)
|
||||||
import qualified Data.Map.Strict as Map
|
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import qualified Data.Semigroup as Semigroup
|
|
||||||
import Data.Text (Text)
|
|
||||||
|
|
||||||
data OpGenFlags = OpGenFlags
|
data OpGenFlags = OpGenFlags
|
||||||
{ outputFile :: String
|
{ outputFile :: String
|
||||||
|
@ -118,7 +94,6 @@ docOpList flags opList =
|
||||||
, "{-# LANGUAGE DataKinds #-}"
|
, "{-# LANGUAGE DataKinds #-}"
|
||||||
, "{-# LANGUAGE FlexibleInstances #-}"
|
, "{-# LANGUAGE FlexibleInstances #-}"
|
||||||
, "{-# LANGUAGE OverloadedStrings #-}"
|
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||||
, "{-# LANGUAGE RankNTypes #-}"
|
|
||||||
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||||
-- Avoids reports about shadowing standard library names.
|
-- Avoids reports about shadowing standard library names.
|
||||||
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
|
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
|
||||||
|
@ -129,34 +104,17 @@ docOpList flags opList =
|
||||||
, imports
|
, imports
|
||||||
, empty
|
, empty
|
||||||
, folddoc (\x y -> x </> empty </> y)
|
, folddoc (\x y -> x </> empty </> y)
|
||||||
(map renderDef $
|
(map renderOpAndExtras $
|
||||||
filter (not . flip elem exclusions . view name) $
|
filter (not . flip elem exclusions . view name) $
|
||||||
toList $ opList ^. op)
|
toList $ opList ^. op)
|
||||||
]
|
]
|
||||||
where moduleName =
|
where moduleName =
|
||||||
Text.pack (prefix flags) <> "." <> camelCase
|
Text.pack (prefix flags) <> "." <> camelCase
|
||||||
-- Discards the optional trailing _op_lib
|
-- Discards the optional trailing _ops_op_lib
|
||||||
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
|
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
|
||||||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||||
|
renderOpAndExtras o = renderOp (parseOp o) </> extras o
|
||||||
camelCase :: Text -> Text
|
|
||||||
camelCase s = Text.concat $ map upCase
|
|
||||||
$ filter (/= "ops")
|
|
||||||
$ Text.splitOn "_" s
|
|
||||||
|
|
||||||
-- | Upper-case the given text.
|
|
||||||
upCase :: Text -> Text
|
|
||||||
upCase = forceCase toUpper
|
|
||||||
|
|
||||||
-- | Lower-case the given name, and prevent it from overlapping with a reserved
|
|
||||||
-- Haskell name.
|
|
||||||
lowCase :: Text -> Text
|
|
||||||
lowCase = replaceReservedName . forceCase toLower
|
|
||||||
|
|
||||||
forceCase :: (Char -> Char) -> Text -> Text
|
|
||||||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
|
||||||
(Text.uncons s)
|
|
||||||
|
|
||||||
imports :: Doc
|
imports :: Doc
|
||||||
imports = stack [
|
imports = stack [
|
||||||
|
@ -170,232 +128,208 @@ imports = stack [
|
||||||
, "import TensorFlow.Output (ResourceHandle)"
|
, "import TensorFlow.Output (ResourceHandle)"
|
||||||
, "import TensorFlow.Tensor"
|
, "import TensorFlow.Tensor"
|
||||||
, "import TensorFlow.Types"
|
, "import TensorFlow.Types"
|
||||||
]
|
|
||||||
|
|
||||||
renderDef :: OpDef -> Doc
|
|
||||||
renderDef d =
|
|
||||||
stack [
|
|
||||||
haddocks
|
|
||||||
, n <+> "::" <+> hang 0 (typeSig d)
|
|
||||||
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
|
|
||||||
-- the body needs to be indented wrt the name
|
|
||||||
indent indentation functionBody
|
|
||||||
, extras -- just for debug
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
|
||||||
|
renderHaskellName = strictText . unHaskellName . haskellName
|
||||||
|
renderTFName = strictText . unTFName . tfName
|
||||||
|
renderQuotedTFName = dquotes . renderTFName
|
||||||
|
|
||||||
|
|
||||||
|
-- | Generate the source code for a single op.
|
||||||
|
-- For example:
|
||||||
|
--
|
||||||
|
-- -- | {haddock comment}
|
||||||
|
-- foo :: {type sig}
|
||||||
|
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
|
||||||
|
renderOp :: ParsedOp -> Doc
|
||||||
|
renderOp pOp = stack $
|
||||||
|
[ haddocks
|
||||||
|
, n <+> "::" <+> hang 0 (typeSig pOp)
|
||||||
|
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||||
|
<+> "=" </> -- args are indented
|
||||||
|
-- the body needs to be indented wrt the name
|
||||||
|
indent indentation (functionBody pOp)
|
||||||
|
] ++ whereClause listSizeAttrs
|
||||||
where
|
where
|
||||||
n = strictText $ fixOpName (d ^. name)
|
n = renderHaskellName $ parsedOpName pOp
|
||||||
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
|
listSizeAttrs = inferredListSizeAttrs pOp
|
||||||
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
|
args = sep $ map renderHaskellName
|
||||||
fixOpName = lowCase
|
$ map attrName (explicitInputAttrs pOp)
|
||||||
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
|
++ map parsedArgName (parsedInputs pOp)
|
||||||
|
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||||
|
|
||||||
|
-- | A check that all lists of the given size have the given length.
|
||||||
|
-- For example:
|
||||||
|
-- eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])]
|
||||||
|
funcGuard :: [Attr (NonEmpty Name)] -> Doc
|
||||||
|
funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
|
||||||
where
|
where
|
||||||
entries =
|
entries =
|
||||||
[ parens $ quotedText nAttr <> comma <+>
|
[ parens $ nAttr <> comma <+>
|
||||||
brackets (commasep $ toList $
|
brackets (commasep $ toList $
|
||||||
NE.map renderTensorName tensorNames)
|
map renderTensorName (toList $ attrInfo a))
|
||||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
| a <- attrs
|
||||||
|
, let nAttr = renderQuotedTFName (attrName a)
|
||||||
]
|
]
|
||||||
renderTensorName x = parens $ quotedText x <> comma <+>
|
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
|
||||||
"length" <+> strictText x
|
"length" <+> renderHaskellName x
|
||||||
-- Uses hang 0 to align the argument vertically on multiple lines.
|
|
||||||
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
-- | Define the implicit list length attributes.
|
||||||
</> indent indentation (sep tensorArgs)
|
-- For example:
|
||||||
|
-- where
|
||||||
|
-- n1 = fromIntegral (length input1) :: Int64
|
||||||
|
-- n2 = fromIntegral (length input2) :: Int64
|
||||||
|
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
||||||
|
whereClause [] = []
|
||||||
|
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||||
|
where
|
||||||
|
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
||||||
|
<+> "fromIntegral (length"
|
||||||
|
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||||
|
<> ") :: Int64"
|
||||||
|
|
||||||
|
functionBody :: ParsedOp -> Doc
|
||||||
|
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
|
</> indent indentation (sep tensorArgs)
|
||||||
|
where
|
||||||
buildFunction
|
buildFunction
|
||||||
| null outputListsSizes = "buildOp"
|
| null outputListsSizes = "buildOp"
|
||||||
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
|
| otherwise = "buildListOp" <+>
|
||||||
outputListsSizes = [ strictText numberAttrName
|
brackets (commasep $
|
||||||
| o <- d ^. outputArg
|
map renderHaskellName outputListsSizes)
|
||||||
, let numberAttrName = o ^. numberAttr
|
outputListsSizes = [ a
|
||||||
, not (Text.null numberAttrName) &&
|
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||||
numberAttrName `Map.member` mandatoryAttrMap d
|
<- parsedOutputs pOp]
|
||||||
]
|
|
||||||
buildOpParts =
|
buildOpParts =
|
||||||
"opDef" <+> quotedText (d ^. name) :
|
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||||
-- Renders tensor arguments.
|
-- Renders tensor arguments.
|
||||||
[ "& opAttr" <+> quotedText tfName <+>
|
[ "& opAttr" <+> renderQuotedTFName n <+>
|
||||||
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
|
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
||||||
| (tfName, (hsName, _)) <- Map.toList typeMap
|
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||||
] ++
|
] ++
|
||||||
-- Renders mandatory attributes as function parameters.
|
-- Renders mandatory attributes as function parameters.
|
||||||
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||||
| (tfName, hsName) <- mandatoryAttrs
|
| a <- explicitInputAttrs pOp, let n = attrName a
|
||||||
] ++
|
] ++
|
||||||
-- Renders sizes of tensor list types having number_attr.
|
-- Renders sizes of tensor list types having number_attr.
|
||||||
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||||
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
|
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
|
||||||
]
|
]
|
||||||
mandatoryAttrs = [(strictText tf, strictText hs)
|
|
||||||
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
|
|
||||||
]
|
|
||||||
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
|
|
||||||
extras = enclose "{-\n" "\n-}" $
|
|
||||||
strictText $ Text.pack $
|
|
||||||
showMessage ((def :: OpDef)
|
|
||||||
& inputArg .~ (d ^. inputArg)
|
|
||||||
& outputArg .~ (d ^. outputArg)
|
|
||||||
& attr .~ (d ^. attr))
|
|
||||||
typeMap = opDefTypeMap d
|
|
||||||
|
|
||||||
-- | Makes a quoted string doc out of the given text value.
|
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||||
quotedText :: Text.Text -> Doc
|
|
||||||
quotedText = dquotes . strictText
|
|
||||||
|
|
||||||
-- | typeSig renders the type signature of the given OpDef.
|
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||||
typeSig :: OpDef -> Doc
|
-- debugging.
|
||||||
typeSig d =
|
extras :: OpDef -> Doc
|
||||||
foralls <+> constraints <+/>
|
extras d = enclose "{-\n" "\n-}" $
|
||||||
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
|
strictText $ Text.pack $
|
||||||
|
showMessage ((def :: OpDef)
|
||||||
|
& inputArg .~ (d ^. inputArg)
|
||||||
|
& outputArg .~ (d ^. outputArg)
|
||||||
|
& attr .~ (d ^. attr))
|
||||||
|
|
||||||
|
-- | The type signature for an op.
|
||||||
|
-- Of the form:
|
||||||
|
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
||||||
|
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||||
|
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||||
|
-- "Tensor t2 v2" is an output.
|
||||||
|
typeSig :: ParsedOp -> Doc
|
||||||
|
typeSig pOp = constraints
|
||||||
|
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||||
|
++ map tensorArgAndComment (parsedInputs pOp)
|
||||||
|
++ [outputs])
|
||||||
where
|
where
|
||||||
foralls | null typeMap = empty
|
constraints
|
||||||
| otherwise =
|
| null (inferredTypeAttrs pOp) = empty
|
||||||
"forall"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||||
<+> sep (refVariableNames ++ typeMapTypeNames)
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
<+> "."
|
ArgTensorEither v <- [parsedArgKind k]]
|
||||||
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
|
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||||
constraints | null typeMap = empty
|
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||||
| otherwise =
|
$ inferredTypeAttrs pOp
|
||||||
tuple (concatMap
|
|
||||||
(\(t, aDef) ->
|
|
||||||
"TensorType" <+> strictText t
|
|
||||||
: maybeToList (oneOfRestrictions aDef t))
|
|
||||||
(Map.elems typeMap)) <+> "=>"
|
|
||||||
refVariableNames = catMaybes (map fst tensorInputs)
|
|
||||||
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
|
||||||
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
|
|
||||||
tensorArg refType arg = wrapArg refType arg <**>
|
|
||||||
pure (<+> hang 0 ("-- ^" <+> argComment arg))
|
|
||||||
-- Argument type is a list of tensors if number_attr is set;
|
|
||||||
-- otherwise it's a single Tensor.
|
|
||||||
wrapArg refType arg =
|
|
||||||
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
|
|
||||||
where typ = tensorType refType arg
|
|
||||||
-- The result is (reference type variable if any, type representing the arg)
|
|
||||||
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
|
|
||||||
tensorType refType arg
|
|
||||||
-- Deals with resource handles that are unlike tensors and
|
|
||||||
-- have their own type level representation. The magic "dtype"
|
|
||||||
-- name is the name of the attribute specifying the type of
|
|
||||||
-- the resource handle. It is not referenced by resource input
|
|
||||||
-- or output because it isn't expressible in TF operation
|
|
||||||
-- signatures.
|
|
||||||
| (arg ^. type' == DT_RESOURCE) =
|
|
||||||
(Nothing, strictText "ResourceHandle dtype")
|
|
||||||
| otherwise =
|
|
||||||
(Just refType,
|
|
||||||
"Tensor" <+> refType <+> maybe directType strictText indirectType)
|
|
||||||
where
|
|
||||||
-- This is the case of a named parameter type which is
|
|
||||||
-- constrained as a OneOf.
|
|
||||||
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
|
|
||||||
-- The nominal case when the type name is given directly.
|
|
||||||
directType = dtTypeToDoc (arg ^. type')
|
|
||||||
outputs =
|
|
||||||
case d ^. outputArg of
|
|
||||||
[] -> "ControlNode"
|
|
||||||
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
|
||||||
os -> renderTupleResult os
|
|
||||||
wrappedOutput = snd . wrapArg "Value"
|
|
||||||
-- Tuple result case is rendered differently to give
|
|
||||||
-- individual elements their own comments.
|
|
||||||
renderTupleResult os =
|
|
||||||
stack $ [ tuple (map wrappedOutput os)
|
|
||||||
, flatten commentSummary
|
|
||||||
] ++ map commentDetails os
|
|
||||||
where
|
|
||||||
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
|
|
||||||
commentDetails o =
|
|
||||||
stack [ "--"
|
|
||||||
, "-- *" <+> argComment o
|
|
||||||
]
|
|
||||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||||
mandatoryAttrInputs = [
|
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||||
dtTypeToDoc dtType <+>
|
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||||
hang 0 ("-- ^" <+> argComment' tfName descr)
|
renderAttrType (AttrList a) = brackets $ renderAttrBaseType a
|
||||||
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
|
renderAttrBaseType = \case
|
||||||
typeMap = opDefTypeMap d
|
AttrBytes -> "ByteString"
|
||||||
|
AttrInt64 -> "Data.Int.Int64"
|
||||||
|
AttrFloat -> "Float"
|
||||||
|
AttrBool -> "Bool"
|
||||||
|
AttrType -> "DataType"
|
||||||
|
AttrShape -> "TensorShapeProto"
|
||||||
|
AttrTensor -> "TensorProto"
|
||||||
|
|
||||||
-- | Returns the type restriction for the given tensor type if the
|
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
||||||
-- set of allowed types is not empty (i.e. restricted).
|
outputs = case parsedOutputs pOp of
|
||||||
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
|
[] -> "ControlNode"
|
||||||
oneOfRestrictions aDef tName = do
|
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
||||||
typs <- onAttrType (^. templateRestrictions) aDef
|
[a] -> tensorArg a <+> "-- ^" <+> argComment a
|
||||||
guard $ not $ null typs
|
as -> tuple (map tensorArg as) <+/> resultComment as
|
||||||
let typeList = commasep $ map strictText $
|
|
||||||
Set.toList $ Set.fromList $
|
-- | Render an op input or output.
|
||||||
map dtTypeToHaskell typs
|
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
||||||
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
|
tensorArg :: ParsedArg -> Doc
|
||||||
|
tensorArg p = case parsedArgCase p of
|
||||||
|
SimpleArg { argType = t } -> tensorType t
|
||||||
|
ListArg { argType = t } -> brackets $ tensorType t
|
||||||
|
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||||
|
where
|
||||||
|
tensorType t = let
|
||||||
|
v = case parsedArgKind p of
|
||||||
|
ArgTensorRef -> "Tensor Ref"
|
||||||
|
ArgTensorValue -> "Tensor Value"
|
||||||
|
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||||
|
ArgResource -> "ResourceHandle"
|
||||||
|
a = case t of
|
||||||
|
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||||
|
ArgTypeAttr n -> renderHaskellName n
|
||||||
|
in v <+> a
|
||||||
|
|
||||||
-- | Identifies the attributes used as tensor cardinalities. In such
|
attrComment :: Attr a -> Doc
|
||||||
-- cases a list of tensors is supplied as an input_arg. The number of
|
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||||
-- such inputs is communicated as a separate opAttr.
|
|
||||||
-- The result key is TensorFlow attribute name and the value is the
|
argComment :: ParsedArg -> Doc
|
||||||
-- tensor names which have number_attr set to the result key.
|
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
|
||||||
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
|
|
||||||
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
|
|
||||||
(nAttr, replaceReservedName (inp ^. name) :| [])
|
|
||||||
| inp <- d ^. inputArg
|
|
||||||
, nAttr <- [inp ^. numberAttr]
|
|
||||||
, not (Text.null nAttr)
|
|
||||||
]
|
|
||||||
|
|
||||||
argComment :: OpDef'ArgDef -> Doc
|
argComment' :: Name -> Text.Text -> Doc
|
||||||
argComment arg = argComment' (arg ^. name) (arg ^. description)
|
|
||||||
|
|
||||||
argComment' :: Text.Text -> Text.Text -> Doc
|
|
||||||
argComment' argName argDesc =
|
argComment' argName argDesc =
|
||||||
bold argName <> splitMultilineText (":" <+>) argDesc
|
bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc
|
||||||
|
|
||||||
bold :: Text.Text -> Doc
|
bold :: Doc -> Doc
|
||||||
bold n = strictText ("__" <> n <> "__")
|
bold n = "__" <> n <> "__"
|
||||||
|
|
||||||
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
|
-- | Comment for the outputs of an op.
|
||||||
|
-- For example:
|
||||||
|
-- -- ^ (__output1__, __output2__)
|
||||||
|
-- --
|
||||||
|
-- -- * __output1__: description1
|
||||||
|
-- --
|
||||||
|
-- -- * __output2__: description2
|
||||||
|
resultComment :: [ParsedArg] -> Doc
|
||||||
|
resultComment os = stack $ flatten commentSummary : map commentDetails os
|
||||||
|
where
|
||||||
|
commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os]
|
||||||
|
commentDetails o =
|
||||||
|
stack [ "--"
|
||||||
|
, "-- *" <+> argComment o
|
||||||
|
]
|
||||||
|
|
||||||
-- | Returns the map of type parameters from OpDef type name to
|
-- | Constraints for a given type parameter.
|
||||||
-- (haskell friendly type name, the type's attribute definition).
|
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
|
||||||
opDefTypeMap :: OpDef -> OpDefTypeMap
|
tensorArgConstraint :: Attr [DataType] -> [Doc]
|
||||||
opDefTypeMap d =
|
tensorArgConstraint a
|
||||||
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
= ("TensorType" <+> n
|
||||||
|
: if null typeList
|
||||||
attrList :: OpDef -> [(Text.Text, AttrDef)]
|
then []
|
||||||
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
|
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
|
||||||
|
where
|
||||||
isType :: AttrDef -> Bool
|
n = renderHaskellName $ attrName a
|
||||||
isType = fromMaybe False . onAttrType (const True)
|
typeList = map strictText $
|
||||||
|
Set.toList $ Set.fromList $
|
||||||
-- | Applies the given function to the data type. Is this a Prism?
|
map dtTypeToHaskell $ attrInfo a
|
||||||
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
|
|
||||||
onAttrType f x = case x ^. attrTemplate of
|
|
||||||
AttrSingle (AttrType a) -> Just (f a)
|
|
||||||
_ -> Nothing
|
|
||||||
|
|
||||||
-- | mandatoryAttrMap contains the attributes chosen by
|
|
||||||
-- isMandatoryAttr, excluding those which are derived from list of
|
|
||||||
-- tensor arguments. The key is the TF name of the attribute. The
|
|
||||||
-- value tuple is (haskell name, TF type, attribute comment).
|
|
||||||
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
|
|
||||||
mandatoryAttrMap d =
|
|
||||||
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
|
|
||||||
| (n, a) <- attrList d
|
|
||||||
, Just dtType <- [isMandatoryAttr a]
|
|
||||||
-- Excludes the attributes rendered as list lengths.
|
|
||||||
, n `Map.notMember` numberAttrMap d
|
|
||||||
]
|
|
||||||
|
|
||||||
-- | Inspects the attribute and if it is one of the implemented
|
|
||||||
-- non-tensor values lacking default, then returns Just the TF type.
|
|
||||||
isMandatoryAttr :: AttrDef -> Maybe DataType
|
|
||||||
isMandatoryAttr x =
|
|
||||||
case x ^. attrTemplate of
|
|
||||||
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
|
|
||||||
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
|
|
||||||
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
|
|
||||||
_ -> Nothing
|
|
||||||
where
|
|
||||||
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
|
|
||||||
|
|
||||||
dtTypeToDoc :: DataType -> Doc
|
|
||||||
dtTypeToDoc = strictText . dtTypeToHaskell
|
|
||||||
|
|
||||||
-- NOTE: The cases of this function should be kept in sync with
|
-- NOTE: The cases of this function should be kept in sync with
|
||||||
-- TensorFlow.Types.AllTensorTypes.
|
-- TensorFlow.Types.AllTensorTypes.
|
||||||
|
@ -429,6 +363,12 @@ dtTypeToHaskell x =
|
||||||
haddockComment :: Text.Text -> Doc
|
haddockComment :: Text.Text -> Doc
|
||||||
haddockComment = strictText
|
haddockComment = strictText
|
||||||
|
|
||||||
|
-- | Generate a multiline comment. For example:
|
||||||
|
-- summary'
|
||||||
|
-- --
|
||||||
|
-- -- detail_line1
|
||||||
|
-- -- detail_line2
|
||||||
|
-- -- ...
|
||||||
multilineComment :: Text.Text -> Text.Text -> Doc
|
multilineComment :: Text.Text -> Text.Text -> Doc
|
||||||
multilineComment summary' detail =
|
multilineComment summary' detail =
|
||||||
haddockComment summary' </>
|
haddockComment summary' </>
|
||||||
|
@ -445,45 +385,5 @@ splitMultilineText lead detail =
|
||||||
(l : ls) -> stack $ lead (haddockComment l)
|
(l : ls) -> stack $ lead (haddockComment l)
|
||||||
: map (("--" <+>) . haddockComment) ls
|
: map (("--" <+>) . haddockComment) ls
|
||||||
|
|
||||||
replaceReservedName :: Text -> Text
|
|
||||||
replaceReservedName n
|
|
||||||
| n `Set.member` reservedKeywords = n <> "'"
|
|
||||||
| otherwise = n
|
|
||||||
|
|
||||||
indentation :: Int
|
indentation :: Int
|
||||||
indentation = 4
|
indentation = 4
|
||||||
|
|
||||||
reservedKeywords :: Set.Set Text
|
|
||||||
reservedKeywords = Set.fromList $
|
|
||||||
-- Haskell2010 keywords:
|
|
||||||
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
|
||||||
-- We don't include keywords that are allowed to be variable names,
|
|
||||||
-- in particular: "as", "forall", and "hiding".
|
|
||||||
[ "case"
|
|
||||||
, "class"
|
|
||||||
, "data"
|
|
||||||
, "default"
|
|
||||||
, "deriving"
|
|
||||||
, "do"
|
|
||||||
, "else"
|
|
||||||
, "foreign"
|
|
||||||
, "if"
|
|
||||||
, "import"
|
|
||||||
, "in"
|
|
||||||
, "infix"
|
|
||||||
, "infixl"
|
|
||||||
, "infixr"
|
|
||||||
, "instance"
|
|
||||||
, "let"
|
|
||||||
, "module"
|
|
||||||
, "newtype"
|
|
||||||
, "of"
|
|
||||||
, "then"
|
|
||||||
, "type"
|
|
||||||
, "where"
|
|
||||||
]
|
|
||||||
++ -- Nonstandard extensions
|
|
||||||
[ "mdo" -- RecursiveDo
|
|
||||||
, "rec" -- Arrows, RecursiveDo
|
|
||||||
, "proc" -- Arrows
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,120 +0,0 @@
|
||||||
-- Copyright 2016 TensorFlow authors.
|
|
||||||
--
|
|
||||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
-- you may not use this file except in compliance with the License.
|
|
||||||
-- You may obtain a copy of the License at
|
|
||||||
--
|
|
||||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
--
|
|
||||||
-- Unless required by applicable law or agreed to in writing, software
|
|
||||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
-- See the License for the specific language governing permissions and
|
|
||||||
-- limitations under the License.
|
|
||||||
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
|
||||||
|
|
||||||
-- | Wrapping of TensorFlow attributes into Haskell entities.
|
|
||||||
module TensorFlow.OpGen.AttrVal
|
|
||||||
(AttrDef
|
|
||||||
, AttrCase(..)
|
|
||||||
, AttrTemplate(..)
|
|
||||||
, Template
|
|
||||||
, attrDef
|
|
||||||
, attrOriginal
|
|
||||||
, attrTemplate
|
|
||||||
, templateDefault
|
|
||||||
, templateRestrictions
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Data.Int (Int64)
|
|
||||||
import Data.Monoid ((<>))
|
|
||||||
import Lens.Family2 (Lens', (^.))
|
|
||||||
import Lens.Family2.Unchecked (lens)
|
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
|
|
||||||
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
|
|
||||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
|
|
||||||
import qualified Data.ByteString as B
|
|
||||||
import qualified Data.Text as Text
|
|
||||||
|
|
||||||
-- | Specifies the optional default value and a set of allowed values
|
|
||||||
-- for the given type.
|
|
||||||
data Template a = Template {
|
|
||||||
_templateDefault :: Maybe a
|
|
||||||
-- ^ The default value (mandatory if unspecified)
|
|
||||||
, _templateRestrictions :: [a]
|
|
||||||
-- ^ The allowed set of values, empty if no restrictions
|
|
||||||
}
|
|
||||||
|
|
||||||
templateDefault :: Lens' (Template a) (Maybe a)
|
|
||||||
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
|
|
||||||
|
|
||||||
templateRestrictions :: Lens' (Template a) [a]
|
|
||||||
templateRestrictions = lens _templateRestrictions
|
|
||||||
(\g x -> g { _templateRestrictions = x })
|
|
||||||
|
|
||||||
data UnusedTensor
|
|
||||||
|
|
||||||
data AttrCase f
|
|
||||||
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
|
|
||||||
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
|
|
||||||
| AttrFloat (f Float) -- float f = 4; // "float"
|
|
||||||
| AttrBool (f Bool) -- bool b = 5; // "bool"
|
|
||||||
| AttrType (f DataType) -- type = 6; // "type"
|
|
||||||
-- To be translated into TensorFlow.Types.Shape before use.
|
|
||||||
-- Leaving as a proto to reduce dependencies.
|
|
||||||
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
|
|
||||||
|
|
||||||
-- | Type-reified representation of TensorFlow AttrDef.
|
|
||||||
-- Initially limited to just the types in Op descriptors.
|
|
||||||
data AttrTemplate
|
|
||||||
= AttrSingle (AttrCase Template)
|
|
||||||
| AttrList (AttrCase [])
|
|
||||||
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
|
|
||||||
|
|
||||||
data AttrDef = AttrDef {
|
|
||||||
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
|
|
||||||
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
|
|
||||||
}
|
|
||||||
|
|
||||||
attrTemplate :: Lens' AttrDef AttrTemplate
|
|
||||||
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
|
|
||||||
|
|
||||||
attrOriginal :: Lens' AttrDef OpDef'AttrDef
|
|
||||||
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
|
|
||||||
|
|
||||||
attrDef :: OpDef'AttrDef -> AttrDef
|
|
||||||
attrDef a = AttrDef a
|
|
||||||
$ translate (a^.OpDef.type')
|
|
||||||
(a^.OpDef.defaultValue)
|
|
||||||
(a^.allowedValues)
|
|
||||||
|
|
||||||
-- | Converts the given AttrValue with the type given by the string
|
|
||||||
-- into the AttrVal if the type is known.
|
|
||||||
translate :: Text.Text -- ^ one of the TensorFlow type strings
|
|
||||||
-> AttrValue -- ^ default value
|
|
||||||
-> AttrValue -- ^ allowed values
|
|
||||||
-> AttrTemplate
|
|
||||||
translate t defaults allowed
|
|
||||||
| t == "string" = makeVal AttrBytes maybe's s
|
|
||||||
| t == "int" = makeVal AttrInt64 maybe'i i
|
|
||||||
| t == "float" = makeVal AttrFloat maybe'f f
|
|
||||||
| t == "bool" = makeVal AttrBool maybe'b b
|
|
||||||
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
|
|
||||||
| t == "shape" = makeVal AttrShape maybe'shape shape
|
|
||||||
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
|
|
||||||
| t == "list(string)" = makeList AttrBytes $ list.s
|
|
||||||
| t == "list(int)" = makeList AttrInt64 $ list.i
|
|
||||||
| t == "list(float)" = makeList AttrFloat $ list.f
|
|
||||||
| t == "list(bool)" = makeList AttrBool $ list.b
|
|
||||||
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
|
|
||||||
| t == "list(shape)" = makeList AttrShape $ list.shape
|
|
||||||
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
|
|
||||||
| t == "func" = AttrTensor $ error "func is unimplemented"
|
|
||||||
| otherwise = error $ show ("Unknown attribute type " <> t) ++
|
|
||||||
"," ++ show defaults ++
|
|
||||||
"," ++ show allowed
|
|
||||||
where makeVal c x y = AttrSingle $ c $
|
|
||||||
Template (defaults^.x) (allowed^.list.y)
|
|
||||||
makeList c y = AttrList $ c $ defaults^.y
|
|
299
tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs
Normal file
299
tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
-- | This module helps parse the proto OpDef into a Haskell type which is more
|
||||||
|
-- descriptive of how the attributes and arguments will be used in the
|
||||||
|
-- generated code.
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
module TensorFlow.OpGen.ParsedOp
|
||||||
|
( ParsedOp(..)
|
||||||
|
, Name(..)
|
||||||
|
, HaskellName(..)
|
||||||
|
, TFName(..)
|
||||||
|
, Attr(..)
|
||||||
|
, AttrType(..)
|
||||||
|
, AttrBaseType(..)
|
||||||
|
, ParsedArg(..)
|
||||||
|
, ParsedArgCase(..)
|
||||||
|
, ArgType(..)
|
||||||
|
, ArgKind(..)
|
||||||
|
, parseOp
|
||||||
|
, camelCase
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Char (toUpper, toLower)
|
||||||
|
import Data.List (sortBy)
|
||||||
|
import Data.List.NonEmpty (NonEmpty, nonEmpty)
|
||||||
|
import Data.Maybe (mapMaybe)
|
||||||
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Ord (comparing)
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.Text (Text)
|
||||||
|
import qualified Data.Text as Text
|
||||||
|
import Lens.Family2 ((^.))
|
||||||
|
import Proto.Tensorflow.Core.Framework.AttrValue (list)
|
||||||
|
import Proto.Tensorflow.Core.Framework.OpDef
|
||||||
|
( OpDef
|
||||||
|
, OpDef'ArgDef
|
||||||
|
, OpDef'AttrDef
|
||||||
|
, allowedValues
|
||||||
|
, attr
|
||||||
|
, maybe'defaultValue
|
||||||
|
, description
|
||||||
|
, name
|
||||||
|
, inputArg
|
||||||
|
, outputArg
|
||||||
|
, summary
|
||||||
|
, typeListAttr
|
||||||
|
, numberAttr
|
||||||
|
, typeAttr
|
||||||
|
, type'
|
||||||
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
|
||||||
|
|
||||||
|
data ParsedOp = ParsedOp
|
||||||
|
{ parsedOpName :: Name
|
||||||
|
, parsedOpSummary :: Text
|
||||||
|
, parsedOpDescription :: Text
|
||||||
|
, parsedInputs :: [ParsedArg]
|
||||||
|
, parsedOutputs :: [ParsedArg]
|
||||||
|
, explicitInputAttrs :: [Attr AttrType]
|
||||||
|
-- ^ Attributes that must be set explicitly when creating the op.
|
||||||
|
-- Associated with the type of the attribute.
|
||||||
|
, inferredTypeAttrs :: [Attr [DataType]]
|
||||||
|
-- ^ Attributes that are type parameters.
|
||||||
|
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
|
||||||
|
-- If this list is empty, then any type is acceptable.
|
||||||
|
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
||||||
|
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||||
|
-- from one or more of the input tensors.
|
||||||
|
-- Associated with the list of tensors whose size it describes.
|
||||||
|
}
|
||||||
|
|
||||||
|
data Name = Name
|
||||||
|
{ haskellName :: HaskellName
|
||||||
|
, tfName :: TFName
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | A raw name as specified in the OpDef proto.
|
||||||
|
newtype TFName = TFName { unTFName :: Text }
|
||||||
|
deriving (Eq, Ord)
|
||||||
|
|
||||||
|
-- | A name that's appropriate for a variable in a Haskell source file.
|
||||||
|
newtype HaskellName = HaskellName { unHaskellName :: Text }
|
||||||
|
|
||||||
|
-- | A named attribute, associated with some information about it.
|
||||||
|
data Attr a = Attr
|
||||||
|
{ attrName :: Name
|
||||||
|
, attrDescription :: Text
|
||||||
|
, attrInfo :: a
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | The type of an attribute.
|
||||||
|
data AttrType = AttrSingle AttrBaseType
|
||||||
|
| AttrList AttrBaseType
|
||||||
|
deriving Eq
|
||||||
|
|
||||||
|
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
||||||
|
| AttrType | AttrShape | AttrTensor
|
||||||
|
deriving Eq
|
||||||
|
|
||||||
|
-- | An input or output argument (Tensor) for an op.
|
||||||
|
data ParsedArg = ParsedArg
|
||||||
|
{ parsedArgName :: Name
|
||||||
|
, parsedArgDescription :: Text
|
||||||
|
, parsedArgCase :: ParsedArgCase
|
||||||
|
, parsedArgKind :: ArgKind
|
||||||
|
}
|
||||||
|
|
||||||
|
data ParsedArgCase
|
||||||
|
= SimpleArg { argType :: ArgType }
|
||||||
|
| ListArg
|
||||||
|
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||||
|
, argType :: ArgType
|
||||||
|
}
|
||||||
|
| MixedListArg { argTypeAttr :: Name }
|
||||||
|
-- ^ A heterogeneous list.
|
||||||
|
-- TODO(judahjacobson): Implement this.
|
||||||
|
|
||||||
|
-- | The type of an argument.
|
||||||
|
data ArgType
|
||||||
|
= ArgTypeFixed DataType -- ^ A fixed type.
|
||||||
|
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
|
||||||
|
|
||||||
|
-- The kind of an op input or output (not including the argument type `a`).
|
||||||
|
data ArgKind
|
||||||
|
= ArgTensorRef -- Tensor Ref a
|
||||||
|
| ArgTensorValue -- Tensor Value a
|
||||||
|
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||||
|
| ArgResource -- Resource a
|
||||||
|
|
||||||
|
|
||||||
|
makeName :: Text -> Name
|
||||||
|
makeName n = Name
|
||||||
|
{ haskellName = HaskellName $ fixReservedName $ lowCase n
|
||||||
|
, tfName = TFName n
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | Change a name so it doesn't conflict with any Haskell keywords.
|
||||||
|
fixReservedName :: Text -> Text
|
||||||
|
fixReservedName n
|
||||||
|
| n `Set.member` reservedKeywords = n <> "'"
|
||||||
|
| otherwise = n
|
||||||
|
|
||||||
|
reservedKeywords :: Set.Set Text
|
||||||
|
reservedKeywords = Set.fromList $
|
||||||
|
-- Haskell2010 keywords:
|
||||||
|
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
||||||
|
-- We don't include keywords that are allowed to be variable names,
|
||||||
|
-- in particular: "as", "forall", and "hiding".
|
||||||
|
[ "case"
|
||||||
|
, "class"
|
||||||
|
, "data"
|
||||||
|
, "default"
|
||||||
|
, "deriving"
|
||||||
|
, "do"
|
||||||
|
, "else"
|
||||||
|
, "foreign"
|
||||||
|
, "if"
|
||||||
|
, "import"
|
||||||
|
, "in"
|
||||||
|
, "infix"
|
||||||
|
, "infixl"
|
||||||
|
, "infixr"
|
||||||
|
, "instance"
|
||||||
|
, "let"
|
||||||
|
, "module"
|
||||||
|
, "newtype"
|
||||||
|
, "of"
|
||||||
|
, "then"
|
||||||
|
, "type"
|
||||||
|
, "where"
|
||||||
|
]
|
||||||
|
++ -- Nonstandard extensions
|
||||||
|
[ "mdo" -- RecursiveDo
|
||||||
|
, "rec" -- Arrows, RecursiveDo
|
||||||
|
, "proc" -- Arrows
|
||||||
|
]
|
||||||
|
|
||||||
|
-- | Lower-case the given text.
|
||||||
|
lowCase :: Text -> Text
|
||||||
|
lowCase = forceCase toLower
|
||||||
|
|
||||||
|
forceCase :: (Char -> Char) -> Text -> Text
|
||||||
|
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||||
|
(Text.uncons s)
|
||||||
|
|
||||||
|
camelCase :: Text -> Text
|
||||||
|
camelCase s = Text.concat $ map upCase
|
||||||
|
$ Text.splitOn "_" s
|
||||||
|
|
||||||
|
-- | Upper-case the given text.
|
||||||
|
upCase :: Text -> Text
|
||||||
|
upCase = forceCase toUpper
|
||||||
|
|
||||||
|
|
||||||
|
parseOp :: OpDef -> ParsedOp
|
||||||
|
parseOp o = ParsedOp
|
||||||
|
{ parsedOpName = makeName $ o ^. name
|
||||||
|
, parsedOpSummary = o ^. summary
|
||||||
|
, parsedOpDescription = o ^. description
|
||||||
|
, ..
|
||||||
|
}
|
||||||
|
where
|
||||||
|
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
||||||
|
(o ^. inputArg) tensorKindParams
|
||||||
|
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||||
|
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||||
|
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||||
|
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||||
|
$ o ^. attr
|
||||||
|
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||||
|
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||||
|
$ o ^. attr
|
||||||
|
implicitAttrs = Set.fromList $ map tfName $
|
||||||
|
map attrName inferredTypeAttrs
|
||||||
|
++ map attrName inferredListSizeAttrs
|
||||||
|
|
||||||
|
-- TODO(judahjacobson): Some arguments should be refs.
|
||||||
|
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||||
|
inputTensorKind a v
|
||||||
|
| a ^. type' == DT_RESOURCE = ArgResource
|
||||||
|
| otherwise = ArgTensorEither v
|
||||||
|
|
||||||
|
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||||
|
outputTensorKind a
|
||||||
|
| a ^. type' == DT_RESOURCE = ArgResource
|
||||||
|
| otherwise = ArgTensorValue
|
||||||
|
|
||||||
|
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||||
|
getExplicitInputAttr implicitAttrs a
|
||||||
|
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||||
|
, a ^. maybe'defaultValue == Nothing
|
||||||
|
, t <- parseAttrType (a ^. type')
|
||||||
|
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
|
||||||
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||||
|
getInferredTypeAttr a
|
||||||
|
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||||
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
||||||
|
getInferredListSizeAttr inputs a
|
||||||
|
| a ^. type' == "int"
|
||||||
|
= nonEmpty [t | ParsedArg { parsedArgName = t
|
||||||
|
, parsedArgCase
|
||||||
|
= ListArg { argLength = n }
|
||||||
|
} <- inputs
|
||||||
|
, TFName (a ^. name) == tfName n]
|
||||||
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
||||||
|
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
||||||
|
mapMaybeAttrs f = mapMaybe $ \a -> do
|
||||||
|
x <- f a
|
||||||
|
Just Attr
|
||||||
|
{ attrName = makeName (a ^. name)
|
||||||
|
, attrDescription = a ^. description
|
||||||
|
, attrInfo = x
|
||||||
|
}
|
||||||
|
|
||||||
|
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||||
|
parseArg a tKind = ParsedArg
|
||||||
|
{ parsedArgName = makeName (a ^. name)
|
||||||
|
, parsedArgDescription = a ^. description
|
||||||
|
, parsedArgCase = parseArgCase a
|
||||||
|
, parsedArgKind = tKind
|
||||||
|
}
|
||||||
|
|
||||||
|
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
|
||||||
|
parseArgCase a
|
||||||
|
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
|
||||||
|
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
|
||||||
|
| otherwise = SimpleArg thisArgType
|
||||||
|
where
|
||||||
|
thisArgType
|
||||||
|
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
||||||
|
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
|
||||||
|
| otherwise = ArgTypeFixed (a ^. type')
|
||||||
|
maybeAttr :: Text -> Maybe Name
|
||||||
|
maybeAttr "" = Nothing
|
||||||
|
maybeAttr t = Just $ makeName t
|
||||||
|
|
||||||
|
parseAttrType :: Text -> AttrType
|
||||||
|
parseAttrType = \case
|
||||||
|
"string" -> AttrSingle AttrBytes
|
||||||
|
"int" -> AttrSingle AttrInt64
|
||||||
|
"float" -> AttrSingle AttrFloat
|
||||||
|
"bool" -> AttrSingle AttrBool
|
||||||
|
"type" -> AttrSingle AttrType
|
||||||
|
"shape" -> AttrSingle AttrShape
|
||||||
|
"tensor" -> AttrSingle AttrTensor
|
||||||
|
"list(string)" -> AttrList AttrBytes
|
||||||
|
"list(int)" -> AttrList AttrInt64
|
||||||
|
"list(float)" -> AttrList AttrFloat
|
||||||
|
"list(bool)" -> AttrList AttrBool
|
||||||
|
"list(type)" -> AttrList AttrType
|
||||||
|
"list(shape)" -> AttrList AttrShape
|
||||||
|
"list(tensor)" -> AttrList AttrTensor
|
||||||
|
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
|
@ -13,7 +13,7 @@ cabal-version: >=1.22
|
||||||
|
|
||||||
library
|
library
|
||||||
hs-source-dirs: src
|
hs-source-dirs: src
|
||||||
exposed-modules: TensorFlow.OpGen.AttrVal
|
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||||
, TensorFlow.OpGen
|
, TensorFlow.OpGen
|
||||||
build-depends: proto-lens == 0.1.*
|
build-depends: proto-lens == 0.1.*
|
||||||
, tensorflow-proto == 0.1.*
|
, tensorflow-proto == 0.1.*
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Data.Int (Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -24,6 +25,8 @@ import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
import qualified TensorFlow.Types as TF
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
-- | Test split and concat are inverses.
|
-- | Test split and concat are inverses.
|
||||||
|
@ -34,11 +37,18 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||||
restored = CoreOps.concat dim splitList
|
restored = CoreOps.concat dim splitList
|
||||||
dim = 1 -- dimension to split along (with size of 3 in original)
|
dim = 1 -- dimension to split along (with size of 3 in original)
|
||||||
liftIO $ 3 @=? length splitList
|
liftIO $ 3 @=? length splitList
|
||||||
(x, y, z) <-
|
(x, y, z) <- TF.run (original, restored, splitList !! 1)
|
||||||
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
|
|
||||||
liftIO $ x @=? (y :: V.Vector Float)
|
liftIO $ x @=? (y :: V.Vector Float)
|
||||||
liftIO $ V.fromList [1, 4] @=? z
|
liftIO $ V.fromList [1, 4] @=? z
|
||||||
|
|
||||||
|
testShapeN :: Test
|
||||||
|
testShapeN = testCase "testShapeN" $ TF.runSession $ do
|
||||||
|
let shapes = map TF.Shape [[1],[2,3]]
|
||||||
|
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
|
||||||
|
result <- TF.run $ CoreOps.shapeN tensors
|
||||||
|
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testSplit
|
main = googleTest [ testSplit
|
||||||
|
, testShapeN
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue