{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.OpGen
( OpGenFlags(..)
, docOpList
, flagParser)
where
import Data.Foldable (toList)
import Data.Maybe (fromMaybe)
import Data.ProtoLens (def, showMessage)
import Data.List (sortOn)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
( OpList
, OpDef
, attr
, inputArg
, name
, op
, outputArg
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.ParsedOp
import Text.PrettyPrint.Mainland
( Doc
, (<>)
, (<+>)
, (</>)
, (<+/>)
, brackets
, comma
, commasep
, dquotes
, empty
, enclose
, flatten
, folddoc
, hang
, indent
, parens
, sep
, stack
, strictText
, tuple
)
import qualified Data.Set as Set
import qualified Data.Text as Text
data OpGenFlags = OpGenFlags
{ outputFile :: String
, prefix :: String
, excludeList :: String
}
flagParser :: Parser OpGenFlags
flagParser = OpGenFlags
<$> strOption (mconcat [ long "output"
, help "File to write."
])
<*> strOption (mconcat [ long "prefix"
, help "Haskell package prefix to use"
])
<*> strOption (mconcat [ long "exclude_list"
, value ""
, help "Comma separated Ops names to ignore"
])
docOpList :: OpGenFlags -> OpList -> Doc
docOpList flags opList =
stack [ "{-# LANGUAGE ConstraintKinds #-}"
, "{-# LANGUAGE DataKinds #-}"
, "{-# LANGUAGE FlexibleContexts #-}"
, "{-# LANGUAGE FlexibleInstances #-}"
, "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}"
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
, "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
, "module" <+> strictText moduleName <+> "where"
, empty
, imports
, empty
, folddoc (\x y -> x </> empty </> y)
(map renderOpAndExtras $
sortOn (view name) $
filter (not . flip elem exclusions . view name) $
toList $ opList ^. op)
]
where moduleName =
Text.pack (prefix flags) <> "." <> camelCase
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
renderOpAndExtras o = renderOp (parseOp o) </> extras o
imports :: Doc
imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Proxy (Proxy(Proxy))"
, "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
renderHaskellName = strictText . unHaskellName . haskellName
renderTFName = strictText . unTFName . tfName
renderQuotedTFName = dquotes . renderTFName
renderOp :: ParsedOp -> Doc
renderOp pOp = stack $
[ haddocks
#if __GLASGOW_HASKELL__ < 800
, "{-# NOINLINE" <+> n <+> "#-}"
#endif
, n <+> "::" <+> hang 0 (typeSig empty pOp)
, n <+> "=" <+> n <> "' id"
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </>
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where
n = renderHaskellName $ parsedOpName pOp
n' = n <> "'"
listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ "op'options"
: (map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp))
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
funcGuard :: [Attr (NonEmpty Name)] -> Doc
funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
where
entries =
[ parens $ nAttr <> comma <+>
brackets (commasep $ toList $
map renderTensorName (toList $ attrInfo a))
| a <- attrs
, let nAttr = renderQuotedTFName (attrName a)
]
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
"length" <+> renderHaskellName x
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellAttrName a <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp
| parsedOpIsMonadic pOp
= "build $ do"
</> indent indentation (bindOpInputsVar
</> "buildOp" <+> outputListsSizes <+> opDef)
| otherwise
= "pureOp" <+> outputListsSizes <+> "$ do"
</> indent indentation (bindOpInputsVar </> "return" <+> opDef)
where
outputListsSizes = brackets $ commasep
[ renderHaskellName a
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs pOp
]
opInputsVar = "op'inputs"
bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence"
<+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs)
opDef = parens $ hang 0 $ stack $
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
| a <- inferredTypeAttrs pOp, let n = attrName a
] ++
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- explicitInputAttrs pOp, let n = attrName a
] ++
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a
] ++
["& op'options & opInputs .~" <+> opInputsVar]
tensorArgs = renderTensorArg <$> parsedInputs pOp
renderTensorArg = renderHaskellName . parsedArgName
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
<> ")"
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
<> ")"
extras :: OpDef -> Doc
extras d = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
typeSig :: Doc -> ParsedOp -> Doc
typeSig pre pOp = constraints
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where
constraints
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
ArgSomeTensor v <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else []
monadConstraint
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
| otherwise = []
classConstraints = monadConstraint ++ map tensorArgConstraint
(inferredTypeAttrs pOp)
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
renderAttrType (AttrList a) = brackets $ renderAttrBaseType a
renderAttrBaseType = \case
AttrBytes -> "ByteString"
AttrInt64 -> "Data.Int.Int64"
AttrFloat -> "Float"
AttrBool -> "Bool"
AttrType -> "DataType"
AttrShape -> "Shape"
AttrTensor -> "TensorProto"
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of
[] -> wrapOutput "ControlNode"
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "m'" <+> parens o
| otherwise = o
tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of
SimpleArg { argType = t, argKind = k } -> tensorType t k
ListArg { argType = t, argKind = k } -> brackets $ tensorType t k
MixedListArg {argTypeAttr = t, argKind = k}
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
where
kind k = case k of
ArgTensorRef -> "Ref"
ArgTensorValue -> "Value"
ArgTensorBuild -> "Build"
ArgSomeTensor v -> strictText v
tensorType t k = let
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n
in "Tensor" <+> kind k <+> a
attrComment :: Attr a -> Doc
attrComment a = argComment' (attrName a) (attrDescription a)
argComment :: ParsedArg -> Doc
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
argComment' :: Name -> Text.Text -> Doc
argComment' argName argDesc =
bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc
bold :: Doc -> Doc
bold n = "__" <> n <> "__"
resultComment :: [ParsedArg] -> Doc
resultComment os = stack $ flatten commentSummary : map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
tensorArgConstraint :: Attr TypeParam -> Doc
tensorArgConstraint a = case attrInfo a of
TypeParam False Nothing -> "TensorType" <+> n
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
TypeParam True Nothing -> "TensorTypes" <+> n
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
where
n = renderHaskellAttrName a
typeList = ("'" <>) . brackets . commasep . map strictText .
Set.toList . Set.fromList .
map dtTypeToHaskell . toList
dtTypeToHaskell :: DataType -> Text.Text
dtTypeToHaskell DT_BOOL = "Bool"
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
dtTypeToHaskell DT_DOUBLE = "Double"
dtTypeToHaskell DT_FLOAT = "Float"
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32"
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16"
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16"
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
haddockComment :: Text.Text -> Doc
haddockComment = strictText
multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail =
haddockComment summary' </>
splitMultilineText insertParagraphAndComment detail
where insertParagraphAndComment x = "--" </> "--" <+> x
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
splitMultilineText lead detail =
case Text.lines detail of
[] -> empty
(l : ls) -> stack $ lead (haddockComment l)
: map (("--" <+>) . haddockComment) ls
indentation :: Int
indentation = 4