Introduce a MonadBuild class, and remove `buildAnd`. (#83)

This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
This commit is contained in:
Judah Jacobson 2017-03-18 12:08:53 -07:00 committed by GitHub
parent 9209dfc4c4
commit 2c5c879037
22 changed files with 152 additions and 162 deletions

View File

@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData let x = TF.vector xData
y = TF.vector yData y = TF.vector yData
-- Create scalar variables for slope and intercept. -- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0) w <- TF.initializedVariable 0
b <- TF.build (TF.initializedVariable 0) b <- TF.initializedVariable 0
-- Define the loss function. -- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y) loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent. -- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b]) trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep) replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters. -- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b) (TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float gradientDescent :: Float
-> TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float] -> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode -> TF.Session TF.ControlNode
gradientDescent alpha loss params = do gradientDescent alpha loss params = do
let applyGrad param grad = let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad)) TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -49,7 +49,7 @@ import TensorFlow.Tensor
) )
import TensorFlow.Ops import TensorFlow.Ops
import TensorFlow.Session import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd) (runSession, run, run_, runWithFeeds, build)
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar) import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
@ -108,7 +108,7 @@ testGraphDefExec :: Test
testGraphDefExec = testCase "testGraphDefExec" $ do testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do runSession $ do
build $ addGraphDef graphDef addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2" x <- run $ tensorFromName ValueKind "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x liftIO $ (50 :: Float) @=? unScalar x
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
wtsCkptPath <- liftIO wtsCkpt wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session. -- Run those restoring nodes on the graph in the current session.
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a]) run_ =<< (sequence :: Monad m => [m a] -> m [a])
[ restore wtsCkptPath wts [ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias , restoreFromName biasCkptPath "bias" bias
] ]

View File

@ -23,7 +23,7 @@ module TensorFlow.NN
import Prelude hiding ( log import Prelude hiding ( log
, exp , exp
) )
import TensorFlow.Build ( Build import TensorFlow.Build ( MonadBuild
, render , render
, withNameScope , withNameScope
) )
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
-- --
-- `logits` and `targets` must have the same type and shape. -- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a) :: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
=> Tensor Value a -- ^ __logits__ => Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__ -> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a) -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits logits' <- render logits
targets' <- render targets targets' <- render targets

View File

@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF import qualified TensorFlow.NN as TF
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
assertAllClose (head r) (V.fromList [0.5, -0.5]) assertAllClose (head r) (V.fromList [0.5, -0.5])
run :: TF.Fetchable t a => TF.Build t -> IO a run :: TF.Fetchable t a => TF.Session t -> IO a
run = TF.runSession . TF.buildAnd TF.run run = TF.runSession . (>>= TF.run)
main :: IO () main :: IO ()
main = googleTest [ testGradientAtZero main = googleTest [ testGradientAtZero

View File

@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs) </> indent indentation (sep tensorArgs)
where where
maybeLift
| parsedOpIsMonadic pOp = "build $"
| otherwise = ""
buildFunction buildFunction
| null outputListsSizes = "buildOp" | null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+> | otherwise = "buildListOp" <+>
@ -277,13 +280,18 @@ typeSig pOp = constraints
++ [outputs]) ++ [outputs])
where where
constraints constraints
| null (inferredTypeAttrs pOp) = empty | null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ map tensorArgConstraint ++ if parsedOpIsMonadic pOp then ["m'"] else []
$ inferredTypeAttrs pOp -- Use m' as the type parameter to avoid clashing with an attribute name.
monadConstraint
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
| otherwise = []
classConstraints = monadConstraint ++ map tensorArgConstraint
(inferredTypeAttrs pOp)
signatureFold = folddoc (\x y -> x </> "->" <+> y) signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a) attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a renderAttrType (AttrSingle a) = renderAttrBaseType a
@ -304,7 +312,7 @@ typeSig pOp = constraints
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a [a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o | parsedOpIsMonadic pOp = "m'" <+> parens o
| otherwise = o | otherwise = o
-- | Render an op input or output. -- | Render an op input or output.

View File

@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render) import TensorFlow.Build (MonadBuild, colocateWith, render)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- --
-- The results of the lookup are concatenated into a dense -- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v . embeddingLookup :: forall a b v m .
( TensorType a ( MonadBuild m
, TensorType a
, OneOf '[Int64, Int32] b , OneOf '[Int64, Int32] b
, Num b , Num b
) )
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
-- containing the ids to be looked up in `params`. -- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31 -- The ids are required to have fewer than 2^31
-- entries. -- entries.
-> Build (Tensor Value a) -> m (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`. -- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids) embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do embeddingLookup params@(p0 : _) ids = do

View File

@ -56,7 +56,9 @@ import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build import TensorFlow.Build
( Build ( MonadBuild
, Build
, build
, render , render
, renderNodeName , renderNodeName
, renderedNodeDefs , renderedNodeDefs
@ -111,16 +113,17 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@. -- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 . ( Num (Tensor v1 a) gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint. -- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance. -- It's a wart inherited from Num instance.
, v1 ~ Value , v1 ~ Value
, GradientCompatible a , GradientCompatible a
) )
=> Tensor v1 a -- ^ The output of the graph. => Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed. -> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> Build [Tensor Value a] -> m [Tensor Value a]
gradients y xs = do gradients y xs = build $ do
-- The gradients are computed using "reverse accumulation", similarly to -- The gradients are computed using "reverse accumulation", similarly to
-- what is described here: -- what is described here:
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation -- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation

View File

@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a => Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a) placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder shape' = placeholder shape' =
buildOp $ opDef "Placeholder" build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape' & opAttr "shape" .~ shape'
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a initializedVariable :: forall a m . (MonadBuild m, TensorType a)
=> Tensor Value a -> Build (Tensor Ref a) => Tensor Value a -> m (Tensor Ref a)
initializedVariable initializer = do initializedVariable initializer = do
v <- CoreOps.variable [] -- The shape is not known initially. v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <- (i :: Tensor Ref a) <-
buildOp (opDef "Assign" build $ buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a) & opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True & opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False & opAttr "validate_shape" .~ False
@ -179,32 +179,32 @@ initializedVariable initializer = do
-- | Creates a zero-initialized variable with the given shape. -- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable zeroInitializedVariable
:: (TensorType a, Num a) => :: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a) TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros zeroInitializedVariable = initializedVariable . zeros
-- TODO: Support heterogeneous list of tensors. -- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a save :: forall a m v . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save. -> [Tensor v a] -- ^ Tensors to save.
-> Build ControlNode -> m ControlNode
save path xs = do save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a)) let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save" let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types & opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs build $ saveOp (scalar path) (CoreOps.pack names) xs
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
-- --
-- This version allows restoring from a checkpoint file that uses a different -- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable. -- tensor name than the variable.
restoreFromName :: forall a . TensorType a restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override. -> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore. -> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode -> m ControlNode
restoreFromName path name x = do restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore" let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a) & opAttr "dt" .~ tensorType (undefined :: a)
@ -212,12 +212,12 @@ restoreFromName path name x = do
(restoreOp (scalar path) (scalar name) :: Tensor Value a) (restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore. -> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode -> m ControlNode
restore path x = do restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
restoreFromName path name x restoreFromName path name x
-- | Create a constant tensor. -- | Create a constant tensor.
@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x] scalar x = constant [] [x]
-- Random tensor from the unit normal distribution with bounded values. -- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a v . TensorType a truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
=> Tensor v Int64 -- ^ Shape. => Tensor v Int64 -- ^ Shape.
-> Build (Tensor Value a) -> m (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal" truncatedNormal
& opAttr "dtype" .~ tensorType (undefined :: a) = build . buildOp (opDef "TruncatedNormal"
& opAttr "T" .~ tensorType (undefined :: Int64) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64))
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)

View File

@ -19,7 +19,6 @@
module Main where module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Functor.Identity (runIdentity)
import Lens.Family2 ((^.)) import Lens.Family2 ((^.))
import Data.List (sort) import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph import Proto.Tensorflow.Core.Framework.Graph
@ -35,7 +34,6 @@ import TensorFlow.Build
, asGraphDef , asGraphDef
, evalBuildT , evalBuildT
, flushNodeBuffer , flushNodeBuffer
, hoistBuildT
, render , render
, withDevice , withDevice
, colocateWith , colocateWith
@ -53,9 +51,7 @@ import TensorFlow.Ops
import TensorFlow.Output (Device(..)) import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref) import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Session import TensorFlow.Session
( build ( run
, buildAnd
, run
, runSession , runSession
, run_ , run_
) )
@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
assign v 5 assign v 5
-- TODO: Implement TensorFlow get_variable and test it. -- TODO: Implement TensorFlow get_variable and test it.
runSession $ do runSession $ do
out <- buildAnd run graph out <- graph >>= run
liftIO $ 5 @=? (unScalar out :: Float) liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already -- | Test that "run" will render and extend any pure ops that haven't already
@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
testInitializedVariable :: Test testInitializedVariable :: Test
testInitializedVariable = testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do (formula, reset) <- do
v <- initializedVariable 42 v <- initializedVariable 42
r <- assign v 24 r <- assign v 24
return (1 `add` v, r) return (1 `add` v, r)
@ -109,7 +105,7 @@ testInitializedVariable =
testInitializedVariableShape :: Test testInitializedVariableShape :: Test
testInitializedVariableShape = testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float]) vector <- initializedVariable (constant [1] [42 :: Float])
result <- run vector result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float) liftIO $ [42] @=? (result :: V.Vector Float)
@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do
"RefIdentity" @=? (nodeDef ^. op) "RefIdentity" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name) "foo1/bar1" @=? (nodeDef ^. name)
-- | Lift a Build action into a context for HUnit to run.
liftBuild :: Build a -> BuildT IO a
liftBuild = hoistBuildT (return . runIdentity)
-- | Flush the node buffer and sort the nodes by name (for more stable tests). -- | Flush the node buffer and sort the nodes by name (for more stable tests).
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a] flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer flushed field = sort . map field <$> flushNodeBuffer
-- | Test the interaction of rendering, CSE and scoping. -- | Test the interaction of rendering, CSE and scoping.
testRenderDedup :: Test testRenderDedup :: Test
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes renderNodes
names <- flushed (^. name) names <- flushed (^. name)
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
-- Render the nodes in a different scope, which should cause them -- Render the nodes in a different scope, which should cause them
-- to be distinct from the previous ones. -- to be distinct from the previous ones.
liftBuild $ withNameScope "foo" renderNodes withNameScope "foo" renderNodes
scopedNames <- flushed (^. name) scopedNames <- flushed (^. name)
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
where where
@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
-- | Test the interaction of rendering, CSE and scoping. -- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation :: Test testDeviceColocation :: Test
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device)) devices <- flushed (\x -> (x ^. name, x ^. device))
liftIO $ [ ("Add_2","dev0") liftIO $ [ ("Add_2","dev0")
, ("Const_1","dev0") , ("Const_1","dev0")

View File

@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
restitch = CoreOps.dynamicStitch restitchIndices splitParts restitch = CoreOps.dynamicStitch restitchIndices splitParts
in monadicIO $ run $ do in monadicIO $ run $ do
fromIntegral numParts @=? length splitParts fromIntegral numParts @=? length splitParts
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch valuesOut <- TF.runSession $ TF.run restitch
V.fromList values @=? valuesOut V.fromList values @=? valuesOut
data StitchExample a = StitchExample Int64 [a] [Int32] data StitchExample a = StitchExample Int64 [a] [Int32]

View File

@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Build as TF import qualified TensorFlow.Build as TF
import qualified TensorFlow.Nodes as TF
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions. -- | Tries to perform a simple embedding lookup, with two partitions.
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
let ids = TF.constant (TF.Shape [1, 2]) idValues let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do (values, shape) <- TF.runSession $ do
vs <- op vs <- op
return (vs, TF.shape vs) TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python. -- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3] shape @=? V.fromList [1, 2, 3]
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
let ids = TF.constant (TF.Shape [1, 2]) idValues let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do (values, shape) <- TF.runSession $ do
vs <- op vs <- op
return (vs, TF.shape vs) TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python. -- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3] shape @=? V.fromList [1, 2, 3]
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
let shape = TF.Shape [2] let shape = TF.Shape [2]
gs <- TF.runSession $ do gs <- TF.runSession $ do
grads <- TF.build $ do
let embShape = TF.Shape [2, 1] let embShape = TF.Shape [2, 1]
let embeddingInit = [1, 20 ::Float] let embeddingInit = [1, 20 ::Float]
let idValues = [1, 1 :: Int32] let idValues = [1, 1 :: Int32]
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding]) grad <- fmap head (TF.gradients loss [embedding])
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad TF.runWithFeeds
[TF.feed x $ TF.encodeTensorData shape xVals]
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float) grad
-- Gradients should be zero (or close) -- Gradients should be zero (or close)
assertAllClose gs (V.fromList ([0, 0 :: Float])) assertAllClose gs (V.fromList ([0, 0 :: Float]))
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
shapedValues = TF.constant shape values shapedValues = TF.constant shape values
in monadicIO $ run $ do in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <- (shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.buildAnd TF.run $ do TF.runSession $ TF.run =<< do
embeddings <- embeddingLookup modShardedValues indicesVector embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs) return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup. -- Checks the explicitly documented invariant of embeddingLookup.

View File

@ -13,6 +13,7 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32) import Data.Int (Int32)
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
y = x*x + b y = x*x + b
grads = TF.gradients y [x, b] grads = TF.gradients y [x, b]
-- Assert that the gradients are right. -- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads [dx, db] <- TF.runSession $ grads >>= TF.run
6 @=? TF.unScalar dx 6 @=? TF.unScalar dx
1 @=? TF.unScalar db 1 @=? TF.unScalar db
-- Assert that the graph has the expected ops. -- Assert that the graph has the expected ops.
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
b = TF.scalar (4 :: Float) b = TF.scalar (4 :: Float)
grads = TF.gradients x [x, b] grads = TF.gradients x [x, b]
-- Assert that the gradients are right. -- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads [dx, db] <- TF.runSession $ grads >>= TF.run
1 @=? TF.unScalar dx 1 @=? TF.unScalar dx
0 @=? TF.unScalar db 0 @=? TF.unScalar db
-- Assert that the graph has the expected ops. -- Assert that the graph has the expected ops.
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
-- Test that identical "stateful" ops work with createGraph. -- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test testCreateGraphStateful :: Test
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do [dx, dy] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1] let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y] TF.gradients (x + y*3) [x, y] >>= TF.run
-- If this test fails, it will likely be caused by an exception within -- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra. -- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx 1 @=? TF.unScalar dx
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
-- Test that name scopes work with createGraph. -- Test that name scopes work with createGraph.
testCreateGraphNameScopes :: Test testCreateGraphNameScopes :: Test
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do [dx] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1] let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- x :: TF.Tensor TF.Value Float <-
TF.withNameScope "foo" (TF.truncatedNormal shape) TF.withNameScope "foo" (TF.truncatedNormal shape)
TF.gradients x [x] TF.gradients x [x] >>= TF.run
-- If this test fails, it will likely be caused by an exception within -- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. This assert is extra. -- `TF.gradients`. This assert is extra.
1 @=? TF.unScalar dx 1 @=? TF.unScalar dx
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
-- Test that createGraph can handle graphs with diamond shapes. -- Test that createGraph can handle graphs with diamond shapes.
testDiamond :: Test testDiamond :: Test
testDiamond = testCase "testDiamond" $ do testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do [dx] <- TF.runSession $ do
let x = TF.vector [1] let x = TF.vector [1]
y = x*x y = x*x
z = y*y z = y*y
TF.gradients z [x] TF.gradients z [x] >>= TF.run
(4 :: Float) @=? TF.unScalar dx (4 :: Float) @=? TF.unScalar dx
testMaxGradient :: Test testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do [dx] <- TF.runSession $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float] let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32) y = TF.max x (0 :: TF.Tensor TF.Value Int32)
TF.gradients y [x] TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx

View File

@ -41,7 +41,7 @@ testSize = testCase "testSize" $ do
TF.Scalar (2 * 3 :: Int32) @=? x TF.Scalar (2 * 3 :: Int32) @=? x
eval :: TF.Fetchable t a => t -> IO a eval :: TF.Fetchable t a => t -> IO a
eval = TF.runSession . TF.buildAnd TF.run . return eval = TF.runSession . TF.run
-- | Confirms that the original example from Python code works. -- | Confirms that the original example from Python code works.
testReducedShape :: Test testReducedShape :: Test
@ -54,16 +54,16 @@ testSaveRestore :: Test
testSaveRestore = testCase "testSaveRestore" $ testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint" let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float) var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<< var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape []) TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.runSession $ do TF.runSession $ do
v <- TF.build var v <- var
TF.buildAnd TF.run_ $ TF.assign v 134 TF.assign v 134 >>= TF.run_
TF.buildAnd TF.run_ $ TF.save path [v] TF.save path [v] >>= TF.run_
result <- TF.runSession $ do result <- TF.runSession $ do
v <- TF.build var v <- var
TF.buildAnd TF.run_ $ TF.restore path v TF.restore path v >>= TF.run_
TF.run v TF.run v
liftIO $ TF.Scalar 134 @=? result liftIO $ TF.Scalar 134 @=? result

View File

@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData let x = TF.vector xData
y = TF.vector yData y = TF.vector yData
-- Create scalar variables for slope and intercept. -- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0) w <- TF.initializedVariable 0
b <- TF.build (TF.initializedVariable 0) b <- TF.initializedVariable 0
-- Define the loss function. -- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y) loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent. -- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b]) trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep) replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters. -- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b) (TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float gradientDescent :: Float
-> TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float] -> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode -> TF.Session TF.ControlNode
gradientDescent alpha loss params = do gradientDescent alpha loss params = do
let applyGrad param grad = let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad)) TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -35,7 +35,7 @@ testTracing = do
loggedValue <- newEmptyMVar loggedValue <- newEmptyMVar
TF.runSessionWithOptions TF.runSessionWithOptions
(def & TF.sessionTracer .~ putMVar loggedValue) (def & TF.sessionTracer .~ putMVar loggedValue)
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float)))) (TF.run_ (TF.scalar (0 :: Float)))
tryReadMVar loggedValue >>= tryReadMVar loggedValue >>=
maybe (assertFailure "Logging never happened") expectedFormat maybe (assertFailure "Logging never happened") expectedFormat
where expectedFormat x = where expectedFormat x =

View File

@ -24,10 +24,11 @@ import Data.ByteString (ByteString)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Proxy (Proxy(..)) import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&)) import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef) import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp) import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group) import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor, TensorList) import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes) import TensorFlow.Types (TensorTypes, fromTensorTypes)
-- | A queue carrying tuples. -- | A queue carrying tuples.
@ -36,36 +37,30 @@ data Queue (as :: [*]) = Queue { handle :: Handle }
type Handle = Tensor Ref ByteString type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue. -- | Adds the given values to the queue.
enqueue :: forall as v . TensorTypes as enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
=> Queue as => Queue as
-> TensorList v as -> TensorList v as
-> Build ControlNode -> m ControlNode
enqueue q = enqueue = CoreOps.queueEnqueue . handle
buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
-- | Retrieves the values from the queue. -- | Retrieves the values from the queue.
dequeue :: forall as . TensorTypes as dequeue :: forall as m . (MonadBuild m, TensorTypes as)
=> Queue as => Queue as
-> Build (TensorList Ref as) -> m (TensorList Value as)
-- ^ Dequeued tensors. They are coupled in a sense -- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are -- that values appear together, even if they are
-- not consumed together. -- not consumed together.
dequeue q = dequeue = CoreOps.queueDequeue . handle
buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
-- | Creates a new queue with the given capacity and shared name. -- | Creates a new queue with the given capacity and shared name.
makeQueue :: forall as . TensorTypes as makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
=> Int64 -- ^ The upper bound on the number of elements in => Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit. -- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared -> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions. -- under the given name across multiple sessions.
-> Build (Queue as) -> m (Queue as)
makeQueue capacity sharedName = do makeQueue capacity sharedName = do
q <- buildOp (opDef "FIFOQueue" q <- build $ buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName & opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity & opAttr "capacity" .~ capacity

View File

@ -39,6 +39,7 @@ Test-Suite QueueTest
, lens-family , lens-family
, google-shim , google-shim
, tensorflow , tensorflow
, tensorflow-core-ops
, tensorflow-ops , tensorflow-ops
, tensorflow-queue , tensorflow-queue
, test-framework , test-framework

View File

@ -27,7 +27,6 @@ import TensorFlow.Queue
import TensorFlow.Session import TensorFlow.Session
( asyncProdNodes ( asyncProdNodes
, build , build
, buildAnd
, run , run
, runSession , runSession
, run_ , run_
@ -41,12 +40,12 @@ import qualified Data.ByteString as BS
testBasic :: Test testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do testBasic = testCase "testBasic" $ runSession $ do
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 "" q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
x <- buildAnd run (dequeue q) x <- run =<< dequeue q
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
y <- buildAnd run (dequeue q) y <- run =<< dequeue q
-- Note: we use explicit "Scalar" here to specify the type that was -- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write -- fetched. Equivalently we could write
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString] -- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
@ -74,7 +73,7 @@ testPump = testCase "testPump" $ runSession $ do
testAsync :: Test testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do (deq, pump) <- do
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "" q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
(,) <$> dequeue q (,) <$> dequeue q
<*> enqueue q (10 :/ scalar "Async" :/ Nil) <*> enqueue q (10 :/ scalar "Async" :/ Nil)

View File

@ -37,6 +37,7 @@ module TensorFlow.Build
, renderedNodeDefs , renderedNodeDefs
, BuildT , BuildT
, Build , Build
, MonadBuild(..)
, addInitializer , addInitializer
, hoistBuildT , hoistBuildT
, evalBuildT , evalBuildT
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
evalBuildT :: Monad m => BuildT m a -> m a evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT (BuildT f) = evalStateT f initGraphState evalBuildT (BuildT f) = evalStateT f initGraphState
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
class Monad m => MonadBuild m where
build :: Build a -> m a
instance Monad m => MonadBuild (BuildT m) where
build = hoistBuildT $ return . runIdentity
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer. -- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
flushNodeBuffer :: Monad m => BuildT m [NodeDef] flushNodeBuffer :: MonadBuild m => m [NodeDef]
flushNodeBuffer = do flushNodeBuffer = build $ do
ns <- use nodeBuffer ns <- use nodeBuffer
nodeBuffer .= [] nodeBuffer .= []
return ns return ns
@ -229,8 +237,8 @@ flushInitializers = do
-- | Registers the given node to be executed before the next -- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'. -- 'TensorFlow.Session.run'.
addInitializer :: ControlNode -> Build () addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode o) = do addInitializer (ControlNode o) = build $ do
i <- getOrAddOp o i <- getOrAddOp o
initializationNodes %= (i:) initializationNodes %= (i:)
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
gs = snd $ runIdentity $ runBuildT b gs = snd $ runIdentity $ runBuildT b
-- TODO: check against existing nodes for conflicts? -- TODO: check against existing nodes for conflicts?
addGraphDef :: GraphDef -> Build () addGraphDef :: MonadBuild m => GraphDef -> m ()
addGraphDef g = nodeBuffer <>= g ^. node addGraphDef g = build $ nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its -- | Render the given op if it hasn't been rendered already, and return its
-- name. -- name.
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
-- | Modify some part of the state, run an action, and restore the state -- | Modify some part of the state, run an action, and restore the state
-- after that action is done. -- after that action is done.
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens accessor f act = do withStateLens accessor f act = do
old <- use accessor old <- build $ use accessor
accessor %= f build $ accessor %= f
result <- act result <- act
accessor .= old build $ accessor .= old
return result return result
-- | Set a device for all nodes rendered in the given 'Build' action -- | Set a device for all nodes rendered in the given 'Build' action
-- (unless further overridden by another use of withDevice). -- (unless further overridden by another use of withDevice).
withDevice :: Maybe Device -> Build a -> Build a withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d) withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same -- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that -- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure -- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect. -- return would not have the desired effect.
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
colocateWith t x = do colocateWith t x = do
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp) d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action. -- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: Text -> Build a -> Build a withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :) withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action. -- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: Set NodeName -> Build a -> Build a withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from -- | Render a 'Tensor', fixing its name, scope, device and control inputs from
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- This operation is idempotent; @render >=> render === render@. However, -- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts -- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's. -- may result in two different 'Tensor's.
render :: Tensor v a -> Build (Tensor v a) render :: MonadBuild m => Tensor v a -> m (Tensor v a)
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
-- | Render a 'Tensor' and get its node's name. -- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName renderNodeName :: Tensor v a -> Build NodeName

View File

@ -40,9 +40,9 @@ import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend -- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument. -- on the nodes in the first argument.
withControlDependencies :: Nodes t => t -> Build a -> Build a withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
withControlDependencies deps act = do withControlDependencies deps act = do
nodes <- getNodes deps nodes <- build $ getNodes deps
withNodeDependencies nodes act withNodeDependencies nodes act
-- TODO(judahjacobson): Reimplement withDependencies. -- TODO(judahjacobson): Reimplement withDependencies.
@ -51,9 +51,9 @@ withControlDependencies deps act = do
-- --
-- When this op finishes, all ops in the input @n@ have finished. This op has -- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output. -- no output.
group :: Nodes t => t -> Build ControlNode group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group deps = do group deps = do
nodes <- Set.toList <$> getNodes deps nodes <- build $ Set.toList <$> getNodes deps
-- TODO: slicker way -- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes

View File

@ -31,8 +31,7 @@ module TensorFlow.Core
, runSession , runSession
, runSessionWithOptions , runSessionWithOptions
-- ** Building graphs -- ** Building graphs
, build , MonadBuild(..)
, buildAnd
-- ** Running graphs -- ** Running graphs
, Fetchable , Fetchable
, Nodes , Nodes

View File

@ -26,8 +26,7 @@ module TensorFlow.Session (
sessionTracer, sessionTracer,
runSession, runSession,
runSessionWithOptions, runSessionWithOptions,
build, MonadBuild(..),
buildAnd,
extend, extend,
addGraphDef, addGraphDef,
run, run,
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Default (Default, def) import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.ProtoLens (showMessage) import Data.ProtoLens (showMessage)
import Data.Set (Set) import Data.Set (Set)
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
FFI.setSessionTarget (options ^. sessionTarget) opt FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt FFI.setSessionConfig (options ^. sessionConfig) opt
-- | Lift a 'Build' action into a 'Session', including any explicit op instance MonadBuild Session where
-- renderings. build = Session . lift . build
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
-- | Add all pending rendered nodes to the TensorFlow graph and runs -- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers. -- any pending initializers.
@ -147,13 +143,6 @@ extend = do
unless (null initializers) $ unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers) void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action.
-- Example usage:
--
-- > buildAnd run :: Fetchable t a => Build t -> Session a
buildAnd :: (a -> Session b) -> Build a -> Session b
buildAnd f m = build m >>= f
-- | Run a subgraph 't', rendering any dependent nodes that aren't already -- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'. -- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a run :: Fetchable t a => t -> Session a