mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Refactor OpGen. (#36)
Also fixes op lists when the same attribute specifies the length of both an input and an output. I added a test of "shapeN" which previously failed with the following error: ERROR: Ran out of counts in toResult. Likely misuse of buildListOp.
This commit is contained in:
parent
2b5e41ffeb
commit
a277c7ddb3
5 changed files with 502 additions and 413 deletions
|
@ -13,6 +13,7 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
-- | Rendering of TensorFlow operations as Haskell functions.
|
||||
|
@ -23,46 +24,25 @@ module TensorFlow.OpGen
|
|||
, flagParser)
|
||||
where
|
||||
|
||||
import Prelude hiding (head, tail)
|
||||
|
||||
import Control.Applicative ((<**>))
|
||||
import Control.Monad (guard)
|
||||
import Data.Char (toLower, toUpper)
|
||||
import Data.Foldable (toList)
|
||||
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
import Lens.Family2 ((^.), (.~), (&), view)
|
||||
import Options.Applicative (Parser, help, long, strOption, value)
|
||||
import Proto.Tensorflow.Core.Framework.OpDef
|
||||
( OpList
|
||||
, OpDef
|
||||
, OpDef'ArgDef
|
||||
, attr
|
||||
, description
|
||||
, inputArg
|
||||
, name
|
||||
, numberAttr
|
||||
, op
|
||||
, outputArg
|
||||
, summary
|
||||
, type'
|
||||
, typeAttr
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
import System.FilePath (takeBaseName)
|
||||
import TensorFlow.OpGen.AttrVal
|
||||
(AttrDef
|
||||
, AttrCase(..)
|
||||
, AttrTemplate(..)
|
||||
, Template
|
||||
, attrDef
|
||||
, attrOriginal
|
||||
, attrTemplate
|
||||
, templateDefault
|
||||
, templateRestrictions
|
||||
)
|
||||
import TensorFlow.OpGen.ParsedOp
|
||||
import Text.PrettyPrint.Mainland
|
||||
( Doc
|
||||
, (<>)
|
||||
|
@ -79,18 +59,14 @@ import Text.PrettyPrint.Mainland
|
|||
, folddoc
|
||||
, hang
|
||||
, indent
|
||||
, int
|
||||
, parens
|
||||
, sep
|
||||
, stack
|
||||
, strictText
|
||||
, tuple
|
||||
)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
import qualified Data.Text as Text
|
||||
import qualified Data.Semigroup as Semigroup
|
||||
import Data.Text (Text)
|
||||
|
||||
data OpGenFlags = OpGenFlags
|
||||
{ outputFile :: String
|
||||
|
@ -118,7 +94,6 @@ docOpList flags opList =
|
|||
, "{-# LANGUAGE DataKinds #-}"
|
||||
, "{-# LANGUAGE FlexibleInstances #-}"
|
||||
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||
, "{-# LANGUAGE RankNTypes #-}"
|
||||
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||
-- Avoids reports about shadowing standard library names.
|
||||
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
|
||||
|
@ -129,34 +104,17 @@ docOpList flags opList =
|
|||
, imports
|
||||
, empty
|
||||
, folddoc (\x y -> x </> empty </> y)
|
||||
(map renderDef $
|
||||
(map renderOpAndExtras $
|
||||
filter (not . flip elem exclusions . view name) $
|
||||
toList $ opList ^. op)
|
||||
]
|
||||
where moduleName =
|
||||
Text.pack (prefix flags) <> "." <> camelCase
|
||||
-- Discards the optional trailing _op_lib
|
||||
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
|
||||
-- Discards the optional trailing _ops_op_lib
|
||||
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
|
||||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||
|
||||
camelCase :: Text -> Text
|
||||
camelCase s = Text.concat $ map upCase
|
||||
$ filter (/= "ops")
|
||||
$ Text.splitOn "_" s
|
||||
|
||||
-- | Upper-case the given text.
|
||||
upCase :: Text -> Text
|
||||
upCase = forceCase toUpper
|
||||
|
||||
-- | Lower-case the given name, and prevent it from overlapping with a reserved
|
||||
-- Haskell name.
|
||||
lowCase :: Text -> Text
|
||||
lowCase = replaceReservedName . forceCase toLower
|
||||
|
||||
forceCase :: (Char -> Char) -> Text -> Text
|
||||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||
(Text.uncons s)
|
||||
renderOpAndExtras o = renderOp (parseOp o) </> extras o
|
||||
|
||||
imports :: Doc
|
||||
imports = stack [
|
||||
|
@ -172,230 +130,206 @@ imports = stack [
|
|||
, "import TensorFlow.Types"
|
||||
]
|
||||
|
||||
renderDef :: OpDef -> Doc
|
||||
renderDef d =
|
||||
stack [
|
||||
haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig d)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
|
||||
renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
|
||||
renderHaskellName = strictText . unHaskellName . haskellName
|
||||
renderTFName = strictText . unTFName . tfName
|
||||
renderQuotedTFName = dquotes . renderTFName
|
||||
|
||||
|
||||
-- | Generate the source code for a single op.
|
||||
-- For example:
|
||||
--
|
||||
-- -- | {haddock comment}
|
||||
-- foo :: {type sig}
|
||||
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
|
||||
renderOp :: ParsedOp -> Doc
|
||||
renderOp pOp = stack $
|
||||
[ haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig pOp)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
<+> "=" </> -- args are indented
|
||||
-- the body needs to be indented wrt the name
|
||||
indent indentation functionBody
|
||||
, extras -- just for debug
|
||||
]
|
||||
indent indentation (functionBody pOp)
|
||||
] ++ whereClause listSizeAttrs
|
||||
where
|
||||
n = strictText $ fixOpName (d ^. name)
|
||||
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
|
||||
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
|
||||
fixOpName = lowCase
|
||||
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
|
||||
n = renderHaskellName $ parsedOpName pOp
|
||||
listSizeAttrs = inferredListSizeAttrs pOp
|
||||
args = sep $ map renderHaskellName
|
||||
$ map attrName (explicitInputAttrs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp)
|
||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||
|
||||
-- | A check that all lists of the given size have the given length.
|
||||
-- For example:
|
||||
-- eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])]
|
||||
funcGuard :: [Attr (NonEmpty Name)] -> Doc
|
||||
funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
|
||||
where
|
||||
entries =
|
||||
[ parens $ quotedText nAttr <> comma <+>
|
||||
[ parens $ nAttr <> comma <+>
|
||||
brackets (commasep $ toList $
|
||||
NE.map renderTensorName tensorNames)
|
||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||
map renderTensorName (toList $ attrInfo a))
|
||||
| a <- attrs
|
||||
, let nAttr = renderQuotedTFName (attrName a)
|
||||
]
|
||||
renderTensorName x = parens $ quotedText x <> comma <+>
|
||||
"length" <+> strictText x
|
||||
-- Uses hang 0 to align the argument vertically on multiple lines.
|
||||
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
|
||||
"length" <+> renderHaskellName x
|
||||
|
||||
-- | Define the implicit list length attributes.
|
||||
-- For example:
|
||||
-- where
|
||||
-- n1 = fromIntegral (length input1) :: Int64
|
||||
-- n2 = fromIntegral (length input2) :: Int64
|
||||
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
||||
whereClause [] = []
|
||||
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||
where
|
||||
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
||||
<+> "fromIntegral (length"
|
||||
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||
<> ") :: Int64"
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
where
|
||||
buildFunction
|
||||
| null outputListsSizes = "buildOp"
|
||||
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
|
||||
outputListsSizes = [ strictText numberAttrName
|
||||
| o <- d ^. outputArg
|
||||
, let numberAttrName = o ^. numberAttr
|
||||
, not (Text.null numberAttrName) &&
|
||||
numberAttrName `Map.member` mandatoryAttrMap d
|
||||
]
|
||||
| otherwise = "buildListOp" <+>
|
||||
brackets (commasep $
|
||||
map renderHaskellName outputListsSizes)
|
||||
outputListsSizes = [ a
|
||||
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||
<- parsedOutputs pOp]
|
||||
buildOpParts =
|
||||
"opDef" <+> quotedText (d ^. name) :
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||
-- Renders tensor arguments.
|
||||
[ "& opAttr" <+> quotedText tfName <+>
|
||||
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
|
||||
| (tfName, (hsName, _)) <- Map.toList typeMap
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+>
|
||||
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
||||
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders mandatory attributes as function parameters.
|
||||
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
|
||||
| (tfName, hsName) <- mandatoryAttrs
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- explicitInputAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders sizes of tensor list types having number_attr.
|
||||
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
|
||||
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
|
||||
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||
]
|
||||
mandatoryAttrs = [(strictText tf, strictText hs)
|
||||
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
|
||||
]
|
||||
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
|
||||
extras = enclose "{-\n" "\n-}" $
|
||||
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
|
||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||
-- debugging.
|
||||
extras :: OpDef -> Doc
|
||||
extras d = enclose "{-\n" "\n-}" $
|
||||
strictText $ Text.pack $
|
||||
showMessage ((def :: OpDef)
|
||||
& inputArg .~ (d ^. inputArg)
|
||||
& outputArg .~ (d ^. outputArg)
|
||||
& attr .~ (d ^. attr))
|
||||
typeMap = opDefTypeMap d
|
||||
|
||||
-- | Makes a quoted string doc out of the given text value.
|
||||
quotedText :: Text.Text -> Doc
|
||||
quotedText = dquotes . strictText
|
||||
-- | The type signature for an op.
|
||||
-- Of the form:
|
||||
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
||||
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||
-- "Tensor t2 v2" is an output.
|
||||
typeSig :: ParsedOp -> Doc
|
||||
typeSig pOp = constraints
|
||||
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
++ map tensorArgAndComment (parsedInputs pOp)
|
||||
++ [outputs])
|
||||
where
|
||||
constraints
|
||||
| null (inferredTypeAttrs pOp) = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
ArgTensorEither v <- [parsedArgKind k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||
renderAttrType (AttrList a) = brackets $ renderAttrBaseType a
|
||||
renderAttrBaseType = \case
|
||||
AttrBytes -> "ByteString"
|
||||
AttrInt64 -> "Data.Int.Int64"
|
||||
AttrFloat -> "Float"
|
||||
AttrBool -> "Bool"
|
||||
AttrType -> "DataType"
|
||||
AttrShape -> "TensorShapeProto"
|
||||
AttrTensor -> "TensorProto"
|
||||
|
||||
-- | typeSig renders the type signature of the given OpDef.
|
||||
typeSig :: OpDef -> Doc
|
||||
typeSig d =
|
||||
foralls <+> constraints <+/>
|
||||
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
|
||||
where
|
||||
foralls | null typeMap = empty
|
||||
| otherwise =
|
||||
"forall"
|
||||
<+> sep (refVariableNames ++ typeMapTypeNames)
|
||||
<+> "."
|
||||
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
|
||||
constraints | null typeMap = empty
|
||||
| otherwise =
|
||||
tuple (concatMap
|
||||
(\(t, aDef) ->
|
||||
"TensorType" <+> strictText t
|
||||
: maybeToList (oneOfRestrictions aDef t))
|
||||
(Map.elems typeMap)) <+> "=>"
|
||||
refVariableNames = catMaybes (map fst tensorInputs)
|
||||
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
||||
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
|
||||
tensorArg refType arg = wrapArg refType arg <**>
|
||||
pure (<+> hang 0 ("-- ^" <+> argComment arg))
|
||||
-- Argument type is a list of tensors if number_attr is set;
|
||||
-- otherwise it's a single Tensor.
|
||||
wrapArg refType arg =
|
||||
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
|
||||
where typ = tensorType refType arg
|
||||
-- The result is (reference type variable if any, type representing the arg)
|
||||
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
|
||||
tensorType refType arg
|
||||
-- Deals with resource handles that are unlike tensors and
|
||||
-- have their own type level representation. The magic "dtype"
|
||||
-- name is the name of the attribute specifying the type of
|
||||
-- the resource handle. It is not referenced by resource input
|
||||
-- or output because it isn't expressible in TF operation
|
||||
-- signatures.
|
||||
| (arg ^. type' == DT_RESOURCE) =
|
||||
(Nothing, strictText "ResourceHandle dtype")
|
||||
| otherwise =
|
||||
(Just refType,
|
||||
"Tensor" <+> refType <+> maybe directType strictText indirectType)
|
||||
where
|
||||
-- This is the case of a named parameter type which is
|
||||
-- constrained as a OneOf.
|
||||
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
|
||||
-- The nominal case when the type name is given directly.
|
||||
directType = dtTypeToDoc (arg ^. type')
|
||||
outputs =
|
||||
case d ^. outputArg of
|
||||
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
|
||||
outputs = case parsedOutputs pOp of
|
||||
[] -> "ControlNode"
|
||||
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
||||
os -> renderTupleResult os
|
||||
wrappedOutput = snd . wrapArg "Value"
|
||||
-- Tuple result case is rendered differently to give
|
||||
-- individual elements their own comments.
|
||||
renderTupleResult os =
|
||||
stack $ [ tuple (map wrappedOutput os)
|
||||
, flatten commentSummary
|
||||
] ++ map commentDetails os
|
||||
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
|
||||
[a] -> tensorArg a <+> "-- ^" <+> argComment a
|
||||
as -> tuple (map tensorArg as) <+/> resultComment as
|
||||
|
||||
-- | Render an op input or output.
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
|
||||
tensorArg :: ParsedArg -> Doc
|
||||
tensorArg p = case parsedArgCase p of
|
||||
SimpleArg { argType = t } -> tensorType t
|
||||
ListArg { argType = t } -> brackets $ tensorType t
|
||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||
where
|
||||
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
|
||||
tensorType t = let
|
||||
v = case parsedArgKind p of
|
||||
ArgTensorRef -> "Tensor Ref"
|
||||
ArgTensorValue -> "Tensor Value"
|
||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||
ArgResource -> "ResourceHandle"
|
||||
a = case t of
|
||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||
ArgTypeAttr n -> renderHaskellName n
|
||||
in v <+> a
|
||||
|
||||
attrComment :: Attr a -> Doc
|
||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||
|
||||
argComment :: ParsedArg -> Doc
|
||||
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
|
||||
|
||||
argComment' :: Name -> Text.Text -> Doc
|
||||
argComment' argName argDesc =
|
||||
bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc
|
||||
|
||||
bold :: Doc -> Doc
|
||||
bold n = "__" <> n <> "__"
|
||||
|
||||
-- | Comment for the outputs of an op.
|
||||
-- For example:
|
||||
-- -- ^ (__output1__, __output2__)
|
||||
-- --
|
||||
-- -- * __output1__: description1
|
||||
-- --
|
||||
-- -- * __output2__: description2
|
||||
resultComment :: [ParsedArg] -> Doc
|
||||
resultComment os = stack $ flatten commentSummary : map commentDetails os
|
||||
where
|
||||
commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os]
|
||||
commentDetails o =
|
||||
stack [ "--"
|
||||
, "-- *" <+> argComment o
|
||||
]
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
mandatoryAttrInputs = [
|
||||
dtTypeToDoc dtType <+>
|
||||
hang 0 ("-- ^" <+> argComment' tfName descr)
|
||||
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
|
||||
typeMap = opDefTypeMap d
|
||||
|
||||
-- | Returns the type restriction for the given tensor type if the
|
||||
-- set of allowed types is not empty (i.e. restricted).
|
||||
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
|
||||
oneOfRestrictions aDef tName = do
|
||||
typs <- onAttrType (^. templateRestrictions) aDef
|
||||
guard $ not $ null typs
|
||||
let typeList = commasep $ map strictText $
|
||||
Set.toList $ Set.fromList $
|
||||
map dtTypeToHaskell typs
|
||||
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
|
||||
|
||||
-- | Identifies the attributes used as tensor cardinalities. In such
|
||||
-- cases a list of tensors is supplied as an input_arg. The number of
|
||||
-- such inputs is communicated as a separate opAttr.
|
||||
-- The result key is TensorFlow attribute name and the value is the
|
||||
-- tensor names which have number_attr set to the result key.
|
||||
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
|
||||
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
|
||||
(nAttr, replaceReservedName (inp ^. name) :| [])
|
||||
| inp <- d ^. inputArg
|
||||
, nAttr <- [inp ^. numberAttr]
|
||||
, not (Text.null nAttr)
|
||||
]
|
||||
|
||||
argComment :: OpDef'ArgDef -> Doc
|
||||
argComment arg = argComment' (arg ^. name) (arg ^. description)
|
||||
|
||||
argComment' :: Text.Text -> Text.Text -> Doc
|
||||
argComment' argName argDesc =
|
||||
bold argName <> splitMultilineText (":" <+>) argDesc
|
||||
|
||||
bold :: Text.Text -> Doc
|
||||
bold n = strictText ("__" <> n <> "__")
|
||||
|
||||
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
|
||||
|
||||
-- | Returns the map of type parameters from OpDef type name to
|
||||
-- (haskell friendly type name, the type's attribute definition).
|
||||
opDefTypeMap :: OpDef -> OpDefTypeMap
|
||||
opDefTypeMap d =
|
||||
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
||||
|
||||
attrList :: OpDef -> [(Text.Text, AttrDef)]
|
||||
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
|
||||
|
||||
isType :: AttrDef -> Bool
|
||||
isType = fromMaybe False . onAttrType (const True)
|
||||
|
||||
-- | Applies the given function to the data type. Is this a Prism?
|
||||
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
|
||||
onAttrType f x = case x ^. attrTemplate of
|
||||
AttrSingle (AttrType a) -> Just (f a)
|
||||
_ -> Nothing
|
||||
|
||||
-- | mandatoryAttrMap contains the attributes chosen by
|
||||
-- isMandatoryAttr, excluding those which are derived from list of
|
||||
-- tensor arguments. The key is the TF name of the attribute. The
|
||||
-- value tuple is (haskell name, TF type, attribute comment).
|
||||
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
|
||||
mandatoryAttrMap d =
|
||||
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
|
||||
| (n, a) <- attrList d
|
||||
, Just dtType <- [isMandatoryAttr a]
|
||||
-- Excludes the attributes rendered as list lengths.
|
||||
, n `Map.notMember` numberAttrMap d
|
||||
]
|
||||
|
||||
-- | Inspects the attribute and if it is one of the implemented
|
||||
-- non-tensor values lacking default, then returns Just the TF type.
|
||||
isMandatoryAttr :: AttrDef -> Maybe DataType
|
||||
isMandatoryAttr x =
|
||||
case x ^. attrTemplate of
|
||||
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
|
||||
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
|
||||
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
|
||||
_ -> Nothing
|
||||
-- | Constraints for a given type parameter.
|
||||
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
|
||||
tensorArgConstraint :: Attr [DataType] -> [Doc]
|
||||
tensorArgConstraint a
|
||||
= ("TensorType" <+> n
|
||||
: if null typeList
|
||||
then []
|
||||
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
|
||||
where
|
||||
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
|
||||
|
||||
dtTypeToDoc :: DataType -> Doc
|
||||
dtTypeToDoc = strictText . dtTypeToHaskell
|
||||
n = renderHaskellName $ attrName a
|
||||
typeList = map strictText $
|
||||
Set.toList $ Set.fromList $
|
||||
map dtTypeToHaskell $ attrInfo a
|
||||
|
||||
-- NOTE: The cases of this function should be kept in sync with
|
||||
-- TensorFlow.Types.AllTensorTypes.
|
||||
|
@ -429,6 +363,12 @@ dtTypeToHaskell x =
|
|||
haddockComment :: Text.Text -> Doc
|
||||
haddockComment = strictText
|
||||
|
||||
-- | Generate a multiline comment. For example:
|
||||
-- summary'
|
||||
-- --
|
||||
-- -- detail_line1
|
||||
-- -- detail_line2
|
||||
-- -- ...
|
||||
multilineComment :: Text.Text -> Text.Text -> Doc
|
||||
multilineComment summary' detail =
|
||||
haddockComment summary' </>
|
||||
|
@ -445,45 +385,5 @@ splitMultilineText lead detail =
|
|||
(l : ls) -> stack $ lead (haddockComment l)
|
||||
: map (("--" <+>) . haddockComment) ls
|
||||
|
||||
replaceReservedName :: Text -> Text
|
||||
replaceReservedName n
|
||||
| n `Set.member` reservedKeywords = n <> "'"
|
||||
| otherwise = n
|
||||
|
||||
indentation :: Int
|
||||
indentation = 4
|
||||
|
||||
reservedKeywords :: Set.Set Text
|
||||
reservedKeywords = Set.fromList $
|
||||
-- Haskell2010 keywords:
|
||||
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
||||
-- We don't include keywords that are allowed to be variable names,
|
||||
-- in particular: "as", "forall", and "hiding".
|
||||
[ "case"
|
||||
, "class"
|
||||
, "data"
|
||||
, "default"
|
||||
, "deriving"
|
||||
, "do"
|
||||
, "else"
|
||||
, "foreign"
|
||||
, "if"
|
||||
, "import"
|
||||
, "in"
|
||||
, "infix"
|
||||
, "infixl"
|
||||
, "infixr"
|
||||
, "instance"
|
||||
, "let"
|
||||
, "module"
|
||||
, "newtype"
|
||||
, "of"
|
||||
, "then"
|
||||
, "type"
|
||||
, "where"
|
||||
]
|
||||
++ -- Nonstandard extensions
|
||||
[ "mdo" -- RecursiveDo
|
||||
, "rec" -- Arrows, RecursiveDo
|
||||
, "proc" -- Arrows
|
||||
]
|
||||
|
|
|
@ -1,120 +0,0 @@
|
|||
-- Copyright 2016 TensorFlow authors.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
-- | Wrapping of TensorFlow attributes into Haskell entities.
|
||||
module TensorFlow.OpGen.AttrVal
|
||||
(AttrDef
|
||||
, AttrCase(..)
|
||||
, AttrTemplate(..)
|
||||
, Template
|
||||
, attrDef
|
||||
, attrOriginal
|
||||
, attrTemplate
|
||||
, templateDefault
|
||||
, templateRestrictions
|
||||
) where
|
||||
|
||||
import Data.Int (Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Lens.Family2 (Lens', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
|
||||
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.Text as Text
|
||||
|
||||
-- | Specifies the optional default value and a set of allowed values
|
||||
-- for the given type.
|
||||
data Template a = Template {
|
||||
_templateDefault :: Maybe a
|
||||
-- ^ The default value (mandatory if unspecified)
|
||||
, _templateRestrictions :: [a]
|
||||
-- ^ The allowed set of values, empty if no restrictions
|
||||
}
|
||||
|
||||
templateDefault :: Lens' (Template a) (Maybe a)
|
||||
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
|
||||
|
||||
templateRestrictions :: Lens' (Template a) [a]
|
||||
templateRestrictions = lens _templateRestrictions
|
||||
(\g x -> g { _templateRestrictions = x })
|
||||
|
||||
data UnusedTensor
|
||||
|
||||
data AttrCase f
|
||||
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
|
||||
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
|
||||
| AttrFloat (f Float) -- float f = 4; // "float"
|
||||
| AttrBool (f Bool) -- bool b = 5; // "bool"
|
||||
| AttrType (f DataType) -- type = 6; // "type"
|
||||
-- To be translated into TensorFlow.Types.Shape before use.
|
||||
-- Leaving as a proto to reduce dependencies.
|
||||
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
|
||||
|
||||
-- | Type-reified representation of TensorFlow AttrDef.
|
||||
-- Initially limited to just the types in Op descriptors.
|
||||
data AttrTemplate
|
||||
= AttrSingle (AttrCase Template)
|
||||
| AttrList (AttrCase [])
|
||||
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
|
||||
|
||||
data AttrDef = AttrDef {
|
||||
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
|
||||
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
|
||||
}
|
||||
|
||||
attrTemplate :: Lens' AttrDef AttrTemplate
|
||||
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
|
||||
|
||||
attrOriginal :: Lens' AttrDef OpDef'AttrDef
|
||||
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
|
||||
|
||||
attrDef :: OpDef'AttrDef -> AttrDef
|
||||
attrDef a = AttrDef a
|
||||
$ translate (a^.OpDef.type')
|
||||
(a^.OpDef.defaultValue)
|
||||
(a^.allowedValues)
|
||||
|
||||
-- | Converts the given AttrValue with the type given by the string
|
||||
-- into the AttrVal if the type is known.
|
||||
translate :: Text.Text -- ^ one of the TensorFlow type strings
|
||||
-> AttrValue -- ^ default value
|
||||
-> AttrValue -- ^ allowed values
|
||||
-> AttrTemplate
|
||||
translate t defaults allowed
|
||||
| t == "string" = makeVal AttrBytes maybe's s
|
||||
| t == "int" = makeVal AttrInt64 maybe'i i
|
||||
| t == "float" = makeVal AttrFloat maybe'f f
|
||||
| t == "bool" = makeVal AttrBool maybe'b b
|
||||
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
|
||||
| t == "shape" = makeVal AttrShape maybe'shape shape
|
||||
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
|
||||
| t == "list(string)" = makeList AttrBytes $ list.s
|
||||
| t == "list(int)" = makeList AttrInt64 $ list.i
|
||||
| t == "list(float)" = makeList AttrFloat $ list.f
|
||||
| t == "list(bool)" = makeList AttrBool $ list.b
|
||||
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
|
||||
| t == "list(shape)" = makeList AttrShape $ list.shape
|
||||
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
|
||||
| t == "func" = AttrTensor $ error "func is unimplemented"
|
||||
| otherwise = error $ show ("Unknown attribute type " <> t) ++
|
||||
"," ++ show defaults ++
|
||||
"," ++ show allowed
|
||||
where makeVal c x y = AttrSingle $ c $
|
||||
Template (defaults^.x) (allowed^.list.y)
|
||||
makeList c y = AttrList $ c $ defaults^.y
|
299
tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs
Normal file
299
tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs
Normal file
|
@ -0,0 +1,299 @@
|
|||
-- | This module helps parse the proto OpDef into a Haskell type which is more
|
||||
-- descriptive of how the attributes and arguments will be used in the
|
||||
-- generated code.
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
module TensorFlow.OpGen.ParsedOp
|
||||
( ParsedOp(..)
|
||||
, Name(..)
|
||||
, HaskellName(..)
|
||||
, TFName(..)
|
||||
, Attr(..)
|
||||
, AttrType(..)
|
||||
, AttrBaseType(..)
|
||||
, ParsedArg(..)
|
||||
, ParsedArgCase(..)
|
||||
, ArgType(..)
|
||||
, ArgKind(..)
|
||||
, parseOp
|
||||
, camelCase
|
||||
) where
|
||||
|
||||
import Data.Char (toUpper, toLower)
|
||||
import Data.List (sortBy)
|
||||
import Data.List.NonEmpty (NonEmpty, nonEmpty)
|
||||
import Data.Maybe (mapMaybe)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Ord (comparing)
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 ((^.))
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (list)
|
||||
import Proto.Tensorflow.Core.Framework.OpDef
|
||||
( OpDef
|
||||
, OpDef'ArgDef
|
||||
, OpDef'AttrDef
|
||||
, allowedValues
|
||||
, attr
|
||||
, maybe'defaultValue
|
||||
, description
|
||||
, name
|
||||
, inputArg
|
||||
, outputArg
|
||||
, summary
|
||||
, typeListAttr
|
||||
, numberAttr
|
||||
, typeAttr
|
||||
, type'
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
|
||||
|
||||
data ParsedOp = ParsedOp
|
||||
{ parsedOpName :: Name
|
||||
, parsedOpSummary :: Text
|
||||
, parsedOpDescription :: Text
|
||||
, parsedInputs :: [ParsedArg]
|
||||
, parsedOutputs :: [ParsedArg]
|
||||
, explicitInputAttrs :: [Attr AttrType]
|
||||
-- ^ Attributes that must be set explicitly when creating the op.
|
||||
-- Associated with the type of the attribute.
|
||||
, inferredTypeAttrs :: [Attr [DataType]]
|
||||
-- ^ Attributes that are type parameters.
|
||||
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If this list is empty, then any type is acceptable.
|
||||
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
||||
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||
-- from one or more of the input tensors.
|
||||
-- Associated with the list of tensors whose size it describes.
|
||||
}
|
||||
|
||||
data Name = Name
|
||||
{ haskellName :: HaskellName
|
||||
, tfName :: TFName
|
||||
}
|
||||
|
||||
-- | A raw name as specified in the OpDef proto.
|
||||
newtype TFName = TFName { unTFName :: Text }
|
||||
deriving (Eq, Ord)
|
||||
|
||||
-- | A name that's appropriate for a variable in a Haskell source file.
|
||||
newtype HaskellName = HaskellName { unHaskellName :: Text }
|
||||
|
||||
-- | A named attribute, associated with some information about it.
|
||||
data Attr a = Attr
|
||||
{ attrName :: Name
|
||||
, attrDescription :: Text
|
||||
, attrInfo :: a
|
||||
}
|
||||
|
||||
-- | The type of an attribute.
|
||||
data AttrType = AttrSingle AttrBaseType
|
||||
| AttrList AttrBaseType
|
||||
deriving Eq
|
||||
|
||||
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
||||
| AttrType | AttrShape | AttrTensor
|
||||
deriving Eq
|
||||
|
||||
-- | An input or output argument (Tensor) for an op.
|
||||
data ParsedArg = ParsedArg
|
||||
{ parsedArgName :: Name
|
||||
, parsedArgDescription :: Text
|
||||
, parsedArgCase :: ParsedArgCase
|
||||
, parsedArgKind :: ArgKind
|
||||
}
|
||||
|
||||
data ParsedArgCase
|
||||
= SimpleArg { argType :: ArgType }
|
||||
| ListArg
|
||||
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||
, argType :: ArgType
|
||||
}
|
||||
| MixedListArg { argTypeAttr :: Name }
|
||||
-- ^ A heterogeneous list.
|
||||
-- TODO(judahjacobson): Implement this.
|
||||
|
||||
-- | The type of an argument.
|
||||
data ArgType
|
||||
= ArgTypeFixed DataType -- ^ A fixed type.
|
||||
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
|
||||
|
||||
-- The kind of an op input or output (not including the argument type `a`).
|
||||
data ArgKind
|
||||
= ArgTensorRef -- Tensor Ref a
|
||||
| ArgTensorValue -- Tensor Value a
|
||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||
| ArgResource -- Resource a
|
||||
|
||||
|
||||
makeName :: Text -> Name
|
||||
makeName n = Name
|
||||
{ haskellName = HaskellName $ fixReservedName $ lowCase n
|
||||
, tfName = TFName n
|
||||
}
|
||||
|
||||
-- | Change a name so it doesn't conflict with any Haskell keywords.
|
||||
fixReservedName :: Text -> Text
|
||||
fixReservedName n
|
||||
| n `Set.member` reservedKeywords = n <> "'"
|
||||
| otherwise = n
|
||||
|
||||
reservedKeywords :: Set.Set Text
|
||||
reservedKeywords = Set.fromList $
|
||||
-- Haskell2010 keywords:
|
||||
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
||||
-- We don't include keywords that are allowed to be variable names,
|
||||
-- in particular: "as", "forall", and "hiding".
|
||||
[ "case"
|
||||
, "class"
|
||||
, "data"
|
||||
, "default"
|
||||
, "deriving"
|
||||
, "do"
|
||||
, "else"
|
||||
, "foreign"
|
||||
, "if"
|
||||
, "import"
|
||||
, "in"
|
||||
, "infix"
|
||||
, "infixl"
|
||||
, "infixr"
|
||||
, "instance"
|
||||
, "let"
|
||||
, "module"
|
||||
, "newtype"
|
||||
, "of"
|
||||
, "then"
|
||||
, "type"
|
||||
, "where"
|
||||
]
|
||||
++ -- Nonstandard extensions
|
||||
[ "mdo" -- RecursiveDo
|
||||
, "rec" -- Arrows, RecursiveDo
|
||||
, "proc" -- Arrows
|
||||
]
|
||||
|
||||
-- | Lower-case the given text.
|
||||
lowCase :: Text -> Text
|
||||
lowCase = forceCase toLower
|
||||
|
||||
forceCase :: (Char -> Char) -> Text -> Text
|
||||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||
(Text.uncons s)
|
||||
|
||||
camelCase :: Text -> Text
|
||||
camelCase s = Text.concat $ map upCase
|
||||
$ Text.splitOn "_" s
|
||||
|
||||
-- | Upper-case the given text.
|
||||
upCase :: Text -> Text
|
||||
upCase = forceCase toUpper
|
||||
|
||||
|
||||
parseOp :: OpDef -> ParsedOp
|
||||
parseOp o = ParsedOp
|
||||
{ parsedOpName = makeName $ o ^. name
|
||||
, parsedOpSummary = o ^. summary
|
||||
, parsedOpDescription = o ^. description
|
||||
, ..
|
||||
}
|
||||
where
|
||||
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
||||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ o ^. attr
|
||||
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
$ o ^. attr
|
||||
implicitAttrs = Set.fromList $ map tfName $
|
||||
map attrName inferredTypeAttrs
|
||||
++ map attrName inferredListSizeAttrs
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
inputTensorKind a v
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| otherwise = ArgTensorEither v
|
||||
|
||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||
outputTensorKind a
|
||||
| a ^. type' == DT_RESOURCE = ArgResource
|
||||
| otherwise = ArgTensorValue
|
||||
|
||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr implicitAttrs a
|
||||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||
, a ^. maybe'defaultValue == Nothing
|
||||
, t <- parseAttrType (a ^. type')
|
||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||
getInferredTypeAttr a
|
||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||
| otherwise = Nothing
|
||||
|
||||
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
||||
getInferredListSizeAttr inputs a
|
||||
| a ^. type' == "int"
|
||||
= nonEmpty [t | ParsedArg { parsedArgName = t
|
||||
, parsedArgCase
|
||||
= ListArg { argLength = n }
|
||||
} <- inputs
|
||||
, TFName (a ^. name) == tfName n]
|
||||
| otherwise = Nothing
|
||||
|
||||
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
||||
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
||||
mapMaybeAttrs f = mapMaybe $ \a -> do
|
||||
x <- f a
|
||||
Just Attr
|
||||
{ attrName = makeName (a ^. name)
|
||||
, attrDescription = a ^. description
|
||||
, attrInfo = x
|
||||
}
|
||||
|
||||
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||
parseArg a tKind = ParsedArg
|
||||
{ parsedArgName = makeName (a ^. name)
|
||||
, parsedArgDescription = a ^. description
|
||||
, parsedArgCase = parseArgCase a
|
||||
, parsedArgKind = tKind
|
||||
}
|
||||
|
||||
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
|
||||
parseArgCase a
|
||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
|
||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
|
||||
| otherwise = SimpleArg thisArgType
|
||||
where
|
||||
thisArgType
|
||||
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
||||
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
|
||||
| otherwise = ArgTypeFixed (a ^. type')
|
||||
maybeAttr :: Text -> Maybe Name
|
||||
maybeAttr "" = Nothing
|
||||
maybeAttr t = Just $ makeName t
|
||||
|
||||
parseAttrType :: Text -> AttrType
|
||||
parseAttrType = \case
|
||||
"string" -> AttrSingle AttrBytes
|
||||
"int" -> AttrSingle AttrInt64
|
||||
"float" -> AttrSingle AttrFloat
|
||||
"bool" -> AttrSingle AttrBool
|
||||
"type" -> AttrSingle AttrType
|
||||
"shape" -> AttrSingle AttrShape
|
||||
"tensor" -> AttrSingle AttrTensor
|
||||
"list(string)" -> AttrList AttrBytes
|
||||
"list(int)" -> AttrList AttrInt64
|
||||
"list(float)" -> AttrList AttrFloat
|
||||
"list(bool)" -> AttrList AttrBool
|
||||
"list(type)" -> AttrList AttrType
|
||||
"list(shape)" -> AttrList AttrShape
|
||||
"list(tensor)" -> AttrList AttrTensor
|
||||
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
|
@ -13,7 +13,7 @@ cabal-version: >=1.22
|
|||
|
||||
library
|
||||
hs-source-dirs: src
|
||||
exposed-modules: TensorFlow.OpGen.AttrVal
|
||||
exposed-modules: TensorFlow.OpGen.ParsedOp
|
||||
, TensorFlow.OpGen
|
||||
build-depends: proto-lens == 0.1.*
|
||||
, tensorflow-proto == 0.1.*
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -24,6 +25,8 @@ import qualified Data.Vector as V
|
|||
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Test split and concat are inverses.
|
||||
|
@ -34,11 +37,18 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
|
|||
restored = CoreOps.concat dim splitList
|
||||
dim = 1 -- dimension to split along (with size of 3 in original)
|
||||
liftIO $ 3 @=? length splitList
|
||||
(x, y, z) <-
|
||||
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
|
||||
(x, y, z) <- TF.run (original, restored, splitList !! 1)
|
||||
liftIO $ x @=? (y :: V.Vector Float)
|
||||
liftIO $ V.fromList [1, 4] @=? z
|
||||
|
||||
testShapeN :: Test
|
||||
testShapeN = testCase "testShapeN" $ TF.runSession $ do
|
||||
let shapes = map TF.Shape [[1],[2,3]]
|
||||
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
|
||||
result <- TF.run $ CoreOps.shapeN tensors
|
||||
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSplit
|
||||
, testShapeN
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue