mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
376 lines
13 KiB
Haskell
376 lines
13 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE Rank2Types #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
module TensorFlow.Build
|
|
( -- * Graph node types
|
|
ControlNode(..)
|
|
, Unique
|
|
-- * Ops
|
|
, explicitName
|
|
, implicitName
|
|
, opDef
|
|
, opDefWithName
|
|
, opName
|
|
, opType
|
|
, opAttr
|
|
, opInputs
|
|
, opControlInputs
|
|
-- * The Build monad
|
|
, GraphState
|
|
, render
|
|
, renderNodeName
|
|
, renderedNodeDefs
|
|
, BuildT
|
|
, Build
|
|
, addInitializer
|
|
, hoistBuildT
|
|
, evalBuildT
|
|
, runBuildT
|
|
, asGraphDef
|
|
, addGraphDef
|
|
, flushInitializers
|
|
, flushNodeBuffer
|
|
-- * Creating and looking up Ops
|
|
, getOrAddOp
|
|
, addNewOp
|
|
, renderOutput
|
|
-- * Modifying all nodes in a Build action
|
|
, colocateWith
|
|
, withStateLens
|
|
, withDevice
|
|
, withNameScope
|
|
, withNodeDependencies
|
|
-- * Internal Summary related bits.
|
|
, addSummary
|
|
, SummaryTensor
|
|
, collectAllSummaries
|
|
) where
|
|
|
|
import Control.Monad.IO.Class (MonadIO(..))
|
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
|
import Data.ByteString (ByteString)
|
|
import Data.Default (def)
|
|
import Data.Functor.Identity (Identity(..))
|
|
import qualified Data.Map.Strict as Map
|
|
import Data.Monoid ((<>))
|
|
import qualified Data.Set as Set
|
|
import Data.Set (Set)
|
|
import Data.String (IsString(..))
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as Text
|
|
import Lens.Family2 (Lens', (.~), (^.), (&))
|
|
import Lens.Family2.State.Strict (MonadState, use, uses, (.=), (<>=), (%=))
|
|
import Lens.Family2.Unchecked (lens)
|
|
import Proto.Tensorflow.Core.Framework.Graph
|
|
( GraphDef
|
|
, node
|
|
)
|
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
|
( NodeDef
|
|
, attr
|
|
, input
|
|
, device
|
|
, name
|
|
, op
|
|
)
|
|
|
|
import TensorFlow.Orphans ()
|
|
import TensorFlow.Output
|
|
import TensorFlow.Tensor
|
|
|
|
newtype Unique = Unique Int
|
|
deriving (Eq, Ord, Enum)
|
|
|
|
--------------
|
|
|
|
implicitName :: PendingNodeName
|
|
implicitName = ImplicitName
|
|
|
|
explicitName :: Text -> PendingNodeName
|
|
explicitName = ExplicitName
|
|
|
|
newtype Scope = Scope {unScope :: Text}
|
|
deriving (Eq, Ord, IsString)
|
|
|
|
instance Show Scope where
|
|
show = show . unScope
|
|
|
|
opDef :: OpType -> OpDef
|
|
opDef = opDefWithName ImplicitName
|
|
|
|
opDefWithName :: PendingNodeName -> OpType -> OpDef
|
|
opDefWithName n t = OpDef
|
|
{ _opName = n
|
|
, _opType = t
|
|
, _opAttrs = Map.empty
|
|
, _opInputs = []
|
|
, _opControlInputs = []
|
|
}
|
|
|
|
-- | Synonym for the tensors that return serialized Summary proto.
|
|
type SummaryTensor = Tensor Value ByteString
|
|
|
|
data GraphState = GraphState
|
|
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
|
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
|
-- assign each implicitly-named node. Also prevents us from adding the
|
|
-- same node (implicit or explicit) more than once to the nodeBuffer.
|
|
, _renderedNodeDefs :: !(Map.Map NodeName NodeDef)
|
|
-- ^ The NodeDefs of nodes which have been rendered. Used by the
|
|
-- Gradient module to inspect the node graph.
|
|
, _nodeBuffer :: [NodeDef]
|
|
-- ^ A list of nodes that should be passed to TensorFlow during
|
|
-- the next call to Session.extend (TF_ExtendGraph).
|
|
, _nextUnique :: !Unique
|
|
-- ^ Unique ID for the next node
|
|
-- TODO(judahjacobson): watch for clashes between auto and user names.
|
|
, _defaultDevice :: !(Maybe Device)
|
|
, _currentScope :: [Scope]
|
|
, _defaultControlInputs :: !(Set NodeName)
|
|
, _initializationNodes :: [NodeName]
|
|
-- ^ The nodes to run next time a TF.run is issued, typically
|
|
-- variable initializers.
|
|
, _summaries :: [SummaryTensor]
|
|
-- ^ The tensors for summary
|
|
}
|
|
|
|
-- | A node definition without its final name. Used as a key in the
|
|
-- "renderedNodes" map.
|
|
-- The NodeDef contained inside has an empty "name" field.
|
|
data PendingNode = PendingNode [Scope] !PendingNodeName !NodeDef
|
|
deriving (Eq, Ord)
|
|
|
|
-- Returns an _incomplete_ NodeDef. The name is fixed by addNewOpFromPending.
|
|
pendingNodeDef :: PendingNode -> NodeDef
|
|
pendingNodeDef (PendingNode _ _ n) = n
|
|
|
|
initGraphState :: GraphState
|
|
initGraphState =
|
|
GraphState Map.empty Map.empty [] (Unique 0) Nothing [] Set.empty [] []
|
|
|
|
renderedNodes :: Lens' GraphState (Map.Map PendingNode NodeDef)
|
|
renderedNodes = lens _renderedNodes (\g x -> g { _renderedNodes = x })
|
|
|
|
renderedNodeDefs :: Lens' GraphState (Map.Map NodeName NodeDef)
|
|
renderedNodeDefs = lens _renderedNodeDefs (\g x -> g { _renderedNodeDefs = x })
|
|
|
|
nodeBuffer :: Lens' GraphState [NodeDef]
|
|
nodeBuffer = lens _nodeBuffer (\g x -> g { _nodeBuffer = x })
|
|
|
|
nextUnique :: Lens' GraphState Unique
|
|
nextUnique = lens _nextUnique (\g x -> g { _nextUnique = x })
|
|
|
|
defaultDevice :: Lens' GraphState (Maybe Device)
|
|
defaultDevice = lens _defaultDevice (\g x -> g { _defaultDevice = x })
|
|
|
|
currentScope :: Lens' GraphState [Scope]
|
|
currentScope = lens _currentScope (\g x -> g { _currentScope = x })
|
|
|
|
defaultControlInputs :: Lens' GraphState (Set NodeName)
|
|
defaultControlInputs = lens _defaultControlInputs
|
|
(\g x -> g { _defaultControlInputs = x })
|
|
|
|
initializationNodes :: Lens' GraphState [NodeName]
|
|
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
|
|
|
summaries :: Lens' GraphState [SummaryTensor]
|
|
summaries = lens _summaries (\g x -> g { _summaries = x })
|
|
|
|
-- | An action for building nodes in a TensorFlow graph.
|
|
-- Used to manage build state internally as part of the @Session@ monad.
|
|
newtype BuildT m a = BuildT (StateT GraphState m a)
|
|
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
|
MonadState GraphState)
|
|
|
|
-- | An action for building nodes in a TensorFlow graph.
|
|
type Build = BuildT Identity
|
|
|
|
-- | This is Control.Monad.Morph.hoist sans the dependency.
|
|
hoistBuildT :: (forall a . m a -> n a) -> BuildT m b -> BuildT n b
|
|
hoistBuildT f (BuildT m) = BuildT $ mapStateT f m
|
|
|
|
runBuildT :: BuildT m a -> m (a, GraphState)
|
|
runBuildT (BuildT f) = runStateT f initGraphState
|
|
|
|
evalBuildT :: Monad m => BuildT m a -> m a
|
|
evalBuildT (BuildT f) = evalStateT f initGraphState
|
|
|
|
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
|
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
|
flushNodeBuffer = do
|
|
ns <- use nodeBuffer
|
|
nodeBuffer .= []
|
|
return ns
|
|
|
|
-- | Get all the initializers that have accumulated so far, and clear
|
|
-- that buffer.
|
|
flushInitializers :: Monad m => BuildT m [NodeName]
|
|
flushInitializers = do
|
|
ns <- use initializationNodes
|
|
initializationNodes .= []
|
|
return ns
|
|
|
|
-- | Registers the given node to be executed before the next
|
|
-- 'TensorFlow.Session.run'.
|
|
addInitializer :: ControlNode -> Build ()
|
|
addInitializer (ControlNode o) = do
|
|
i <- getOrAddOp o
|
|
initializationNodes %= (i:)
|
|
|
|
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
|
-- the given 'Build' action.
|
|
asGraphDef :: Build a -> GraphDef
|
|
asGraphDef b = def & node .~ gs ^. nodeBuffer
|
|
where
|
|
gs = snd $ runIdentity $ runBuildT b
|
|
|
|
-- TODO: check against existing nodes for conflicts?
|
|
addGraphDef :: GraphDef -> Build ()
|
|
addGraphDef g = nodeBuffer <>= g ^. node
|
|
|
|
-- | Render the given op if it hasn't been rendered already, and return its
|
|
-- name.
|
|
getOrAddOp :: Op -> Build NodeName
|
|
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
|
|
|
|
resolveOp :: Op -> Build NodeDef
|
|
resolveOp (Rendered n) = return n
|
|
resolveOp (Unrendered o) = do
|
|
pending <- getPendingNode o
|
|
uses renderedNodes (Map.lookup pending) >>= \case
|
|
Just n -> return n
|
|
Nothing -> addNewOpFromPending pending
|
|
|
|
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
|
-- which are not safe to dedup (e.g, "variable" and "assign").
|
|
addNewOp :: OpDef -> Build NodeDef
|
|
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
|
|
|
addNewOpFromPending :: PendingNode -> Build NodeDef
|
|
addNewOpFromPending pending = do
|
|
nodeName <- renderPendingNode pending
|
|
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
|
nodeBuffer %= (nodeDef :)
|
|
renderedNodes %= Map.insert pending nodeDef
|
|
renderedNodeDefs %= Map.insert nodeName nodeDef
|
|
return nodeDef
|
|
|
|
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
|
-- been rendered before. Implicitly renders all of this node's inputs.
|
|
getPendingNode :: OpDef -> Build PendingNode
|
|
getPendingNode o = do
|
|
-- An empty string in the proto field means that no specific
|
|
-- device is specified.
|
|
dev <- maybe "" deviceName <$> use defaultDevice
|
|
inputs <- mapM getInput (o ^. opInputs)
|
|
scope <- use currentScope
|
|
controls <- use defaultControlInputs
|
|
let controlInputs
|
|
= map getDep (o ^. opControlInputs ++ Set.toList controls)
|
|
return $ PendingNode scope (o ^. opName)
|
|
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
|
& attr .~ _opAttrs o
|
|
& input .~ (inputs ++ controlInputs)
|
|
& device .~ dev
|
|
where
|
|
getInput (Output (OutputIx k) subOp)
|
|
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
|
|
getDep = ("^" <>) . unNodeName
|
|
|
|
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
|
-- if the name is implicit, assign a new unique name based on the op type.
|
|
renderPendingNode :: PendingNode -> Build NodeName
|
|
renderPendingNode (PendingNode scope pendingName nodeDef)
|
|
= NodeName . (scopePrefix <>) <$> getName
|
|
where
|
|
scopePrefix = Text.concat $ fmap ((<> "/") . unScope) scope
|
|
getName = case pendingName of
|
|
ExplicitName n -> return n
|
|
ImplicitName -> do
|
|
u@(Unique k) <- use nextUnique
|
|
nextUnique .= succ u
|
|
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
|
|
|
|
|
-- | Render an 'Output' and return a string representation for the TensorFlow
|
|
-- foreign APIs.
|
|
renderOutput :: Output -> Build Text
|
|
renderOutput (Output (OutputIx i) o) = do
|
|
n <- getOrAddOp o
|
|
return $ unNodeName n <> Text.pack (":" ++ show i)
|
|
|
|
-- | Modify some part of the state, run an action, and restore the state
|
|
-- after that action is done.
|
|
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
|
withStateLens accessor f act = do
|
|
old <- use accessor
|
|
accessor %= f
|
|
result <- act
|
|
accessor .= old
|
|
return result
|
|
|
|
-- | Set a device for all nodes rendered in the given 'Build' action
|
|
-- (unless further overridden by another use of withDevice).
|
|
withDevice :: Maybe Device -> Build a -> Build a
|
|
withDevice d = withStateLens defaultDevice (const d)
|
|
|
|
-- | Places all nodes rendered in the given 'Build' action on the same
|
|
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
|
-- the action has side effects of rendering the desired tensors. A pure
|
|
-- return would not have the desired effect.
|
|
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
|
colocateWith t x = do
|
|
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
|
withDevice (Just d) x
|
|
|
|
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
|
withNameScope :: Text -> Build a -> Build a
|
|
withNameScope s = withStateLens currentScope (Scope s :)
|
|
|
|
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
|
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
|
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
|
|
|
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
|
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
|
|
-- weren't already rendered.
|
|
--
|
|
-- This operation is idempotent; @render >=> render === render@. However,
|
|
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
|
-- may result in two different 'Tensor's.
|
|
render :: Tensor v a -> Build (Tensor v a)
|
|
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
|
|
|
-- | Render a 'Tensor' and get its node's name.
|
|
renderNodeName :: Tensor v a -> Build NodeName
|
|
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
|
|
|
-- | Records the given summary action in Build for retrieval with
|
|
-- 'collectAllSummaries'. The summary op is required to produce a
|
|
-- Summary protocol buffer in string form. For safety, use the
|
|
-- pre-composed functions: Logging.scalarSummary and
|
|
-- Logging.histogramSummary.
|
|
addSummary :: SummaryTensor -> Build ()
|
|
addSummary t = summaries %= (t :)
|
|
|
|
-- | Retrieves the summary ops collected thus far. Typically this only
|
|
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
|
-- repeatedly, the values accumulate.
|
|
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
|
|
collectAllSummaries = use summaries
|