tensorflow-haskell/tensorflow/src/TensorFlow/Session.hs

261 lines
9.6 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
2016-10-24 21:26:42 +02:00
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.Session (
Session,
SessionT,
Options,
2016-10-24 21:26:42 +02:00
sessionConfig,
sessionTarget,
sessionTracer,
2016-10-24 21:26:42 +02:00
runSession,
runSessionWithOptions,
runSavedModel,
runSavedModelWithOptions,
MonadBuild(..),
2016-10-24 21:26:42 +02:00
extend,
addGraphDef,
run,
runWithFeeds,
run_,
runWithFeeds_,
asyncProdNodes,
SavedModelTag(..),
2016-10-24 21:26:42 +02:00
) where
2019-04-12 04:27:15 +02:00
import Data.ProtoLens.Message(defMessage)
2016-10-24 21:26:42 +02:00
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
2016-10-24 21:26:42 +02:00
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
2016-10-24 21:26:42 +02:00
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.ProtoLens (showMessage)
2016-10-24 21:26:42 +02:00
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
2016-10-24 21:26:42 +02:00
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName(..), unNodeName)
2016-10-24 21:26:42 +02:00
import TensorFlow.Tensor
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI
-- | An action for logging.
type Tracer = Builder.Builder -> IO ()
2016-10-24 21:26:42 +02:00
-- Common state threaded through the session.
data SessionState = SessionState
{ rawSession :: FFI.Session
, rawGraph :: FFI.Graph
, asyncCollector :: IO () -> IO ()
-- ^ Starts the given action concurrently.
, tracer :: Tracer
}
2016-10-24 21:26:42 +02:00
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
2019-04-12 04:27:15 +02:00
MonadMask, MonadFail)
2016-10-24 21:26:42 +02:00
instance MonadTrans SessionT where
lift = Session . lift . lift
type Session = SessionT IO
2016-10-24 21:26:42 +02:00
-- | Run 'Session' actions in a new TensorFlow session.
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession = runSessionWithOptions def
2016-10-24 21:26:42 +02:00
-- | Customization for session. Use the lenses to update:
-- 'sessionTarget', 'sessionTracer', 'sessionConfig'.
data Options = Options
{ _sessionTarget :: ByteString
, _sessionConfig :: ConfigProto
, _sessionTracer :: Tracer
}
data SavedModelTag = GPU | TPU | Serve | Train
savedModelTagValue :: SavedModelTag -> ByteString
savedModelTagValue GPU = "gpu"
savedModelTagValue TPU = "tpu"
savedModelTagValue Serve = "serve"
savedModelTagValue Train = "train"
instance Default Options where
def = Options
{ _sessionTarget = ""
2019-04-12 04:27:15 +02:00
, _sessionConfig = defMessage
, _sessionTracer = const (return ())
}
2016-10-24 21:26:42 +02:00
-- | Target can be: "local", ip:port, host:port.
-- The set of supported factories depends on the linked in libraries.
sessionTarget :: Lens' Options ByteString
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
2016-10-24 21:26:42 +02:00
-- | Uses the specified config for the created session.
sessionConfig :: Lens' Options ConfigProto
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
-- | Uses the given logger to monitor session progress.
sessionTracer :: Lens' Options Tracer
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
2016-10-24 21:26:42 +02:00
-- | Run 'Session' actions in a new TensorFlow session created with
-- the given option setter actions ('sessionTarget', 'sessionConfig').
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options session =
_runSessionWithOptions session options $ FFI.withSession
runSavedModel :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> SessionT m a
-> m a
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> Options
-> SessionT m a
-> m a
runSavedModelWithOptions exportDir tags options session =
_runSessionWithOptions session options $
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
_runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a
-> Options
-> ((FFI.SessionOptions -> IO ()) -> FFI.SessionAction m a -> m a)
-> m a
_runSessionWithOptions (Session m) options withSession =
withSession applyOptions $
\ac rSession rGraph ->
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
in evalBuildT (runReaderT m initState)
where
applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
2016-10-24 21:26:42 +02:00
instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build
2016-10-24 21:26:42 +02:00
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
--
-- Note that run, runWithFeeds, etc. will all call this function implicitly.
extend :: MonadIO m => SessionT m ()
2016-10-24 21:26:42 +02:00
extend = do
session <- Session (asks rawSession)
graph <- Session (asks rawGraph)
trace <- Session (asks tracer)
2016-10-24 21:26:42 +02:00
nodesToExtend <- build flushNodeBuffer
unless (null nodesToExtend) $ liftIO $ do
2019-04-12 04:27:15 +02:00
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph graph graphDef
2016-10-24 21:26:42 +02:00
-- Now that all the nodes are created, run the initializers.
initializers <- build flushInitializers
unless (null initializers) $
void $ liftIO $ FFI.run session graph [] [] (toNodeNames initializers)
2016-10-24 21:26:42 +02:00
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
2016-10-24 21:26:42 +02:00
run = runWithFeeds []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'.
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
2016-10-24 21:26:42 +02:00
runWithFeeds feeds t = do
ns <- build $ getNodes t
-- Note that this call to "fetch" shouldn't affect the following "extend"
-- call, since all nodes in t and its inputs/deps will be rendered by the
-- above call to getNodes.
fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
2016-10-24 21:26:42 +02:00
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
let feeds' = fixFeeds feeds
2016-10-24 21:26:42 +02:00
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
state <- Session ask
runResult <- liftIO $ FFI.run (rawSession state)
(rawGraph state)
2016-10-24 21:26:42 +02:00
feeds'
fetchNames
targetNames
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
return $ restore resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = map (encodeUtf8 . unNodeName)
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
-- already rendered. This behaves like 'run' except that it doesn't do any
-- fetches.
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
2016-10-24 21:26:42 +02:00
run_ = runWithFeeds_ []
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, feed the given input values, and fetch the corresponding result
-- values for 'a'. This behaves like 'runWithFeeds' except that it doesn't do
-- any fetches.
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
2016-10-24 21:26:42 +02:00
runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
2016-10-24 21:26:42 +02:00
-- | Starts a concurrent thread which evaluates the given Nodes
-- forever until runSession exits or an exception occurs. Graph
-- extension happens synchronously, but the resultant run proceeds as
-- a separate thread.
asyncProdNodes :: (MonadIO m, Nodes t)
2016-10-24 21:26:42 +02:00
=> t -- ^ Node to evaluate concurrently.
-> SessionT m ()
2016-10-24 21:26:42 +02:00
asyncProdNodes nodes = do
target <- build (getNodes nodes)
extend
let targetNames = toNodeNames $ Set.toList target
state <- Session ask
let loop = forever (void (FFI.run (rawSession state) (rawGraph state) [] [] targetNames))
2016-10-24 21:26:42 +02:00
liftIO (asyncCollector state loop)