{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as M
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-> m a
withSession optionSetter action = do
drain <- liftIO $ newMVar []
let cleanup s =
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)]
-> [B.ByteString]
-> [B.ByteString]
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
mask_ $
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
withStringArrayLen fetches $ \fetchesLen fetchNames ->
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
nullPtr
mapM_ Raw.deleteTensor feedTensors
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert x =
fromMaybe
(error ("Failed to convert " ++ show x ++ ", got " ++
show (fromIntegral x :: b)))
(toIntegralSized x)
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
cdims (safeConvert cdimsLen)
(castPtr dest) (safeConvert len)
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims-1]
dtype <- toEnum . fromEnum <$> Raw.tensorType t
len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
fp <- newForeignPtr_ bytes
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
Raw.deleteTensor t
return $ TensorData (map safeConvert dims) dtype v
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwM $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (safeConvert <$> Raw.getBufferLength ptr)
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwM exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"