mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add support for loading a Session from a SavedModel
This commit is contained in:
parent
fb629d1207
commit
dad19fde31
4 changed files with 82 additions and 0 deletions
|
@ -30,6 +30,9 @@ module TensorFlow.Core
|
||||||
, sessionTracer
|
, sessionTracer
|
||||||
, runSession
|
, runSession
|
||||||
, runSessionWithOptions
|
, runSessionWithOptions
|
||||||
|
, SavedModelTag(..)
|
||||||
|
, runSavedModel
|
||||||
|
, runSavedModelWithOptions
|
||||||
-- ** Building graphs
|
-- ** Building graphs
|
||||||
, MonadBuild(..)
|
, MonadBuild(..)
|
||||||
-- ** Running graphs
|
-- ** Running graphs
|
||||||
|
|
|
@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
|
||||||
( TensorFlowException(..)
|
( TensorFlowException(..)
|
||||||
, Raw.Session
|
, Raw.Session
|
||||||
, withSession
|
, withSession
|
||||||
|
, withSessionFromSavedModel
|
||||||
, run
|
, run
|
||||||
|
|
||||||
, SessionAction
|
, SessionAction
|
||||||
|
@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m)
|
||||||
-> m a
|
-> m a
|
||||||
withSession = withSession_ Raw.newSession
|
withSession = withSession_ Raw.newSession
|
||||||
|
|
||||||
|
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
|
||||||
|
=> B.ByteString
|
||||||
|
-- ^ exportDir
|
||||||
|
-> [B.ByteString]
|
||||||
|
-- ^ Tags.
|
||||||
|
-> (Raw.SessionOptions -> IO ())
|
||||||
|
-- ^ optionSetter
|
||||||
|
-> SessionAction m a
|
||||||
|
-> m a
|
||||||
|
withSessionFromSavedModel exportDir tags =
|
||||||
|
withSession_ $ \graph options status ->
|
||||||
|
Raw.loadSessionFromSavedModel options
|
||||||
|
runOptions
|
||||||
|
exportDir
|
||||||
|
tags
|
||||||
|
graph
|
||||||
|
metaGraphDef
|
||||||
|
status
|
||||||
|
where
|
||||||
|
runOptions = nullPtr
|
||||||
|
metaGraphDef = nullPtr
|
||||||
|
|
||||||
withSession_ :: (MonadIO m, MonadMask m)
|
withSession_ :: (MonadIO m, MonadMask m)
|
||||||
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
|
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
|
||||||
-- ^ mkSession
|
-- ^ mkSession
|
||||||
|
|
|
@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
||||||
newSession :: Graph -> SessionOptions -> Status -> IO Session
|
newSession :: Graph -> SessionOptions -> Status -> IO Session
|
||||||
newSession = {# call TF_NewSession as ^ #}
|
newSession = {# call TF_NewSession as ^ #}
|
||||||
|
|
||||||
|
{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
|
||||||
|
{ `SessionOptions'
|
||||||
|
, `BufferPtr' -- RunOptions proto.
|
||||||
|
, useAsCString* `ByteString' -- Export directory.
|
||||||
|
, withStringArrayLen* `[ByteString]'& -- Tags.
|
||||||
|
, `Graph'
|
||||||
|
, `BufferPtr' -- MetaGraphDef.
|
||||||
|
, `Status'
|
||||||
|
} -> `Session'
|
||||||
|
#}
|
||||||
|
|
||||||
closeSession :: Session -> Status -> IO ()
|
closeSession :: Session -> Status -> IO ()
|
||||||
closeSession = {# call TF_CloseSession as ^ #}
|
closeSession = {# call TF_CloseSession as ^ #}
|
||||||
|
@ -231,3 +241,18 @@ foreign import ccall "wrapper"
|
||||||
-- in this address space.
|
-- in this address space.
|
||||||
getAllOpList :: IO BufferPtr
|
getAllOpList :: IO BufferPtr
|
||||||
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
||||||
|
|
||||||
|
-- | Use a list of ByteString as a list of CString.
|
||||||
|
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
|
||||||
|
withStringList strings fn = go strings []
|
||||||
|
where
|
||||||
|
go [] cs = fn (reverse cs)
|
||||||
|
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
||||||
|
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)
|
||||||
|
|
||||||
|
|
||||||
|
-- | Use a list of ByteString as an array of CString with its length.
|
||||||
|
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
|
||||||
|
withStringArrayLen xs fn =
|
||||||
|
withStringList xs $ \strings ->
|
||||||
|
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)
|
||||||
|
|
|
@ -27,6 +27,8 @@ module TensorFlow.Session (
|
||||||
sessionTracer,
|
sessionTracer,
|
||||||
runSession,
|
runSession,
|
||||||
runSessionWithOptions,
|
runSessionWithOptions,
|
||||||
|
runSavedModel,
|
||||||
|
runSavedModelWithOptions,
|
||||||
MonadBuild(..),
|
MonadBuild(..),
|
||||||
extend,
|
extend,
|
||||||
addGraphDef,
|
addGraphDef,
|
||||||
|
@ -35,6 +37,7 @@ module TensorFlow.Session (
|
||||||
run_,
|
run_,
|
||||||
runWithFeeds_,
|
runWithFeeds_,
|
||||||
asyncProdNodes,
|
asyncProdNodes,
|
||||||
|
SavedModelTag(..),
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.ProtoLens.Message(defMessage)
|
import Data.ProtoLens.Message(defMessage)
|
||||||
|
@ -58,6 +61,7 @@ import TensorFlow.Nodes
|
||||||
import TensorFlow.Output (NodeName(..), unNodeName)
|
import TensorFlow.Output (NodeName(..), unNodeName)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
|
|
||||||
|
import qualified Data.ByteString.Char8 as C
|
||||||
import qualified Data.ByteString.Builder as Builder
|
import qualified Data.ByteString.Builder as Builder
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
|
@ -97,6 +101,14 @@ data Options = Options
|
||||||
, _sessionTracer :: Tracer
|
, _sessionTracer :: Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data SavedModelTag = GPU | TPU | Serve | Train
|
||||||
|
|
||||||
|
savedModelTagValue :: SavedModelTag -> ByteString
|
||||||
|
savedModelTagValue GPU = "gpu"
|
||||||
|
savedModelTagValue TPU = "tpu"
|
||||||
|
savedModelTagValue Serve = "serve"
|
||||||
|
savedModelTagValue Train = "train"
|
||||||
|
|
||||||
instance Default Options where
|
instance Default Options where
|
||||||
def = Options
|
def = Options
|
||||||
{ _sessionTarget = ""
|
{ _sessionTarget = ""
|
||||||
|
@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
|
||||||
runSessionWithOptions options session =
|
runSessionWithOptions options session =
|
||||||
_runSessionWithOptions session options $ FFI.withSession
|
_runSessionWithOptions session options $ FFI.withSession
|
||||||
|
|
||||||
|
runSavedModel :: (MonadMask m, MonadIO m)
|
||||||
|
=> FilePath
|
||||||
|
-- ^ Export directory.
|
||||||
|
-> Set SavedModelTag
|
||||||
|
-> SessionT m a
|
||||||
|
-> m a
|
||||||
|
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
|
||||||
|
|
||||||
|
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
|
||||||
|
=> FilePath
|
||||||
|
-- ^ Export directory.
|
||||||
|
-> Set SavedModelTag
|
||||||
|
-> Options
|
||||||
|
-> SessionT m a
|
||||||
|
-> m a
|
||||||
|
runSavedModelWithOptions exportDir tags options session =
|
||||||
|
_runSessionWithOptions session options $
|
||||||
|
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
|
||||||
|
|
||||||
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
||||||
=> SessionT m a
|
=> SessionT m a
|
||||||
-> Options
|
-> Options
|
||||||
|
|
Loading…
Reference in a new issue