module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Int (Int64)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
withSession :: (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
-> IO a
withSession optionSetter action = do
drain <- newMVar []
let cleanup s =
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
optionSetter options
bracket
(checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)]
-> [B.ByteString]
-> [B.ByteString]
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
mask_ $
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
withStringArrayLen fetches $ \fetchesLen fetchNames ->
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (fromIntegral feedsLen)
fetchNames tensorOuts (fromIntegral fetchesLen)
ctargets (fromIntegral targetsLen)
nullPtr
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
cdims (fromIntegral cdimsLen)
(castPtr dest) (fromIntegral len)
tensorDeallocFunPtr nullPtr
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims1]
dtype <- toEnum . fromEnum <$> Raw.tensorType t
len <- fromIntegral <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
v <- S.fromList <$> peekArray len bytes
Raw.deleteTensor t
return $ TensorData (map fromIntegral dims) dtype v
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (fromIntegral <$> Raw.getBufferLength ptr)
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"