mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Distinguish between "rendered" and "unrendered" Tensors. (#88)
Distinguish between "rendered" and "unrendered" Tensors. There are now three types of `Tensor`: - `Tensor Value a`: rendered value - `Tensor Ref a`: rendered reference - `Tensor Build a` : unrendered value The extra bookkeeping makes it easier to track (and enforce) which tensors are rendered or not. For examples where this has been confusing in the past, see With this change, pure ops look similar to before, returning `Tensor Build` instead of `Tensor Value`. "Stateful" (monadic) ops are unchanged. For example: add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t assign :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t) The `gradients` function now requires that the variables over which it's differentiating are pre-rendered: gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a] -> m [Tensor Value a] (`Rendered v2` means that `v2` is either a `Ref` or a `Value`.) Additionally, the implementation of `gradients` now takes care to render every intermediate value when performing the reverse accumulation. I suspect this fixes an exponential blowup for complicated expressions.
This commit is contained in:
parent
d71f48090a
commit
d62c614695
29 changed files with 636 additions and 608 deletions
|
@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do
|
||||||
return (w', b')
|
return (w', b')
|
||||||
|
|
||||||
gradientDescent :: Float
|
gradientDescent :: Float
|
||||||
-> TF.Tensor TF.Value Float
|
-> TF.Tensor TF.Build Float
|
||||||
-> [TF.Tensor TF.Ref Float]
|
-> [TF.Tensor TF.Ref Float]
|
||||||
-> TF.Session TF.ControlNode
|
-> TF.Session TF.ControlNode
|
||||||
gradientDescent alpha loss params = do
|
gradientDescent alpha loss params = do
|
||||||
|
|
|
@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
||||||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
||||||
import System.Directory (createDirectoryIfMissing)
|
import System.Directory (createDirectoryIfMissing)
|
||||||
import System.FilePath ((</>))
|
import System.FilePath ((</>))
|
||||||
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
|
import TensorFlow.Build (MonadBuild)
|
||||||
import TensorFlow.Ops (scalar)
|
import TensorFlow.Ops (scalar)
|
||||||
import TensorFlow.Records.Conduit (sinkTFRecords)
|
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||||
import TensorFlow.Tensor (Tensor)
|
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||||
import TensorFlow.Types (TensorType, type(/=))
|
import TensorFlow.Types (TensorType, type(/=))
|
||||||
import Text.Printf (printf)
|
import Text.Printf (printf)
|
||||||
import qualified Data.ByteString.Lazy as L
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime
|
||||||
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
|
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
|
||||||
-- limited to a single value for simplicity.
|
-- limited to a single value for simplicity.
|
||||||
histogramSummary ::
|
histogramSummary ::
|
||||||
(TensorType t, t /= ByteString, t /= Bool)
|
(MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
|
||||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||||
=> ByteString -> Tensor v t -> Build ()
|
=> ByteString -> Tensor v t -> m ()
|
||||||
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
|
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
|
||||||
|
|
||||||
-- | Adds a 'CoreOps.scalarSummary' node.
|
-- | Adds a 'CoreOps.scalarSummary' node.
|
||||||
scalarSummary ::
|
scalarSummary ::
|
||||||
(TensorType t, t /= ByteString, t /= Bool)
|
(TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
|
||||||
-- (TensorType t,
|
-- (TensorType t,
|
||||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||||
=> ByteString -> Tensor v t -> Build ()
|
=> ByteString -> Tensor v t -> m ()
|
||||||
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
|
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
|
||||||
|
|
||||||
-- | Merge all summaries accumulated in the 'Build' into one summary.
|
-- | Merge all summaries accumulated in the 'Build' into one summary.
|
||||||
mergeAllSummaries :: Build SummaryTensor
|
mergeAllSummaries :: MonadBuild m => m SummaryTensor
|
||||||
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary
|
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary
|
||||||
|
|
|
@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64
|
||||||
numLabels = 10 :: Int64
|
numLabels = 10 :: Int64
|
||||||
|
|
||||||
-- | Create tensor with random values where the stddev depends on the width.
|
-- | Create tensor with random values where the stddev depends on the width.
|
||||||
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
|
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float)
|
||||||
randomParam width (TF.Shape shape) =
|
randomParam width (TF.Shape shape) =
|
||||||
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
|
(`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape)
|
||||||
where
|
where
|
||||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||||
|
|
||||||
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
|
reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float
|
||||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||||
|
|
||||||
-- Types must match due to model structure.
|
-- Types must match due to model structure.
|
||||||
|
@ -87,7 +87,7 @@ createModel = do
|
||||||
grads <- TF.gradients loss params
|
grads <- TF.gradients loss params
|
||||||
|
|
||||||
let lr = TF.scalar 0.00001
|
let lr = TF.scalar 0.00001
|
||||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
|
||||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
let correctPredictions = TF.equal predict labels
|
let correctPredictions = TF.equal predict labels
|
||||||
|
|
|
@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
( asGraphDef
|
( asGraphDef
|
||||||
, addGraphDef
|
, addGraphDef
|
||||||
, render
|
, Build
|
||||||
)
|
)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
( Tensor(..)
|
( Tensor(..)
|
||||||
, Ref
|
, Ref
|
||||||
, Value
|
|
||||||
, feed
|
, feed
|
||||||
, TensorKind(..)
|
, render
|
||||||
, tensorFromName
|
, tensorFromName
|
||||||
|
, tensorValueFromName
|
||||||
)
|
)
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
|
@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do
|
||||||
labelData <- readMNISTLabels =<< testLabelData
|
labelData <- readMNISTLabels =<< testLabelData
|
||||||
10000 @=? length labelData
|
10000 @=? length labelData
|
||||||
|
|
||||||
testNodeName :: Text -> Tensor v a -> Assertion
|
testNodeName :: Text -> Tensor Build a -> Assertion
|
||||||
testNodeName n g = n @=? opName
|
testNodeName n g = n @=? opName
|
||||||
where
|
where
|
||||||
opName = head (gDef^.node)^.op
|
opName = head (gDef^.node)^.op
|
||||||
|
@ -89,7 +89,7 @@ testNodeName n g = n @=? opName
|
||||||
testGraphDefGen :: Test
|
testGraphDefGen :: Test
|
||||||
testGraphDefGen = testCase "testGraphDefGen" $ do
|
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||||
-- Test the inferred operation type.
|
-- Test the inferred operation type.
|
||||||
let f0 :: Tensor Value Float
|
let f0 :: Tensor Build Float
|
||||||
f0 = 0
|
f0 = 0
|
||||||
testNodeName "Const" f0
|
testNodeName "Const" f0
|
||||||
testNodeName "Add" $ 1 + f0
|
testNodeName "Add" $ 1 + f0
|
||||||
|
@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||||
runSession $ do
|
runSession $ do
|
||||||
addGraphDef graphDef
|
addGraphDef graphDef
|
||||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
x <- run $ tensorValueFromName "Mul_2"
|
||||||
liftIO $ (50 :: Float) @=? unScalar x
|
liftIO $ (50 :: Float) @=? unScalar x
|
||||||
|
|
||||||
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||||
|
@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
||||||
build $ addGraphDef $ mnist & version .~ 0
|
build $ addGraphDef $ mnist & version .~ 0
|
||||||
-- Define nodes that restore saved weights and biases.
|
-- Define nodes that restore saved weights and biases.
|
||||||
let bias, wts :: Tensor Ref Float
|
let bias, wts :: Tensor Ref Float
|
||||||
bias = tensorFromName RefKind "Variable"
|
bias = tensorFromName "Variable"
|
||||||
wts = tensorFromName RefKind "weights"
|
wts = tensorFromName "weights"
|
||||||
wtsCkptPath <- liftIO wtsCkpt
|
wtsCkptPath <- liftIO wtsCkpt
|
||||||
biasCkptPath <- liftIO biasCkpt
|
biasCkptPath <- liftIO biasCkpt
|
||||||
-- Run those restoring nodes on the graph in the current session.
|
-- Run those restoring nodes on the graph in the current session.
|
||||||
|
@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
||||||
let ty = encodeTensorData [10] oneHotLabels
|
let ty = encodeTensorData [10] oneHotLabels
|
||||||
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
||||||
updates = [(fromIntegral label, 1)]
|
updates = [(fromIntegral label, 1)]
|
||||||
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
|
let feeds = [ feed (tensorValueFromName "x-input") tensorSample
|
||||||
, feed (tensorFromName ValueKind "y-input") ty
|
, feed (tensorValueFromName "y-input") ty
|
||||||
]
|
]
|
||||||
-- Run the graph with the input feeds and read the ArgMax'd result from
|
-- Run the graph with the input feeds and read the ArgMax'd result from
|
||||||
-- the test (not training) side of the evaluation.
|
-- the test (not training) side of the evaluation.
|
||||||
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
|
x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax"
|
||||||
-- Print the trained model's predicted outcome.
|
-- Print the trained model's predicted outcome.
|
||||||
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
||||||
++ "Prediction: " ++ show (unScalar x :: Int64)
|
++ "Prediction: " ++ show (unScalar x :: Int64)
|
||||||
|
|
|
@ -24,7 +24,6 @@ import Prelude hiding ( log
|
||||||
, exp
|
, exp
|
||||||
)
|
)
|
||||||
import TensorFlow.Build ( MonadBuild
|
import TensorFlow.Build ( MonadBuild
|
||||||
, render
|
|
||||||
, withNameScope
|
, withNameScope
|
||||||
)
|
)
|
||||||
import TensorFlow.GenOps.Core ( greaterEqual
|
import TensorFlow.GenOps.Core ( greaterEqual
|
||||||
|
@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
|
||||||
, exp
|
, exp
|
||||||
)
|
)
|
||||||
import TensorFlow.Tensor ( Tensor(..)
|
import TensorFlow.Tensor ( Tensor(..)
|
||||||
|
, render
|
||||||
, Value
|
, Value
|
||||||
)
|
)
|
||||||
import TensorFlow.Types ( TensorType(..)
|
import TensorFlow.Types ( TensorType(..)
|
||||||
|
@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..)
|
||||||
)
|
)
|
||||||
import TensorFlow.Ops ( zerosLike
|
import TensorFlow.Ops ( zerosLike
|
||||||
, add
|
, add
|
||||||
|
, mul
|
||||||
|
, neg
|
||||||
)
|
)
|
||||||
|
|
||||||
-- | Computes sigmoid cross entropy given `logits`.
|
-- | Computes sigmoid cross entropy given `logits`.
|
||||||
|
@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits
|
||||||
-> Tensor Value a -- ^ __targets__
|
-> Tensor Value a -- ^ __targets__
|
||||||
-> m (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
sigmoidCrossEntropyWithLogits logits targets = do
|
sigmoidCrossEntropyWithLogits logits targets = do
|
||||||
logits' <- render logits
|
let zeros = zerosLike logits
|
||||||
targets' <- render targets
|
cond = logits `greaterEqual` zeros
|
||||||
let zeros = zerosLike logits'
|
relu_logits = select cond logits zeros
|
||||||
cond = logits' `greaterEqual` zeros
|
neg_abs_logits = select cond (neg logits) logits
|
||||||
relu_logits = select cond logits' zeros
|
|
||||||
neg_abs_logits = select cond (-logits') logits'
|
|
||||||
withNameScope "logistic_loss" $ do
|
withNameScope "logistic_loss" $ do
|
||||||
left <- render $ relu_logits - logits' * targets'
|
left <- render $ relu_logits - logits `mul` targets
|
||||||
right <- render $ log (1 + exp neg_abs_logits)
|
right <- render $ log (1 + exp neg_abs_logits)
|
||||||
withNameScope "sigmoid_add" $ render $ left `add` right
|
withNameScope "sigmoid_add" $ render $ left `add` right
|
||||||
|
|
|
@ -23,10 +23,9 @@ import Test.Framework (Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Gradient as TF
|
||||||
import qualified TensorFlow.Nodes as TF
|
|
||||||
import qualified TensorFlow.NN as TF
|
import qualified TensorFlow.NN as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Core as TF
|
||||||
|
|
||||||
-- | These tests are ported from:
|
-- | These tests are ported from:
|
||||||
--
|
--
|
||||||
|
@ -60,12 +59,11 @@ defInputs = Inputs {
|
||||||
testLogisticOutput :: Test
|
testLogisticOutput :: Test
|
||||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||||
let inputs = defInputs
|
let inputs = defInputs
|
||||||
vLogits = TF.vector $ logits inputs
|
r <- run $ do
|
||||||
vTargets = TF.vector $ targets inputs
|
vLogits <- TF.render $ TF.vector $ logits inputs
|
||||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
vTargets <- TF.render $ TF.vector $ targets inputs
|
||||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||||
r <- run tfLoss
|
|
||||||
assertAllClose r ourLoss
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,23 +72,22 @@ testLogisticOutputMultipleDim =
|
||||||
testCase "testLogisticOutputMultipleDim" $ do
|
testCase "testLogisticOutputMultipleDim" $ do
|
||||||
let inputs = defInputs
|
let inputs = defInputs
|
||||||
shape = [2, 2, 2]
|
shape = [2, 2, 2]
|
||||||
vLogits = TF.constant shape (logits inputs)
|
r <- run $ do
|
||||||
vTargets = TF.constant shape (targets inputs)
|
vLogits <- TF.render $ TF.constant shape (logits inputs)
|
||||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
vTargets <- TF.render $ TF.constant shape (targets inputs)
|
||||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||||
r <- run tfLoss
|
|
||||||
assertAllClose r ourLoss
|
assertAllClose r ourLoss
|
||||||
|
|
||||||
|
|
||||||
testGradientAtZero :: Test
|
testGradientAtZero :: Test
|
||||||
testGradientAtZero = testCase "testGradientAtZero" $ do
|
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
|
||||||
vLogits = TF.vector $ logits inputs
|
|
||||||
vTargets = TF.vector $ targets inputs
|
|
||||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
|
||||||
|
|
||||||
r <- run $ do
|
r <- run $ do
|
||||||
|
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||||
|
vTargets <- TF.render $ TF.vector $ targets inputs
|
||||||
|
vLogits <- TF.render $ TF.vector $ logits inputs
|
||||||
|
let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||||
|
|
||||||
l <- tfLoss
|
l <- tfLoss
|
||||||
TF.gradients l [vLogits]
|
TF.gradients l [vLogits]
|
||||||
|
|
||||||
|
|
|
@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc
|
||||||
renderHaskellAttrName = renderHaskellName . attrName
|
renderHaskellAttrName = renderHaskellName . attrName
|
||||||
|
|
||||||
functionBody :: ParsedOp -> Doc
|
functionBody :: ParsedOp -> Doc
|
||||||
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
|
functionBody pOp
|
||||||
</> indent indentation (sep tensorArgs)
|
| parsedOpIsMonadic pOp
|
||||||
|
= "build $ do"
|
||||||
|
</> indent indentation (bindOpInputsVar
|
||||||
|
</> "buildOp" <+> outputListsSizes <+> opDef)
|
||||||
|
| otherwise
|
||||||
|
= "pureOp" <+> outputListsSizes <+> "$ do"
|
||||||
|
</> indent indentation (bindOpInputsVar </> "return" <+> opDef)
|
||||||
where
|
where
|
||||||
maybeLift
|
outputListsSizes = brackets $ commasep
|
||||||
| parsedOpIsMonadic pOp = "build $"
|
[ renderHaskellName a
|
||||||
| otherwise = ""
|
|
||||||
buildFunction
|
|
||||||
| null outputListsSizes = "buildOp"
|
|
||||||
| otherwise = "buildListOp" <+>
|
|
||||||
brackets (commasep $
|
|
||||||
map renderHaskellName outputListsSizes)
|
|
||||||
outputListsSizes = [ a
|
|
||||||
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||||
<- parsedOutputs pOp]
|
<- parsedOutputs pOp
|
||||||
buildOpParts =
|
]
|
||||||
|
opInputsVar = "op'inputs"
|
||||||
|
bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence"
|
||||||
|
<+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs)
|
||||||
|
opDef = parens $ hang 0 $ stack $
|
||||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||||
-- Renders type parameter arguments.
|
-- Renders type parameter arguments.
|
||||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
||||||
|
@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
|
||||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||||
] ++
|
] ++
|
||||||
["& op'options"]
|
["& op'options & opInputs .~" <+> opInputsVar]
|
||||||
|
tensorArgs = renderTensorArg <$> parsedInputs pOp
|
||||||
|
renderTensorArg = renderHaskellName . parsedArgName
|
||||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
|
||||||
inferredTypeExpr a
|
inferredTypeExpr a
|
||||||
| typeParamIsList $ attrInfo a
|
| typeParamIsList $ attrInfo a
|
||||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||||
|
@ -296,7 +298,7 @@ typeSig pre pOp = constraints
|
||||||
| null classConstraints = empty
|
| null classConstraints = empty
|
||||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
|
||||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||||
|
@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of
|
||||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||||
-> "TensorList" <+> kind k <+> renderHaskellName t
|
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
|
||||||
where
|
where
|
||||||
kind k = case k of
|
kind k = case k of
|
||||||
ArgTensorRef -> "Ref"
|
ArgTensorRef -> "Ref"
|
||||||
ArgTensorValue -> "Value"
|
ArgTensorValue -> "Value"
|
||||||
ArgTensorEither v' -> strictText v'
|
ArgTensorBuild -> "Build"
|
||||||
|
ArgSomeTensor v -> strictText v
|
||||||
tensorType t k = let
|
tensorType t k = let
|
||||||
a = case t of
|
a = case t of
|
||||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||||
|
|
|
@ -141,7 +141,8 @@ data ArgType
|
||||||
data ArgKind
|
data ArgKind
|
||||||
= ArgTensorRef -- Tensor Ref a
|
= ArgTensorRef -- Tensor Ref a
|
||||||
| ArgTensorValue -- Tensor Value a
|
| ArgTensorValue -- Tensor Value a
|
||||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
| ArgTensorBuild -- Tensor Build a
|
||||||
|
| ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
|
||||||
deriving (Eq)
|
deriving (Eq)
|
||||||
|
|
||||||
isRefCase :: ParsedArgCase -> Bool
|
isRefCase :: ParsedArgCase -> Bool
|
||||||
|
@ -219,15 +220,17 @@ parseOp o = ParsedOp
|
||||||
{ parsedOpName = makeName $ o ^. name
|
{ parsedOpName = makeName $ o ^. name
|
||||||
, parsedOpSummary = o ^. summary
|
, parsedOpSummary = o ^. summary
|
||||||
, parsedOpDescription = o ^. description
|
, parsedOpDescription = o ^. description
|
||||||
, parsedOpIsMonadic = o ^. isStateful
|
|
||||||
|| any (isRefCase . parsedArgCase) parsedInputs
|
|
||||||
, ..
|
, ..
|
||||||
}
|
}
|
||||||
where
|
where
|
||||||
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
parsedOpIsMonadic = o ^. isStateful
|
||||||
(o ^. inputArg) tensorKindParams
|
|| any (isRefCase . parsedArgCase) parsedInputs
|
||||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
|| null (o ^. outputArg)
|
||||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a))
|
||||||
|
tensorKindParams (o ^. inputArg)
|
||||||
|
tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||||
|
parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a))
|
||||||
|
(o ^. outputArg)
|
||||||
-- Integer attributes that can be inferred from the size of at least one
|
-- Integer attributes that can be inferred from the size of at least one
|
||||||
-- input list.
|
-- input list.
|
||||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||||
|
@ -246,15 +249,16 @@ parseOp o = ParsedOp
|
||||||
$ o ^. attr
|
$ o ^. attr
|
||||||
|
|
||||||
-- TODO(judahjacobson): Some arguments should be refs.
|
-- TODO(judahjacobson): Some arguments should be refs.
|
||||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
|
||||||
inputTensorKind a v
|
inputTensorKind v a
|
||||||
| a ^. isRef = ArgTensorRef
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorEither v
|
| otherwise = ArgSomeTensor v
|
||||||
|
|
||||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
|
||||||
outputTensorKind a
|
outputTensorKind isMonadic a
|
||||||
| a ^. isRef = ArgTensorRef
|
| a ^. isRef = ArgTensorRef
|
||||||
| otherwise = ArgTensorValue
|
| isMonadic = ArgTensorValue
|
||||||
|
| otherwise = ArgTensorBuild
|
||||||
|
|
||||||
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||||
getExplicitInputAttr o implicitAttrs a
|
getExplicitInputAttr o implicitAttrs a
|
||||||
|
|
|
@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where
|
||||||
|
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import TensorFlow.Build (MonadBuild, colocateWith, render)
|
import TensorFlow.Build (MonadBuild)
|
||||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||||
import TensorFlow.Tensor (Tensor, Value)
|
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
|
||||||
import TensorFlow.Types (OneOf, TensorType)
|
import TensorFlow.Types (OneOf, TensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
|
@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
--
|
--
|
||||||
-- The results of the lookup are concatenated into a dense
|
-- The results of the lookup are concatenated into a dense
|
||||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||||
embeddingLookup :: forall a b v m .
|
embeddingLookup :: forall a b v1 v2 m .
|
||||||
( MonadBuild m
|
( MonadBuild m
|
||||||
|
, Rendered v1
|
||||||
, TensorType a
|
, TensorType a
|
||||||
, OneOf '[Int64, Int32] b
|
, OneOf '[Int64, Int32] b
|
||||||
, Num b
|
, Num b
|
||||||
)
|
)
|
||||||
=> [Tensor v a]
|
=> [Tensor v1 a]
|
||||||
-- ^ A list of tensors which can be concatenated along
|
-- ^ A list of tensors which can be concatenated along
|
||||||
-- dimension 0. Each `Tensor` must be appropriately
|
-- dimension 0. Each `Tensor` must be appropriately
|
||||||
-- sized for `mod` partition strategy.
|
-- sized for `mod` partition strategy.
|
||||||
-> Tensor Value b
|
-> Tensor v2 b
|
||||||
-- ^ A `Tensor` with type `int32` or `int64`
|
-- ^ A `Tensor` with type `int32` or `int64`
|
||||||
-- containing the ids to be looked up in `params`.
|
-- containing the ids to be looked up in `params`.
|
||||||
-- The ids are required to have fewer than 2^31
|
-- The ids are required to have fewer than 2^31
|
||||||
|
|
|
@ -31,6 +31,7 @@ import Data.ByteString (ByteString)
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
|
import Data.Foldable (foldlM)
|
||||||
import Data.List (foldl', sortBy)
|
import Data.List (foldl', sortBy)
|
||||||
import Data.Map.Strict (Map)
|
import Data.Map.Strict (Map)
|
||||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||||
|
@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import Data.Tuple (swap)
|
import Data.Tuple (swap)
|
||||||
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
|
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
|
||||||
import Lens.Family2.State.Strict (uses)
|
import Lens.Family2.State.Strict (uses)
|
||||||
import Lens.Family2.Stock (at, intAt)
|
import Lens.Family2.Stock (at, intAt)
|
||||||
import Lens.Family2.Unchecked (lens, iso)
|
import Lens.Family2.Unchecked (lens, iso)
|
||||||
|
@ -59,11 +60,10 @@ import TensorFlow.Build
|
||||||
( MonadBuild
|
( MonadBuild
|
||||||
, Build
|
, Build
|
||||||
, build
|
, build
|
||||||
, render
|
|
||||||
, renderNodeName
|
|
||||||
, renderedNodeDefs
|
, renderedNodeDefs
|
||||||
, opDef
|
, opDef
|
||||||
, opAttr
|
, opAttr
|
||||||
|
, opInputs
|
||||||
)
|
)
|
||||||
import TensorFlow.BuildOp
|
import TensorFlow.BuildOp
|
||||||
import TensorFlow.Ops
|
import TensorFlow.Ops
|
||||||
|
@ -86,16 +86,19 @@ import TensorFlow.Ops
|
||||||
)
|
)
|
||||||
import TensorFlow.Output
|
import TensorFlow.Output
|
||||||
( NodeName(..)
|
( NodeName(..)
|
||||||
, Op (Rendered)
|
|
||||||
, Output(..)
|
, Output(..)
|
||||||
, OutputIx(..)
|
, OutputIx(..)
|
||||||
, outputIndex
|
, outputIndex
|
||||||
)
|
)
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
( Tensor(..)
|
( Tensor(..)
|
||||||
, TensorKind (ValueKind)
|
|
||||||
, Value
|
, Value
|
||||||
, tensorOutput
|
, render
|
||||||
|
, expr
|
||||||
|
, Rendered
|
||||||
|
, tensorNodeName
|
||||||
|
, renderedOutput
|
||||||
|
, renderValue
|
||||||
)
|
)
|
||||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
@ -114,10 +117,7 @@ type GradientCompatible a =
|
||||||
|
|
||||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||||
, Num (Tensor v1 a)
|
, Rendered v2
|
||||||
-- TODO(gnezdo): remove indirect constraint.
|
|
||||||
-- It's a wart inherited from Num instance.
|
|
||||||
, v1 ~ Value
|
|
||||||
, GradientCompatible a
|
, GradientCompatible a
|
||||||
)
|
)
|
||||||
=> Tensor v1 a -- ^ The output of the graph.
|
=> Tensor v1 a -- ^ The output of the graph.
|
||||||
|
@ -150,27 +150,31 @@ gradients y xs = build $ do
|
||||||
--
|
--
|
||||||
-- 4. Lookup the recorded gradient for each x in xs.
|
-- 4. Lookup the recorded gradient for each x in xs.
|
||||||
|
|
||||||
yName <- renderNodeName y
|
y' <- renderValue y
|
||||||
|
let yName = tensorNodeName y'
|
||||||
|
yOne <- render $ fill (shape y') (scalar 1)
|
||||||
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
||||||
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
||||||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||||
. flip Map.lookup
|
. flip Map.lookup
|
||||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||||
-- Set gradient of y to one.
|
-- Set gradient of y to one.
|
||||||
|
-- TODO: nicer
|
||||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||||
initPending = Map.empty & at (nodeMap Map.! yName)
|
= Map.empty & (at (nodeMap Map.! yName)
|
||||||
. nonEmpty
|
. nonEmpty
|
||||||
. outputIxAt (y ^. tensorOutput . outputIndex)
|
. outputIxAt (outputIndex $ renderedOutput y')
|
||||||
. nonEmpty
|
. nonEmpty
|
||||||
.~ [fill (shape y) (scalar 1)]
|
.~ [yOne]
|
||||||
|
)
|
||||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||||
gradientMap <- graphGrads gr initPending
|
gradientMap <- graphGrads gr initPending
|
||||||
-- Lookup the gradients for each x.
|
-- Lookup the gradients for each x.
|
||||||
forM xs $ \x -> do
|
forM xs $ \x ->
|
||||||
xName <- renderNodeName x
|
let xName = tensorNodeName x
|
||||||
render $ fromMaybe (zerosLike x) $ do
|
in maybe (render $ zerosLike x) return $ do
|
||||||
n <- nodeMap ^. at xName
|
n <- nodeMap ^. at xName
|
||||||
let i = x ^. tensorOutput . outputIndex
|
let i = outputIndex $ renderedOutput x
|
||||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||||
|
|
||||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||||
|
@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx
|
||||||
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
||||||
|
|
||||||
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
||||||
|
-- TODO: precache the rendering?
|
||||||
type Gradients a = IntMap.IntMap (Tensor Value a)
|
type Gradients a = IntMap.IntMap (Tensor Value a)
|
||||||
|
|
||||||
-- | Graph of TensorFlow operations.
|
-- | Graph of TensorFlow operations.
|
||||||
|
@ -229,42 +234,43 @@ non a = anon a (a==)
|
||||||
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
||||||
nonEmpty = anon mempty null
|
nonEmpty = anon mempty null
|
||||||
|
|
||||||
|
-- TODO: strictness (e.g., foldlM')
|
||||||
|
|
||||||
-- | Calculate the gradients for every node in a graph.
|
-- | Calculate the gradients for every node in a graph.
|
||||||
graphGrads :: forall a. GradientCompatible a
|
graphGrads :: forall a. GradientCompatible a
|
||||||
=> Graph
|
=> Graph
|
||||||
-> Map FGL.Node (PendingGradients a)
|
-> Map FGL.Node (PendingGradients a)
|
||||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||||
-> Build (Map FGL.Node (Gradients a))
|
-> Build (Map FGL.Node (Gradients a))
|
||||||
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
|
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||||
where
|
where
|
||||||
initState = GradientsState initPending Map.empty
|
initState = GradientsState initPending Map.empty
|
||||||
-- Reverse topological sort.
|
-- Reverse topological sort.
|
||||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||||
-- avoid calculating gradients that won't be used.
|
-- avoid calculating gradients that won't be used.
|
||||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||||
go state node =
|
go :: GradientsState a -> Int -> Build (GradientsState a)
|
||||||
|
go state node = do
|
||||||
-- Aggregate the accumulated gradients for this node.
|
-- Aggregate the accumulated gradients for this node.
|
||||||
let outputGrads =
|
outputGrads <-
|
||||||
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
||||||
in if null outputGrads
|
if null outputGrads
|
||||||
then state
|
then pure state
|
||||||
else
|
else do
|
||||||
|
let ctx = FGL.context gr node
|
||||||
|
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||||
-- Calculate the gradients for each of the node's inputs.
|
-- Calculate the gradients for each of the node's inputs.
|
||||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||||
ctx = FGL.context gr node
|
pure $ updatePendingGradients ctx inputGrads nextState
|
||||||
in updatePendingGradients
|
|
||||||
ctx
|
|
||||||
(calculateInputGrads ctx outputGrads gr)
|
|
||||||
nextState
|
|
||||||
|
|
||||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||||
sumPendingGradient :: GradientCompatible a
|
sumPendingGradient :: GradientCompatible a
|
||||||
=> PendingGradients a -> Gradients a
|
=> PendingGradients a -> Build (Gradients a)
|
||||||
sumPendingGradient = IntMap.mapMaybe f
|
sumPendingGradient = sequence . IntMap.mapMaybe f
|
||||||
where
|
where
|
||||||
f [] = Nothing
|
f [] = Nothing
|
||||||
f [x] = Just x
|
f [x] = Just (pure x)
|
||||||
f xs = Just (addN xs)
|
f xs = Just (render $ addN xs)
|
||||||
|
|
||||||
|
|
||||||
-- | Calculate the gradients of a node's input tensors.
|
-- | Calculate the gradients of a node's input tensors.
|
||||||
|
@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a
|
||||||
=> FGL.Context NodeDef EdgeLabel
|
=> FGL.Context NodeDef EdgeLabel
|
||||||
-> Gradients a -- ^ Output gradients of the node.
|
-> Gradients a -- ^ Output gradients of the node.
|
||||||
-> Graph
|
-> Graph
|
||||||
-> [Maybe (Tensor Value a)]
|
-> Build [Maybe (Tensor Value a)]
|
||||||
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
|
||||||
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
|
||||||
|
outputGrads
|
||||||
|
traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
||||||
where
|
where
|
||||||
fullOutGrads =
|
|
||||||
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
|
|
||||||
-- Create a tensor from an edge (technically an Output, but it seems less
|
-- Create a tensor from an edge (technically an Output, but it seems less
|
||||||
-- confusing to refer to it as a tensor here).
|
-- confusing to refer to it as a tensor here).
|
||||||
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
||||||
edgeToTensor ((i, _), n) =
|
edgeToTensor ((i, _), n) =
|
||||||
case FGL.lab gr n of
|
case FGL.lab gr n of
|
||||||
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
|
Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name)
|
||||||
Nothing -> error $ "calculateInputGrads: missing input node for "
|
Nothing -> error $ "calculateInputGrads: missing input node for "
|
||||||
++ Text.unpack (nodeDef ^. name)
|
++ Text.unpack (nodeDef ^. name)
|
||||||
-- Input tensors, sorted by input index.
|
-- Input tensors, sorted by input index.
|
||||||
|
@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
||||||
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
||||||
fullOutputGrads :: (TensorType a, Num a)
|
fullOutputGrads :: (TensorType a, Num a)
|
||||||
=> OutputIx -- ^ Number of outputs.
|
=> OutputIx -- ^ Number of outputs.
|
||||||
-> Op
|
-> NodeName
|
||||||
-> Gradients a
|
-> Gradients a
|
||||||
-> [Tensor Value a]
|
-> Build [Tensor Value a]
|
||||||
fullOutputGrads n o gs =
|
fullOutputGrads n o gs =
|
||||||
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
|
mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
|
||||||
where
|
where
|
||||||
-- A tensor of zeros with the same shape as the i'th output.
|
-- A tensor of zeros with the same shape as the i'th output.
|
||||||
zero i = zerosLike $ toT (Output i o)
|
zero i = zerosLike $ toT (Output i o)
|
||||||
|
@ -397,19 +403,19 @@ type GradientFunc a = NodeDef
|
||||||
-- ^ Input tensors.
|
-- ^ Input tensors.
|
||||||
-> [Tensor Value a]
|
-> [Tensor Value a]
|
||||||
-- ^ Gradient of y w.r.t. each output tensor.
|
-- ^ Gradient of y w.r.t. each output tensor.
|
||||||
-> [Maybe (Tensor Value a)]
|
-> [Maybe (Tensor Build a)]
|
||||||
-- ^ Gradient of y w.r.t. each input tensor.
|
-- ^ Gradient of y w.r.t. each input tensor.
|
||||||
|
|
||||||
|
|
||||||
-- TODO(fmayle): Assert the type is correct.
|
-- TODO(fmayle): Assert the type is correct.
|
||||||
-- | Create a Tensor from an Output.
|
-- | Create a Tensor from an Output.
|
||||||
toT :: Output -> Tensor Value a
|
toT :: Output -> Tensor Build a
|
||||||
toT = Tensor ValueKind
|
toT = Tensor . pure
|
||||||
|
|
||||||
|
|
||||||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||||
-- simple slicing operations.
|
-- simple slicing operations.
|
||||||
flatSlice :: forall v1 t . (TensorType t)
|
flatSlice :: forall v1 t . TensorType t
|
||||||
=> Tensor v1 t -- ^ __input__
|
=> Tensor v1 t -- ^ __input__
|
||||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||||
-- 'input' to slice from.
|
-- 'input' to slice from.
|
||||||
|
@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t)
|
||||||
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||||
-- are included in the slice (i.e. this is equivalent to setting
|
-- are included in the slice (i.e. this is equivalent to setting
|
||||||
-- size = input.dim_size(0) - begin).
|
-- size = input.dim_size(0) - begin).
|
||||||
-> Tensor Value t -- ^ __output__
|
-> Tensor Build t -- ^ __output__
|
||||||
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
||||||
|
|
||||||
|
nodeDefName :: NodeDef -> NodeName
|
||||||
|
nodeDefName = NodeName . view name
|
||||||
|
|
||||||
|
|
||||||
-- | The gradient function for an op type.
|
-- | The gradient function for an op type.
|
||||||
--
|
--
|
||||||
|
@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
||||||
-- third_party/tensorflow/python/ops/*_grad.py
|
-- third_party/tensorflow/python/ops/*_grad.py
|
||||||
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
||||||
|
|
||||||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
|
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
||||||
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
|
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||||
|
|
||||||
opGrad "Square" _ [toT -> x] [dz] =
|
opGrad "Square" _ [toT -> x] [dz] =
|
||||||
|
@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] =
|
||||||
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||||
-- (for performance reasons?). Will need to put these functions in the Build
|
-- (for performance reasons?). Will need to put these functions in the Build
|
||||||
-- monad to replicate that.
|
-- monad to replicate that.
|
||||||
[Just $ dz * (2 * x)]
|
[Just $ dz `CoreOps.mul` (2 * x)]
|
||||||
|
|
||||||
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||||
-- TODO(fmayle): The python version uses a better performance implementation
|
-- TODO(fmayle): The python version uses a better performance implementation
|
||||||
|
@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||||
denseShape = shape (x :: Tensor Value a)
|
denseShape = shape (x :: Tensor Build a)
|
||||||
numRows = scalarize $ flatSlice denseShape 0 1
|
numRows = scalarize $ flatSlice denseShape 0 1
|
||||||
valuesShape = CoreOps.concat 0 [ allDimensions
|
valuesShape = CoreOps.concat 0 [ allDimensions
|
||||||
, flatSlice denseShape 1 (-1)
|
, flatSlice denseShape 1 (-1)
|
||||||
]
|
]
|
||||||
values = reshape dz valuesShape
|
values = reshape dz valuesShape
|
||||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||||
indices' = reshape indices allDimensions :: Tensor Value Int32
|
indices' = reshape indices allDimensions :: Tensor Build Int32
|
||||||
|
|
||||||
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
||||||
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
||||||
where
|
where
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Build a)
|
||||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
|
||||||
y = CoreOps.max x indices
|
y = CoreOps.max x indices
|
||||||
y' = reshape y outputShapeKeptDims
|
y' = reshape y outputShapeKeptDims
|
||||||
dz' = reshape dz outputShapeKeptDims
|
dz' = reshape dz outputShapeKeptDims
|
||||||
|
@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
||||||
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
||||||
where
|
where
|
||||||
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Build a)
|
||||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
|
||||||
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
||||||
grad = reshape dz outputShapeKeptDims
|
grad = reshape dz outputShapeKeptDims
|
||||||
|
|
||||||
|
@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w =
|
||||||
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
||||||
where
|
where
|
||||||
[Just dz, Nothing] = opGrad "Sum" u v w
|
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||||
inputShape = shape (x :: Tensor Value a)
|
inputShape = shape (x :: Tensor Build a)
|
||||||
outputShape = shape (dz :: Tensor Value a)
|
outputShape = shape (dz :: Tensor Build a)
|
||||||
-- TODO(fmayle): Add fast path when shape is known.
|
-- TODO(fmayle): Add fast path when shape is known.
|
||||||
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
||||||
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
||||||
|
@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ reshape (sum dz rx) sx
|
[ Just $ reshape (sum dz rx) sx
|
||||||
, Just $ reshape (sum dz ry) sy ]
|
, Just $ reshape (sum dz ry) sy ]
|
||||||
where
|
where
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Build a)
|
||||||
sy = shape (y :: Tensor Value a)
|
sy = shape (y :: Tensor Build a)
|
||||||
(rx, ry) = broadcastGradientArgs sx sy
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
opGrad "Sub" u v w =
|
opGrad "Sub" u v w =
|
||||||
|
@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
|
||||||
|
|
||||||
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
||||||
-- TODO(fmayle): Handle complex numbers.
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
[ Just $ reshape (sum (dz * y) rx) sx
|
[ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
|
||||||
, Just $ reshape (sum (x * dz) ry) sy ]
|
, Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
|
||||||
where
|
where
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Build a)
|
||||||
sy = shape (y :: Tensor Value a)
|
sy = shape (y :: Tensor Build a)
|
||||||
(rx, ry) = broadcastGradientArgs sx sy
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
||||||
-- TODO(fmayle): Handle complex numbers.
|
-- TODO(fmayle): Handle complex numbers.
|
||||||
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
||||||
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
||||||
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
|
, Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
|
||||||
|
ry)
|
||||||
|
sy
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
sx = shape (x :: Tensor Value a)
|
sx = shape (x :: Tensor Build a)
|
||||||
sy = shape (y :: Tensor Value a)
|
sy = shape (y :: Tensor Build a)
|
||||||
(rx, ry) = broadcastGradientArgs sx sy
|
(rx, ry) = broadcastGradientArgs sx sy
|
||||||
|
|
||||||
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
|
|
||||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
[ Just $ CoreOps.transpose dz
|
[ Just $ CoreOps.transpose dz
|
||||||
(CoreOps.invertPermutation p :: Tensor Value Int32)
|
(CoreOps.invertPermutation p :: Tensor Build Int32)
|
||||||
, Nothing
|
, Nothing
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
x output dz
|
x output dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
output :: Tensor Value a
|
output :: Tensor Build a
|
||||||
output = toT $ Output 0 (Rendered nodeDef)
|
output = toT $ Output 0 (nodeDefName nodeDef)
|
||||||
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||||
|
|
||||||
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
||||||
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
|
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||||
|
|
||||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||||
|
|
||||||
opGrad "RefIdentity" _ _ [dz] = [Just dz]
|
opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz]
|
||||||
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
||||||
where
|
where
|
||||||
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
||||||
reverseCast =
|
reverseCast =
|
||||||
buildOp (opDef "Cast"
|
pureOp [] $ pure (opDef "Cast"
|
||||||
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
||||||
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
|
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)
|
||||||
dz
|
& opInputs .~ [renderedOutput dz])
|
||||||
|
|
||||||
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
||||||
replicate halfLen Nothing ++ valuesGrads
|
replicate halfLen Nothing ++ valuesGrads
|
||||||
|
@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] =
|
||||||
in if 2 * half == len
|
in if 2 * half == len
|
||||||
then half
|
then half
|
||||||
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
||||||
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
|
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
|
||||||
| idx <- take halfLen inputs
|
| idx <- take halfLen inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
|
||||||
[ Just reconstructed, Nothing ]
|
[ Just reconstructed, Nothing ]
|
||||||
where
|
where
|
||||||
reconstructed = CoreOps.reshape stitched
|
reconstructed = CoreOps.reshape stitched
|
||||||
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
|
(CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
|
||||||
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
||||||
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
||||||
np = lookupAttr nodeDef "num_partitions" :: Int64
|
np = lookupAttr nodeDef "num_partitions" :: Int64
|
||||||
originalIndices =
|
originalIndices =
|
||||||
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
||||||
prefixShape = shapeInt32 indices
|
prefixShape = shapeInt32 indices
|
||||||
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
|
shapeInt32 t = CoreOps.shape t :: Tensor Build Int32
|
||||||
|
|
||||||
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
||||||
[ Nothing
|
[ Nothing
|
||||||
|
@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
||||||
where zeros = CoreOps.zerosLike x
|
where zeros = CoreOps.zerosLike x
|
||||||
|
|
||||||
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
||||||
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
|
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
|
||||||
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
||||||
-- though, it is probably CSE'd away anyway.
|
-- though, it is probably CSE'd away anyway.
|
||||||
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
|
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
|
||||||
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||||
[ Just $ CoreOps.unsortedSegmentSum
|
[ Just $ CoreOps.unsortedSegmentSum
|
||||||
(CoreOps.gather dz (t :: Tensor Value Int32))
|
(CoreOps.gather dz (t :: Tensor Build Int32))
|
||||||
(y :: Tensor Value Int32) inputRows
|
(y :: Tensor Build Int32) inputRows
|
||||||
, Nothing
|
, Nothing
|
||||||
, Nothing
|
, Nothing
|
||||||
]
|
]
|
||||||
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
|
where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1
|
||||||
|
|
||||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||||
|
@ -710,13 +721,13 @@ numOutputs o =
|
||||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||||
|
|
||||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||||
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
|
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
|
||||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||||
|
|
||||||
allDimensions :: Tensor Value Int32
|
allDimensions :: Tensor Build Int32
|
||||||
allDimensions = vector [-1 :: Int32]
|
allDimensions = vector [-1 :: Int32]
|
||||||
|
|
||||||
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
|
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
|
||||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||||
|
|
||||||
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
||||||
|
|
|
@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.BuildOp
|
import TensorFlow.BuildOp
|
||||||
import TensorFlow.ControlFlow (group)
|
import TensorFlow.ControlFlow (group)
|
||||||
import TensorFlow.Output (unNodeName)
|
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
import TensorFlow.Types
|
import TensorFlow.Types
|
||||||
|
|
||||||
|
@ -183,7 +182,7 @@ import qualified Prelude (abs)
|
||||||
-- "1".
|
-- "1".
|
||||||
instance ( TensorType a
|
instance ( TensorType a
|
||||||
, Num a
|
, Num a
|
||||||
, v ~ Value
|
, v ~ Build
|
||||||
, OneOf '[ Double, Float, Int32, Int64
|
, OneOf '[ Double, Float, Int32, Int64
|
||||||
, Complex Float, Complex Double] a) => Num (Tensor v a) where
|
, Complex Float, Complex Double] a) => Num (Tensor v a) where
|
||||||
(+) = CoreOps.add
|
(+) = CoreOps.add
|
||||||
|
@ -194,10 +193,10 @@ instance ( TensorType a
|
||||||
signum = CoreOps.sign
|
signum = CoreOps.sign
|
||||||
negate = CoreOps.neg
|
negate = CoreOps.neg
|
||||||
|
|
||||||
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
|
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
||||||
matTranspose = matTranspose' id
|
matTranspose = matTranspose' id
|
||||||
|
|
||||||
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
|
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
|
||||||
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||||
|
|
||||||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||||
|
@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a)
|
||||||
placeholder' params pShape
|
placeholder' params pShape
|
||||||
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
||||||
-- and thus would be CSE'd.
|
-- and thus would be CSE'd.
|
||||||
= build $ buildOp $ opDef "Placeholder"
|
= build $ buildOp [] $ opDef "Placeholder"
|
||||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
& opAttr "shape" .~ pShape
|
& opAttr "shape" .~ pShape
|
||||||
& params
|
& params
|
||||||
|
@ -216,11 +215,11 @@ placeholder' params pShape
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: (MonadBuild m, TensorType a)
|
initializedVariable :: (MonadBuild m, TensorType a)
|
||||||
=> Tensor Value a -> m (Tensor Ref a)
|
=> Tensor v a -> m (Tensor Ref a)
|
||||||
initializedVariable = initializedVariable' id
|
initializedVariable = initializedVariable' id
|
||||||
|
|
||||||
initializedVariable' :: (MonadBuild m, TensorType a)
|
initializedVariable' :: (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
|
=> OpParams -> Tensor v a -> m (Tensor Ref a)
|
||||||
initializedVariable' params initializer = do
|
initializedVariable' params initializer = do
|
||||||
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
||||||
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
||||||
|
@ -240,17 +239,20 @@ zeroInitializedVariable'
|
||||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||||
|
|
||||||
-- TODO: Support heterogeneous list of tensors.
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> [Tensor v a] -- ^ Tensors to save.
|
-> [Tensor v a] -- ^ Tensors to save.
|
||||||
-> m ControlNode
|
-> m ControlNode
|
||||||
save path xs = do
|
save path xs = build $ do
|
||||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput
|
||||||
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
|
let names = fmap toByteStringTensor xs
|
||||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||||
let saveOp = buildOp $ opDef "Save"
|
names' <- buildInputs $ CoreOps.pack names
|
||||||
|
xs' <- buildInputs xs
|
||||||
|
path' <- buildInputs $ scalar path
|
||||||
|
buildOp [] $ opDef "Save"
|
||||||
& opAttr "T" .~ types
|
& opAttr "T" .~ types
|
||||||
build $ saveOp (scalar path) (CoreOps.pack names) xs
|
& opInputs .~ (path' ++ names' ++ xs')
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
--
|
--
|
||||||
|
@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a)
|
||||||
-> ByteString -- ^ Tensor name override.
|
-> ByteString -- ^ Tensor name override.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
-> m ControlNode
|
-> m ControlNode
|
||||||
restoreFromName path name x = do
|
restoreFromName path name x = build $ do
|
||||||
let restoreOp = buildOp $ opDef "Restore"
|
path' <- buildInputs $ scalar path
|
||||||
|
name' <- buildInputs $ scalar name
|
||||||
|
restoreOp <- buildOp [] $ opDef "Restore"
|
||||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||||
group =<< CoreOps.assign x
|
& opInputs .~ (path' ++ name')
|
||||||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
group =<< CoreOps.assign x (restoreOp :: Tensor Value a)
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
restore :: forall a m . (MonadBuild m, TensorType a)
|
restore :: forall a m . (MonadBuild m, TensorType a)
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> Tensor Ref a -- ^ Tensor to restore.
|
||||||
-> m ControlNode
|
-> m ControlNode
|
||||||
restore path x = do
|
restore path x = restoreFromName path name x
|
||||||
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
|
where
|
||||||
restoreFromName path name x
|
name = encodeUtf8 $ encodeOutput $ renderedOutput x
|
||||||
|
|
||||||
-- | Create a constant tensor.
|
-- | Create a constant tensor.
|
||||||
--
|
--
|
||||||
|
@ -283,10 +287,10 @@ restore path x = do
|
||||||
-- element 0: index (0, ..., 0)
|
-- element 0: index (0, ..., 0)
|
||||||
-- element 1: index (0, ..., 1)
|
-- element 1: index (0, ..., 1)
|
||||||
-- ...
|
-- ...
|
||||||
constant :: TensorType a => Shape -> [a] -> Tensor Value a
|
constant :: TensorType a => Shape -> [a] -> Tensor Build a
|
||||||
constant = constant' id
|
constant = constant' id
|
||||||
|
|
||||||
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
|
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
|
||||||
constant' params (Shape cShape) values
|
constant' params (Shape cShape) values
|
||||||
| invalidLength = error invalidLengthMsg
|
| invalidLength = error invalidLengthMsg
|
||||||
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
||||||
|
@ -305,24 +309,24 @@ constant' params (Shape cShape) values
|
||||||
-- | Reshape a N-D tensor down to a scalar.
|
-- | Reshape a N-D tensor down to a scalar.
|
||||||
--
|
--
|
||||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
scalarize :: TensorType a => Tensor v a -> Tensor Build a
|
||||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
where
|
where
|
||||||
scalarShape = [] :: [Int32]
|
scalarShape = [] :: [Int32]
|
||||||
|
|
||||||
|
|
||||||
-- | Create a constant vector.
|
-- | Create a constant vector.
|
||||||
vector :: TensorType a => [a] -> Tensor Value a
|
vector :: TensorType a => [a] -> Tensor Build a
|
||||||
vector = vector' id
|
vector = vector' id
|
||||||
|
|
||||||
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
|
vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
|
||||||
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
||||||
|
|
||||||
-- | Create a constant scalar.
|
-- | Create a constant scalar.
|
||||||
scalar :: TensorType a => a -> Tensor Value a
|
scalar :: TensorType a => a -> Tensor Build a
|
||||||
scalar = scalar' id
|
scalar = scalar' id
|
||||||
|
|
||||||
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
|
scalar' :: TensorType a => OpParams -> a -> Tensor Build a
|
||||||
scalar' params x = constant' params [] [x]
|
scalar' params x = constant' params [] [x]
|
||||||
|
|
||||||
-- | Random tensor from the unit normal distribution with bounded values.
|
-- | Random tensor from the unit normal distribution with bounded values.
|
||||||
|
@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||||
-> m (Tensor Value a)
|
-> m (Tensor Value a)
|
||||||
truncatedNormal' = CoreOps.truncatedNormal'
|
truncatedNormal' = CoreOps.truncatedNormal'
|
||||||
|
|
||||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
|
||||||
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
||||||
|
|
||||||
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
|
shape :: TensorType t => Tensor v t -> Tensor Build Int32
|
||||||
shape = CoreOps.shape
|
shape = CoreOps.shape
|
||||||
|
|
||||||
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
|
shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
|
||||||
shape' = CoreOps.shape'
|
shape' = CoreOps.shape'
|
||||||
|
|
||||||
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
|
||||||
expandDims = CoreOps.expandDims
|
expandDims = CoreOps.expandDims
|
||||||
|
|
||||||
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
|
||||||
expandDims' = CoreOps.expandDims'
|
expandDims' = CoreOps.expandDims'
|
||||||
|
|
||||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
|
||||||
reducedShape inputShape axes =
|
reducedShape inputShape axes =
|
||||||
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
|
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
|
||||||
axes32 = toInt32 axes -- [1, 2]
|
axes32 = toInt32 axes -- [1, 2]
|
||||||
toInt32 x = CoreOps.cast x :: Tensor Value Int32
|
toInt32 x = CoreOps.cast x :: Tensor Build Int32
|
||||||
inputRank = CoreOps.size inputShape32 -- 4
|
inputRank = CoreOps.size inputShape32 -- 4
|
||||||
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
|
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
|
||||||
axesShape = shape axesMod -- [2]
|
axesShape = shape axesMod -- [2]
|
||||||
|
|
|
@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest
|
||||||
, test-framework
|
, test-framework
|
||||||
, test-framework-hunit
|
, test-framework-hunit
|
||||||
, test-framework-quickcheck2
|
, test-framework-quickcheck2
|
||||||
|
, transformers
|
||||||
, vector
|
, vector
|
||||||
|
|
||||||
Test-Suite ArrayOpsTest
|
Test-Suite ArrayOpsTest
|
||||||
|
|
|
@ -24,9 +24,7 @@ import Test.HUnit ((@=?))
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
-- | Test split and concat are inverses.
|
-- | Test split and concat are inverses.
|
||||||
|
@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
|
||||||
testShapeN :: Test
|
testShapeN :: Test
|
||||||
testShapeN = testCase "testShapeN" $ TF.runSession $ do
|
testShapeN = testCase "testShapeN" $ TF.runSession $ do
|
||||||
let shapes = map TF.Shape [[1],[2,3]]
|
let shapes = map TF.Shape [[1],[2,3]]
|
||||||
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
|
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float]
|
||||||
result <- TF.run $ CoreOps.shapeN tensors
|
result <- TF.run $ CoreOps.shapeN tensors
|
||||||
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
|
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
|
||||||
|
|
||||||
|
|
|
@ -34,9 +34,7 @@ import TensorFlow.Build
|
||||||
, asGraphDef
|
, asGraphDef
|
||||||
, evalBuildT
|
, evalBuildT
|
||||||
, flushNodeBuffer
|
, flushNodeBuffer
|
||||||
, render
|
|
||||||
, withDevice
|
, withDevice
|
||||||
, colocateWith
|
|
||||||
, withNameScope
|
, withNameScope
|
||||||
, opName
|
, opName
|
||||||
)
|
)
|
||||||
|
@ -50,7 +48,13 @@ import TensorFlow.Ops
|
||||||
, variable'
|
, variable'
|
||||||
)
|
)
|
||||||
import TensorFlow.Output (Device(..))
|
import TensorFlow.Output (Device(..))
|
||||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
import TensorFlow.Tensor
|
||||||
|
( colocateWith
|
||||||
|
, render
|
||||||
|
, Tensor
|
||||||
|
, Value
|
||||||
|
, Ref
|
||||||
|
)
|
||||||
import TensorFlow.Session
|
import TensorFlow.Session
|
||||||
( run
|
( run
|
||||||
, runSession
|
, runSession
|
||||||
|
@ -65,8 +69,7 @@ import qualified Data.Vector as V
|
||||||
-- | Test 'opName' behavior.
|
-- | Test 'opName' behavior.
|
||||||
testOpName :: Test
|
testOpName :: Test
|
||||||
testOpName = testCase "testOpName" $ do
|
testOpName = testCase "testOpName" $ do
|
||||||
let graph = variable' (opName .~ "foo") []
|
let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float)
|
||||||
>>= render :: Build (Tensor Ref Float)
|
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
nodeDef = head $ asGraphDef graph ^. node
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
"Variable" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
|
@ -114,7 +117,6 @@ testNamedAndScoped :: Test
|
||||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||||
let graph :: Build (Tensor Ref Float)
|
let graph :: Build (Tensor Ref Float)
|
||||||
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
||||||
>>= render
|
|
||||||
nodeDef :: NodeDef
|
nodeDef :: NodeDef
|
||||||
nodeDef = head $ asGraphDef graph ^. node
|
nodeDef = head $ asGraphDef graph ^. node
|
||||||
"Variable" @=? (nodeDef ^. op)
|
"Variable" @=? (nodeDef ^. op)
|
||||||
|
|
|
@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
|
||||||
|
|
||||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||||
-- back.
|
-- back.
|
||||||
testDynamicPartitionStitchInverse :: forall a.
|
testDynamicPartitionStitchInverse :: forall a.
|
||||||
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
||||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||||
let splitParts :: [TF.Tensor TF.Value a] =
|
let splitParts :: [TF.Tensor TF.Build a] =
|
||||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||||
partTensor = TF.vector partitions
|
partTensor = TF.vector partitions
|
||||||
restitchIndices = CoreOps.dynamicPartition numParts
|
restitchIndices = CoreOps.dynamicPartition numParts
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
-- | Tests for EmbeddingOps.
|
-- | Tests for EmbeddingOps.
|
||||||
module Main where
|
module Main where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
import Data.List (genericLength)
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
|
@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
||||||
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||||
let embedding1 = [1, 1, 1 :: Int32]
|
let embedding1 = [1, 1, 1 :: Int32]
|
||||||
let embedding2 = [0, 0, 0 :: Int32]
|
let embedding2 = [0, 0, 0 :: Int32]
|
||||||
let embedding = [ TF.constant embShape embedding1
|
|
||||||
, TF.constant embShape embedding2
|
|
||||||
]
|
|
||||||
|
|
||||||
let idValues = [0, 1 :: Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
|
||||||
let op = embeddingLookup embedding ids
|
|
||||||
|
|
||||||
(values, shape) <- TF.runSession $ do
|
(values, shape) <- TF.runSession $ do
|
||||||
vs <- op
|
embedding <- mapM TF.render [ TF.constant embShape embedding1
|
||||||
|
, TF.constant embShape embedding2
|
||||||
|
]
|
||||||
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
vs <- embeddingLookup embedding ids
|
||||||
TF.run (vs, TF.shape vs)
|
TF.run (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
|
@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape =
|
||||||
, 0, 0, 0 :: Int32
|
, 0, 0, 0 :: Int32
|
||||||
]
|
]
|
||||||
|
|
||||||
let embedding = TF.constant embShape embeddingInit
|
|
||||||
let idValues = [0, 1 :: Int32]
|
let idValues = [0, 1 :: Int32]
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
|
||||||
let op = embeddingLookup [embedding] ids
|
|
||||||
|
|
||||||
(values, shape) <- TF.runSession $ do
|
(values, shape) <- TF.runSession $ do
|
||||||
vs <- op
|
embedding <- TF.render $ TF.constant embShape embeddingInit
|
||||||
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
vs <- embeddingLookup [embedding] ids
|
||||||
TF.run (vs, TF.shape vs)
|
TF.run (vs, TF.shape vs)
|
||||||
|
|
||||||
-- This is the shape that is returned in the equiv. Python.
|
-- This is the shape that is returned in the equiv. Python.
|
||||||
|
@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape =
|
||||||
-- "[0, 1]" should pull out the resulting vector.
|
-- "[0, 1]" should pull out the resulting vector.
|
||||||
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||||
|
|
||||||
|
|
||||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||||
testEmbeddingLookupGradients :: Test
|
testEmbeddingLookupGradients :: Test
|
||||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
|
@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
|
|
||||||
x <- TF.placeholder (TF.Shape [2])
|
x <- TF.placeholder (TF.Shape [2])
|
||||||
embedding <- TF.initializedVariable
|
embedding <- TF.initializedVariable
|
||||||
=<< TF.render (TF.constant embShape embeddingInit)
|
(TF.constant embShape embeddingInit)
|
||||||
|
|
||||||
op <- embeddingLookup [embedding] ids
|
op <- embeddingLookup [embedding] ids
|
||||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x)
|
||||||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||||
|
|
||||||
grad <- fmap head (TF.gradients loss [embedding])
|
grad <- fmap head (TF.gradients loss [embedding])
|
||||||
|
@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit
|
||||||
(LookupExample numParts
|
(LookupExample numParts
|
||||||
shape@(TF.Shape (firstDim : restDims))
|
shape@(TF.Shape (firstDim : restDims))
|
||||||
values
|
values
|
||||||
indices) =
|
indices) = monadicIO $ run $ TF.runSession $ do
|
||||||
let modShardedValues :: [TF.Tensor TF.Value a] =
|
let shapedValues = TF.constant shape values
|
||||||
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
indicesVector <- TF.render $ TF.vector indices
|
||||||
cyclicCounter :: TF.Tensor TF.Value Int32 =
|
let directs = CoreOps.gather shapedValues indicesVector
|
||||||
|
let cyclicCounter :: TF.Tensor TF.Build Int32 =
|
||||||
TF.vector [0..fromIntegral firstDim-1]
|
TF.vector [0..fromIntegral firstDim-1]
|
||||||
`CoreOps.mod` fromIntegral numParts
|
`CoreOps.mod` fromIntegral numParts
|
||||||
indicesVector = TF.vector indices
|
modShardedValues :: [TF.Tensor TF.Value a] <-
|
||||||
directs = CoreOps.gather shapedValues indicesVector
|
mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
||||||
shapedValues = TF.constant shape values
|
|
||||||
in monadicIO $ run $ do
|
|
||||||
(shapeOut, got, want :: V.Vector a) <-
|
|
||||||
TF.runSession $ TF.run =<< do
|
|
||||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
(shapeOut, got, want :: V.Vector a) <-
|
||||||
|
TF.run (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||||
shapeOut @=? V.fromList (genericLength indices : restDims)
|
liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims)
|
||||||
got @=? want
|
liftIO $ got @=? want
|
||||||
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
||||||
|
|
||||||
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
||||||
|
|
|
@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
||||||
|
|
||||||
testGradientSimple :: Test
|
testGradientSimple :: Test
|
||||||
testGradientSimple = testCase "testGradientSimple" $ do
|
testGradientSimple = testCase "testGradientSimple" $ do
|
||||||
let x = TF.scalar (3 :: Float)
|
let grads = do
|
||||||
b = TF.scalar (4 :: Float)
|
x <- TF.render $ TF.scalar (3 :: Float)
|
||||||
y = x*x + b
|
b <- TF.render $ TF.scalar (4 :: Float)
|
||||||
grads = TF.gradients y [x, b]
|
let y = x `TF.mul` x `TF.add` b
|
||||||
|
TF.gradients y [x, b]
|
||||||
-- Assert that the gradients are right.
|
-- Assert that the gradients are right.
|
||||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||||
6 @=? TF.unScalar dx
|
6 @=? TF.unScalar dx
|
||||||
|
@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do
|
||||||
|
|
||||||
testGradientDisconnected :: Test
|
testGradientDisconnected :: Test
|
||||||
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||||
let x = TF.scalar (3 :: Float)
|
let grads = do
|
||||||
b = TF.scalar (4 :: Float)
|
x <- TF.render $ TF.scalar (3 :: Float)
|
||||||
grads = TF.gradients x [x, b]
|
b <- TF.render $ TF.scalar (4 :: Float)
|
||||||
|
TF.gradients x [x, b]
|
||||||
-- Assert that the gradients are right.
|
-- Assert that the gradients are right.
|
||||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||||
1 @=? TF.unScalar dx
|
1 @=? TF.unScalar dx
|
||||||
|
@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||||
let shape = TF.constant (TF.Shape [1]) [1]
|
let shape = TF.constant (TF.Shape [1]) [1]
|
||||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||||
TF.gradients (x + y*3) [x, y] >>= TF.run
|
TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run
|
||||||
-- If this test fails, it will likely be caused by an exception within
|
-- If this test fails, it will likely be caused by an exception within
|
||||||
-- `TF.gradients`. These asserts are extra.
|
-- `TF.gradients`. These asserts are extra.
|
||||||
1 @=? TF.unScalar dx
|
1 @=? TF.unScalar dx
|
||||||
|
@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||||
testDiamond :: Test
|
testDiamond :: Test
|
||||||
testDiamond = testCase "testDiamond" $ do
|
testDiamond = testCase "testDiamond" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
let x = TF.vector [1]
|
x <- TF.render $ TF.vector [1]
|
||||||
y = x*x
|
let y = x `TF.mul` x
|
||||||
z = y*y
|
z = y*y
|
||||||
TF.gradients z [x] >>= TF.run
|
TF.gradients z [x] >>= TF.run
|
||||||
(4 :: Float) @=? TF.unScalar dx
|
(4 :: Float) @=? TF.unScalar dx
|
||||||
|
@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do
|
||||||
testMaxGradient :: Test
|
testMaxGradient :: Test
|
||||||
testMaxGradient = testCase "testMaxGradient" $ do
|
testMaxGradient = testCase "testMaxGradient" $ do
|
||||||
[dx] <- TF.runSession $ do
|
[dx] <- TF.runSession $ do
|
||||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
let y = TF.max x (0 :: TF.Tensor TF.Build Int32)
|
||||||
TF.gradients y [x] >>= TF.run
|
TF.gradients y [x] >>= TF.run
|
||||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||||
|
|
||||||
|
|
|
@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $
|
||||||
withSystemTempDirectory "" $ \dirPath -> do
|
withSystemTempDirectory "" $ \dirPath -> do
|
||||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||||
var = TF.render =<<
|
var = TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||||
TF.zeroInitializedVariable' (TF.opName .~ "a")
|
|
||||||
(TF.Shape [])
|
(TF.Shape [])
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
v <- var
|
v <- var
|
||||||
|
@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
||||||
p2 <- TF.placeholder []
|
p2 <- TF.placeholder []
|
||||||
let enc :: Float -> TF.TensorData Float
|
let enc :: Float -> TF.TensorData Float
|
||||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
||||||
|
$ p1 `TF.add` p2
|
||||||
liftIO $ result @=? TF.Scalar 5
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
-- | Test that regular tensors can also be used for feeds, as long as they each
|
-- | Test that regular tensors can also be used for feeds, as long as they each
|
||||||
|
@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
||||||
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
||||||
let enc :: Float -> TF.TensorData Float
|
let enc :: Float -> TF.TensorData Float
|
||||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
||||||
|
$ p1 `TF.add` p2
|
||||||
liftIO $ result @=? TF.Scalar 5
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
|
|
|
@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do
|
||||||
return (w', b')
|
return (w', b')
|
||||||
|
|
||||||
gradientDescent :: Float
|
gradientDescent :: Float
|
||||||
-> TF.Tensor TF.Value Float
|
-> TF.Tensor TF.Build Float
|
||||||
-> [TF.Tensor TF.Ref Float]
|
-> [TF.Tensor TF.Ref Float]
|
||||||
-> TF.Session TF.ControlNode
|
-> TF.Session TF.ControlNode
|
||||||
gradientDescent alpha loss params = do
|
gradientDescent alpha loss params = do
|
||||||
|
|
|
@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||||
let floatData = V.fromList [1..6 :: Float]
|
let floatData = V.fromList [1..6 :: Float]
|
||||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||||
boolData = V.fromList [True, True, False, True, False, False]
|
boolData = V.fromList [True, True, False, True, False, False]
|
||||||
f <- TF.build $ TF.placeholder [2,3]
|
f <- TF.placeholder [2,3]
|
||||||
s <- TF.build $ TF.placeholder [2,3]
|
s <- TF.placeholder [2,3]
|
||||||
b <- TF.build $ TF.placeholder [2,3]
|
b <- TF.placeholder [2,3]
|
||||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||||
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||||
, TF.feed b (TF.encodeTensorData [2,3] boolData)
|
, TF.feed b (TF.encodeTensorData [2,3] boolData)
|
||||||
|
@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||||
-- Do something idempotent to the tensors to verify that tensorflow can
|
-- Do something idempotent to the tensors to verify that tensorflow can
|
||||||
-- handle the encoding. Originally this used `TF.identity`, but that
|
-- handle the encoding. Originally this used `TF.identity`, but that
|
||||||
-- wasn't enough to catch a bug in the encoding of Bool.
|
-- wasn't enough to catch a bug in the encoding of Bool.
|
||||||
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
|
(f', s', b') <- TF.runWithFeeds feeds
|
||||||
|
(f `TF.add` 0, TF.identity s, TF.select b b b)
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
floatData @=? f'
|
floatData @=? f'
|
||||||
stringData @=? s'
|
stringData @=? s'
|
||||||
|
|
|
@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||||
-- under the given name across multiple sessions.
|
-- under the given name across multiple sessions.
|
||||||
-> m (Queue as)
|
-> m (Queue as)
|
||||||
makeQueue capacity sharedName = do
|
makeQueue capacity sharedName = do
|
||||||
q <- build $ buildOp (opDef "FIFOQueue"
|
q <- build $ buildOp [] (opDef "FIFOQueue"
|
||||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||||
& opAttr "shared_name" .~ sharedName
|
& opAttr "shared_name" .~ sharedName
|
||||||
& opAttr "capacity" .~ capacity
|
& opAttr "capacity" .~ capacity
|
||||||
|
|
|
@ -13,9 +13,13 @@
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
{-# LANGUAGE Rank2Types #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
module TensorFlow.Build
|
module TensorFlow.Build
|
||||||
( -- * Graph node types
|
( -- * Graph node types
|
||||||
ControlNode(..)
|
ControlNode(..)
|
||||||
|
@ -32,8 +36,6 @@ module TensorFlow.Build
|
||||||
, opControlInputs
|
, opControlInputs
|
||||||
-- * The Build monad
|
-- * The Build monad
|
||||||
, GraphState
|
, GraphState
|
||||||
, render
|
|
||||||
, renderNodeName
|
|
||||||
, renderedNodeDefs
|
, renderedNodeDefs
|
||||||
, BuildT
|
, BuildT
|
||||||
, Build
|
, Build
|
||||||
|
@ -46,27 +48,23 @@ module TensorFlow.Build
|
||||||
, addGraphDef
|
, addGraphDef
|
||||||
, flushInitializers
|
, flushInitializers
|
||||||
, flushNodeBuffer
|
, flushNodeBuffer
|
||||||
|
, summaries
|
||||||
-- * Creating and looking up Ops
|
-- * Creating and looking up Ops
|
||||||
, getOrAddOp
|
, getOrAddOp
|
||||||
, addNewOp
|
, addNewOp
|
||||||
, renderOutput
|
, encodeOutput
|
||||||
|
, lookupNode
|
||||||
-- * Modifying all nodes in a Build action
|
-- * Modifying all nodes in a Build action
|
||||||
, colocateWith
|
|
||||||
, withStateLens
|
, withStateLens
|
||||||
, withDevice
|
, withDevice
|
||||||
, withNameScope
|
, withNameScope
|
||||||
, withNodeDependencies
|
, withNodeDependencies
|
||||||
-- * Internal Summary related bits.
|
|
||||||
, addSummary
|
|
||||||
, SummaryTensor
|
|
||||||
, collectAllSummaries
|
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||||
import Data.ByteString (ByteString)
|
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Functor.Identity (Identity(..))
|
import Data.Functor.Identity (Identity(..))
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef
|
||||||
|
|
||||||
import TensorFlow.Orphans ()
|
import TensorFlow.Orphans ()
|
||||||
import TensorFlow.Output
|
import TensorFlow.Output
|
||||||
import TensorFlow.Tensor
|
|
||||||
|
|
||||||
newtype Unique = Unique Int
|
newtype Unique = Unique Int
|
||||||
deriving (Eq, Ord, Enum)
|
deriving (Eq, Ord, Enum)
|
||||||
|
@ -125,9 +122,6 @@ opDefWithName n t = OpDef
|
||||||
, _opControlInputs = []
|
, _opControlInputs = []
|
||||||
}
|
}
|
||||||
|
|
||||||
-- | Synonym for the tensors that return serialized Summary proto.
|
|
||||||
type SummaryTensor = Tensor Value ByteString
|
|
||||||
|
|
||||||
data GraphState = GraphState
|
data GraphState = GraphState
|
||||||
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
||||||
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
||||||
|
@ -148,8 +142,8 @@ data GraphState = GraphState
|
||||||
, _initializationNodes :: [NodeName]
|
, _initializationNodes :: [NodeName]
|
||||||
-- ^ The nodes to run next time a TF.run is issued, typically
|
-- ^ The nodes to run next time a TF.run is issued, typically
|
||||||
-- variable initializers.
|
-- variable initializers.
|
||||||
, _summaries :: [SummaryTensor]
|
, _summaries :: [Output]
|
||||||
-- ^ The tensors for summary
|
-- ^ The tensors for summary (ByteString type)
|
||||||
}
|
}
|
||||||
|
|
||||||
-- | A node definition without its final name. Used as a key in the
|
-- | A node definition without its final name. Used as a key in the
|
||||||
|
@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs
|
||||||
initializationNodes :: Lens' GraphState [NodeName]
|
initializationNodes :: Lens' GraphState [NodeName]
|
||||||
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
||||||
|
|
||||||
summaries :: Lens' GraphState [SummaryTensor]
|
summaries :: Lens' GraphState [Output]
|
||||||
summaries = lens _summaries (\g x -> g { _summaries = x })
|
summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||||
|
|
||||||
-- | An action for building nodes in a TensorFlow graph.
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
|
@ -238,9 +232,7 @@ flushInitializers = do
|
||||||
-- | Registers the given node to be executed before the next
|
-- | Registers the given node to be executed before the next
|
||||||
-- 'TensorFlow.Session.run'.
|
-- 'TensorFlow.Session.run'.
|
||||||
addInitializer :: MonadBuild m => ControlNode -> m ()
|
addInitializer :: MonadBuild m => ControlNode -> m ()
|
||||||
addInitializer (ControlNode o) = build $ do
|
addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
|
||||||
i <- getOrAddOp o
|
|
||||||
initializationNodes %= (i:)
|
|
||||||
|
|
||||||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||||
-- the given 'Build' action.
|
-- the given 'Build' action.
|
||||||
|
@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node
|
||||||
|
|
||||||
-- | Render the given op if it hasn't been rendered already, and return its
|
-- | Render the given op if it hasn't been rendered already, and return its
|
||||||
-- name.
|
-- name.
|
||||||
getOrAddOp :: Op -> Build NodeName
|
getOrAddOp :: OpDef -> Build NodeName
|
||||||
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
|
getOrAddOp o = do
|
||||||
|
|
||||||
resolveOp :: Op -> Build NodeDef
|
|
||||||
resolveOp (Rendered n) = return n
|
|
||||||
resolveOp (Unrendered o) = do
|
|
||||||
pending <- getPendingNode o
|
pending <- getPendingNode o
|
||||||
uses renderedNodes (Map.lookup pending) >>= \case
|
uses renderedNodes (Map.lookup pending) >>= \case
|
||||||
Just n -> return n
|
Just n -> return $ NodeName $ n ^. name
|
||||||
Nothing -> addNewOpFromPending pending
|
Nothing -> addNewOpFromPending pending
|
||||||
|
|
||||||
|
lookupNode :: NodeName -> Build NodeDef
|
||||||
|
lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case
|
||||||
|
Just n' -> return n'
|
||||||
|
Nothing -> error $ "lookupNode: unknown node name " ++ show n
|
||||||
|
|
||||||
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
||||||
-- which are not safe to dedup (e.g, "variable" and "assign").
|
-- which are not safe to dedup (e.g, "variable" and "assign").
|
||||||
addNewOp :: OpDef -> Build NodeDef
|
addNewOp :: OpDef -> Build NodeName
|
||||||
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
||||||
|
|
||||||
addNewOpFromPending :: PendingNode -> Build NodeDef
|
addNewOpFromPending :: PendingNode -> Build NodeName
|
||||||
addNewOpFromPending pending = do
|
addNewOpFromPending pending = do
|
||||||
nodeName <- renderPendingNode pending
|
nodeName <- renderPendingNode pending
|
||||||
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
||||||
nodeBuffer %= (nodeDef :)
|
nodeBuffer %= (nodeDef :)
|
||||||
renderedNodes %= Map.insert pending nodeDef
|
renderedNodes %= Map.insert pending nodeDef
|
||||||
renderedNodeDefs %= Map.insert nodeName nodeDef
|
renderedNodeDefs %= Map.insert nodeName nodeDef
|
||||||
return nodeDef
|
return nodeName
|
||||||
|
|
||||||
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
||||||
-- been rendered before. Implicitly renders all of this node's inputs.
|
-- been rendered before. Implicitly renders all of this node's inputs.
|
||||||
|
@ -287,20 +280,18 @@ getPendingNode o = do
|
||||||
-- An empty string in the proto field means that no specific
|
-- An empty string in the proto field means that no specific
|
||||||
-- device is specified.
|
-- device is specified.
|
||||||
dev <- maybe "" deviceName <$> use defaultDevice
|
dev <- maybe "" deviceName <$> use defaultDevice
|
||||||
inputs <- mapM getInput (o ^. opInputs)
|
|
||||||
scope <- use currentScope
|
scope <- use currentScope
|
||||||
controls <- use defaultControlInputs
|
controls <- use defaultControlInputs
|
||||||
|
let inputs = map encodeOutput (o ^. opInputs)
|
||||||
let controlInputs
|
let controlInputs
|
||||||
= map getDep (o ^. opControlInputs ++ Set.toList controls)
|
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
||||||
return $ PendingNode scope (o ^. opName)
|
return $ PendingNode scope (o ^. opName)
|
||||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||||
& attr .~ _opAttrs o
|
& attr .~ _opAttrs o
|
||||||
& input .~ (inputs ++ controlInputs)
|
& input .~ (inputs ++ controlInputs)
|
||||||
& device .~ dev
|
& device .~ dev
|
||||||
where
|
where
|
||||||
getInput (Output (OutputIx k) subOp)
|
makeDep = ("^" <>) . unNodeName
|
||||||
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
|
|
||||||
getDep = ("^" <>) . unNodeName
|
|
||||||
|
|
||||||
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
||||||
-- if the name is implicit, assign a new unique name based on the op type.
|
-- if the name is implicit, assign a new unique name based on the op type.
|
||||||
|
@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
|
||||||
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||||
|
|
||||||
|
|
||||||
-- | Render an 'Output' and return a string representation for the TensorFlow
|
-- | Turn an 'Output' into a string representation for the TensorFlow
|
||||||
-- foreign APIs.
|
-- foreign APIs.
|
||||||
renderOutput :: Output -> Build Text
|
encodeOutput :: Output -> Text
|
||||||
renderOutput (Output (OutputIx i) o) = do
|
encodeOutput (Output (OutputIx 0) n) = unNodeName n
|
||||||
n <- getOrAddOp o
|
encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i)
|
||||||
return $ unNodeName n <> Text.pack (":" ++ show i)
|
|
||||||
|
|
||||||
-- | Modify some part of the state, run an action, and restore the state
|
-- | Modify some part of the state, run an action, and restore the state
|
||||||
-- after that action is done.
|
-- after that action is done.
|
||||||
|
@ -339,15 +329,6 @@ withStateLens accessor f act = do
|
||||||
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
||||||
withDevice d = withStateLens defaultDevice (const d)
|
withDevice d = withStateLens defaultDevice (const d)
|
||||||
|
|
||||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
|
||||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
|
||||||
-- the action has side effects of rendering the desired tensors. A pure
|
|
||||||
-- return would not have the desired effect.
|
|
||||||
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
|
|
||||||
colocateWith t x = do
|
|
||||||
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
|
||||||
withDevice (Just d) x
|
|
||||||
|
|
||||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||||
withNameScope :: MonadBuild m => Text -> m a -> m a
|
withNameScope :: MonadBuild m => Text -> m a -> m a
|
||||||
withNameScope s = withStateLens currentScope (Scope s :)
|
withNameScope s = withStateLens currentScope (Scope s :)
|
||||||
|
@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :)
|
||||||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||||
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
||||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||||
|
|
||||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
|
||||||
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
|
|
||||||
-- weren't already rendered.
|
|
||||||
--
|
|
||||||
-- This operation is idempotent; @render >=> render === render@. However,
|
|
||||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
|
||||||
-- may result in two different 'Tensor's.
|
|
||||||
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
|
|
||||||
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
|
|
||||||
|
|
||||||
-- | Render a 'Tensor' and get its node's name.
|
|
||||||
renderNodeName :: Tensor v a -> Build NodeName
|
|
||||||
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
|
||||||
|
|
||||||
-- | Records the given summary action in Build for retrieval with
|
|
||||||
-- 'collectAllSummaries'. The summary op is required to produce a
|
|
||||||
-- Summary protocol buffer in string form. For safety, use the
|
|
||||||
-- pre-composed functions: Logging.scalarSummary and
|
|
||||||
-- Logging.histogramSummary.
|
|
||||||
addSummary :: SummaryTensor -> Build ()
|
|
||||||
addSummary t = summaries %= (t :)
|
|
||||||
|
|
||||||
-- | Retrieves the summary ops collected thus far. Typically this only
|
|
||||||
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
|
||||||
-- repeatedly, the values accumulate.
|
|
||||||
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
|
|
||||||
collectAllSummaries = use summaries
|
|
||||||
|
|
|
@ -12,26 +12,27 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TupleSections #-}
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
|
||||||
module TensorFlow.BuildOp
|
module TensorFlow.BuildOp
|
||||||
( OpResult
|
( BuildResult(..)
|
||||||
, BuildOp
|
|
||||||
, buildOp
|
, buildOp
|
||||||
, buildListOp
|
, PureResult(..)
|
||||||
|
, pureOp
|
||||||
, eqLengthGuard
|
, eqLengthGuard
|
||||||
|
, BuildInputs(..)
|
||||||
, OpParams
|
, OpParams
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
|
|
||||||
import Control.Monad (replicateM)
|
import Control.Monad (liftM2, replicateM)
|
||||||
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
||||||
import Control.Monad.State.Strict (State, runState, get, put)
|
import Control.Monad.State.Strict (State, evalState, get, put)
|
||||||
import Data.Int (Int64)
|
import Data.Int (Int64)
|
||||||
import Lens.Family2 ((&), (<>~), (^.))
|
|
||||||
|
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Output
|
import TensorFlow.Output
|
||||||
|
@ -40,48 +41,45 @@ import TensorFlow.Types
|
||||||
|
|
||||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||||
|
|
||||||
type Result = ReaderT Op (State ResultState)
|
type Result = ReaderT NodeName (State ResultState)
|
||||||
|
|
||||||
-- | Class of types that can be used as op outputs.
|
-- | Class of types that can be used as op outputs.
|
||||||
class OpResult a where
|
class BuildResult a where
|
||||||
toResult :: Result a
|
buildResult :: Result a
|
||||||
|
|
||||||
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
|
instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
|
||||||
toResult = (,) <$> toResult <*> toResult
|
buildResult = (,) <$> buildResult <*> buildResult
|
||||||
|
|
||||||
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
|
instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
|
||||||
toResult = (,,) <$> toResult <*> toResult <*> toResult
|
buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult
|
||||||
|
|
||||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
|
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
|
||||||
=> OpResult (a1, a2, a3, a4) where
|
=> BuildResult (a1, a2, a3, a4) where
|
||||||
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
|
buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult
|
||||||
|
|
||||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
|
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
|
||||||
=> OpResult (a1, a2, a3, a4, a5) where
|
=> BuildResult (a1, a2, a3, a4, a5) where
|
||||||
toResult = (,,,,) <$> toResult
|
buildResult = (,,,,) <$> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
|
|
||||||
instance ( OpResult a1
|
instance ( BuildResult a1
|
||||||
, OpResult a2
|
, BuildResult a2
|
||||||
, OpResult a3
|
, BuildResult a3
|
||||||
, OpResult a4
|
, BuildResult a4
|
||||||
, OpResult a5
|
, BuildResult a5
|
||||||
, OpResult a6
|
, BuildResult a6
|
||||||
)
|
)
|
||||||
=> OpResult (a1, a2, a3, a4, a5, a6) where
|
=> BuildResult (a1, a2, a3, a4, a5, a6) where
|
||||||
toResult = (,,,,,)
|
buildResult = (,,,,,)
|
||||||
<$> toResult
|
<$> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
<*> toResult
|
<*> buildResult
|
||||||
|
|
||||||
tensorResult :: TensorKind v -> Result (Tensor v a)
|
|
||||||
tensorResult v = Tensor v <$> recordResult
|
|
||||||
|
|
||||||
recordResult :: Result Output
|
recordResult :: Result Output
|
||||||
recordResult = do
|
recordResult = do
|
||||||
|
@ -90,144 +88,39 @@ recordResult = do
|
||||||
put $! ResultState (i+1) ns
|
put $! ResultState (i+1) ns
|
||||||
return $! output i o
|
return $! output i o
|
||||||
|
|
||||||
instance OpResult ResourceHandle where
|
instance BuildResult ResourceHandle where
|
||||||
toResult = ResourceHandle <$> recordResult
|
buildResult = ResourceHandle <$> recordResult
|
||||||
|
|
||||||
instance OpResult (Tensor Value a) where
|
instance Rendered v => BuildResult (Tensor v a) where
|
||||||
toResult = tensorResult ValueKind
|
buildResult = Tensor . pure <$> recordResult
|
||||||
|
|
||||||
instance OpResult (Tensor Ref a) where
|
instance BuildResult ControlNode where
|
||||||
toResult = tensorResult RefKind
|
buildResult = ControlNode <$> ask
|
||||||
|
|
||||||
instance OpResult ControlNode where
|
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
|
||||||
toResult = ControlNode <$> ask
|
buildResult = loop (tensorTypes :: TensorTypeList as)
|
||||||
|
|
||||||
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
|
|
||||||
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
|
|
||||||
where
|
where
|
||||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||||
loop Nil = return Nil
|
loop Nil = return Nil
|
||||||
loop (TensorTypeProxy :/ ls) = do
|
loop (TensorTypeProxy :/ ls) = do
|
||||||
t <- tensorResult v
|
t <- buildResult
|
||||||
ts <- loop ls
|
ts <- loop ls
|
||||||
return (t :/ ts)
|
return (t :/ ts)
|
||||||
|
|
||||||
instance TensorTypes as => OpResult (TensorList Value as) where
|
instance BuildResult a => BuildResult [a] where
|
||||||
toResult = tensorListResult ValueKind
|
buildResult = do
|
||||||
|
|
||||||
instance TensorTypes as => OpResult (TensorList Ref as) where
|
|
||||||
toResult = tensorListResult RefKind
|
|
||||||
|
|
||||||
instance OpResult a => OpResult [a] where
|
|
||||||
toResult = do
|
|
||||||
ResultState i ns <- get
|
ResultState i ns <- get
|
||||||
case ns of
|
case ns of
|
||||||
[] -> error $ "Ran out of counts in toResult. " ++
|
[] -> error $ "Ran out of counts in buildResult. " ++
|
||||||
"Likely misuse of buildListOp."
|
"Likely misuse of buildOp."
|
||||||
(n : rest) -> do
|
(n : rest) -> do
|
||||||
put $! ResultState i rest
|
put $! ResultState i rest
|
||||||
replicateM (fromIntegral n) toResult
|
replicateM (fromIntegral n) buildResult
|
||||||
|
|
||||||
runResult :: OpResult a => [Int64] -> Op -> a
|
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
|
||||||
runResult ns o =
|
buildOp sizes o = do
|
||||||
case runState (runReaderT toResult o) (ResultState 0 ns) of
|
n <- addNewOp o
|
||||||
(x, ResultState _ []) -> x
|
return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)
|
||||||
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
|
|
||||||
show (ns, ns')
|
|
||||||
|
|
||||||
-- | Make a new "pure" op, which may be deduped with identical ops within
|
|
||||||
-- the same scope.
|
|
||||||
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
|
|
||||||
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
|
|
||||||
|
|
||||||
-- | Make a new "stateful" op, which will not be deduped with otherwise
|
|
||||||
-- identical ops.
|
|
||||||
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
|
|
||||||
buildResult ns o ts
|
|
||||||
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
|
|
||||||
|
|
||||||
addReversedInputs :: OpDef -> [Output] -> OpDef
|
|
||||||
addReversedInputs o ts = o & opInputs <>~ reverse ts
|
|
||||||
|
|
||||||
-- | Class of types that can be used as op functions.
|
|
||||||
class BuildOp f where
|
|
||||||
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
|
|
||||||
-> OpDef
|
|
||||||
-> [Output] -- ^ Accumulator for inputs to the op.
|
|
||||||
-> f
|
|
||||||
|
|
||||||
-- | Starts an operation that returns a structured set of tensors
|
|
||||||
-- (singletons or tuples).
|
|
||||||
buildOp :: BuildOp f => OpDef -> f
|
|
||||||
buildOp o = buildOp' [] o []
|
|
||||||
|
|
||||||
-- | Starts an operation that returns a list of tensors.
|
|
||||||
buildListOp :: BuildOp f => [Int64]
|
|
||||||
-- ^ Cardinality of the corresponding list of tensors output.
|
|
||||||
-> OpDef -> f
|
|
||||||
buildListOp counts o = buildOp' counts o []
|
|
||||||
|
|
||||||
instance BuildOp ControlNode where
|
|
||||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
|
||||||
|
|
||||||
instance BuildOp ResourceHandle where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance BuildOp (Tensor Value a) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance BuildOp (Tensor Ref a) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance TensorTypes as => BuildOp (TensorList Value as) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance TensorTypes as => BuildOp (TensorList Ref as) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance BuildOp [Tensor Value a] where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
|
|
||||||
=> BuildOp (t1, t2, t3, t4) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
|
|
||||||
=> BuildOp (t1, t2, t3, t4, t5) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance ( OpResult t1
|
|
||||||
, OpResult t2
|
|
||||||
, OpResult t3
|
|
||||||
, OpResult t4
|
|
||||||
, OpResult t5
|
|
||||||
, OpResult t6
|
|
||||||
)
|
|
||||||
=> BuildOp (t1, t2, t3, t4, t5, t6) where
|
|
||||||
buildOp' = pureResult
|
|
||||||
|
|
||||||
instance OpResult a => BuildOp (Build a) where
|
|
||||||
buildOp' = buildResult
|
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (ResourceHandle -> f) where
|
|
||||||
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
|
||||||
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
|
||||||
|
|
||||||
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
|
||||||
buildOp' rf o accum ts
|
|
||||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (TensorList v as -> f) where
|
|
||||||
buildOp' rf o accum ts
|
|
||||||
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
|
|
||||||
|
|
||||||
-- | Returns true if all the integers in each tuple are identical.
|
-- | Returns true if all the integers in each tuple are identical.
|
||||||
-- Throws an error with a descriptive message if not.
|
-- Throws an error with a descriptive message if not.
|
||||||
|
@ -240,6 +133,104 @@ eqLengthGuard = all eachOk
|
||||||
error ("number_attr " ++ numberAttrName ++
|
error ("number_attr " ++ numberAttrName ++
|
||||||
" contains tensors with different length " ++ show pairs)
|
" contains tensors with different length " ++ show pairs)
|
||||||
|
|
||||||
|
-----------
|
||||||
|
|
||||||
|
|
||||||
|
-- | Class of types that can be used as op outputs.
|
||||||
|
class PureResult a where
|
||||||
|
pureResult :: ReaderT (Build OpDef) (State ResultState) a
|
||||||
|
|
||||||
|
instance PureResult (Tensor Build a) where
|
||||||
|
pureResult = do
|
||||||
|
ResultState i ns <- get
|
||||||
|
put $! ResultState (i+1) ns
|
||||||
|
makeOp <- ask
|
||||||
|
return $ Tensor $ do
|
||||||
|
o <- makeOp
|
||||||
|
-- TODO: unify with BuildResult (Tensor v)
|
||||||
|
output i <$> getOrAddOp o
|
||||||
|
|
||||||
|
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
|
||||||
|
pureResult = (,) <$> pureResult <*> pureResult
|
||||||
|
|
||||||
|
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
|
||||||
|
pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult
|
||||||
|
|
||||||
|
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
|
||||||
|
=> PureResult (a1, a2, a3, a4) where
|
||||||
|
pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult
|
||||||
|
|
||||||
|
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
|
||||||
|
=> PureResult (a1, a2, a3, a4, a5) where
|
||||||
|
pureResult = (,,,,) <$> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
|
||||||
|
instance ( PureResult a1
|
||||||
|
, PureResult a2
|
||||||
|
, PureResult a3
|
||||||
|
, PureResult a4
|
||||||
|
, PureResult a5
|
||||||
|
, PureResult a6
|
||||||
|
)
|
||||||
|
=> PureResult (a1, a2, a3, a4, a5, a6) where
|
||||||
|
pureResult = (,,,,,)
|
||||||
|
<$> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
<*> pureResult
|
||||||
|
|
||||||
|
instance PureResult a => PureResult [a] where
|
||||||
|
pureResult = do
|
||||||
|
ResultState i ns <- get
|
||||||
|
case ns of
|
||||||
|
[] -> error $ "Ran out of counts in pureResult. " ++
|
||||||
|
"Likely misuse of pureOp with output lists."
|
||||||
|
n : rest -> do
|
||||||
|
put $! ResultState i rest
|
||||||
|
replicateM (fromIntegral n) pureResult
|
||||||
|
|
||||||
|
instance TensorTypes as => PureResult (TensorList Build as) where
|
||||||
|
pureResult = loop (tensorTypes :: TensorTypeList as)
|
||||||
|
where
|
||||||
|
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
|
||||||
|
(TensorList Build bs)
|
||||||
|
loop Nil = return Nil
|
||||||
|
loop (TensorTypeProxy :/ ls) = do
|
||||||
|
t <- pureResult
|
||||||
|
ts <- loop ls
|
||||||
|
return (t :/ ts)
|
||||||
|
|
||||||
|
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
|
||||||
|
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)
|
||||||
|
|
||||||
|
-----
|
||||||
|
-- Class of types that can be used as arguments
|
||||||
|
|
||||||
|
class BuildInputs a where
|
||||||
|
buildInputs :: a -> Build [Output]
|
||||||
|
|
||||||
|
instance BuildInputs a => BuildInputs [a] where
|
||||||
|
buildInputs = fmap concat . mapM buildInputs
|
||||||
|
|
||||||
|
instance BuildInputs (Tensor v a) where
|
||||||
|
buildInputs (Tensor t) = do
|
||||||
|
o <- toBuild t
|
||||||
|
return [o]
|
||||||
|
|
||||||
|
instance BuildInputs (ListOf (Tensor v) as) where
|
||||||
|
buildInputs Nil = return []
|
||||||
|
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
|
||||||
|
|
||||||
|
instance BuildInputs ResourceHandle where
|
||||||
|
buildInputs (ResourceHandle o) = return [o]
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||||
-- TODO: be more type safe.
|
-- TODO: be more type safe.
|
||||||
type OpParams = OpDef -> OpDef
|
type OpParams = OpDef -> OpDef
|
||||||
|
|
|
@ -25,9 +25,6 @@ module TensorFlow.ControlFlow
|
||||||
, noOp
|
, noOp
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Set as Set
|
|
||||||
import Lens.Family2 ((&), (.~))
|
|
||||||
|
|
||||||
import TensorFlow.BuildOp
|
import TensorFlow.BuildOp
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Nodes
|
import TensorFlow.Nodes
|
||||||
|
@ -46,11 +43,8 @@ withControlDependencies deps act = do
|
||||||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||||
-- no output.
|
-- no output.
|
||||||
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
||||||
group deps = do
|
group deps = withControlDependencies deps noOp
|
||||||
nodes <- build $ Set.toList <$> getNodes deps
|
|
||||||
-- TODO: slicker way
|
|
||||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
|
||||||
|
|
||||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||||
noOp :: ControlNode
|
noOp :: MonadBuild m => m ControlNode
|
||||||
noOp = buildOp $ opDef "NoOp"
|
noOp = build $ buildOp [] $ opDef "NoOp"
|
||||||
|
|
|
@ -57,9 +57,9 @@ module TensorFlow.Core
|
||||||
, Tensor
|
, Tensor
|
||||||
, Value
|
, Value
|
||||||
, Ref
|
, Ref
|
||||||
, TensorKind(..)
|
|
||||||
, value
|
, value
|
||||||
, tensorFromName
|
, tensorFromName
|
||||||
|
, expr
|
||||||
-- ** Element types
|
-- ** Element types
|
||||||
, TensorType
|
, TensorType
|
||||||
, TensorData
|
, TensorData
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a)
|
||||||
module TensorFlow.Nodes where
|
module TensorFlow.Nodes where
|
||||||
|
|
||||||
import Control.Applicative (liftA2, liftA3)
|
import Control.Applicative (liftA2, liftA3)
|
||||||
|
@ -28,7 +29,6 @@ import Data.Map.Strict (Map)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import Lens.Family2 ((^.))
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where
|
||||||
getFetch ts = sequenceA <$> mapM getFetch ts
|
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||||
|
|
||||||
instance Nodes ControlNode where
|
instance Nodes ControlNode where
|
||||||
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
|
getNodes (ControlNode o) = pure $ Set.singleton o
|
||||||
|
|
||||||
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
|
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
|
||||||
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
|
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
|
||||||
|
@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
|
||||||
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
||||||
|
|
||||||
instance Nodes (Tensor v a) where
|
instance Nodes (Tensor v a) where
|
||||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o
|
||||||
|
|
||||||
fetchTensorVector :: forall a v . TensorType a
|
fetchTensorVector :: forall a v . (TensorType a)
|
||||||
=> Tensor v a -> Build (Fetch (TensorData a))
|
=> Tensor v a -> Build (Fetch (TensorData a))
|
||||||
fetchTensorVector (Tensor _ o) = do
|
fetchTensorVector (Tensor o) = do
|
||||||
outputName <- renderOutput o
|
outputName <- encodeOutput <$> toBuild o
|
||||||
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
pure $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||||
let tensorData = tensors Map.! outputName
|
let tensorData = tensors Map.! outputName
|
||||||
expectedType = tensorType (undefined :: a)
|
expectedType = tensorType (undefined :: a)
|
||||||
actualType = FFI.tensorDataType tensorData
|
actualType = FFI.tensorDataType tensorData
|
||||||
|
|
|
@ -22,8 +22,6 @@ module TensorFlow.Output
|
||||||
, Device(..)
|
, Device(..)
|
||||||
-- * Ops
|
-- * Ops
|
||||||
, NodeName(..)
|
, NodeName(..)
|
||||||
, Op(..)
|
|
||||||
, opUnrendered
|
|
||||||
, OpDef(..)
|
, OpDef(..)
|
||||||
, opName
|
, opName
|
||||||
, opType
|
, opType
|
||||||
|
@ -34,28 +32,24 @@ module TensorFlow.Output
|
||||||
, OutputIx(..)
|
, OutputIx(..)
|
||||||
, Output(..)
|
, Output(..)
|
||||||
, output
|
, output
|
||||||
, outputIndex
|
|
||||||
, outputOp
|
|
||||||
, PendingNodeName(..)
|
, PendingNodeName(..)
|
||||||
, ResourceHandle(..)
|
, ResourceHandle(..)
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import Data.ProtoLens.TextFormat (showMessage)
|
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import Data.Text (Text)
|
import Data.Text (Text)
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
|
import Lens.Family2 (Lens')
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||||
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
|
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import TensorFlow.Types (Attribute, attrLens)
|
import TensorFlow.Types (Attribute, attrLens)
|
||||||
import TensorFlow.Orphans ()
|
import TensorFlow.Orphans ()
|
||||||
|
|
||||||
-- | A type of graph node which has no outputs. These nodes are
|
-- | A type of graph node which has no outputs. These nodes are
|
||||||
-- valuable for causing side effects when they are run.
|
-- valuable for causing side effects when they are run.
|
||||||
newtype ControlNode = ControlNode { unControlNode :: Op }
|
newtype ControlNode = ControlNode { unControlNode :: NodeName }
|
||||||
|
|
||||||
-- | The type of op of a node in the graph. This corresponds to the proto field
|
-- | The type of op of a node in the graph. This corresponds to the proto field
|
||||||
-- NodeDef.op.
|
-- NodeDef.op.
|
||||||
|
@ -66,18 +60,12 @@ instance IsString OpType where
|
||||||
fromString = OpType . Text.pack
|
fromString = OpType . Text.pack
|
||||||
|
|
||||||
-- | An output of a TensorFlow node.
|
-- | An output of a TensorFlow node.
|
||||||
data Output = Output !OutputIx !Op
|
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
output :: OutputIx -> Op -> Output
|
output :: OutputIx -> NodeName -> Output
|
||||||
output = Output
|
output = Output
|
||||||
|
|
||||||
outputOp :: Lens' Output Op
|
|
||||||
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
|
|
||||||
|
|
||||||
outputIndex :: Lens' Output OutputIx
|
|
||||||
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
|
|
||||||
|
|
||||||
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
||||||
deriving (Eq, Ord, Num, Enum, Show)
|
deriving (Eq, Ord, Num, Enum, Show)
|
||||||
|
|
||||||
|
@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text}
|
||||||
instance Show Device where
|
instance Show Device where
|
||||||
show (Device d) = show d
|
show (Device d) = show d
|
||||||
|
|
||||||
-- | The representation of a node in a TensorFlow graph.
|
|
||||||
data Op
|
|
||||||
= Rendered !NodeDef -- ^ Properties are fixed, including the
|
|
||||||
-- device, name, and scope.
|
|
||||||
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
|
|
||||||
-- on which context this op is rendered in.
|
|
||||||
deriving (Eq, Ord)
|
|
||||||
|
|
||||||
instance Show Op where
|
|
||||||
show (Rendered n) = "Rendered " ++ showMessage n
|
|
||||||
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
|
|
||||||
|
|
||||||
-- | Traverse on the 'Unrendered' of an 'Op'.
|
|
||||||
--
|
|
||||||
-- Same implementation as _Left.
|
|
||||||
opUnrendered :: Traversal' Op OpDef
|
|
||||||
opUnrendered f (Unrendered a) = Unrendered <$> f a
|
|
||||||
opUnrendered _ (Rendered b) = pure (Rendered b)
|
|
||||||
|
|
||||||
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
||||||
data OpDef = OpDef
|
data OpDef = OpDef
|
||||||
{ _opName :: !PendingNodeName
|
{ _opName :: !PendingNodeName
|
||||||
|
@ -157,7 +126,7 @@ instance IsString Output where
|
||||||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||||
-> Output (fromInteger ix) $ assigned n
|
-> Output (fromInteger ix) $ assigned n
|
||||||
_ -> Output 0 $ assigned s
|
_ -> Output 0 $ assigned s
|
||||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
where assigned = NodeName . Text.pack
|
||||||
|
|
||||||
|
|
||||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||||
|
|
|
@ -163,7 +163,7 @@ runWithFeeds feeds t = do
|
||||||
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
||||||
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||||
extend
|
extend
|
||||||
feeds' <- build $ fixFeeds feeds
|
let feeds' = fixFeeds feeds
|
||||||
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||||
targetNames = toNodeNames $ Set.toList target
|
targetNames = toNodeNames $ Set.toList target
|
||||||
session <- Session (asks rawSession)
|
session <- Session (asks rawSession)
|
||||||
|
@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do
|
||||||
ns <- build $ getNodes t
|
ns <- build $ getNodes t
|
||||||
runFetchWithFeeds feeds ns (pure ())
|
runFetchWithFeeds feeds ns (pure ())
|
||||||
|
|
||||||
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
|
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
|
||||||
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
|
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
|
||||||
|
|
||||||
-- | Starts a concurrent thread which evaluates the given Nodes
|
-- | Starts a concurrent thread which evaluates the given Nodes
|
||||||
-- forever until runSession exits or an exception occurs. Graph
|
-- forever until runSession exits or an exception occurs. Graph
|
||||||
|
|
|
@ -16,21 +16,26 @@
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FunctionalDependencies #-}
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE Rank2Types #-}
|
{-# LANGUAGE Rank2Types #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE UndecidableInstances #-} -- For the Render class
|
||||||
|
|
||||||
module TensorFlow.Tensor where
|
module TensorFlow.Tensor where
|
||||||
|
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
import Data.String (IsString(..))
|
import Data.String (IsString(..))
|
||||||
import qualified Data.Text as Text
|
import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens', (^.))
|
import Lens.Family2 ((^.))
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.State ((%=), use)
|
||||||
|
|
||||||
import TensorFlow.Output (Output)
|
import Proto.Tensorflow.Core.Framework.NodeDef (device)
|
||||||
|
import TensorFlow.Build
|
||||||
|
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
|
||||||
import TensorFlow.Types
|
import TensorFlow.Types
|
||||||
( TensorData(..)
|
( TensorData(..)
|
||||||
, ListOf(..)
|
, ListOf(..)
|
||||||
|
@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI
|
||||||
-- | A named output of a TensorFlow operation.
|
-- | A named output of a TensorFlow operation.
|
||||||
--
|
--
|
||||||
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
|
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
|
||||||
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
|
-- parameter @v@ is either:
|
||||||
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
|
--
|
||||||
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
|
-- * 'Build': An unrendered, immutable value.
|
||||||
-- 'value'.
|
-- * 'Value': A rendered, immutable value.
|
||||||
data Tensor v a = Tensor (TensorKind v) Output
|
-- * 'Ref': A rendered stateful handle (e.g., a variable).
|
||||||
|
--
|
||||||
|
-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between
|
||||||
|
-- the different types of 'Tensor'.
|
||||||
|
data Tensor v a where
|
||||||
|
Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
|
||||||
|
|
||||||
data Value
|
newtype Value a = Value {runValue :: a}
|
||||||
data Ref
|
deriving Functor
|
||||||
|
|
||||||
-- | This class provides a runtime switch on whether a 'Tensor' should be
|
instance Applicative Value where
|
||||||
-- treated as a 'Value' or as a 'Ref'.
|
pure = Value
|
||||||
data TensorKind v where
|
Value f <*> Value x = Value $ f x
|
||||||
ValueKind :: TensorKind Value
|
|
||||||
RefKind :: TensorKind Ref
|
|
||||||
|
|
||||||
tensorKind :: Lens' (Tensor v a) (TensorKind v)
|
instance Monad Value where
|
||||||
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
f >>= g = g $ runValue f
|
||||||
|
|
||||||
tensorOutput :: Lens' (Tensor v a) Output
|
newtype Ref a = Ref {runRef :: a}
|
||||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
deriving Functor
|
||||||
|
|
||||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
instance Applicative Ref where
|
||||||
-- Ref into Value. This behaves like a no-op.
|
pure = Ref
|
||||||
value :: Tensor v a -> Tensor Value a
|
Ref f <*> Ref x = Ref $ f x
|
||||||
value (Tensor _ o) = Tensor ValueKind o
|
|
||||||
|
instance Monad Ref where
|
||||||
|
f >>= g = g $ runRef f
|
||||||
|
|
||||||
|
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
|
||||||
|
value :: Tensor Ref a -> Tensor Value a
|
||||||
|
value (Tensor o) = Tensor $ Value $ runRef o
|
||||||
|
|
||||||
|
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
|
||||||
|
renderValue (Tensor o) = render $ Tensor $ toBuild o
|
||||||
|
|
||||||
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
|
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
|
||||||
-- when running the graph.
|
-- when running the graph.
|
||||||
data Feed = Feed Output FFI.TensorData
|
data Feed = Feed Output FFI.TensorData
|
||||||
|
|
||||||
|
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
||||||
|
-- name, device, etc.
|
||||||
|
class TensorKind v => Rendered v where
|
||||||
|
rendered :: v a -> a
|
||||||
|
|
||||||
|
instance Rendered Value where
|
||||||
|
rendered = runValue
|
||||||
|
|
||||||
|
instance Rendered Ref where
|
||||||
|
rendered = runRef
|
||||||
|
|
||||||
|
renderedOutput :: Rendered v => Tensor v a -> Output
|
||||||
|
renderedOutput = rendered . tensorOutput
|
||||||
|
|
||||||
|
tensorNodeName :: Rendered v => Tensor v a -> NodeName
|
||||||
|
tensorNodeName = outputNodeName . renderedOutput
|
||||||
|
|
||||||
|
|
||||||
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
|
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
|
||||||
-- the graph.
|
-- the graph.
|
||||||
--
|
--
|
||||||
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||||
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||||
feed :: Tensor v a -> TensorData a -> Feed
|
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
|
||||||
feed (Tensor _ o) (TensorData td) = Feed o td
|
feed t (TensorData td) = Feed (renderedOutput t) td
|
||||||
|
|
||||||
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||||
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
|
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
|
||||||
-- TODO(judahjacobson): add more safety checks here.
|
-- TODO(judahjacobson): add more safety checks here.
|
||||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
|
||||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
tensorFromName = Tensor . pure . fromString . Text.unpack
|
||||||
|
|
||||||
|
-- | Like 'tensorFromName', but type-restricted to 'Value'.
|
||||||
|
tensorValueFromName :: Text.Text -> Tensor Value a
|
||||||
|
tensorValueFromName = tensorFromName
|
||||||
|
|
||||||
|
-- | Like 'tensorFromName', but type-restricted to 'Ref'.
|
||||||
|
tensorRefFromName :: Text.Text -> Tensor Ref a
|
||||||
|
tensorRefFromName = tensorFromName
|
||||||
|
|
||||||
type TensorList v = ListOf (Tensor v)
|
type TensorList v = ListOf (Tensor v)
|
||||||
|
|
||||||
tensorListOutputs :: TensorList v as -> [Output]
|
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
|
||||||
tensorListOutputs Nil = []
|
tensorListOutputs Nil = []
|
||||||
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
|
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
||||||
|
|
||||||
|
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||||
|
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||||
|
-- the action has side effects of rendering the desired tensors. A pure
|
||||||
|
-- return would not have the desired effect.
|
||||||
|
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
|
||||||
|
colocateWith t x = do
|
||||||
|
d <- build $ Device . (^. device)
|
||||||
|
<$> lookupNode (outputNodeName $ renderedOutput t)
|
||||||
|
withDevice (Just d) x
|
||||||
|
|
||||||
|
|
||||||
|
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||||
|
-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that
|
||||||
|
-- weren't already rendered.
|
||||||
|
--
|
||||||
|
-- This operation is idempotent; calling 'render' on the same input in the same
|
||||||
|
-- context will produce the same result. However, rendering the same
|
||||||
|
-- @Tensor Build@ in two different contexts may result in two different
|
||||||
|
-- @Tensor Value@s.
|
||||||
|
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
|
||||||
|
render (Tensor t) = Tensor . Value <$> build t
|
||||||
|
|
||||||
|
-- TODO: better name.
|
||||||
|
expr :: TensorKind v => Tensor v a -> Tensor Build a
|
||||||
|
expr (Tensor o) = Tensor $ toBuild o
|
||||||
|
|
||||||
|
-- | Records the given summary action in Build for retrieval with
|
||||||
|
-- Summary protocol buffer in string form. For safety, use the
|
||||||
|
-- pre-composed functions: Logging.scalarSummary and
|
||||||
|
-- Logging.histogramSummary.
|
||||||
|
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor'
|
||||||
|
-> m ()
|
||||||
|
addSummary t = build $ do
|
||||||
|
-- TODO: more generic way
|
||||||
|
o <- toBuild $ tensorOutput t
|
||||||
|
summaries %= (o :)
|
||||||
|
|
||||||
|
-- | Retrieves the summary ops collected thus far. Typically this only
|
||||||
|
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
||||||
|
-- repeatedly, the values accumulate.
|
||||||
|
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
|
||||||
|
collectAllSummaries = build $ map (Tensor . Value) <$> use summaries
|
||||||
|
|
||||||
|
-- | Synonym for the tensors that return serialized Summary proto.
|
||||||
|
type SummaryTensor = Tensor Value ByteString
|
||||||
|
|
||||||
|
-- | An internal class for kinds of Tensors.
|
||||||
|
class Monad v => TensorKind v where
|
||||||
|
toBuild :: v a -> Build a
|
||||||
|
|
||||||
|
instance TensorKind Value where
|
||||||
|
toBuild = return . rendered
|
||||||
|
|
||||||
|
instance TensorKind Ref where
|
||||||
|
toBuild = return . rendered
|
||||||
|
|
||||||
|
instance TensorKind Build where
|
||||||
|
toBuild = id
|
||||||
|
|
Loading…
Reference in a new issue