-- | This module helps parse the proto OpDef into a Haskell type which is more -- descriptive of how the attributes and arguments will be used in the -- generated code. {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module TensorFlow.OpGen.ParsedOp ( ParsedOp(..) , Name(..) , HaskellName(..) , TFName(..) , Attr(..) , AttrType(..) , AttrBaseType(..) , ParsedArg(..) , ParsedArgCase(..) , ArgType(..) , ArgKind(..) , parseOp , camelCase ) where import Data.Char (toUpper, toLower) import Data.List (sortBy) import Data.List.NonEmpty (NonEmpty, nonEmpty) import Data.Maybe (mapMaybe) import Data.Monoid ((<>)) import Data.Ord (comparing) import qualified Data.Set as Set import Data.Text (Text) import qualified Data.Text as Text import Lens.Family2 ((^.)) import Proto.Tensorflow.Core.Framework.AttrValue (list) import Proto.Tensorflow.Core.Framework.OpDef ( OpDef , OpDef'ArgDef , OpDef'AttrDef , allowedValues , attr , maybe'defaultValue , description , name , inputArg , outputArg , summary , typeListAttr , numberAttr , typeAttr , type' ) import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE)) data ParsedOp = ParsedOp { parsedOpName :: Name , parsedOpSummary :: Text , parsedOpDescription :: Text , parsedInputs :: [ParsedArg] , parsedOutputs :: [ParsedArg] , explicitInputAttrs :: [Attr AttrType] -- ^ Attributes that must be set explicitly when creating the op. -- Associated with the type of the attribute. , inferredTypeAttrs :: [Attr [DataType]] -- ^ Attributes that are type parameters. -- Associated with the list of allowed types (see: TensorFlow.Types.OneOf). -- If this list is empty, then any type is acceptable. , inferredListSizeAttrs :: [Attr (NonEmpty Name)] -- Attributes which are list sizes (ints) that are inferred automatically -- from one or more of the input tensors. -- Associated with the list of tensors whose size it describes. } data Name = Name { haskellName :: HaskellName , tfName :: TFName } -- | A raw name as specified in the OpDef proto. newtype TFName = TFName { unTFName :: Text } deriving (Eq, Ord) -- | A name that's appropriate for a variable in a Haskell source file. newtype HaskellName = HaskellName { unHaskellName :: Text } -- | A named attribute, associated with some information about it. data Attr a = Attr { attrName :: Name , attrDescription :: Text , attrInfo :: a } -- | The type of an attribute. data AttrType = AttrSingle AttrBaseType | AttrList AttrBaseType deriving Eq data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool | AttrType | AttrShape | AttrTensor deriving Eq -- | An input or output argument (Tensor) for an op. data ParsedArg = ParsedArg { parsedArgName :: Name , parsedArgDescription :: Text , parsedArgCase :: ParsedArgCase , parsedArgKind :: ArgKind } data ParsedArgCase = SimpleArg { argType :: ArgType } | ListArg { argLength :: Name -- ^ The attribute that specifies this list's length. , argType :: ArgType } | MixedListArg { argTypeAttr :: Name } -- ^ A heterogeneous list. -- TODO(judahjacobson): Implement this. -- | The type of an argument. data ArgType = ArgTypeFixed DataType -- ^ A fixed type. | ArgTypeAttr Name -- ^ A type that depends on an attribute. -- The kind of an op input or output (not including the argument type `a`). data ArgKind = ArgTensorRef -- Tensor Ref a | ArgTensorValue -- Tensor Value a | ArgTensorEither Text -- Tensor v a; the Text is the variable `v` | ArgResource -- Resource a makeName :: Text -> Name makeName n = Name { haskellName = HaskellName $ fixReservedName $ lowCase n , tfName = TFName n } -- | Change a name so it doesn't conflict with any Haskell keywords. fixReservedName :: Text -> Text fixReservedName n | n `Set.member` reservedKeywords = n <> "'" | otherwise = n reservedKeywords :: Set.Set Text reservedKeywords = Set.fromList $ -- Haskell2010 keywords: -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4 -- We don't include keywords that are allowed to be variable names, -- in particular: "as", "forall", and "hiding". [ "case" , "class" , "data" , "default" , "deriving" , "do" , "else" , "foreign" , "if" , "import" , "in" , "infix" , "infixl" , "infixr" , "instance" , "let" , "module" , "newtype" , "of" , "then" , "type" , "where" ] ++ -- Nonstandard extensions [ "mdo" -- RecursiveDo , "rec" -- Arrows, RecursiveDo , "proc" -- Arrows ] -- | Lower-case the given text. lowCase :: Text -> Text lowCase = forceCase toLower forceCase :: (Char -> Char) -> Text -> Text forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) (Text.uncons s) camelCase :: Text -> Text camelCase s = Text.concat $ map upCase $ Text.splitOn "_" s -- | Upper-case the given text. upCase :: Text -> Text upCase = forceCase toUpper parseOp :: OpDef -> ParsedOp parseOp o = ParsedOp { parsedOpName = makeName $ o ^. name , parsedOpSummary = o ^. summary , parsedOpDescription = o ^. description , .. } where parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v)) (o ^. inputArg) tensorKindParams tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) explicitInputAttrs = sortBy (comparing (tfName . attrName)) $ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) $ o ^. attr inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) $ o ^. attr implicitAttrs = Set.fromList $ map tfName $ map attrName inferredTypeAttrs ++ map attrName inferredListSizeAttrs -- TODO(judahjacobson): Some arguments should be refs. inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind a v | a ^. type' == DT_RESOURCE = ArgResource | otherwise = ArgTensorEither v outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind a | a ^. type' == DT_RESOURCE = ArgResource | otherwise = ArgTensorValue getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr implicitAttrs a | TFName (a ^. name) `Set.notMember` implicitAttrs , a ^. maybe'defaultValue == Nothing , t <- parseAttrType (a ^. type') , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t | otherwise = Nothing getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] getInferredTypeAttr a | a ^. type' == "type" = Just $ a ^. allowedValues . list . type' | otherwise = Nothing getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name) getInferredListSizeAttr inputs a | a ^. type' == "int" = nonEmpty [t | ParsedArg { parsedArgName = t , parsedArgCase = ListArg { argLength = n } } <- inputs , TFName (a ^. name) == tfName n] | otherwise = Nothing -- | Like mapMaybe, but associates the attribute name/description with the given info. mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a] mapMaybeAttrs f = mapMaybe $ \a -> do x <- f a Just Attr { attrName = makeName (a ^. name) , attrDescription = a ^. description , attrInfo = x } parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg parseArg a tKind = ParsedArg { parsedArgName = makeName (a ^. name) , parsedArgDescription = a ^. description , parsedArgCase = parseArgCase a , parsedArgKind = tKind } parseArgCase :: OpDef'ArgDef -> ParsedArgCase parseArgCase a | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType | otherwise = SimpleArg thisArgType where thisArgType | Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n | a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype") | otherwise = ArgTypeFixed (a ^. type') maybeAttr :: Text -> Maybe Name maybeAttr "" = Nothing maybeAttr t = Just $ makeName t parseAttrType :: Text -> AttrType parseAttrType = \case "string" -> AttrSingle AttrBytes "int" -> AttrSingle AttrInt64 "float" -> AttrSingle AttrFloat "bool" -> AttrSingle AttrBool "type" -> AttrSingle AttrType "shape" -> AttrSingle AttrShape "tensor" -> AttrSingle AttrTensor "list(string)" -> AttrList AttrBytes "list(int)" -> AttrList AttrInt64 "list(float)" -> AttrList AttrFloat "list(bool)" -> AttrList AttrBool "list(type)" -> AttrList AttrType "list(shape)" -> AttrList AttrShape "list(tensor)" -> AttrList AttrTensor t -> error $ "parseAttrType: unrecognized type " ++ show t