1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-27 03:05:01 +01:00

Add support for loading a Session from a SavedModel

This commit is contained in:
Bart Schuurmans 2023-01-13 17:05:16 +01:00 committed by fkm3
parent fb629d1207
commit dad19fde31
4 changed files with 82 additions and 0 deletions

View file

@ -30,6 +30,9 @@ module TensorFlow.Core
, sessionTracer
, runSession
, runSessionWithOptions
, SavedModelTag(..)
, runSavedModel
, runSavedModelWithOptions
-- ** Building graphs
, MonadBuild(..)
-- ** Running graphs

View file

@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, withSessionFromSavedModel
, run
, SessionAction
@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m)
-> m a
withSession = withSession_ Raw.newSession
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
=> B.ByteString
-- ^ exportDir
-> [B.ByteString]
-- ^ Tags.
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSessionFromSavedModel exportDir tags =
withSession_ $ \graph options status ->
Raw.loadSessionFromSavedModel options
runOptions
exportDir
tags
graph
metaGraphDef
status
where
runOptions = nullPtr
metaGraphDef = nullPtr
withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession

View file

@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
{ `SessionOptions'
, `BufferPtr' -- RunOptions proto.
, useAsCString* `ByteString' -- Export directory.
, withStringArrayLen* `[ByteString]'& -- Tags.
, `Graph'
, `BufferPtr' -- MetaGraphDef.
, `Status'
} -> `Session'
#}
closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #}
@ -231,3 +241,18 @@ foreign import ccall "wrapper"
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}
-- | Use a list of ByteString as a list of CString.
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString with its length.
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
withStringArrayLen xs fn =
withStringList xs $ \strings ->
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)

View file

@ -27,6 +27,8 @@ module TensorFlow.Session (
sessionTracer,
runSession,
runSessionWithOptions,
runSavedModel,
runSavedModelWithOptions,
MonadBuild(..),
extend,
addGraphDef,
@ -35,6 +37,7 @@ module TensorFlow.Session (
run_,
runWithFeeds_,
asyncProdNodes,
SavedModelTag(..),
) where
import Data.ProtoLens.Message(defMessage)
@ -58,6 +61,7 @@ import TensorFlow.Nodes
import TensorFlow.Output (NodeName(..), unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
@ -97,6 +101,14 @@ data Options = Options
, _sessionTracer :: Tracer
}
data SavedModelTag = GPU | TPU | Serve | Train
savedModelTagValue :: SavedModelTag -> ByteString
savedModelTagValue GPU = "gpu"
savedModelTagValue TPU = "tpu"
savedModelTagValue Serve = "serve"
savedModelTagValue Train = "train"
instance Default Options where
def = Options
{ _sessionTarget = ""
@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
runSessionWithOptions options session =
_runSessionWithOptions session options $ FFI.withSession
runSavedModel :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> SessionT m a
-> m a
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> Options
-> SessionT m a
-> m a
runSavedModelWithOptions exportDir tags options session =
_runSessionWithOptions session options $
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
_runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a
-> Options