mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Support lists of tensors in ops. (#79)
Adds a new type `ListOf` which wraps a heterogeneous list; for example, `ListOf (Tensor Value) '[Int32, Float]` represents a list of two elements: a tensor of int32s and a tensor of floats. Also changes the `Queue2` type (which suppored pairs of tensors) to `Queue` (which supports arbitrary lists).
This commit is contained in:
parent
7cc6a69866
commit
9209dfc4c4
9 changed files with 273 additions and 144 deletions
|
@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
|
||||||
(prettyLazyText 80 $ docOpList flags x)
|
(prettyLazyText 80 $ docOpList flags x)
|
||||||
|
|
||||||
blackList =
|
blackList =
|
||||||
-- A few data flow ops take a list of heterogeneous
|
[ -- Requires the "func" type:
|
||||||
-- parameters which we don't support in general form.
|
"SymbolicGradient"
|
||||||
[ "HashTable"
|
|
||||||
, "MutableDenseHashTable"
|
|
||||||
, "MutableHashTable"
|
|
||||||
, "MutableHashTableOfTensors"
|
|
||||||
, "QueueDequeue"
|
|
||||||
, "QueueDequeueMany"
|
|
||||||
, "QueueDequeueUpTo"
|
|
||||||
, "Stack"
|
|
||||||
, "TensorArray"
|
|
||||||
, "TensorArrayV2"
|
|
||||||
, "QueueEnqueueManyV2"
|
|
||||||
, "QueueDequeueV2"
|
|
||||||
, "QueueDequeueUpToV2"
|
|
||||||
, "QueueEnqueueV2"
|
|
||||||
, "QueueDequeueManyV2"
|
|
||||||
, "Stage"
|
|
||||||
, "Unstage"
|
|
||||||
-- These should be possible to support by adding a bunch of
|
|
||||||
-- overloads with a variable number of tuple arguments.
|
|
||||||
, "Assert"
|
|
||||||
, "BarrierTakeMany"
|
|
||||||
, "Print"
|
|
||||||
, "QueueEnqueue"
|
|
||||||
, "QueueEnqueueMany"
|
|
||||||
-- Need list of types support.
|
|
||||||
, "DecodeCSV"
|
|
||||||
, "ParseExample"
|
|
||||||
, "ParseSingleSequenceExample"
|
|
||||||
, "RestoreV2"
|
|
||||||
, "Save"
|
|
||||||
, "SaveV2"
|
|
||||||
, "SaveSlices"
|
|
||||||
, "SymbolicGradient"
|
|
||||||
, "_ArrayToList"
|
|
||||||
, "_ListToArray"
|
|
||||||
-- Easy: support larger result tuples.
|
-- Easy: support larger result tuples.
|
||||||
|
, "ParseSingleSequenceExample"
|
||||||
, "Skipgram"
|
, "Skipgram"
|
||||||
]
|
]
|
||||||
|
|
|
@ -147,6 +147,7 @@ imports = stack [
|
||||||
"import Data.ByteString (ByteString)"
|
"import Data.ByteString (ByteString)"
|
||||||
, "import Data.Complex (Complex)"
|
, "import Data.Complex (Complex)"
|
||||||
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||||
|
, "import Data.Proxy (Proxy(Proxy))"
|
||||||
, "import Data.Word (Word8, Word16)"
|
, "import Data.Word (Word8, Word16)"
|
||||||
, "import Lens.Family2 ((.~), (&))"
|
, "import Lens.Family2 ((.~), (&))"
|
||||||
, "import TensorFlow.Build"
|
, "import TensorFlow.Build"
|
||||||
|
@ -210,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
||||||
whereClause [] = []
|
whereClause [] = []
|
||||||
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||||
where
|
where
|
||||||
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
defineLengthAttr a = renderHaskellAttrName a <+> "="
|
||||||
<+> "fromIntegral (length"
|
<+> "fromIntegral (length"
|
||||||
<+> renderHaskellName (NE.head $ attrInfo a)
|
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||||
<> ") :: Int64"
|
<> ") :: Int64"
|
||||||
|
|
||||||
|
renderHaskellAttrName :: Attr a -> Doc
|
||||||
|
renderHaskellAttrName = renderHaskellName . attrName
|
||||||
|
|
||||||
functionBody :: ParsedOp -> Doc
|
functionBody :: ParsedOp -> Doc
|
||||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
</> indent indentation (sep tensorArgs)
|
</> indent indentation (sep tensorArgs)
|
||||||
|
@ -229,9 +233,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
<- parsedOutputs pOp]
|
<- parsedOutputs pOp]
|
||||||
buildOpParts =
|
buildOpParts =
|
||||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||||
-- Renders tensor arguments.
|
-- Renders type parameter arguments.
|
||||||
[ "& opAttr" <+> renderQuotedTFName n <+>
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
||||||
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
|
||||||
| a <- inferredTypeAttrs pOp, let n = attrName a
|
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||||
] ++
|
] ++
|
||||||
-- Renders mandatory attributes as function parameters.
|
-- Renders mandatory attributes as function parameters.
|
||||||
|
@ -244,6 +247,12 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
]
|
]
|
||||||
|
|
||||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||||
|
inferredTypeExpr a
|
||||||
|
| typeParamIsList $ attrInfo a
|
||||||
|
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||||
|
<> ")"
|
||||||
|
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
|
||||||
|
<> ")"
|
||||||
|
|
||||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||||
-- debugging.
|
-- debugging.
|
||||||
|
@ -272,8 +281,8 @@ typeSig pOp = constraints
|
||||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
classConstraints = tuple $ map tensorArgConstraint
|
||||||
$ inferredTypeAttrs pOp
|
$ inferredTypeAttrs pOp
|
||||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||||
|
@ -305,17 +314,18 @@ tensorArg p = case parsedArgCase p of
|
||||||
ResourceArg -> "ResourceHandle"
|
ResourceArg -> "ResourceHandle"
|
||||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||||
|
-> "TensorList" <+> kind k <+> renderHaskellName t
|
||||||
where
|
where
|
||||||
|
kind k = case k of
|
||||||
|
ArgTensorRef -> "Ref"
|
||||||
|
ArgTensorValue -> "Value"
|
||||||
|
ArgTensorEither v' -> strictText v'
|
||||||
tensorType t k = let
|
tensorType t k = let
|
||||||
v = case k of
|
|
||||||
ArgTensorRef -> "Tensor Ref"
|
|
||||||
ArgTensorValue -> "Tensor Value"
|
|
||||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
|
||||||
a = case t of
|
a = case t of
|
||||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||||
ArgTypeAttr n -> renderHaskellName n
|
ArgTypeAttr n -> renderHaskellName n
|
||||||
in v <+> a
|
in "Tensor" <+> kind k <+> a
|
||||||
|
|
||||||
attrComment :: Attr a -> Doc
|
attrComment :: Attr a -> Doc
|
||||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||||
|
@ -347,18 +357,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
|
||||||
]
|
]
|
||||||
|
|
||||||
-- | Constraints for a given type parameter.
|
-- | Constraints for a given type parameter.
|
||||||
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
|
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
|
||||||
tensorArgConstraint :: Attr [DataType] -> [Doc]
|
-- or "TensorTypes ts" or "OneOfs [..] ts".
|
||||||
tensorArgConstraint a
|
tensorArgConstraint :: Attr TypeParam -> Doc
|
||||||
= ("TensorType" <+> n
|
tensorArgConstraint a = case attrInfo a of
|
||||||
: if null typeList
|
TypeParam False Nothing -> "TensorType" <+> n
|
||||||
then []
|
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
|
||||||
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
|
TypeParam True Nothing -> "TensorTypes" <+> n
|
||||||
|
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
|
||||||
where
|
where
|
||||||
n = renderHaskellName $ attrName a
|
n = renderHaskellAttrName a
|
||||||
typeList = map strictText $
|
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
|
||||||
Set.toList $ Set.fromList $
|
typeList = ("'" <>) . brackets . commasep . map strictText .
|
||||||
map dtTypeToHaskell $ attrInfo a
|
Set.toList . Set.fromList .
|
||||||
|
map dtTypeToHaskell . toList
|
||||||
|
|
||||||
-- NOTE: The cases of this function should be kept in sync with
|
-- NOTE: The cases of this function should be kept in sync with
|
||||||
-- TensorFlow.Types.AllTensorTypes.
|
-- TensorFlow.Types.AllTensorTypes.
|
||||||
|
|
|
@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
|
||||||
, Attr(..)
|
, Attr(..)
|
||||||
, AttrType(..)
|
, AttrType(..)
|
||||||
, AttrBaseType(..)
|
, AttrBaseType(..)
|
||||||
|
, TypeParam(..)
|
||||||
, ParsedArg(..)
|
, ParsedArg(..)
|
||||||
, ParsedArgCase(..)
|
, ParsedArgCase(..)
|
||||||
, ArgType(..)
|
, ArgType(..)
|
||||||
|
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
|
||||||
, explicitInputAttrs :: [Attr AttrType]
|
, explicitInputAttrs :: [Attr AttrType]
|
||||||
-- ^ Attributes that must be set explicitly when creating the op.
|
-- ^ Attributes that must be set explicitly when creating the op.
|
||||||
-- Associated with the type of the attribute.
|
-- Associated with the type of the attribute.
|
||||||
, inferredTypeAttrs :: [Attr [DataType]]
|
, inferredTypeAttrs :: [Attr TypeParam]
|
||||||
-- ^ Attributes that are type parameters.
|
-- ^ Attributes that are type parameters.
|
||||||
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
|
|
||||||
-- If this list is empty, then any type is acceptable.
|
|
||||||
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
||||||
-- Attributes which are list sizes (ints) that are inferred automatically
|
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||||
-- from one or more of the input tensors.
|
-- from one or more of the input tensors.
|
||||||
|
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
||||||
| AttrType | AttrShape | AttrTensor
|
| AttrType | AttrShape | AttrTensor
|
||||||
deriving Eq
|
deriving Eq
|
||||||
|
|
||||||
|
data TypeParam = TypeParam
|
||||||
|
{ typeParamIsList :: Bool
|
||||||
|
, typeParamRestrictions :: Maybe (NonEmpty DataType)
|
||||||
|
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
|
||||||
|
-- If 'Nothing', then any type is acceptable.
|
||||||
|
}
|
||||||
|
|
||||||
-- | An input or output argument (Tensor) for an op.
|
-- | An input or output argument (Tensor) for an op.
|
||||||
data ParsedArg = ParsedArg
|
data ParsedArg = ParsedArg
|
||||||
{ parsedArgName :: Name
|
{ parsedArgName :: Name
|
||||||
|
@ -120,7 +126,6 @@ data ParsedArgCase
|
||||||
}
|
}
|
||||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||||
-- ^ A heterogeneous list.
|
-- ^ A heterogeneous list.
|
||||||
-- TODO(judahjacobson): Implement this.
|
|
||||||
| ResourceArg
|
| ResourceArg
|
||||||
|
|
||||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||||
|
@ -223,11 +228,6 @@ parseOp o = ParsedOp
|
||||||
(o ^. inputArg) tensorKindParams
|
(o ^. inputArg) tensorKindParams
|
||||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||||
-- Type attributes that can be inferred from at least one input or output.
|
|
||||||
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
|
|
||||||
$ parsedInputs ++ parsedOutputs
|
|
||||||
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
|
|
||||||
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
|
||||||
-- Integer attributes that can be inferred from the size of at least one
|
-- Integer attributes that can be inferred from the size of at least one
|
||||||
-- input list.
|
-- input list.
|
||||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||||
|
@ -235,10 +235,14 @@ parseOp o = ParsedOp
|
||||||
implicitAttrs = Set.fromList $ map tfName $
|
implicitAttrs = Set.fromList $ map tfName $
|
||||||
map attrName inferredTypeAttrs
|
map attrName inferredTypeAttrs
|
||||||
++ map attrName inferredListSizeAttrs
|
++ map attrName inferredListSizeAttrs
|
||||||
-- Attributes that can't be inferred and don't have defaults, so must be passed
|
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
|
||||||
-- as separate arguments to the op.
|
argTypeParams = Set.fromList $ map tfName $
|
||||||
|
mapMaybe (getArgTypeParam . parsedArgCase) $
|
||||||
|
parsedInputs ++ parsedOutputs
|
||||||
|
-- Attributes that can't be inferred and don't have defaults, so must be
|
||||||
|
-- passed as separate arguments to the op.
|
||||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
|
||||||
$ o ^. attr
|
$ o ^. attr
|
||||||
|
|
||||||
-- TODO(judahjacobson): Some arguments should be refs.
|
-- TODO(judahjacobson): Some arguments should be refs.
|
||||||
|
@ -252,29 +256,30 @@ outputTensorKind a
|
||||||
| a ^. isRef = ArgTensorRef
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorValue
|
| otherwise = ArgTensorValue
|
||||||
|
|
||||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||||
getExplicitInputAttr implicitAttrs a
|
getExplicitInputAttr o implicitAttrs a
|
||||||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||||
, a ^. maybe'defaultValue == Nothing
|
, a ^. maybe'defaultValue == Nothing
|
||||||
, t <- parseAttrType (a ^. type')
|
, t <- parseAttrType o (a ^. type')
|
||||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
, t `elem` map AttrSingle
|
||||||
|
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
||||||
|
++ [AttrList AttrType] = Just t
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
-- | The type attribute used by this input or output (if any).
|
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
|
||||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
getInferredTypeAttr argTypeParams a
|
||||||
parsedArgTypeAttr p = case parsedArgCase p of
|
| TFName (a ^. name) `notElem` argTypeParams = Nothing
|
||||||
ResourceArg -> Nothing
|
| a ^. type' == "type" = Just $ TypeParam False allowed
|
||||||
SimpleArg {argType = t} -> fromArgType t
|
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
|
||||||
ListArg {argType = t} -> fromArgType t
|
| otherwise = Nothing
|
||||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
|
||||||
where
|
where
|
||||||
fromArgType (ArgTypeAttr n) = Just $ tfName n
|
allowed = nonEmpty (a ^. allowedValues . list . type')
|
||||||
fromArgType _ = Nothing
|
|
||||||
|
|
||||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
getArgTypeParam :: ParsedArgCase -> Maybe Name
|
||||||
getInferredTypeAttr a
|
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
|
||||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
|
||||||
| otherwise = Nothing
|
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
|
||||||
|
getArgTypeParam _ = Nothing
|
||||||
|
|
||||||
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
||||||
getInferredListSizeAttr inputs a
|
getInferredListSizeAttr inputs a
|
||||||
|
@ -317,8 +322,8 @@ parseArgCase a tKind
|
||||||
maybeAttr "" = Nothing
|
maybeAttr "" = Nothing
|
||||||
maybeAttr t = Just $ makeName t
|
maybeAttr t = Just $ makeName t
|
||||||
|
|
||||||
parseAttrType :: Text -> AttrType
|
parseAttrType :: OpDef -> Text -> AttrType
|
||||||
parseAttrType = \case
|
parseAttrType o = \case
|
||||||
"string" -> AttrSingle AttrBytes
|
"string" -> AttrSingle AttrBytes
|
||||||
"int" -> AttrSingle AttrInt64
|
"int" -> AttrSingle AttrInt64
|
||||||
"float" -> AttrSingle AttrFloat
|
"float" -> AttrSingle AttrFloat
|
||||||
|
@ -334,3 +339,4 @@ parseAttrType = \case
|
||||||
"list(shape)" -> AttrList AttrShape
|
"list(shape)" -> AttrList AttrShape
|
||||||
"list(tensor)" -> AttrList AttrTensor
|
"list(tensor)" -> AttrList AttrTensor
|
||||||
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
||||||
|
++ " for op " ++ show (o ^. name)
|
||||||
|
|
|
@ -12,67 +12,65 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
-- | Queues in TensorFlow graph. Very limited support for now.
|
-- | Queues in TensorFlow graph. Very limited support for now.
|
||||||
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
|
||||||
|
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
|
import Data.Proxy (Proxy(..))
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||||
import TensorFlow.BuildOp (buildOp)
|
import TensorFlow.BuildOp (buildOp)
|
||||||
import TensorFlow.ControlFlow (group)
|
import TensorFlow.ControlFlow (group)
|
||||||
import TensorFlow.Tensor (Ref, Tensor)
|
import TensorFlow.Tensor (Ref, Tensor, TensorList)
|
||||||
import TensorFlow.Types (TensorType, tensorType)
|
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
||||||
|
|
||||||
-- | A queue carrying tuples. The underlying structure is more
|
-- | A queue carrying tuples.
|
||||||
-- versatile and can be made to support arbitrary tuples.
|
data Queue (as :: [*]) = Queue { handle :: Handle }
|
||||||
data Queue2 a b = Queue2 { handle :: Handle }
|
|
||||||
|
|
||||||
type Handle = Tensor Ref ByteString
|
type Handle = Tensor Ref ByteString
|
||||||
|
|
||||||
-- | Adds the given values to the queue.
|
-- | Adds the given values to the queue.
|
||||||
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
enqueue :: forall as v . TensorTypes as
|
||||||
=> Queue2 a b
|
=> Queue as
|
||||||
-> Tensor v1 a
|
-> TensorList v as
|
||||||
-> Tensor v2 b
|
|
||||||
-> Build ControlNode
|
-> Build ControlNode
|
||||||
enqueue q =
|
enqueue q =
|
||||||
buildOp (opDef "QueueEnqueue"
|
buildOp (opDef "QueueEnqueue"
|
||||||
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||||
, tensorType (undefined :: b)])
|
|
||||||
(handle q)
|
(handle q)
|
||||||
|
|
||||||
-- | Retrieves the values from the queue.
|
-- | Retrieves the values from the queue.
|
||||||
dequeue :: forall a b . (TensorType a, TensorType b)
|
dequeue :: forall as . TensorTypes as
|
||||||
=> Queue2 a b
|
=> Queue as
|
||||||
-> Build (Tensor Ref a, Tensor Ref b)
|
-> Build (TensorList Ref as)
|
||||||
-- ^ Dequeued tensors. They are paired in a sense
|
-- ^ Dequeued tensors. They are coupled in a sense
|
||||||
-- that values appear together, even if they are
|
-- that values appear together, even if they are
|
||||||
-- not consumed together.
|
-- not consumed together.
|
||||||
dequeue q =
|
dequeue q =
|
||||||
buildOp (opDef "QueueDequeue"
|
buildOp (opDef "QueueDequeue"
|
||||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||||
, tensorType (undefined :: b)])
|
|
||||||
(handle q)
|
(handle q)
|
||||||
|
|
||||||
-- | Creates a new queue with the given capacity and shared name.
|
-- | Creates a new queue with the given capacity and shared name.
|
||||||
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
makeQueue :: forall as . TensorTypes as
|
||||||
=> Int64 -- ^ The upper bound on the number of elements in
|
=> Int64 -- ^ The upper bound on the number of elements in
|
||||||
-- this queue. Negative numbers mean no limit.
|
-- this queue. Negative numbers mean no limit.
|
||||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||||
-- under the given name across multiple sessions.
|
-- under the given name across multiple sessions.
|
||||||
-> Build (Queue2 a b)
|
-> Build (Queue as)
|
||||||
makeQueue2 capacity sharedName = do
|
makeQueue capacity sharedName = do
|
||||||
q <- buildOp (opDef "FIFOQueue"
|
q <- buildOp (opDef "FIFOQueue"
|
||||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||||
, tensorType (undefined :: b)]
|
|
||||||
& opAttr "shared_name" .~ sharedName
|
& opAttr "shared_name" .~ sharedName
|
||||||
& opAttr "capacity" .~ capacity
|
& opAttr "capacity" .~ capacity
|
||||||
)
|
)
|
||||||
group q >>= addInitializer
|
group q >>= addInitializer
|
||||||
return (Queue2 q)
|
return (Queue q)
|
||||||
|
|
||||||
-- TODO(gnezdo): Figure out the closing story for queues.
|
-- TODO(gnezdo): Figure out the closing story for queues.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ module Main where
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import TensorFlow.Types (Scalar(..))
|
import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
|
||||||
import TensorFlow.Ops (scalar)
|
import TensorFlow.Ops (scalar)
|
||||||
import TensorFlow.Queue
|
import TensorFlow.Queue
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
|
@ -39,42 +40,50 @@ import qualified Data.ByteString as BS
|
||||||
-- | Test basic queue behaviors.
|
-- | Test basic queue behaviors.
|
||||||
testBasic :: Test
|
testBasic :: Test
|
||||||
testBasic = testCase "testBasic" $ runSession $ do
|
testBasic = testCase "testBasic" $ runSession $ do
|
||||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
|
||||||
x <- buildAnd run (dequeue q)
|
x <- buildAnd run (dequeue q)
|
||||||
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
||||||
|
|
||||||
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
|
||||||
y <- buildAnd run (dequeue q)
|
y <- buildAnd run (dequeue q)
|
||||||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||||
|
-- fetched. Equivalently we could write
|
||||||
|
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||||
|
-- or else allow the types to be determined by future use of the fetched
|
||||||
|
-- value.
|
||||||
|
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
|
||||||
|
liftIO $ expected @=? y
|
||||||
|
|
||||||
-- | Test queue pumping.
|
-- | Test queue pumping.
|
||||||
testPump :: Test
|
testPump :: Test
|
||||||
testPump = testCase "testPump" $ runSession $ do
|
testPump = testCase "testPump" $ runSession $ do
|
||||||
(deq, pump) <- build $ do
|
(deq, pump) <- build $ do
|
||||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
|
||||||
(,) <$> dequeue q
|
(,) <$> dequeue q
|
||||||
<*> enqueue q 31 (scalar "Baz")
|
<*> enqueue q (31 :/ scalar "Baz" :/ Nil)
|
||||||
-- This is a realistic use. The pump inputs are pre-bound to some
|
-- This is a realistic use. The pump inputs are pre-bound to some
|
||||||
-- nodes that produce values when pumped (e.g. read from a
|
-- nodes that produce values when pumped (e.g. read from a
|
||||||
-- file).
|
-- file).
|
||||||
run_ (pump, pump)
|
run_ (pump, pump)
|
||||||
|
|
||||||
(x, y) <- run (deq, deq)
|
(x, y) <- run (deq, deq)
|
||||||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
|
||||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
liftIO $ expected @=? x
|
||||||
|
liftIO $ expected @=? y
|
||||||
|
|
||||||
testAsync :: Test
|
testAsync :: Test
|
||||||
testAsync = testCase "testAsync" $ runSession $ do
|
testAsync = testCase "testAsync" $ runSession $ do
|
||||||
(deq, pump) <- build $ do
|
(deq, pump) <- build $ do
|
||||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
||||||
(,) <$> dequeue q
|
(,) <$> dequeue q
|
||||||
<*> enqueue q 10 (scalar "Async")
|
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
||||||
-- Pumps the queue until canceled by runSession exiting.
|
-- Pumps the queue until canceled by runSession exiting.
|
||||||
asyncProdNodes pump
|
asyncProdNodes pump
|
||||||
-- Picks up a couple values and verifies they are as expected.
|
-- Picks up a couple values and verifies they are as expected.
|
||||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
|
||||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
run deq >>= liftIO . (expected @=?)
|
||||||
|
run deq >>= liftIO . (expected @=?)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testBasic
|
main = googleTest [ testBasic
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TupleSections #-}
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
|
||||||
module TensorFlow.BuildOp
|
module TensorFlow.BuildOp
|
||||||
|
@ -33,6 +35,7 @@ import Lens.Family2 ((&), (<>~), (^.))
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Output
|
import TensorFlow.Output
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
|
import TensorFlow.Types
|
||||||
|
|
||||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||||
|
|
||||||
|
@ -98,6 +101,22 @@ instance OpResult (Tensor Ref a) where
|
||||||
instance OpResult ControlNode where
|
instance OpResult ControlNode where
|
||||||
toResult = ControlNode <$> ask
|
toResult = ControlNode <$> ask
|
||||||
|
|
||||||
|
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
|
||||||
|
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
|
||||||
|
where
|
||||||
|
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||||
|
loop Nil = return Nil
|
||||||
|
loop (TensorTypeProxy :/ ls) = do
|
||||||
|
t <- tensorResult v
|
||||||
|
ts <- loop ls
|
||||||
|
return (t :/ ts)
|
||||||
|
|
||||||
|
instance TensorTypes as => OpResult (TensorList Value as) where
|
||||||
|
toResult = tensorListResult ValueKind
|
||||||
|
|
||||||
|
instance TensorTypes as => OpResult (TensorList Ref as) where
|
||||||
|
toResult = tensorListResult RefKind
|
||||||
|
|
||||||
instance OpResult a => OpResult [a] where
|
instance OpResult a => OpResult [a] where
|
||||||
toResult = do
|
toResult = do
|
||||||
ResultState i ns <- get
|
ResultState i ns <- get
|
||||||
|
@ -159,6 +178,12 @@ instance BuildOp (Tensor Value a) where
|
||||||
instance BuildOp (Tensor Ref a) where
|
instance BuildOp (Tensor Ref a) where
|
||||||
buildOp' = pureResult
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance TensorTypes as => BuildOp (TensorList Value as) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
instance TensorTypes as => BuildOp (TensorList Ref as) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
instance BuildOp [Tensor Value a] where
|
instance BuildOp [Tensor Value a] where
|
||||||
buildOp' = pureResult
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
@ -199,6 +224,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
||||||
buildOp' rf o accum ts
|
buildOp' rf o accum ts
|
||||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||||
|
|
||||||
|
instance BuildOp f => BuildOp (TensorList v as -> f) where
|
||||||
|
buildOp' rf o accum ts
|
||||||
|
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
|
||||||
|
|
||||||
-- | Returns true if all the integers in each tuple are identical.
|
-- | Returns true if all the integers in each tuple are identical.
|
||||||
-- Throws an error with a descriptive message if not.
|
-- Throws an error with a descriptive message if not.
|
||||||
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
||||||
|
|
|
@ -12,15 +12,18 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
module TensorFlow.Nodes where
|
module TensorFlow.Nodes where
|
||||||
|
|
||||||
import Control.Applicative (liftA2, liftA3)
|
import Control.Applicative (liftA2, liftA3)
|
||||||
|
import Data.Functor.Identity (Identity)
|
||||||
import Data.Map.Strict (Map)
|
import Data.Map.Strict (Map)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
|
@ -96,6 +99,19 @@ instance Nodes ControlNode where
|
||||||
instance a ~ () => Fetchable ControlNode a where
|
instance a ~ () => Fetchable ControlNode a where
|
||||||
getFetch _ = return $ pure ()
|
getFetch _ = return $ pure ()
|
||||||
|
|
||||||
|
instance Nodes (ListOf f '[]) where
|
||||||
|
getNodes _ = return Set.empty
|
||||||
|
|
||||||
|
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
|
||||||
|
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
|
||||||
|
|
||||||
|
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
|
||||||
|
getFetch _ = return $ pure Nil
|
||||||
|
|
||||||
|
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
|
||||||
|
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
|
||||||
|
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
||||||
|
|
||||||
instance Nodes (Tensor v a) where
|
instance Nodes (Tensor v a) where
|
||||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||||
|
|
||||||
|
|
|
@ -12,20 +12,30 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
|
||||||
module TensorFlow.Tensor where
|
module TensorFlow.Tensor where
|
||||||
|
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens', Traversal')
|
import Lens.Family2 (Lens', Traversal', (^.))
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
|
||||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||||
import TensorFlow.Types (TensorData(..), Attribute)
|
import TensorFlow.Types
|
||||||
|
( TensorData(..)
|
||||||
|
, Attribute
|
||||||
|
, ListOf(..)
|
||||||
|
)
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
-- | A named output of a TensorFlow operation.
|
-- | A named output of a TensorFlow operation.
|
||||||
|
@ -83,3 +93,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
|
||||||
-- TODO(judahjacobson): add more safety checks here.
|
-- TODO(judahjacobson): add more safety checks here.
|
||||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
tensorFromName v = Tensor v . fromString . Text.unpack
|
||||||
|
|
||||||
|
type TensorList v = ListOf (Tensor v)
|
||||||
|
|
||||||
|
tensorListOutputs :: TensorList v as -> [Output]
|
||||||
|
tensorListOutputs Nil = []
|
||||||
|
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
|
||||||
|
|
|
@ -13,9 +13,11 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE ConstraintKinds #-}
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
@ -36,23 +38,35 @@ module TensorFlow.Types
|
||||||
, Shape(..)
|
, Shape(..)
|
||||||
, protoShape
|
, protoShape
|
||||||
, Attribute(..)
|
, Attribute(..)
|
||||||
|
, DataType(..)
|
||||||
|
-- * Lists
|
||||||
|
, ListOf(..)
|
||||||
|
, List
|
||||||
|
, (/:/)
|
||||||
|
, TensorTypeProxy(..)
|
||||||
|
, TensorTypes(..)
|
||||||
|
, TensorTypeList
|
||||||
|
, fromTensorTypeList
|
||||||
|
, fromTensorTypes
|
||||||
-- * Type constraints
|
-- * Type constraints
|
||||||
, OneOf
|
, OneOf
|
||||||
, type (/=)
|
, type (/=)
|
||||||
|
, OneOfs
|
||||||
-- ** Implementation of constraints
|
-- ** Implementation of constraints
|
||||||
, TypeError
|
, TypeError
|
||||||
, ExcludedCase
|
, ExcludedCase
|
||||||
, TensorTypes
|
|
||||||
, NoneOf
|
, NoneOf
|
||||||
, type (\\)
|
, type (\\)
|
||||||
, Delete
|
, Delete
|
||||||
, AllTensorTypes
|
, AllTensorTypes
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.Functor.Identity (Identity(..))
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Int (Int8, Int16, Int32, Int64)
|
import Data.Int (Int8, Int16, Int32, Int64)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
|
import Data.Proxy (Proxy(..))
|
||||||
import Data.String (IsString)
|
import Data.String (IsString)
|
||||||
import Data.Word (Word8, Word16, Word64)
|
import Data.Word (Word8, Word16, Word64)
|
||||||
import Foreign.Storable (Storable)
|
import Foreign.Storable (Storable)
|
||||||
|
@ -376,6 +390,44 @@ instance Attribute [DataType] where
|
||||||
instance Attribute [Int64] where
|
instance Attribute [Int64] where
|
||||||
attrLens = list . i
|
attrLens = list . i
|
||||||
|
|
||||||
|
-- | A heterogeneous list type.
|
||||||
|
data ListOf f as where
|
||||||
|
Nil :: ListOf f '[]
|
||||||
|
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
|
||||||
|
|
||||||
|
infixr 5 :/
|
||||||
|
|
||||||
|
type family All f as :: Constraint where
|
||||||
|
All f '[] = ()
|
||||||
|
All f (a ': as) = (f a, All f as)
|
||||||
|
|
||||||
|
type family Map f as where
|
||||||
|
Map f '[] = '[]
|
||||||
|
Map f (a ': as) = f a ': Map f as
|
||||||
|
|
||||||
|
instance All Eq (Map f as) => Eq (ListOf f as) where
|
||||||
|
Nil == Nil = True
|
||||||
|
(x :/ xs) == (y :/ ys) = x == y && xs == ys
|
||||||
|
-- Newer versions of GHC use the GADT to tell that the previous cases are
|
||||||
|
-- exhaustive.
|
||||||
|
#if _GLASGOW_HASKELL__ < 800
|
||||||
|
_ == _ = False
|
||||||
|
#endif
|
||||||
|
|
||||||
|
instance All Show (Map f as) => Show (ListOf f as) where
|
||||||
|
showsPrec _ Nil = showString "Nil"
|
||||||
|
showsPrec d (x :/ xs) = showParen (d > 10)
|
||||||
|
$ showsPrec 6 x . showString " :/ "
|
||||||
|
. showsPrec 6 xs
|
||||||
|
|
||||||
|
type List = ListOf Identity
|
||||||
|
|
||||||
|
-- | Equivalent of ':/' for lists.
|
||||||
|
(/:/) :: a -> List as -> List (a ': as)
|
||||||
|
(/:/) = (:/) . Identity
|
||||||
|
|
||||||
|
infixr 5 /:/
|
||||||
|
|
||||||
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||||
--
|
--
|
||||||
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||||
|
@ -393,13 +445,38 @@ instance Attribute [Int64] where
|
||||||
--
|
--
|
||||||
-- using an enumeration of all the possible 'TensorType's.
|
-- using an enumeration of all the possible 'TensorType's.
|
||||||
type OneOf ts a
|
type OneOf ts a
|
||||||
|
-- Assert `TensorTypes ts` to make error messages a little better.
|
||||||
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
||||||
|
|
||||||
-- | A check that the input is a list of 'TensorType's.
|
type OneOfs ts as = (TensorTypes as, TensorTypes ts,
|
||||||
-- Helps improve error messages when using 'OneOf'.
|
NoneOfs (AllTensorTypes \\ ts) as)
|
||||||
|
|
||||||
|
type family NoneOfs ts as :: Constraint where
|
||||||
|
NoneOfs ts '[] = ()
|
||||||
|
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
|
||||||
|
|
||||||
|
data TensorTypeProxy a where
|
||||||
|
TensorTypeProxy :: TensorType a => TensorTypeProxy a
|
||||||
|
|
||||||
|
type TensorTypeList = ListOf TensorTypeProxy
|
||||||
|
|
||||||
|
fromTensorTypeList :: TensorTypeList ts -> [DataType]
|
||||||
|
fromTensorTypeList Nil = []
|
||||||
|
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
|
||||||
|
= tensorType (undefined :: t) : fromTensorTypeList ts
|
||||||
|
|
||||||
|
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
|
||||||
|
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
|
||||||
|
|
||||||
class TensorTypes (ts :: [*]) where
|
class TensorTypes (ts :: [*]) where
|
||||||
instance TensorTypes '[]
|
tensorTypes :: TensorTypeList ts
|
||||||
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
|
|
||||||
|
instance TensorTypes '[] where
|
||||||
|
tensorTypes = Nil
|
||||||
|
|
||||||
|
-- | A constraint that the input is a list of 'TensorTypes'.
|
||||||
|
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
|
||||||
|
tensorTypes = TensorTypeProxy :/ tensorTypes
|
||||||
|
|
||||||
-- | A constraint checking that two types are different.
|
-- | A constraint checking that two types are different.
|
||||||
type family a /= b :: Constraint where
|
type family a /= b :: Constraint where
|
||||||
|
|
Loading…
Reference in a new issue