1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Refactor OpGen. (#36)

Also fixes op lists when the same attribute specifies the length of
both an input and an output.  I added a test of "shapeN" which
previously failed with the following error:

    ERROR: Ran out of counts in toResult. Likely misuse of buildListOp.
This commit is contained in:
Judah Jacobson 2016-11-20 10:00:22 -08:00 committed by Greg Steuck
parent 2b5e41ffeb
commit a277c7ddb3
5 changed files with 502 additions and 413 deletions

View file

@ -13,6 +13,7 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions. -- | Rendering of TensorFlow operations as Haskell functions.
@ -23,46 +24,25 @@ module TensorFlow.OpGen
, flagParser) , flagParser)
where where
import Prelude hiding (head, tail)
import Control.Applicative ((<**>))
import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList) import Data.Foldable (toList)
import Data.Maybe (catMaybes, fromMaybe, maybeToList) import Data.Maybe (fromMaybe)
import Data.ProtoLens (def, showMessage) import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head) import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view) import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value) import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef import Proto.Tensorflow.Core.Framework.OpDef
( OpList ( OpList
, OpDef , OpDef
, OpDef'ArgDef
, attr , attr
, description
, inputArg , inputArg
, name , name
, numberAttr
, op , op
, outputArg , outputArg
, summary
, type'
, typeAttr
) )
import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName) import System.FilePath (takeBaseName)
import TensorFlow.OpGen.AttrVal import TensorFlow.OpGen.ParsedOp
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
)
import Text.PrettyPrint.Mainland import Text.PrettyPrint.Mainland
( Doc ( Doc
, (<>) , (<>)
@ -79,18 +59,14 @@ import Text.PrettyPrint.Mainland
, folddoc , folddoc
, hang , hang
, indent , indent
, int
, parens , parens
, sep , sep
, stack , stack
, strictText , strictText
, tuple , tuple
) )
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set import qualified Data.Set as Set
import qualified Data.Text as Text import qualified Data.Text as Text
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)
data OpGenFlags = OpGenFlags data OpGenFlags = OpGenFlags
{ outputFile :: String { outputFile :: String
@ -118,7 +94,6 @@ docOpList flags opList =
, "{-# LANGUAGE DataKinds #-}" , "{-# LANGUAGE DataKinds #-}"
, "{-# LANGUAGE FlexibleInstances #-}" , "{-# LANGUAGE FlexibleInstances #-}"
, "{-# LANGUAGE OverloadedStrings #-}" , "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}" , "{-# LANGUAGE ScopedTypeVariables #-}"
-- Avoids reports about shadowing standard library names. -- Avoids reports about shadowing standard library names.
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}" , "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
@ -129,34 +104,17 @@ docOpList flags opList =
, imports , imports
, empty , empty
, folddoc (\x y -> x </> empty </> y) , folddoc (\x y -> x </> empty </> y)
(map renderDef $ (map renderOpAndExtras $
filter (not . flip elem exclusions . view name) $ filter (not . flip elem exclusions . view name) $
toList $ opList ^. op) toList $ opList ^. op)
] ]
where moduleName = where moduleName =
Text.pack (prefix flags) <> "." <> camelCase Text.pack (prefix flags) <> "." <> camelCase
-- Discards the optional trailing _op_lib -- Discards the optional trailing _ops_op_lib
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName)) (fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags) shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
renderOpAndExtras o = renderOp (parseOp o) </> extras o
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
-- | Lower-case the given name, and prevent it from overlapping with a reserved
-- Haskell name.
lowCase :: Text -> Text
lowCase = replaceReservedName . forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
imports :: Doc imports :: Doc
imports = stack [ imports = stack [
@ -170,232 +128,208 @@ imports = stack [
, "import TensorFlow.Output (ResourceHandle)" , "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor" , "import TensorFlow.Tensor"
, "import TensorFlow.Types" , "import TensorFlow.Types"
]
renderDef :: OpDef -> Doc
renderDef d =
stack [
haddocks
, n <+> "::" <+> hang 0 (typeSig d)
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation functionBody
, extras -- just for debug
] ]
renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
renderHaskellName = strictText . unHaskellName . haskellName
renderTFName = strictText . unTFName . tfName
renderQuotedTFName = dquotes . renderTFName
-- | Generate the source code for a single op.
-- For example:
--
-- -- | {haddock comment}
-- foo :: {type sig}
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
renderOp :: ParsedOp -> Doc
renderOp pOp = stack $
[ haddocks
, n <+> "::" <+> hang 0 (typeSig pOp)
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where where
n = strictText $ fixOpName (d ^. name) n = renderHaskellName $ parsedOpName pOp
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs listSizeAttrs = inferredListSizeAttrs pOp
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg] args = sep $ map renderHaskellName
fixOpName = lowCase $ map attrName (explicitInputAttrs pOp)
funcGuard = "eqLengthGuard" <+> brackets (commasep entries) ++ map parsedArgName (parsedInputs pOp)
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length.
-- For example:
-- eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])]
funcGuard :: [Attr (NonEmpty Name)] -> Doc
funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
where where
entries = entries =
[ parens $ quotedText nAttr <> comma <+> [ parens $ nAttr <> comma <+>
brackets (commasep $ toList $ brackets (commasep $ toList $
NE.map renderTensorName tensorNames) map renderTensorName (toList $ attrInfo a))
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d | a <- attrs
, let nAttr = renderQuotedTFName (attrName a)
] ]
renderTensorName x = parens $ quotedText x <> comma <+> renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
"length" <+> strictText x "length" <+> renderHaskellName x
-- Uses hang 0 to align the argument vertically on multiple lines.
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts)) -- | Define the implicit list length attributes.
</> indent indentation (sep tensorArgs) -- For example:
-- where
-- n1 = fromIntegral (length input1) :: Int64
-- n2 = fromIntegral (length input2) :: Int64
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
buildFunction buildFunction
| null outputListsSizes = "buildOp" | null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes) | otherwise = "buildListOp" <+>
outputListsSizes = [ strictText numberAttrName brackets (commasep $
| o <- d ^. outputArg map renderHaskellName outputListsSizes)
, let numberAttrName = o ^. numberAttr outputListsSizes = [ a
, not (Text.null numberAttrName) && | ParsedArg { parsedArgCase = ListArg { argLength = a } }
numberAttrName `Map.member` mandatoryAttrMap d <- parsedOutputs pOp]
]
buildOpParts = buildOpParts =
"opDef" <+> quotedText (d ^. name) : "opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments. -- Renders tensor arguments.
[ "& opAttr" <+> quotedText tfName <+> [ "& opAttr" <+> renderQuotedTFName n <+>
".~ tensorType (undefined ::" <+> strictText hsName <> ")" ".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
| (tfName, (hsName, _)) <- Map.toList typeMap | a <- inferredTypeAttrs pOp, let n = attrName a
] ++ ] ++
-- Renders mandatory attributes as function parameters. -- Renders mandatory attributes as function parameters.
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| (tfName, hsName) <- mandatoryAttrs | a <- explicitInputAttrs pOp, let n = attrName a
] ++ ] ++
-- Renders sizes of tensor list types having number_attr. -- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+> [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)" | a <- inferredListSizeAttrs pOp, let n = attrName a
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
] ]
mandatoryAttrs = [(strictText tf, strictText hs)
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
]
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
extras = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
typeMap = opDefTypeMap d
-- | Makes a quoted string doc out of the given text value. tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
quotedText :: Text.Text -> Doc
quotedText = dquotes . strictText
-- | typeSig renders the type signature of the given OpDef. -- | Write a comment with the inputs/outputs/attributes in proto format, for
typeSig :: OpDef -> Doc -- debugging.
typeSig d = extras :: OpDef -> Doc
foralls <+> constraints <+/> extras d = enclose "{-\n" "\n-}" $
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs]) strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
-- | The type signature for an op.
-- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc
typeSig pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where where
foralls | null typeMap = empty constraints
| otherwise = | null (inferredTypeAttrs pOp) = empty
"forall" | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
<+> sep (refVariableNames ++ typeMapTypeNames) typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
<+> "." ArgTensorEither v <- [parsedArgKind k]]
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap) ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
constraints | null typeMap = empty classConstraints = tuple $ concatMap tensorArgConstraint
| otherwise = $ inferredTypeAttrs pOp
tuple (concatMap
(\(t, aDef) ->
"TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>"
refVariableNames = catMaybes (map fst tensorInputs)
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
tensorArg refType arg = wrapArg refType arg <**>
pure (<+> hang 0 ("-- ^" <+> argComment arg))
-- Argument type is a list of tensors if number_attr is set;
-- otherwise it's a single Tensor.
wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
where typ = tensorType refType arg
-- The result is (reference type variable if any, type representing the arg)
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
tensorType refType arg
-- Deals with resource handles that are unlike tensors and
-- have their own type level representation. The magic "dtype"
-- name is the name of the attribute specifying the type of
-- the resource handle. It is not referenced by resource input
-- or output because it isn't expressible in TF operation
-- signatures.
| (arg ^. type' == DT_RESOURCE) =
(Nothing, strictText "ResourceHandle dtype")
| otherwise =
(Just refType,
"Tensor" <+> refType <+> maybe directType strictText indirectType)
where
-- This is the case of a named parameter type which is
-- constrained as a OneOf.
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
-- The nominal case when the type name is given directly.
directType = dtTypeToDoc (arg ^. type')
outputs =
case d ^. outputArg of
[] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os
wrappedOutput = snd . wrapArg "Value"
-- Tuple result case is rendered differently to give
-- individual elements their own comments.
renderTupleResult os =
stack $ [ tuple (map wrappedOutput os)
, flatten commentSummary
] ++ map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
signatureFold = folddoc (\x y -> x </> "->" <+> y) signatureFold = folddoc (\x y -> x </> "->" <+> y)
mandatoryAttrInputs = [ attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
dtTypeToDoc dtType <+> renderAttrType (AttrSingle a) = renderAttrBaseType a
hang 0 ("-- ^" <+> argComment' tfName descr) renderAttrType (AttrList a) = brackets $ renderAttrBaseType a
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d] renderAttrBaseType = \case
typeMap = opDefTypeMap d AttrBytes -> "ByteString"
AttrInt64 -> "Data.Int.Int64"
AttrFloat -> "Float"
AttrBool -> "Bool"
AttrType -> "DataType"
AttrShape -> "TensorShapeProto"
AttrTensor -> "TensorProto"
-- | Returns the type restriction for the given tensor type if the tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
-- set of allowed types is not empty (i.e. restricted). outputs = case parsedOutputs pOp of
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc [] -> "ControlNode"
oneOfRestrictions aDef tName = do -- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
typs <- onAttrType (^. templateRestrictions) aDef [a] -> tensorArg a <+> "-- ^" <+> argComment a
guard $ not $ null typs as -> tuple (map tensorArg as) <+/> resultComment as
let typeList = commasep $ map strictText $
Set.toList $ Set.fromList $ -- | Render an op input or output.
map dtTypeToHaskell typs -- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of
SimpleArg { argType = t } -> tensorType t
ListArg { argType = t } -> brackets $ tensorType t
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
where
tensorType t = let
v = case parsedArgKind p of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
ArgResource -> "ResourceHandle"
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n
in v <+> a
-- | Identifies the attributes used as tensor cardinalities. In such attrComment :: Attr a -> Doc
-- cases a list of tensors is supplied as an input_arg. The number of attrComment a = argComment' (attrName a) (attrDescription a)
-- such inputs is communicated as a separate opAttr.
-- The result key is TensorFlow attribute name and the value is the argComment :: ParsedArg -> Doc
-- tensor names which have number_attr set to the result key. argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
(nAttr, replaceReservedName (inp ^. name) :| [])
| inp <- d ^. inputArg
, nAttr <- [inp ^. numberAttr]
, not (Text.null nAttr)
]
argComment :: OpDef'ArgDef -> Doc argComment' :: Name -> Text.Text -> Doc
argComment arg = argComment' (arg ^. name) (arg ^. description)
argComment' :: Text.Text -> Text.Text -> Doc
argComment' argName argDesc = argComment' argName argDesc =
bold argName <> splitMultilineText (":" <+>) argDesc bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc
bold :: Text.Text -> Doc bold :: Doc -> Doc
bold n = strictText ("__" <> n <> "__") bold n = "__" <> n <> "__"
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef) -- | Comment for the outputs of an op.
-- For example:
-- -- ^ (__output1__, __output2__)
-- --
-- -- * __output1__: description1
-- --
-- -- * __output2__: description2
resultComment :: [ParsedArg] -> Doc
resultComment os = stack $ flatten commentSummary : map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
-- | Returns the map of type parameters from OpDef type name to -- | Constraints for a given type parameter.
-- (haskell friendly type name, the type's attribute definition). -- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
opDefTypeMap :: OpDef -> OpDefTypeMap tensorArgConstraint :: Attr [DataType] -> [Doc]
opDefTypeMap d = tensorArgConstraint a
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a] = ("TensorType" <+> n
: if null typeList
attrList :: OpDef -> [(Text.Text, AttrDef)] then []
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr] else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
where
isType :: AttrDef -> Bool n = renderHaskellName $ attrName a
isType = fromMaybe False . onAttrType (const True) typeList = map strictText $
Set.toList $ Set.fromList $
-- | Applies the given function to the data type. Is this a Prism? map dtTypeToHaskell $ attrInfo a
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
onAttrType f x = case x ^. attrTemplate of
AttrSingle (AttrType a) -> Just (f a)
_ -> Nothing
-- | mandatoryAttrMap contains the attributes chosen by
-- isMandatoryAttr, excluding those which are derived from list of
-- tensor arguments. The key is the TF name of the attribute. The
-- value tuple is (haskell name, TF type, attribute comment).
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
mandatoryAttrMap d =
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
| (n, a) <- attrList d
, Just dtType <- [isMandatoryAttr a]
-- Excludes the attributes rendered as list lengths.
, n `Map.notMember` numberAttrMap d
]
-- | Inspects the attribute and if it is one of the implemented
-- non-tensor values lacking default, then returns Just the TF type.
isMandatoryAttr :: AttrDef -> Maybe DataType
isMandatoryAttr x =
case x ^. attrTemplate of
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
_ -> Nothing
where
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
dtTypeToDoc :: DataType -> Doc
dtTypeToDoc = strictText . dtTypeToHaskell
-- NOTE: The cases of this function should be kept in sync with -- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes. -- TensorFlow.Types.AllTensorTypes.
@ -429,6 +363,12 @@ dtTypeToHaskell x =
haddockComment :: Text.Text -> Doc haddockComment :: Text.Text -> Doc
haddockComment = strictText haddockComment = strictText
-- | Generate a multiline comment. For example:
-- summary'
-- --
-- -- detail_line1
-- -- detail_line2
-- -- ...
multilineComment :: Text.Text -> Text.Text -> Doc multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail = multilineComment summary' detail =
haddockComment summary' </> haddockComment summary' </>
@ -445,45 +385,5 @@ splitMultilineText lead detail =
(l : ls) -> stack $ lead (haddockComment l) (l : ls) -> stack $ lead (haddockComment l)
: map (("--" <+>) . haddockComment) ls : map (("--" <+>) . haddockComment) ls
replaceReservedName :: Text -> Text
replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
indentation :: Int indentation :: Int
indentation = 4 indentation = 4
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]

View file

@ -1,120 +0,0 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Wrapping of TensorFlow attributes into Haskell entities.
module TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
) where
import Data.Int (Int64)
import Data.Monoid ((<>))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
import qualified Data.ByteString as B
import qualified Data.Text as Text
-- | Specifies the optional default value and a set of allowed values
-- for the given type.
data Template a = Template {
_templateDefault :: Maybe a
-- ^ The default value (mandatory if unspecified)
, _templateRestrictions :: [a]
-- ^ The allowed set of values, empty if no restrictions
}
templateDefault :: Lens' (Template a) (Maybe a)
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
templateRestrictions :: Lens' (Template a) [a]
templateRestrictions = lens _templateRestrictions
(\g x -> g { _templateRestrictions = x })
data UnusedTensor
data AttrCase f
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
| AttrFloat (f Float) -- float f = 4; // "float"
| AttrBool (f Bool) -- bool b = 5; // "bool"
| AttrType (f DataType) -- type = 6; // "type"
-- To be translated into TensorFlow.Types.Shape before use.
-- Leaving as a proto to reduce dependencies.
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
-- | Type-reified representation of TensorFlow AttrDef.
-- Initially limited to just the types in Op descriptors.
data AttrTemplate
= AttrSingle (AttrCase Template)
| AttrList (AttrCase [])
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
data AttrDef = AttrDef {
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
}
attrTemplate :: Lens' AttrDef AttrTemplate
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
attrOriginal :: Lens' AttrDef OpDef'AttrDef
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
attrDef :: OpDef'AttrDef -> AttrDef
attrDef a = AttrDef a
$ translate (a^.OpDef.type')
(a^.OpDef.defaultValue)
(a^.allowedValues)
-- | Converts the given AttrValue with the type given by the string
-- into the AttrVal if the type is known.
translate :: Text.Text -- ^ one of the TensorFlow type strings
-> AttrValue -- ^ default value
-> AttrValue -- ^ allowed values
-> AttrTemplate
translate t defaults allowed
| t == "string" = makeVal AttrBytes maybe's s
| t == "int" = makeVal AttrInt64 maybe'i i
| t == "float" = makeVal AttrFloat maybe'f f
| t == "bool" = makeVal AttrBool maybe'b b
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
| t == "shape" = makeVal AttrShape maybe'shape shape
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
| t == "list(string)" = makeList AttrBytes $ list.s
| t == "list(int)" = makeList AttrInt64 $ list.i
| t == "list(float)" = makeList AttrFloat $ list.f
| t == "list(bool)" = makeList AttrBool $ list.b
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
| t == "list(shape)" = makeList AttrShape $ list.shape
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
| t == "func" = AttrTensor $ error "func is unimplemented"
| otherwise = error $ show ("Unknown attribute type " <> t) ++
"," ++ show defaults ++
"," ++ show allowed
where makeVal c x y = AttrSingle $ c $
Template (defaults^.x) (allowed^.list.y)
makeList c y = AttrList $ c $ defaults^.y

View file

@ -0,0 +1,299 @@
-- | This module helps parse the proto OpDef into a Haskell type which is more
-- descriptive of how the attributes and arguments will be used in the
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
( ParsedOp(..)
, Name(..)
, HaskellName(..)
, TFName(..)
, Attr(..)
, AttrType(..)
, AttrBaseType(..)
, ParsedArg(..)
, ParsedArgCase(..)
, ArgType(..)
, ArgKind(..)
, parseOp
, camelCase
) where
import Data.Char (toUpper, toLower)
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Maybe (mapMaybe)
import Data.Monoid ((<>))
import Data.Ord (comparing)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Proto.Tensorflow.Core.Framework.AttrValue (list)
import Proto.Tensorflow.Core.Framework.OpDef
( OpDef
, OpDef'ArgDef
, OpDef'AttrDef
, allowedValues
, attr
, maybe'defaultValue
, description
, name
, inputArg
, outputArg
, summary
, typeListAttr
, numberAttr
, typeAttr
, type'
)
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
data ParsedOp = ParsedOp
{ parsedOpName :: Name
, parsedOpSummary :: Text
, parsedOpDescription :: Text
, parsedInputs :: [ParsedArg]
, parsedOutputs :: [ParsedArg]
, explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr [DataType]]
-- ^ Attributes that are type parameters.
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
-- If this list is empty, then any type is acceptable.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
-- Associated with the list of tensors whose size it describes.
}
data Name = Name
{ haskellName :: HaskellName
, tfName :: TFName
}
-- | A raw name as specified in the OpDef proto.
newtype TFName = TFName { unTFName :: Text }
deriving (Eq, Ord)
-- | A name that's appropriate for a variable in a Haskell source file.
newtype HaskellName = HaskellName { unHaskellName :: Text }
-- | A named attribute, associated with some information about it.
data Attr a = Attr
{ attrName :: Name
, attrDescription :: Text
, attrInfo :: a
}
-- | The type of an attribute.
data AttrType = AttrSingle AttrBaseType
| AttrList AttrBaseType
deriving Eq
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
deriving Eq
-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
{ parsedArgName :: Name
, parsedArgDescription :: Text
, parsedArgCase :: ParsedArgCase
, parsedArgKind :: ArgKind
}
data ParsedArgCase
= SimpleArg { argType :: ArgType }
| ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType
}
| MixedListArg { argTypeAttr :: Name }
-- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
-- | The type of an argument.
data ArgType
= ArgTypeFixed DataType -- ^ A fixed type.
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a
makeName :: Text -> Name
makeName n = Name
{ haskellName = HaskellName $ fixReservedName $ lowCase n
, tfName = TFName n
}
-- | Change a name so it doesn't conflict with any Haskell keywords.
fixReservedName :: Text -> Text
fixReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]
-- | Lower-case the given text.
lowCase :: Text -> Text
lowCase = forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
parseOp :: OpDef -> ParsedOp
parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, ..
}
where
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource
| otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource
| otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
| otherwise = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
| otherwise = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
| a ^. type' == "int"
= nonEmpty [t | ParsedArg { parsedArgName = t
, parsedArgCase
= ListArg { argLength = n }
} <- inputs
, TFName (a ^. name) == tfName n]
| otherwise = Nothing
-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do
x <- f a
Just Attr
{ attrName = makeName (a ^. name)
, attrDescription = a ^. description
, attrInfo = x
}
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name)
, parsedArgDescription = a ^. description
, parsedArgCase = parseArgCase a
, parsedArgKind = tKind
}
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
parseArgCase a
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
| otherwise = SimpleArg thisArgType
where
thisArgType
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
| otherwise = ArgTypeFixed (a ^. type')
maybeAttr :: Text -> Maybe Name
maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t
parseAttrType :: Text -> AttrType
parseAttrType = \case
"string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat
"list(bool)" -> AttrList AttrBool
"list(type)" -> AttrList AttrType
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t

View file

@ -13,7 +13,7 @@ cabal-version: >=1.22
library library
hs-source-dirs: src hs-source-dirs: src
exposed-modules: TensorFlow.OpGen.AttrVal exposed-modules: TensorFlow.OpGen.ParsedOp
, TensorFlow.OpGen , TensorFlow.OpGen
build-depends: proto-lens == 0.1.* build-depends: proto-lens == 0.1.*
, tensorflow-proto == 0.1.* , tensorflow-proto == 0.1.*

View file

@ -16,6 +16,7 @@
module Main where module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
@ -24,6 +25,8 @@ import qualified Data.Vector as V
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses. -- | Test split and concat are inverses.
@ -34,11 +37,18 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
restored = CoreOps.concat dim splitList restored = CoreOps.concat dim splitList
dim = 1 -- dimension to split along (with size of 3 in original) dim = 1 -- dimension to split along (with size of 3 in original)
liftIO $ 3 @=? length splitList liftIO $ 3 @=? length splitList
(x, y, z) <- (x, y, z) <- TF.run (original, restored, splitList !! 1)
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
liftIO $ x @=? (y :: V.Vector Float) liftIO $ x @=? (y :: V.Vector Float)
liftIO $ V.fromList [1, 4] @=? z liftIO $ V.fromList [1, 4] @=? z
testShapeN :: Test
testShapeN = testCase "testShapeN" $ TF.runSession $ do
let shapes = map TF.Shape [[1],[2,3]]
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
result <- TF.run $ CoreOps.shapeN tensors
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
main :: IO () main :: IO ()
main = googleTest [ testSplit main = googleTest [ testSplit
, testShapeN
] ]