mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Migrate from TF_DeprecatedSession to TF_Session
Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and pass an input map for all existing nodes in the graph.
This commit is contained in:
parent
30a12d7776
commit
fb629d1207
4 changed files with 244 additions and 102 deletions
|
@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
|
||||||
( TensorFlowException(..)
|
( TensorFlowException(..)
|
||||||
, Raw.Session
|
, Raw.Session
|
||||||
, withSession
|
, withSession
|
||||||
, extendGraph
|
|
||||||
, run
|
, run
|
||||||
|
|
||||||
|
, SessionAction
|
||||||
|
|
||||||
|
, Raw.SessionOptions
|
||||||
|
|
||||||
|
, Raw.Graph
|
||||||
|
, extendGraph
|
||||||
|
|
||||||
, TensorData(..)
|
, TensorData(..)
|
||||||
, setSessionConfig
|
, setSessionConfig
|
||||||
, setSessionTarget
|
, setSessionTarget
|
||||||
|
@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
|
||||||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||||
import Data.Bits (Bits, toIntegralSized)
|
import Data.Bits (Bits, toIntegralSized)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
|
import Data.Foldable (for_)
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.Typeable (Typeable)
|
import Data.Typeable (Typeable)
|
||||||
import Data.Word (Word8)
|
import Data.Word (Word8)
|
||||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
|
||||||
import Foreign.C.String (CString)
|
import Foreign.ForeignPtr (newForeignPtr_)
|
||||||
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
|
|
||||||
import Foreign.Marshal.Alloc (free)
|
import Foreign.Marshal.Alloc (free)
|
||||||
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
|
import qualified Data.ByteString.Char8 as C
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import qualified Data.Text.Encoding as T
|
import qualified Data.Text.Encoding as T
|
||||||
import qualified Data.Text.Encoding.Error as T
|
import qualified Data.Text.Encoding.Error as T
|
||||||
|
@ -87,15 +95,26 @@ data TensorData = TensorData
|
||||||
}
|
}
|
||||||
deriving (Show, Eq)
|
deriving (Show, Eq)
|
||||||
|
|
||||||
|
-- | The action can spawn concurrent tasks which will be canceled before
|
||||||
|
-- withSession returns.
|
||||||
|
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
|
||||||
|
|
||||||
-- | Runs the given action after creating a session with options
|
-- | Runs the given action after creating a session with options
|
||||||
-- populated by the given optionSetter.
|
-- populated by the given optionSetter.
|
||||||
withSession :: (MonadIO m, MonadMask m)
|
withSession :: (MonadIO m, MonadMask m)
|
||||||
=> (Raw.SessionOptions -> IO ())
|
=> (Raw.SessionOptions -> IO ())
|
||||||
-> ((IO () -> IO ()) -> Raw.Session -> m a)
|
-> SessionAction m a
|
||||||
-- ^ The action can spawn concurrent tasks which will
|
|
||||||
-- be canceled before withSession returns.
|
|
||||||
-> m a
|
-> m a
|
||||||
withSession optionSetter action = do
|
withSession = withSession_ Raw.newSession
|
||||||
|
|
||||||
|
withSession_ :: (MonadIO m, MonadMask m)
|
||||||
|
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
|
||||||
|
-- ^ mkSession
|
||||||
|
-> (Raw.SessionOptions -> IO ())
|
||||||
|
-- ^ optionSetter
|
||||||
|
-> SessionAction m a
|
||||||
|
-> m a
|
||||||
|
withSession_ mkSession optionSetter action = do
|
||||||
drain <- liftIO $ newMVar []
|
drain <- liftIO $ newMVar []
|
||||||
let cleanup s =
|
let cleanup s =
|
||||||
-- Closes the session to nudge the pending run calls to fail and exit.
|
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||||
|
@ -105,11 +124,12 @@ withSession optionSetter action = do
|
||||||
mapM_ shutDownRunner runners
|
mapM_ shutDownRunner runners
|
||||||
checkStatus (Raw.deleteSession s)
|
checkStatus (Raw.deleteSession s)
|
||||||
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
||||||
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
|
||||||
bracketIO
|
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||||
(optionSetter options >> checkStatus (Raw.newSession options))
|
bracketIO
|
||||||
cleanup
|
(optionSetter options >> checkStatus (mkSession graph options))
|
||||||
(action (asyncCollector drain))
|
cleanup
|
||||||
|
(\session -> action (asyncCollector drain) session graph)
|
||||||
|
|
||||||
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
||||||
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
||||||
|
@ -122,43 +142,103 @@ shutDownRunner r = do
|
||||||
-- TODO(gnezdo): manage exceptions better than print.
|
-- TODO(gnezdo): manage exceptions better than print.
|
||||||
either print (const (return ())) =<< waitCatch r
|
either print (const (return ())) =<< waitCatch r
|
||||||
|
|
||||||
extendGraph :: Raw.Session -> GraphDef -> IO ()
|
graphImportGraphDef :: Raw.Graph
|
||||||
extendGraph session pb =
|
-> GraphDef
|
||||||
useProtoAsVoidPtrLen pb $ \ptr len ->
|
-> (Raw.ImportGraphDefOptions -> IO ())
|
||||||
checkStatus $ Raw.extendGraph session ptr len
|
-> IO ()
|
||||||
|
graphImportGraphDef graph pb optionSetter =
|
||||||
|
useProtoAsBuffer pb $ \buffer ->
|
||||||
|
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
|
||||||
|
optionSetter importGraphDefOptions
|
||||||
|
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
|
||||||
|
|
||||||
|
forGraphOperations_ :: Raw.Graph
|
||||||
|
-> (Raw.Operation -> IO b)
|
||||||
|
-> IO ()
|
||||||
|
forGraphOperations_ graph f = with 0 go
|
||||||
|
where
|
||||||
|
go indexPtr = do
|
||||||
|
op <- Raw.graphNextOperation graph indexPtr
|
||||||
|
case op of
|
||||||
|
Raw.Operation ptr | ptr == nullPtr -> return ()
|
||||||
|
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
|
||||||
|
|
||||||
|
extendGraph :: Raw.Graph -> GraphDef -> IO ()
|
||||||
|
extendGraph graph graphDef =
|
||||||
|
graphImportGraphDef graph graphDef $ \opts ->
|
||||||
|
-- All inputs of the nodes in the GraphDef should either refer to
|
||||||
|
-- other nodes in the GraphDef, or be mapped to nodes already in
|
||||||
|
-- the Graph by adding an input mapping.
|
||||||
|
-- We add an input mapping for all existing nodes in the Graph in
|
||||||
|
-- case they are referenced in the GraphDef.
|
||||||
|
forGraphOperations_ graph $ \op -> do
|
||||||
|
srcName <- Raw.operationName op
|
||||||
|
numOutputs <- Raw.operationNumOutputs op
|
||||||
|
for_ [0..numOutputs] $ \srcIndex -> do
|
||||||
|
let dst = Raw.Output op (safeConvert srcIndex)
|
||||||
|
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
|
||||||
|
|
||||||
run :: Raw.Session
|
run :: Raw.Session
|
||||||
-> [(B.ByteString, TensorData)] -- ^ Feeds.
|
-> Raw.Graph
|
||||||
-> [B.ByteString] -- ^ Fetches.
|
-> [(B.ByteString, TensorData)] -- ^ Inputs.
|
||||||
-> [B.ByteString] -- ^ Targets.
|
-> [B.ByteString] -- ^ Outputs.
|
||||||
|
-> [B.ByteString] -- ^ Target operations.
|
||||||
-> IO [TensorData]
|
-> IO [TensorData]
|
||||||
run session feeds fetches targets = do
|
run session graph inputNamesData outputNames targetNames = do
|
||||||
let nullTensor = Raw.Tensor nullPtr
|
|
||||||
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
||||||
-- and output tensors before they are passed to 'createTensorData'.
|
-- and output tensors before they are passed to 'createTensorData'.
|
||||||
mask_ $
|
mask_ $
|
||||||
-- Feeds
|
-- Inputs.
|
||||||
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
|
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
|
||||||
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
|
withArrayLen inputs $ \nInputs cInputs ->
|
||||||
withArrayLen feedTensors $ \_ cFeedTensors ->
|
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
|
||||||
-- Fetches.
|
withArrayLen inputTensors $ \_ cInputTensors ->
|
||||||
withStringArrayLen fetches $ \fetchesLen fetchNames ->
|
-- Outputs.
|
||||||
-- tensorOuts is an array of null Tensor pointers that will be filled
|
mapM (resolveOutput graph) outputNames >>= \outputs ->
|
||||||
|
withArrayLen outputs $ \nOutputs cOutputs ->
|
||||||
|
-- outputTensors is an array of null Tensor pointers that will be filled
|
||||||
-- by the call to Raw.run.
|
-- by the call to Raw.run.
|
||||||
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
|
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
|
||||||
-- Targets.
|
-- Target operations.
|
||||||
withStringArrayLen targets $ \targetsLen ctargets -> do
|
mapM (resolveOperation graph) targetNames >>= \targets ->
|
||||||
|
withArrayLen targets $ \nTargets cTargets -> do
|
||||||
checkStatus $ Raw.run
|
checkStatus $ Raw.run
|
||||||
session
|
session
|
||||||
nullPtr
|
nullPtr -- RunOptions proto.
|
||||||
feedNames cFeedTensors (safeConvert feedsLen)
|
cInputs cInputTensors (safeConvert nInputs)
|
||||||
fetchNames tensorOuts (safeConvert fetchesLen)
|
cOutputs cOutputTensors (safeConvert nOutputs)
|
||||||
ctargets (safeConvert targetsLen)
|
cTargets (safeConvert nTargets)
|
||||||
nullPtr
|
nullPtr -- RunMetadata.
|
||||||
mapM_ Raw.deleteTensor feedTensors
|
mapM_ Raw.deleteTensor inputTensors
|
||||||
outTensors <- peekArray fetchesLen tensorOuts
|
outTensors <- peekArray nOutputs cOutputTensors
|
||||||
mapM createTensorData outTensors
|
mapM createTensorData outTensors
|
||||||
|
where
|
||||||
|
|
||||||
|
nullTensor = Raw.Tensor nullPtr
|
||||||
|
|
||||||
|
resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
|
||||||
|
resolveOutput graph name = do
|
||||||
|
let (opName, idx) = parseName name
|
||||||
|
op <- resolveOperation graph opName
|
||||||
|
pure $ Raw.Output op (safeConvert idx)
|
||||||
|
where
|
||||||
|
parseName :: B.ByteString -> (B.ByteString, Int)
|
||||||
|
parseName opName =
|
||||||
|
case break (== ':') (C.unpack opName) of
|
||||||
|
(opName_, ':':idxStr) | idx <- read idxStr
|
||||||
|
-> (C.pack opName_, idx)
|
||||||
|
_ -> (opName, 0)
|
||||||
|
|
||||||
|
resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
|
||||||
|
resolveOperation graph name = do
|
||||||
|
op <- Raw.graphOperationByName graph name
|
||||||
|
case op of
|
||||||
|
Raw.Operation ptr | ptr == nullPtr -> throwM exception
|
||||||
|
_ -> pure op
|
||||||
|
where
|
||||||
|
exception =
|
||||||
|
let msg = "Operation not found in graph: " <> (T.pack $ show name)
|
||||||
|
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
|
||||||
|
|
||||||
|
|
||||||
-- Internal.
|
-- Internal.
|
||||||
|
@ -174,21 +254,6 @@ safeConvert x =
|
||||||
show (fromIntegral x :: b)))
|
show (fromIntegral x :: b)))
|
||||||
(toIntegralSized x)
|
(toIntegralSized x)
|
||||||
|
|
||||||
|
|
||||||
-- | Use a list of ByteString as a list of CString.
|
|
||||||
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
|
||||||
withStringList strings fn = go strings []
|
|
||||||
where
|
|
||||||
go [] cs = fn (reverse cs)
|
|
||||||
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
|
||||||
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
|
|
||||||
|
|
||||||
|
|
||||||
-- | Use a list of ByteString as an array of CString.
|
|
||||||
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
|
|
||||||
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
|
||||||
|
|
||||||
|
|
||||||
-- | Create a Raw.Tensor from a TensorData.
|
-- | Create a Raw.Tensor from a TensorData.
|
||||||
createRawTensor :: TensorData -> IO Raw.Tensor
|
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||||
createRawTensor (TensorData dims dt byteVec) =
|
createRawTensor (TensorData dims dt byteVec) =
|
||||||
|
@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
|
||||||
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||||
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
||||||
|
|
||||||
|
-- | Serializes the given msg and provides it as BufferPtr argument
|
||||||
|
-- to the given action.
|
||||||
|
useProtoAsBuffer :: (Message msg) =>
|
||||||
|
msg -> (Raw.BufferPtr -> IO a) -> IO a
|
||||||
|
useProtoAsBuffer msg f =
|
||||||
|
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
|
||||||
|
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
|
||||||
|
Raw.deleteBuffer
|
||||||
|
f
|
||||||
|
|
||||||
-- | Returns the serialized OpList of all OpDefs defined in this
|
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||||
-- address space.
|
-- address space.
|
||||||
getAllOpList :: IO B.ByteString
|
getAllOpList :: IO B.ByteString
|
||||||
getAllOpList = do
|
getAllOpList =
|
||||||
foreignPtr <-
|
bracket checkCall Raw.deleteBuffer $ \buffer ->
|
||||||
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
|
-- Makes a copy because it is more reliable than eviscerating
|
||||||
-- Makes a copy because it is more reliable than eviscerating
|
-- Buffer to steal its memory (including custom deallocator).
|
||||||
-- Buffer to steal its memory (including custom deallocator).
|
B.packCStringLen =<< (,)
|
||||||
withForeignPtr foreignPtr $
|
<$> (castPtr <$> Raw.getBufferData buffer)
|
||||||
\ptr -> B.packCStringLen =<< (,)
|
<*> (safeConvert <$> Raw.getBufferLength buffer)
|
||||||
<$> (castPtr <$> Raw.getBufferData ptr)
|
|
||||||
<*> (safeConvert <$> Raw.getBufferLength ptr)
|
|
||||||
where
|
where
|
||||||
checkCall = do
|
checkCall = do
|
||||||
p <- Raw.getAllOpList
|
p <- Raw.getAllOpList
|
||||||
|
|
|
@ -18,6 +18,7 @@ module TensorFlow.Internal.Raw where
|
||||||
|
|
||||||
#include "third_party/tensorflow/c/c_api.h"
|
#include "third_party/tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
import Data.ByteString (ByteString, packCString, useAsCString)
|
||||||
import Foreign
|
import Foreign
|
||||||
import Foreign.C
|
import Foreign.C
|
||||||
|
|
||||||
|
@ -61,6 +62,35 @@ stringGetSize :: TString -> IO CULong
|
||||||
stringGetSize = {# call TF_StringGetSize as ^ #}
|
stringGetSize = {# call TF_StringGetSize as ^ #}
|
||||||
|
|
||||||
|
|
||||||
|
-- Operation.
|
||||||
|
{# pointer *TF_Operation as Operation newtype #}
|
||||||
|
{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #}
|
||||||
|
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
|
||||||
|
|
||||||
|
instance Storable Operation where
|
||||||
|
sizeOf (Operation t) = sizeOf t
|
||||||
|
alignment (Operation t) = alignment t
|
||||||
|
peek p = fmap Operation (peek (castPtr p))
|
||||||
|
poke p (Operation t) = poke (castPtr p) t
|
||||||
|
|
||||||
|
|
||||||
|
-- Output.
|
||||||
|
data Output = Output
|
||||||
|
{ outputOperation :: Operation
|
||||||
|
, outputIndex :: CInt
|
||||||
|
}
|
||||||
|
{# pointer *TF_Output as OutputPtr -> Output #}
|
||||||
|
|
||||||
|
instance Storable Output where
|
||||||
|
sizeOf _ = {# sizeof TF_Output #}
|
||||||
|
alignment _ = {# alignof TF_Output #}
|
||||||
|
peek p = Output <$> {# get TF_Output->oper #} p
|
||||||
|
<*> (fromIntegral <$> {# get TF_Output->index #} p)
|
||||||
|
poke p (Output oper index) = do
|
||||||
|
{# set TF_Output->oper #} p oper
|
||||||
|
{# set TF_Output->index #} p $ fromIntegral index
|
||||||
|
|
||||||
|
|
||||||
-- Buffer.
|
-- Buffer.
|
||||||
data Buffer
|
data Buffer
|
||||||
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||||
|
@ -71,6 +101,12 @@ getBufferData = {# get TF_Buffer->data #}
|
||||||
getBufferLength :: BufferPtr -> IO CULong
|
getBufferLength :: BufferPtr -> IO CULong
|
||||||
getBufferLength = {# get TF_Buffer->length #}
|
getBufferLength = {# get TF_Buffer->length #}
|
||||||
|
|
||||||
|
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
|
||||||
|
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
|
||||||
|
|
||||||
|
deleteBuffer :: BufferPtr -> IO ()
|
||||||
|
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
|
||||||
|
|
||||||
-- Tensor.
|
-- Tensor.
|
||||||
{# pointer *TF_Tensor as Tensor newtype #}
|
{# pointer *TF_Tensor as Tensor newtype #}
|
||||||
|
|
||||||
|
@ -86,6 +122,8 @@ instance Storable Tensor where
|
||||||
-- `CLLong`).
|
-- `CLLong`).
|
||||||
type CInt64 = {#type int64_t #}
|
type CInt64 = {#type int64_t #}
|
||||||
|
|
||||||
|
{# pointer *size_t as CSizePtr -> CSize #}
|
||||||
|
|
||||||
newTensor :: DataType
|
newTensor :: DataType
|
||||||
-> Ptr CInt64 -- dimensions array
|
-> Ptr CInt64 -- dimensions array
|
||||||
-> CInt -- num dimensions
|
-> CInt -- num dimensions
|
||||||
|
@ -114,6 +152,31 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #}
|
||||||
tensorData :: Tensor -> IO (Ptr ())
|
tensorData :: Tensor -> IO (Ptr ())
|
||||||
tensorData = {# call TF_TensorData as ^ #}
|
tensorData = {# call TF_TensorData as ^ #}
|
||||||
|
|
||||||
|
-- ImportGraphDefOptions.
|
||||||
|
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
|
||||||
|
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
|
||||||
|
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
|
||||||
|
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping
|
||||||
|
{ `ImportGraphDefOptions'
|
||||||
|
, useAsCString* `ByteString'
|
||||||
|
, `Int'
|
||||||
|
, %`OutputPtr'
|
||||||
|
} -> `()'
|
||||||
|
#}
|
||||||
|
|
||||||
|
|
||||||
|
-- Graph.
|
||||||
|
{# pointer *TF_Graph as Graph newtype #}
|
||||||
|
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
|
||||||
|
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
|
||||||
|
{# fun TF_GraphOperationByName as graphOperationByName
|
||||||
|
{ `Graph'
|
||||||
|
, useAsCString* `ByteString'
|
||||||
|
} -> `Operation'
|
||||||
|
#}
|
||||||
|
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
|
||||||
|
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
|
||||||
|
|
||||||
|
|
||||||
-- Session Options.
|
-- Session Options.
|
||||||
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
||||||
|
@ -132,29 +195,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
||||||
|
|
||||||
|
|
||||||
-- Session.
|
-- Session.
|
||||||
{# pointer *TF_DeprecatedSession as Session newtype #}
|
{# pointer *TF_Session as Session newtype #}
|
||||||
|
|
||||||
|
newSession :: Graph -> SessionOptions -> Status -> IO Session
|
||||||
|
newSession = {# call TF_NewSession as ^ #}
|
||||||
|
|
||||||
newSession :: SessionOptions -> Status -> IO Session
|
|
||||||
newSession = {# call TF_NewDeprecatedSession as ^ #}
|
|
||||||
|
|
||||||
closeSession :: Session -> Status -> IO ()
|
closeSession :: Session -> Status -> IO ()
|
||||||
closeSession = {# call TF_CloseDeprecatedSession as ^ #}
|
closeSession = {# call TF_CloseSession as ^ #}
|
||||||
|
|
||||||
deleteSession :: Session -> Status -> IO ()
|
deleteSession :: Session -> Status -> IO ()
|
||||||
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #}
|
deleteSession = {# call TF_DeleteSession as ^ #}
|
||||||
|
|
||||||
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
|
|
||||||
extendGraph = {# call TF_ExtendGraph as ^ #}
|
|
||||||
|
|
||||||
run :: Session
|
run :: Session
|
||||||
-> BufferPtr -- RunOptions proto.
|
-> BufferPtr -- RunOptions proto.
|
||||||
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
-> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
||||||
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
-> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
||||||
-> Ptr CString -> CInt -- Target nodes (names, count).
|
-> Ptr Operation -> CInt -- Target operations (ops, count).
|
||||||
-> BufferPtr -- RunMetadata proto.
|
-> BufferPtr -- RunMetadata proto.
|
||||||
-> Status
|
-> Status
|
||||||
-> IO ()
|
-> IO ()
|
||||||
run = {# call TF_Run as ^ #}
|
run = {# call TF_SessionRun as ^ #}
|
||||||
|
|
||||||
-- FFI helpers.
|
-- FFI helpers.
|
||||||
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
||||||
|
@ -170,6 +231,3 @@ foreign import ccall "wrapper"
|
||||||
-- in this address space.
|
-- in this address space.
|
||||||
getAllOpList :: IO BufferPtr
|
getAllOpList :: IO BufferPtr
|
||||||
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||||
|
|
||||||
foreign import ccall "&TF_DeleteBuffer"
|
|
||||||
deleteBuffer :: FunPtr (BufferPtr -> IO ())
|
|
||||||
|
|
|
@ -121,7 +121,7 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
||||||
-- code into a Build function
|
-- code into a Build function
|
||||||
instance IsString Output where
|
instance IsString Output where
|
||||||
fromString s = case break (==':') s of
|
fromString s = case break (==':') s of
|
||||||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
(n, ':':ixStr) | ix <- read ixStr
|
||||||
-> Output (fromInteger ix) $ assigned n
|
-> Output (fromInteger ix) $ assigned n
|
||||||
_ -> Output 0 $ assigned s
|
_ -> Output 0 $ assigned s
|
||||||
where assigned = NodeName . Text.pack
|
where assigned = NodeName . Text.pack
|
||||||
|
|
|
@ -55,7 +55,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
||||||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Nodes
|
import TensorFlow.Nodes
|
||||||
import TensorFlow.Output (NodeName, unNodeName)
|
import TensorFlow.Output (NodeName(..), unNodeName)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
import qualified Data.ByteString.Builder as Builder
|
import qualified Data.ByteString.Builder as Builder
|
||||||
|
@ -67,13 +67,13 @@ import qualified TensorFlow.Internal.FFI as FFI
|
||||||
type Tracer = Builder.Builder -> IO ()
|
type Tracer = Builder.Builder -> IO ()
|
||||||
|
|
||||||
-- Common state threaded through the session.
|
-- Common state threaded through the session.
|
||||||
data SessionState
|
data SessionState = SessionState
|
||||||
= SessionState {
|
{ rawSession :: FFI.Session
|
||||||
rawSession :: FFI.Session
|
, rawGraph :: FFI.Graph
|
||||||
, asyncCollector :: IO () -> IO ()
|
, asyncCollector :: IO () -> IO ()
|
||||||
-- ^ Starts the given action concurrently.
|
-- ^ Starts the given action concurrently.
|
||||||
, tracer :: Tracer
|
, tracer :: Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
newtype SessionT m a
|
newtype SessionT m a
|
||||||
= Session (ReaderT SessionState (BuildT m) a)
|
= Session (ReaderT SessionState (BuildT m) a)
|
||||||
|
@ -120,14 +120,23 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
||||||
-- | Run 'Session' actions in a new TensorFlow session created with
|
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||||
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
||||||
runSessionWithOptions options (Session m) =
|
runSessionWithOptions options session =
|
||||||
FFI.withSession applyOptions $
|
_runSessionWithOptions session options $ FFI.withSession
|
||||||
\as rs ->
|
|
||||||
let initState = SessionState rs as (options ^. sessionTracer)
|
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
||||||
|
=> SessionT m a
|
||||||
|
-> Options
|
||||||
|
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
|
||||||
|
-> m a
|
||||||
|
_runSessionWithOptions (Session m) options withSession =
|
||||||
|
withSession applyOptions $
|
||||||
|
\ac rSession rGraph ->
|
||||||
|
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
|
||||||
in evalBuildT (runReaderT m initState)
|
in evalBuildT (runReaderT m initState)
|
||||||
where applyOptions opt = do
|
where
|
||||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
applyOptions opt = do
|
||||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||||
|
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||||
|
|
||||||
instance Monad m => MonadBuild (SessionT m) where
|
instance Monad m => MonadBuild (SessionT m) where
|
||||||
build = Session . lift . build
|
build = Session . lift . build
|
||||||
|
@ -139,16 +148,17 @@ instance Monad m => MonadBuild (SessionT m) where
|
||||||
extend :: MonadIO m => SessionT m ()
|
extend :: MonadIO m => SessionT m ()
|
||||||
extend = do
|
extend = do
|
||||||
session <- Session (asks rawSession)
|
session <- Session (asks rawSession)
|
||||||
|
graph <- Session (asks rawGraph)
|
||||||
trace <- Session (asks tracer)
|
trace <- Session (asks tracer)
|
||||||
nodesToExtend <- build flushNodeBuffer
|
nodesToExtend <- build flushNodeBuffer
|
||||||
unless (null nodesToExtend) $ liftIO $ do
|
unless (null nodesToExtend) $ liftIO $ do
|
||||||
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
||||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||||
FFI.extendGraph session graphDef
|
FFI.extendGraph graph graphDef
|
||||||
-- Now that all the nodes are created, run the initializers.
|
-- Now that all the nodes are created, run the initializers.
|
||||||
initializers <- build flushInitializers
|
initializers <- build flushInitializers
|
||||||
unless (null initializers) $
|
unless (null initializers) $
|
||||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
|
||||||
|
|
||||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
-- rendered, and fetch the corresponding values for 'a'.
|
-- rendered, and fetch the corresponding values for 'a'.
|
||||||
|
@ -173,8 +183,9 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||||
let feeds' = fixFeeds feeds
|
let feeds' = fixFeeds feeds
|
||||||
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||||
targetNames = toNodeNames $ Set.toList target
|
targetNames = toNodeNames $ Set.toList target
|
||||||
session <- Session (asks rawSession)
|
state <- Session ask
|
||||||
runResult <- liftIO $ FFI.run session
|
runResult <- liftIO $ FFI.run (rawSession state)
|
||||||
|
(rawGraph state)
|
||||||
feeds'
|
feeds'
|
||||||
fetchNames
|
fetchNames
|
||||||
targetNames
|
targetNames
|
||||||
|
@ -214,5 +225,5 @@ asyncProdNodes nodes = do
|
||||||
extend
|
extend
|
||||||
let targetNames = toNodeNames $ Set.toList target
|
let targetNames = toNodeNames $ Set.toList target
|
||||||
state <- Session ask
|
state <- Session ask
|
||||||
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
|
let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
|
||||||
liftIO (asyncCollector state loop)
|
liftIO (asyncCollector state loop)
|
||||||
|
|
Loading…
Reference in a new issue