mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-12 20:59:48 +01:00
9209dfc4c4
Adds a new type `ListOf` which wraps a heterogeneous list; for example, `ListOf (Tensor Value) '[Int32, Float]` represents a list of two elements: a tensor of int32s and a tensor of floats. Also changes the `Queue2` type (which suppored pairs of tensors) to `Queue` (which supports arbitrary lists).
342 lines
11 KiB
Haskell
342 lines
11 KiB
Haskell
-- | This module helps parse the proto OpDef into a Haskell type which is more
|
|
-- descriptive of how the attributes and arguments will be used in the
|
|
-- generated code.
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RecordWildCards #-}
|
|
module TensorFlow.OpGen.ParsedOp
|
|
( ParsedOp(..)
|
|
, Name(..)
|
|
, HaskellName(..)
|
|
, TFName(..)
|
|
, Attr(..)
|
|
, AttrType(..)
|
|
, AttrBaseType(..)
|
|
, TypeParam(..)
|
|
, ParsedArg(..)
|
|
, ParsedArgCase(..)
|
|
, ArgType(..)
|
|
, ArgKind(..)
|
|
, argKind
|
|
, parseOp
|
|
, camelCase
|
|
) where
|
|
|
|
import Data.Char (toUpper, toLower)
|
|
import Data.List (sortBy)
|
|
import Data.List.NonEmpty (NonEmpty, nonEmpty)
|
|
import Data.Maybe (mapMaybe)
|
|
import Data.Monoid ((<>))
|
|
import Data.Ord (comparing)
|
|
import qualified Data.Set as Set
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as Text
|
|
import Lens.Family2 ((^.))
|
|
import Proto.Tensorflow.Core.Framework.AttrValue (list)
|
|
import Proto.Tensorflow.Core.Framework.OpDef
|
|
( OpDef
|
|
, OpDef'ArgDef
|
|
, OpDef'AttrDef
|
|
, allowedValues
|
|
, attr
|
|
, maybe'defaultValue
|
|
, description
|
|
, name
|
|
, inputArg
|
|
, isRef
|
|
, isStateful
|
|
, outputArg
|
|
, summary
|
|
, typeListAttr
|
|
, numberAttr
|
|
, typeAttr
|
|
, type'
|
|
)
|
|
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
|
|
|
|
data ParsedOp = ParsedOp
|
|
{ parsedOpName :: Name
|
|
, parsedOpSummary :: Text
|
|
, parsedOpDescription :: Text
|
|
, parsedInputs :: [ParsedArg]
|
|
, parsedOutputs :: [ParsedArg]
|
|
, explicitInputAttrs :: [Attr AttrType]
|
|
-- ^ Attributes that must be set explicitly when creating the op.
|
|
-- Associated with the type of the attribute.
|
|
, inferredTypeAttrs :: [Attr TypeParam]
|
|
-- ^ Attributes that are type parameters.
|
|
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
|
-- Attributes which are list sizes (ints) that are inferred automatically
|
|
-- from one or more of the input tensors.
|
|
-- Associated with the list of tensors whose size it describes.
|
|
, parsedOpIsMonadic :: Bool
|
|
-- ^ Whether this op is stateful or takes a stateful input. Such ops
|
|
-- should not be CSE'd and must be monadic in our API (i.e., return a
|
|
-- Build action).
|
|
}
|
|
|
|
data Name = Name
|
|
{ haskellName :: HaskellName
|
|
, tfName :: TFName
|
|
}
|
|
|
|
-- | A raw name as specified in the OpDef proto.
|
|
newtype TFName = TFName { unTFName :: Text }
|
|
deriving (Eq, Ord)
|
|
|
|
-- | A name that's appropriate for a variable in a Haskell source file.
|
|
newtype HaskellName = HaskellName { unHaskellName :: Text }
|
|
|
|
-- | A named attribute, associated with some information about it.
|
|
data Attr a = Attr
|
|
{ attrName :: Name
|
|
, attrDescription :: Text
|
|
, attrInfo :: a
|
|
}
|
|
|
|
-- | The type of an attribute.
|
|
data AttrType = AttrSingle AttrBaseType
|
|
| AttrList AttrBaseType
|
|
deriving Eq
|
|
|
|
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
|
| AttrType | AttrShape | AttrTensor
|
|
deriving Eq
|
|
|
|
data TypeParam = TypeParam
|
|
{ typeParamIsList :: Bool
|
|
, typeParamRestrictions :: Maybe (NonEmpty DataType)
|
|
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
|
|
-- If 'Nothing', then any type is acceptable.
|
|
}
|
|
|
|
-- | An input or output argument (Tensor) for an op.
|
|
data ParsedArg = ParsedArg
|
|
{ parsedArgName :: Name
|
|
, parsedArgDescription :: Text
|
|
, parsedArgCase :: ParsedArgCase
|
|
}
|
|
|
|
data ParsedArgCase
|
|
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
|
|
| ListArg
|
|
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
|
, argType :: ArgType
|
|
, argCaseKind :: ArgKind
|
|
}
|
|
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
|
-- ^ A heterogeneous list.
|
|
| ResourceArg
|
|
|
|
argKind :: ParsedArgCase -> Maybe ArgKind
|
|
argKind ResourceArg = Nothing
|
|
argKind a = Just $ argCaseKind a
|
|
|
|
-- | The type of an argument.
|
|
data ArgType
|
|
= ArgTypeFixed DataType -- ^ A fixed type.
|
|
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
|
|
|
|
-- The kind of an op input or output (not including the argument type `a`).
|
|
data ArgKind
|
|
= ArgTensorRef -- Tensor Ref a
|
|
| ArgTensorValue -- Tensor Value a
|
|
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
|
deriving (Eq)
|
|
|
|
isRefCase :: ParsedArgCase -> Bool
|
|
isRefCase a = case argKind a of
|
|
Nothing -> True -- Resource
|
|
Just ArgTensorRef -> True
|
|
_ -> False
|
|
|
|
makeName :: Text -> Name
|
|
makeName n = Name
|
|
{ haskellName = HaskellName $ fixReservedName $ lowCase n
|
|
, tfName = TFName n
|
|
}
|
|
|
|
-- | Change a name so it doesn't conflict with any Haskell keywords.
|
|
fixReservedName :: Text -> Text
|
|
fixReservedName n
|
|
| n `Set.member` reservedKeywords = n <> "'"
|
|
| otherwise = n
|
|
|
|
reservedKeywords :: Set.Set Text
|
|
reservedKeywords = Set.fromList $
|
|
-- Haskell2010 keywords:
|
|
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
|
|
-- We don't include keywords that are allowed to be variable names,
|
|
-- in particular: "as", "forall", and "hiding".
|
|
[ "case"
|
|
, "class"
|
|
, "data"
|
|
, "default"
|
|
, "deriving"
|
|
, "do"
|
|
, "else"
|
|
, "foreign"
|
|
, "if"
|
|
, "import"
|
|
, "in"
|
|
, "infix"
|
|
, "infixl"
|
|
, "infixr"
|
|
, "instance"
|
|
, "let"
|
|
, "module"
|
|
, "newtype"
|
|
, "of"
|
|
, "then"
|
|
, "type"
|
|
, "where"
|
|
]
|
|
++ -- Nonstandard extensions
|
|
[ "mdo" -- RecursiveDo
|
|
, "rec" -- Arrows, RecursiveDo
|
|
, "proc" -- Arrows
|
|
]
|
|
|
|
-- | Lower-case the given text.
|
|
lowCase :: Text -> Text
|
|
lowCase = forceCase toLower
|
|
|
|
forceCase :: (Char -> Char) -> Text -> Text
|
|
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
|
(Text.uncons s)
|
|
|
|
camelCase :: Text -> Text
|
|
camelCase s = Text.concat $ map upCase
|
|
$ Text.splitOn "_" s
|
|
|
|
-- | Upper-case the given text.
|
|
upCase :: Text -> Text
|
|
upCase = forceCase toUpper
|
|
|
|
|
|
parseOp :: OpDef -> ParsedOp
|
|
parseOp o = ParsedOp
|
|
{ parsedOpName = makeName $ o ^. name
|
|
, parsedOpSummary = o ^. summary
|
|
, parsedOpDescription = o ^. description
|
|
, parsedOpIsMonadic = o ^. isStateful
|
|
|| any (isRefCase . parsedArgCase) parsedInputs
|
|
, ..
|
|
}
|
|
where
|
|
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
|
(o ^. inputArg) tensorKindParams
|
|
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
|
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
|
-- Integer attributes that can be inferred from the size of at least one
|
|
-- input list.
|
|
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
|
$ o ^. attr
|
|
implicitAttrs = Set.fromList $ map tfName $
|
|
map attrName inferredTypeAttrs
|
|
++ map attrName inferredListSizeAttrs
|
|
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
|
|
argTypeParams = Set.fromList $ map tfName $
|
|
mapMaybe (getArgTypeParam . parsedArgCase) $
|
|
parsedInputs ++ parsedOutputs
|
|
-- Attributes that can't be inferred and don't have defaults, so must be
|
|
-- passed as separate arguments to the op.
|
|
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
|
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
|
|
$ o ^. attr
|
|
|
|
-- TODO(judahjacobson): Some arguments should be refs.
|
|
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
|
inputTensorKind a v
|
|
| a ^. isRef = ArgTensorRef
|
|
| otherwise = ArgTensorEither v
|
|
|
|
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
|
outputTensorKind a
|
|
| a ^. isRef = ArgTensorRef
|
|
| otherwise = ArgTensorValue
|
|
|
|
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
|
getExplicitInputAttr o implicitAttrs a
|
|
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
|
, a ^. maybe'defaultValue == Nothing
|
|
, t <- parseAttrType o (a ^. type')
|
|
, t `elem` map AttrSingle
|
|
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
|
++ [AttrList AttrType] = Just t
|
|
| otherwise = Nothing
|
|
|
|
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
|
|
getInferredTypeAttr argTypeParams a
|
|
| TFName (a ^. name) `notElem` argTypeParams = Nothing
|
|
| a ^. type' == "type" = Just $ TypeParam False allowed
|
|
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
|
|
| otherwise = Nothing
|
|
where
|
|
allowed = nonEmpty (a ^. allowedValues . list . type')
|
|
|
|
getArgTypeParam :: ParsedArgCase -> Maybe Name
|
|
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
|
|
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
|
|
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
|
|
getArgTypeParam _ = Nothing
|
|
|
|
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
|
getInferredListSizeAttr inputs a
|
|
| a ^. type' == "int"
|
|
= nonEmpty [t | ParsedArg { parsedArgName = t
|
|
, parsedArgCase
|
|
= ListArg { argLength = n }
|
|
} <- inputs
|
|
, TFName (a ^. name) == tfName n]
|
|
| otherwise = Nothing
|
|
|
|
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
|
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
|
mapMaybeAttrs f = mapMaybe $ \a -> do
|
|
x <- f a
|
|
Just Attr
|
|
{ attrName = makeName (a ^. name)
|
|
, attrDescription = a ^. description
|
|
, attrInfo = x
|
|
}
|
|
|
|
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
|
parseArg a tKind = ParsedArg
|
|
{ parsedArgName = makeName (a ^. name)
|
|
, parsedArgDescription = a ^. description
|
|
, parsedArgCase = parseArgCase a tKind
|
|
}
|
|
|
|
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
|
parseArgCase a tKind
|
|
| a ^. type' == DT_RESOURCE = ResourceArg
|
|
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
|
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
|
| otherwise = SimpleArg thisArgType tKind
|
|
where
|
|
thisArgType
|
|
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
|
|
| otherwise = ArgTypeFixed (a ^. type')
|
|
maybeAttr :: Text -> Maybe Name
|
|
maybeAttr "" = Nothing
|
|
maybeAttr t = Just $ makeName t
|
|
|
|
parseAttrType :: OpDef -> Text -> AttrType
|
|
parseAttrType o = \case
|
|
"string" -> AttrSingle AttrBytes
|
|
"int" -> AttrSingle AttrInt64
|
|
"float" -> AttrSingle AttrFloat
|
|
"bool" -> AttrSingle AttrBool
|
|
"type" -> AttrSingle AttrType
|
|
"shape" -> AttrSingle AttrShape
|
|
"tensor" -> AttrSingle AttrTensor
|
|
"list(string)" -> AttrList AttrBytes
|
|
"list(int)" -> AttrList AttrInt64
|
|
"list(float)" -> AttrList AttrFloat
|
|
"list(bool)" -> AttrList AttrBool
|
|
"list(type)" -> AttrList AttrType
|
|
"list(shape)" -> AttrList AttrShape
|
|
"list(tensor)" -> AttrList AttrTensor
|
|
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
|
++ " for op " ++ show (o ^. name)
|