From a277c7ddb36ab8879dec557ddbfa671a4b314ee9 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Sun, 20 Nov 2016 10:00:22 -0800 Subject: [PATCH] Refactor OpGen. (#36) Also fixes op lists when the same attribute specifies the length of both an input and an output. I added a test of "shapeN" which previously failed with the following error: ERROR: Ran out of counts in toResult. Likely misuse of buildListOp. --- tensorflow-opgen/src/TensorFlow/OpGen.hs | 480 +++++++----------- .../src/TensorFlow/OpGen/AttrVal.hs | 120 ----- .../src/TensorFlow/OpGen/ParsedOp.hs | 299 +++++++++++ tensorflow-opgen/tensorflow-opgen.cabal | 2 +- tensorflow-ops/tests/ArrayOpsTest.hs | 14 +- 5 files changed, 502 insertions(+), 413 deletions(-) delete mode 100644 tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs create mode 100644 tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index 446e049..a825ec7 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -13,6 +13,7 @@ -- limitations under the License. {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} -- | Rendering of TensorFlow operations as Haskell functions. @@ -23,46 +24,25 @@ module TensorFlow.OpGen , flagParser) where -import Prelude hiding (head, tail) - -import Control.Applicative ((<**>)) -import Control.Monad (guard) -import Data.Char (toLower, toUpper) import Data.Foldable (toList) -import Data.Maybe (catMaybes, fromMaybe, maybeToList) +import Data.Maybe (fromMaybe) import Data.ProtoLens (def, showMessage) -import Data.List.NonEmpty (NonEmpty((:|)), head) +import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as NE import Lens.Family2 ((^.), (.~), (&), view) import Options.Applicative (Parser, help, long, strOption, value) import Proto.Tensorflow.Core.Framework.OpDef ( OpList , OpDef - , OpDef'ArgDef , attr - , description , inputArg , name - , numberAttr , op , outputArg - , summary - , type' - , typeAttr ) import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import System.FilePath (takeBaseName) -import TensorFlow.OpGen.AttrVal - (AttrDef - , AttrCase(..) - , AttrTemplate(..) - , Template - , attrDef - , attrOriginal - , attrTemplate - , templateDefault - , templateRestrictions - ) +import TensorFlow.OpGen.ParsedOp import Text.PrettyPrint.Mainland ( Doc , (<>) @@ -79,18 +59,14 @@ import Text.PrettyPrint.Mainland , folddoc , hang , indent - , int , parens , sep , stack , strictText , tuple ) -import qualified Data.Map.Strict as Map import qualified Data.Set as Set import qualified Data.Text as Text -import qualified Data.Semigroup as Semigroup -import Data.Text (Text) data OpGenFlags = OpGenFlags { outputFile :: String @@ -118,7 +94,6 @@ docOpList flags opList = , "{-# LANGUAGE DataKinds #-}" , "{-# LANGUAGE FlexibleInstances #-}" , "{-# LANGUAGE OverloadedStrings #-}" - , "{-# LANGUAGE RankNTypes #-}" , "{-# LANGUAGE ScopedTypeVariables #-}" -- Avoids reports about shadowing standard library names. , "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}" @@ -129,34 +104,17 @@ docOpList flags opList = , imports , empty , folddoc (\x y -> x empty y) - (map renderDef $ + (map renderOpAndExtras $ filter (not . flip elem exclusions . view name) $ toList $ opList ^. op) ] where moduleName = Text.pack (prefix flags) <> "." <> camelCase - -- Discards the optional trailing _op_lib - (fromMaybe shortName (Text.stripSuffix "_op_lib" shortName)) + -- Discards the optional trailing _ops_op_lib + (fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName)) shortName = Text.pack (takeBaseName $ outputFile flags) exclusions = Text.splitOn "," $ Text.pack $ excludeList flags - -camelCase :: Text -> Text -camelCase s = Text.concat $ map upCase - $ filter (/= "ops") - $ Text.splitOn "_" s - --- | Upper-case the given text. -upCase :: Text -> Text -upCase = forceCase toUpper - --- | Lower-case the given name, and prevent it from overlapping with a reserved --- Haskell name. -lowCase :: Text -> Text -lowCase = replaceReservedName . forceCase toLower - -forceCase :: (Char -> Char) -> Text -> Text -forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) - (Text.uncons s) + renderOpAndExtras o = renderOp (parseOp o) extras o imports :: Doc imports = stack [ @@ -170,232 +128,208 @@ imports = stack [ , "import TensorFlow.Output (ResourceHandle)" , "import TensorFlow.Tensor" , "import TensorFlow.Types" - ] - -renderDef :: OpDef -> Doc -renderDef d = - stack [ - haddocks - , n <+> "::" <+> hang 0 (typeSig d) - , n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" -- args are indented - -- the body needs to be indented wrt the name - indent indentation functionBody - , extras -- just for debug ] + +renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc +renderHaskellName = strictText . unHaskellName . haskellName +renderTFName = strictText . unTFName . tfName +renderQuotedTFName = dquotes . renderTFName + + +-- | Generate the source code for a single op. +-- For example: +-- +-- -- | {haddock comment} +-- foo :: {type sig} +-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body} +renderOp :: ParsedOp -> Doc +renderOp pOp = stack $ + [ haddocks + , n <+> "::" <+> hang 0 (typeSig pOp) + , n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs + <+> "=" -- args are indented + -- the body needs to be indented wrt the name + indent indentation (functionBody pOp) + ] ++ whereClause listSizeAttrs where - n = strictText $ fixOpName (d ^. name) - args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs - tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg] - fixOpName = lowCase - funcGuard = "eqLengthGuard" <+> brackets (commasep entries) + n = renderHaskellName $ parsedOpName pOp + listSizeAttrs = inferredListSizeAttrs pOp + args = sep $ map renderHaskellName + $ map attrName (explicitInputAttrs pOp) + ++ map parsedArgName (parsedInputs pOp) + haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp) + +-- | A check that all lists of the given size have the given length. +-- For example: +-- eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])] +funcGuard :: [Attr (NonEmpty Name)] -> Doc +funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries) where entries = - [ parens $ quotedText nAttr <> comma <+> + [ parens $ nAttr <> comma <+> brackets (commasep $ toList $ - NE.map renderTensorName tensorNames) - | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d + map renderTensorName (toList $ attrInfo a)) + | a <- attrs + , let nAttr = renderQuotedTFName (attrName a) ] - renderTensorName x = parens $ quotedText x <> comma <+> - "length" <+> strictText x - -- Uses hang 0 to align the argument vertically on multiple lines. - functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts)) - indent indentation (sep tensorArgs) + renderTensorName x = parens $ renderQuotedTFName x <> comma <+> + "length" <+> renderHaskellName x + +-- | Define the implicit list length attributes. +-- For example: +-- where +-- n1 = fromIntegral (length input1) :: Int64 +-- n2 = fromIntegral (length input2) :: Int64 +whereClause :: [Attr (NonEmpty Name)] -> [Doc] +whereClause [] = [] +whereClause as = [indent 2 $ "where" indent 2 (stack $ map defineLengthAttr as)] + where + defineLengthAttr a = renderHaskellName (attrName a) <+> "=" + <+> "fromIntegral (length" + <+> renderHaskellName (NE.head $ attrInfo a) + <> ") :: Int64" + +functionBody :: ParsedOp -> Doc +functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) + indent indentation (sep tensorArgs) + where buildFunction | null outputListsSizes = "buildOp" - | otherwise = "buildListOp" <+> brackets (commasep outputListsSizes) - outputListsSizes = [ strictText numberAttrName - | o <- d ^. outputArg - , let numberAttrName = o ^. numberAttr - , not (Text.null numberAttrName) && - numberAttrName `Map.member` mandatoryAttrMap d - ] + | otherwise = "buildListOp" <+> + brackets (commasep $ + map renderHaskellName outputListsSizes) + outputListsSizes = [ a + | ParsedArg { parsedArgCase = ListArg { argLength = a } } + <- parsedOutputs pOp] buildOpParts = - "opDef" <+> quotedText (d ^. name) : + "opDef" <+> renderQuotedTFName (parsedOpName pOp) : -- Renders tensor arguments. - [ "& opAttr" <+> quotedText tfName <+> - ".~ tensorType (undefined ::" <+> strictText hsName <> ")" - | (tfName, (hsName, _)) <- Map.toList typeMap + [ "& opAttr" <+> renderQuotedTFName n <+> + ".~ tensorType (undefined ::" <+> renderHaskellName n <> ")" + | a <- inferredTypeAttrs pOp, let n = attrName a ] ++ -- Renders mandatory attributes as function parameters. - [ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName - | (tfName, hsName) <- mandatoryAttrs + [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n + | a <- explicitInputAttrs pOp, let n = attrName a ] ++ -- Renders sizes of tensor list types having number_attr. - [ "& opAttr" <+> quotedText nAttr <+> ".~" <+> - "(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)" - | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d + [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n + | a <- inferredListSizeAttrs pOp, let n = attrName a ] - mandatoryAttrs = [(strictText tf, strictText hs) - | (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d) - ] - haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description) - extras = enclose "{-\n" "\n-}" $ - strictText $ Text.pack $ - showMessage ((def :: OpDef) - & inputArg .~ (d ^. inputArg) - & outputArg .~ (d ^. outputArg) - & attr .~ (d ^. attr)) - typeMap = opDefTypeMap d --- | Makes a quoted string doc out of the given text value. -quotedText :: Text.Text -> Doc -quotedText = dquotes . strictText + tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp --- | typeSig renders the type signature of the given OpDef. -typeSig :: OpDef -> Doc -typeSig d = - foralls <+> constraints <+/> - signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs]) +-- | Write a comment with the inputs/outputs/attributes in proto format, for +-- debugging. +extras :: OpDef -> Doc +extras d = enclose "{-\n" "\n-}" $ + strictText $ Text.pack $ + showMessage ((def :: OpDef) + & inputArg .~ (d ^. inputArg) + & outputArg .~ (d ^. outputArg) + & attr .~ (d ^. attr)) + +-- | The type signature for an op. +-- Of the form: +-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2) +-- => Float -> Tensor t1 v1 -> Tensor t2 v2 +-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and +-- "Tensor t2 v2" is an output. +typeSig :: ParsedOp -> Doc +typeSig pOp = constraints + <+/> signatureFold (map attrInput (explicitInputAttrs pOp) + ++ map tensorArgAndComment (parsedInputs pOp) + ++ [outputs]) where - foralls | null typeMap = empty - | otherwise = - "forall" - <+> sep (refVariableNames ++ typeMapTypeNames) - <+> "." - typeMapTypeNames = map (strictText . fst) (Map.elems typeMap) - constraints | null typeMap = empty - | otherwise = - tuple (concatMap - (\(t, aDef) -> - "TensorType" <+> strictText t - : maybeToList (oneOfRestrictions aDef t)) - (Map.elems typeMap)) <+> "=>" - refVariableNames = catMaybes (map fst tensorInputs) - tensorInputs = zipWith tensorArg refTypes (d ^. inputArg) - refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]] - tensorArg refType arg = wrapArg refType arg <**> - pure (<+> hang 0 ("-- ^" <+> argComment arg)) - -- Argument type is a list of tensors if number_attr is set; - -- otherwise it's a single Tensor. - wrapArg refType arg = - if Text.null (arg ^. numberAttr) then typ else brackets <$> typ - where typ = tensorType refType arg - -- The result is (reference type variable if any, type representing the arg) - tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc) - tensorType refType arg - -- Deals with resource handles that are unlike tensors and - -- have their own type level representation. The magic "dtype" - -- name is the name of the attribute specifying the type of - -- the resource handle. It is not referenced by resource input - -- or output because it isn't expressible in TF operation - -- signatures. - | (arg ^. type' == DT_RESOURCE) = - (Nothing, strictText "ResourceHandle dtype") - | otherwise = - (Just refType, - "Tensor" <+> refType <+> maybe directType strictText indirectType) - where - -- This is the case of a named parameter type which is - -- constrained as a OneOf. - indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap) - -- The nominal case when the type name is given directly. - directType = dtTypeToDoc (arg ^. type') - outputs = - case d ^. outputArg of - [] -> "ControlNode" - [o] -> wrappedOutput o <+> "-- ^" <+> argComment o - os -> renderTupleResult os - wrappedOutput = snd . wrapArg "Value" - -- Tuple result case is rendered differently to give - -- individual elements their own comments. - renderTupleResult os = - stack $ [ tuple (map wrappedOutput os) - , flatten commentSummary - ] ++ map commentDetails os - where - commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os] - commentDetails o = - stack [ "--" - , "-- *" <+> argComment o - ] + constraints + | null (inferredTypeAttrs pOp) = empty + | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" + typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, + ArgTensorEither v <- [parsedArgKind k]] + ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] + classConstraints = tuple $ concatMap tensorArgConstraint + $ inferredTypeAttrs pOp signatureFold = folddoc (\x y -> x "->" <+> y) - mandatoryAttrInputs = [ - dtTypeToDoc dtType <+> - hang 0 ("-- ^" <+> argComment' tfName descr) - | (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d] - typeMap = opDefTypeMap d + attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a) + renderAttrType (AttrSingle a) = renderAttrBaseType a + renderAttrType (AttrList a) = brackets $ renderAttrBaseType a + renderAttrBaseType = \case + AttrBytes -> "ByteString" + AttrInt64 -> "Data.Int.Int64" + AttrFloat -> "Float" + AttrBool -> "Bool" + AttrType -> "DataType" + AttrShape -> "TensorShapeProto" + AttrTensor -> "TensorProto" --- | Returns the type restriction for the given tensor type if the --- set of allowed types is not empty (i.e. restricted). -oneOfRestrictions :: AttrDef -> Text -> Maybe Doc -oneOfRestrictions aDef tName = do - typs <- onAttrType (^. templateRestrictions) aDef - guard $ not $ null typs - let typeList = commasep $ map strictText $ - Set.toList $ Set.fromList $ - map dtTypeToHaskell typs - return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName + tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t) + outputs = case parsedOutputs pOp of + [] -> "ControlNode" + -- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a` + [a] -> tensorArg a <+> "-- ^" <+> argComment a + as -> tuple (map tensorArg as) <+/> resultComment as + +-- | Render an op input or output. +-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype" +tensorArg :: ParsedArg -> Doc +tensorArg p = case parsedArgCase p of + SimpleArg { argType = t } -> tensorType t + ListArg { argType = t } -> brackets $ tensorType t + MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" + where + tensorType t = let + v = case parsedArgKind p of + ArgTensorRef -> "Tensor Ref" + ArgTensorValue -> "Tensor Value" + ArgTensorEither v' -> "Tensor" <+> strictText v' + ArgResource -> "ResourceHandle" + a = case t of + ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt + ArgTypeAttr n -> renderHaskellName n + in v <+> a --- | Identifies the attributes used as tensor cardinalities. In such --- cases a list of tensors is supplied as an input_arg. The number of --- such inputs is communicated as a separate opAttr. --- The result key is TensorFlow attribute name and the value is the --- tensor names which have number_attr set to the result key. -numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text) -numberAttrMap d = Map.fromListWith (Semigroup.<>) [ - (nAttr, replaceReservedName (inp ^. name) :| []) - | inp <- d ^. inputArg - , nAttr <- [inp ^. numberAttr] - , not (Text.null nAttr) - ] +attrComment :: Attr a -> Doc +attrComment a = argComment' (attrName a) (attrDescription a) + +argComment :: ParsedArg -> Doc +argComment a = argComment' (parsedArgName a) (parsedArgDescription a) -argComment :: OpDef'ArgDef -> Doc -argComment arg = argComment' (arg ^. name) (arg ^. description) - -argComment' :: Text.Text -> Text.Text -> Doc +argComment' :: Name -> Text.Text -> Doc argComment' argName argDesc = - bold argName <> splitMultilineText (":" <+>) argDesc + bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc -bold :: Text.Text -> Doc -bold n = strictText ("__" <> n <> "__") +bold :: Doc -> Doc +bold n = "__" <> n <> "__" -type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef) +-- | Comment for the outputs of an op. +-- For example: +-- -- ^ (__output1__, __output2__) +-- -- +-- -- * __output1__: description1 +-- -- +-- -- * __output2__: description2 +resultComment :: [ParsedArg] -> Doc +resultComment os = stack $ flatten commentSummary : map commentDetails os + where + commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os] + commentDetails o = + stack [ "--" + , "-- *" <+> argComment o + ] --- | Returns the map of type parameters from OpDef type name to --- (haskell friendly type name, the type's attribute definition). -opDefTypeMap :: OpDef -> OpDefTypeMap -opDefTypeMap d = - Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a] - -attrList :: OpDef -> [(Text.Text, AttrDef)] -attrList d = [(a ^. name, attrDef a) | a <- d ^. attr] - -isType :: AttrDef -> Bool -isType = fromMaybe False . onAttrType (const True) - --- | Applies the given function to the data type. Is this a Prism? -onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a -onAttrType f x = case x ^. attrTemplate of - AttrSingle (AttrType a) -> Just (f a) - _ -> Nothing - --- | mandatoryAttrMap contains the attributes chosen by --- isMandatoryAttr, excluding those which are derived from list of --- tensor arguments. The key is the TF name of the attribute. The --- value tuple is (haskell name, TF type, attribute comment). -mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text) -mandatoryAttrMap d = - Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description)) - | (n, a) <- attrList d - , Just dtType <- [isMandatoryAttr a] - -- Excludes the attributes rendered as list lengths. - , n `Map.notMember` numberAttrMap d - ] - --- | Inspects the attribute and if it is one of the implemented --- non-tensor values lacking default, then returns Just the TF type. -isMandatoryAttr :: AttrDef -> Maybe DataType -isMandatoryAttr x = - case x ^. attrTemplate of - AttrSingle (AttrBool y) -> noDefault DT_BOOL y - AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y - AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y - _ -> Nothing - where - noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault) - -dtTypeToDoc :: DataType -> Doc -dtTypeToDoc = strictText . dtTypeToHaskell +-- | Constraints for a given type parameter. +-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"] +tensorArgConstraint :: Attr [DataType] -> [Doc] +tensorArgConstraint a + = ("TensorType" <+> n + : if null typeList + then [] + else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n]) + where + n = renderHaskellName $ attrName a + typeList = map strictText $ + Set.toList $ Set.fromList $ + map dtTypeToHaskell $ attrInfo a -- NOTE: The cases of this function should be kept in sync with -- TensorFlow.Types.AllTensorTypes. @@ -429,6 +363,12 @@ dtTypeToHaskell x = haddockComment :: Text.Text -> Doc haddockComment = strictText +-- | Generate a multiline comment. For example: +-- summary' +-- -- +-- -- detail_line1 +-- -- detail_line2 +-- -- ... multilineComment :: Text.Text -> Text.Text -> Doc multilineComment summary' detail = haddockComment summary' @@ -445,45 +385,5 @@ splitMultilineText lead detail = (l : ls) -> stack $ lead (haddockComment l) : map (("--" <+>) . haddockComment) ls -replaceReservedName :: Text -> Text -replaceReservedName n - | n `Set.member` reservedKeywords = n <> "'" - | otherwise = n - indentation :: Int indentation = 4 - -reservedKeywords :: Set.Set Text -reservedKeywords = Set.fromList $ - -- Haskell2010 keywords: - -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4 - -- We don't include keywords that are allowed to be variable names, - -- in particular: "as", "forall", and "hiding". - [ "case" - , "class" - , "data" - , "default" - , "deriving" - , "do" - , "else" - , "foreign" - , "if" - , "import" - , "in" - , "infix" - , "infixl" - , "infixr" - , "instance" - , "let" - , "module" - , "newtype" - , "of" - , "then" - , "type" - , "where" - ] - ++ -- Nonstandard extensions - [ "mdo" -- RecursiveDo - , "rec" -- Arrows, RecursiveDo - , "proc" -- Arrows - ] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs b/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs deleted file mode 100644 index 997a908..0000000 --- a/tensorflow-opgen/src/TensorFlow/OpGen/AttrVal.hs +++ /dev/null @@ -1,120 +0,0 @@ --- Copyright 2016 TensorFlow authors. --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - -{-# LANGUAGE OverloadedStrings #-} - --- | Wrapping of TensorFlow attributes into Haskell entities. -module TensorFlow.OpGen.AttrVal - (AttrDef - , AttrCase(..) - , AttrTemplate(..) - , Template - , attrDef - , attrOriginal - , attrTemplate - , templateDefault - , templateRestrictions - ) where - -import Data.Int (Int64) -import Data.Monoid ((<>)) -import Lens.Family2 (Lens', (^.)) -import Lens.Family2.Unchecked (lens) -import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue -import Proto.Tensorflow.Core.Framework.OpDef as OpDef -import Proto.Tensorflow.Core.Framework.Types (DataType(..)) -import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto) -import qualified Data.ByteString as B -import qualified Data.Text as Text - --- | Specifies the optional default value and a set of allowed values --- for the given type. -data Template a = Template { - _templateDefault :: Maybe a - -- ^ The default value (mandatory if unspecified) - , _templateRestrictions :: [a] - -- ^ The allowed set of values, empty if no restrictions - } - -templateDefault :: Lens' (Template a) (Maybe a) -templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x }) - -templateRestrictions :: Lens' (Template a) [a] -templateRestrictions = lens _templateRestrictions - (\g x -> g { _templateRestrictions = x }) - -data UnusedTensor - -data AttrCase f - = AttrBytes (f B.ByteString) -- bytes s = 2; // "string" - | AttrInt64 (f Int64) -- int64 i = 3; // "int" - | AttrFloat (f Float) -- float f = 4; // "float" - | AttrBool (f Bool) -- bool b = 5; // "bool" - | AttrType (f DataType) -- type = 6; // "type" - -- To be translated into TensorFlow.Types.Shape before use. - -- Leaving as a proto to reduce dependencies. - | AttrShape (f TensorShapeProto) -- shape = 7; // "shape" - --- | Type-reified representation of TensorFlow AttrDef. --- Initially limited to just the types in Op descriptors. -data AttrTemplate - = AttrSingle (AttrCase Template) - | AttrList (AttrCase []) - | AttrTensor UnusedTensor -- tensor = 8; // "tensor" - -data AttrDef = AttrDef { - _attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from - , _attrTemplate :: AttrTemplate -- ^ the type of the attribute - } - -attrTemplate :: Lens' AttrDef AttrTemplate -attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x }) - -attrOriginal :: Lens' AttrDef OpDef'AttrDef -attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x }) - -attrDef :: OpDef'AttrDef -> AttrDef -attrDef a = AttrDef a - $ translate (a^.OpDef.type') - (a^.OpDef.defaultValue) - (a^.allowedValues) - --- | Converts the given AttrValue with the type given by the string --- into the AttrVal if the type is known. -translate :: Text.Text -- ^ one of the TensorFlow type strings - -> AttrValue -- ^ default value - -> AttrValue -- ^ allowed values - -> AttrTemplate -translate t defaults allowed - | t == "string" = makeVal AttrBytes maybe's s - | t == "int" = makeVal AttrInt64 maybe'i i - | t == "float" = makeVal AttrFloat maybe'f f - | t == "bool" = makeVal AttrBool maybe'b b - | t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type' - | t == "shape" = makeVal AttrShape maybe'shape shape - | t == "tensor" = AttrTensor $ error "tensor is unimplemented" - | t == "list(string)" = makeList AttrBytes $ list.s - | t == "list(int)" = makeList AttrInt64 $ list.i - | t == "list(float)" = makeList AttrFloat $ list.f - | t == "list(bool)" = makeList AttrBool $ list.b - | t == "list(type)" = makeList AttrType $ list.AttrValue.type' - | t == "list(shape)" = makeList AttrShape $ list.shape - | t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented" - | t == "func" = AttrTensor $ error "func is unimplemented" - | otherwise = error $ show ("Unknown attribute type " <> t) ++ - "," ++ show defaults ++ - "," ++ show allowed - where makeVal c x y = AttrSingle $ c $ - Template (defaults^.x) (allowed^.list.y) - makeList c y = AttrList $ c $ defaults^.y diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs new file mode 100644 index 0000000..1d69fea --- /dev/null +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -0,0 +1,299 @@ +-- | This module helps parse the proto OpDef into a Haskell type which is more +-- descriptive of how the attributes and arguments will be used in the +-- generated code. +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +module TensorFlow.OpGen.ParsedOp + ( ParsedOp(..) + , Name(..) + , HaskellName(..) + , TFName(..) + , Attr(..) + , AttrType(..) + , AttrBaseType(..) + , ParsedArg(..) + , ParsedArgCase(..) + , ArgType(..) + , ArgKind(..) + , parseOp + , camelCase + ) where + +import Data.Char (toUpper, toLower) +import Data.List (sortBy) +import Data.List.NonEmpty (NonEmpty, nonEmpty) +import Data.Maybe (mapMaybe) +import Data.Monoid ((<>)) +import Data.Ord (comparing) +import qualified Data.Set as Set +import Data.Text (Text) +import qualified Data.Text as Text +import Lens.Family2 ((^.)) +import Proto.Tensorflow.Core.Framework.AttrValue (list) +import Proto.Tensorflow.Core.Framework.OpDef + ( OpDef + , OpDef'ArgDef + , OpDef'AttrDef + , allowedValues + , attr + , maybe'defaultValue + , description + , name + , inputArg + , outputArg + , summary + , typeListAttr + , numberAttr + , typeAttr + , type' + ) +import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE)) + +data ParsedOp = ParsedOp + { parsedOpName :: Name + , parsedOpSummary :: Text + , parsedOpDescription :: Text + , parsedInputs :: [ParsedArg] + , parsedOutputs :: [ParsedArg] + , explicitInputAttrs :: [Attr AttrType] + -- ^ Attributes that must be set explicitly when creating the op. + -- Associated with the type of the attribute. + , inferredTypeAttrs :: [Attr [DataType]] + -- ^ Attributes that are type parameters. + -- Associated with the list of allowed types (see: TensorFlow.Types.OneOf). + -- If this list is empty, then any type is acceptable. + , inferredListSizeAttrs :: [Attr (NonEmpty Name)] + -- Attributes which are list sizes (ints) that are inferred automatically + -- from one or more of the input tensors. + -- Associated with the list of tensors whose size it describes. + } + +data Name = Name + { haskellName :: HaskellName + , tfName :: TFName + } + +-- | A raw name as specified in the OpDef proto. +newtype TFName = TFName { unTFName :: Text } + deriving (Eq, Ord) + +-- | A name that's appropriate for a variable in a Haskell source file. +newtype HaskellName = HaskellName { unHaskellName :: Text } + +-- | A named attribute, associated with some information about it. +data Attr a = Attr + { attrName :: Name + , attrDescription :: Text + , attrInfo :: a + } + +-- | The type of an attribute. +data AttrType = AttrSingle AttrBaseType + | AttrList AttrBaseType + deriving Eq + +data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool + | AttrType | AttrShape | AttrTensor + deriving Eq + +-- | An input or output argument (Tensor) for an op. +data ParsedArg = ParsedArg + { parsedArgName :: Name + , parsedArgDescription :: Text + , parsedArgCase :: ParsedArgCase + , parsedArgKind :: ArgKind + } + +data ParsedArgCase + = SimpleArg { argType :: ArgType } + | ListArg + { argLength :: Name -- ^ The attribute that specifies this list's length. + , argType :: ArgType + } + | MixedListArg { argTypeAttr :: Name } + -- ^ A heterogeneous list. + -- TODO(judahjacobson): Implement this. + +-- | The type of an argument. +data ArgType + = ArgTypeFixed DataType -- ^ A fixed type. + | ArgTypeAttr Name -- ^ A type that depends on an attribute. + +-- The kind of an op input or output (not including the argument type `a`). +data ArgKind + = ArgTensorRef -- Tensor Ref a + | ArgTensorValue -- Tensor Value a + | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` + | ArgResource -- Resource a + + +makeName :: Text -> Name +makeName n = Name + { haskellName = HaskellName $ fixReservedName $ lowCase n + , tfName = TFName n + } + +-- | Change a name so it doesn't conflict with any Haskell keywords. +fixReservedName :: Text -> Text +fixReservedName n + | n `Set.member` reservedKeywords = n <> "'" + | otherwise = n + +reservedKeywords :: Set.Set Text +reservedKeywords = Set.fromList $ + -- Haskell2010 keywords: + -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4 + -- We don't include keywords that are allowed to be variable names, + -- in particular: "as", "forall", and "hiding". + [ "case" + , "class" + , "data" + , "default" + , "deriving" + , "do" + , "else" + , "foreign" + , "if" + , "import" + , "in" + , "infix" + , "infixl" + , "infixr" + , "instance" + , "let" + , "module" + , "newtype" + , "of" + , "then" + , "type" + , "where" + ] + ++ -- Nonstandard extensions + [ "mdo" -- RecursiveDo + , "rec" -- Arrows, RecursiveDo + , "proc" -- Arrows + ] + +-- | Lower-case the given text. +lowCase :: Text -> Text +lowCase = forceCase toLower + +forceCase :: (Char -> Char) -> Text -> Text +forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) + (Text.uncons s) + +camelCase :: Text -> Text +camelCase s = Text.concat $ map upCase + $ Text.splitOn "_" s + +-- | Upper-case the given text. +upCase :: Text -> Text +upCase = forceCase toUpper + + +parseOp :: OpDef -> ParsedOp +parseOp o = ParsedOp + { parsedOpName = makeName $ o ^. name + , parsedOpSummary = o ^. summary + , parsedOpDescription = o ^. description + , .. + } + where + parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v)) + (o ^. inputArg) tensorKindParams + tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] + parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) + explicitInputAttrs = sortBy (comparing (tfName . attrName)) + $ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) + $ o ^. attr + inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr + inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) + $ o ^. attr + implicitAttrs = Set.fromList $ map tfName $ + map attrName inferredTypeAttrs + ++ map attrName inferredListSizeAttrs + +-- TODO(judahjacobson): Some arguments should be refs. +inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind +inputTensorKind a v + | a ^. type' == DT_RESOURCE = ArgResource + | otherwise = ArgTensorEither v + +outputTensorKind :: OpDef'ArgDef -> ArgKind +outputTensorKind a + | a ^. type' == DT_RESOURCE = ArgResource + | otherwise = ArgTensorValue + +getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType +getExplicitInputAttr implicitAttrs a + | TFName (a ^. name) `Set.notMember` implicitAttrs + , a ^. maybe'defaultValue == Nothing + , t <- parseAttrType (a ^. type') + , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t + | otherwise = Nothing + +getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] +getInferredTypeAttr a + | a ^. type' == "type" = Just $ a ^. allowedValues . list . type' + | otherwise = Nothing + +getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name) +getInferredListSizeAttr inputs a + | a ^. type' == "int" + = nonEmpty [t | ParsedArg { parsedArgName = t + , parsedArgCase + = ListArg { argLength = n } + } <- inputs + , TFName (a ^. name) == tfName n] + | otherwise = Nothing + +-- | Like mapMaybe, but associates the attribute name/description with the given info. +mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a] +mapMaybeAttrs f = mapMaybe $ \a -> do + x <- f a + Just Attr + { attrName = makeName (a ^. name) + , attrDescription = a ^. description + , attrInfo = x + } + +parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg +parseArg a tKind = ParsedArg + { parsedArgName = makeName (a ^. name) + , parsedArgDescription = a ^. description + , parsedArgCase = parseArgCase a + , parsedArgKind = tKind + } + +parseArgCase :: OpDef'ArgDef -> ParsedArgCase +parseArgCase a + | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n + | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType + | otherwise = SimpleArg thisArgType + where + thisArgType + | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n + | a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype") + | otherwise = ArgTypeFixed (a ^. type') + maybeAttr :: Text -> Maybe Name + maybeAttr "" = Nothing + maybeAttr t = Just $ makeName t + +parseAttrType :: Text -> AttrType +parseAttrType = \case + "string" -> AttrSingle AttrBytes + "int" -> AttrSingle AttrInt64 + "float" -> AttrSingle AttrFloat + "bool" -> AttrSingle AttrBool + "type" -> AttrSingle AttrType + "shape" -> AttrSingle AttrShape + "tensor" -> AttrSingle AttrTensor + "list(string)" -> AttrList AttrBytes + "list(int)" -> AttrList AttrInt64 + "list(float)" -> AttrList AttrFloat + "list(bool)" -> AttrList AttrBool + "list(type)" -> AttrList AttrType + "list(shape)" -> AttrList AttrShape + "list(tensor)" -> AttrList AttrTensor + t -> error $ "parseAttrType: unrecognized type " ++ show t diff --git a/tensorflow-opgen/tensorflow-opgen.cabal b/tensorflow-opgen/tensorflow-opgen.cabal index 4028799..39289bd 100644 --- a/tensorflow-opgen/tensorflow-opgen.cabal +++ b/tensorflow-opgen/tensorflow-opgen.cabal @@ -13,7 +13,7 @@ cabal-version: >=1.22 library hs-source-dirs: src - exposed-modules: TensorFlow.OpGen.AttrVal + exposed-modules: TensorFlow.OpGen.ParsedOp , TensorFlow.OpGen build-depends: proto-lens == 0.1.* , tensorflow-proto == 0.1.* diff --git a/tensorflow-ops/tests/ArrayOpsTest.hs b/tensorflow-ops/tests/ArrayOpsTest.hs index 0801e11..1b32b03 100644 --- a/tensorflow-ops/tests/ArrayOpsTest.hs +++ b/tensorflow-ops/tests/ArrayOpsTest.hs @@ -16,6 +16,7 @@ module Main where import Control.Monad.IO.Class (liftIO) +import Data.Int (Int64) import Google.Test (googleTest) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) @@ -24,6 +25,8 @@ import qualified Data.Vector as V import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF +import qualified TensorFlow.Tensor as TF +import qualified TensorFlow.Types as TF import qualified TensorFlow.GenOps.Core as CoreOps -- | Test split and concat are inverses. @@ -34,11 +37,18 @@ testSplit = testCase "testSplit" $ TF.runSession $ do restored = CoreOps.concat dim splitList dim = 1 -- dimension to split along (with size of 3 in original) liftIO $ 3 @=? length splitList - (x, y, z) <- - TF.buildAnd TF.run $ return (original, restored, splitList !! 1) + (x, y, z) <- TF.run (original, restored, splitList !! 1) liftIO $ x @=? (y :: V.Vector Float) liftIO $ V.fromList [1, 4] @=? z +testShapeN :: Test +testShapeN = testCase "testShapeN" $ TF.runSession $ do + let shapes = map TF.Shape [[1],[2,3]] + let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float] + result <- TF.run $ CoreOps.shapeN tensors + liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64]) + main :: IO () main = googleTest [ testSplit + , testShapeN ]