diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 6a36c2a..87003c3 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -30,6 +30,9 @@ module TensorFlow.Core , sessionTracer , runSession , runSessionWithOptions + , SavedModelTag(..) + , runSavedModel + , runSavedModelWithOptions -- ** Building graphs , MonadBuild(..) -- ** Running graphs diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 3cc9406..bcf5b43 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI ( TensorFlowException(..) , Raw.Session , withSession + , withSessionFromSavedModel , run , SessionAction @@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m) -> m a withSession = withSession_ Raw.newSession +withSessionFromSavedModel :: (MonadIO m, MonadMask m) + => B.ByteString + -- ^ exportDir + -> [B.ByteString] + -- ^ Tags. + -> (Raw.SessionOptions -> IO ()) + -- ^ optionSetter + -> SessionAction m a + -> m a +withSessionFromSavedModel exportDir tags = + withSession_ $ \graph options status -> + Raw.loadSessionFromSavedModel options + runOptions + exportDir + tags + graph + metaGraphDef + status + where + runOptions = nullPtr + metaGraphDef = nullPtr + withSession_ :: (MonadIO m, MonadMask m) => (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session) -- ^ mkSession diff --git a/tensorflow/src/TensorFlow/Internal/Raw.chs b/tensorflow/src/TensorFlow/Internal/Raw.chs index 28fabf2..b9382dc 100644 --- a/tensorflow/src/TensorFlow/Internal/Raw.chs +++ b/tensorflow/src/TensorFlow/Internal/Raw.chs @@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #} newSession :: Graph -> SessionOptions -> Status -> IO Session newSession = {# call TF_NewSession as ^ #} +{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel + { `SessionOptions' + , `BufferPtr' -- RunOptions proto. + , useAsCString* `ByteString' -- Export directory. + , withStringArrayLen* `[ByteString]'& -- Tags. + , `Graph' + , `BufferPtr' -- MetaGraphDef. + , `Status' + } -> `Session' +#} closeSession :: Session -> Status -> IO () closeSession = {# call TF_CloseSession as ^ #} @@ -231,3 +241,18 @@ foreign import ccall "wrapper" -- in this address space. getAllOpList :: IO BufferPtr getAllOpList = {# call TF_GetAllOpList as ^ #} + +-- | Use a list of ByteString as a list of CString. +withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a +withStringList strings fn = go strings [] + where + go [] cs = fn (reverse cs) + -- TODO(fmayle): Is it worth using unsafeAsCString here? + go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs) + + +-- | Use a list of ByteString as an array of CString with its length. +withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a +withStringArrayLen xs fn = + withStringList xs $ \strings -> + withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len) diff --git a/tensorflow/src/TensorFlow/Session.hs b/tensorflow/src/TensorFlow/Session.hs index 62d74de..26ffe6c 100644 --- a/tensorflow/src/TensorFlow/Session.hs +++ b/tensorflow/src/TensorFlow/Session.hs @@ -27,6 +27,8 @@ module TensorFlow.Session ( sessionTracer, runSession, runSessionWithOptions, + runSavedModel, + runSavedModelWithOptions, MonadBuild(..), extend, addGraphDef, @@ -35,6 +37,7 @@ module TensorFlow.Session ( run_, runWithFeeds_, asyncProdNodes, + SavedModelTag(..), ) where import Data.ProtoLens.Message(defMessage) @@ -58,6 +61,7 @@ import TensorFlow.Nodes import TensorFlow.Output (NodeName(..), unNodeName) import TensorFlow.Tensor +import qualified Data.ByteString.Char8 as C import qualified Data.ByteString.Builder as Builder import qualified Data.Map.Strict as Map import qualified Data.Set as Set @@ -97,6 +101,14 @@ data Options = Options , _sessionTracer :: Tracer } +data SavedModelTag = GPU | TPU | Serve | Train + +savedModelTagValue :: SavedModelTag -> ByteString +savedModelTagValue GPU = "gpu" +savedModelTagValue TPU = "tpu" +savedModelTagValue Serve = "serve" +savedModelTagValue Train = "train" + instance Default Options where def = Options { _sessionTarget = "" @@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> runSessionWithOptions options session = _runSessionWithOptions session options $ FFI.withSession +runSavedModel :: (MonadMask m, MonadIO m) + => FilePath + -- ^ Export directory. + -> Set SavedModelTag + -> SessionT m a + -> m a +runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def + +runSavedModelWithOptions :: (MonadMask m, MonadIO m) + => FilePath + -- ^ Export directory. + -> Set SavedModelTag + -> Options + -> SessionT m a + -> m a +runSavedModelWithOptions exportDir tags options session = + _runSessionWithOptions session options $ + FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags) + _runSessionWithOptions :: (MonadMask m, MonadIO m) => SessionT m a -> Options