88 lines
2.9 KiB
Haskell
88 lines
2.9 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE GADTs #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
module TensorFlow.ControlFlow
|
|
( -- * Dependencies
|
|
withControlDependencies
|
|
, group
|
|
-- * Operations
|
|
, identity
|
|
, noOp
|
|
, named
|
|
) where
|
|
|
|
import qualified Data.Set as Set
|
|
import Data.Text (Text)
|
|
import Lens.Family2 ((&), (^.), (.~))
|
|
|
|
import TensorFlow.BuildOp
|
|
import TensorFlow.Build
|
|
import TensorFlow.Nodes
|
|
import TensorFlow.Output
|
|
import TensorFlow.Tensor
|
|
import TensorFlow.Types
|
|
|
|
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
|
-- on the nodes in the first argument.
|
|
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
|
withControlDependencies deps act = do
|
|
nodes <- getNodes deps
|
|
withNodeDependencies nodes act
|
|
|
|
-- TODO(judahjacobson): Reimplement withDependencies.
|
|
|
|
-- | Create an op that groups multiple operations.
|
|
--
|
|
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
|
-- no output.
|
|
group :: Nodes t => t -> Build ControlNode
|
|
group deps = do
|
|
nodes <- Set.toList <$> getNodes deps
|
|
-- TODO: slicker way
|
|
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
|
|
|
|
|
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
|
identity :: TensorType a => Tensor v a -> Tensor v a
|
|
identity = namedIdentity implicitName
|
|
|
|
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
|
-- the input.
|
|
--
|
|
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
|
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
|
-- whether we can change that op.
|
|
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
|
named = namedIdentity . explicitName
|
|
|
|
-- | An internal version of "identity" that allows setting the name
|
|
-- of the output Tensor.
|
|
namedIdentity :: forall a v . TensorType a
|
|
=> PendingNodeName -> Tensor v a -> Tensor v a
|
|
namedIdentity n t = case t ^. tensorKind of
|
|
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
|
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
|
where
|
|
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
|
|
|
|
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
|
noOp :: ControlNode
|
|
noOp = buildOp $ opDef "NoOp"
|