-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

module TensorFlow.BuildOp
    ( BuildResult(..)
    , buildOp
    , PureResult(..)
    , pureOp
    , eqLengthGuard
    , BuildInputs(..)
    , OpParams
    )
  where

import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64)

import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types

data ResultState = ResultState !OutputIx [Int64] deriving Show

type Result = ReaderT NodeName (State ResultState)

-- | Class of types that can be used as op outputs.
class BuildResult a where
    buildResult :: Result a

instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
    buildResult = (,) <$> buildResult <*> buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
    buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
         => BuildResult (a1, a2, a3, a4) where
    buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult

instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
         => BuildResult (a1, a2, a3, a4, a5) where
    buildResult = (,,,,) <$> buildResult
                      <*> buildResult
                      <*> buildResult
                      <*> buildResult
                      <*> buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         )
         => BuildResult (a1, a2, a3, a4, a5, a6) where
    buildResult = (,,,,,)
               <$> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         , BuildResult a7
         )
         => BuildResult (a1, a2, a3, a4, a5, a6, a7) where
    buildResult = (,,,,,,)
               <$> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult

instance ( BuildResult a1
         , BuildResult a2
         , BuildResult a3
         , BuildResult a4
         , BuildResult a5
         , BuildResult a6
         , BuildResult a7
         , BuildResult a8
         )
         => BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where
    buildResult = (,,,,,,,)
               <$> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult
               <*> buildResult

recordResult :: Result Output
recordResult = do
    o <- ask
    ResultState i ns <- get
    put $! ResultState (i+1) ns
    return $! output i o

instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
    buildResult = Tensor . pure <$> recordResult

instance BuildResult ControlNode where
    buildResult = ControlNode <$> ask

instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
  buildResult = loop (tensorTypes :: TensorTypeList as)
    where
        loop :: TensorTypeList bs -> Result (TensorList v bs)
        loop Nil = return Nil
        loop (TensorTypeProxy :/ ls) = do
            t <- buildResult
            ts <- loop ls
            return (t :/ ts)

instance BuildResult a => BuildResult [a] where
    buildResult = do
        ResultState i ns <- get
        case ns of
            [] -> error $ "Ran out of counts in buildResult. " ++
                          "Likely misuse of buildOp."
            (n : rest) -> do
                put $! ResultState i rest
                replicateM (fromIntegral n) buildResult

buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
buildOp sizes o = do
    n <- addNewOp o
    return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)

-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = all eachOk
  where
    eachOk (_, []) = True
    -- The next line has (== 1) . length . nub in disguise
    eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
        error ("number_attr " ++ numberAttrName ++
               " contains tensors with different length " ++ show pairs)

-----------


-- | Class of types that can be used as op outputs.
class PureResult a where
    pureResult :: ReaderT (Build OpDef) (State ResultState) a

instance PureResult (Tensor Build a) where
    pureResult = do
        ResultState i ns <- get
        put $! ResultState (i+1) ns
        makeOp <- ask
        return $ Tensor $ do
            o <- makeOp
            -- TODO: unify with BuildResult (Tensor v)
            output i <$> getOrAddOp o

instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
    pureResult = (,) <$> pureResult <*> pureResult

instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
    pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult

instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
         => PureResult (a1, a2, a3, a4) where
    pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult

instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
         => PureResult (a1, a2, a3, a4, a5) where
    pureResult = (,,,,) <$> pureResult
                      <*> pureResult
                      <*> pureResult
                      <*> pureResult
                      <*> pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         )
         => PureResult (a1, a2, a3, a4, a5, a6) where
    pureResult = (,,,,,)
               <$> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         , PureResult a7
         )
         => PureResult (a1, a2, a3, a4, a5, a6, a7) where
    pureResult = (,,,,,,)
               <$> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult

instance ( PureResult a1
         , PureResult a2
         , PureResult a3
         , PureResult a4
         , PureResult a5
         , PureResult a6
         , PureResult a7
         , PureResult a8
         )
         => PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where
    pureResult = (,,,,,,,)
               <$> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult
               <*> pureResult

instance PureResult a => PureResult [a] where
    pureResult = do
        ResultState i ns <- get
        case ns of
            [] -> error $ "Ran out of counts in pureResult. " ++
                          "Likely misuse of pureOp with output lists."
            n : rest -> do
                put $! ResultState i rest
                replicateM (fromIntegral n) pureResult

instance TensorTypes as => PureResult (TensorList Build as) where
    pureResult = loop (tensorTypes :: TensorTypeList as)
      where
        loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
                                        (TensorList Build bs)
        loop Nil = return Nil
        loop (TensorTypeProxy :/ ls) = do
            t <- pureResult
            ts <- loop ls
            return (t :/ ts)

pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)

-----
-- Class of types that can be used as arguments

class BuildInputs a where
    buildInputs :: a -> Build [Output]

instance BuildInputs a => BuildInputs [a] where
    buildInputs = fmap concat . mapM buildInputs

instance BuildInputs (Tensor v a) where
    buildInputs (Tensor t) = do
        o <- toBuild t
        return [o]

instance BuildInputs (ListOf (Tensor v) as) where
    buildInputs Nil = return []
    buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)

----

-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef