-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module TensorFlow.Internal.FFI
    ( TensorFlowException(..)
    , Raw.Session
    , withSession
    , extendGraph
    , run
    , TensorData(..)
    , setSessionConfig
    , setSessionTarget
    , getAllOpList
      -- * Internal helper.
    , useProtoAsVoidPtrLen
    )
    where

import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Foreign.Concurrent as ForeignC

import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)

import qualified TensorFlow.Internal.Raw as Raw

data TensorFlowException = TensorFlowException Raw.Code T.Text
    deriving (Show, Eq, Typeable)

instance Exception TensorFlowException

-- | All of the data needed to represent a tensor.
data TensorData = TensorData
    { tensorDataDimensions :: [Int64]
    , tensorDataType       :: !DataType
    , tensorDataBytes      :: !(S.Vector Word8)
    }
  deriving (Show, Eq)

-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ())
            -> ((IO () -> IO ()) -> Raw.Session -> IO a)
            -- ^ The action can spawn concurrent tasks which will
            -- be canceled before withSession returns.
            -> IO a
withSession optionSetter action = do
    drain <- newMVar []
    let cleanup s =
        -- Closes the session to nudge the pending run calls to fail and exit.
            finally (checkStatus (Raw.closeSession s)) $ do
                runners <- takeMVar drain
                -- Collects all runners before deleting the session.
                mapM_ shutDownRunner runners
                checkStatus (Raw.deleteSession s)
    bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
        optionSetter options
        bracket
            (checkStatus (Raw.newSession options))
            cleanup
            (action (asyncCollector drain))

asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
    where
      launchAndRecord restRunners = (: restRunners) <$> async runner

shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
    cancel r
    -- TODO(gnezdo): manage exceptions better than print.
    either print (const (return ())) =<< waitCatch r

extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
    useProtoAsVoidPtrLen pb $ \ptr len ->
        checkStatus $ Raw.extendGraph session ptr len


run :: Raw.Session
    -> [(B.ByteString, TensorData)] -- ^ Feeds.
    -> [B.ByteString]               -- ^ Fetches.
    -> [B.ByteString]               -- ^ Targets.
    -> IO [TensorData]
run session feeds fetches targets = do
    let nullTensor = Raw.Tensor nullPtr
    -- Use mask to avoid leaking input tensors before they are passed to 'run'
    -- and output tensors before they are passed to 'createTensorData'.
    mask_ $
        -- Feeds
        withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
        mapM (createRawTensor . snd) feeds >>= \feedTensors ->
        withArrayLen feedTensors $ \_ cFeedTensors ->
        -- Fetches.
        withStringArrayLen fetches $ \fetchesLen fetchNames ->
        -- tensorOuts is an array of null Tensor pointers that will be filled
        -- by the call to Raw.run.
        withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
        -- Targets.
        withStringArrayLen targets $ \targetsLen ctargets -> do
            checkStatus $ Raw.run
                session
                nullPtr
                feedNames cFeedTensors (safeConvert feedsLen)
                fetchNames tensorOuts (safeConvert fetchesLen)
                ctargets (safeConvert targetsLen)
                nullPtr
            mapM_ Raw.deleteTensor feedTensors
            outTensors <- peekArray fetchesLen tensorOuts
            mapM createTensorData outTensors


-- Internal.


-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
    forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
    => a -> b
safeConvert x =
    fromMaybe
    (error ("Failed to convert " ++ show x ++ ", got " ++
            show (fromIntegral x :: b)))
    (toIntegralSized x)


-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
  where
    go [] cs = fn (reverse cs)
    -- TODO(fmayle): Is it worth using unsafeAsCString here?
    go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)


-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)


-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
    withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
        let len = S.length byteVec
        dest <- mallocArray len
        S.unsafeWith byteVec $ \x -> copyArray dest x len
        Raw.newTensor (toEnum $ fromEnum dt)
                      cdims (safeConvert cdimsLen)
                      (castPtr dest) (safeConvert len)
                      tensorDeallocFunPtr nullPtr

{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x

-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
    -- Read dimensions.
    numDims <- Raw.numDims t
    dims <- mapM (Raw.dim t) [0..numDims-1]
    -- Read type.
    dtype <- toEnum . fromEnum <$> Raw.tensorType t
    -- Read data.
    len <- safeConvert <$> Raw.tensorByteSize t
    bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
    fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
    let v = S.unsafeFromForeignPtr0 fp len
    return $ TensorData (map safeConvert dims) dtype v

-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
    bracket Raw.newStatus Raw.deleteStatus $ \status -> do
        result <- fn status
        code <- Raw.getCode status
        when (code /= Raw.TF_OK) $ do
            msg <- T.decodeUtf8With T.lenientDecode <$>
                   (Raw.message status >>= B.packCString)
            throwIO $ TensorFlowException code msg
        return result

setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
    useProtoAsVoidPtrLen pb $ \ptr len ->
        checkStatus (Raw.setConfig opt ptr len)

setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget

-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
                        msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
        \(bytes, len) -> f (castPtr bytes) (safeConvert len)

-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
    foreignPtr <-
        mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
    -- Makes a copy because it is more reliable than eviscerating
    -- Buffer to steal its memory (including custom deallocator).
    withForeignPtr foreignPtr $
        \ptr -> B.packCStringLen =<< (,)
                <$> (castPtr <$> Raw.getBufferData ptr)
                <*> (safeConvert <$> Raw.getBufferLength ptr)
    where
      checkCall = do
          p <- Raw.getAllOpList
          when (p == nullPtr) (throwIO exception)
          return p
      exception = TensorFlowException
                Raw.TF_UNKNOWN "GetAllOpList failure, check logs"