{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)
data Queue (as :: [*]) = Queue { handle :: Handle }
type Handle = Tensor Ref ByteString
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
=> Queue as
-> TensorList v as
-> m ControlNode
enqueue = CoreOps.queueEnqueue . handle
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
=> Queue as
-> m (TensorList Value as)
dequeue = CoreOps.queueDequeue . handle
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
=> Int64
-> ByteString
-> m (Queue as)
makeQueue capacity sharedName = do
q <- build $ buildOp [] (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity
)
group q >>= addInitializer
return (Queue q)