module TensorFlow.Session (
Session,
SessionOption,
sessionConfig,
sessionTarget,
runSession,
runSessionWithOptions,
build,
buildAnd,
buildWithSummary,
extend,
addGraphDef,
run,
runWithFeeds,
run_,
runWithFeeds_,
asyncProdNodes,
) where
import Control.Monad (forever, unless, void)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Functor.Identity (runIdentity)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Text.Encoding (encodeUtf8)
import Data.ProtoLens (def)
import Lens.Family2 ((&), (.~))
import Proto.Tensorflow.Core.Framework.Graph (node)
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import TensorFlow.Build
import qualified TensorFlow.Internal.FFI as FFI
import qualified TensorFlow.Internal.Raw as Raw
import TensorFlow.Nodes
import TensorFlow.Output (NodeName, unNodeName)
import TensorFlow.Tensor
data SessionState
= SessionState {
rawSession :: FFI.Session
, asyncCollector :: IO () -> IO ()
}
newtype Session a
= Session (ReaderT SessionState (BuildT IO) a)
deriving (Functor, Applicative, Monad, MonadIO)
runSession :: Session a -> IO a
runSession = runSessionWithOptions []
newtype SessionOption =
SessionOption { unSesssionOption :: Raw.SessionOptions -> IO () }
sessionTarget :: ByteString -> SessionOption
sessionTarget = SessionOption . FFI.setSessionTarget
sessionConfig :: ConfigProto -> SessionOption
sessionConfig = SessionOption . FFI.setSessionConfig
runSessionWithOptions :: [SessionOption] -> Session a -> IO a
runSessionWithOptions options (Session m) =
FFI.withSession applyOptions $
\as rs -> evalBuildT (runReaderT m (SessionState rs as))
where applyOptions opt = mapM_ (`unSesssionOption` opt) options
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
buildWithSummary :: forall a . Build a -> Session (a, [SummaryTensor])
buildWithSummary b = Session $ lift $ (,) <$> v <*> collectAllSummaries
where v :: BuildT IO a
v = hoistBuildT (return . runIdentity) b
extend :: Session ()
extend = do
let withSessionWhen vs action =
unless (null vs) $ Session (asks rawSession) >>= action
nodesToExtend <- build flushNodeBuffer
withSessionWhen nodesToExtend $ \session ->
liftIO $ FFI.extendGraph session
$ def & node .~ nodesToExtend
initializers <- build flushInitializers
withSessionWhen initializers $ \session ->
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
buildAnd :: (a -> Session b) -> Build a -> Session b
buildAnd f m = build m >>= f
run :: Fetchable t a => t -> Session a
run = runWithFeeds []
runWithFeeds :: Fetchable t a => [Feed] -> t -> Session a
runWithFeeds feeds t = do
ns <- build $ getNodes t
fetch <- build $ getFetch t
runFetchWithFeeds feeds ns fetch
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
feeds' <- build $ fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession)
runResult <- liftIO $ FFI.run session
feeds'
fetchNames
targetNames
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
return $ restore resultTensorsMap
toNodeNames :: [NodeName] -> [ByteString]
toNodeNames = map (encodeUtf8 . unNodeName)
run_ :: Nodes t => t -> Session ()
run_ = runWithFeeds_ []
runWithFeeds_ :: Nodes t => [Feed] -> t -> Session ()
runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
asyncProdNodes :: Nodes t
=> t
-> Session ()
asyncProdNodes nodes = do
target <- build (getNodes nodes)
extend
let targetNames = toNodeNames $ Set.toList target
state <- Session ask
let loop = forever (void (FFI.run (rawSession state) [] [] targetNames))
liftIO (asyncCollector state loop)