-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module TensorFlow.ControlFlow
    ( -- * Dependencies
      withControlDependencies
    , group
      -- * Operations
    , identity
    , noOp
    , named
    ) where

import qualified Data.Set as Set
import Data.Text (Text)
import Lens.Family2 ((&), (^.), (.~))

import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types

-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
withControlDependencies :: Nodes t => t -> Build a -> Build a
withControlDependencies deps act = do
    nodes <- getNodes deps
    withNodeDependencies nodes act

-- TODO(judahjacobson): Reimplement withDependencies.

-- | Create an op that groups multiple operations.
--
-- When this op finishes, all ops in the input @n@ have finished.  This op has
-- no output.
group :: Nodes t => t -> Build ControlNode
group deps = do
    nodes <- Set.toList <$> getNodes deps
    -- TODO: slicker way
    return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes


-- | Returns a 'Tensor' with the same shape and contents as the input.
identity :: TensorType a => Tensor v a -> Tensor v a
identity = namedIdentity implicitName

-- | Returns a 'Tensor' with a given name and the same shape and contents as
-- the input.
--
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput().  Look into
-- whether we can change that op.
named :: TensorType a => Text -> Tensor v a -> Tensor v a
named = namedIdentity . explicitName

-- | An internal version of "identity" that allows setting the name
-- of the output Tensor.
namedIdentity :: forall a v . TensorType a
              => PendingNodeName -> Tensor v a -> Tensor v a
namedIdentity n t = case t ^. tensorKind of
                      ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
                      RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
  where
    setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)


-- | Does nothing.  Only useful as a placeholder for control edges.
noOp :: ControlNode
noOp = buildOp $ opDef "NoOp"