mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-22 19:09:43 +01:00
Migrate from TF_DeprecatedSession to TF_Session
Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and pass an input map for all existing nodes in the graph.
This commit is contained in:
parent
30a12d7776
commit
fb629d1207
4 changed files with 244 additions and 102 deletions
|
@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
|
|||
( TensorFlowException(..)
|
||||
, Raw.Session
|
||||
, withSession
|
||||
, extendGraph
|
||||
, run
|
||||
|
||||
, SessionAction
|
||||
|
||||
, Raw.SessionOptions
|
||||
|
||||
, Raw.Graph
|
||||
, extendGraph
|
||||
|
||||
, TensorData(..)
|
||||
, setSessionConfig
|
||||
, setSessionTarget
|
||||
|
@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
|
|||
import Control.Monad.IO.Class (MonadIO, liftIO)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Int (Int64)
|
||||
import Data.Foldable (for_)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word8)
|
||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||
import Foreign.C.String (CString)
|
||||
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
|
||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
|
||||
import Foreign.ForeignPtr (newForeignPtr_)
|
||||
import Foreign.Marshal.Alloc (free)
|
||||
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
||||
import System.IO.Unsafe (unsafePerformIO)
|
||||
import qualified Data.ByteString as B
|
||||
import qualified Data.ByteString.Char8 as C
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Text.Encoding as T
|
||||
import qualified Data.Text.Encoding.Error as T
|
||||
|
@ -87,15 +95,26 @@ data TensorData = TensorData
|
|||
}
|
||||
deriving (Show, Eq)
|
||||
|
||||
-- | The action can spawn concurrent tasks which will be canceled before
|
||||
-- withSession returns.
|
||||
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
|
||||
|
||||
-- | Runs the given action after creating a session with options
|
||||
-- populated by the given optionSetter.
|
||||
withSession :: (MonadIO m, MonadMask m)
|
||||
=> (Raw.SessionOptions -> IO ())
|
||||
-> ((IO () -> IO ()) -> Raw.Session -> m a)
|
||||
-- ^ The action can spawn concurrent tasks which will
|
||||
-- be canceled before withSession returns.
|
||||
-> SessionAction m a
|
||||
-> m a
|
||||
withSession optionSetter action = do
|
||||
withSession = withSession_ Raw.newSession
|
||||
|
||||
withSession_ :: (MonadIO m, MonadMask m)
|
||||
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
|
||||
-- ^ mkSession
|
||||
-> (Raw.SessionOptions -> IO ())
|
||||
-- ^ optionSetter
|
||||
-> SessionAction m a
|
||||
-> m a
|
||||
withSession_ mkSession optionSetter action = do
|
||||
drain <- liftIO $ newMVar []
|
||||
let cleanup s =
|
||||
-- Closes the session to nudge the pending run calls to fail and exit.
|
||||
|
@ -105,11 +124,12 @@ withSession optionSetter action = do
|
|||
mapM_ shutDownRunner runners
|
||||
checkStatus (Raw.deleteSession s)
|
||||
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
||||
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||
bracketIO
|
||||
(optionSetter options >> checkStatus (Raw.newSession options))
|
||||
cleanup
|
||||
(action (asyncCollector drain))
|
||||
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
|
||||
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
||||
bracketIO
|
||||
(optionSetter options >> checkStatus (mkSession graph options))
|
||||
cleanup
|
||||
(\session -> action (asyncCollector drain) session graph)
|
||||
|
||||
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
||||
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
||||
|
@ -122,43 +142,103 @@ shutDownRunner r = do
|
|||
-- TODO(gnezdo): manage exceptions better than print.
|
||||
either print (const (return ())) =<< waitCatch r
|
||||
|
||||
extendGraph :: Raw.Session -> GraphDef -> IO ()
|
||||
extendGraph session pb =
|
||||
useProtoAsVoidPtrLen pb $ \ptr len ->
|
||||
checkStatus $ Raw.extendGraph session ptr len
|
||||
graphImportGraphDef :: Raw.Graph
|
||||
-> GraphDef
|
||||
-> (Raw.ImportGraphDefOptions -> IO ())
|
||||
-> IO ()
|
||||
graphImportGraphDef graph pb optionSetter =
|
||||
useProtoAsBuffer pb $ \buffer ->
|
||||
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
|
||||
optionSetter importGraphDefOptions
|
||||
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
|
||||
|
||||
forGraphOperations_ :: Raw.Graph
|
||||
-> (Raw.Operation -> IO b)
|
||||
-> IO ()
|
||||
forGraphOperations_ graph f = with 0 go
|
||||
where
|
||||
go indexPtr = do
|
||||
op <- Raw.graphNextOperation graph indexPtr
|
||||
case op of
|
||||
Raw.Operation ptr | ptr == nullPtr -> return ()
|
||||
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
|
||||
|
||||
extendGraph :: Raw.Graph -> GraphDef -> IO ()
|
||||
extendGraph graph graphDef =
|
||||
graphImportGraphDef graph graphDef $ \opts ->
|
||||
-- All inputs of the nodes in the GraphDef should either refer to
|
||||
-- other nodes in the GraphDef, or be mapped to nodes already in
|
||||
-- the Graph by adding an input mapping.
|
||||
-- We add an input mapping for all existing nodes in the Graph in
|
||||
-- case they are referenced in the GraphDef.
|
||||
forGraphOperations_ graph $ \op -> do
|
||||
srcName <- Raw.operationName op
|
||||
numOutputs <- Raw.operationNumOutputs op
|
||||
for_ [0..numOutputs] $ \srcIndex -> do
|
||||
let dst = Raw.Output op (safeConvert srcIndex)
|
||||
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
|
||||
|
||||
run :: Raw.Session
|
||||
-> [(B.ByteString, TensorData)] -- ^ Feeds.
|
||||
-> [B.ByteString] -- ^ Fetches.
|
||||
-> [B.ByteString] -- ^ Targets.
|
||||
-> Raw.Graph
|
||||
-> [(B.ByteString, TensorData)] -- ^ Inputs.
|
||||
-> [B.ByteString] -- ^ Outputs.
|
||||
-> [B.ByteString] -- ^ Target operations.
|
||||
-> IO [TensorData]
|
||||
run session feeds fetches targets = do
|
||||
let nullTensor = Raw.Tensor nullPtr
|
||||
run session graph inputNamesData outputNames targetNames = do
|
||||
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
||||
-- and output tensors before they are passed to 'createTensorData'.
|
||||
mask_ $
|
||||
-- Feeds
|
||||
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
|
||||
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
|
||||
withArrayLen feedTensors $ \_ cFeedTensors ->
|
||||
-- Fetches.
|
||||
withStringArrayLen fetches $ \fetchesLen fetchNames ->
|
||||
-- tensorOuts is an array of null Tensor pointers that will be filled
|
||||
-- Inputs.
|
||||
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
|
||||
withArrayLen inputs $ \nInputs cInputs ->
|
||||
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
|
||||
withArrayLen inputTensors $ \_ cInputTensors ->
|
||||
-- Outputs.
|
||||
mapM (resolveOutput graph) outputNames >>= \outputs ->
|
||||
withArrayLen outputs $ \nOutputs cOutputs ->
|
||||
-- outputTensors is an array of null Tensor pointers that will be filled
|
||||
-- by the call to Raw.run.
|
||||
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
|
||||
-- Targets.
|
||||
withStringArrayLen targets $ \targetsLen ctargets -> do
|
||||
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
|
||||
-- Target operations.
|
||||
mapM (resolveOperation graph) targetNames >>= \targets ->
|
||||
withArrayLen targets $ \nTargets cTargets -> do
|
||||
checkStatus $ Raw.run
|
||||
session
|
||||
nullPtr
|
||||
feedNames cFeedTensors (safeConvert feedsLen)
|
||||
fetchNames tensorOuts (safeConvert fetchesLen)
|
||||
ctargets (safeConvert targetsLen)
|
||||
nullPtr
|
||||
mapM_ Raw.deleteTensor feedTensors
|
||||
outTensors <- peekArray fetchesLen tensorOuts
|
||||
nullPtr -- RunOptions proto.
|
||||
cInputs cInputTensors (safeConvert nInputs)
|
||||
cOutputs cOutputTensors (safeConvert nOutputs)
|
||||
cTargets (safeConvert nTargets)
|
||||
nullPtr -- RunMetadata.
|
||||
mapM_ Raw.deleteTensor inputTensors
|
||||
outTensors <- peekArray nOutputs cOutputTensors
|
||||
mapM createTensorData outTensors
|
||||
where
|
||||
|
||||
nullTensor = Raw.Tensor nullPtr
|
||||
|
||||
resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
|
||||
resolveOutput graph name = do
|
||||
let (opName, idx) = parseName name
|
||||
op <- resolveOperation graph opName
|
||||
pure $ Raw.Output op (safeConvert idx)
|
||||
where
|
||||
parseName :: B.ByteString -> (B.ByteString, Int)
|
||||
parseName opName =
|
||||
case break (== ':') (C.unpack opName) of
|
||||
(opName_, ':':idxStr) | idx <- read idxStr
|
||||
-> (C.pack opName_, idx)
|
||||
_ -> (opName, 0)
|
||||
|
||||
resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
|
||||
resolveOperation graph name = do
|
||||
op <- Raw.graphOperationByName graph name
|
||||
case op of
|
||||
Raw.Operation ptr | ptr == nullPtr -> throwM exception
|
||||
_ -> pure op
|
||||
where
|
||||
exception =
|
||||
let msg = "Operation not found in graph: " <> (T.pack $ show name)
|
||||
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
|
||||
|
||||
|
||||
-- Internal.
|
||||
|
@ -174,21 +254,6 @@ safeConvert x =
|
|||
show (fromIntegral x :: b)))
|
||||
(toIntegralSized x)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as a list of CString.
|
||||
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
||||
withStringList strings fn = go strings []
|
||||
where
|
||||
go [] cs = fn (reverse cs)
|
||||
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
||||
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as an array of CString.
|
||||
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
|
||||
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
||||
|
||||
|
||||
-- | Create a Raw.Tensor from a TensorData.
|
||||
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||
createRawTensor (TensorData dims dt byteVec) =
|
||||
|
@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
|
|||
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
||||
|
||||
-- | Serializes the given msg and provides it as BufferPtr argument
|
||||
-- to the given action.
|
||||
useProtoAsBuffer :: (Message msg) =>
|
||||
msg -> (Raw.BufferPtr -> IO a) -> IO a
|
||||
useProtoAsBuffer msg f =
|
||||
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
|
||||
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
|
||||
Raw.deleteBuffer
|
||||
f
|
||||
|
||||
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||
-- address space.
|
||||
getAllOpList :: IO B.ByteString
|
||||
getAllOpList = do
|
||||
foreignPtr <-
|
||||
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
|
||||
-- Makes a copy because it is more reliable than eviscerating
|
||||
-- Buffer to steal its memory (including custom deallocator).
|
||||
withForeignPtr foreignPtr $
|
||||
\ptr -> B.packCStringLen =<< (,)
|
||||
<$> (castPtr <$> Raw.getBufferData ptr)
|
||||
<*> (safeConvert <$> Raw.getBufferLength ptr)
|
||||
getAllOpList =
|
||||
bracket checkCall Raw.deleteBuffer $ \buffer ->
|
||||
-- Makes a copy because it is more reliable than eviscerating
|
||||
-- Buffer to steal its memory (including custom deallocator).
|
||||
B.packCStringLen =<< (,)
|
||||
<$> (castPtr <$> Raw.getBufferData buffer)
|
||||
<*> (safeConvert <$> Raw.getBufferLength buffer)
|
||||
where
|
||||
checkCall = do
|
||||
p <- Raw.getAllOpList
|
||||
|
|
|
@ -18,6 +18,7 @@ module TensorFlow.Internal.Raw where
|
|||
|
||||
#include "third_party/tensorflow/c/c_api.h"
|
||||
|
||||
import Data.ByteString (ByteString, packCString, useAsCString)
|
||||
import Foreign
|
||||
import Foreign.C
|
||||
|
||||
|
@ -61,6 +62,35 @@ stringGetSize :: TString -> IO CULong
|
|||
stringGetSize = {# call TF_StringGetSize as ^ #}
|
||||
|
||||
|
||||
-- Operation.
|
||||
{# pointer *TF_Operation as Operation newtype #}
|
||||
{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #}
|
||||
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
|
||||
|
||||
instance Storable Operation where
|
||||
sizeOf (Operation t) = sizeOf t
|
||||
alignment (Operation t) = alignment t
|
||||
peek p = fmap Operation (peek (castPtr p))
|
||||
poke p (Operation t) = poke (castPtr p) t
|
||||
|
||||
|
||||
-- Output.
|
||||
data Output = Output
|
||||
{ outputOperation :: Operation
|
||||
, outputIndex :: CInt
|
||||
}
|
||||
{# pointer *TF_Output as OutputPtr -> Output #}
|
||||
|
||||
instance Storable Output where
|
||||
sizeOf _ = {# sizeof TF_Output #}
|
||||
alignment _ = {# alignof TF_Output #}
|
||||
peek p = Output <$> {# get TF_Output->oper #} p
|
||||
<*> (fromIntegral <$> {# get TF_Output->index #} p)
|
||||
poke p (Output oper index) = do
|
||||
{# set TF_Output->oper #} p oper
|
||||
{# set TF_Output->index #} p $ fromIntegral index
|
||||
|
||||
|
||||
-- Buffer.
|
||||
data Buffer
|
||||
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||
|
@ -71,6 +101,12 @@ getBufferData = {# get TF_Buffer->data #}
|
|||
getBufferLength :: BufferPtr -> IO CULong
|
||||
getBufferLength = {# get TF_Buffer->length #}
|
||||
|
||||
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
|
||||
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
|
||||
|
||||
deleteBuffer :: BufferPtr -> IO ()
|
||||
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
|
||||
|
||||
-- Tensor.
|
||||
{# pointer *TF_Tensor as Tensor newtype #}
|
||||
|
||||
|
@ -86,6 +122,8 @@ instance Storable Tensor where
|
|||
-- `CLLong`).
|
||||
type CInt64 = {#type int64_t #}
|
||||
|
||||
{# pointer *size_t as CSizePtr -> CSize #}
|
||||
|
||||
newTensor :: DataType
|
||||
-> Ptr CInt64 -- dimensions array
|
||||
-> CInt -- num dimensions
|
||||
|
@ -114,6 +152,31 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #}
|
|||
tensorData :: Tensor -> IO (Ptr ())
|
||||
tensorData = {# call TF_TensorData as ^ #}
|
||||
|
||||
-- ImportGraphDefOptions.
|
||||
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
|
||||
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
|
||||
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
|
||||
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping
|
||||
{ `ImportGraphDefOptions'
|
||||
, useAsCString* `ByteString'
|
||||
, `Int'
|
||||
, %`OutputPtr'
|
||||
} -> `()'
|
||||
#}
|
||||
|
||||
|
||||
-- Graph.
|
||||
{# pointer *TF_Graph as Graph newtype #}
|
||||
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
|
||||
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
|
||||
{# fun TF_GraphOperationByName as graphOperationByName
|
||||
{ `Graph'
|
||||
, useAsCString* `ByteString'
|
||||
} -> `Operation'
|
||||
#}
|
||||
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
|
||||
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
|
||||
|
||||
|
||||
-- Session Options.
|
||||
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
||||
|
@ -132,29 +195,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
|||
|
||||
|
||||
-- Session.
|
||||
{# pointer *TF_DeprecatedSession as Session newtype #}
|
||||
{# pointer *TF_Session as Session newtype #}
|
||||
|
||||
newSession :: Graph -> SessionOptions -> Status -> IO Session
|
||||
newSession = {# call TF_NewSession as ^ #}
|
||||
|
||||
newSession :: SessionOptions -> Status -> IO Session
|
||||
newSession = {# call TF_NewDeprecatedSession as ^ #}
|
||||
|
||||
closeSession :: Session -> Status -> IO ()
|
||||
closeSession = {# call TF_CloseDeprecatedSession as ^ #}
|
||||
closeSession = {# call TF_CloseSession as ^ #}
|
||||
|
||||
deleteSession :: Session -> Status -> IO ()
|
||||
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #}
|
||||
|
||||
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
|
||||
extendGraph = {# call TF_ExtendGraph as ^ #}
|
||||
deleteSession = {# call TF_DeleteSession as ^ #}
|
||||
|
||||
run :: Session
|
||||
-> BufferPtr -- RunOptions proto.
|
||||
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
||||
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
||||
-> Ptr CString -> CInt -- Target nodes (names, count).
|
||||
-> BufferPtr -- RunMetadata proto.
|
||||
-> BufferPtr -- RunOptions proto.
|
||||
-> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
||||
-> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
||||
-> Ptr Operation -> CInt -- Target operations (ops, count).
|
||||
-> BufferPtr -- RunMetadata proto.
|
||||
-> Status
|
||||
-> IO ()
|
||||
run = {# call TF_Run as ^ #}
|
||||
run = {# call TF_SessionRun as ^ #}
|
||||
|
||||
-- FFI helpers.
|
||||
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
||||
|
@ -170,6 +231,3 @@ foreign import ccall "wrapper"
|
|||
-- in this address space.
|
||||
getAllOpList :: IO BufferPtr
|
||||
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||
|
||||
foreign import ccall "&TF_DeleteBuffer"
|
||||
deleteBuffer :: FunPtr (BufferPtr -> IO ())
|
||||
|
|
|
@ -121,7 +121,7 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
|||
-- code into a Build function
|
||||
instance IsString Output where
|
||||
fromString s = case break (==':') s of
|
||||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||
-> Output (fromInteger ix) $ assigned n
|
||||
(n, ':':ixStr) | ix <- read ixStr
|
||||
-> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned = NodeName . Text.pack
|
||||
|
|
|
@ -55,7 +55,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
|||
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output (NodeName, unNodeName)
|
||||
import TensorFlow.Output (NodeName(..), unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
|
@ -67,13 +67,13 @@ import qualified TensorFlow.Internal.FFI as FFI
|
|||
type Tracer = Builder.Builder -> IO ()
|
||||
|
||||
-- Common state threaded through the session.
|
||||
data SessionState
|
||||
= SessionState {
|
||||
rawSession :: FFI.Session
|
||||
, asyncCollector :: IO () -> IO ()
|
||||
-- ^ Starts the given action concurrently.
|
||||
, tracer :: Tracer
|
||||
}
|
||||
data SessionState = SessionState
|
||||
{ rawSession :: FFI.Session
|
||||
, rawGraph :: FFI.Graph
|
||||
, asyncCollector :: IO () -> IO ()
|
||||
-- ^ Starts the given action concurrently.
|
||||
, tracer :: Tracer
|
||||
}
|
||||
|
||||
newtype SessionT m a
|
||||
= Session (ReaderT SessionState (BuildT m) a)
|
||||
|
@ -120,14 +120,23 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
|||
-- | Run 'Session' actions in a new TensorFlow session created with
|
||||
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
||||
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
||||
runSessionWithOptions options (Session m) =
|
||||
FFI.withSession applyOptions $
|
||||
\as rs ->
|
||||
let initState = SessionState rs as (options ^. sessionTracer)
|
||||
runSessionWithOptions options session =
|
||||
_runSessionWithOptions session options $ FFI.withSession
|
||||
|
||||
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
||||
=> SessionT m a
|
||||
-> Options
|
||||
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
|
||||
-> m a
|
||||
_runSessionWithOptions (Session m) options withSession =
|
||||
withSession applyOptions $
|
||||
\ac rSession rGraph ->
|
||||
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
|
||||
in evalBuildT (runReaderT m initState)
|
||||
where applyOptions opt = do
|
||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
where
|
||||
applyOptions opt = do
|
||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
instance Monad m => MonadBuild (SessionT m) where
|
||||
build = Session . lift . build
|
||||
|
@ -139,16 +148,17 @@ instance Monad m => MonadBuild (SessionT m) where
|
|||
extend :: MonadIO m => SessionT m ()
|
||||
extend = do
|
||||
session <- Session (asks rawSession)
|
||||
graph <- Session (asks rawGraph)
|
||||
trace <- Session (asks tracer)
|
||||
nodesToExtend <- build flushNodeBuffer
|
||||
unless (null nodesToExtend) $ liftIO $ do
|
||||
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
||||
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
||||
FFI.extendGraph session graphDef
|
||||
FFI.extendGraph graph graphDef
|
||||
-- Now that all the nodes are created, run the initializers.
|
||||
initializers <- build flushInitializers
|
||||
unless (null initializers) $
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, and fetch the corresponding values for 'a'.
|
||||
|
@ -173,8 +183,9 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
|||
let feeds' = fixFeeds feeds
|
||||
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||
targetNames = toNodeNames $ Set.toList target
|
||||
session <- Session (asks rawSession)
|
||||
runResult <- liftIO $ FFI.run session
|
||||
state <- Session ask
|
||||
runResult <- liftIO $ FFI.run (rawSession state)
|
||||
(rawGraph state)
|
||||
feeds'
|
||||
fetchNames
|
||||
targetNames
|
||||
|
@ -214,5 +225,5 @@ asyncProdNodes nodes = do
|
|||
extend
|
||||
let targetNames = toNodeNames $ Set.toList target
|
||||
state <- Session ask
|
||||
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
|
||||
let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
|
||||
liftIO (asyncCollector state loop)
|
||||
|
|
Loading…
Reference in a new issue