diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 3a840e8..3cc9406 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI ( TensorFlowException(..) , Raw.Session , withSession - , extendGraph , run + + , SessionAction + + , Raw.SessionOptions + + , Raw.Graph + , extendGraph + , TensorData(..) , setSessionConfig , setSessionTarget @@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Bits (Bits, toIntegralSized) import Data.Int (Int64) +import Data.Foldable (for_) import Data.Maybe (fromMaybe) import Data.Typeable (Typeable) import Data.Word (Word8) -import Foreign (Ptr, FunPtr, nullPtr, castPtr) -import Foreign.C.String (CString) -import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr) +import Foreign (Ptr, FunPtr, nullPtr, castPtr, with) +import Foreign.ForeignPtr (newForeignPtr_) import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import System.IO.Unsafe (unsafePerformIO) import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as C import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T @@ -87,15 +95,26 @@ data TensorData = TensorData } deriving (Show, Eq) +-- | The action can spawn concurrent tasks which will be canceled before +-- withSession returns. +type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a + -- | Runs the given action after creating a session with options -- populated by the given optionSetter. withSession :: (MonadIO m, MonadMask m) => (Raw.SessionOptions -> IO ()) - -> ((IO () -> IO ()) -> Raw.Session -> m a) - -- ^ The action can spawn concurrent tasks which will - -- be canceled before withSession returns. + -> SessionAction m a -> m a -withSession optionSetter action = do +withSession = withSession_ Raw.newSession + +withSession_ :: (MonadIO m, MonadMask m) + => (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session) + -- ^ mkSession + -> (Raw.SessionOptions -> IO ()) + -- ^ optionSetter + -> SessionAction m a + -> m a +withSession_ mkSession optionSetter action = do drain <- liftIO $ newMVar [] let cleanup s = -- Closes the session to nudge the pending run calls to fail and exit. @@ -105,11 +124,12 @@ withSession optionSetter action = do mapM_ shutDownRunner runners checkStatus (Raw.deleteSession s) let bracketIO x y = bracket (liftIO x) (liftIO . y) - bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do - bracketIO - (optionSetter options >> checkStatus (Raw.newSession options)) - cleanup - (action (asyncCollector drain)) + bracketIO Raw.newGraph Raw.deleteGraph $ \graph -> + bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do + bracketIO + (optionSetter options >> checkStatus (mkSession graph options)) + cleanup + (\session -> action (asyncCollector drain) session graph) asyncCollector :: MVar [Async ()] -> IO () -> IO () asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord @@ -122,43 +142,103 @@ shutDownRunner r = do -- TODO(gnezdo): manage exceptions better than print. either print (const (return ())) =<< waitCatch r -extendGraph :: Raw.Session -> GraphDef -> IO () -extendGraph session pb = - useProtoAsVoidPtrLen pb $ \ptr len -> - checkStatus $ Raw.extendGraph session ptr len +graphImportGraphDef :: Raw.Graph + -> GraphDef + -> (Raw.ImportGraphDefOptions -> IO ()) + -> IO () +graphImportGraphDef graph pb optionSetter = + useProtoAsBuffer pb $ \buffer -> + bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do + optionSetter importGraphDefOptions + checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions +forGraphOperations_ :: Raw.Graph + -> (Raw.Operation -> IO b) + -> IO () +forGraphOperations_ graph f = with 0 go + where + go indexPtr = do + op <- Raw.graphNextOperation graph indexPtr + case op of + Raw.Operation ptr | ptr == nullPtr -> return () + _ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation. + +extendGraph :: Raw.Graph -> GraphDef -> IO () +extendGraph graph graphDef = + graphImportGraphDef graph graphDef $ \opts -> + -- All inputs of the nodes in the GraphDef should either refer to + -- other nodes in the GraphDef, or be mapped to nodes already in + -- the Graph by adding an input mapping. + -- We add an input mapping for all existing nodes in the Graph in + -- case they are referenced in the GraphDef. + forGraphOperations_ graph $ \op -> do + srcName <- Raw.operationName op + numOutputs <- Raw.operationNumOutputs op + for_ [0..numOutputs] $ \srcIndex -> do + let dst = Raw.Output op (safeConvert srcIndex) + with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex run :: Raw.Session - -> [(B.ByteString, TensorData)] -- ^ Feeds. - -> [B.ByteString] -- ^ Fetches. - -> [B.ByteString] -- ^ Targets. + -> Raw.Graph + -> [(B.ByteString, TensorData)] -- ^ Inputs. + -> [B.ByteString] -- ^ Outputs. + -> [B.ByteString] -- ^ Target operations. -> IO [TensorData] -run session feeds fetches targets = do - let nullTensor = Raw.Tensor nullPtr +run session graph inputNamesData outputNames targetNames = do -- Use mask to avoid leaking input tensors before they are passed to 'run' -- and output tensors before they are passed to 'createTensorData'. mask_ $ - -- Feeds - withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames -> - mapM (createRawTensor . snd) feeds >>= \feedTensors -> - withArrayLen feedTensors $ \_ cFeedTensors -> - -- Fetches. - withStringArrayLen fetches $ \fetchesLen fetchNames -> - -- tensorOuts is an array of null Tensor pointers that will be filled + -- Inputs. + mapM (resolveOutput graph . fst) inputNamesData >>= \inputs -> + withArrayLen inputs $ \nInputs cInputs -> + mapM (createRawTensor . snd) inputNamesData >>= \inputTensors -> + withArrayLen inputTensors $ \_ cInputTensors -> + -- Outputs. + mapM (resolveOutput graph) outputNames >>= \outputs -> + withArrayLen outputs $ \nOutputs cOutputs -> + -- outputTensors is an array of null Tensor pointers that will be filled -- by the call to Raw.run. - withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts -> - -- Targets. - withStringArrayLen targets $ \targetsLen ctargets -> do + withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors -> + -- Target operations. + mapM (resolveOperation graph) targetNames >>= \targets -> + withArrayLen targets $ \nTargets cTargets -> do checkStatus $ Raw.run session - nullPtr - feedNames cFeedTensors (safeConvert feedsLen) - fetchNames tensorOuts (safeConvert fetchesLen) - ctargets (safeConvert targetsLen) - nullPtr - mapM_ Raw.deleteTensor feedTensors - outTensors <- peekArray fetchesLen tensorOuts + nullPtr -- RunOptions proto. + cInputs cInputTensors (safeConvert nInputs) + cOutputs cOutputTensors (safeConvert nOutputs) + cTargets (safeConvert nTargets) + nullPtr -- RunMetadata. + mapM_ Raw.deleteTensor inputTensors + outTensors <- peekArray nOutputs cOutputTensors mapM createTensorData outTensors + where + + nullTensor = Raw.Tensor nullPtr + +resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output +resolveOutput graph name = do + let (opName, idx) = parseName name + op <- resolveOperation graph opName + pure $ Raw.Output op (safeConvert idx) + where + parseName :: B.ByteString -> (B.ByteString, Int) + parseName opName = + case break (== ':') (C.unpack opName) of + (opName_, ':':idxStr) | idx <- read idxStr + -> (C.pack opName_, idx) + _ -> (opName, 0) + +resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation +resolveOperation graph name = do + op <- Raw.graphOperationByName graph name + case op of + Raw.Operation ptr | ptr == nullPtr -> throwM exception + _ -> pure op + where + exception = + let msg = "Operation not found in graph: " <> (T.pack $ show name) + in TensorFlowException Raw.TF_INVALID_ARGUMENT msg -- Internal. @@ -174,21 +254,6 @@ safeConvert x = show (fromIntegral x :: b))) (toIntegralSized x) - --- | Use a list of ByteString as a list of CString. -withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a -withStringList strings fn = go strings [] - where - go [] cs = fn (reverse cs) - -- TODO(fmayle): Is it worth using unsafeAsCString here? - go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs) - - --- | Use a list of ByteString as an array of CString. -withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a -withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn) - - -- | Create a Raw.Tensor from a TensorData. createRawTensor :: TensorData -> IO Raw.Tensor createRawTensor (TensorData dims dt byteVec) = @@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) => useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) -> f (castPtr bytes) (safeConvert len) +-- | Serializes the given msg and provides it as BufferPtr argument +-- to the given action. +useProtoAsBuffer :: (Message msg) => + msg -> (Raw.BufferPtr -> IO a) -> IO a +useProtoAsBuffer msg f = + B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) -> + bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len)) + Raw.deleteBuffer + f + -- | Returns the serialized OpList of all OpDefs defined in this -- address space. getAllOpList :: IO B.ByteString -getAllOpList = do - foreignPtr <- - mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall) - -- Makes a copy because it is more reliable than eviscerating - -- Buffer to steal its memory (including custom deallocator). - withForeignPtr foreignPtr $ - \ptr -> B.packCStringLen =<< (,) - <$> (castPtr <$> Raw.getBufferData ptr) - <*> (safeConvert <$> Raw.getBufferLength ptr) +getAllOpList = + bracket checkCall Raw.deleteBuffer $ \buffer -> + -- Makes a copy because it is more reliable than eviscerating + -- Buffer to steal its memory (including custom deallocator). + B.packCStringLen =<< (,) + <$> (castPtr <$> Raw.getBufferData buffer) + <*> (safeConvert <$> Raw.getBufferLength buffer) where checkCall = do p <- Raw.getAllOpList diff --git a/tensorflow/src/TensorFlow/Internal/Raw.chs b/tensorflow/src/TensorFlow/Internal/Raw.chs index f1fa7a0..28fabf2 100644 --- a/tensorflow/src/TensorFlow/Internal/Raw.chs +++ b/tensorflow/src/TensorFlow/Internal/Raw.chs @@ -18,6 +18,7 @@ module TensorFlow.Internal.Raw where #include "third_party/tensorflow/c/c_api.h" +import Data.ByteString (ByteString, packCString, useAsCString) import Foreign import Foreign.C @@ -61,6 +62,35 @@ stringGetSize :: TString -> IO CULong stringGetSize = {# call TF_StringGetSize as ^ #} +-- Operation. +{# pointer *TF_Operation as Operation newtype #} +{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #} +{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #} + +instance Storable Operation where + sizeOf (Operation t) = sizeOf t + alignment (Operation t) = alignment t + peek p = fmap Operation (peek (castPtr p)) + poke p (Operation t) = poke (castPtr p) t + + +-- Output. +data Output = Output + { outputOperation :: Operation + , outputIndex :: CInt + } +{# pointer *TF_Output as OutputPtr -> Output #} + +instance Storable Output where + sizeOf _ = {# sizeof TF_Output #} + alignment _ = {# alignof TF_Output #} + peek p = Output <$> {# get TF_Output->oper #} p + <*> (fromIntegral <$> {# get TF_Output->index #} p) + poke p (Output oper index) = do + {# set TF_Output->oper #} p oper + {# set TF_Output->index #} p $ fromIntegral index + + -- Buffer. data Buffer {# pointer *TF_Buffer as BufferPtr -> Buffer #} @@ -71,6 +101,12 @@ getBufferData = {# get TF_Buffer->data #} getBufferLength :: BufferPtr -> IO CULong getBufferLength = {# get TF_Buffer->length #} +newBufferFromString :: Ptr () -> CULong -> IO BufferPtr +newBufferFromString = {# call TF_NewBufferFromString as ^ #} + +deleteBuffer :: BufferPtr -> IO () +deleteBuffer = {# call TF_DeleteBuffer as ^ #} + -- Tensor. {# pointer *TF_Tensor as Tensor newtype #} @@ -86,6 +122,8 @@ instance Storable Tensor where -- `CLLong`). type CInt64 = {#type int64_t #} +{# pointer *size_t as CSizePtr -> CSize #} + newTensor :: DataType -> Ptr CInt64 -- dimensions array -> CInt -- num dimensions @@ -114,6 +152,31 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #} tensorData :: Tensor -> IO (Ptr ()) tensorData = {# call TF_TensorData as ^ #} +-- ImportGraphDefOptions. +{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #} +{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #} +{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #} +{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping + { `ImportGraphDefOptions' + , useAsCString* `ByteString' + , `Int' + , %`OutputPtr' + } -> `()' +#} + + +-- Graph. +{# pointer *TF_Graph as Graph newtype #} +{# fun TF_NewGraph as newGraph { } -> `Graph' #} +{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #} +{# fun TF_GraphOperationByName as graphOperationByName + { `Graph' + , useAsCString* `ByteString' + } -> `Operation' +#} +{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #} +{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #} + -- Session Options. {# pointer *TF_SessionOptions as SessionOptions newtype #} @@ -132,29 +195,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #} -- Session. -{# pointer *TF_DeprecatedSession as Session newtype #} +{# pointer *TF_Session as Session newtype #} + +newSession :: Graph -> SessionOptions -> Status -> IO Session +newSession = {# call TF_NewSession as ^ #} -newSession :: SessionOptions -> Status -> IO Session -newSession = {# call TF_NewDeprecatedSession as ^ #} closeSession :: Session -> Status -> IO () -closeSession = {# call TF_CloseDeprecatedSession as ^ #} +closeSession = {# call TF_CloseSession as ^ #} deleteSession :: Session -> Status -> IO () -deleteSession = {# call TF_DeleteDeprecatedSession as ^ #} - -extendGraph :: Session -> Ptr () -> CULong -> Status -> IO () -extendGraph = {# call TF_ExtendGraph as ^ #} +deleteSession = {# call TF_DeleteSession as ^ #} run :: Session - -> BufferPtr -- RunOptions proto. - -> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count). - -> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count). - -> Ptr CString -> CInt -- Target nodes (names, count). - -> BufferPtr -- RunMetadata proto. + -> BufferPtr -- RunOptions proto. + -> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count). + -> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count). + -> Ptr Operation -> CInt -- Target operations (ops, count). + -> BufferPtr -- RunMetadata proto. -> Status -> IO () -run = {# call TF_Run as ^ #} +run = {# call TF_SessionRun as ^ #} -- FFI helpers. type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO () @@ -170,6 +231,3 @@ foreign import ccall "wrapper" -- in this address space. getAllOpList :: IO BufferPtr getAllOpList = {# call TF_GetAllOpList as ^ #} - -foreign import ccall "&TF_DeleteBuffer" - deleteBuffer :: FunPtr (BufferPtr -> IO ()) diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 53afefd..4b55fb2 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -121,7 +121,7 @@ opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x}) -- code into a Build function instance IsString Output where fromString s = case break (==':') s of - (n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr - -> Output (fromInteger ix) $ assigned n + (n, ':':ixStr) | ix <- read ixStr + -> Output (fromInteger ix) $ assigned n _ -> Output 0 $ assigned s where assigned = NodeName . Text.pack diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index f09998a..62d74de 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -55,7 +55,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import TensorFlow.Build import TensorFlow.Nodes -import TensorFlow.Output (NodeName, unNodeName) +import TensorFlow.Output (NodeName(..), unNodeName) import TensorFlow.Tensor import qualified Data.ByteString.Builder as Builder @@ -67,13 +67,13 @@ import qualified TensorFlow.Internal.FFI as FFI type Tracer = Builder.Builder -> IO () -- Common state threaded through the session. -data SessionState - = SessionState { - rawSession :: FFI.Session - , asyncCollector :: IO () -> IO () - -- ^ Starts the given action concurrently. - , tracer :: Tracer - } +data SessionState = SessionState + { rawSession :: FFI.Session + , rawGraph :: FFI.Graph + , asyncCollector :: IO () -> IO () + -- ^ Starts the given action concurrently. + , tracer :: Tracer + } newtype SessionT m a = Session (ReaderT SessionState (BuildT m) a) @@ -120,14 +120,23 @@ sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x }) -- | Run 'Session' actions in a new TensorFlow session created with -- the given option setter actions ('sessionTarget', 'sessionConfig'). runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a -runSessionWithOptions options (Session m) = - FFI.withSession applyOptions $ - \as rs -> - let initState = SessionState rs as (options ^. sessionTracer) +runSessionWithOptions options session = + _runSessionWithOptions session options $ FFI.withSession + +_runSessionWithOptions :: (MonadMask m, MonadIO m) + => SessionT m a + -> Options + -> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a) + -> m a +_runSessionWithOptions (Session m) options withSession = + withSession applyOptions $ + \ac rSession rGraph -> + let initState = SessionState rSession rGraph ac (options ^. sessionTracer) in evalBuildT (runReaderT m initState) - where applyOptions opt = do - FFI.setSessionTarget (options ^. sessionTarget) opt - FFI.setSessionConfig (options ^. sessionConfig) opt + where + applyOptions opt = do + FFI.setSessionTarget (options ^. sessionTarget) opt + FFI.setSessionConfig (options ^. sessionConfig) opt instance Monad m => MonadBuild (SessionT m) where build = Session . lift . build @@ -139,16 +148,17 @@ instance Monad m => MonadBuild (SessionT m) where extend :: MonadIO m => SessionT m () extend = do session <- Session (asks rawSession) + graph <- Session (asks rawGraph) trace <- Session (asks tracer) nodesToExtend <- build flushNodeBuffer unless (null nodesToExtend) $ liftIO $ do let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend trace ("Session.extend " <> Builder.string8 (showMessage graphDef)) - FFI.extendGraph session graphDef + FFI.extendGraph graph graphDef -- Now that all the nodes are created, run the initializers. initializers <- build flushInitializers unless (null initializers) $ - void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) + void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers) -- | Run a subgraph 't', rendering any dependent nodes that aren't already -- rendered, and fetch the corresponding values for 'a'. @@ -173,8 +183,9 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do let feeds' = fixFeeds feeds let fetchNames = encodeUtf8 <$> Set.toList fetch targetNames = toNodeNames $ Set.toList target - session <- Session (asks rawSession) - runResult <- liftIO $ FFI.run session + state <- Session ask + runResult <- liftIO $ FFI.run (rawSession state) + (rawGraph state) feeds' fetchNames targetNames @@ -214,5 +225,5 @@ asyncProdNodes nodes = do extend let targetNames = toNodeNames $ Set.toList target state <- Session ask - let loop = forever (void (FFI.run (rawSession state) [] [] targetNames)) + let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames)) liftIO (asyncCollector state loop)