79 lines
3.1 KiB
Haskell
79 lines
3.1 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
-- | Queues in TensorFlow graph. Very limited support for now.
|
|
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
|
|
|
import Data.ByteString (ByteString)
|
|
import Data.Int (Int64)
|
|
import Lens.Family2 ((.~), (&))
|
|
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
|
import TensorFlow.BuildOp (buildOp)
|
|
import TensorFlow.ControlFlow (group)
|
|
import TensorFlow.Tensor (Ref, Tensor)
|
|
import TensorFlow.Types (TensorType, tensorType)
|
|
|
|
-- | A queue carrying tuples. The underlying structure is more
|
|
-- versatile and can be made to support arbitrary tuples.
|
|
data Queue2 a b = Queue2 { handle :: Handle }
|
|
|
|
type Handle = Tensor Ref ByteString
|
|
|
|
-- | Adds the given values to the queue.
|
|
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
|
=> Queue2 a b
|
|
-> Tensor v1 a
|
|
-> Tensor v2 b
|
|
-> Build ControlNode
|
|
enqueue q =
|
|
buildOp (opDef "QueueEnqueue"
|
|
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
|
, tensorType (undefined :: b)])
|
|
(handle q)
|
|
|
|
-- | Retrieves the values from the queue.
|
|
dequeue :: forall a b . (TensorType a, TensorType b)
|
|
=> Queue2 a b
|
|
-> Build (Tensor Ref a, Tensor Ref b)
|
|
-- ^ Dequeued tensors. They are paired in a sense
|
|
-- that values appear together, even if they are
|
|
-- not consumed together.
|
|
dequeue q =
|
|
buildOp (opDef "QueueDequeue"
|
|
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
|
, tensorType (undefined :: b)])
|
|
(handle q)
|
|
|
|
-- | Creates a new queue with the given capacity and shared name.
|
|
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
|
=> Int64 -- ^ The upper bound on the number of elements in
|
|
-- this queue. Negative numbers mean no limit.
|
|
-> ByteString -- ^ If non-empty, this queue will be shared
|
|
-- under the given name across multiple sessions.
|
|
-> Build (Queue2 a b)
|
|
makeQueue2 capacity sharedName = do
|
|
q <- buildOp (opDef "FIFOQueue"
|
|
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
|
, tensorType (undefined :: b)]
|
|
& opAttr "shared_name" .~ sharedName
|
|
& opAttr "capacity" .~ capacity
|
|
)
|
|
group q >>= addInitializer
|
|
return (Queue2 q)
|
|
|
|
-- TODO(gnezdo): Figure out the closing story for queues.
|