-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}

module TensorFlow.Output
    ( ControlNode(..)
    , Device(..)
    -- * Ops
    , NodeName(..)
    , Op(..)
    , opUnrendered
    , OpDef(..)
    , opName
    , opType
    , opAttr
    , opInputs
    , opControlInputs
    , OpType(..)
    , OutputIx(..)
    , Output(..)
    , output
    , outputIndex
    , outputOp
    , PendingNodeName(..)
    )  where

import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans ()

-- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: Op }

-- | The type of op of a node in the graph.  This corresponds to the proto field
-- NodeDef.op.
newtype OpType = OpType { unOpType :: Text }
    deriving (Eq, Ord, Show)

instance IsString OpType where
    fromString = OpType . Text.pack

-- | An output of a TensorFlow node.
data Output = Output !OutputIx !Op
    deriving (Eq, Ord, Show)

output :: OutputIx -> Op -> Output
output = Output

outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)

outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)

newtype OutputIx = OutputIx { unOutputIx :: Int }
    deriving (Eq, Ord, Num, Enum, Show)

-- | A device that a node can be assigned to.
-- There's a naming convention where the device names
-- are constructed from job and replica names.
newtype Device = Device {deviceName :: Text}
    deriving (Eq, Ord, IsString)

instance Show Device where
    show (Device d) = show d

-- | The representation of a node in a TensorFlow graph.
data Op
    = Rendered !NodeDef  -- ^ Properties are fixed, including the
                         -- device, name, and scope.
    | Unrendered !OpDef  -- ^ Properties are not fixed, and may change depending
                         -- on which context this op is rendered in.
    deriving (Eq, Ord)

instance Show Op where
    show (Rendered n) = "Rendered " ++ showMessage n
    show (Unrendered o) = "Unrendered " ++ show (o ^. opName)

-- | Traverse on the 'Unrendered' of an 'Op'.
--
-- Same implementation as _Left.
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)

-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef
    { _opName :: !PendingNodeName
    , _opType :: !OpType
    , _opAttrs :: !(Map.Map Text AttrValue)
    , _opInputs :: [Output]
    , _opControlInputs :: [NodeName]
    }  deriving (Eq, Ord)

-- | The name specified for an unrendered Op.  If an Op has an
-- ImplicitName, it will be assigned based on the opType plus a
-- unique identifier.  Does not contain the "scope" prefix.
data PendingNodeName = ExplicitName !Text | ImplicitName
    deriving (Eq, Ord, Show)

-- | The name of a node in the graph.  This corresponds to the proto field
-- NodeDef.name.  Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).
newtype NodeName = NodeName { unNodeName :: Text }
    deriving (Eq, Ord, Show)

opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})

opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})

opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
              . lens (Map.findWithDefault def n) (flip (Map.insert n))
              . attrLens

opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})

opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})

-- TODO(gnezdo): IsString instance is weird and we should move that
-- code into a Build function
instance IsString Output where
    fromString s = case break (==':') s of
        (n, ':':ixStr)
            | [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
        _ -> Output 0 $ assigned s
        where assigned n = Rendered $ def & name .~ Text.pack n