mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Support lists of tensors in ops. (#79)
Adds a new type `ListOf` which wraps a heterogeneous list; for example, `ListOf (Tensor Value) '[Int32, Float]` represents a list of two elements: a tensor of int32s and a tensor of floats. Also changes the `Queue2` type (which suppored pairs of tensors) to `Queue` (which supports arbitrary lists).
This commit is contained in:
parent
7cc6a69866
commit
9209dfc4c4
9 changed files with 273 additions and 144 deletions
|
@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
|
|||
(prettyLazyText 80 $ docOpList flags x)
|
||||
|
||||
blackList =
|
||||
-- A few data flow ops take a list of heterogeneous
|
||||
-- parameters which we don't support in general form.
|
||||
[ "HashTable"
|
||||
, "MutableDenseHashTable"
|
||||
, "MutableHashTable"
|
||||
, "MutableHashTableOfTensors"
|
||||
, "QueueDequeue"
|
||||
, "QueueDequeueMany"
|
||||
, "QueueDequeueUpTo"
|
||||
, "Stack"
|
||||
, "TensorArray"
|
||||
, "TensorArrayV2"
|
||||
, "QueueEnqueueManyV2"
|
||||
, "QueueDequeueV2"
|
||||
, "QueueDequeueUpToV2"
|
||||
, "QueueEnqueueV2"
|
||||
, "QueueDequeueManyV2"
|
||||
, "Stage"
|
||||
, "Unstage"
|
||||
-- These should be possible to support by adding a bunch of
|
||||
-- overloads with a variable number of tuple arguments.
|
||||
, "Assert"
|
||||
, "BarrierTakeMany"
|
||||
, "Print"
|
||||
, "QueueEnqueue"
|
||||
, "QueueEnqueueMany"
|
||||
-- Need list of types support.
|
||||
, "DecodeCSV"
|
||||
, "ParseExample"
|
||||
, "ParseSingleSequenceExample"
|
||||
, "RestoreV2"
|
||||
, "Save"
|
||||
, "SaveV2"
|
||||
, "SaveSlices"
|
||||
, "SymbolicGradient"
|
||||
, "_ArrayToList"
|
||||
, "_ListToArray"
|
||||
[ -- Requires the "func" type:
|
||||
"SymbolicGradient"
|
||||
-- Easy: support larger result tuples.
|
||||
, "ParseSingleSequenceExample"
|
||||
, "Skipgram"
|
||||
]
|
||||
|
|
|
@ -147,6 +147,7 @@ imports = stack [
|
|||
"import Data.ByteString (ByteString)"
|
||||
, "import Data.Complex (Complex)"
|
||||
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||
, "import Data.Proxy (Proxy(Proxy))"
|
||||
, "import Data.Word (Word8, Word16)"
|
||||
, "import Lens.Family2 ((.~), (&))"
|
||||
, "import TensorFlow.Build"
|
||||
|
@ -210,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
|||
whereClause [] = []
|
||||
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||
where
|
||||
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
||||
defineLengthAttr a = renderHaskellAttrName a <+> "="
|
||||
<+> "fromIntegral (length"
|
||||
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||
<> ") :: Int64"
|
||||
|
||||
renderHaskellAttrName :: Attr a -> Doc
|
||||
renderHaskellAttrName = renderHaskellName . attrName
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
|
@ -229,9 +233,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
<- parsedOutputs pOp]
|
||||
buildOpParts =
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||
-- Renders tensor arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+>
|
||||
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
||||
-- Renders type parameter arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
||||
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders mandatory attributes as function parameters.
|
||||
|
@ -244,6 +247,12 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
]
|
||||
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
inferredTypeExpr a
|
||||
| typeParamIsList $ attrInfo a
|
||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
|
||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||
-- debugging.
|
||||
|
@ -272,8 +281,8 @@ typeSig pOp = constraints
|
|||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ map tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||
|
@ -305,17 +314,18 @@ tensorArg p = case parsedArgCase p of
|
|||
ResourceArg -> "ResourceHandle"
|
||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||
-> "TensorList" <+> kind k <+> renderHaskellName t
|
||||
where
|
||||
kind k = case k of
|
||||
ArgTensorRef -> "Ref"
|
||||
ArgTensorValue -> "Value"
|
||||
ArgTensorEither v' -> strictText v'
|
||||
tensorType t k = let
|
||||
v = case k of
|
||||
ArgTensorRef -> "Tensor Ref"
|
||||
ArgTensorValue -> "Tensor Value"
|
||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||
a = case t of
|
||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||
ArgTypeAttr n -> renderHaskellName n
|
||||
in v <+> a
|
||||
in "Tensor" <+> kind k <+> a
|
||||
|
||||
attrComment :: Attr a -> Doc
|
||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||
|
@ -347,18 +357,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
|
|||
]
|
||||
|
||||
-- | Constraints for a given type parameter.
|
||||
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
|
||||
tensorArgConstraint :: Attr [DataType] -> [Doc]
|
||||
tensorArgConstraint a
|
||||
= ("TensorType" <+> n
|
||||
: if null typeList
|
||||
then []
|
||||
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
|
||||
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
|
||||
-- or "TensorTypes ts" or "OneOfs [..] ts".
|
||||
tensorArgConstraint :: Attr TypeParam -> Doc
|
||||
tensorArgConstraint a = case attrInfo a of
|
||||
TypeParam False Nothing -> "TensorType" <+> n
|
||||
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
|
||||
TypeParam True Nothing -> "TensorTypes" <+> n
|
||||
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
|
||||
where
|
||||
n = renderHaskellName $ attrName a
|
||||
typeList = map strictText $
|
||||
Set.toList $ Set.fromList $
|
||||
map dtTypeToHaskell $ attrInfo a
|
||||
n = renderHaskellAttrName a
|
||||
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
|
||||
typeList = ("'" <>) . brackets . commasep . map strictText .
|
||||
Set.toList . Set.fromList .
|
||||
map dtTypeToHaskell . toList
|
||||
|
||||
-- NOTE: The cases of this function should be kept in sync with
|
||||
-- TensorFlow.Types.AllTensorTypes.
|
||||
|
|
|
@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
|
|||
, Attr(..)
|
||||
, AttrType(..)
|
||||
, AttrBaseType(..)
|
||||
, TypeParam(..)
|
||||
, ParsedArg(..)
|
||||
, ParsedArgCase(..)
|
||||
, ArgType(..)
|
||||
|
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
|
|||
, explicitInputAttrs :: [Attr AttrType]
|
||||
-- ^ Attributes that must be set explicitly when creating the op.
|
||||
-- Associated with the type of the attribute.
|
||||
, inferredTypeAttrs :: [Attr [DataType]]
|
||||
, inferredTypeAttrs :: [Attr TypeParam]
|
||||
-- ^ Attributes that are type parameters.
|
||||
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If this list is empty, then any type is acceptable.
|
||||
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
||||
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||
-- from one or more of the input tensors.
|
||||
|
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
|||
| AttrType | AttrShape | AttrTensor
|
||||
deriving Eq
|
||||
|
||||
data TypeParam = TypeParam
|
||||
{ typeParamIsList :: Bool
|
||||
, typeParamRestrictions :: Maybe (NonEmpty DataType)
|
||||
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If 'Nothing', then any type is acceptable.
|
||||
}
|
||||
|
||||
-- | An input or output argument (Tensor) for an op.
|
||||
data ParsedArg = ParsedArg
|
||||
{ parsedArgName :: Name
|
||||
|
@ -120,7 +126,6 @@ data ParsedArgCase
|
|||
}
|
||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||
-- ^ A heterogeneous list.
|
||||
-- TODO(judahjacobson): Implement this.
|
||||
| ResourceArg
|
||||
|
||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||
|
@ -223,11 +228,6 @@ parseOp o = ParsedOp
|
|||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
-- Type attributes that can be inferred from at least one input or output.
|
||||
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
|
||||
$ parsedInputs ++ parsedOutputs
|
||||
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
|
||||
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
-- Integer attributes that can be inferred from the size of at least one
|
||||
-- input list.
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
|
@ -235,10 +235,14 @@ parseOp o = ParsedOp
|
|||
implicitAttrs = Set.fromList $ map tfName $
|
||||
map attrName inferredTypeAttrs
|
||||
++ map attrName inferredListSizeAttrs
|
||||
-- Attributes that can't be inferred and don't have defaults, so must be passed
|
||||
-- as separate arguments to the op.
|
||||
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
|
||||
argTypeParams = Set.fromList $ map tfName $
|
||||
mapMaybe (getArgTypeParam . parsedArgCase) $
|
||||
parsedInputs ++ parsedOutputs
|
||||
-- Attributes that can't be inferred and don't have defaults, so must be
|
||||
-- passed as separate arguments to the op.
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
|
||||
$ o ^. attr
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
|
@ -252,29 +256,30 @@ outputTensorKind a
|
|||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorValue
|
||||
|
||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr implicitAttrs a
|
||||
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr o implicitAttrs a
|
||||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||
, a ^. maybe'defaultValue == Nothing
|
||||
, t <- parseAttrType (a ^. type')
|
||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||
, t <- parseAttrType o (a ^. type')
|
||||
, t `elem` map AttrSingle
|
||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
||||
++ [AttrList AttrType] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
-- | The type attribute used by this input or output (if any).
|
||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||
parsedArgTypeAttr p = case parsedArgCase p of
|
||||
ResourceArg -> Nothing
|
||||
SimpleArg {argType = t} -> fromArgType t
|
||||
ListArg {argType = t} -> fromArgType t
|
||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
|
||||
getInferredTypeAttr argTypeParams a
|
||||
| TFName (a ^. name) `notElem` argTypeParams = Nothing
|
||||
| a ^. type' == "type" = Just $ TypeParam False allowed
|
||||
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
|
||||
| otherwise = Nothing
|
||||
where
|
||||
fromArgType (ArgTypeAttr n) = Just $ tfName n
|
||||
fromArgType _ = Nothing
|
||||
allowed = nonEmpty (a ^. allowedValues . list . type')
|
||||
|
||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||
getInferredTypeAttr a
|
||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||
| otherwise = Nothing
|
||||
getArgTypeParam :: ParsedArgCase -> Maybe Name
|
||||
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
|
||||
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
|
||||
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
|
||||
getArgTypeParam _ = Nothing
|
||||
|
||||
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
||||
getInferredListSizeAttr inputs a
|
||||
|
@ -285,7 +290,7 @@ getInferredListSizeAttr inputs a
|
|||
} <- inputs
|
||||
, TFName (a ^. name) == tfName n]
|
||||
| otherwise = Nothing
|
||||
|
||||
|
||||
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
||||
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
||||
mapMaybeAttrs f = mapMaybe $ \a -> do
|
||||
|
@ -295,7 +300,7 @@ mapMaybeAttrs f = mapMaybe $ \a -> do
|
|||
, attrDescription = a ^. description
|
||||
, attrInfo = x
|
||||
}
|
||||
|
||||
|
||||
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||
parseArg a tKind = ParsedArg
|
||||
{ parsedArgName = makeName (a ^. name)
|
||||
|
@ -317,15 +322,15 @@ parseArgCase a tKind
|
|||
maybeAttr "" = Nothing
|
||||
maybeAttr t = Just $ makeName t
|
||||
|
||||
parseAttrType :: Text -> AttrType
|
||||
parseAttrType = \case
|
||||
parseAttrType :: OpDef -> Text -> AttrType
|
||||
parseAttrType o = \case
|
||||
"string" -> AttrSingle AttrBytes
|
||||
"int" -> AttrSingle AttrInt64
|
||||
"float" -> AttrSingle AttrFloat
|
||||
"bool" -> AttrSingle AttrBool
|
||||
"type" -> AttrSingle AttrType
|
||||
"shape" -> AttrSingle AttrShape
|
||||
"tensor" -> AttrSingle AttrTensor
|
||||
"int" -> AttrSingle AttrInt64
|
||||
"float" -> AttrSingle AttrFloat
|
||||
"bool" -> AttrSingle AttrBool
|
||||
"type" -> AttrSingle AttrType
|
||||
"shape" -> AttrSingle AttrShape
|
||||
"tensor" -> AttrSingle AttrTensor
|
||||
"list(string)" -> AttrList AttrBytes
|
||||
"list(int)" -> AttrList AttrInt64
|
||||
"list(float)" -> AttrList AttrFloat
|
||||
|
@ -334,3 +339,4 @@ parseAttrType = \case
|
|||
"list(shape)" -> AttrList AttrShape
|
||||
"list(tensor)" -> AttrList AttrTensor
|
||||
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
||||
++ " for op " ++ show (o ^. name)
|
||||
|
|
|
@ -12,67 +12,65 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
-- | Queues in TensorFlow graph. Very limited support for now.
|
||||
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
||||
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.BuildOp (buildOp)
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Tensor (Ref, Tensor)
|
||||
import TensorFlow.Types (TensorType, tensorType)
|
||||
import TensorFlow.Tensor (Ref, Tensor, TensorList)
|
||||
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
||||
|
||||
-- | A queue carrying tuples. The underlying structure is more
|
||||
-- versatile and can be made to support arbitrary tuples.
|
||||
data Queue2 a b = Queue2 { handle :: Handle }
|
||||
-- | A queue carrying tuples.
|
||||
data Queue (as :: [*]) = Queue { handle :: Handle }
|
||||
|
||||
type Handle = Tensor Ref ByteString
|
||||
|
||||
-- | Adds the given values to the queue.
|
||||
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Tensor v1 a
|
||||
-> Tensor v2 b
|
||||
enqueue :: forall as v . TensorTypes as
|
||||
=> Queue as
|
||||
-> TensorList v as
|
||||
-> Build ControlNode
|
||||
enqueue q =
|
||||
buildOp (opDef "QueueEnqueue"
|
||||
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||
(handle q)
|
||||
|
||||
-- | Retrieves the values from the queue.
|
||||
dequeue :: forall a b . (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Build (Tensor Ref a, Tensor Ref b)
|
||||
-- ^ Dequeued tensors. They are paired in a sense
|
||||
dequeue :: forall as . TensorTypes as
|
||||
=> Queue as
|
||||
-> Build (TensorList Ref as)
|
||||
-- ^ Dequeued tensors. They are coupled in a sense
|
||||
-- that values appear together, even if they are
|
||||
-- not consumed together.
|
||||
dequeue q =
|
||||
buildOp (opDef "QueueDequeue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||
(handle q)
|
||||
|
||||
-- | Creates a new queue with the given capacity and shared name.
|
||||
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
||||
makeQueue :: forall as . TensorTypes as
|
||||
=> Int64 -- ^ The upper bound on the number of elements in
|
||||
-- this queue. Negative numbers mean no limit.
|
||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||
-- under the given name across multiple sessions.
|
||||
-> Build (Queue2 a b)
|
||||
makeQueue2 capacity sharedName = do
|
||||
-> Build (Queue as)
|
||||
makeQueue capacity sharedName = do
|
||||
q <- buildOp (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)]
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||
& opAttr "shared_name" .~ sharedName
|
||||
& opAttr "capacity" .~ capacity
|
||||
)
|
||||
group q >>= addInitializer
|
||||
return (Queue2 q)
|
||||
return (Queue q)
|
||||
|
||||
-- TODO(gnezdo): Figure out the closing story for queues.
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
|
@ -20,7 +21,7 @@ module Main where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Types (Scalar(..))
|
||||
import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Queue
|
||||
import TensorFlow.Session
|
||||
|
@ -39,42 +40,50 @@ import qualified Data.ByteString as BS
|
|||
-- | Test basic queue behaviors.
|
||||
testBasic :: Test
|
||||
testBasic = testCase "testBasic" $ runSession $ do
|
||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
|
||||
x <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
||||
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
||||
|
||||
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
||||
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
|
||||
y <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||
-- fetched. Equivalently we could write
|
||||
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||
-- or else allow the types to be determined by future use of the fetched
|
||||
-- value.
|
||||
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
|
||||
liftIO $ expected @=? y
|
||||
|
||||
-- | Test queue pumping.
|
||||
testPump :: Test
|
||||
testPump = testCase "testPump" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 31 (scalar "Baz")
|
||||
<*> enqueue q (31 :/ scalar "Baz" :/ Nil)
|
||||
-- This is a realistic use. The pump inputs are pre-bound to some
|
||||
-- nodes that produce values when pumped (e.g. read from a
|
||||
-- file).
|
||||
run_ (pump, pump)
|
||||
|
||||
(x, y) <- run (deq, deq)
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||
let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
|
||||
liftIO $ expected @=? x
|
||||
liftIO $ expected @=? y
|
||||
|
||||
testAsync :: Test
|
||||
testAsync = testCase "testAsync" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 10 (scalar "Async")
|
||||
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
||||
-- Pumps the queue until canceled by runSession exiting.
|
||||
asyncProdNodes pump
|
||||
-- Picks up a couple values and verifies they are as expected.
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testBasic
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.BuildOp
|
||||
|
@ -33,6 +35,7 @@ import Lens.Family2 ((&), (<>~), (^.))
|
|||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||
|
||||
|
@ -98,6 +101,22 @@ instance OpResult (Tensor Ref a) where
|
|||
instance OpResult ControlNode where
|
||||
toResult = ControlNode <$> ask
|
||||
|
||||
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
|
||||
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||
loop Nil = return Nil
|
||||
loop (TensorTypeProxy :/ ls) = do
|
||||
t <- tensorResult v
|
||||
ts <- loop ls
|
||||
return (t :/ ts)
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Value as) where
|
||||
toResult = tensorListResult ValueKind
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Ref as) where
|
||||
toResult = tensorListResult RefKind
|
||||
|
||||
instance OpResult a => OpResult [a] where
|
||||
toResult = do
|
||||
ResultState i ns <- get
|
||||
|
@ -159,6 +178,12 @@ instance BuildOp (Tensor Value a) where
|
|||
instance BuildOp (Tensor Ref a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Value as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Ref as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp [Tensor Value a] where
|
||||
buildOp' = pureResult
|
||||
|
||||
|
@ -199,6 +224,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
|||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||
|
||||
instance BuildOp f => BuildOp (TensorList v as -> f) where
|
||||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
|
||||
|
||||
-- | Returns true if all the integers in each tuple are identical.
|
||||
-- Throws an error with a descriptive message if not.
|
||||
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
||||
|
|
|
@ -12,15 +12,18 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
module TensorFlow.Nodes where
|
||||
|
||||
import Control.Applicative (liftA2, liftA3)
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Set (Set)
|
||||
|
@ -96,6 +99,19 @@ instance Nodes ControlNode where
|
|||
instance a ~ () => Fetchable ControlNode a where
|
||||
getFetch _ = return $ pure ()
|
||||
|
||||
instance Nodes (ListOf f '[]) where
|
||||
getNodes _ = return Set.empty
|
||||
|
||||
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
|
||||
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
|
||||
|
||||
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
|
||||
getFetch _ = return $ pure Nil
|
||||
|
||||
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
|
||||
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
|
||||
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
||||
|
||||
instance Nodes (Tensor v a) where
|
||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
|
|
|
@ -12,20 +12,30 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
module TensorFlow.Tensor where
|
||||
|
||||
import Data.String (IsString(..))
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal')
|
||||
import Lens.Family2 (Lens', Traversal', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
|
||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||
import TensorFlow.Types (TensorData(..), Attribute)
|
||||
import TensorFlow.Types
|
||||
( TensorData(..)
|
||||
, Attribute
|
||||
, ListOf(..)
|
||||
)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | A named output of a TensorFlow operation.
|
||||
|
@ -83,3 +93,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
|
|||
-- TODO(judahjacobson): add more safety checks here.
|
||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
||||
|
||||
type TensorList v = ListOf (Tensor v)
|
||||
|
||||
tensorListOutputs :: TensorList v as -> [Output]
|
||||
tensorListOutputs Nil = []
|
||||
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
|
||||
|
|
|
@ -13,9 +13,11 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
@ -36,23 +38,35 @@ module TensorFlow.Types
|
|||
, Shape(..)
|
||||
, protoShape
|
||||
, Attribute(..)
|
||||
, DataType(..)
|
||||
-- * Lists
|
||||
, ListOf(..)
|
||||
, List
|
||||
, (/:/)
|
||||
, TensorTypeProxy(..)
|
||||
, TensorTypes(..)
|
||||
, TensorTypeList
|
||||
, fromTensorTypeList
|
||||
, fromTensorTypes
|
||||
-- * Type constraints
|
||||
, OneOf
|
||||
, type (/=)
|
||||
, OneOfs
|
||||
-- ** Implementation of constraints
|
||||
, TypeError
|
||||
, ExcludedCase
|
||||
, TensorTypes
|
||||
, NoneOf
|
||||
, type (\\)
|
||||
, Delete
|
||||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Data.String (IsString)
|
||||
import Data.Word (Word8, Word16, Word64)
|
||||
import Foreign.Storable (Storable)
|
||||
|
@ -376,6 +390,44 @@ instance Attribute [DataType] where
|
|||
instance Attribute [Int64] where
|
||||
attrLens = list . i
|
||||
|
||||
-- | A heterogeneous list type.
|
||||
data ListOf f as where
|
||||
Nil :: ListOf f '[]
|
||||
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
|
||||
|
||||
infixr 5 :/
|
||||
|
||||
type family All f as :: Constraint where
|
||||
All f '[] = ()
|
||||
All f (a ': as) = (f a, All f as)
|
||||
|
||||
type family Map f as where
|
||||
Map f '[] = '[]
|
||||
Map f (a ': as) = f a ': Map f as
|
||||
|
||||
instance All Eq (Map f as) => Eq (ListOf f as) where
|
||||
Nil == Nil = True
|
||||
(x :/ xs) == (y :/ ys) = x == y && xs == ys
|
||||
-- Newer versions of GHC use the GADT to tell that the previous cases are
|
||||
-- exhaustive.
|
||||
#if _GLASGOW_HASKELL__ < 800
|
||||
_ == _ = False
|
||||
#endif
|
||||
|
||||
instance All Show (Map f as) => Show (ListOf f as) where
|
||||
showsPrec _ Nil = showString "Nil"
|
||||
showsPrec d (x :/ xs) = showParen (d > 10)
|
||||
$ showsPrec 6 x . showString " :/ "
|
||||
. showsPrec 6 xs
|
||||
|
||||
type List = ListOf Identity
|
||||
|
||||
-- | Equivalent of ':/' for lists.
|
||||
(/:/) :: a -> List as -> List (a ': as)
|
||||
(/:/) = (:/) . Identity
|
||||
|
||||
infixr 5 /:/
|
||||
|
||||
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||
--
|
||||
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||
|
@ -393,13 +445,38 @@ instance Attribute [Int64] where
|
|||
--
|
||||
-- using an enumeration of all the possible 'TensorType's.
|
||||
type OneOf ts a
|
||||
-- Assert `TensorTypes ts` to make error messages a little better.
|
||||
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
||||
|
||||
-- | A check that the input is a list of 'TensorType's.
|
||||
-- Helps improve error messages when using 'OneOf'.
|
||||
type OneOfs ts as = (TensorTypes as, TensorTypes ts,
|
||||
NoneOfs (AllTensorTypes \\ ts) as)
|
||||
|
||||
type family NoneOfs ts as :: Constraint where
|
||||
NoneOfs ts '[] = ()
|
||||
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
|
||||
|
||||
data TensorTypeProxy a where
|
||||
TensorTypeProxy :: TensorType a => TensorTypeProxy a
|
||||
|
||||
type TensorTypeList = ListOf TensorTypeProxy
|
||||
|
||||
fromTensorTypeList :: TensorTypeList ts -> [DataType]
|
||||
fromTensorTypeList Nil = []
|
||||
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
|
||||
= tensorType (undefined :: t) : fromTensorTypeList ts
|
||||
|
||||
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
|
||||
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
|
||||
|
||||
class TensorTypes (ts :: [*]) where
|
||||
instance TensorTypes '[]
|
||||
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
|
||||
tensorTypes :: TensorTypeList ts
|
||||
|
||||
instance TensorTypes '[] where
|
||||
tensorTypes = Nil
|
||||
|
||||
-- | A constraint that the input is a list of 'TensorTypes'.
|
||||
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
|
||||
tensorTypes = TensorTypeProxy :/ tensorTypes
|
||||
|
||||
-- | A constraint checking that two types are different.
|
||||
type family a /= b :: Constraint where
|
||||
|
|
Loading…
Reference in a new issue