Migrate from TF_DeprecatedSession to TF_Session

Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and
pass an input map for all existing nodes in the graph.
This commit is contained in:
Bart Schuurmans 2023-02-02 11:30:30 +01:00 committed by fkm3
parent 30a12d7776
commit fb629d1207
4 changed files with 244 additions and 102 deletions

View File

@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, SessionAction
, Raw.SessionOptions
, Raw.Graph
, extendGraph
, TensorData(..)
, setSessionConfig
, setSessionTarget
@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Foldable (for_)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
@ -87,15 +95,26 @@ data TensorData = TensorData
}
deriving (Show, Eq)
-- | The action can spawn concurrent tasks which will be canceled before
-- withSession returns.
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> SessionAction m a
-> m a
withSession optionSetter action = do
withSession = withSession_ Raw.newSession
withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSession_ mkSession optionSetter action = do
drain <- liftIO $ newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
@ -105,11 +124,12 @@ withSession optionSetter action = do
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (mkSession graph options))
cleanup
(\session -> action (asyncCollector drain) session graph)
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@ -122,43 +142,103 @@ shutDownRunner r = do
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
graphImportGraphDef :: Raw.Graph
-> GraphDef
-> (Raw.ImportGraphDefOptions -> IO ())
-> IO ()
graphImportGraphDef graph pb optionSetter =
useProtoAsBuffer pb $ \buffer ->
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
optionSetter importGraphDefOptions
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
forGraphOperations_ :: Raw.Graph
-> (Raw.Operation -> IO b)
-> IO ()
forGraphOperations_ graph f = with 0 go
where
go indexPtr = do
op <- Raw.graphNextOperation graph indexPtr
case op of
Raw.Operation ptr | ptr == nullPtr -> return ()
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
extendGraph :: Raw.Graph -> GraphDef -> IO ()
extendGraph graph graphDef =
graphImportGraphDef graph graphDef $ \opts ->
-- All inputs of the nodes in the GraphDef should either refer to
-- other nodes in the GraphDef, or be mapped to nodes already in
-- the Graph by adding an input mapping.
-- We add an input mapping for all existing nodes in the Graph in
-- case they are referenced in the GraphDef.
forGraphOperations_ graph $ \op -> do
srcName <- Raw.operationName op
numOutputs <- Raw.operationNumOutputs op
for_ [0..numOutputs] $ \srcIndex -> do
let dst = Raw.Output op (safeConvert srcIndex)
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> Raw.Graph
-> [(B.ByteString, TensorData)] -- ^ Inputs.
-> [B.ByteString] -- ^ Outputs.
-> [B.ByteString] -- ^ Target operations.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
run session graph inputNamesData outputNames targetNames = do
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- Inputs.
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
withArrayLen inputs $ \nInputs cInputs ->
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
withArrayLen inputTensors $ \_ cInputTensors ->
-- Outputs.
mapM (resolveOutput graph) outputNames >>= \outputs ->
withArrayLen outputs $ \nOutputs cOutputs ->
-- outputTensors is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
-- Target operations.
mapM (resolveOperation graph) targetNames >>= \targets ->
withArrayLen targets $ \nTargets cTargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
nullPtr
mapM_ Raw.deleteTensor feedTensors
outTensors <- peekArray fetchesLen tensorOuts
nullPtr -- RunOptions proto.
cInputs cInputTensors (safeConvert nInputs)
cOutputs cOutputTensors (safeConvert nOutputs)
cTargets (safeConvert nTargets)
nullPtr -- RunMetadata.
mapM_ Raw.deleteTensor inputTensors
outTensors <- peekArray nOutputs cOutputTensors
mapM createTensorData outTensors
where
nullTensor = Raw.Tensor nullPtr
resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
resolveOutput graph name = do
let (opName, idx) = parseName name
op <- resolveOperation graph opName
pure $ Raw.Output op (safeConvert idx)
where
parseName :: B.ByteString -> (B.ByteString, Int)
parseName opName =
case break (== ':') (C.unpack opName) of
(opName_, ':':idxStr) | idx <- read idxStr
-> (C.pack opName_, idx)
_ -> (opName, 0)
resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
resolveOperation graph name = do
op <- Raw.graphOperationByName graph name
case op of
Raw.Operation ptr | ptr == nullPtr -> throwM exception
_ -> pure op
where
exception =
let msg = "Operation not found in graph: " <> (T.pack $ show name)
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
-- Internal.
@ -174,21 +254,6 @@ safeConvert x =
show (fromIntegral x :: b)))
(toIntegralSized x)
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
-- | Serializes the given msg and provides it as BufferPtr argument
-- to the given action.
useProtoAsBuffer :: (Message msg) =>
msg -> (Raw.BufferPtr -> IO a) -> IO a
useProtoAsBuffer msg f =
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
Raw.deleteBuffer
f
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (safeConvert <$> Raw.getBufferLength ptr)
getAllOpList =
bracket checkCall Raw.deleteBuffer $ \buffer ->
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData buffer)
<*> (safeConvert <$> Raw.getBufferLength buffer)
where
checkCall = do
p <- Raw.getAllOpList

View File

@ -18,6 +18,7 @@ module TensorFlow.Internal.Raw where
#include "third_party/tensorflow/c/c_api.h"
import Data.ByteString (ByteString, packCString, useAsCString)
import Foreign
import Foreign.C
@ -61,6 +62,35 @@ stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #}
-- Operation.
{# pointer *TF_Operation as Operation newtype #}
{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #}
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
instance Storable Operation where
sizeOf (Operation t) = sizeOf t
alignment (Operation t) = alignment t
peek p = fmap Operation (peek (castPtr p))
poke p (Operation t) = poke (castPtr p) t
-- Output.
data Output = Output
{ outputOperation :: Operation
, outputIndex :: CInt
}
{# pointer *TF_Output as OutputPtr -> Output #}
instance Storable Output where
sizeOf _ = {# sizeof TF_Output #}
alignment _ = {# alignof TF_Output #}
peek p = Output <$> {# get TF_Output->oper #} p
<*> (fromIntegral <$> {# get TF_Output->index #} p)
poke p (Output oper index) = do
{# set TF_Output->oper #} p oper
{# set TF_Output->index #} p $ fromIntegral index
-- Buffer.
data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
@ -71,6 +101,12 @@ getBufferData = {# get TF_Buffer->data #}
getBufferLength :: BufferPtr -> IO CULong
getBufferLength = {# get TF_Buffer->length #}
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
deleteBuffer :: BufferPtr -> IO ()
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
-- Tensor.
{# pointer *TF_Tensor as Tensor newtype #}
@ -86,6 +122,8 @@ instance Storable Tensor where
-- `CLLong`).
type CInt64 = {#type int64_t #}
{# pointer *size_t as CSizePtr -> CSize #}
newTensor :: DataType
-> Ptr CInt64 -- dimensions array
-> CInt -- num dimensions
@ -114,6 +152,31 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #}
tensorData :: Tensor -> IO (Ptr ())
tensorData = {# call TF_TensorData as ^ #}
-- ImportGraphDefOptions.
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping
{ `ImportGraphDefOptions'
, useAsCString* `ByteString'
, `Int'
, %`OutputPtr'
} -> `()'
#}
-- Graph.
{# pointer *TF_Graph as Graph newtype #}
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
{# fun TF_GraphOperationByName as graphOperationByName
{ `Graph'
, useAsCString* `ByteString'
} -> `Operation'
#}
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
-- Session Options.
{# pointer *TF_SessionOptions as SessionOptions newtype #}
@ -132,29 +195,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
-- Session.
{# pointer *TF_DeprecatedSession as Session newtype #}
{# pointer *TF_Session as Session newtype #}
newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
newSession :: SessionOptions -> Status -> IO Session
newSession = {# call TF_NewDeprecatedSession as ^ #}
closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseDeprecatedSession as ^ #}
closeSession = {# call TF_CloseSession as ^ #}
deleteSession :: Session -> Status -> IO ()
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #}
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
extendGraph = {# call TF_ExtendGraph as ^ #}
deleteSession = {# call TF_DeleteSession as ^ #}
run :: Session
-> BufferPtr -- RunOptions proto.
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr CString -> CInt -- Target nodes (names, count).
-> BufferPtr -- RunMetadata proto.
-> BufferPtr -- RunOptions proto.
-> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr Operation -> CInt -- Target operations (ops, count).
-> BufferPtr -- RunMetadata proto.
-> Status
-> IO ()
run = {# call TF_Run as ^ #}
run = {# call TF_SessionRun as ^ #}
-- FFI helpers.
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
@ -170,6 +231,3 @@ foreign import ccall "wrapper"
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}
foreign import ccall "&TF_DeleteBuffer"
deleteBuffer :: FunPtr (BufferPtr -> IO ())

View File

@ -121,7 +121,7 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- code into a Build function
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
-> Output (fromInteger ix) $ assigned n
(n, ':':ixStr) | ix <- read ixStr
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack

View File

@ -55,7 +55,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Output (NodeName(..), unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
@ -67,13 +67,13 @@ import qualified TensorFlow.Internal.FFI as FFI
type Tracer = Builder.Builder -> IO ()
-- Common state threaded through the session.
data SessionState
= SessionState {
rawSession :: FFI.Session
, asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently.
, tracer :: Tracer
}
data SessionState = SessionState
{ rawSession :: FFI.Session
, rawGraph :: FFI.Graph
, asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently.
, tracer :: Tracer
}
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
@ -120,14 +120,23 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs ->
let initState = SessionState rs as (options ^. sessionTracer)
runSessionWithOptions options session =
_runSessionWithOptions session options $ FFI.withSession
_runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a
-> Options
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
-> m a
_runSessionWithOptions (Session m) options withSession =
withSession applyOptions $
\ac rSession rGraph ->
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
in evalBuildT (runReaderT m initState)
where applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
where
applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build
@ -139,16 +148,17 @@ instance Monad m => MonadBuild (SessionT m) where
extend :: MonadIO m => SessionT m ()
extend = do
session <- Session (asks rawSession)
graph <- Session (asks rawGraph)
trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer
unless (null nodesToExtend) $ liftIO $ do
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef
FFI.extendGraph graph graphDef
-- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers
unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
@ -173,8 +183,9 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do
let feeds' = fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession)
runResult <- liftIO $ FFI.run session
state <- Session ask
runResult <- liftIO $ FFI.run (rawSession state)
(rawGraph state)
feeds'
fetchNames
targetNames
@ -214,5 +225,5 @@ asyncProdNodes nodes = do
extend
let targetNames = toNodeNames $ Set.toList target
state <- Session ask
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
liftIO (asyncCollector state loop)