-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} module TensorFlow.Build ( -- * Graph node types ControlNode(..) , Unique -- * Ops , explicitName , implicitName , opDef , opDefWithName , opName , opType , opAttr , opInputs , opControlInputs -- * The Build monad , GraphState , render , renderNodeName , renderedNodeDefs , BuildT , Build , MonadBuild(..) , addInitializer , hoistBuildT , evalBuildT , runBuildT , asGraphDef , addGraphDef , flushInitializers , flushNodeBuffer -- * Creating and looking up Ops , getOrAddOp , addNewOp , renderOutput -- * Modifying all nodes in a Build action , colocateWith , withStateLens , withDevice , withNameScope , withNodeDependencies -- * Internal Summary related bits. , addSummary , SummaryTensor , collectAllSummaries ) where import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) import Data.ByteString (ByteString) import Data.Default (def) import Data.Functor.Identity (Identity(..)) import qualified Data.Map.Strict as Map import Data.Monoid ((<>)) import qualified Data.Set as Set import Data.Set (Set) import Data.String (IsString(..)) import Data.Text (Text) import qualified Data.Text as Text import Lens.Family2 (Lens', (.~), (^.), (&)) import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=)) import Lens.Family2.Unchecked (lens) import Proto.Tensorflow.Core.Framework.Graph ( GraphDef , node ) import Proto.Tensorflow.Core.Framework.NodeDef ( NodeDef , attr , input , device , name , op ) import TensorFlow.Orphans () import TensorFlow.Output import TensorFlow.Tensor newtype Unique = Unique Int deriving (Eq, Ord, Enum) -------------- implicitName :: PendingNodeName implicitName = ImplicitName explicitName :: Text -> PendingNodeName explicitName = ExplicitName newtype Scope = Scope {unScope :: Text} deriving (Eq, Ord, IsString) instance Show Scope where show = show . unScope opDef :: OpType -> OpDef opDef = opDefWithName ImplicitName opDefWithName :: PendingNodeName -> OpType -> OpDef opDefWithName n t = OpDef { _opName = n , _opType = t , _opAttrs = Map.empty , _opInputs = [] , _opControlInputs = [] } -- | Synonym for the tensors that return serialized Summary proto. type SummaryTensor = Tensor Value ByteString data GraphState = GraphState { _renderedNodes :: !(Map.Map PendingNode NodeDef) -- ^ Nodes which have been rendered. Keeps track of the unique ID we -- assign each implicitly-named node. Also prevents us from adding the -- same node (implicit or explicit) more than once to the nodeBuffer. , _renderedNodeDefs :: !(Map.Map NodeName NodeDef) -- ^ The NodeDefs of nodes which have been rendered. Used by the -- Gradient module to inspect the node graph. , _nodeBuffer :: [NodeDef] -- ^ A list of nodes that should be passed to TensorFlow during -- the next call to Session.extend (TF_ExtendGraph). , _nextUnique :: !Unique -- ^ Unique ID for the next node -- TODO(judahjacobson): watch for clashes between auto and user names. , _defaultDevice :: !(Maybe Device) , _currentScope :: [Scope] , _defaultControlInputs :: !(Set NodeName) , _initializationNodes :: [NodeName] -- ^ The nodes to run next time a TF.run is issued, typically -- variable initializers. , _summaries :: [SummaryTensor] -- ^ The tensors for summary } -- | A node definition without its final name. Used as a key in the -- "renderedNodes" map. -- The NodeDef contained inside has an empty "name" field. data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef deriving (Eq, Ord) -- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending. pendingNodeDef :: PendingNode -> NodeDef pendingNodeDef (PendingNode _ _ n) = n initGraphState :: GraphState initGraphState = GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] [] renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef) renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x }) renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef) renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x }) nodeBuffer :: Lens' GraphState [NodeDef] nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x }) nextUnique :: Lens' GraphState Unique nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x }) defaultDevice :: Lens' GraphState (Maybe Device) defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x }) currentScope :: Lens' GraphState [Scope] currentScope = lens _currentScope (\g x -> g { _currentScope = x }) defaultControlInputs :: Lens' GraphState (Set NodeName) defaultControlInputs = lens _defaultControlInputs (\g x -> g { _defaultControlInputs = x }) initializationNodes :: Lens' GraphState [NodeName] initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x }) summaries :: Lens' GraphState [SummaryTensor] summaries = lens _summaries (\g x -> g { _summaries = x }) -- | An action for building nodes in a TensorFlow graph. -- Used to manage build state internally as part of the @Session@ monad. newtype BuildT m a = BuildT (StateT GraphState m a) deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadState GraphState, MonadThrow, MonadCatch, MonadMask) -- | An action for building nodes in a TensorFlow graph. type Build = BuildT Identity -- | This is Control.Monad.Morph.hoist sans the dependency. hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b hoistBuildT f (BuildT m) = BuildT $ mapStateT f m runBuildT :: BuildT m a -> m (a, GraphState) runBuildT (BuildT f) = runStateT f initGraphState evalBuildT :: Monad m => BuildT m a -> m a evalBuildT (BuildT f) = evalStateT f initGraphState -- | Lift a 'Build' action into a monad, including any explicit op renderings. class Monad m => MonadBuild m where build :: Build a -> m a instance Monad m => MonadBuild (BuildT m) where build = hoistBuildT $ return . runIdentity -- | Get all the NodeDefs that have accumulated so far, and clear that buffer. flushNodeBuffer :: MonadBuild m => m [NodeDef] flushNodeBuffer = build $ do ns <- use nodeBuffer nodeBuffer .= [] return ns -- | Get all the initializers that have accumulated so far, and clear -- that buffer. flushInitializers :: Monad m => BuildT m [NodeName] flushInitializers = do ns <- use initializationNodes initializationNodes .= [] return ns -- | Registers the given node to be executed before the next -- 'TensorFlow.Session.run'. addInitializer :: MonadBuild m => ControlNode -> m () addInitializer (ControlNode o) = build $ do i <- getOrAddOp o initializationNodes %= (i:) -- | Produce a GraphDef proto representation of the nodes that are rendered in -- the given 'Build' action. asGraphDef :: Build a -> GraphDef asGraphDef b = def & node .~ gs ^. nodeBuffer where gs = snd $ runIdentity $ runBuildT b -- TODO: check against existing nodes for conflicts? addGraphDef :: MonadBuild m => GraphDef -> m () addGraphDef g = build $ nodeBuffer <>= g ^. node -- | Render the given op if it hasn't been rendered already, and return its -- name. getOrAddOp :: Op -> Build NodeName getOrAddOp o = NodeName . (^. name) <$> resolveOp o resolveOp :: Op -> Build NodeDef resolveOp (Rendered n) = return n resolveOp (Unrendered o) = do pending <- getPendingNode o uses renderedNodes (Map.lookup pending) >>= \case Just n -> return n Nothing -> addNewOpFromPending pending -- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops -- which are not safe to dedup (e.g, "variable" and "assign"). addNewOp :: OpDef -> Build NodeDef addNewOp o = getPendingNode o >>= addNewOpFromPending addNewOpFromPending :: PendingNode -> Build NodeDef addNewOpFromPending pending = do nodeName <- renderPendingNode pending let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName nodeBuffer %= (nodeDef :) renderedNodes %= Map.insert pending nodeDef renderedNodeDefs %= Map.insert nodeName nodeDef return nodeDef -- | Get the pending node corresponding to an OpDef, which may or may not have -- been rendered before. Implicitly renders all of this node's inputs. getPendingNode :: OpDef -> Build PendingNode getPendingNode o = do -- An empty string in the proto field means that no specific -- device is specified. dev <- maybe "" deviceName <$> use defaultDevice inputs <- mapM getInput (o ^. opInputs) scope <- use currentScope controls <- use defaultControlInputs let controlInputs = map getDep (o ^. opControlInputs ++ Set.toList controls) return $ PendingNode scope (o ^. opName) $ def & op .~ (unOpType (o ^. opType) :: Text) & attr .~ _opAttrs o & input .~ (inputs ++ controlInputs) & device .~ dev where getInput (Output (OutputIx k) subOp) = (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp getDep = ("^" <>) . unNodeName -- | Pick a name for a pending node. If it has an explicit name, just use that; -- if the name is implicit, assign a new unique name based on the op type. renderPendingNode :: PendingNode -> Build NodeName renderPendingNode (PendingNode scope pendingName nodeDef) = NodeName . (scopePrefix <>) <$> getName where scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope getName = case pendingName of ExplicitName n -> return n ImplicitName -> do u@(Unique k) <- use nextUnique nextUnique .= succ u return $ nodeDef ^. op <> "_" <> Text.pack (show k) -- | Render an 'Output' and return a string representation for the TensorFlow -- foreign APIs. renderOutput :: Output -> Build Text renderOutput (Output (OutputIx i) o) = do n <- getOrAddOp o return $ unNodeName n <> Text.pack (":" ++ show i) -- | Modify some part of the state, run an action, and restore the state -- after that action is done. withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b withStateLens accessor f act = do old <- build $ use accessor build $ accessor %= f result <- act build $ accessor .= old return result -- | Set a device for all nodes rendered in the given 'Build' action -- (unless further overridden by another use of withDevice). withDevice :: MonadBuild m => Maybe Device -> m a -> m a withDevice d = withStateLens defaultDevice (const d) -- | Places all nodes rendered in the given 'Build' action on the same -- device as the given Tensor (see also 'withDevice'). Make sure that -- the action has side effects of rendering the desired tensors. A pure -- return would not have the desired effect. colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a colocateWith t x = do d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) withDevice (Just d) x -- | Prepend a scope to all nodes rendered in the given 'Build' action. withNameScope :: MonadBuild m => Text -> m a -> m a withNameScope s = withStateLens currentScope (Scope s :) -- | Add control inputs to all nodes rendered in the given 'Build' action. withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) -- | Render a 'Tensor', fixing its name, scope, device and control inputs from -- the 'Build' context. Also renders any dependencies of the 'Tensor' that -- weren't already rendered. -- -- This operation is idempotent; @render >=> render === render@. However, -- rendering a (previously un-rendered) 'Tensor' in two different contexts -- may result in two different 'Tensor's. render :: MonadBuild m => Tensor v a -> m (Tensor v a) render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp) -- | Render a 'Tensor' and get its node's name. renderNodeName :: Tensor v a -> Build NodeName renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp) -- | Records the given summary action in Build for retrieval with -- 'collectAllSummaries'. The summary op is required to produce a -- Summary protocol buffer in string form. For safety, use the -- pre-composed functions: Logging.scalarSummary and -- Logging.histogramSummary. addSummary :: SummaryTensor -> Build () addSummary t = summaries %= (t :) -- | Retrieves the summary ops collected thus far. Typically this only -- happens once, but if 'TensorFlow.Session.buildWithSummary' is used -- repeatedly, the values accumulate. collectAllSummaries :: Monad m => BuildT m [SummaryTensor] collectAllSummaries = use summaries