{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)
data Queue (as :: [*]) = Queue { Queue as -> Handle
handle :: Handle }
type Handle = Tensor Ref ByteString
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
=> Queue as
-> TensorList v as
-> m ControlNode
enqueue :: Queue as -> TensorList v as -> m ControlNode
enqueue = Handle -> TensorList v as -> m ControlNode
forall (v'2 :: * -> *) (tcomponents :: [*]) (m' :: * -> *).
(MonadBuild m', TensorTypes tcomponents) =>
Handle -> TensorList v'2 tcomponents -> m' ControlNode
CoreOps.queueEnqueue (Handle -> TensorList v as -> m ControlNode)
-> (Queue as -> Handle)
-> Queue as
-> TensorList v as
-> m ControlNode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Queue as -> Handle
forall (as :: [*]). Queue as -> Handle
handle
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
=> Queue as
-> m (TensorList Value as)
dequeue :: Queue as -> m (TensorList Value as)
dequeue = Handle -> m (TensorList Value as)
forall (component_types :: [*]) (m' :: * -> *).
(MonadBuild m', TensorTypes component_types) =>
Handle -> m' (TensorList Value component_types)
CoreOps.queueDequeue (Handle -> m (TensorList Value as))
-> (Queue as -> Handle) -> Queue as -> m (TensorList Value as)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Queue as -> Handle
forall (as :: [*]). Queue as -> Handle
handle
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
=> Int64
-> ByteString
-> m (Queue as)
makeQueue :: Int64 -> ByteString -> m (Queue as)
makeQueue capacity :: Int64
capacity sharedName :: ByteString
sharedName = do
Handle
q <- Build Handle -> m Handle
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build Handle -> m Handle) -> Build Handle -> m Handle
forall a b. (a -> b) -> a -> b
$ [Int64] -> OpDef -> Build Handle
forall a. BuildResult a => [Int64] -> OpDef -> Build a
buildOp [] (OpType -> OpDef
opDef "FIFOQueue"
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef [DataType]
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "component_types" (forall (f :: * -> *). Identical f => LensLike' f OpDef [DataType])
-> [DataType] -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Proxy as -> [DataType]
forall (as :: [*]). TensorTypes as => Proxy as -> [DataType]
fromTensorTypes (Proxy as
forall k (t :: k). Proxy t
Proxy :: Proxy as)
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shared_name" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
sharedName
OpDef -> (OpDef -> OpDef) -> OpDef
forall s t. s -> (s -> t) -> t
& Text -> Lens' OpDef Int64
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "capacity" (forall (f :: * -> *). Identical f => LensLike' f OpDef Int64)
-> Int64 -> OpDef -> OpDef
forall s t a b. Setter s t a b -> b -> s -> t
.~ Int64
capacity
)
Handle -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group Handle
q m ControlNode -> (ControlNode -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer
Queue as -> m (Queue as)
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> Queue as
forall (as :: [*]). Handle -> Queue as
Queue Handle
q)