mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Introduce a MonadBuild class, and remove buildAnd
. (#83)
This change adds a class that both `Build` and `Session` are instances of: class MonadBuild m where build :: Build a -> m a All stateful ops (generated and manually written) now have a signature that returns an instance of `MonadBuild` (rather than just `Build`). For example: assign_ :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v t -> m (Tensor Ref t) This lets us remove a bunch of spurious calls to `build` in user code. It also lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run` (or `run =<< foo`, which is sometimes nicer when foo is a complicated expression). I went ahead and deleted `buildAnd` altogether since it seems to lead to confusion; in particular a few tests had `buildAnd run . pure` which is actually equivalent to just `run`.
This commit is contained in:
parent
9209dfc4c4
commit
2c5c879037
22 changed files with 152 additions and 162 deletions
|
@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
|
||||||
let x = TF.vector xData
|
let x = TF.vector xData
|
||||||
y = TF.vector yData
|
y = TF.vector yData
|
||||||
-- Create scalar variables for slope and intercept.
|
-- Create scalar variables for slope and intercept.
|
||||||
w <- TF.build (TF.initializedVariable 0)
|
w <- TF.initializedVariable 0
|
||||||
b <- TF.build (TF.initializedVariable 0)
|
b <- TF.initializedVariable 0
|
||||||
-- Define the loss function.
|
-- Define the loss function.
|
||||||
let yHat = (x `TF.mul` w) `TF.add` b
|
let yHat = (x `TF.mul` w) `TF.add` b
|
||||||
loss = TF.square (yHat `TF.sub` y)
|
loss = TF.square (yHat `TF.sub` y)
|
||||||
-- Optimize with gradient descent.
|
-- Optimize with gradient descent.
|
||||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||||
replicateM_ 1000 (TF.run trainStep)
|
replicateM_ 1000 (TF.run trainStep)
|
||||||
-- Return the learned parameters.
|
-- Return the learned parameters.
|
||||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||||
|
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
|
||||||
gradientDescent :: Float
|
gradientDescent :: Float
|
||||||
-> TF.Tensor TF.Value Float
|
-> TF.Tensor TF.Value Float
|
||||||
-> [TF.Tensor TF.Ref Float]
|
-> [TF.Tensor TF.Ref Float]
|
||||||
-> TF.Build TF.ControlNode
|
-> TF.Session TF.ControlNode
|
||||||
gradientDescent alpha loss params = do
|
gradientDescent alpha loss params = do
|
||||||
let applyGrad param grad =
|
let applyGrad param grad =
|
||||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||||
|
|
|
@ -49,7 +49,7 @@ import TensorFlow.Tensor
|
||||||
)
|
)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
(runSession, run, run_, runWithFeeds, build)
|
||||||
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -108,7 +108,7 @@ testGraphDefExec :: Test
|
||||||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||||
runSession $ do
|
runSession $ do
|
||||||
build $ addGraphDef graphDef
|
addGraphDef graphDef
|
||||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||||
liftIO $ (50 :: Float) @=? unScalar x
|
liftIO $ (50 :: Float) @=? unScalar x
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
||||||
wtsCkptPath <- liftIO wtsCkpt
|
wtsCkptPath <- liftIO wtsCkpt
|
||||||
biasCkptPath <- liftIO biasCkpt
|
biasCkptPath <- liftIO biasCkpt
|
||||||
-- Run those restoring nodes on the graph in the current session.
|
-- Run those restoring nodes on the graph in the current session.
|
||||||
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
run_ =<< (sequence :: Monad m => [m a] -> m [a])
|
||||||
[ restore wtsCkptPath wts
|
[ restore wtsCkptPath wts
|
||||||
, restoreFromName biasCkptPath "bias" bias
|
, restoreFromName biasCkptPath "bias" bias
|
||||||
]
|
]
|
||||||
|
|
|
@ -23,7 +23,7 @@ module TensorFlow.NN
|
||||||
import Prelude hiding ( log
|
import Prelude hiding ( log
|
||||||
, exp
|
, exp
|
||||||
)
|
)
|
||||||
import TensorFlow.Build ( Build
|
import TensorFlow.Build ( MonadBuild
|
||||||
, render
|
, render
|
||||||
, withNameScope
|
, withNameScope
|
||||||
)
|
)
|
||||||
|
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
|
||||||
--
|
--
|
||||||
-- `logits` and `targets` must have the same type and shape.
|
-- `logits` and `targets` must have the same type and shape.
|
||||||
sigmoidCrossEntropyWithLogits
|
sigmoidCrossEntropyWithLogits
|
||||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
|
||||||
=> Tensor Value a -- ^ __logits__
|
=> Tensor Value a -- ^ __logits__
|
||||||
-> Tensor Value a -- ^ __targets__
|
-> Tensor Value a -- ^ __targets__
|
||||||
-> Build (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
sigmoidCrossEntropyWithLogits logits targets = do
|
sigmoidCrossEntropyWithLogits logits targets = do
|
||||||
logits' <- render logits
|
logits' <- render logits
|
||||||
targets' <- render targets
|
targets' <- render targets
|
||||||
|
|
|
@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
|
||||||
import Test.Framework (Test)
|
import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.Build as TF
|
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Nodes as TF
|
import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.NN as TF
|
import qualified TensorFlow.NN as TF
|
||||||
|
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||||
|
|
||||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||||
|
|
||||||
run :: TF.Fetchable t a => TF.Build t -> IO a
|
run :: TF.Fetchable t a => TF.Session t -> IO a
|
||||||
run = TF.runSession . TF.buildAnd TF.run
|
run = TF.runSession . (>>= TF.run)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testGradientAtZero
|
main = googleTest [ testGradientAtZero
|
||||||
|
|
|
@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc
|
||||||
renderHaskellAttrName = renderHaskellName . attrName
|
renderHaskellAttrName = renderHaskellName . attrName
|
||||||
|
|
||||||
functionBody :: ParsedOp -> Doc
|
functionBody :: ParsedOp -> Doc
|
||||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||||
</> indent indentation (sep tensorArgs)
|
</> indent indentation (sep tensorArgs)
|
||||||
where
|
where
|
||||||
|
maybeLift
|
||||||
|
| parsedOpIsMonadic pOp = "build $"
|
||||||
|
| otherwise = ""
|
||||||
buildFunction
|
buildFunction
|
||||||
| null outputListsSizes = "buildOp"
|
| null outputListsSizes = "buildOp"
|
||||||
| otherwise = "buildListOp" <+>
|
| otherwise = "buildListOp" <+>
|
||||||
|
@ -277,13 +280,18 @@ typeSig pOp = constraints
|
||||||
++ [outputs])
|
++ [outputs])
|
||||||
where
|
where
|
||||||
constraints
|
constraints
|
||||||
| null (inferredTypeAttrs pOp) = empty
|
| null classConstraints = empty
|
||||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||||
classConstraints = tuple $ map tensorArgConstraint
|
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||||
$ inferredTypeAttrs pOp
|
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||||
|
monadConstraint
|
||||||
|
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
|
||||||
|
| otherwise = []
|
||||||
|
classConstraints = monadConstraint ++ map tensorArgConstraint
|
||||||
|
(inferredTypeAttrs pOp)
|
||||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||||
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||||
|
@ -304,7 +312,7 @@ typeSig pOp = constraints
|
||||||
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
||||||
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
||||||
wrapOutput o
|
wrapOutput o
|
||||||
| parsedOpIsMonadic pOp = "Build" <+> parens o
|
| parsedOpIsMonadic pOp = "m'" <+> parens o
|
||||||
| otherwise = o
|
| otherwise = o
|
||||||
|
|
||||||
-- | Render an op input or output.
|
-- | Render an op input or output.
|
||||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
||||||
|
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import TensorFlow.Build (Build, colocateWith, render)
|
import TensorFlow.Build (MonadBuild, colocateWith, render)
|
||||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||||
import TensorFlow.Tensor (Tensor, Value)
|
import TensorFlow.Tensor (Tensor, Value)
|
||||||
import TensorFlow.Types (OneOf, TensorType)
|
import TensorFlow.Types (OneOf, TensorType)
|
||||||
|
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
--
|
--
|
||||||
-- The results of the lookup are concatenated into a dense
|
-- The results of the lookup are concatenated into a dense
|
||||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup :: forall a b v .
|
embeddingLookup :: forall a b v m .
|
||||||
( TensorType a
|
( MonadBuild m
|
||||||
|
, TensorType a
|
||||||
, OneOf '[Int64, Int32] b
|
, OneOf '[Int64, Int32] b
|
||||||
, Num b
|
, Num b
|
||||||
)
|
)
|
||||||
|
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
|
||||||
-- containing the ids to be looked up in `params`.
|
-- containing the ids to be looked up in `params`.
|
||||||
-- The ids are required to have fewer than 2^31
|
-- The ids are required to have fewer than 2^31
|
||||||
-- entries.
|
-- entries.
|
||||||
-> Build (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
||||||
embeddingLookup params@(p0 : _) ids = do
|
embeddingLookup params@(p0 : _) ids = do
|
||||||
|
|
|
@ -56,7 +56,9 @@ import qualified Data.Text as Text
|
||||||
|
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
( Build
|
( MonadBuild
|
||||||
|
, Build
|
||||||
|
, build
|
||||||
, render
|
, render
|
||||||
, renderNodeName
|
, renderNodeName
|
||||||
, renderedNodeDefs
|
, renderedNodeDefs
|
||||||
|
@ -111,16 +113,17 @@ type GradientCompatible a =
|
||||||
|
|
||||||
|
|
||||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||||
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||||
|
, Num (Tensor v1 a)
|
||||||
-- TODO(gnezdo): remove indirect constraint.
|
-- TODO(gnezdo): remove indirect constraint.
|
||||||
-- It's a wart inherited from Num instance.
|
-- It's a wart inherited from Num instance.
|
||||||
, v1 ~ Value
|
, v1 ~ Value
|
||||||
, GradientCompatible a
|
, GradientCompatible a
|
||||||
)
|
)
|
||||||
=> Tensor v1 a -- ^ The output of the graph.
|
=> Tensor v1 a -- ^ The output of the graph.
|
||||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||||
-> Build [Tensor Value a]
|
-> m [Tensor Value a]
|
||||||
gradients y xs = do
|
gradients y xs = build $ do
|
||||||
-- The gradients are computed using "reverse accumulation", similarly to
|
-- The gradients are computed using "reverse accumulation", similarly to
|
||||||
-- what is described here:
|
-- what is described here:
|
||||||
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
||||||
|
|
|
@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a
|
||||||
=> Tensor v a -> Tensor Value a
|
=> Tensor v a -> Tensor Value a
|
||||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||||
|
|
||||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||||
placeholder shape' =
|
placeholder shape' =
|
||||||
buildOp $ opDef "Placeholder"
|
build $ buildOp $ opDef "Placeholder"
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
& opAttr "shape" .~ shape'
|
& opAttr "shape" .~ shape'
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: forall a . TensorType a
|
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
|
||||||
=> Tensor Value a -> Build (Tensor Ref a)
|
=> Tensor Value a -> m (Tensor Ref a)
|
||||||
initializedVariable initializer = do
|
initializedVariable initializer = do
|
||||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||||
(i :: Tensor Ref a) <-
|
(i :: Tensor Ref a) <-
|
||||||
buildOp (opDef "Assign"
|
build $ buildOp (opDef "Assign"
|
||||||
& opAttr "T" .~ tensorType (undefined :: a)
|
& opAttr "T" .~ tensorType (undefined :: a)
|
||||||
& opAttr "use_locking" .~ True
|
& opAttr "use_locking" .~ True
|
||||||
& opAttr "validate_shape" .~ False
|
& opAttr "validate_shape" .~ False
|
||||||
|
@ -179,32 +179,32 @@ initializedVariable initializer = do
|
||||||
|
|
||||||
-- | Creates a zero-initialized variable with the given shape.
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
zeroInitializedVariable
|
zeroInitializedVariable
|
||||||
:: (TensorType a, Num a) =>
|
:: (MonadBuild m, TensorType a, Num a) =>
|
||||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||||
zeroInitializedVariable = initializedVariable . zeros
|
zeroInitializedVariable = initializedVariable . zeros
|
||||||
|
|
||||||
-- TODO: Support heterogeneous list of tensors.
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
save :: forall a v . TensorType a
|
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> [Tensor v a] -- ^ Tensors to save.
|
-> [Tensor v a] -- ^ Tensors to save.
|
||||||
-> Build ControlNode
|
-> m ControlNode
|
||||||
save path xs = do
|
save path xs = do
|
||||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
|
||||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||||
let saveOp = buildOp $ opDef "Save"
|
let saveOp = buildOp $ opDef "Save"
|
||||||
& opAttr "T" .~ types
|
& opAttr "T" .~ types
|
||||||
saveOp (scalar path) (CoreOps.pack names) xs
|
build $ saveOp (scalar path) (CoreOps.pack names) xs
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
--
|
--
|
||||||
-- This version allows restoring from a checkpoint file that uses a different
|
-- This version allows restoring from a checkpoint file that uses a different
|
||||||
-- tensor name than the variable.
|
-- tensor name than the variable.
|
||||||
restoreFromName :: forall a . TensorType a
|
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> ByteString -- ^ Tensor name override.
|
-> ByteString -- ^ Tensor name override.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
-> Build ControlNode
|
-> m ControlNode
|
||||||
restoreFromName path name x = do
|
restoreFromName path name x = do
|
||||||
let restoreOp = buildOp $ opDef "Restore"
|
let restoreOp = buildOp $ opDef "Restore"
|
||||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||||
|
@ -212,12 +212,12 @@ restoreFromName path name x = do
|
||||||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
restore :: forall a . TensorType a
|
restore :: forall a m . (MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
-> Build ControlNode
|
-> m ControlNode
|
||||||
restore path x = do
|
restore path x = do
|
||||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
|
||||||
restoreFromName path name x
|
restoreFromName path name x
|
||||||
|
|
||||||
-- | Create a constant tensor.
|
-- | Create a constant tensor.
|
||||||
|
@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||||
scalar x = constant [] [x]
|
scalar x = constant [] [x]
|
||||||
|
|
||||||
-- Random tensor from the unit normal distribution with bounded values.
|
-- Random tensor from the unit normal distribution with bounded values.
|
||||||
truncatedNormal :: forall a v . TensorType a
|
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
=> Tensor v Int64 -- ^ Shape.
|
=> Tensor v Int64 -- ^ Shape.
|
||||||
-> Build (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
truncatedNormal
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
= build . buildOp (opDef "TruncatedNormal"
|
||||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "T" .~ tensorType (undefined :: Int64))
|
||||||
|
|
||||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Functor.Identity (runIdentity)
|
|
||||||
import Lens.Family2 ((^.))
|
import Lens.Family2 ((^.))
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph
|
import Proto.Tensorflow.Core.Framework.Graph
|
||||||
|
@ -35,7 +34,6 @@ import TensorFlow.Build
|
||||||
, asGraphDef
|
, asGraphDef
|
||||||
, evalBuildT
|
, evalBuildT
|
||||||
, flushNodeBuffer
|
, flushNodeBuffer
|
||||||
, hoistBuildT
|
|
||||||
, render
|
, render
|
||||||
, withDevice
|
, withDevice
|
||||||
, colocateWith
|
, colocateWith
|
||||||
|
@ -53,9 +51,7 @@ import TensorFlow.Ops
|
||||||
import TensorFlow.Output (Device(..))
|
import TensorFlow.Output (Device(..))
|
||||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
( build
|
( run
|
||||||
, buildAnd
|
|
||||||
, run
|
|
||||||
, runSession
|
, runSession
|
||||||
, run_
|
, run_
|
||||||
)
|
)
|
||||||
|
@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||||
assign v 5
|
assign v 5
|
||||||
-- TODO: Implement TensorFlow get_variable and test it.
|
-- TODO: Implement TensorFlow get_variable and test it.
|
||||||
runSession $ do
|
runSession $ do
|
||||||
out <- buildAnd run graph
|
out <- graph >>= run
|
||||||
liftIO $ 5 @=? (unScalar out :: Float)
|
liftIO $ 5 @=? (unScalar out :: Float)
|
||||||
|
|
||||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||||
|
@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
|
||||||
testInitializedVariable :: Test
|
testInitializedVariable :: Test
|
||||||
testInitializedVariable =
|
testInitializedVariable =
|
||||||
testCase "testInitializedVariable" $ runSession $ do
|
testCase "testInitializedVariable" $ runSession $ do
|
||||||
(formula, reset) <- build $ do
|
(formula, reset) <- do
|
||||||
v <- initializedVariable 42
|
v <- initializedVariable 42
|
||||||
r <- assign v 24
|
r <- assign v 24
|
||||||
return (1 `add` v, r)
|
return (1 `add` v, r)
|
||||||
|
@ -109,7 +105,7 @@ testInitializedVariable =
|
||||||
testInitializedVariableShape :: Test
|
testInitializedVariableShape :: Test
|
||||||
testInitializedVariableShape =
|
testInitializedVariableShape =
|
||||||
testCase "testInitializedVariableShape" $ runSession $ do
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
vector <- initializedVariable (constant [1] [42 :: Float])
|
||||||
result <- run vector
|
result <- run vector
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||||
"RefIdentity" @=? (nodeDef ^. op)
|
"RefIdentity" @=? (nodeDef ^. op)
|
||||||
"foo1/bar1" @=? (nodeDef ^. name)
|
"foo1/bar1" @=? (nodeDef ^. name)
|
||||||
|
|
||||||
-- | Lift a Build action into a context for HUnit to run.
|
|
||||||
liftBuild :: Build a -> BuildT IO a
|
|
||||||
liftBuild = hoistBuildT (return . runIdentity)
|
|
||||||
|
|
||||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||||
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
flushed field = sort . map field <$> flushNodeBuffer
|
||||||
|
|
||||||
-- | Test the interaction of rendering, CSE and scoping.
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
testRenderDedup :: Test
|
testRenderDedup :: Test
|
||||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||||
liftBuild renderNodes
|
renderNodes
|
||||||
names <- flushed (^. name)
|
names <- flushed (^. name)
|
||||||
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
||||||
-- Render the nodes in a different scope, which should cause them
|
-- Render the nodes in a different scope, which should cause them
|
||||||
-- to be distinct from the previous ones.
|
-- to be distinct from the previous ones.
|
||||||
liftBuild $ withNameScope "foo" renderNodes
|
withNameScope "foo" renderNodes
|
||||||
scopedNames <- flushed (^. name)
|
scopedNames <- flushed (^. name)
|
||||||
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
||||||
where
|
where
|
||||||
|
@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||||
-- | Test the interaction of rendering, CSE and scoping.
|
-- | Test the interaction of rendering, CSE and scoping.
|
||||||
testDeviceColocation :: Test
|
testDeviceColocation :: Test
|
||||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||||
liftBuild renderNodes
|
renderNodes
|
||||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||||
liftIO $ [ ("Add_2","dev0")
|
liftIO $ [ ("Add_2","dev0")
|
||||||
, ("Const_1","dev0")
|
, ("Const_1","dev0")
|
||||||
|
|
|
@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||||
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
||||||
in monadicIO $ run $ do
|
in monadicIO $ run $ do
|
||||||
fromIntegral numParts @=? length splitParts
|
fromIntegral numParts @=? length splitParts
|
||||||
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
valuesOut <- TF.runSession $ TF.run restitch
|
||||||
V.fromList values @=? valuesOut
|
V.fromList values @=? valuesOut
|
||||||
|
|
||||||
data StitchExample a = StitchExample Int64 [a] [Int32]
|
data StitchExample a = StitchExample Int64 [a] [Int32]
|
||||||
|
|
|
@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Build as TF
|
||||||
import qualified TensorFlow.Nodes as TF
|
|
||||||
|
|
||||||
|
|
||||||
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
|
||||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
|
||||||
|
|
||||||
|
|
||||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||||
|
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
let op = embeddingLookup embedding ids
|
let op = embeddingLookup embedding ids
|
||||||
|
|
||||||
(values, shape) <- buildAndRun $ do
|
(values, shape) <- TF.runSession $ do
|
||||||
vs <- op
|
vs <- op
|
||||||
return (vs, TF.shape vs)
|
TF.run (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
shape @=? V.fromList [1, 2, 3]
|
shape @=? V.fromList [1, 2, 3]
|
||||||
|
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
let op = embeddingLookup [embedding] ids
|
let op = embeddingLookup [embedding] ids
|
||||||
|
|
||||||
(values, shape) <- buildAndRun $ do
|
(values, shape) <- TF.runSession $ do
|
||||||
vs <- op
|
vs <- op
|
||||||
return (vs, TF.shape vs)
|
TF.run (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
shape @=? V.fromList [1, 2, 3]
|
shape @=? V.fromList [1, 2, 3]
|
||||||
|
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
let shape = TF.Shape [2]
|
let shape = TF.Shape [2]
|
||||||
|
|
||||||
gs <- TF.runSession $ do
|
gs <- TF.runSession $ do
|
||||||
grads <- TF.build $ do
|
|
||||||
let embShape = TF.Shape [2, 1]
|
let embShape = TF.Shape [2, 1]
|
||||||
let embeddingInit = [1, 20 ::Float]
|
let embeddingInit = [1, 20 ::Float]
|
||||||
let idValues = [1, 1 :: Int32]
|
let idValues = [1, 1 :: Int32]
|
||||||
|
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||||
|
|
||||||
grad <- fmap head (TF.gradients loss [embedding])
|
grad <- fmap head (TF.gradients loss [embedding])
|
||||||
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
TF.runWithFeeds
|
||||||
|
[TF.feed x $ TF.encodeTensorData shape xVals]
|
||||||
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
grad
|
||||||
-- Gradients should be zero (or close)
|
-- Gradients should be zero (or close)
|
||||||
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
||||||
|
|
||||||
|
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
|
||||||
shapedValues = TF.constant shape values
|
shapedValues = TF.constant shape values
|
||||||
in monadicIO $ run $ do
|
in monadicIO $ run $ do
|
||||||
(shapeOut, got, want :: V.Vector a) <-
|
(shapeOut, got, want :: V.Vector a) <-
|
||||||
TF.runSession $ TF.buildAnd TF.run $ do
|
TF.runSession $ TF.run =<< do
|
||||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
import Data.Int (Int32)
|
import Data.Int (Int32)
|
||||||
|
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
|
||||||
y = x*x + b
|
y = x*x + b
|
||||||
grads = TF.gradients y [x, b]
|
grads = TF.gradients y [x, b]
|
||||||
-- Assert that the gradients are right.
|
-- Assert that the gradients are right.
|
||||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||||
6 @=? TF.unScalar dx
|
6 @=? TF.unScalar dx
|
||||||
1 @=? TF.unScalar db
|
1 @=? TF.unScalar db
|
||||||
-- Assert that the graph has the expected ops.
|
-- Assert that the graph has the expected ops.
|
||||||
|
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||||
b = TF.scalar (4 :: Float)
|
b = TF.scalar (4 :: Float)
|
||||||
grads = TF.gradients x [x, b]
|
grads = TF.gradients x [x, b]
|
||||||
-- Assert that the gradients are right.
|
-- Assert that the gradients are right.
|
||||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||||
1 @=? TF.unScalar dx
|
1 @=? TF.unScalar dx
|
||||||
0 @=? TF.unScalar db
|
0 @=? TF.unScalar db
|
||||||
-- Assert that the graph has the expected ops.
|
-- Assert that the graph has the expected ops.
|
||||||
|
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||||
-- Test that identical "stateful" ops work with createGraph.
|
-- Test that identical "stateful" ops work with createGraph.
|
||||||
testCreateGraphStateful :: Test
|
testCreateGraphStateful :: Test
|
||||||
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||||
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
[dx, dy] <- TF.runSession $ do
|
||||||
let shape = TF.constant (TF.Shape [1]) [1]
|
let shape = TF.constant (TF.Shape [1]) [1]
|
||||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
TF.gradients (x + y*3) [x, y]
|
TF.gradients (x + y*3) [x, y] >>= TF.run
|
||||||
-- If this test fails, it will likely be caused by an exception within
|
-- If this test fails, it will likely be caused by an exception within
|
||||||
-- `TF.gradients`. These asserts are extra.
|
-- `TF.gradients`. These asserts are extra.
|
||||||
1 @=? TF.unScalar dx
|
1 @=? TF.unScalar dx
|
||||||
|
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||||
-- Test that name scopes work with createGraph.
|
-- Test that name scopes work with createGraph.
|
||||||
testCreateGraphNameScopes :: Test
|
testCreateGraphNameScopes :: Test
|
||||||
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
[dx] <- TF.runSession $ do
|
||||||
let shape = TF.constant (TF.Shape [1]) [1]
|
let shape = TF.constant (TF.Shape [1]) [1]
|
||||||
x :: TF.Tensor TF.Value Float <-
|
x :: TF.Tensor TF.Value Float <-
|
||||||
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
||||||
TF.gradients x [x]
|
TF.gradients x [x] >>= TF.run
|
||||||
-- If this test fails, it will likely be caused by an exception within
|
-- If this test fails, it will likely be caused by an exception within
|
||||||
-- `TF.gradients`. This assert is extra.
|
-- `TF.gradients`. This assert is extra.
|
||||||
1 @=? TF.unScalar dx
|
1 @=? TF.unScalar dx
|
||||||
|
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||||
-- Test that createGraph can handle graphs with diamond shapes.
|
-- Test that createGraph can handle graphs with diamond shapes.
|
||||||
testDiamond :: Test
|
testDiamond :: Test
|
||||||
testDiamond = testCase "testDiamond" $ do
|
testDiamond = testCase "testDiamond" $ do
|
||||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
[dx] <- TF.runSession $ do
|
||||||
let x = TF.vector [1]
|
let x = TF.vector [1]
|
||||||
y = x*x
|
y = x*x
|
||||||
z = y*y
|
z = y*y
|
||||||
TF.gradients z [x]
|
TF.gradients z [x] >>= TF.run
|
||||||
(4 :: Float) @=? TF.unScalar dx
|
(4 :: Float) @=? TF.unScalar dx
|
||||||
|
|
||||||
|
|
||||||
testMaxGradient :: Test
|
testMaxGradient :: Test
|
||||||
testMaxGradient = testCase "testMaxGradient" $ do
|
testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
[dx] <- TF.runSession $ do
|
||||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||||
TF.gradients y [x]
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ testSize = testCase "testSize" $ do
|
||||||
TF.Scalar (2 * 3 :: Int32) @=? x
|
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||||
|
|
||||||
eval :: TF.Fetchable t a => t -> IO a
|
eval :: TF.Fetchable t a => t -> IO a
|
||||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
eval = TF.runSession . TF.run
|
||||||
|
|
||||||
-- | Confirms that the original example from Python code works.
|
-- | Confirms that the original example from Python code works.
|
||||||
testReducedShape :: Test
|
testReducedShape :: Test
|
||||||
|
@ -54,16 +54,16 @@ testSaveRestore :: Test
|
||||||
testSaveRestore = testCase "testSaveRestore" $
|
testSaveRestore = testCase "testSaveRestore" $
|
||||||
withSystemTempDirectory "" $ \dirPath -> do
|
withSystemTempDirectory "" $ \dirPath -> do
|
||||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||||
var = TF.render =<<
|
var = TF.render =<<
|
||||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
v <- TF.build var
|
v <- var
|
||||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
TF.assign v 134 >>= TF.run_
|
||||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
TF.save path [v] >>= TF.run_
|
||||||
result <- TF.runSession $ do
|
result <- TF.runSession $ do
|
||||||
v <- TF.build var
|
v <- var
|
||||||
TF.buildAnd TF.run_ $ TF.restore path v
|
TF.restore path v >>= TF.run_
|
||||||
TF.run v
|
TF.run v
|
||||||
liftIO $ TF.Scalar 134 @=? result
|
liftIO $ TF.Scalar 134 @=? result
|
||||||
|
|
||||||
|
|
|
@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
|
||||||
let x = TF.vector xData
|
let x = TF.vector xData
|
||||||
y = TF.vector yData
|
y = TF.vector yData
|
||||||
-- Create scalar variables for slope and intercept.
|
-- Create scalar variables for slope and intercept.
|
||||||
w <- TF.build (TF.initializedVariable 0)
|
w <- TF.initializedVariable 0
|
||||||
b <- TF.build (TF.initializedVariable 0)
|
b <- TF.initializedVariable 0
|
||||||
-- Define the loss function.
|
-- Define the loss function.
|
||||||
let yHat = (x `TF.mul` w) `TF.add` b
|
let yHat = (x `TF.mul` w) `TF.add` b
|
||||||
loss = TF.square (yHat `TF.sub` y)
|
loss = TF.square (yHat `TF.sub` y)
|
||||||
-- Optimize with gradient descent.
|
-- Optimize with gradient descent.
|
||||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||||
replicateM_ 1000 (TF.run trainStep)
|
replicateM_ 1000 (TF.run trainStep)
|
||||||
-- Return the learned parameters.
|
-- Return the learned parameters.
|
||||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||||
|
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
|
||||||
gradientDescent :: Float
|
gradientDescent :: Float
|
||||||
-> TF.Tensor TF.Value Float
|
-> TF.Tensor TF.Value Float
|
||||||
-> [TF.Tensor TF.Ref Float]
|
-> [TF.Tensor TF.Ref Float]
|
||||||
-> TF.Build TF.ControlNode
|
-> TF.Session TF.ControlNode
|
||||||
gradientDescent alpha loss params = do
|
gradientDescent alpha loss params = do
|
||||||
let applyGrad param grad =
|
let applyGrad param grad =
|
||||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||||
|
|
|
@ -35,7 +35,7 @@ testTracing = do
|
||||||
loggedValue <- newEmptyMVar
|
loggedValue <- newEmptyMVar
|
||||||
TF.runSessionWithOptions
|
TF.runSessionWithOptions
|
||||||
(def & TF.sessionTracer .~ putMVar loggedValue)
|
(def & TF.sessionTracer .~ putMVar loggedValue)
|
||||||
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
|
(TF.run_ (TF.scalar (0 :: Float)))
|
||||||
tryReadMVar loggedValue >>=
|
tryReadMVar loggedValue >>=
|
||||||
maybe (assertFailure "Logging never happened") expectedFormat
|
maybe (assertFailure "Logging never happened") expectedFormat
|
||||||
where expectedFormat x =
|
where expectedFormat x =
|
||||||
|
|
|
@ -24,10 +24,11 @@ import Data.ByteString (ByteString)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Data.Proxy (Proxy(..))
|
import Data.Proxy (Proxy(..))
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
|
||||||
import TensorFlow.BuildOp (buildOp)
|
import TensorFlow.BuildOp (buildOp)
|
||||||
import TensorFlow.ControlFlow (group)
|
import TensorFlow.ControlFlow (group)
|
||||||
import TensorFlow.Tensor (Ref, Tensor, TensorList)
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
|
||||||
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
||||||
|
|
||||||
-- | A queue carrying tuples.
|
-- | A queue carrying tuples.
|
||||||
|
@ -36,36 +37,30 @@ data Queue (as :: [*]) = Queue { handle :: Handle }
|
||||||
type Handle = Tensor Ref ByteString
|
type Handle = Tensor Ref ByteString
|
||||||
|
|
||||||
-- | Adds the given values to the queue.
|
-- | Adds the given values to the queue.
|
||||||
enqueue :: forall as v . TensorTypes as
|
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
|
||||||
=> Queue as
|
=> Queue as
|
||||||
-> TensorList v as
|
-> TensorList v as
|
||||||
-> Build ControlNode
|
-> m ControlNode
|
||||||
enqueue q =
|
enqueue = CoreOps.queueEnqueue . handle
|
||||||
buildOp (opDef "QueueEnqueue"
|
|
||||||
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
|
|
||||||
(handle q)
|
|
||||||
|
|
||||||
-- | Retrieves the values from the queue.
|
-- | Retrieves the values from the queue.
|
||||||
dequeue :: forall as . TensorTypes as
|
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||||
=> Queue as
|
=> Queue as
|
||||||
-> Build (TensorList Ref as)
|
-> m (TensorList Value as)
|
||||||
-- ^ Dequeued tensors. They are coupled in a sense
|
-- ^ Dequeued tensors. They are coupled in a sense
|
||||||
-- that values appear together, even if they are
|
-- that values appear together, even if they are
|
||||||
-- not consumed together.
|
-- not consumed together.
|
||||||
dequeue q =
|
dequeue = CoreOps.queueDequeue . handle
|
||||||
buildOp (opDef "QueueDequeue"
|
|
||||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
|
|
||||||
(handle q)
|
|
||||||
|
|
||||||
-- | Creates a new queue with the given capacity and shared name.
|
-- | Creates a new queue with the given capacity and shared name.
|
||||||
makeQueue :: forall as . TensorTypes as
|
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||||
=> Int64 -- ^ The upper bound on the number of elements in
|
=> Int64 -- ^ The upper bound on the number of elements in
|
||||||
-- this queue. Negative numbers mean no limit.
|
-- this queue. Negative numbers mean no limit.
|
||||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||||
-- under the given name across multiple sessions.
|
-- under the given name across multiple sessions.
|
||||||
-> Build (Queue as)
|
-> m (Queue as)
|
||||||
makeQueue capacity sharedName = do
|
makeQueue capacity sharedName = do
|
||||||
q <- buildOp (opDef "FIFOQueue"
|
q <- build $ buildOp (opDef "FIFOQueue"
|
||||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||||
& opAttr "shared_name" .~ sharedName
|
& opAttr "shared_name" .~ sharedName
|
||||||
& opAttr "capacity" .~ capacity
|
& opAttr "capacity" .~ capacity
|
||||||
|
|
|
@ -39,6 +39,7 @@ Test-Suite QueueTest
|
||||||
, lens-family
|
, lens-family
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-queue
|
, tensorflow-queue
|
||||||
, test-framework
|
, test-framework
|
||||||
|
|
|
@ -27,7 +27,6 @@ import TensorFlow.Queue
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
( asyncProdNodes
|
( asyncProdNodes
|
||||||
, build
|
, build
|
||||||
, buildAnd
|
|
||||||
, run
|
, run
|
||||||
, runSession
|
, runSession
|
||||||
, run_
|
, run_
|
||||||
|
@ -41,12 +40,12 @@ import qualified Data.ByteString as BS
|
||||||
testBasic :: Test
|
testBasic :: Test
|
||||||
testBasic = testCase "testBasic" $ runSession $ do
|
testBasic = testCase "testBasic" $ runSession $ do
|
||||||
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||||
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
|
run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
|
||||||
x <- buildAnd run (dequeue q)
|
x <- run =<< dequeue q
|
||||||
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
||||||
|
|
||||||
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
|
run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
|
||||||
y <- buildAnd run (dequeue q)
|
y <- run =<< dequeue q
|
||||||
-- Note: we use explicit "Scalar" here to specify the type that was
|
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||||
-- fetched. Equivalently we could write
|
-- fetched. Equivalently we could write
|
||||||
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||||
|
@ -74,7 +73,7 @@ testPump = testCase "testPump" $ runSession $ do
|
||||||
|
|
||||||
testAsync :: Test
|
testAsync :: Test
|
||||||
testAsync = testCase "testAsync" $ runSession $ do
|
testAsync = testCase "testAsync" $ runSession $ do
|
||||||
(deq, pump) <- build $ do
|
(deq, pump) <- do
|
||||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
||||||
(,) <$> dequeue q
|
(,) <$> dequeue q
|
||||||
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
||||||
|
|
|
@ -37,6 +37,7 @@ module TensorFlow.Build
|
||||||
, renderedNodeDefs
|
, renderedNodeDefs
|
||||||
, BuildT
|
, BuildT
|
||||||
, Build
|
, Build
|
||||||
|
, MonadBuild(..)
|
||||||
, addInitializer
|
, addInitializer
|
||||||
, hoistBuildT
|
, hoistBuildT
|
||||||
, evalBuildT
|
, evalBuildT
|
||||||
|
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
|
||||||
evalBuildT :: Monad m => BuildT m a -> m a
|
evalBuildT :: Monad m => BuildT m a -> m a
|
||||||
evalBuildT (BuildT f) = evalStateT f initGraphState
|
evalBuildT (BuildT f) = evalStateT f initGraphState
|
||||||
|
|
||||||
|
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
|
||||||
|
class Monad m => MonadBuild m where
|
||||||
|
build :: Build a -> m a
|
||||||
|
|
||||||
|
instance Monad m => MonadBuild (BuildT m) where
|
||||||
|
build = hoistBuildT $ return . runIdentity
|
||||||
|
|
||||||
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
||||||
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
flushNodeBuffer :: MonadBuild m => m [NodeDef]
|
||||||
flushNodeBuffer = do
|
flushNodeBuffer = build $ do
|
||||||
ns <- use nodeBuffer
|
ns <- use nodeBuffer
|
||||||
nodeBuffer .= []
|
nodeBuffer .= []
|
||||||
return ns
|
return ns
|
||||||
|
@ -229,8 +237,8 @@ flushInitializers = do
|
||||||
|
|
||||||
-- | Registers the given node to be executed before the next
|
-- | Registers the given node to be executed before the next
|
||||||
-- 'TensorFlow.Session.run'.
|
-- 'TensorFlow.Session.run'.
|
||||||
addInitializer :: ControlNode -> Build ()
|
addInitializer :: MonadBuild m => ControlNode -> m ()
|
||||||
addInitializer (ControlNode o) = do
|
addInitializer (ControlNode o) = build $ do
|
||||||
i <- getOrAddOp o
|
i <- getOrAddOp o
|
||||||
initializationNodes %= (i:)
|
initializationNodes %= (i:)
|
||||||
|
|
||||||
|
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
|
||||||
gs = snd $ runIdentity $ runBuildT b
|
gs = snd $ runIdentity $ runBuildT b
|
||||||
|
|
||||||
-- TODO: check against existing nodes for conflicts?
|
-- TODO: check against existing nodes for conflicts?
|
||||||
addGraphDef :: GraphDef -> Build ()
|
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
||||||
addGraphDef g = nodeBuffer <>= g ^. node
|
addGraphDef g = build $ nodeBuffer <>= g ^. node
|
||||||
|
|
||||||
-- | Render the given op if it hasn't been rendered already, and return its
|
-- | Render the given op if it hasn't been rendered already, and return its
|
||||||
-- name.
|
-- name.
|
||||||
|
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
|
||||||
|
|
||||||
-- | Modify some part of the state, run an action, and restore the state
|
-- | Modify some part of the state, run an action, and restore the state
|
||||||
-- after that action is done.
|
-- after that action is done.
|
||||||
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
|
||||||
withStateLens accessor f act = do
|
withStateLens accessor f act = do
|
||||||
old <- use accessor
|
old <- build $ use accessor
|
||||||
accessor %= f
|
build $ accessor %= f
|
||||||
result <- act
|
result <- act
|
||||||
accessor .= old
|
build $ accessor .= old
|
||||||
return result
|
return result
|
||||||
|
|
||||||
-- | Set a device for all nodes rendered in the given 'Build' action
|
-- | Set a device for all nodes rendered in the given 'Build' action
|
||||||
-- (unless further overridden by another use of withDevice).
|
-- (unless further overridden by another use of withDevice).
|
||||||
withDevice :: Maybe Device -> Build a -> Build a
|
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
||||||
withDevice d = withStateLens defaultDevice (const d)
|
withDevice d = withStateLens defaultDevice (const d)
|
||||||
|
|
||||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||||
-- the action has side effects of rendering the desired tensors. A pure
|
-- the action has side effects of rendering the desired tensors. A pure
|
||||||
-- return would not have the desired effect.
|
-- return would not have the desired effect.
|
||||||
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
|
||||||
colocateWith t x = do
|
colocateWith t x = do
|
||||||
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||||
withDevice (Just d) x
|
withDevice (Just d) x
|
||||||
|
|
||||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||||
withNameScope :: Text -> Build a -> Build a
|
withNameScope :: MonadBuild m => Text -> m a -> m a
|
||||||
withNameScope s = withStateLens currentScope (Scope s :)
|
withNameScope s = withStateLens currentScope (Scope s :)
|
||||||
|
|
||||||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||||
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
||||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||||
|
|
||||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||||
|
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||||
-- This operation is idempotent; @render >=> render === render@. However,
|
-- This operation is idempotent; @render >=> render === render@. However,
|
||||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||||
-- may result in two different 'Tensor's.
|
-- may result in two different 'Tensor's.
|
||||||
render :: Tensor v a -> Build (Tensor v a)
|
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
|
||||||
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
|
||||||
|
|
||||||
-- | Render a 'Tensor' and get its node's name.
|
-- | Render a 'Tensor' and get its node's name.
|
||||||
renderNodeName :: Tensor v a -> Build NodeName
|
renderNodeName :: Tensor v a -> Build NodeName
|
||||||
|
|
|
@ -40,9 +40,9 @@ import TensorFlow.Types
|
||||||
|
|
||||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||||
-- on the nodes in the first argument.
|
-- on the nodes in the first argument.
|
||||||
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
|
||||||
withControlDependencies deps act = do
|
withControlDependencies deps act = do
|
||||||
nodes <- getNodes deps
|
nodes <- build $ getNodes deps
|
||||||
withNodeDependencies nodes act
|
withNodeDependencies nodes act
|
||||||
|
|
||||||
-- TODO(judahjacobson): Reimplement withDependencies.
|
-- TODO(judahjacobson): Reimplement withDependencies.
|
||||||
|
@ -51,9 +51,9 @@ withControlDependencies deps act = do
|
||||||
--
|
--
|
||||||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||||
-- no output.
|
-- no output.
|
||||||
group :: Nodes t => t -> Build ControlNode
|
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
||||||
group deps = do
|
group deps = do
|
||||||
nodes <- Set.toList <$> getNodes deps
|
nodes <- build $ Set.toList <$> getNodes deps
|
||||||
-- TODO: slicker way
|
-- TODO: slicker way
|
||||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,7 @@ module TensorFlow.Core
|
||||||
, runSession
|
, runSession
|
||||||
, runSessionWithOptions
|
, runSessionWithOptions
|
||||||
-- ** Building graphs
|
-- ** Building graphs
|
||||||
, build
|
, MonadBuild(..)
|
||||||
, buildAnd
|
|
||||||
-- ** Running graphs
|
-- ** Running graphs
|
||||||
, Fetchable
|
, Fetchable
|
||||||
, Nodes
|
, Nodes
|
||||||
|
|
|
@ -26,8 +26,7 @@ module TensorFlow.Session (
|
||||||
sessionTracer,
|
sessionTracer,
|
||||||
runSession,
|
runSession,
|
||||||
runSessionWithOptions,
|
runSessionWithOptions,
|
||||||
build,
|
MonadBuild(..),
|
||||||
buildAnd,
|
|
||||||
extend,
|
extend,
|
||||||
addGraphDef,
|
addGraphDef,
|
||||||
run,
|
run,
|
||||||
|
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
|
||||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||||
import Data.ByteString (ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Data.Default (Default, def)
|
import Data.Default (Default, def)
|
||||||
import Data.Functor.Identity (runIdentity)
|
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
import Data.ProtoLens (showMessage)
|
import Data.ProtoLens (showMessage)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
|
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
|
||||||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||||
|
|
||||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
instance MonadBuild Session where
|
||||||
-- renderings.
|
build = Session . lift . build
|
||||||
build :: Build a -> Session a
|
|
||||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
|
||||||
|
|
||||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||||
-- any pending initializers.
|
-- any pending initializers.
|
||||||
|
@ -147,13 +143,6 @@ extend = do
|
||||||
unless (null initializers) $
|
unless (null initializers) $
|
||||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||||
|
|
||||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
|
||||||
-- Example usage:
|
|
||||||
--
|
|
||||||
-- > buildAnd run :: Fetchable t a => Build t -> Session a
|
|
||||||
buildAnd :: (a -> Session b) -> Build a -> Session b
|
|
||||||
buildAnd f m = build m >>= f
|
|
||||||
|
|
||||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||||
-- rendered, and fetch the corresponding values for 'a'.
|
-- rendered, and fetch the corresponding values for 'a'.
|
||||||
run :: Fetchable t a => t -> Session a
|
run :: Fetchable t a => t -> Session a
|
||||||
|
|
Loading…
Reference in a new issue