Support lists of tensors in ops. (#79)

Adds a new type `ListOf` which wraps a heterogeneous list; for example,
`ListOf (Tensor Value) '[Int32, Float]` represents a list of two
elements: a tensor of int32s and a tensor of floats.

Also changes the `Queue2` type (which suppored pairs of tensors) to
`Queue` (which supports arbitrary lists).
This commit is contained in:
Judah Jacobson 2017-03-17 13:53:19 -07:00 committed by Greg Steuck
parent 7cc6a69866
commit 9209dfc4c4
9 changed files with 273 additions and 144 deletions

View File

@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
(prettyLazyText 80 $ docOpList flags x) (prettyLazyText 80 $ docOpList flags x)
blackList = blackList =
-- A few data flow ops take a list of heterogeneous [ -- Requires the "func" type:
-- parameters which we don't support in general form. "SymbolicGradient"
[ "HashTable"
, "MutableDenseHashTable"
, "MutableHashTable"
, "MutableHashTableOfTensors"
, "QueueDequeue"
, "QueueDequeueMany"
, "QueueDequeueUpTo"
, "Stack"
, "TensorArray"
, "TensorArrayV2"
, "QueueEnqueueManyV2"
, "QueueDequeueV2"
, "QueueDequeueUpToV2"
, "QueueEnqueueV2"
, "QueueDequeueManyV2"
, "Stage"
, "Unstage"
-- These should be possible to support by adding a bunch of
-- overloads with a variable number of tuple arguments.
, "Assert"
, "BarrierTakeMany"
, "Print"
, "QueueEnqueue"
, "QueueEnqueueMany"
-- Need list of types support.
, "DecodeCSV"
, "ParseExample"
, "ParseSingleSequenceExample"
, "RestoreV2"
, "Save"
, "SaveV2"
, "SaveSlices"
, "SymbolicGradient"
, "_ArrayToList"
, "_ListToArray"
-- Easy: support larger result tuples. -- Easy: support larger result tuples.
, "ParseSingleSequenceExample"
, "Skipgram" , "Skipgram"
] ]

View File

@ -147,6 +147,7 @@ imports = stack [
"import Data.ByteString (ByteString)" "import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)" , "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)" , "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Proxy (Proxy(Proxy))"
, "import Data.Word (Word8, Word16)" , "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))" , "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build" , "import TensorFlow.Build"
@ -210,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = [] whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)] whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where where
defineLengthAttr a = renderHaskellName (attrName a) <+> "=" defineLengthAttr a = renderHaskellAttrName a <+> "="
<+> "fromIntegral (length" <+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a) <+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64" <> ") :: Int64"
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs) </> indent indentation (sep tensorArgs)
@ -229,9 +233,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
<- parsedOutputs pOp] <- parsedOutputs pOp]
buildOpParts = buildOpParts =
"opDef" <+> renderQuotedTFName (parsedOpName pOp) : "opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments. -- Renders type parameter arguments.
[ "& opAttr" <+> renderQuotedTFName n <+> [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
| a <- inferredTypeAttrs pOp, let n = attrName a | a <- inferredTypeAttrs pOp, let n = attrName a
] ++ ] ++
-- Renders mandatory attributes as function parameters. -- Renders mandatory attributes as function parameters.
@ -244,6 +247,12 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
] ]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
<> ")"
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
<> ")"
-- | Write a comment with the inputs/outputs/attributes in proto format, for -- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging. -- debugging.
@ -272,8 +281,8 @@ typeSig pOp = constraints
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint classConstraints = tuple $ map tensorArgConstraint
$ inferredTypeAttrs pOp $ inferredTypeAttrs pOp
signatureFold = folddoc (\x y -> x </> "->" <+> y) signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a) attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
@ -305,17 +314,18 @@ tensorArg p = case parsedArgCase p of
ResourceArg -> "ResourceHandle" ResourceArg -> "ResourceHandle"
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> kind k <+> renderHaskellName t
where where
kind k = case k of
ArgTensorRef -> "Ref"
ArgTensorValue -> "Value"
ArgTensorEither v' -> strictText v'
tensorType t k = let tensorType t k = let
v = case k of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
a = case t of a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n ArgTypeAttr n -> renderHaskellName n
in v <+> a in "Tensor" <+> kind k <+> a
attrComment :: Attr a -> Doc attrComment :: Attr a -> Doc
attrComment a = argComment' (attrName a) (attrDescription a) attrComment a = argComment' (attrName a) (attrDescription a)
@ -347,18 +357,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
] ]
-- | Constraints for a given type parameter. -- | Constraints for a given type parameter.
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"] -- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
tensorArgConstraint :: Attr [DataType] -> [Doc] -- or "TensorTypes ts" or "OneOfs [..] ts".
tensorArgConstraint a tensorArgConstraint :: Attr TypeParam -> Doc
= ("TensorType" <+> n tensorArgConstraint a = case attrInfo a of
: if null typeList TypeParam False Nothing -> "TensorType" <+> n
then [] TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n]) TypeParam True Nothing -> "TensorTypes" <+> n
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
where where
n = renderHaskellName $ attrName a n = renderHaskellAttrName a
typeList = map strictText $ -- Produces a type-level list, e.g.: '[Int32,Int64,Float]
Set.toList $ Set.fromList $ typeList = ("'" <>) . brackets . commasep . map strictText .
map dtTypeToHaskell $ attrInfo a Set.toList . Set.fromList .
map dtTypeToHaskell . toList
-- NOTE: The cases of this function should be kept in sync with -- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes. -- TensorFlow.Types.AllTensorTypes.

View File

@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
, Attr(..) , Attr(..)
, AttrType(..) , AttrType(..)
, AttrBaseType(..) , AttrBaseType(..)
, TypeParam(..)
, ParsedArg(..) , ParsedArg(..)
, ParsedArgCase(..) , ParsedArgCase(..)
, ArgType(..) , ArgType(..)
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
, explicitInputAttrs :: [Attr AttrType] , explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op. -- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute. -- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr [DataType]] , inferredTypeAttrs :: [Attr TypeParam]
-- ^ Attributes that are type parameters. -- ^ Attributes that are type parameters.
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
-- If this list is empty, then any type is acceptable.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)] , inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically -- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors. -- from one or more of the input tensors.
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor | AttrType | AttrShape | AttrTensor
deriving Eq deriving Eq
data TypeParam = TypeParam
{ typeParamIsList :: Bool
, typeParamRestrictions :: Maybe (NonEmpty DataType)
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
-- If 'Nothing', then any type is acceptable.
}
-- | An input or output argument (Tensor) for an op. -- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg data ParsedArg = ParsedArg
{ parsedArgName :: Name { parsedArgName :: Name
@ -120,7 +126,6 @@ data ParsedArgCase
} }
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list. -- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
| ResourceArg | ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind argKind :: ParsedArgCase -> Maybe ArgKind
@ -223,11 +228,6 @@ parseOp o = ParsedOp
(o ^. inputArg) tensorKindParams (o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
-- Type attributes that can be inferred from at least one input or output.
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
$ parsedInputs ++ parsedOutputs
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Integer attributes that can be inferred from the size of at least one -- Integer attributes that can be inferred from the size of at least one
-- input list. -- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
@ -235,10 +235,14 @@ parseOp o = ParsedOp
implicitAttrs = Set.fromList $ map tfName $ implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs ++ map attrName inferredListSizeAttrs
-- Attributes that can't be inferred and don't have defaults, so must be passed inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
-- as separate arguments to the op. argTypeParams = Set.fromList $ map tfName $
mapMaybe (getArgTypeParam . parsedArgCase) $
parsedInputs ++ parsedOutputs
-- Attributes that can't be inferred and don't have defaults, so must be
-- passed as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName)) explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) $ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
$ o ^. attr $ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs. -- TODO(judahjacobson): Some arguments should be refs.
@ -252,29 +256,30 @@ outputTensorKind a
| a ^. isRef = ArgTensorRef | a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue | otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr implicitAttrs a getExplicitInputAttr o implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs | TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing , a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type') , t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t , t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
++ [AttrList AttrType] = Just t
| otherwise = Nothing | otherwise = Nothing
-- | The type attribute used by this input or output (if any). getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
parsedArgTypeAttr :: ParsedArg -> Maybe TFName getInferredTypeAttr argTypeParams a
parsedArgTypeAttr p = case parsedArgCase p of | TFName (a ^. name) `notElem` argTypeParams = Nothing
ResourceArg -> Nothing | a ^. type' == "type" = Just $ TypeParam False allowed
SimpleArg {argType = t} -> fromArgType t | a ^. type' == "list(type)" = Just $ TypeParam True allowed
ListArg {argType = t} -> fromArgType t | otherwise = Nothing
MixedListArg {argTypeAttr = n} -> Just $ tfName n
where where
fromArgType (ArgTypeAttr n) = Just $ tfName n allowed = nonEmpty (a ^. allowedValues . list . type')
fromArgType _ = Nothing
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] getArgTypeParam :: ParsedArgCase -> Maybe Name
getInferredTypeAttr a getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type' getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
| otherwise = Nothing getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
getArgTypeParam _ = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name) getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a getInferredListSizeAttr inputs a
@ -285,7 +290,7 @@ getInferredListSizeAttr inputs a
} <- inputs } <- inputs
, TFName (a ^. name) == tfName n] , TFName (a ^. name) == tfName n]
| otherwise = Nothing | otherwise = Nothing
-- | Like mapMaybe, but associates the attribute name/description with the given info. -- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a] mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do mapMaybeAttrs f = mapMaybe $ \a -> do
@ -295,7 +300,7 @@ mapMaybeAttrs f = mapMaybe $ \a -> do
, attrDescription = a ^. description , attrDescription = a ^. description
, attrInfo = x , attrInfo = x
} }
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name) { parsedArgName = makeName (a ^. name)
@ -317,15 +322,15 @@ parseArgCase a tKind
maybeAttr "" = Nothing maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t maybeAttr t = Just $ makeName t
parseAttrType :: Text -> AttrType parseAttrType :: OpDef -> Text -> AttrType
parseAttrType = \case parseAttrType o = \case
"string" -> AttrSingle AttrBytes "string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64 "int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat "float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool "bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType "type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape "shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor "tensor" -> AttrSingle AttrTensor
"list(string)" -> AttrList AttrBytes "list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64 "list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat "list(float)" -> AttrList AttrFloat
@ -334,3 +339,4 @@ parseAttrType = \case
"list(shape)" -> AttrList AttrShape "list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor "list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name)

View File

@ -12,67 +12,65 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
-- | Queues in TensorFlow graph. Very limited support for now. -- | Queues in TensorFlow graph. Very limited support for now.
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&)) import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef) import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp) import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group) import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor) import TensorFlow.Tensor (Ref, Tensor, TensorList)
import TensorFlow.Types (TensorType, tensorType) import TensorFlow.Types (TensorTypes, fromTensorTypes)
-- | A queue carrying tuples. The underlying structure is more -- | A queue carrying tuples.
-- versatile and can be made to support arbitrary tuples. data Queue (as :: [*]) = Queue { handle :: Handle }
data Queue2 a b = Queue2 { handle :: Handle }
type Handle = Tensor Ref ByteString type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue. -- | Adds the given values to the queue.
enqueue :: forall a b v1 v2. (TensorType a, TensorType b) enqueue :: forall as v . TensorTypes as
=> Queue2 a b => Queue as
-> Tensor v1 a -> TensorList v as
-> Tensor v2 b
-> Build ControlNode -> Build ControlNode
enqueue q = enqueue q =
buildOp (opDef "QueueEnqueue" buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a) & opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
, tensorType (undefined :: b)])
(handle q) (handle q)
-- | Retrieves the values from the queue. -- | Retrieves the values from the queue.
dequeue :: forall a b . (TensorType a, TensorType b) dequeue :: forall as . TensorTypes as
=> Queue2 a b => Queue as
-> Build (Tensor Ref a, Tensor Ref b) -> Build (TensorList Ref as)
-- ^ Dequeued tensors. They are paired in a sense -- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are -- that values appear together, even if they are
-- not consumed together. -- not consumed together.
dequeue q = dequeue q =
buildOp (opDef "QueueDequeue" buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ [ tensorType (undefined :: a) & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
, tensorType (undefined :: b)])
(handle q) (handle q)
-- | Creates a new queue with the given capacity and shared name. -- | Creates a new queue with the given capacity and shared name.
makeQueue2 :: forall a b . (TensorType a, TensorType b) makeQueue :: forall as . TensorTypes as
=> Int64 -- ^ The upper bound on the number of elements in => Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit. -- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared -> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions. -- under the given name across multiple sessions.
-> Build (Queue2 a b) -> Build (Queue as)
makeQueue2 capacity sharedName = do makeQueue capacity sharedName = do
q <- buildOp (opDef "FIFOQueue" q <- buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ [ tensorType (undefined :: a) & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
, tensorType (undefined :: b)]
& opAttr "shared_name" .~ sharedName & opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity & opAttr "capacity" .~ capacity
) )
group q >>= addInitializer group q >>= addInitializer
return (Queue2 q) return (Queue q)
-- TODO(gnezdo): Figure out the closing story for queues. -- TODO(gnezdo): Figure out the closing story for queues.

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -20,7 +21,7 @@ module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64) import Data.Int (Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import TensorFlow.Types (Scalar(..)) import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
import TensorFlow.Ops (scalar) import TensorFlow.Ops (scalar)
import TensorFlow.Queue import TensorFlow.Queue
import TensorFlow.Session import TensorFlow.Session
@ -39,42 +40,50 @@ import qualified Data.ByteString as BS
-- | Test basic queue behaviors. -- | Test basic queue behaviors.
testBasic :: Test testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do testBasic = testCase "testBasic" $ runSession $ do
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 "" q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
buildAnd run_ (enqueue q 42 (scalar "Hi")) buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
x <- buildAnd run (dequeue q) x <- buildAnd run (dequeue q)
liftIO $ (Scalar 42, Scalar "Hi") @=? x liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
buildAnd run_ (enqueue q 56 (scalar "Bar")) buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
y <- buildAnd run (dequeue q) y <- buildAnd run (dequeue q)
liftIO $ (Scalar 56, Scalar "Bar") @=? y -- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
-- or else allow the types to be determined by future use of the fetched
-- value.
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
liftIO $ expected @=? y
-- | Test queue pumping. -- | Test queue pumping.
testPump :: Test testPump :: Test
testPump = testCase "testPump" $ runSession $ do testPump = testCase "testPump" $ runSession $ do
(deq, pump) <- build $ do (deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue" q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
(,) <$> dequeue q (,) <$> dequeue q
<*> enqueue q 31 (scalar "Baz") <*> enqueue q (31 :/ scalar "Baz" :/ Nil)
-- This is a realistic use. The pump inputs are pre-bound to some -- This is a realistic use. The pump inputs are pre-bound to some
-- nodes that produce values when pumped (e.g. read from a -- nodes that produce values when pumped (e.g. read from a
-- file). -- file).
run_ (pump, pump) run_ (pump, pump)
(x, y) <- run (deq, deq) (x, y) <- run (deq, deq)
liftIO $ (Scalar 31, Scalar "Baz") @=? x let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
liftIO $ (Scalar 31, Scalar "Baz") @=? y liftIO $ expected @=? x
liftIO $ expected @=? y
testAsync :: Test testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do (deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "" q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
(,) <$> dequeue q (,) <$> dequeue q
<*> enqueue q 10 (scalar "Async") <*> enqueue q (10 :/ scalar "Async" :/ Nil)
-- Pumps the queue until canceled by runSession exiting. -- Pumps the queue until canceled by runSession exiting.
asyncProdNodes pump asyncProdNodes pump
-- Picks up a couple values and verifies they are as expected. -- Picks up a couple values and verifies they are as expected.
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) run deq >>= liftIO . (expected @=?)
run deq >>= liftIO . (expected @=?)
main :: IO () main :: IO ()
main = googleTest [ testBasic main = googleTest [ testBasic

View File

@ -13,6 +13,8 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-} {-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp module TensorFlow.BuildOp
@ -33,6 +35,7 @@ import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Output import TensorFlow.Output
import TensorFlow.Tensor import TensorFlow.Tensor
import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show data ResultState = ResultState !OutputIx [Int64] deriving Show
@ -98,6 +101,22 @@ instance OpResult (Tensor Ref a) where
instance OpResult ControlNode where instance OpResult ControlNode where
toResult = ControlNode <$> ask toResult = ControlNode <$> ask
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- tensorResult v
ts <- loop ls
return (t :/ ts)
instance TensorTypes as => OpResult (TensorList Value as) where
toResult = tensorListResult ValueKind
instance TensorTypes as => OpResult (TensorList Ref as) where
toResult = tensorListResult RefKind
instance OpResult a => OpResult [a] where instance OpResult a => OpResult [a] where
toResult = do toResult = do
ResultState i ns <- get ResultState i ns <- get
@ -159,6 +178,12 @@ instance BuildOp (Tensor Value a) where
instance BuildOp (Tensor Ref a) where instance BuildOp (Tensor Ref a) where
buildOp' = pureResult buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Value as) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Ref as) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where instance BuildOp [Tensor Value a] where
buildOp' = pureResult buildOp' = pureResult
@ -199,6 +224,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum) = buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
instance BuildOp f => BuildOp (TensorList v as -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical. -- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not. -- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool eqLengthGuard :: [(String, [(String, Int)])] -> Bool

View File

@ -12,15 +12,18 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Nodes where module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3) import Control.Applicative (liftA2, liftA3)
import Data.Functor.Identity (Identity)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.Set (Set) import Data.Set (Set)
@ -96,6 +99,19 @@ instance Nodes ControlNode where
instance a ~ () => Fetchable ControlNode a where instance a ~ () => Fetchable ControlNode a where
getFetch _ = return $ pure () getFetch _ = return $ pure ()
instance Nodes (ListOf f '[]) where
getNodes _ = return Set.empty
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
getFetch _ = return $ pure Nil
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
instance Nodes (Tensor v a) where instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)

View File

@ -12,20 +12,30 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} {-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Tensor where module TensorFlow.Tensor where
import Data.String (IsString(..)) import Data.String (IsString(..))
import qualified Data.Text as Text import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal') import Lens.Family2 (Lens', Traversal', (^.))
import Lens.Family2.Unchecked (lens) import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
import TensorFlow.Types (TensorData(..), Attribute) import TensorFlow.Types
( TensorData(..)
, Attribute
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation. -- | A named output of a TensorFlow operation.
@ -83,3 +93,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
-- TODO(judahjacobson): add more safety checks here. -- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack tensorFromName v = Tensor v . fromString . Text.unpack
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts

View File

@ -13,9 +13,11 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
@ -36,23 +38,35 @@ module TensorFlow.Types
, Shape(..) , Shape(..)
, protoShape , protoShape
, Attribute(..) , Attribute(..)
, DataType(..)
-- * Lists
, ListOf(..)
, List
, (/:/)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
, fromTensorTypeList
, fromTensorTypes
-- * Type constraints -- * Type constraints
, OneOf , OneOf
, type (/=) , type (/=)
, OneOfs
-- ** Implementation of constraints -- ** Implementation of constraints
, TypeError , TypeError
, ExcludedCase , ExcludedCase
, TensorTypes
, NoneOf , NoneOf
, type (\\) , type (\\)
, Delete , Delete
, AllTensorTypes , AllTensorTypes
) where ) where
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex) import Data.Complex (Complex)
import Data.Default (def) import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64) import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Data.String (IsString) import Data.String (IsString)
import Data.Word (Word8, Word16, Word64) import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable) import Foreign.Storable (Storable)
@ -376,6 +390,44 @@ instance Attribute [DataType] where
instance Attribute [Int64] where instance Attribute [Int64] where
attrLens = list . i attrLens = list . i
-- | A heterogeneous list type.
data ListOf f as where
Nil :: ListOf f '[]
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
infixr 5 :/
type family All f as :: Constraint where
All f '[] = ()
All f (a ': as) = (f a, All f as)
type family Map f as where
Map f '[] = '[]
Map f (a ': as) = f a ': Map f as
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == Nil = True
(x :/ xs) == (y :/ ys) = x == y && xs == ys
-- Newer versions of GHC use the GADT to tell that the previous cases are
-- exhaustive.
#if _GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec _ Nil = showString "Nil"
showsPrec d (x :/ xs) = showParen (d > 10)
$ showsPrec 6 x . showString " :/ "
. showsPrec 6 xs
type List = ListOf Identity
-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
(/:/) = (:/) . Identity
infixr 5 /:/
-- | A 'Constraint' specifying the possible choices of a 'TensorType'. -- | A 'Constraint' specifying the possible choices of a 'TensorType'.
-- --
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the -- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
@ -393,13 +445,38 @@ instance Attribute [Int64] where
-- --
-- using an enumeration of all the possible 'TensorType's. -- using an enumeration of all the possible 'TensorType's.
type OneOf ts a type OneOf ts a
-- Assert `TensorTypes ts` to make error messages a little better.
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a) = (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
-- | A check that the input is a list of 'TensorType's. type OneOfs ts as = (TensorTypes as, TensorTypes ts,
-- Helps improve error messages when using 'OneOf'. NoneOfs (AllTensorTypes \\ ts) as)
type family NoneOfs ts as :: Constraint where
NoneOfs ts '[] = ()
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
data TensorTypeProxy a where
TensorTypeProxy :: TensorType a => TensorTypeProxy a
type TensorTypeList = ListOf TensorTypeProxy
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
= tensorType (undefined :: t) : fromTensorTypeList ts
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
class TensorTypes (ts :: [*]) where class TensorTypes (ts :: [*]) where
instance TensorTypes '[] tensorTypes :: TensorTypeList ts
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
instance TensorTypes '[] where
tensorTypes = Nil
-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
tensorTypes = TensorTypeProxy :/ tensorTypes
-- | A constraint checking that two types are different. -- | A constraint checking that two types are different.
type family a /= b :: Constraint where type family a /= b :: Constraint where