Introduce a MonadBuild class, and remove `buildAnd`. (#83)

This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
This commit is contained in:
Judah Jacobson 2017-03-18 12:08:53 -07:00 committed by GitHub
parent 9209dfc4c4
commit 2c5c879037
22 changed files with 152 additions and 162 deletions

View File

@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0)
b <- TF.build (TF.initializedVariable 0)
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -49,7 +49,7 @@ import TensorFlow.Tensor
)
import TensorFlow.Ops
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd)
(runSession, run, run_, runWithFeeds, build)
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -108,7 +108,7 @@ testGraphDefExec :: Test
testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
build $ addGraphDef graphDef
addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
run_ =<< (sequence :: Monad m => [m a] -> m [a])
[ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias
]

View File

@ -23,7 +23,7 @@ module TensorFlow.NN
import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( Build
import TensorFlow.Build ( MonadBuild
, render
, withNameScope
)
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
--
-- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a)
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
=> Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a)
-> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits
targets' <- render targets

View File

@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
assertAllClose (head r) (V.fromList [0.5, -0.5])
run :: TF.Fetchable t a => TF.Build t -> IO a
run = TF.runSession . TF.buildAnd TF.run
run :: TF.Fetchable t a => TF.Session t -> IO a
run = TF.runSession . (>>= TF.run)
main :: IO ()
main = googleTest [ testGradientAtZero

View File

@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
maybeLift
| parsedOpIsMonadic pOp = "build $"
| otherwise = ""
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+>
@ -277,13 +280,18 @@ typeSig pOp = constraints
++ [outputs])
where
constraints
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ map tensorArgConstraint
$ inferredTypeAttrs pOp
++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name.
monadConstraint
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
| otherwise = []
classConstraints = monadConstraint ++ map tensorArgConstraint
(inferredTypeAttrs pOp)
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
@ -304,7 +312,7 @@ typeSig pOp = constraints
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o
| parsedOpIsMonadic pOp = "m'" <+> parens o
| otherwise = o
-- | Render an op input or output.

View File

@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Build (MonadBuild, colocateWith, render)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v .
( TensorType a
embeddingLookup :: forall a b v m .
( MonadBuild m
, TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
-- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31
-- entries.
-> Build (Tensor Value a)
-> m (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do

View File

@ -56,7 +56,9 @@ import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( Build
( MonadBuild
, Build
, build
, render
, renderNodeName
, renderedNodeDefs
@ -111,16 +113,17 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> Build [Tensor Value a]
gradients y xs = do
-> m [Tensor Value a]
gradients y xs = build $ do
-- The gradients are computed using "reverse accumulation", similarly to
-- what is described here:
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation

View File

@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a)
initializedVariable initializer = do
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
build $ buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
@ -179,32 +179,32 @@ initializedVariable initializer = do
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a
save :: forall a m v . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> Build ControlNode
-> m ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs
build $ saveOp (scalar path) (CoreOps.pack names) xs
-- | Restore a tensor's value from a checkpoint file.
--
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a . TensorType a
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
@ -212,12 +212,12 @@ restoreFromName path name x = do
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
restoreFromName path name x
-- | Create a constant tensor.
@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]
-- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a v . TensorType a
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
=> Tensor v Int64 -- ^ Shape.
-> Build (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
-> m (Tensor Value a)
truncatedNormal
= build . buildOp (opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64))
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)

View File

@ -19,7 +19,6 @@
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Functor.Identity (runIdentity)
import Lens.Family2 ((^.))
import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph
@ -35,7 +34,6 @@ import TensorFlow.Build
, asGraphDef
, evalBuildT
, flushNodeBuffer
, hoistBuildT
, render
, withDevice
, colocateWith
@ -53,9 +51,7 @@ import TensorFlow.Ops
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Session
( build
, buildAnd
, run
( run
, runSession
, run_
)
@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- buildAnd run graph
out <- graph >>= run
liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already
@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
testInitializedVariable :: Test
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do
(formula, reset) <- do
v <- initializedVariable 42
r <- assign v 24
return (1 `add` v, r)
@ -109,7 +105,7 @@ testInitializedVariable =
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float])
vector <- initializedVariable (constant [1] [42 :: Float])
result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float)
@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do
"RefIdentity" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name)
-- | Lift a Build action into a context for HUnit to run.
liftBuild :: Build a -> BuildT IO a
liftBuild = hoistBuildT (return . runIdentity)
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer
flushed field = sort . map field <$> flushNodeBuffer
-- | Test the interaction of rendering, CSE and scoping.
testRenderDedup :: Test
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
names <- flushed (^. name)
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
-- Render the nodes in a different scope, which should cause them
-- to be distinct from the previous ones.
liftBuild $ withNameScope "foo" renderNodes
withNameScope "foo" renderNodes
scopedNames <- flushed (^. name)
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
where
@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
-- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation :: Test
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device))
liftIO $ [ ("Add_2","dev0")
, ("Const_1","dev0")

View File

@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
restitch = CoreOps.dynamicStitch restitchIndices splitParts
in monadicIO $ run $ do
fromIntegral numParts @=? length splitParts
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
valuesOut <- TF.runSession $ TF.run restitch
V.fromList values @=? valuesOut
data StitchExample a = StitchExample Int64 [a] [Int32]

View File

@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Nodes as TF
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do
(values, shape) <- TF.runSession $ do
vs <- op
return (vs, TF.shape vs)
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3]
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do
(values, shape) <- TF.runSession $ do
vs <- op
return (vs, TF.shape vs)
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3]
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
let shape = TF.Shape [2]
gs <- TF.runSession $ do
grads <- TF.build $ do
let embShape = TF.Shape [2, 1]
let embeddingInit = [1, 20 ::Float]
let idValues = [1, 1 :: Int32]
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding])
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
TF.runWithFeeds
[TF.feed x $ TF.encodeTensorData shape xVals]
grad
-- Gradients should be zero (or close)
assertAllClose gs (V.fromList ([0, 0 :: Float]))
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.buildAnd TF.run $ do
TF.runSession $ TF.run =<< do
embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.

View File

@ -13,6 +13,7 @@
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32)
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
y = x*x + b
grads = TF.gradients y [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
[dx, db] <- TF.runSession $ grads >>= TF.run
6 @=? TF.unScalar dx
1 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
b = TF.scalar (4 :: Float)
grads = TF.gradients x [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
[dx, db] <- TF.runSession $ grads >>= TF.run
1 @=? TF.unScalar dx
0 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
-- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx, dy] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y]
TF.gradients (x + y*3) [x, y] >>= TF.run
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
-- Test that name scopes work with createGraph.
testCreateGraphNameScopes :: Test
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <-
TF.withNameScope "foo" (TF.truncatedNormal shape)
TF.gradients x [x]
TF.gradients x [x] >>= TF.run
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. This assert is extra.
1 @=? TF.unScalar dx
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
-- Test that createGraph can handle graphs with diamond shapes.
testDiamond :: Test
testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1]
y = x*x
z = y*y
TF.gradients z [x]
TF.gradients z [x] >>= TF.run
(4 :: Float) @=? TF.unScalar dx
testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
TF.gradients y [x]
TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx

View File

@ -41,7 +41,7 @@ testSize = testCase "testSize" $ do
TF.Scalar (2 * 3 :: Int32) @=? x
eval :: TF.Fetchable t a => t -> IO a
eval = TF.runSession . TF.buildAnd TF.run . return
eval = TF.runSession . TF.run
-- | Confirms that the original example from Python code works.
testReducedShape :: Test
@ -54,16 +54,16 @@ testSaveRestore :: Test
testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float)
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.assign v 134
TF.buildAnd TF.run_ $ TF.save path [v]
v <- var
TF.assign v 134 >>= TF.run_
TF.save path [v] >>= TF.run_
result <- TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.restore path v
v <- var
TF.restore path v >>= TF.run_
TF.run v
liftIO $ TF.Scalar 134 @=? result

View File

@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0)
b <- TF.build (TF.initializedVariable 0)
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -35,7 +35,7 @@ testTracing = do
loggedValue <- newEmptyMVar
TF.runSessionWithOptions
(def & TF.sessionTracer .~ putMVar loggedValue)
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
(TF.run_ (TF.scalar (0 :: Float)))
tryReadMVar loggedValue >>=
maybe (assertFailure "Logging never happened") expectedFormat
where expectedFormat x =

View File

@ -24,10 +24,11 @@ import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor, TensorList)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)
-- | A queue carrying tuples.
@ -36,36 +37,30 @@ data Queue (as :: [*]) = Queue { handle :: Handle }
type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue.
enqueue :: forall as v . TensorTypes as
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
=> Queue as
-> TensorList v as
-> Build ControlNode
enqueue q =
buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
-> m ControlNode
enqueue = CoreOps.queueEnqueue . handle
-- | Retrieves the values from the queue.
dequeue :: forall as . TensorTypes as
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
=> Queue as
-> Build (TensorList Ref as)
-> m (TensorList Value as)
-- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are
-- not consumed together.
dequeue q =
buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
dequeue = CoreOps.queueDequeue . handle
-- | Creates a new queue with the given capacity and shared name.
makeQueue :: forall as . TensorTypes as
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
=> Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions.
-> Build (Queue as)
-> m (Queue as)
makeQueue capacity sharedName = do
q <- buildOp (opDef "FIFOQueue"
q <- build $ buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity

View File

@ -39,6 +39,7 @@ Test-Suite QueueTest
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-queue
, test-framework

View File

@ -27,7 +27,6 @@ import TensorFlow.Queue
import TensorFlow.Session
( asyncProdNodes
, build
, buildAnd
, run
, runSession
, run_
@ -41,12 +40,12 @@ import qualified Data.ByteString as BS
testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
x <- buildAnd run (dequeue q)
run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
x <- run =<< dequeue q
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
y <- buildAnd run (dequeue q)
run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
y <- run =<< dequeue q
-- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
@ -74,7 +73,7 @@ testPump = testCase "testPump" $ runSession $ do
testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do
(deq, pump) <- do
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
(,) <$> dequeue q
<*> enqueue q (10 :/ scalar "Async" :/ Nil)

View File

@ -37,6 +37,7 @@ module TensorFlow.Build
, renderedNodeDefs
, BuildT
, Build
, MonadBuild(..)
, addInitializer
, hoistBuildT
, evalBuildT
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT (BuildT f) = evalStateT f initGraphState
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
class Monad m => MonadBuild m where
build :: Build a -> m a
instance Monad m => MonadBuild (BuildT m) where
build = hoistBuildT $ return . runIdentity
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
flushNodeBuffer = do
flushNodeBuffer :: MonadBuild m => m [NodeDef]
flushNodeBuffer = build $ do
ns <- use nodeBuffer
nodeBuffer .= []
return ns
@ -229,8 +237,8 @@ flushInitializers = do
-- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'.
addInitializer :: ControlNode -> Build ()
addInitializer (ControlNode o) = do
addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode o) = build $ do
i <- getOrAddOp o
initializationNodes %= (i:)
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
gs = snd $ runIdentity $ runBuildT b
-- TODO: check against existing nodes for conflicts?
addGraphDef :: GraphDef -> Build ()
addGraphDef g = nodeBuffer <>= g ^. node
addGraphDef :: MonadBuild m => GraphDef -> m ()
addGraphDef g = build $ nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its
-- name.
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
-- | Modify some part of the state, run an action, and restore the state
-- after that action is done.
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens accessor f act = do
old <- use accessor
accessor %= f
old <- build $ use accessor
build $ accessor %= f
result <- act
accessor .= old
build $ accessor .= old
return result
-- | Set a device for all nodes rendered in the given 'Build' action
-- (unless further overridden by another use of withDevice).
withDevice :: Maybe Device -> Build a -> Build a
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
colocateWith t x = do
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: Text -> Build a -> Build a
withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: Set NodeName -> Build a -> Build a
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's.
render :: Tensor v a -> Build (Tensor v a)
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
-- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName

View File

@ -40,9 +40,9 @@ import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
withControlDependencies :: Nodes t => t -> Build a -> Build a
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
withControlDependencies deps act = do
nodes <- getNodes deps
nodes <- build $ getNodes deps
withNodeDependencies nodes act
-- TODO(judahjacobson): Reimplement withDependencies.
@ -51,9 +51,9 @@ withControlDependencies deps act = do
--
-- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output.
group :: Nodes t => t -> Build ControlNode
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group deps = do
nodes <- Set.toList <$> getNodes deps
nodes <- build $ Set.toList <$> getNodes deps
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes

View File

@ -31,8 +31,7 @@ module TensorFlow.Core
, runSession
, runSessionWithOptions
-- ** Building graphs
, build
, buildAnd
, MonadBuild(..)
-- ** Running graphs
, Fetchable
, Nodes

View File

@ -26,8 +26,7 @@ module TensorFlow.Session (
sessionTracer,
runSession,
runSessionWithOptions,
build,
buildAnd,
MonadBuild(..),
extend,
addGraphDef,
run,
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>))
import Data.ProtoLens (showMessage)
import Data.Set (Set)
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings.
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
instance MonadBuild Session where
build = Session . lift . build
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
@ -147,13 +143,6 @@ extend = do
unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action.
-- Example usage:
--
-- > buildAnd run :: Fetchable t a => Build t -> Session a
buildAnd :: (a -> Session b) -> Build a -> Session b
buildAnd f m = build m >>= f
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a