module TensorFlow.OpGen
( OpGenFlags(..)
, docOpList
, flagParser)
where
import Prelude hiding (head, tail)
import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList)
import Data.Maybe (fromMaybe, maybeToList)
import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
( OpList
, OpDef
, OpDef'ArgDef
, attr
, description
, inputArg
, name
, numberAttr
, op
, outputArg
, summary
, type'
, typeAttr
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
)
import Text.PrettyPrint.Mainland
( Doc
, (<>)
, (<+>)
, (</>)
, (<+/>)
, brackets
, comma
, commasep
, dquotes
, empty
, enclose
, flatten
, folddoc
, hang
, indent
, int
, parens
, sep
, stack
, strictText
, tuple
)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)
data OpGenFlags = OpGenFlags
{ outputFile :: String
, prefix :: String
, excludeList :: String
}
flagParser :: Parser OpGenFlags
flagParser = OpGenFlags
<$> strOption (mconcat [ long "output"
, help "File to write."
])
<*> strOption (mconcat [ long "prefix"
, help "Haskell package prefix to use"
])
<*> strOption (mconcat [ long "exclude_list"
, value ""
, help "Comma separated Ops names to ignore"
])
docOpList :: OpGenFlags -> OpList -> Doc
docOpList flags opList =
stack [ "{-# LANGUAGE ConstraintKinds #-}"
, "{-# LANGUAGE DataKinds #-}"
, "{-# LANGUAGE FlexibleInstances #-}"
, "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}"
, "module" <+> strictText moduleName <+> "where"
, empty
, imports
, empty
, folddoc (\x y -> x </> empty </> y)
(map renderDef $
filter (not . flip elem exclusions . view name) $
toList $ opList ^. op)
]
where moduleName =
Text.pack (prefix flags) <> "." <> camelCase
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
upCase :: Text -> Text
upCase = forceCase toUpper
lowCase :: Text -> Text
lowCase = replaceReservedName . forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
renderDef :: OpDef -> Doc
renderDef d =
stack [
haddocks
, n <+> "::" <+> hang 0 (typeSig d)
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </>
indent indentation functionBody
, extras
]
where
n = strictText $ fixOpName (d ^. name)
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
fixOpName = lowCase
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
where
entries =
[ parens $ quotedText nAttr <> comma <+>
brackets (commasep $ toList $
NE.map renderTensorName tensorNames)
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
]
renderTensorName x = parens $ quotedText x <> comma <+>
"length" <+> strictText x
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
outputListsSizes = [ strictText numberAttrName
| o <- d ^. outputArg
, let numberAttrName = o ^. numberAttr
, not (Text.null numberAttrName) &&
numberAttrName `Map.member` mandatoryAttrMap d
]
buildOpParts =
"opDef" <+> quotedText (d ^. name) :
[ "& opAttr" <+> quotedText tfName <+>
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
| (tfName, (hsName, _)) <- Map.toList typeMap
] ++
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
| (tfName, hsName) <- mandatoryAttrs
] ++
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
]
mandatoryAttrs = [(strictText tf, strictText hs)
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
]
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
extras = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
typeMap = opDefTypeMap d
quotedText :: Text.Text -> Doc
quotedText = dquotes . strictText
typeSig :: OpDef -> Doc
typeSig d =
foralls <+> constraints <+/>
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
where
foralls | Map.null typeMap = empty
| otherwise =
"forall"
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
<+> "."
constraints | Map.null typeMap = empty
| otherwise =
tuple (concatMap
(\(t, aDef) ->
"TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>"
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
tensorArg refType arg = wrapArg refType arg <+>
hang 0 ("-- ^" <+> argComment arg)
wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets typ
where typ = tensorType refType arg
tensorType refType arg =
"Tensor" <+> refType <+> maybe directType strictText indirectType
where
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
directType = dtTypeToDoc (arg ^. type')
outputs =
case d ^. outputArg of
[] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os
wrappedOutput = wrapArg "Value"
renderTupleResult os =
stack $ [ tuple (map wrappedOutput os)
, flatten commentSummary
] ++ map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
signatureFold = folddoc (\x y -> x </> "->" <+> y)
mandatoryAttrInputs = [
dtTypeToDoc dtType <+>
hang 0 ("-- ^" <+> argComment' tfName descr)
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
typeMap = opDefTypeMap d
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
oneOfRestrictions aDef tName = do
typs <- onAttrType (^. templateRestrictions) aDef
guard $ not $ null typs
let typeList = commasep $ map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell typs
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
(nAttr, replaceReservedName (inp ^. name) :| [])
| inp <- d ^. inputArg
, nAttr <- [inp ^. numberAttr]
, not (Text.null nAttr)
]
argComment :: OpDef'ArgDef -> Doc
argComment arg = argComment' (arg ^. name) (arg ^. description)
argComment' :: Text.Text -> Text.Text -> Doc
argComment' argName argDesc =
bold argName <> splitMultilineText (":" <+>) argDesc
bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__")
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
opDefTypeMap d =
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
attrList :: OpDef -> [(Text.Text, AttrDef)]
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
isType :: AttrDef -> Bool
isType = fromMaybe False . onAttrType (const True)
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
onAttrType f x = case x ^. attrTemplate of
AttrSingle (AttrType a) -> Just (f a)
_ -> Nothing
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
mandatoryAttrMap d =
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
| (n, a) <- attrList d
, Just dtType <- [isMandatoryAttr a]
, n `Map.notMember` numberAttrMap d
]
isMandatoryAttr :: AttrDef -> Maybe DataType
isMandatoryAttr x =
case x ^. attrTemplate of
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
_ -> Nothing
where
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
dtTypeToDoc :: DataType -> Doc
dtTypeToDoc = strictText . dtTypeToHaskell
dtTypeToHaskell :: DataType -> Text.Text
dtTypeToHaskell DT_BOOL = "Bool"
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
dtTypeToHaskell DT_DOUBLE = "Double"
dtTypeToHaskell DT_FLOAT = "Float"
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32"
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16"
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16"
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
haddockComment :: Text.Text -> Doc
haddockComment = strictText
multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail =
haddockComment summary' </>
splitMultilineText insertParagraphAndComment detail
where insertParagraphAndComment x = "--" </> "--" <+> x
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
splitMultilineText lead detail =
case Text.lines detail of
[] -> empty
(l : ls) -> stack $ lead (haddockComment l)
: map (("--" <+>) . haddockComment) ls
replaceReservedName :: Text -> Text
replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
indentation = 4
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++
[ "mdo"
, "rec"
, "proc"
]