{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.Session (
Session,
SessionT,
Options,
sessionConfig,
sessionTarget,
sessionTracer,
runSession,
runSessionWithOptions,
MonadBuild(..),
extend,
addGraphDef,
run,
runWithFeeds,
run_,
runWithFeeds_,
asyncProdNodes,
) where
import Control.Monad (forever, unless, void)
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Monoid ((<>))
import Data.ProtoLens (showMessage)
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 (Lens', (^.), (&), (.~))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef, node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified TensorFlow.Internal.FFI as FFI
type Tracer = Builder.Builder -> IO ()
data SessionState
= SessionState {
rawSession :: FFI.Session
, asyncCollector :: IO () -> IO ()
, tracer :: Tracer
}
newtype SessionT m a
= Session (ReaderT SessionState (BuildT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadThrow, MonadCatch,
MonadMask)
instance MonadTrans SessionT where
lift = Session . lift . lift
type Session = SessionT IO
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
runSession = runSessionWithOptions def
data Options = Options
{ _sessionTarget :: ByteString
, _sessionConfig :: ConfigProto
, _sessionTracer :: Tracer
}
instance Default Options where
def = Options
{ _sessionTarget = ""
, _sessionConfig = def
, _sessionTracer = const (return ())
}
sessionTarget :: Lens' Options ByteString
sessionTarget = lens _sessionTarget (\g x -> g { _sessionTarget = x })
sessionConfig :: Lens' Options ConfigProto
sessionConfig = lens _sessionConfig (\g x -> g { _sessionConfig = x })
sessionTracer :: Lens' Options Tracer
sessionTracer = lens _sessionTracer (\g x -> g { _sessionTracer = x })
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs ->
let initState = SessionState rs as (options ^. sessionTracer)
in evalBuildT (runReaderT m initState)
where applyOptions opt = do
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
instance Monad m => MonadBuild (SessionT m) where
build = Session . lift . build
extend :: MonadIO m => SessionT m ()
extend = do
session <- Session (asks rawSession)
trace <- Session (asks tracer)
nodesToExtend <- build flushNodeBuffer
unless (null nodesToExtend) $ liftIO $ do
let graphDef = (def :: GraphDef) & node .~ nodesToExtend
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
FFI.extendGraph session graphDef
initializers <- build flushInitializers
unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
run = runWithFeeds []
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
runWithFeeds feeds t = do
ns <- build $ getNodes t
fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT m a
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
let feeds' = fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession)
runResult <- liftIO $ FFI.run session
feeds'
fetchNames
targetNames
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
return $ restore resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = map (encodeUtf8 . unNodeName)
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
run_ = runWithFeeds_ []
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
asyncProdNodes :: (MonadIO m, Nodes t)
=> t
-> SessionT m ()
asyncProdNodes nodes = do
target <- build (getNodes nodes)
extend
let targetNames = toNodeNames $ Set.toList target
state <- Session ask
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
liftIO (asyncCollector state loop)