1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-12 20:59:48 +01:00
tensorflow-haskell/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs
Judah Jacobson 9209dfc4c4 Support lists of tensors in ops. (#79)
Adds a new type `ListOf` which wraps a heterogeneous list; for example,
`ListOf (Tensor Value) '[Int32, Float]` represents a list of two
elements: a tensor of int32s and a tensor of floats.

Also changes the `Queue2` type (which suppored pairs of tensors) to
`Queue` (which supports arbitrary lists).
2017-03-17 13:53:19 -07:00

342 lines
11 KiB
Haskell

-- | This module helps parse the proto OpDef into a Haskell type which is more
-- descriptive of how the attributes and arguments will be used in the
-- generated code.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module TensorFlow.OpGen.ParsedOp
( ParsedOp(..)
, Name(..)
, HaskellName(..)
, TFName(..)
, Attr(..)
, AttrType(..)
, AttrBaseType(..)
, TypeParam(..)
, ParsedArg(..)
, ParsedArgCase(..)
, ArgType(..)
, ArgKind(..)
, argKind
, parseOp
, camelCase
) where
import Data.Char (toUpper, toLower)
import Data.List (sortBy)
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import Data.Maybe (mapMaybe)
import Data.Monoid ((<>))
import Data.Ord (comparing)
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Proto.Tensorflow.Core.Framework.AttrValue (list)
import Proto.Tensorflow.Core.Framework.OpDef
( OpDef
, OpDef'ArgDef
, OpDef'AttrDef
, allowedValues
, attr
, maybe'defaultValue
, description
, name
, inputArg
, isRef
, isStateful
, outputArg
, summary
, typeListAttr
, numberAttr
, typeAttr
, type'
)
import Proto.Tensorflow.Core.Framework.Types (DataType(DT_RESOURCE))
data ParsedOp = ParsedOp
{ parsedOpName :: Name
, parsedOpSummary :: Text
, parsedOpDescription :: Text
, parsedInputs :: [ParsedArg]
, parsedOutputs :: [ParsedArg]
, explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr TypeParam]
-- ^ Attributes that are type parameters.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
-- Associated with the list of tensors whose size it describes.
, parsedOpIsMonadic :: Bool
-- ^ Whether this op is stateful or takes a stateful input. Such ops
-- should not be CSE'd and must be monadic in our API (i.e., return a
-- Build action).
}
data Name = Name
{ haskellName :: HaskellName
, tfName :: TFName
}
-- | A raw name as specified in the OpDef proto.
newtype TFName = TFName { unTFName :: Text }
deriving (Eq, Ord)
-- | A name that's appropriate for a variable in a Haskell source file.
newtype HaskellName = HaskellName { unHaskellName :: Text }
-- | A named attribute, associated with some information about it.
data Attr a = Attr
{ attrName :: Name
, attrDescription :: Text
, attrInfo :: a
}
-- | The type of an attribute.
data AttrType = AttrSingle AttrBaseType
| AttrList AttrBaseType
deriving Eq
data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
deriving Eq
data TypeParam = TypeParam
{ typeParamIsList :: Bool
, typeParamRestrictions :: Maybe (NonEmpty DataType)
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
-- If 'Nothing', then any type is acceptable.
}
-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
{ parsedArgName :: Name
, parsedArgDescription :: Text
, parsedArgCase :: ParsedArgCase
}
data ParsedArgCase
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
| ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType
, argCaseKind :: ArgKind
}
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
argKind ResourceArg = Nothing
argKind a = Just $ argCaseKind a
-- | The type of an argument.
data ArgType
= ArgTypeFixed DataType -- ^ A fixed type.
| ArgTypeAttr Name -- ^ A type that depends on an attribute.
-- The kind of an op input or output (not including the argument type `a`).
data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
deriving (Eq)
isRefCase :: ParsedArgCase -> Bool
isRefCase a = case argKind a of
Nothing -> True -- Resource
Just ArgTensorRef -> True
_ -> False
makeName :: Text -> Name
makeName n = Name
{ haskellName = HaskellName $ fixReservedName $ lowCase n
, tfName = TFName n
}
-- | Change a name so it doesn't conflict with any Haskell keywords.
fixReservedName :: Text -> Text
fixReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
-- Haskell2010 keywords:
-- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
-- We don't include keywords that are allowed to be variable names,
-- in particular: "as", "forall", and "hiding".
[ "case"
, "class"
, "data"
, "default"
, "deriving"
, "do"
, "else"
, "foreign"
, "if"
, "import"
, "in"
, "infix"
, "infixl"
, "infixr"
, "instance"
, "let"
, "module"
, "newtype"
, "of"
, "then"
, "type"
, "where"
]
++ -- Nonstandard extensions
[ "mdo" -- RecursiveDo
, "rec" -- Arrows, RecursiveDo
, "proc" -- Arrows
]
-- | Lower-case the given text.
lowCase :: Text -> Text
lowCase = forceCase toLower
forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ Text.splitOn "_" s
-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper
parseOp :: OpDef -> ParsedOp
parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefCase . parsedArgCase) parsedInputs
, ..
}
where
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
$ o ^. attr
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
argTypeParams = Set.fromList $ map tfName $
mapMaybe (getArgTypeParam . parsedArgCase) $
parsedInputs ++ parsedOutputs
-- Attributes that can't be inferred and don't have defaults, so must be
-- passed as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
++ [AttrList AttrType] = Just t
| otherwise = Nothing
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr argTypeParams a
| TFName (a ^. name) `notElem` argTypeParams = Nothing
| a ^. type' == "type" = Just $ TypeParam False allowed
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
| otherwise = Nothing
where
allowed = nonEmpty (a ^. allowedValues . list . type')
getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
getArgTypeParam _ = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
| a ^. type' == "int"
= nonEmpty [t | ParsedArg { parsedArgName = t
, parsedArgCase
= ListArg { argLength = n }
} <- inputs
, TFName (a ^. name) == tfName n]
| otherwise = Nothing
-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do
x <- f a
Just Attr
{ attrName = makeName (a ^. name)
, attrDescription = a ^. description
, attrInfo = x
}
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name)
, parsedArgDescription = a ^. description
, parsedArgCase = parseArgCase a tKind
}
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a tKind
| a ^. type' == DT_RESOURCE = ResourceArg
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
| otherwise = SimpleArg thisArgType tKind
where
thisArgType
| Just n <- maybeAttr (a ^. typeAttr) = ArgTypeAttr n
| otherwise = ArgTypeFixed (a ^. type')
maybeAttr :: Text -> Maybe Name
maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t
parseAttrType :: OpDef -> Text -> AttrType
parseAttrType o = \case
"string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat
"list(bool)" -> AttrList AttrBool
"list(type)" -> AttrList AttrType
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name)