{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as M
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Int -> TensorFlowException -> ShowS
[TensorFlowException] -> ShowS
TensorFlowException -> String
(Int -> TensorFlowException -> ShowS)
-> (TensorFlowException -> String)
-> ([TensorFlowException] -> ShowS)
-> Show TensorFlowException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorFlowException] -> ShowS
$cshowList :: [TensorFlowException] -> ShowS
show :: TensorFlowException -> String
$cshow :: TensorFlowException -> String
showsPrec :: Int -> TensorFlowException -> ShowS
$cshowsPrec :: Int -> TensorFlowException -> ShowS
Show, TensorFlowException -> TensorFlowException -> Bool
(TensorFlowException -> TensorFlowException -> Bool)
-> (TensorFlowException -> TensorFlowException -> Bool)
-> Eq TensorFlowException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorFlowException -> TensorFlowException -> Bool
$c/= :: TensorFlowException -> TensorFlowException -> Bool
== :: TensorFlowException -> TensorFlowException -> Bool
$c== :: TensorFlowException -> TensorFlowException -> Bool
Eq, Typeable)
instance Exception TensorFlowException
data TensorData = TensorData
{ TensorData -> [Int64]
tensorDataDimensions :: [Int64]
, TensorData -> DataType
tensorDataType :: !DataType
, TensorData -> Vector Word8
tensorDataBytes :: !(S.Vector Word8)
}
deriving (Int -> TensorData -> ShowS
[TensorData] -> ShowS
TensorData -> String
(Int -> TensorData -> ShowS)
-> (TensorData -> String)
-> ([TensorData] -> ShowS)
-> Show TensorData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorData] -> ShowS
$cshowList :: [TensorData] -> ShowS
show :: TensorData -> String
$cshow :: TensorData -> String
showsPrec :: Int -> TensorData -> ShowS
$cshowsPrec :: Int -> TensorData -> ShowS
Show, TensorData -> TensorData -> Bool
(TensorData -> TensorData -> Bool)
-> (TensorData -> TensorData -> Bool) -> Eq TensorData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorData -> TensorData -> Bool
$c/= :: TensorData -> TensorData -> Bool
== :: TensorData -> TensorData -> Bool
$c== :: TensorData -> TensorData -> Bool
Eq)
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-> m a
withSession :: (SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Session -> m a) -> m a
withSession optionSetter :: SessionOptions -> IO ()
optionSetter action :: (IO () -> IO ()) -> Session -> m a
action = do
MVar [Async ()]
drain <- IO (MVar [Async ()]) -> m (MVar [Async ()])
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MVar [Async ()]) -> m (MVar [Async ()]))
-> IO (MVar [Async ()]) -> m (MVar [Async ()])
forall a b. (a -> b) -> a -> b
$ [Async ()] -> IO (MVar [Async ()])
forall a. a -> IO (MVar a)
newMVar []
let cleanup :: Session -> IO ()
cleanup s :: Session
s =
IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
finally ((Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (Session -> Status -> IO ()
Raw.closeSession Session
s)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
[Async ()]
runners <- MVar [Async ()] -> IO [Async ()]
forall a. MVar a -> IO a
takeMVar MVar [Async ()]
drain
(Async () -> IO ()) -> [Async ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async () -> IO ()
shutDownRunner [Async ()]
runners
(Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (Session -> Status -> IO ()
Raw.deleteSession Session
s)
let bracketIO :: IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO x :: IO a
x y :: a -> IO c
y = m a -> (a -> m c) -> (a -> m b) -> m b
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
x) (IO c -> m c
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO c -> m c) -> (a -> IO c) -> a -> m c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IO c
y)
IO SessionOptions
-> (SessionOptions -> IO ()) -> (SessionOptions -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadMask m, MonadIO m) =>
IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO IO SessionOptions
Raw.newSessionOptions SessionOptions -> IO ()
Raw.deleteSessionOptions ((SessionOptions -> m a) -> m a) -> (SessionOptions -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \options :: SessionOptions
options -> do
IO Session -> (Session -> IO ()) -> (Session -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadMask m, MonadIO m) =>
IO a -> (a -> IO c) -> (a -> m b) -> m b
bracketIO
(SessionOptions -> IO ()
optionSetter SessionOptions
options IO () -> IO Session -> IO Session
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Status -> IO Session) -> IO Session
forall a. (Status -> IO a) -> IO a
checkStatus (SessionOptions -> Status -> IO Session
Raw.newSession SessionOptions
options))
Session -> IO ()
cleanup
((IO () -> IO ()) -> Session -> m a
action (MVar [Async ()] -> IO () -> IO ()
asyncCollector MVar [Async ()]
drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain :: MVar [Async ()]
drain runner :: IO ()
runner = MVar [Async ()] -> ([Async ()] -> IO [Async ()]) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVarMasked_ MVar [Async ()]
drain [Async ()] -> IO [Async ()]
launchAndRecord
where
launchAndRecord :: [Async ()] -> IO [Async ()]
launchAndRecord restRunners :: [Async ()]
restRunners = (Async () -> [Async ()] -> [Async ()]
forall a. a -> [a] -> [a]
: [Async ()]
restRunners) (Async () -> [Async ()]) -> IO (Async ()) -> IO [Async ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async IO ()
runner
shutDownRunner :: Async () -> IO ()
shutDownRunner :: Async () -> IO ()
shutDownRunner r :: Async ()
r = do
Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
r
(SomeException -> IO ())
-> (() -> IO ()) -> Either SomeException () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO ()
forall a. Show a => a -> IO ()
print (IO () -> () -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())) (Either SomeException () -> IO ())
-> IO (Either SomeException ()) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Async () -> IO (Either SomeException ())
forall a. Async a -> IO (Either SomeException a)
waitCatch Async ()
r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph :: Session -> GraphDef -> IO ()
extendGraph session :: Session
session pb :: GraphDef
pb =
GraphDef -> (Ptr () -> CULong -> IO ()) -> IO ()
forall msg c b a.
(Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen GraphDef
pb ((Ptr () -> CULong -> IO ()) -> IO ())
-> (Ptr () -> CULong -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr ()
ptr len :: CULong
len ->
(Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus ((Status -> IO ()) -> IO ()) -> (Status -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Session -> Ptr () -> CULong -> Status -> IO ()
Raw.extendGraph Session
session Ptr ()
ptr CULong
len
run :: Raw.Session
-> [(B.ByteString, TensorData)]
-> [B.ByteString]
-> [B.ByteString]
-> IO [TensorData]
run :: Session
-> [(ByteString, TensorData)]
-> [ByteString]
-> [ByteString]
-> IO [TensorData]
run session :: Session
session feeds :: [(ByteString, TensorData)]
feeds fetches :: [ByteString]
fetches targets :: [ByteString]
targets = do
let nullTensor :: Tensor
nullTensor = Ptr Tensor -> Tensor
Raw.Tensor Ptr Tensor
forall a. Ptr a
nullPtr
IO [TensorData] -> IO [TensorData]
forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (IO [TensorData] -> IO [TensorData])
-> IO [TensorData] -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$
[ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen ((ByteString, TensorData) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, TensorData) -> ByteString)
-> [(ByteString, TensorData)] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(ByteString, TensorData)]
feeds) ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \feedsLen :: Int
feedsLen feedNames :: Ptr CString
feedNames ->
((ByteString, TensorData) -> IO Tensor)
-> [(ByteString, TensorData)] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TensorData -> IO Tensor
createRawTensor (TensorData -> IO Tensor)
-> ((ByteString, TensorData) -> TensorData)
-> (ByteString, TensorData)
-> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, TensorData) -> TensorData
forall a b. (a, b) -> b
snd) [(ByteString, TensorData)]
feeds IO [Tensor] -> ([Tensor] -> IO [TensorData]) -> IO [TensorData]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \feedTensors :: [Tensor]
feedTensors ->
[Tensor]
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [Tensor]
feedTensors ((Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \_ cFeedTensors :: Ptr Tensor
cFeedTensors ->
[ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen [ByteString]
fetches ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \fetchesLen :: Int
fetchesLen fetchNames :: Ptr CString
fetchNames ->
[Tensor]
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (Int -> Tensor -> [Tensor]
forall a. Int -> a -> [a]
replicate Int
fetchesLen Tensor
nullTensor) ((Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr Tensor -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \_ tensorOuts :: Ptr Tensor
tensorOuts ->
[ByteString]
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a. [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen [ByteString]
targets ((Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData])
-> (Int -> Ptr CString -> IO [TensorData]) -> IO [TensorData]
forall a b. (a -> b) -> a -> b
$ \targetsLen :: Int
targetsLen ctargets :: Ptr CString
ctargets -> do
(Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus ((Status -> IO ()) -> IO ()) -> (Status -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Session
-> BufferPtr
-> Ptr CString
-> Ptr Tensor
-> CInt
-> Ptr CString
-> Ptr Tensor
-> CInt
-> Ptr CString
-> CInt
-> BufferPtr
-> Status
-> IO ()
Raw.run
Session
session
BufferPtr
forall a. Ptr a
nullPtr
Ptr CString
feedNames Ptr Tensor
cFeedTensors (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
feedsLen)
Ptr CString
fetchNames Ptr Tensor
tensorOuts (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
fetchesLen)
Ptr CString
ctargets (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
targetsLen)
BufferPtr
forall a. Ptr a
nullPtr
(Tensor -> IO ()) -> [Tensor] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Tensor -> IO ()
Raw.deleteTensor [Tensor]
feedTensors
[Tensor]
outTensors <- Int -> Ptr Tensor -> IO [Tensor]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
fetchesLen Ptr Tensor
tensorOuts
(Tensor -> IO TensorData) -> [Tensor] -> IO [TensorData]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor -> IO TensorData
createTensorData [Tensor]
outTensors
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert :: a -> b
safeConvert x :: a
x =
b -> Maybe b -> b
forall a. a -> Maybe a -> a
fromMaybe
(String -> b
forall a. HasCallStack => String -> a
error ("Failed to convert " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ ", got " String -> ShowS
forall a. [a] -> [a] -> [a]
++
b -> String
forall a. Show a => a -> String
show (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x :: b)))
(a -> Maybe b
forall a b.
(Integral a, Integral b, Bits a, Bits b) =>
a -> Maybe b
toIntegralSized a
x)
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings :: [ByteString]
strings fn :: [CString] -> IO a
fn = [ByteString] -> [CString] -> IO a
go [ByteString]
strings []
where
go :: [ByteString] -> [CString] -> IO a
go [] cs :: [CString]
cs = [CString] -> IO a
fn ([CString] -> [CString]
forall a. [a] -> [a]
reverse [CString]
cs)
go (x :: ByteString
x:xs :: [ByteString]
xs) cs :: [CString]
cs = ByteString -> (CString -> IO a) -> IO a
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
x ((CString -> IO a) -> IO a) -> (CString -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \c :: CString
c -> [ByteString] -> [CString] -> IO a
go [ByteString]
xs (CString
cCString -> [CString] -> [CString]
forall a. a -> [a] -> [a]
:[CString]
cs)
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen :: [ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs :: [ByteString]
xs fn :: Int -> Ptr CString -> IO a
fn = [ByteString] -> ([CString] -> IO a) -> IO a
forall a. [ByteString] -> ([CString] -> IO a) -> IO a
withStringList [ByteString]
xs ([CString] -> (Int -> Ptr CString -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
`withArrayLen` Int -> Ptr CString -> IO a
fn)
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor :: TensorData -> IO Tensor
createRawTensor (TensorData dims :: [Int64]
dims dt :: DataType
dt byteVec :: Vector Word8
byteVec) =
[CInt64] -> (Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen ((Int64 -> CInt64) -> [Int64] -> [CInt64]
forall a b. (a -> b) -> [a] -> [b]
map Int64 -> CInt64
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert [Int64]
dims) ((Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor)
-> (Int -> Ptr CInt64 -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \cdimsLen :: Int
cdimsLen cdims :: Ptr CInt64
cdims -> do
let len :: Int
len = Vector Word8 -> Int
forall a. Storable a => Vector a -> Int
S.length Vector Word8
byteVec
Ptr Word8
dest <- Int -> IO (Ptr Word8)
forall a. Storable a => Int -> IO (Ptr a)
mallocArray Int
len
Vector Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector Word8
byteVec ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \x :: Ptr Word8
x -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Word8
dest Ptr Word8
x Int
len
DataType
-> Ptr CInt64
-> CInt
-> Ptr ()
-> CULong
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
-> Ptr ()
-> IO Tensor
Raw.newTensor (Int -> DataType
forall a. Enum a => Int -> a
toEnum (Int -> DataType) -> Int -> DataType
forall a b. (a -> b) -> a -> b
$ DataType -> Int
forall a. Enum a => a -> Int
fromEnum DataType
dt)
Ptr CInt64
cdims (Int -> CInt
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
cdimsLen)
(Ptr Word8 -> Ptr ()
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dest) (Int -> CULong
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
len)
FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
tensorDeallocFunPtr Ptr ()
forall a. Ptr a
nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr :: FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
tensorDeallocFunPtr = IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
forall a. IO a -> a
unsafePerformIO (IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ())
forall a b. (a -> b) -> a -> b
$ (Ptr () -> CULong -> Ptr () -> IO ())
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
Raw.wrapTensorDealloc ((Ptr () -> CULong -> Ptr () -> IO ())
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ())))
-> (Ptr () -> CULong -> Ptr () -> IO ())
-> IO (FunPtr (Ptr () -> CULong -> Ptr () -> IO ()))
forall a b. (a -> b) -> a -> b
$ \x :: Ptr ()
x _ _ -> Ptr () -> IO ()
forall a. Ptr a -> IO ()
free Ptr ()
x
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData :: Tensor -> IO TensorData
createTensorData t :: Tensor
t = do
CInt
numDims <- Tensor -> IO CInt
Raw.numDims Tensor
t
[CInt64]
dims <- (CInt -> IO CInt64) -> [CInt] -> IO [CInt64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor -> CInt -> IO CInt64
Raw.dim Tensor
t) [0..CInt
numDimsCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
-1]
DataType
dtype <- Int -> DataType
forall a. Enum a => Int -> a
toEnum (Int -> DataType) -> (DataType -> Int) -> DataType -> DataType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataType -> Int
forall a. Enum a => a -> Int
fromEnum (DataType -> DataType) -> IO DataType -> IO DataType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO DataType
Raw.tensorType Tensor
t
Int
len <- CULong -> Int
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert (CULong -> Int) -> IO CULong -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO CULong
Raw.tensorByteSize Tensor
t
Ptr Word8
bytes <- Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (Ptr () -> Ptr Word8) -> IO (Ptr ()) -> IO (Ptr Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO (Ptr ())
Raw.tensorData Tensor
t :: IO (Ptr Word8)
ForeignPtr Word8
fp <- Ptr Word8 -> IO (ForeignPtr Word8)
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr Word8
bytes
Vector Word8
v <- MVector (PrimState IO) Word8 -> IO (Vector Word8)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
S.freeze (ForeignPtr Word8 -> Int -> MVector RealWorld Word8
forall a s. Storable a => ForeignPtr a -> Int -> MVector s a
M.unsafeFromForeignPtr0 ForeignPtr Word8
fp Int
len)
Tensor -> IO ()
Raw.deleteTensor Tensor
t
TensorData -> IO TensorData
forall (m :: * -> *) a. Monad m => a -> m a
return (TensorData -> IO TensorData) -> TensorData -> IO TensorData
forall a b. (a -> b) -> a -> b
$ [Int64] -> DataType -> Vector Word8 -> TensorData
TensorData ((CInt64 -> Int64) -> [CInt64] -> [Int64]
forall a b. (a -> b) -> [a] -> [b]
map CInt64 -> Int64
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert [CInt64]
dims) DataType
dtype Vector Word8
v
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus :: (Status -> IO a) -> IO a
checkStatus fn :: Status -> IO a
fn =
IO Status -> (Status -> IO ()) -> (Status -> IO a) -> IO a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket IO Status
Raw.newStatus Status -> IO ()
Raw.deleteStatus ((Status -> IO a) -> IO a) -> (Status -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \status :: Status
status -> do
a
result <- Status -> IO a
fn Status
status
Code
code <- Status -> IO Code
Raw.getCode Status
status
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Code
code Code -> Code -> Bool
forall a. Eq a => a -> a -> Bool
/= Code
Raw.TF_OK) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Text
msg <- OnDecodeError -> ByteString -> Text
T.decodeUtf8With OnDecodeError
T.lenientDecode (ByteString -> Text) -> IO ByteString -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(Status -> IO CString
Raw.message Status
status IO CString -> (CString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CString -> IO ByteString
B.packCString)
TensorFlowException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (TensorFlowException -> IO ()) -> TensorFlowException -> IO ()
forall a b. (a -> b) -> a -> b
$ Code -> Text -> TensorFlowException
TensorFlowException Code
code Text
msg
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig :: ConfigProto -> SessionOptions -> IO ()
setSessionConfig pb :: ConfigProto
pb opt :: SessionOptions
opt =
ConfigProto -> (Ptr () -> CULong -> IO ()) -> IO ()
forall msg c b a.
(Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen ConfigProto
pb ((Ptr () -> CULong -> IO ()) -> IO ())
-> (Ptr () -> CULong -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr ()
ptr len :: CULong
len ->
(Status -> IO ()) -> IO ()
forall a. (Status -> IO a) -> IO a
checkStatus (SessionOptions -> Ptr () -> CULong -> Status -> IO ()
Raw.setConfig SessionOptions
opt Ptr ()
ptr CULong
len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget :: ByteString -> SessionOptions -> IO ()
setSessionTarget target :: ByteString
target = ByteString -> (CString -> IO ()) -> IO ()
forall a. ByteString -> (CString -> IO a) -> IO a
B.useAsCString ByteString
target ((CString -> IO ()) -> IO ())
-> (SessionOptions -> CString -> IO ()) -> SessionOptions -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionOptions -> CString -> IO ()
Raw.setTarget
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen :: msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg :: msg
msg f :: Ptr b -> c -> IO a
f = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.useAsCStringLen (msg -> ByteString
forall msg. Message msg => msg -> ByteString
encodeMessage msg
msg) ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$
\(bytes :: CString
bytes, len :: Int
len) -> Ptr b -> c -> IO a
f (CString -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr CString
bytes) (Int -> c
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert Int
len)
getAllOpList :: IO B.ByteString
getAllOpList :: IO ByteString
getAllOpList = do
ForeignPtr Buffer
foreignPtr <-
IO (ForeignPtr Buffer) -> IO (ForeignPtr Buffer)
forall (m :: * -> *) a. MonadMask m => m a -> m a
mask_ (FinalizerPtr Buffer -> BufferPtr -> IO (ForeignPtr Buffer)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr Buffer
Raw.deleteBuffer (BufferPtr -> IO (ForeignPtr Buffer))
-> IO BufferPtr -> IO (ForeignPtr Buffer)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO BufferPtr
checkCall)
ForeignPtr Buffer -> (BufferPtr -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Buffer
foreignPtr ((BufferPtr -> IO ByteString) -> IO ByteString)
-> (BufferPtr -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$
\ptr :: BufferPtr
ptr -> CStringLen -> IO ByteString
B.packCStringLen (CStringLen -> IO ByteString) -> IO CStringLen -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (,)
(CString -> Int -> CStringLen)
-> IO CString -> IO (Int -> CStringLen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr () -> CString
forall a b. Ptr a -> Ptr b
castPtr (Ptr () -> CString) -> IO (Ptr ()) -> IO CString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BufferPtr -> IO (Ptr ())
Raw.getBufferData BufferPtr
ptr)
IO (Int -> CStringLen) -> IO Int -> IO CStringLen
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (CULong -> Int
forall a b.
(Show a, Show b, Bits a, Bits b, Integral a, Integral b) =>
a -> b
safeConvert (CULong -> Int) -> IO CULong -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BufferPtr -> IO CULong
Raw.getBufferLength BufferPtr
ptr)
where
checkCall :: IO BufferPtr
checkCall = do
BufferPtr
p <- IO BufferPtr
Raw.getAllOpList
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BufferPtr
p BufferPtr -> BufferPtr -> Bool
forall a. Eq a => a -> a -> Bool
== BufferPtr
forall a. Ptr a
nullPtr) (TensorFlowException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM TensorFlowException
exception)
BufferPtr -> IO BufferPtr
forall (m :: * -> *) a. Monad m => a -> m a
return BufferPtr
p
exception :: TensorFlowException
exception = Code -> Text -> TensorFlowException
TensorFlowException
Code
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"