1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Refactor OpGen. (#36)

Also fixes op lists when the same attribute specifies the length of
both an input and an output.  I added a test of "shapeN" which
previously failed with the following error:

    ERROR: Ran out of counts in toResult. Likely misuse of buildListOp.
This commit is contained in:
Judah Jacobson 2016-11-20 10:00:22 -08:00 committed by Greg Steuck
parent 2b5e41ffeb
commit a277c7ddb3
5 changed files with 502 additions and 413 deletions

View file

@ -13,6 +13,7 @@
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions.
@ -23,46 +24,25 @@ module TensorFlow.OpGen
, flagParser)
where
import Prelude hiding (head, tail)
import Control.Applicative ((<**>))
import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList)
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
import Data.Maybe (fromMaybe)
import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
( OpList
, OpDef
, OpDef'ArgDef
, attr
, description
, inputArg
, name
, numberAttr
, op
, outputArg
, summary
, type'
, typeAttr
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
)
import TensorFlow.OpGen.ParsedOp
import Text.PrettyPrint.Mainland
( Doc
, (<>)
@ -79,18 +59,14 @@ import Text.PrettyPrint.Mainland
, folddoc
, hang
, indent
, int
, parens
, sep
, stack
, strictText
, tuple
)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)
data OpGenFlags = OpGenFlags
{ outputFile :: String
@ -118,7 +94,6 @@ docOpList flags opList =
, "{-# LANGUAGE DataKinds #-}"
, "{-# LANGUAGE FlexibleInstances #-}"
, "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}"
-- Avoids reports about shadowing standard library names.
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
@ -129,34 +104,17 @@ docOpList flags opList =
, imports
, empty
, folddoc (\x y -> x </> empty </> y)
(map renderDef $
(map renderOpAndExtras $
filter (not . flip elem exclusions . view name) $
toList $ opList ^. op)
]
where moduleName =
Text.pack (prefix flags) <> "." <> camelCase
-- Discards the optional trailing _op_lib
(fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
-- Discards the optional trailing _ops_op_lib
(fromMaybe shortName (Text.stripSuffix "_ops_op_lib" shortName))
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
-- | Lower-case the given name, and prevent it from overlapping with a reserved
-- Haskell name.
lowCase :: Text -> Text
lowCase = replaceReservedName . forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
renderOpAndExtras o = renderOp (parseOp o) </> extras o
imports :: Doc
imports = stack [
@ -170,232 +128,208 @@ imports = stack [
, "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
renderDef :: OpDef -> Doc
renderDef d =
stack [
haddocks
, n <+> "::" <+> hang 0 (typeSig d)
, n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation functionBody
, extras -- just for debug
]
renderHaskellName, renderTFName, renderQuotedTFName :: Name -> Doc
renderHaskellName = strictText . unHaskellName . haskellName
renderTFName = strictText . unTFName . tfName
renderQuotedTFName = dquotes . renderTFName
-- | Generate the source code for a single op.
-- For example:
--
-- -- | {haddock comment}
-- foo :: {type sig}
-- foo attr1 attr2 input1 input2 | eqLengthGuard [...] = {function body}
renderOp :: ParsedOp -> Doc
renderOp pOp = stack $
[ haddocks
, n <+> "::" <+> hang 0 (typeSig pOp)
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where
n = strictText $ fixOpName (d ^. name)
args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
fixOpName = lowCase
funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
n = renderHaskellName $ parsedOpName pOp
listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp)
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length.
-- For example:
-- eqLengthGuard [("N", [("input1", length input1), ("input2", length input2)])]
funcGuard :: [Attr (NonEmpty Name)] -> Doc
funcGuard attrs = "eqLengthGuard" <+> brackets (commasep entries)
where
entries =
[ parens $ quotedText nAttr <> comma <+>
[ parens $ nAttr <> comma <+>
brackets (commasep $ toList $
NE.map renderTensorName tensorNames)
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
map renderTensorName (toList $ attrInfo a))
| a <- attrs
, let nAttr = renderQuotedTFName (attrName a)
]
renderTensorName x = parens $ quotedText x <> comma <+>
"length" <+> strictText x
-- Uses hang 0 to align the argument vertically on multiple lines.
functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
renderTensorName x = parens $ renderQuotedTFName x <> comma <+>
"length" <+> renderHaskellName x
-- | Define the implicit list length attributes.
-- For example:
-- where
-- n1 = fromIntegral (length input1) :: Int64
-- n2 = fromIntegral (length input2) :: Int64
whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
outputListsSizes = [ strictText numberAttrName
| o <- d ^. outputArg
, let numberAttrName = o ^. numberAttr
, not (Text.null numberAttrName) &&
numberAttrName `Map.member` mandatoryAttrMap d
]
| otherwise = "buildListOp" <+>
brackets (commasep $
map renderHaskellName outputListsSizes)
outputListsSizes = [ a
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs pOp]
buildOpParts =
"opDef" <+> quotedText (d ^. name) :
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments.
[ "& opAttr" <+> quotedText tfName <+>
".~ tensorType (undefined ::" <+> strictText hsName <> ")"
| (tfName, (hsName, _)) <- Map.toList typeMap
[ "& opAttr" <+> renderQuotedTFName n <+>
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
| a <- inferredTypeAttrs pOp, let n = attrName a
] ++
-- Renders mandatory attributes as function parameters.
[ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
| (tfName, hsName) <- mandatoryAttrs
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- explicitInputAttrs pOp, let n = attrName a
] ++
-- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
"(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
| (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a
]
mandatoryAttrs = [(strictText tf, strictText hs)
| (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
]
haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
extras = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
typeMap = opDefTypeMap d
-- | Makes a quoted string doc out of the given text value.
quotedText :: Text.Text -> Doc
quotedText = dquotes . strictText
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
-- | typeSig renders the type signature of the given OpDef.
typeSig :: OpDef -> Doc
typeSig d =
foralls <+> constraints <+/>
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
-- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging.
extras :: OpDef -> Doc
extras d = enclose "{-\n" "\n-}" $
strictText $ Text.pack $
showMessage ((def :: OpDef)
& inputArg .~ (d ^. inputArg)
& outputArg .~ (d ^. outputArg)
& attr .~ (d ^. attr))
-- | The type signature for an op.
-- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc
typeSig pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where
foralls | null typeMap = empty
| otherwise =
"forall"
<+> sep (refVariableNames ++ typeMapTypeNames)
<+> "."
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
constraints | null typeMap = empty
| otherwise =
tuple (concatMap
(\(t, aDef) ->
"TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>"
refVariableNames = catMaybes (map fst tensorInputs)
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
tensorArg refType arg = wrapArg refType arg <**>
pure (<+> hang 0 ("-- ^" <+> argComment arg))
-- Argument type is a list of tensors if number_attr is set;
-- otherwise it's a single Tensor.
wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
where typ = tensorType refType arg
-- The result is (reference type variable if any, type representing the arg)
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
tensorType refType arg
-- Deals with resource handles that are unlike tensors and
-- have their own type level representation. The magic "dtype"
-- name is the name of the attribute specifying the type of
-- the resource handle. It is not referenced by resource input
-- or output because it isn't expressible in TF operation
-- signatures.
| (arg ^. type' == DT_RESOURCE) =
(Nothing, strictText "ResourceHandle dtype")
| otherwise =
(Just refType,
"Tensor" <+> refType <+> maybe directType strictText indirectType)
where
-- This is the case of a named parameter type which is
-- constrained as a OneOf.
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
-- The nominal case when the type name is given directly.
directType = dtTypeToDoc (arg ^. type')
outputs =
case d ^. outputArg of
[] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os
wrappedOutput = snd . wrapArg "Value"
-- Tuple result case is rendered differently to give
-- individual elements their own comments.
renderTupleResult os =
stack $ [ tuple (map wrappedOutput os)
, flatten commentSummary
] ++ map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
constraints
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
ArgTensorEither v <- [parsedArgKind k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint
$ inferredTypeAttrs pOp
signatureFold = folddoc (\x y -> x </> "->" <+> y)
mandatoryAttrInputs = [
dtTypeToDoc dtType <+>
hang 0 ("-- ^" <+> argComment' tfName descr)
| (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
typeMap = opDefTypeMap d
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
renderAttrType (AttrList a) = brackets $ renderAttrBaseType a
renderAttrBaseType = \case
AttrBytes -> "ByteString"
AttrInt64 -> "Data.Int.Int64"
AttrFloat -> "Float"
AttrBool -> "Bool"
AttrType -> "DataType"
AttrShape -> "TensorShapeProto"
AttrTensor -> "TensorProto"
-- | Returns the type restriction for the given tensor type if the
-- set of allowed types is not empty (i.e. restricted).
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
oneOfRestrictions aDef tName = do
typs <- onAttrType (^. templateRestrictions) aDef
guard $ not $ null typs
let typeList = commasep $ map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell typs
return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName
tensorArgAndComment t = tensorArg t <+> hang 0 ("-- ^" <+> argComment t)
outputs = case parsedOutputs pOp of
[] -> "ControlNode"
-- TODO(judahjacobson): To improve indentation: `tensorArgAndComment a`
[a] -> tensorArg a <+> "-- ^" <+> argComment a
as -> tuple (map tensorArg as) <+/> resultComment as
-- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle dtype"
tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of
SimpleArg { argType = t } -> tensorType t
ListArg { argType = t } -> brackets $ tensorType t
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
where
tensorType t = let
v = case parsedArgKind p of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
ArgResource -> "ResourceHandle"
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n
in v <+> a
-- | Identifies the attributes used as tensor cardinalities. In such
-- cases a list of tensors is supplied as an input_arg. The number of
-- such inputs is communicated as a separate opAttr.
-- The result key is TensorFlow attribute name and the value is the
-- tensor names which have number_attr set to the result key.
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
(nAttr, replaceReservedName (inp ^. name) :| [])
| inp <- d ^. inputArg
, nAttr <- [inp ^. numberAttr]
, not (Text.null nAttr)
]
attrComment :: Attr a -> Doc
attrComment a = argComment' (attrName a) (attrDescription a)
argComment :: ParsedArg -> Doc
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
argComment :: OpDef'ArgDef -> Doc
argComment arg = argComment' (arg ^. name) (arg ^. description)
argComment' :: Text.Text -> Text.Text -> Doc
argComment' :: Name -> Text.Text -> Doc
argComment' argName argDesc =
bold argName <> splitMultilineText (":" <+>) argDesc
bold (renderTFName argName) <> splitMultilineText (":" <+>) argDesc
bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__")
bold :: Doc -> Doc
bold n = "__" <> n <> "__"
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
-- | Comment for the outputs of an op.
-- For example:
-- -- ^ (__output1__, __output2__)
-- --
-- -- * __output1__: description1
-- --
-- -- * __output2__: description2
resultComment :: [ParsedArg] -> Doc
resultComment os = stack $ flatten commentSummary : map commentDetails os
where
commentSummary = "-- ^" <+> tuple [bold (renderTFName $ parsedArgName o) | o <- os]
commentDetails o =
stack [ "--"
, "-- *" <+> argComment o
]
-- | Returns the map of type parameters from OpDef type name to
-- (haskell friendly type name, the type's attribute definition).
opDefTypeMap :: OpDef -> OpDefTypeMap
opDefTypeMap d =
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
attrList :: OpDef -> [(Text.Text, AttrDef)]
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]
isType :: AttrDef -> Bool
isType = fromMaybe False . onAttrType (const True)
-- | Applies the given function to the data type. Is this a Prism?
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
onAttrType f x = case x ^. attrTemplate of
AttrSingle (AttrType a) -> Just (f a)
_ -> Nothing
-- | mandatoryAttrMap contains the attributes chosen by
-- isMandatoryAttr, excluding those which are derived from list of
-- tensor arguments. The key is the TF name of the attribute. The
-- value tuple is (haskell name, TF type, attribute comment).
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
mandatoryAttrMap d =
Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
| (n, a) <- attrList d
, Just dtType <- [isMandatoryAttr a]
-- Excludes the attributes rendered as list lengths.
, n `Map.notMember` numberAttrMap d
]
-- | Inspects the attribute and if it is one of the implemented
-- non-tensor values lacking default, then returns Just the TF type.
isMandatoryAttr :: AttrDef -> Maybe DataType
isMandatoryAttr x =
case x ^. attrTemplate of
AttrSingle (AttrBool y) -> noDefault DT_BOOL y
AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
_ -> Nothing
where
noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)
dtTypeToDoc :: DataType -> Doc
dtTypeToDoc = strictText . dtTypeToHaskell
-- | Constraints for a given type parameter.
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
tensorArgConstraint :: Attr [DataType] -> [Doc]
tensorArgConstraint a
= ("TensorType" <+> n
: if null typeList
then []
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
where
n = renderHaskellName $ attrName a
typeList = map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell $ attrInfo a
-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.
@ -429,6 +363,12 @@ dtTypeToHaskell x =
haddockComment :: Text.Text -> Doc
haddockComment = strictText
-- | Generate a multiline comment. For example:
-- summary'
-- --
-- -- detail_line1
-- -- detail_line2
-- -- ...
multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail =
haddockComment summary' </>
@ -445,45 +385,5 @@ splitMultilineText lead detail =
(l : ls) -> stack $ lead (haddockComment l)
: map (("--" <+>) . haddockComment) ls
replaceReservedName :: Text -> Text
replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
indentation :: Int
indentation = 4
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]

View file

@ -1,120 +0,0 @@
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Wrapping of TensorFlow attributes into Haskell entities.
module TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
) where
import Data.Int (Int64)
import Data.Monoid ((<>))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
import qualified Data.ByteString as B
import qualified Data.Text as Text
-- | Specifies the optional default value and a set of allowed values
-- for the given type.
data Template a = Template {
_templateDefault :: Maybe a
-- ^ The default value (mandatory if unspecified)
, _templateRestrictions :: [a]
-- ^ The allowed set of values, empty if no restrictions
}
templateDefault :: Lens' (Template a) (Maybe a)
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
templateRestrictions :: Lens' (Template a) [a]
templateRestrictions = lens _templateRestrictions
(\g x -> g { _templateRestrictions = x })
data UnusedTensor
data AttrCase f
= AttrBytes (f B.ByteString) -- bytes s = 2; // "string"
| AttrInt64 (f Int64) -- int64 i = 3; // "int"
| AttrFloat (f Float) -- float f = 4; // "float"
| AttrBool (f Bool) -- bool b = 5; // "bool"
| AttrType (f DataType) -- type = 6; // "type"
-- To be translated into TensorFlow.Types.Shape before use.
-- Leaving as a proto to reduce dependencies.
| AttrShape (f TensorShapeProto) -- shape = 7; // "shape"
-- | Type-reified representation of TensorFlow AttrDef.
-- Initially limited to just the types in Op descriptors.
data AttrTemplate
= AttrSingle (AttrCase Template)
| AttrList (AttrCase [])
| AttrTensor UnusedTensor -- tensor = 8; // "tensor"
data AttrDef = AttrDef {
_attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
, _attrTemplate :: AttrTemplate -- ^ the type of the attribute
}
attrTemplate :: Lens' AttrDef AttrTemplate
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
attrOriginal :: Lens' AttrDef OpDef'AttrDef
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
attrDef :: OpDef'AttrDef -> AttrDef
attrDef a = AttrDef a
$ translate (a^.OpDef.type')
(a^.OpDef.defaultValue)
(a^.allowedValues)
-- | Converts the given AttrValue with the type given by the string
-- into the AttrVal if the type is known.
translate :: Text.Text -- ^ one of the TensorFlow type strings
-> AttrValue -- ^ default value
-> AttrValue -- ^ allowed values
-> AttrTemplate
translate t defaults allowed
| t == "string" = makeVal AttrBytes maybe's s
| t == "int" = makeVal AttrInt64 maybe'i i
| t == "float" = makeVal AttrFloat maybe'f f
| t == "bool" = makeVal AttrBool maybe'b b
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
| t == "shape" = makeVal AttrShape maybe'shape shape
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
| t == "list(string)" = makeList AttrBytes $ list.s
| t == "list(int)" = makeList AttrInt64 $ list.i
| t == "list(float)" = makeList AttrFloat $ list.f
| t == "list(bool)" = makeList AttrBool $ list.b
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
| t == "list(shape)" = makeList AttrShape $ list.shape
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
| t == "func" = AttrTensor $ error "func is unimplemented"
| otherwise = error $ show ("Unknown attribute type " <> t) ++
"," ++ show defaults ++
"," ++ show allowed
where makeVal c x y = AttrSingle $ c $
Template (defaults^.x) (allowed^.list.y)
makeList c y = AttrList $ c $ defaults^.y

View file

@ -0,0 +1,299 @@
-- | This module helps parse the proto OpDef into a Haskell type which is more
-- descriptive of how the attributes and arguments will be used in the
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
( ParsedOp(..)
, Name(..)
, HaskellName(..)
, TFName(..)
, Attr(..)
, AttrType(..)
, AttrBaseType(..)
, ParsedArg(..)
, ParsedArgCase(..)
, ArgType(..)
, ArgKind(..)
, parseOp
, camelCase
) where
import Data.Char (toUpper, toLower)
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Maybe (mapMaybe)
import Data.Monoid ((<>))
import Data.Ord (comparing)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Proto.Tensorflow.Core.Framework.AttrValue (list)
import Proto.Tensorflow.Core.Framework.OpDef
( OpDef
, OpDef'ArgDef
, OpDef'AttrDef
, allowedValues
, attr
, maybe'defaultValue
, description
, name
, inputArg
, outputArg
, summary
, typeListAttr
, numberAttr
, typeAttr
, type'
)
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
data ParsedOp = ParsedOp
{ parsedOpName :: Name
, parsedOpSummary :: Text
, parsedOpDescription :: Text
, parsedInputs :: [ParsedArg]
, parsedOutputs :: [ParsedArg]
, explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr [DataType]]
-- ^ Attributes that are type parameters.
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
-- If this list is empty, then any type is acceptable.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
-- Associated with the list of tensors whose size it describes.
}
data Name = Name
{ haskellName :: HaskellName
, tfName :: TFName
}
-- | A raw name as specified in the OpDef proto.
newtype TFName = TFName { unTFName :: Text }
deriving (Eq, Ord)
-- | A name that's appropriate for a variable in a Haskell source file.
newtype HaskellName = HaskellName { unHaskellName :: Text }
-- | A named attribute, associated with some information about it.
data Attr a = Attr
{ attrName :: Name
, attrDescription :: Text
, attrInfo :: a
}
-- | The type of an attribute.
data AttrType = AttrSingle AttrBaseType
| AttrList AttrBaseType
deriving Eq
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
deriving Eq
-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
{ parsedArgName :: Name
, parsedArgDescription :: Text
, parsedArgCase :: ParsedArgCase
, parsedArgKind :: ArgKind
}
data ParsedArgCase
= SimpleArg { argType :: ArgType }
| ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType
}
| MixedListArg { argTypeAttr :: Name }
-- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
-- | The type of an argument.
data ArgType
= ArgTypeFixed DataType -- ^ A fixed type.
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgResource -- Resource a
makeName :: Text -> Name
makeName n = Name
{ haskellName = HaskellName $ fixReservedName $ lowCase n
, tfName = TFName n
}
-- | Change a name so it doesn't conflict with any Haskell keywords.
fixReservedName :: Text -> Text
fixReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]
-- | Lower-case the given text.
lowCase :: Text -> Text
lowCase = forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
parseOp :: OpDef -> ParsedOp
parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, ..
}
where
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ o ^. attr
inferredTypeAttrs = mapMaybeAttrs getInferredTypeAttr $ o ^. attr
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
| a ^. type' == DT_RESOURCE = ArgResource
| otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
| a ^. type' == DT_RESOURCE = ArgResource
| otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat] = Just t
| otherwise = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
| otherwise = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
| a ^. type' == "int"
= nonEmpty [t | ParsedArg { parsedArgName = t
, parsedArgCase
= ListArg { argLength = n }
} <- inputs
, TFName (a ^. name) == tfName n]
| otherwise = Nothing
-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do
x <- f a
Just Attr
{ attrName = makeName (a ^. name)
, attrDescription = a ^. description
, attrInfo = x
}
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name)
, parsedArgDescription = a ^. description
, parsedArgCase = parseArgCase a
, parsedArgKind = tKind
}
parseArgCase :: OpDef'ArgDef -> ParsedArgCase
parseArgCase a
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType
| otherwise = SimpleArg thisArgType
where
thisArgType
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
| a ^. type' == DT_RESOURCE = ArgTypeAttr (makeName "dtype")
| otherwise = ArgTypeFixed (a ^. type')
maybeAttr :: Text -> Maybe Name
maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t
parseAttrType :: Text -> AttrType
parseAttrType = \case
"string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat
"list(bool)" -> AttrList AttrBool
"list(type)" -> AttrList AttrType
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t

View file

@ -13,7 +13,7 @@ cabal-version: >=1.22
library
hs-source-dirs: src
exposed-modules: TensorFlow.OpGen.AttrVal
exposed-modules: TensorFlow.OpGen.ParsedOp
, TensorFlow.OpGen
build-depends: proto-lens == 0.1.*
, tensorflow-proto == 0.1.*

View file

@ -16,6 +16,7 @@
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -24,6 +25,8 @@ import qualified Data.Vector as V
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses.
@ -34,11 +37,18 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
restored = CoreOps.concat dim splitList
dim = 1 -- dimension to split along (with size of 3 in original)
liftIO $ 3 @=? length splitList
(x, y, z) <-
TF.buildAnd TF.run $ return (original, restored, splitList !! 1)
(x, y, z) <- TF.run (original, restored, splitList !! 1)
liftIO $ x @=? (y :: V.Vector Float)
liftIO $ V.fromList [1, 4] @=? z
testShapeN :: Test
testShapeN = testCase "testShapeN" $ TF.runSession $ do
let shapes = map TF.Shape [[1],[2,3]]
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
result <- TF.run $ CoreOps.shapeN tensors
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
main :: IO ()
main = googleTest [ testSplit
, testShapeN
]