mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 02:29:46 +01:00
Introduce a MonadBuild class, and remove buildAnd
. (#83)
This change adds a class that both `Build` and `Session` are instances of: class MonadBuild m where build :: Build a -> m a All stateful ops (generated and manually written) now have a signature that returns an instance of `MonadBuild` (rather than just `Build`). For example: assign_ :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v t -> m (Tensor Ref t) This lets us remove a bunch of spurious calls to `build` in user code. It also lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run` (or `run =<< foo`, which is sometimes nicer when foo is a complicated expression). I went ahead and deleted `buildAnd` altogether since it seems to lead to confusion; in particular a few tests had `buildAnd run . pure` which is actually equivalent to just `run`.
This commit is contained in:
parent
9209dfc4c4
commit
2c5c879037
22 changed files with 152 additions and 162 deletions
|
@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
|
|||
let x = TF.vector xData
|
||||
y = TF.vector yData
|
||||
-- Create scalar variables for slope and intercept.
|
||||
w <- TF.build (TF.initializedVariable 0)
|
||||
b <- TF.build (TF.initializedVariable 0)
|
||||
w <- TF.initializedVariable 0
|
||||
b <- TF.initializedVariable 0
|
||||
-- Define the loss function.
|
||||
let yHat = (x `TF.mul` w) `TF.add` b
|
||||
loss = TF.square (yHat `TF.sub` y)
|
||||
-- Optimize with gradient descent.
|
||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||
replicateM_ 1000 (TF.run trainStep)
|
||||
-- Return the learned parameters.
|
||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||
|
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
|
|||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Build TF.ControlNode
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
let applyGrad param grad =
|
||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||
|
|
|
@ -49,7 +49,7 @@ import TensorFlow.Tensor
|
|||
)
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Session
|
||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||
(runSession, run, run_, runWithFeeds, build)
|
||||
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -108,7 +108,7 @@ testGraphDefExec :: Test
|
|||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||
runSession $ do
|
||||
build $ addGraphDef graphDef
|
||||
addGraphDef graphDef
|
||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||
liftIO $ (50 :: Float) @=? unScalar x
|
||||
|
||||
|
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
|||
wtsCkptPath <- liftIO wtsCkpt
|
||||
biasCkptPath <- liftIO biasCkpt
|
||||
-- Run those restoring nodes on the graph in the current session.
|
||||
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
||||
run_ =<< (sequence :: Monad m => [m a] -> m [a])
|
||||
[ restore wtsCkptPath wts
|
||||
, restoreFromName biasCkptPath "bias" bias
|
||||
]
|
||||
|
|
|
@ -23,7 +23,7 @@ module TensorFlow.NN
|
|||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build
|
||||
import TensorFlow.Build ( MonadBuild
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
|
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
|
|||
--
|
||||
-- `logits` and `targets` must have the same type and shape.
|
||||
sigmoidCrossEntropyWithLogits
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
=> Tensor Value a -- ^ __logits__
|
||||
-> Tensor Value a -- ^ __targets__
|
||||
-> Build (Tensor Value a)
|
||||
-> m (Tensor Value a)
|
||||
sigmoidCrossEntropyWithLogits logits targets = do
|
||||
logits' <- render logits
|
||||
targets' <- render targets
|
||||
|
|
|
@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
|
|||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.NN as TF
|
||||
|
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
|
|||
|
||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||
|
||||
run :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
run = TF.runSession . TF.buildAnd TF.run
|
||||
run :: TF.Fetchable t a => TF.Session t -> IO a
|
||||
run = TF.runSession . (>>= TF.run)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientAtZero
|
||||
|
|
|
@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc
|
|||
renderHaskellAttrName = renderHaskellName . attrName
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
where
|
||||
maybeLift
|
||||
| parsedOpIsMonadic pOp = "build $"
|
||||
| otherwise = ""
|
||||
buildFunction
|
||||
| null outputListsSizes = "buildOp"
|
||||
| otherwise = "buildListOp" <+>
|
||||
|
@ -277,13 +280,18 @@ typeSig pOp = constraints
|
|||
++ [outputs])
|
||||
where
|
||||
constraints
|
||||
| null (inferredTypeAttrs pOp) = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
| null classConstraints = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ map tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||
monadConstraint
|
||||
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
|
||||
| otherwise = []
|
||||
classConstraints = monadConstraint ++ map tensorArgConstraint
|
||||
(inferredTypeAttrs pOp)
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||
|
@ -304,7 +312,7 @@ typeSig pOp = constraints
|
|||
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
||||
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
||||
wrapOutput o
|
||||
| parsedOpIsMonadic pOp = "Build" <+> parens o
|
||||
| parsedOpIsMonadic pOp = "m'" <+> parens o
|
||||
| otherwise = o
|
||||
|
||||
-- | Render an op input or output.
|
||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
|||
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Build (MonadBuild, colocateWith, render)
|
||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
|
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
|||
--
|
||||
-- The results of the lookup are concatenated into a dense
|
||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v .
|
||||
( TensorType a
|
||||
embeddingLookup :: forall a b v m .
|
||||
( MonadBuild m
|
||||
, TensorType a
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
)
|
||||
|
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
|
|||
-- containing the ids to be looked up in `params`.
|
||||
-- The ids are required to have fewer than 2^31
|
||||
-- entries.
|
||||
-> Build (Tensor Value a)
|
||||
-> m (Tensor Value a)
|
||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
||||
embeddingLookup params@(p0 : _) ids = do
|
||||
|
|
|
@ -56,7 +56,9 @@ import qualified Data.Text as Text
|
|||
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Build
|
||||
( Build
|
||||
( MonadBuild
|
||||
, Build
|
||||
, build
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
|
@ -111,16 +113,17 @@ type GradientCompatible a =
|
|||
|
||||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||
, Num (Tensor v1 a)
|
||||
-- TODO(gnezdo): remove indirect constraint.
|
||||
-- It's a wart inherited from Num instance.
|
||||
-- It's a wart inherited from Num instance.
|
||||
, v1 ~ Value
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||
-> Build [Tensor Value a]
|
||||
gradients y xs = do
|
||||
-> m [Tensor Value a]
|
||||
gradients y xs = build $ do
|
||||
-- The gradients are computed using "reverse accumulation", similarly to
|
||||
-- what is described here:
|
||||
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
||||
|
|
|
@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a
|
|||
=> Tensor v a -> Tensor Value a
|
||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||
|
||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
placeholder shape' =
|
||||
buildOp $ opDef "Placeholder"
|
||||
build $ buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
=> Tensor Value a -> Build (Tensor Ref a)
|
||||
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> Tensor Value a -> m (Tensor Ref a)
|
||||
initializedVariable initializer = do
|
||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
buildOp (opDef "Assign"
|
||||
build $ buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
& opAttr "validate_shape" .~ False
|
||||
|
@ -179,32 +179,32 @@ initializedVariable initializer = do
|
|||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||
:: (MonadBuild m, TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a v . TensorType a
|
||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
save path xs = do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
||||
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
|
||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||
let saveOp = buildOp $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
saveOp (scalar path) (CoreOps.pack names) xs
|
||||
build $ saveOp (scalar path) (CoreOps.pack names) xs
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
--
|
||||
-- This version allows restoring from a checkpoint file that uses a different
|
||||
-- tensor name than the variable.
|
||||
restoreFromName :: forall a . TensorType a
|
||||
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> ByteString -- ^ Tensor name override.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
restoreFromName path name x = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
|
@ -212,12 +212,12 @@ restoreFromName path name x = do
|
|||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a . TensorType a
|
||||
restore :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
restore path x = do
|
||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
||||
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
|
||||
restoreFromName path name x
|
||||
|
||||
-- | Create a constant tensor.
|
||||
|
@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a
|
|||
scalar x = constant [] [x]
|
||||
|
||||
-- Random tensor from the unit normal distribution with bounded values.
|
||||
truncatedNormal :: forall a v . TensorType a
|
||||
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
|
||||
=> Tensor v Int64 -- ^ Shape.
|
||||
-> Build (Tensor Value a)
|
||||
truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||
-> m (Tensor Value a)
|
||||
truncatedNormal
|
||||
= build . buildOp (opDef "TruncatedNormal"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64))
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Lens.Family2 ((^.))
|
||||
import Data.List (sort)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
|
@ -35,7 +34,6 @@ import TensorFlow.Build
|
|||
, asGraphDef
|
||||
, evalBuildT
|
||||
, flushNodeBuffer
|
||||
, hoistBuildT
|
||||
, render
|
||||
, withDevice
|
||||
, colocateWith
|
||||
|
@ -53,9 +51,7 @@ import TensorFlow.Ops
|
|||
import TensorFlow.Output (Device(..))
|
||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||
import TensorFlow.Session
|
||||
( build
|
||||
, buildAnd
|
||||
, run
|
||||
( run
|
||||
, runSession
|
||||
, run_
|
||||
)
|
||||
|
@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
|
|||
assign v 5
|
||||
-- TODO: Implement TensorFlow get_variable and test it.
|
||||
runSession $ do
|
||||
out <- buildAnd run graph
|
||||
out <- graph >>= run
|
||||
liftIO $ 5 @=? (unScalar out :: Float)
|
||||
|
||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||
|
@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
|
|||
testInitializedVariable :: Test
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- build $ do
|
||||
(formula, reset) <- do
|
||||
v <- initializedVariable 42
|
||||
r <- assign v 24
|
||||
return (1 `add` v, r)
|
||||
|
@ -109,7 +105,7 @@ testInitializedVariable =
|
|||
testInitializedVariableShape :: Test
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
vector <- initializedVariable (constant [1] [42 :: Float])
|
||||
result <- run vector
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
|
@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
|||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"foo1/bar1" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Lift a Build action into a context for HUnit to run.
|
||||
liftBuild :: Build a -> BuildT IO a
|
||||
liftBuild = hoistBuildT (return . runIdentity)
|
||||
|
||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||
flushed field = sort . map field <$> flushNodeBuffer
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testRenderDedup :: Test
|
||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
renderNodes
|
||||
names <- flushed (^. name)
|
||||
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
||||
-- Render the nodes in a different scope, which should cause them
|
||||
-- to be distinct from the previous ones.
|
||||
liftBuild $ withNameScope "foo" renderNodes
|
||||
withNameScope "foo" renderNodes
|
||||
scopedNames <- flushed (^. name)
|
||||
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
||||
where
|
||||
|
@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
|||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testDeviceColocation :: Test
|
||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
renderNodes
|
||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||
liftIO $ [ ("Add_2","dev0")
|
||||
, ("Const_1","dev0")
|
||||
|
|
|
@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
|||
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
||||
in monadicIO $ run $ do
|
||||
fromIntegral numParts @=? length splitParts
|
||||
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
||||
valuesOut <- TF.runSession $ TF.run restitch
|
||||
V.fromList values @=? valuesOut
|
||||
|
||||
data StitchExample a = StitchExample Int64 [a] [Int32]
|
||||
|
|
|
@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
|
|||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
|
||||
|
||||
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
|
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
|||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup embedding ids
|
||||
|
||||
(values, shape) <- buildAndRun $ do
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
return (vs, TF.shape vs)
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
|
|||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup [embedding] ids
|
||||
|
||||
(values, shape) <- buildAndRun $ do
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
return (vs, TF.shape vs)
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
let shape = TF.Shape [2]
|
||||
|
||||
gs <- TF.runSession $ do
|
||||
grads <- TF.build $ do
|
||||
let embShape = TF.Shape [2, 1]
|
||||
let embeddingInit = [1, 20 ::Float]
|
||||
let idValues = [1, 1 :: Int32]
|
||||
|
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||
|
||||
grad <- fmap head (TF.gradients loss [embedding])
|
||||
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
||||
|
||||
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
||||
TF.runWithFeeds
|
||||
[TF.feed x $ TF.encodeTensorData shape xVals]
|
||||
grad
|
||||
-- Gradients should be zero (or close)
|
||||
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
||||
|
||||
|
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
|
|||
shapedValues = TF.constant shape values
|
||||
in monadicIO $ run $ do
|
||||
(shapeOut, got, want :: V.Vector a) <-
|
||||
TF.runSession $ TF.buildAnd TF.run $ do
|
||||
TF.runSession $ TF.run =<< do
|
||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.Int (Int32)
|
||||
|
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
|
|||
y = x*x + b
|
||||
grads = TF.gradients y [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
6 @=? TF.unScalar dx
|
||||
1 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
|
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
b = TF.scalar (4 :: Float)
|
||||
grads = TF.gradients x [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
1 @=? TF.unScalar dx
|
||||
0 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
|
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
-- Test that identical "stateful" ops work with createGraph.
|
||||
testCreateGraphStateful :: Test
|
||||
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx, dy] <- TF.runSession $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
TF.gradients (x + y*3) [x, y]
|
||||
TF.gradients (x + y*3) [x, y] >>= TF.run
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. These asserts are extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
|||
-- Test that name scopes work with createGraph.
|
||||
testCreateGraphNameScopes :: Test
|
||||
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <-
|
||||
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
||||
TF.gradients x [x]
|
||||
TF.gradients x [x] >>= TF.run
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. This assert is extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
|||
-- Test that createGraph can handle graphs with diamond shapes.
|
||||
testDiamond :: Test
|
||||
testDiamond = testCase "testDiamond" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1]
|
||||
y = x*x
|
||||
z = y*y
|
||||
TF.gradients z [x]
|
||||
TF.gradients z [x] >>= TF.run
|
||||
(4 :: Float) @=? TF.unScalar dx
|
||||
|
||||
|
||||
testMaxGradient :: Test
|
||||
testMaxGradient = testCase "testMaxGradient" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||
TF.gradients y [x]
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ testSize = testCase "testSize" $ do
|
|||
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||
|
||||
eval :: TF.Fetchable t a => t -> IO a
|
||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||
eval = TF.runSession . TF.run
|
||||
|
||||
-- | Confirms that the original example from Python code works.
|
||||
testReducedShape :: Test
|
||||
|
@ -54,16 +54,16 @@ testSaveRestore :: Test
|
|||
testSaveRestore = testCase "testSaveRestore" $
|
||||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<<
|
||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||
v <- var
|
||||
TF.assign v 134 >>= TF.run_
|
||||
TF.save path [v] >>= TF.run_
|
||||
result <- TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.restore path v
|
||||
v <- var
|
||||
TF.restore path v >>= TF.run_
|
||||
TF.run v
|
||||
liftIO $ TF.Scalar 134 @=? result
|
||||
|
||||
|
|
|
@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
|
|||
let x = TF.vector xData
|
||||
y = TF.vector yData
|
||||
-- Create scalar variables for slope and intercept.
|
||||
w <- TF.build (TF.initializedVariable 0)
|
||||
b <- TF.build (TF.initializedVariable 0)
|
||||
w <- TF.initializedVariable 0
|
||||
b <- TF.initializedVariable 0
|
||||
-- Define the loss function.
|
||||
let yHat = (x `TF.mul` w) `TF.add` b
|
||||
loss = TF.square (yHat `TF.sub` y)
|
||||
-- Optimize with gradient descent.
|
||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||
replicateM_ 1000 (TF.run trainStep)
|
||||
-- Return the learned parameters.
|
||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||
|
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
|
|||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Build TF.ControlNode
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
let applyGrad param grad =
|
||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||
|
|
|
@ -35,7 +35,7 @@ testTracing = do
|
|||
loggedValue <- newEmptyMVar
|
||||
TF.runSessionWithOptions
|
||||
(def & TF.sessionTracer .~ putMVar loggedValue)
|
||||
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
|
||||
(TF.run_ (TF.scalar (0 :: Float)))
|
||||
tryReadMVar loggedValue >>=
|
||||
maybe (assertFailure "Logging never happened") expectedFormat
|
||||
where expectedFormat x =
|
||||
|
|
|
@ -24,10 +24,11 @@ import Data.ByteString (ByteString)
|
|||
import Data.Int (Int64)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.BuildOp (buildOp)
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Tensor (Ref, Tensor, TensorList)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
|
||||
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
||||
|
||||
-- | A queue carrying tuples.
|
||||
|
@ -36,36 +37,30 @@ data Queue (as :: [*]) = Queue { handle :: Handle }
|
|||
type Handle = Tensor Ref ByteString
|
||||
|
||||
-- | Adds the given values to the queue.
|
||||
enqueue :: forall as v . TensorTypes as
|
||||
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
|
||||
=> Queue as
|
||||
-> TensorList v as
|
||||
-> Build ControlNode
|
||||
enqueue q =
|
||||
buildOp (opDef "QueueEnqueue"
|
||||
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||
(handle q)
|
||||
-> m ControlNode
|
||||
enqueue = CoreOps.queueEnqueue . handle
|
||||
|
||||
-- | Retrieves the values from the queue.
|
||||
dequeue :: forall as . TensorTypes as
|
||||
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||
=> Queue as
|
||||
-> Build (TensorList Ref as)
|
||||
-> m (TensorList Value as)
|
||||
-- ^ Dequeued tensors. They are coupled in a sense
|
||||
-- that values appear together, even if they are
|
||||
-- not consumed together.
|
||||
dequeue q =
|
||||
buildOp (opDef "QueueDequeue"
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
|
||||
(handle q)
|
||||
dequeue = CoreOps.queueDequeue . handle
|
||||
|
||||
-- | Creates a new queue with the given capacity and shared name.
|
||||
makeQueue :: forall as . TensorTypes as
|
||||
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||
=> Int64 -- ^ The upper bound on the number of elements in
|
||||
-- this queue. Negative numbers mean no limit.
|
||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||
-- under the given name across multiple sessions.
|
||||
-> Build (Queue as)
|
||||
-> m (Queue as)
|
||||
makeQueue capacity sharedName = do
|
||||
q <- buildOp (opDef "FIFOQueue"
|
||||
q <- build $ buildOp (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||
& opAttr "shared_name" .~ sharedName
|
||||
& opAttr "capacity" .~ capacity
|
||||
|
|
|
@ -39,6 +39,7 @@ Test-Suite QueueTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-queue
|
||||
, test-framework
|
||||
|
|
|
@ -27,7 +27,6 @@ import TensorFlow.Queue
|
|||
import TensorFlow.Session
|
||||
( asyncProdNodes
|
||||
, build
|
||||
, buildAnd
|
||||
, run
|
||||
, runSession
|
||||
, run_
|
||||
|
@ -41,12 +40,12 @@ import qualified Data.ByteString as BS
|
|||
testBasic :: Test
|
||||
testBasic = testCase "testBasic" $ runSession $ do
|
||||
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
|
||||
x <- buildAnd run (dequeue q)
|
||||
run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
|
||||
x <- run =<< dequeue q
|
||||
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
||||
|
||||
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
|
||||
y <- buildAnd run (dequeue q)
|
||||
run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
|
||||
y <- run =<< dequeue q
|
||||
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||
-- fetched. Equivalently we could write
|
||||
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||
|
@ -74,7 +73,7 @@ testPump = testCase "testPump" $ runSession $ do
|
|||
|
||||
testAsync :: Test
|
||||
testAsync = testCase "testAsync" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
(deq, pump) <- do
|
||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
||||
|
|
|
@ -37,6 +37,7 @@ module TensorFlow.Build
|
|||
, renderedNodeDefs
|
||||
, BuildT
|
||||
, Build
|
||||
, MonadBuild(..)
|
||||
, addInitializer
|
||||
, hoistBuildT
|
||||
, evalBuildT
|
||||
|
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
|
|||
evalBuildT :: Monad m => BuildT m a -> m a
|
||||
evalBuildT (BuildT f) = evalStateT f initGraphState
|
||||
|
||||
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
|
||||
class Monad m => MonadBuild m where
|
||||
build :: Build a -> m a
|
||||
|
||||
instance Monad m => MonadBuild (BuildT m) where
|
||||
build = hoistBuildT $ return . runIdentity
|
||||
|
||||
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
||||
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
||||
flushNodeBuffer = do
|
||||
flushNodeBuffer :: MonadBuild m => m [NodeDef]
|
||||
flushNodeBuffer = build $ do
|
||||
ns <- use nodeBuffer
|
||||
nodeBuffer .= []
|
||||
return ns
|
||||
|
@ -229,8 +237,8 @@ flushInitializers = do
|
|||
|
||||
-- | Registers the given node to be executed before the next
|
||||
-- 'TensorFlow.Session.run'.
|
||||
addInitializer :: ControlNode -> Build ()
|
||||
addInitializer (ControlNode o) = do
|
||||
addInitializer :: MonadBuild m => ControlNode -> m ()
|
||||
addInitializer (ControlNode o) = build $ do
|
||||
i <- getOrAddOp o
|
||||
initializationNodes %= (i:)
|
||||
|
||||
|
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
|
|||
gs = snd $ runIdentity $ runBuildT b
|
||||
|
||||
-- TODO: check against existing nodes for conflicts?
|
||||
addGraphDef :: GraphDef -> Build ()
|
||||
addGraphDef g = nodeBuffer <>= g ^. node
|
||||
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
||||
addGraphDef g = build $ nodeBuffer <>= g ^. node
|
||||
|
||||
-- | Render the given op if it hasn't been rendered already, and return its
|
||||
-- name.
|
||||
|
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
|
|||
|
||||
-- | Modify some part of the state, run an action, and restore the state
|
||||
-- after that action is done.
|
||||
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
||||
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
|
||||
withStateLens accessor f act = do
|
||||
old <- use accessor
|
||||
accessor %= f
|
||||
old <- build $ use accessor
|
||||
build $ accessor %= f
|
||||
result <- act
|
||||
accessor .= old
|
||||
build $ accessor .= old
|
||||
return result
|
||||
|
||||
-- | Set a device for all nodes rendered in the given 'Build' action
|
||||
-- (unless further overridden by another use of withDevice).
|
||||
withDevice :: Maybe Device -> Build a -> Build a
|
||||
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
||||
withDevice d = withStateLens defaultDevice (const d)
|
||||
|
||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
||||
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
|
||||
colocateWith t x = do
|
||||
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
withDevice (Just d) x
|
||||
|
||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||
withNameScope :: Text -> Build a -> Build a
|
||||
withNameScope :: MonadBuild m => Text -> m a -> m a
|
||||
withNameScope s = withStateLens currentScope (Scope s :)
|
||||
|
||||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
||||
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||
|
||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||
|
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
|||
-- This operation is idempotent; @render >=> render === render@. However,
|
||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||
-- may result in two different 'Tensor's.
|
||||
render :: Tensor v a -> Build (Tensor v a)
|
||||
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
||||
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
|
||||
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
|
||||
|
||||
-- | Render a 'Tensor' and get its node's name.
|
||||
renderNodeName :: Tensor v a -> Build NodeName
|
||||
|
|
|
@ -40,9 +40,9 @@ import TensorFlow.Types
|
|||
|
||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||
-- on the nodes in the first argument.
|
||||
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
||||
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
|
||||
withControlDependencies deps act = do
|
||||
nodes <- getNodes deps
|
||||
nodes <- build $ getNodes deps
|
||||
withNodeDependencies nodes act
|
||||
|
||||
-- TODO(judahjacobson): Reimplement withDependencies.
|
||||
|
@ -51,9 +51,9 @@ withControlDependencies deps act = do
|
|||
--
|
||||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||
-- no output.
|
||||
group :: Nodes t => t -> Build ControlNode
|
||||
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
||||
group deps = do
|
||||
nodes <- Set.toList <$> getNodes deps
|
||||
nodes <- build $ Set.toList <$> getNodes deps
|
||||
-- TODO: slicker way
|
||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||
|
||||
|
|
|
@ -31,8 +31,7 @@ module TensorFlow.Core
|
|||
, runSession
|
||||
, runSessionWithOptions
|
||||
-- ** Building graphs
|
||||
, build
|
||||
, buildAnd
|
||||
, MonadBuild(..)
|
||||
-- ** Running graphs
|
||||
, Fetchable
|
||||
, Nodes
|
||||
|
|
|
@ -26,8 +26,7 @@ module TensorFlow.Session (
|
|||
sessionTracer,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
build,
|
||||
buildAnd,
|
||||
MonadBuild(..),
|
||||
extend,
|
||||
addGraphDef,
|
||||
run,
|
||||
|
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
|
|||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.Set (Set)
|
||||
|
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
|
|||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings.
|
||||
build :: Build a -> Session a
|
||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||
instance MonadBuild Session where
|
||||
build = Session . lift . build
|
||||
|
||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||
-- any pending initializers.
|
||||
|
@ -147,13 +143,6 @@ extend = do
|
|||
unless (null initializers) $
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||
-- Example usage:
|
||||
--
|
||||
-- > buildAnd run :: Fetchable t a => Build t -> Session a
|
||||
buildAnd :: (a -> Session b) -> Build a -> Session b
|
||||
buildAnd f m = build m >>= f
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, and fetch the corresponding values for 'a'.
|
||||
run :: Fetchable t a => t -> Session a
|
||||
|
|
Loading…
Reference in a new issue