mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 03:05:01 +01:00
Add support for loading a Session from a SavedModel
This commit is contained in:
parent
fb629d1207
commit
dad19fde31
4 changed files with 82 additions and 0 deletions
|
@ -30,6 +30,9 @@ module TensorFlow.Core
|
|||
, sessionTracer
|
||||
, runSession
|
||||
, runSessionWithOptions
|
||||
, SavedModelTag(..)
|
||||
, runSavedModel
|
||||
, runSavedModelWithOptions
|
||||
-- ** Building graphs
|
||||
, MonadBuild(..)
|
||||
-- ** Running graphs
|
||||
|
|
|
@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
|
|||
( TensorFlowException(..)
|
||||
, Raw.Session
|
||||
, withSession
|
||||
, withSessionFromSavedModel
|
||||
, run
|
||||
|
||||
, SessionAction
|
||||
|
@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m)
|
|||
-> m a
|
||||
withSession = withSession_ Raw.newSession
|
||||
|
||||
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
|
||||
=> B.ByteString
|
||||
-- ^ exportDir
|
||||
-> [B.ByteString]
|
||||
-- ^ Tags.
|
||||
-> (Raw.SessionOptions -> IO ())
|
||||
-- ^ optionSetter
|
||||
-> SessionAction m a
|
||||
-> m a
|
||||
withSessionFromSavedModel exportDir tags =
|
||||
withSession_ $ \graph options status ->
|
||||
Raw.loadSessionFromSavedModel options
|
||||
runOptions
|
||||
exportDir
|
||||
tags
|
||||
graph
|
||||
metaGraphDef
|
||||
status
|
||||
where
|
||||
runOptions = nullPtr
|
||||
metaGraphDef = nullPtr
|
||||
|
||||
withSession_ :: (MonadIO m, MonadMask m)
|
||||
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
|
||||
-- ^ mkSession
|
||||
|
|
|
@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
|||
newSession :: Graph -> SessionOptions -> Status -> IO Session
|
||||
newSession = {# call TF_NewSession as ^ #}
|
||||
|
||||
{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
|
||||
{ `SessionOptions'
|
||||
, `BufferPtr' -- RunOptions proto.
|
||||
, useAsCString* `ByteString' -- Export directory.
|
||||
, withStringArrayLen* `[ByteString]'& -- Tags.
|
||||
, `Graph'
|
||||
, `BufferPtr' -- MetaGraphDef.
|
||||
, `Status'
|
||||
} -> `Session'
|
||||
#}
|
||||
|
||||
closeSession :: Session -> Status -> IO ()
|
||||
closeSession = {# call TF_CloseSession as ^ #}
|
||||
|
@ -231,3 +241,18 @@ foreign import ccall "wrapper"
|
|||
-- in this address space.
|
||||
getAllOpList :: IO BufferPtr
|
||||
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||
|
||||
-- | Use a list of ByteString as a list of CString.
|
||||
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
|
||||
withStringList strings fn = go strings []
|
||||
where
|
||||
go [] cs = fn (reverse cs)
|
||||
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
||||
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as an array of CString with its length.
|
||||
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
|
||||
withStringArrayLen xs fn =
|
||||
withStringList xs $ \strings ->
|
||||
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)
|
||||
|
|
|
@ -27,6 +27,8 @@ module TensorFlow.Session (
|
|||
sessionTracer,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
runSavedModel,
|
||||
runSavedModelWithOptions,
|
||||
MonadBuild(..),
|
||||
extend,
|
||||
addGraphDef,
|
||||
|
@ -35,6 +37,7 @@ module TensorFlow.Session (
|
|||
run_,
|
||||
runWithFeeds_,
|
||||
asyncProdNodes,
|
||||
SavedModelTag(..),
|
||||
) where
|
||||
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
|
@ -58,6 +61,7 @@ import TensorFlow.Nodes
|
|||
import TensorFlow.Output (NodeName(..), unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
|
||||
import qualified Data.ByteString.Char8 as C
|
||||
import qualified Data.ByteString.Builder as Builder
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
|
@ -97,6 +101,14 @@ data Options = Options
|
|||
, _sessionTracer :: Tracer
|
||||
}
|
||||
|
||||
data SavedModelTag = GPU | TPU | Serve | Train
|
||||
|
||||
savedModelTagValue :: SavedModelTag -> ByteString
|
||||
savedModelTagValue GPU = "gpu"
|
||||
savedModelTagValue TPU = "tpu"
|
||||
savedModelTagValue Serve = "serve"
|
||||
savedModelTagValue Train = "train"
|
||||
|
||||
instance Default Options where
|
||||
def = Options
|
||||
{ _sessionTarget = ""
|
||||
|
@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
|
|||
runSessionWithOptions options session =
|
||||
_runSessionWithOptions session options $ FFI.withSession
|
||||
|
||||
runSavedModel :: (MonadMask m, MonadIO m)
|
||||
=> FilePath
|
||||
-- ^ Export directory.
|
||||
-> Set SavedModelTag
|
||||
-> SessionT m a
|
||||
-> m a
|
||||
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
|
||||
|
||||
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
|
||||
=> FilePath
|
||||
-- ^ Export directory.
|
||||
-> Set SavedModelTag
|
||||
-> Options
|
||||
-> SessionT m a
|
||||
-> m a
|
||||
runSavedModelWithOptions exportDir tags options session =
|
||||
_runSessionWithOptions session options $
|
||||
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
|
||||
|
||||
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
||||
=> SessionT m a
|
||||
-> Options
|
||||
|
|
Loading…
Add table
Reference in a new issue