1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Migrate from TF_DeprecatedSession to TF_Session

Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and
pass an input map for all existing nodes in the graph.
This commit is contained in:
Bart Schuurmans 2023-02-02 11:30:30 +01:00 committed by fkm3
parent 30a12d7776
commit fb629d1207
4 changed files with 244 additions and 102 deletions

View file

@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..) ( TensorFlowException(..)
, Raw.Session , Raw.Session
, withSession , withSession
, extendGraph
, run , run
, SessionAction
, Raw.SessionOptions
, Raw.Graph
, extendGraph
, TensorData(..) , TensorData(..)
, setSessionConfig , setSessionConfig
, setSessionTarget , setSessionTarget
@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized) import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Foldable (for_)
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import Data.Word (Word8) import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
import Foreign.C.String (CString) import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T import qualified Data.Text.Encoding.Error as T
@ -87,15 +95,26 @@ data TensorData = TensorData
} }
deriving (Show, Eq) deriving (Show, Eq)
-- | The action can spawn concurrent tasks which will be canceled before
-- withSession returns.
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
-- | Runs the given action after creating a session with options -- | Runs the given action after creating a session with options
-- populated by the given optionSetter. -- populated by the given optionSetter.
withSession :: (MonadIO m, MonadMask m) withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ()) => (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a) -> SessionAction m a
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> m a -> m a
withSession optionSetter action = do withSession = withSession_ Raw.newSession
withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSession_ mkSession optionSetter action = do
drain <- liftIO $ newMVar [] drain <- liftIO $ newMVar []
let cleanup s = let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit. -- Closes the session to nudge the pending run calls to fail and exit.
@ -105,11 +124,12 @@ withSession optionSetter action = do
mapM_ shutDownRunner runners mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s) checkStatus (Raw.deleteSession s)
let bracketIO x y = bracket (liftIO x) (liftIO . y) let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
bracketIO bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
(optionSetter options >> checkStatus (Raw.newSession options)) bracketIO
cleanup (optionSetter options >> checkStatus (mkSession graph options))
(action (asyncCollector drain)) cleanup
(\session -> action (asyncCollector drain) session graph)
asyncCollector :: MVar [Async ()] -> IO () -> IO () asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@ -122,43 +142,103 @@ shutDownRunner r = do
-- TODO(gnezdo): manage exceptions better than print. -- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO () graphImportGraphDef :: Raw.Graph
extendGraph session pb = -> GraphDef
useProtoAsVoidPtrLen pb $ \ptr len -> -> (Raw.ImportGraphDefOptions -> IO ())
checkStatus $ Raw.extendGraph session ptr len -> IO ()
graphImportGraphDef graph pb optionSetter =
useProtoAsBuffer pb $ \buffer ->
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
optionSetter importGraphDefOptions
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
forGraphOperations_ :: Raw.Graph
-> (Raw.Operation -> IO b)
-> IO ()
forGraphOperations_ graph f = with 0 go
where
go indexPtr = do
op <- Raw.graphNextOperation graph indexPtr
case op of
Raw.Operation ptr | ptr == nullPtr -> return ()
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
extendGraph :: Raw.Graph -> GraphDef -> IO ()
extendGraph graph graphDef =
graphImportGraphDef graph graphDef $ \opts ->
-- All inputs of the nodes in the GraphDef should either refer to
-- other nodes in the GraphDef, or be mapped to nodes already in
-- the Graph by adding an input mapping.
-- We add an input mapping for all existing nodes in the Graph in
-- case they are referenced in the GraphDef.
forGraphOperations_ graph $ \op -> do
srcName <- Raw.operationName op
numOutputs <- Raw.operationNumOutputs op
for_ [0..numOutputs] $ \srcIndex -> do
let dst = Raw.Output op (safeConvert srcIndex)
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
run :: Raw.Session run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds. -> Raw.Graph
-> [B.ByteString] -- ^ Fetches. -> [(B.ByteString, TensorData)] -- ^ Inputs.
-> [B.ByteString] -- ^ Targets. -> [B.ByteString] -- ^ Outputs.
-> [B.ByteString] -- ^ Target operations.
-> IO [TensorData] -> IO [TensorData]
run session feeds fetches targets = do run session graph inputNamesData outputNames targetNames = do
let nullTensor = Raw.Tensor nullPtr
-- Use mask to avoid leaking input tensors before they are passed to 'run' -- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'. -- and output tensors before they are passed to 'createTensorData'.
mask_ $ mask_ $
-- Feeds -- Inputs.
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames -> mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
mapM (createRawTensor . snd) feeds >>= \feedTensors -> withArrayLen inputs $ \nInputs cInputs ->
withArrayLen feedTensors $ \_ cFeedTensors -> mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
-- Fetches. withArrayLen inputTensors $ \_ cInputTensors ->
withStringArrayLen fetches $ \fetchesLen fetchNames -> -- Outputs.
-- tensorOuts is an array of null Tensor pointers that will be filled mapM (resolveOutput graph) outputNames >>= \outputs ->
withArrayLen outputs $ \nOutputs cOutputs ->
-- outputTensors is an array of null Tensor pointers that will be filled
-- by the call to Raw.run. -- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts -> withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
-- Targets. -- Target operations.
withStringArrayLen targets $ \targetsLen ctargets -> do mapM (resolveOperation graph) targetNames >>= \targets ->
withArrayLen targets $ \nTargets cTargets -> do
checkStatus $ Raw.run checkStatus $ Raw.run
session session
nullPtr nullPtr -- RunOptions proto.
feedNames cFeedTensors (safeConvert feedsLen) cInputs cInputTensors (safeConvert nInputs)
fetchNames tensorOuts (safeConvert fetchesLen) cOutputs cOutputTensors (safeConvert nOutputs)
ctargets (safeConvert targetsLen) cTargets (safeConvert nTargets)
nullPtr nullPtr -- RunMetadata.
mapM_ Raw.deleteTensor feedTensors mapM_ Raw.deleteTensor inputTensors
outTensors <- peekArray fetchesLen tensorOuts outTensors <- peekArray nOutputs cOutputTensors
mapM createTensorData outTensors mapM createTensorData outTensors
where
nullTensor = Raw.Tensor nullPtr
resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
resolveOutput graph name = do
let (opName, idx) = parseName name
op <- resolveOperation graph opName
pure $ Raw.Output op (safeConvert idx)
where
parseName :: B.ByteString -> (B.ByteString, Int)
parseName opName =
case break (== ':') (C.unpack opName) of
(opName_, ':':idxStr) | idx <- read idxStr
-> (C.pack opName_, idx)
_ -> (opName, 0)
resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
resolveOperation graph name = do
op <- Raw.graphOperationByName graph name
case op of
Raw.Operation ptr | ptr == nullPtr -> throwM exception
_ -> pure op
where
exception =
let msg = "Operation not found in graph: " <> (T.pack $ show name)
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
-- Internal. -- Internal.
@ -174,21 +254,6 @@ safeConvert x =
show (fromIntegral x :: b))) show (fromIntegral x :: b)))
(toIntegralSized x) (toIntegralSized x)
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData. -- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) = createRawTensor (TensorData dims dt byteVec) =
@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (safeConvert len) \(bytes, len) -> f (castPtr bytes) (safeConvert len)
-- | Serializes the given msg and provides it as BufferPtr argument
-- to the given action.
useProtoAsBuffer :: (Message msg) =>
msg -> (Raw.BufferPtr -> IO a) -> IO a
useProtoAsBuffer msg f =
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
Raw.deleteBuffer
f
-- | Returns the serialized OpList of all OpDefs defined in this -- | Returns the serialized OpList of all OpDefs defined in this
-- address space. -- address space.
getAllOpList :: IO B.ByteString getAllOpList :: IO B.ByteString
getAllOpList = do getAllOpList =
foreignPtr <- bracket checkCall Raw.deleteBuffer $ \buffer ->
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall) -- Makes a copy because it is more reliable than eviscerating
-- Makes a copy because it is more reliable than eviscerating -- Buffer to steal its memory (including custom deallocator).
-- Buffer to steal its memory (including custom deallocator). B.packCStringLen =<< (,)
withForeignPtr foreignPtr $ <$> (castPtr <$> Raw.getBufferData buffer)
\ptr -> B.packCStringLen =<< (,) <*> (safeConvert <$> Raw.getBufferLength buffer)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (safeConvert <$> Raw.getBufferLength ptr)
where where
checkCall = do checkCall = do
p <- Raw.getAllOpList p <- Raw.getAllOpList

View file

@ -18,6 +18,7 @@ module TensorFlow.Internal.Raw where
#include "third_party/tensorflow/c/c_api.h" #include "third_party/tensorflow/c/c_api.h"
import Data.ByteString (ByteString, packCString, useAsCString)
import Foreign import Foreign
import Foreign.C import Foreign.C
@ -61,6 +62,35 @@ stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #} stringGetSize = {# call TF_StringGetSize as ^ #}
-- Operation.
{# pointer *TF_Operation as Operation newtype #}
{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #}
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
instance Storable Operation where
sizeOf (Operation t) = sizeOf t
alignment (Operation t) = alignment t
peek p = fmap Operation (peek (castPtr p))
poke p (Operation t) = poke (castPtr p) t
-- Output.
data Output = Output
{ outputOperation :: Operation
, outputIndex :: CInt
}
{# pointer *TF_Output as OutputPtr -> Output #}
instance Storable Output where
sizeOf _ = {# sizeof TF_Output #}
alignment _ = {# alignof TF_Output #}
peek p = Output <$> {# get TF_Output->oper #} p
<*> (fromIntegral <$> {# get TF_Output->index #} p)
poke p (Output oper index) = do
{# set TF_Output->oper #} p oper
{# set TF_Output->index #} p $ fromIntegral index
-- Buffer. -- Buffer.
data Buffer data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #} {# pointer *TF_Buffer as BufferPtr -> Buffer #}
@ -71,6 +101,12 @@ getBufferData = {# get TF_Buffer->data #}
getBufferLength :: BufferPtr -> IO CULong getBufferLength :: BufferPtr -> IO CULong
getBufferLength = {# get TF_Buffer->length #} getBufferLength = {# get TF_Buffer->length #}
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
deleteBuffer :: BufferPtr -> IO ()
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
-- Tensor. -- Tensor.
{# pointer *TF_Tensor as Tensor newtype #} {# pointer *TF_Tensor as Tensor newtype #}
@ -86,6 +122,8 @@ instance Storable Tensor where
-- `CLLong`). -- `CLLong`).
type CInt64 = {#type int64_t #} type CInt64 = {#type int64_t #}
{# pointer *size_t as CSizePtr -> CSize #}
newTensor :: DataType newTensor :: DataType
-> Ptr CInt64 -- dimensions array -> Ptr CInt64 -- dimensions array
-> CInt -- num dimensions -> CInt -- num dimensions
@ -114,6 +152,31 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #}
tensorData :: Tensor -> IO (Ptr ()) tensorData :: Tensor -> IO (Ptr ())
tensorData = {# call TF_TensorData as ^ #} tensorData = {# call TF_TensorData as ^ #}
-- ImportGraphDefOptions.
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping
{ `ImportGraphDefOptions'
, useAsCString* `ByteString'
, `Int'
, %`OutputPtr'
} -> `()'
#}
-- Graph.
{# pointer *TF_Graph as Graph newtype #}
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
{# fun TF_GraphOperationByName as graphOperationByName
{ `Graph'
, useAsCString* `ByteString'
} -> `Operation'
#}
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
-- Session Options. -- Session Options.
{# pointer *TF_SessionOptions as SessionOptions newtype #} {# pointer *TF_SessionOptions as SessionOptions newtype #}
@ -132,29 +195,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
-- Session. -- Session.
{# pointer *TF_DeprecatedSession as Session newtype #} {# pointer *TF_Session as Session newtype #}
newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
newSession :: SessionOptions -> Status -> IO Session
newSession = {# call TF_NewDeprecatedSession as ^ #}
closeSession :: Session -> Status -> IO () closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseDeprecatedSession as ^ #} closeSession = {# call TF_CloseSession as ^ #}
deleteSession :: Session -> Status -> IO () deleteSession :: Session -> Status -> IO ()
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #} deleteSession = {# call TF_DeleteSession as ^ #}
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
extendGraph = {# call TF_ExtendGraph as ^ #}
run :: Session run :: Session
-> BufferPtr -- RunOptions proto. -> BufferPtr -- RunOptions proto.
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count). -> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count). -> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr CString -> CInt -- Target nodes (names, count). -> Ptr Operation -> CInt -- Target operations (ops, count).
-> BufferPtr -- RunMetadata proto. -> BufferPtr -- RunMetadata proto.
-> Status -> Status
-> IO () -> IO ()
run = {# call TF_Run as ^ #} run = {# call TF_SessionRun as ^ #}
-- FFI helpers. -- FFI helpers.
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO () type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
@ -170,6 +231,3 @@ foreign import ccall "wrapper"
-- in this address space. -- in this address space.
getAllOpList :: IO BufferPtr getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #} getAllOpList = {# call TF_GetAllOpList as ^ #}
foreign import ccall "&TF_DeleteBuffer"
deleteBuffer :: FunPtr (BufferPtr -> IO ())

View file

@ -121,7 +121,7 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- code into a Build function -- code into a Build function
instance IsString Output where instance IsString Output where
fromString s = case break (==':') s of fromString s = case break (==':') s of
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr (n, ':':ixStr) | ix <- read ixStr
-> Output (fromInteger ix) $ assigned n -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s _ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack where assigned = NodeName . Text.pack

View file

@ -55,7 +55,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Nodes import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName) import TensorFlow.Output (NodeName(..), unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder import qualified Data.ByteString.Builder as Builder
@ -67,13 +67,13 @@ import qualified TensorFlow.Internal.FFI as FFI
type Tracer = Builder.Builder -> IO () type Tracer = Builder.Builder -> IO ()
-- Common state threaded through the session. -- Common state threaded through the session.
data SessionState data SessionState = SessionState
= SessionState { { rawSession :: FFI.Session
rawSession :: FFI.Session , rawGraph :: FFI.Graph
, asyncCollector :: IO () -> IO () , asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently. -- ^ Starts the given action concurrently.
, tracer :: Tracer , tracer :: Tracer
} }
newtype SessionT m a newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a) = Session (ReaderT SessionState (BuildT m) a)
@ -120,14 +120,23 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with -- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig'). -- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options (Session m) = runSessionWithOptions options session =
FFI.withSession applyOptions $ _runSessionWithOptions session options $ FFI.withSession
\as rs ->
let initState = SessionState rs as (options ^. sessionTracer) _runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a
-> Options
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
-> m a
_runSessionWithOptions (Session m) options withSession =
withSession applyOptions $
\ac rSession rGraph ->
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
in evalBuildT (runReaderT m initState) in evalBuildT (runReaderT m initState)
where applyOptions opt = do where
FFI.setSessionTarget (options ^. sessionTarget) opt applyOptions opt = do
FFI.setSessionConfig (options ^. sessionConfig) opt FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
instance Monad m => MonadBuild (SessionT m) where instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build build = Session . lift . build
@ -139,16 +148,17 @@ instance Monad m => MonadBuild (SessionT m) where
extend :: MonadIO m => SessionT m () extend :: MonadIO m => SessionT m ()
extend = do extend = do
session <- Session (asks rawSession) session <- Session (asks rawSession)
graph <- Session (asks rawGraph)
trace <- Session (asks tracer) trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer nodesToExtend <- build flushNodeBuffer
unless (null nodesToExtend) $ liftIO $ do unless (null nodesToExtend) $ liftIO $ do
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef)) trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef FFI.extendGraph graph graphDef
-- Now that all the nodes are created, run the initializers. -- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers initializers <- build flushInitializers
unless (null initializers) $ unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
-- | Run a subgraph 't', rendering any dependent nodes that aren't already -- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'. -- rendered, and fetch the corresponding values for 'a'.
@ -173,8 +183,9 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do
let feeds' = fixFeeds feeds let feeds' = fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession) state <- Session ask
runResult <- liftIO $ FFI.run session runResult <- liftIO $ FFI.run (rawSession state)
(rawGraph state)
feeds' feeds'
fetchNames fetchNames
targetNames targetNames
@ -214,5 +225,5 @@ asyncProdNodes nodes = do
extend extend
let targetNames = toNodeNames $ Set.toList target let targetNames = toNodeNames $ Set.toList target
state <- Session ask state <- Session ask
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames)) let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
liftIO (asyncCollector state loop) liftIO (asyncCollector state loop)