mirror of
https://github.com/tensorflow/haskell.git
synced 2025-03-27 16:15:13 +01:00
260 lines
9.6 KiB
Haskell
260 lines
9.6 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE Rank2Types #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TupleSections #-}
|
|
|
|
module TensorFlow.Session (
|
|
Session,
|
|
SessionT,
|
|
Options,
|
|
sessionConfig,
|
|
sessionTarget,
|
|
sessionTracer,
|
|
runSession,
|
|
runSessionWithOptions,
|
|
runSavedModel,
|
|
runSavedModelWithOptions,
|
|
MonadBuild(..),
|
|
extend,
|
|
addGraphDef,
|
|
run,
|
|
runWithFeeds,
|
|
run_,
|
|
runWithFeeds_,
|
|
asyncProdNodes,
|
|
SavedModelTag(..),
|
|
) where
|
|
|
|
import Data.ProtoLens.Message(defMessage)
|
|
import Control.Monad (forever, unless, void)
|
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
|
import Control.Monad.Trans.Class (MonadTrans, lift)
|
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
|
import Data.ByteString (ByteString)
|
|
import Data.Default (Default, def)
|
|
import Data.ProtoLens (showMessage)
|
|
import Data.Set (Set)
|
|
import Data.Text.Encoding (encodeUtf8)
|
|
import Lens.Family2 (Lens', (^.), (&), (.~))
|
|
import Lens.Family2.Unchecked (lens)
|
|
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
|
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
|
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
|
import TensorFlow.Build
|
|
import TensorFlow.Nodes
|
|
import TensorFlow.Output (NodeName(..), unNodeName)
|
|
import TensorFlow.Tensor
|
|
|
|
import qualified Data.ByteString.Char8 as C
|
|
import qualified Data.ByteString.Builder as Builder
|
|
import qualified Data.Map.Strict as Map
|
|
import qualified Data.Set as Set
|
|
import qualified TensorFlow.Internal.FFI as FFI
|
|
|
|
-- | An action for logging.
|
|
type Tracer = Builder.Builder -> IO ()
|
|
|
|
-- Common state threaded through the session.
|
|
data SessionState = SessionState
|
|
{ rawSession :: FFI.Session
|
|
, rawGraph :: FFI.Graph
|
|
, asyncCollector :: IO () -> IO ()
|
|
-- ^ Starts the given action concurrently.
|
|
, tracer :: Tracer
|
|
}
|
|
|
|
newtype SessionT m a
|
|
= Session (ReaderT SessionState (BuildT m) a)
|
|
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
|
|
MonadMask, MonadFail)
|
|
|
|
instance MonadTrans SessionT where
|
|
lift = Session . lift . lift
|
|
|
|
type Session = SessionT IO
|
|
|
|
-- | Run 'Session' actions in a new TensorFlow session.
|
|
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
|
|
runSession = runSessionWithOptions def
|
|
|
|
-- | Customization for session. Use the lenses to update:
|
|
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
|
|
data Options = Options
|
|
{ _sessionTarget :: ByteString
|
|
, _sessionConfig :: ConfigProto
|
|
, _sessionTracer :: Tracer
|
|
}
|
|
|
|
data SavedModelTag = GPU | TPU | Serve | Train
|
|
|
|
savedModelTagValue :: SavedModelTag -> ByteString
|
|
savedModelTagValue GPU = "gpu"
|
|
savedModelTagValue TPU = "tpu"
|
|
savedModelTagValue Serve = "serve"
|
|
savedModelTagValue Train = "train"
|
|
|
|
instance Default Options where
|
|
def = Options
|
|
{ _sessionTarget = ""
|
|
, _sessionConfig = defMessage
|
|
, _sessionTracer = const (return ())
|
|
}
|
|
|
|
-- | Target can be: "local", ip:port, host:port.
|
|
-- The set of supported factories depends on the linked in libraries.
|
|
sessionTarget :: Lens' Options ByteString
|
|
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
|
|
|
|
-- | Uses the specified config for the created session.
|
|
sessionConfig :: Lens' Options ConfigProto
|
|
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
|
|
|
|
-- | Uses the given logger to monitor session progress.
|
|
sessionTracer :: Lens' Options Tracer
|
|
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
|
|
|
|
-- | Run 'Session' actions in a new TensorFlow session created with
|
|
-- the given option setter actions ('sessionTarget', 'sessionConfig').
|
|
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
|
runSessionWithOptions options session =
|
|
_runSessionWithOptions session options $ FFI.withSession
|
|
|
|
runSavedModel :: (MonadMask m, MonadIO m)
|
|
=> FilePath
|
|
-- ^ Export directory.
|
|
-> Set SavedModelTag
|
|
-> SessionT m a
|
|
-> m a
|
|
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
|
|
|
|
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
|
|
=> FilePath
|
|
-- ^ Export directory.
|
|
-> Set SavedModelTag
|
|
-> Options
|
|
-> SessionT m a
|
|
-> m a
|
|
runSavedModelWithOptions exportDir tags options session =
|
|
_runSessionWithOptions session options $
|
|
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
|
|
|
|
_runSessionWithOptions :: (MonadMask m, MonadIO m)
|
|
=> SessionT m a
|
|
-> Options
|
|
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
|
|
-> m a
|
|
_runSessionWithOptions (Session m) options withSession =
|
|
withSession applyOptions $
|
|
\ac rSession rGraph ->
|
|
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
|
|
in evalBuildT (runReaderT m initState)
|
|
where
|
|
applyOptions opt = do
|
|
FFI.setSessionTarget (options ^. sessionTarget) opt
|
|
FFI.setSessionConfig (options ^. sessionConfig) opt
|
|
|
|
instance Monad m => MonadBuild (SessionT m) where
|
|
build = Session . lift . build
|
|
|
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
|
-- any pending initializers.
|
|
--
|
|
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
|
|
extend :: MonadIO m => SessionT m ()
|
|
extend = do
|
|
session <- Session (asks rawSession)
|
|
graph <- Session (asks rawGraph)
|
|
trace <- Session (asks tracer)
|
|
nodesToExtend <- build flushNodeBuffer
|
|
unless (null nodesToExtend) $ liftIO $ do
|
|
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
|
|
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
|
|
FFI.extendGraph graph graphDef
|
|
-- Now that all the nodes are created, run the initializers.
|
|
initializers <- build flushInitializers
|
|
unless (null initializers) $
|
|
void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
|
|
|
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
|
-- rendered, and fetch the corresponding values for 'a'.
|
|
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
|
|
run = runWithFeeds []
|
|
|
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
|
-- rendered, feed the given input values, and fetch the corresponding result
|
|
-- values for 'a'.
|
|
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
|
|
runWithFeeds feeds t = do
|
|
ns <- build $ getNodes t
|
|
-- Note that this call to "fetch" shouldn't affect the following "extend"
|
|
-- call, since all nodes in t and its inputs/deps will be rendered by the
|
|
-- above call to getNodes.
|
|
fetch <- build $ getFetch t
|
|
runFetchWithFeeds feeds ns fetch
|
|
|
|
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
|
|
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
|
extend
|
|
let feeds' = fixFeeds feeds
|
|
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
|
targetNames = toNodeNames $ Set.toList target
|
|
state <- Session ask
|
|
runResult <- liftIO $ FFI.run (rawSession state)
|
|
(rawGraph state)
|
|
feeds'
|
|
fetchNames
|
|
targetNames
|
|
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
|
|
return $ restore resultTensorsMap
|
|
|
|
toNodeNames :: [NodeName] -> [ByteString]
|
|
toNodeNames = map (encodeUtf8 . unNodeName)
|
|
|
|
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
|
|
-- already rendered. This behaves like 'run' except that it doesn't do any
|
|
-- fetches.
|
|
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
|
run_ = runWithFeeds_ []
|
|
|
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
|
-- rendered, feed the given input values, and fetch the corresponding result
|
|
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
|
|
-- any fetches.
|
|
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
|
|
runWithFeeds_ feeds t = do
|
|
ns <- build $ getNodes t
|
|
runFetchWithFeeds feeds ns (pure ())
|
|
|
|
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
|
|
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
|
|
|
|
-- | Starts a concurrent thread which evaluates the given Nodes
|
|
-- forever until runSession exits or an exception occurs. Graph
|
|
-- extension happens synchronously, but the resultant run proceeds as
|
|
-- a separate thread.
|
|
asyncProdNodes :: (MonadIO m, Nodes t)
|
|
=> t -- ^ Node to evaluate concurrently.
|
|
-> SessionT m ()
|
|
asyncProdNodes nodes = do
|
|
target <- build (getNodes nodes)
|
|
extend
|
|
let targetNames = toNodeNames $ Set.toList target
|
|
state <- Session ask
|
|
let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
|
|
liftIO (asyncCollector state loop)
|