{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.Build
(
ControlNode(..)
, Unique
, explicitName
, implicitName
, opDef
, opDefWithName
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, GraphState
, renderedNodeDefs
, BuildT
, Build
, MonadBuild(..)
, addInitializer
, hoistBuildT
, evalBuildT
, runBuildT
, asGraphDef
, addGraphDef
, flushInitializers
, flushNodeBuffer
, summaries
, getOrAddOp
, addNewOp
, encodeOutput
, lookupNode
, withStateLens
, withDevice
, withNameScope
, withNodeDependencies
) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.Default (def)
import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map
import Data.Monoid ((<>))
import qualified Data.Set as Set
import Data.Set (Set)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', (.~), (^.), (&))
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.Graph
( GraphDef
, node
)
import Proto.Tensorflow.Core.Framework.NodeDef
( NodeDef
, attr
, input
, device
, name
, op
)
import TensorFlow.Output
newtype Unique = Unique Int
deriving (Eq, Ord, Enum)
implicitName :: PendingNodeName
implicitName = ImplicitName
explicitName :: Text -> PendingNodeName
explicitName = ExplicitName
newtype Scope = Scope {unScope :: Text}
deriving (Eq, Ord, IsString)
instance Show Scope where
show = show . unScope
opDef :: OpType -> OpDef
opDef = opDefWithName ImplicitName
opDefWithName :: PendingNodeName -> OpType -> OpDef
opDefWithName n t = OpDef
{ _opName = n
, _opType = t
, _opAttrs = Map.empty
, _opInputs = []
, _opControlInputs = []
}
data GraphState = GraphState
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
, _renderedNodeDefs :: !(Map.Map NodeName NodeDef)
, _nodeBuffer :: [NodeDef]
, _nextUnique :: !Unique
, _defaultDevice :: !(Maybe Device)
, _currentScope :: [Scope]
, _defaultControlInputs :: !(Set NodeName)
, _initializationNodes :: [NodeName]
, _summaries :: [Output]
}
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
deriving (Eq, Ord)
pendingNodeDef :: PendingNode -> NodeDef
pendingNodeDef (PendingNode _ _ n) = n
initGraphState :: GraphState
initGraphState =
GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] []
renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x })
renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x })
nodeBuffer :: Lens' GraphState [NodeDef]
nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x })
nextUnique :: Lens' GraphState Unique
nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x })
defaultDevice :: Lens' GraphState (Maybe Device)
defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x })
currentScope :: Lens' GraphState [Scope]
currentScope = lens _currentScope (\g x -> g { _currentScope = x })
defaultControlInputs :: Lens' GraphState (Set NodeName)
defaultControlInputs = lens _defaultControlInputs
(\g x -> g { _defaultControlInputs = x })
initializationNodes :: Lens' GraphState [NodeName]
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
summaries :: Lens' GraphState [Output]
summaries = lens _summaries (\g x -> g { _summaries = x })
newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
MonadFix)
type Build = BuildT Identity
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
hoistBuildT f (BuildT m) = BuildT $ mapStateT f m
runBuildT :: BuildT m a -> m (a, GraphState)
runBuildT (BuildT f) = runStateT f initGraphState
evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT (BuildT f) = evalStateT f initGraphState
class Monad m => MonadBuild m where
build :: Build a -> m a
instance Monad m => MonadBuild (BuildT m) where
build = hoistBuildT $ return . runIdentity
flushNodeBuffer :: MonadBuild m => m [NodeDef]
flushNodeBuffer = build $ do
ns <- use nodeBuffer
nodeBuffer .= []
return ns
flushInitializers :: Monad m => BuildT m [NodeName]
flushInitializers = do
ns <- use initializationNodes
initializationNodes .= []
return ns
addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
asGraphDef :: Build a -> GraphDef
asGraphDef b = def & node .~ gs ^. nodeBuffer
where
gs = snd $ runIdentity $ runBuildT b
addGraphDef :: MonadBuild m => GraphDef -> m ()
addGraphDef g = build $ nodeBuffer <>= g ^. node
getOrAddOp :: OpDef -> Build NodeName
getOrAddOp o = do
pending <- getPendingNode o
uses renderedNodes (Map.lookup pending) >>= \case
Just n -> return $ NodeName $ n ^. name
Nothing -> addNewOpFromPending pending
lookupNode :: NodeName -> Build NodeDef
lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case
Just n' -> return n'
Nothing -> error $ "lookupNode: unknown node name " ++ show n
addNewOp :: OpDef -> Build NodeName
addNewOp o = getPendingNode o >>= addNewOpFromPending
addNewOpFromPending :: PendingNode -> Build NodeName
addNewOpFromPending pending = do
nodeName <- renderPendingNode pending
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
nodeBuffer %= (nodeDef :)
renderedNodes %= Map.insert pending nodeDef
renderedNodeDefs %= Map.insert nodeName nodeDef
return nodeName
getPendingNode :: OpDef -> Build PendingNode
getPendingNode o = do
dev <- maybe "" deviceName <$> use defaultDevice
scope <- use currentScope
controls <- use defaultControlInputs
let inputs = map encodeOutput (o ^. opInputs)
let controlInputs
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
return $ PendingNode scope (o ^. opName)
$ def & op .~ (unOpType (o ^. opType) :: Text)
& attr .~ _opAttrs o
& input .~ (inputs ++ controlInputs)
& device .~ dev
where
makeDep = ("^" <>) . unNodeName
renderPendingNode :: PendingNode -> Build NodeName
renderPendingNode (PendingNode scope pendingName nodeDef)
= NodeName . (scopePrefix <>) <$> getName
where
scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope
getName = case pendingName of
ExplicitName n -> return n
ImplicitName -> do
u@(Unique k) <- use nextUnique
nextUnique .= succ u
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
encodeOutput :: Output -> Text
encodeOutput (Output (OutputIx 0) n) = unNodeName n
encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i)
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens accessor f act = do
old <- build $ use accessor
build $ accessor %= f
result <- act
build $ accessor .= old
return result
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d)
withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :)
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)