1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Add support for loading a Session from a SavedModel

This commit is contained in:
Bart Schuurmans 2023-01-13 17:05:16 +01:00 committed by fkm3
parent fb629d1207
commit dad19fde31
4 changed files with 82 additions and 0 deletions

View file

@ -30,6 +30,9 @@ module TensorFlow.Core
, sessionTracer , sessionTracer
, runSession , runSession
, runSessionWithOptions , runSessionWithOptions
, SavedModelTag(..)
, runSavedModel
, runSavedModelWithOptions
-- ** Building graphs -- ** Building graphs
, MonadBuild(..) , MonadBuild(..)
-- ** Running graphs -- ** Running graphs

View file

@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..) ( TensorFlowException(..)
, Raw.Session , Raw.Session
, withSession , withSession
, withSessionFromSavedModel
, run , run
, SessionAction , SessionAction
@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m)
-> m a -> m a
withSession = withSession_ Raw.newSession withSession = withSession_ Raw.newSession
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
=> B.ByteString
-- ^ exportDir
-> [B.ByteString]
-- ^ Tags.
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSessionFromSavedModel exportDir tags =
withSession_ $ \graph options status ->
Raw.loadSessionFromSavedModel options
runOptions
exportDir
tags
graph
metaGraphDef
status
where
runOptions = nullPtr
metaGraphDef = nullPtr
withSession_ :: (MonadIO m, MonadMask m) withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session) => (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession -- ^ mkSession

View file

@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
newSession :: Graph -> SessionOptions -> Status -> IO Session newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #} newSession = {# call TF_NewSession as ^ #}
{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
{ `SessionOptions'
, `BufferPtr' -- RunOptions proto.
, useAsCString* `ByteString' -- Export directory.
, withStringArrayLen* `[ByteString]'& -- Tags.
, `Graph'
, `BufferPtr' -- MetaGraphDef.
, `Status'
} -> `Session'
#}
closeSession :: Session -> Status -> IO () closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #} closeSession = {# call TF_CloseSession as ^ #}
@ -231,3 +241,18 @@ foreign import ccall "wrapper"
-- in this address space. -- in this address space.
getAllOpList :: IO BufferPtr getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #} getAllOpList = {# call TF_GetAllOpList as ^ #}
-- | Use a list of ByteString as a list of CString.
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString with its length.
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
withStringArrayLen xs fn =
withStringList xs $ \strings ->
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)

View file

@ -27,6 +27,8 @@ module TensorFlow.Session (
sessionTracer, sessionTracer,
runSession, runSession,
runSessionWithOptions, runSessionWithOptions,
runSavedModel,
runSavedModelWithOptions,
MonadBuild(..), MonadBuild(..),
extend, extend,
addGraphDef, addGraphDef,
@ -35,6 +37,7 @@ module TensorFlow.Session (
run_, run_,
runWithFeeds_, runWithFeeds_,
asyncProdNodes, asyncProdNodes,
SavedModelTag(..),
) where ) where
import Data.ProtoLens.Message(defMessage) import Data.ProtoLens.Message(defMessage)
@ -58,6 +61,7 @@ import TensorFlow.Nodes
import TensorFlow.Output (NodeName(..), unNodeName) import TensorFlow.Output (NodeName(..), unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Builder as Builder import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import qualified Data.Set as Set import qualified Data.Set as Set
@ -97,6 +101,14 @@ data Options = Options
, _sessionTracer :: Tracer , _sessionTracer :: Tracer
} }
data SavedModelTag = GPU | TPU | Serve | Train
savedModelTagValue :: SavedModelTag -> ByteString
savedModelTagValue GPU = "gpu"
savedModelTagValue TPU = "tpu"
savedModelTagValue Serve = "serve"
savedModelTagValue Train = "train"
instance Default Options where instance Default Options where
def = Options def = Options
{ _sessionTarget = "" { _sessionTarget = ""
@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
runSessionWithOptions options session = runSessionWithOptions options session =
_runSessionWithOptions session options $ FFI.withSession _runSessionWithOptions session options $ FFI.withSession
runSavedModel :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> SessionT m a
-> m a
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> Options
-> SessionT m a
-> m a
runSavedModelWithOptions exportDir tags options session =
_runSessionWithOptions session options $
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
_runSessionWithOptions :: (MonadMask m, MonadIO m) _runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a => SessionT m a
-> Options -> Options