mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-06 00:49:47 +01:00
Distinguish between "rendered" and "unrendered" Tensors. (#88)
Distinguish between "rendered" and "unrendered" Tensors. There are now three types of `Tensor`: - `Tensor Value a`: rendered value - `Tensor Ref a`: rendered reference - `Tensor Build a` : unrendered value The extra bookkeeping makes it easier to track (and enforce) which tensors are rendered or not. For examples where this has been confusing in the past, see With this change, pure ops look similar to before, returning `Tensor Build` instead of `Tensor Value`. "Stateful" (monadic) ops are unchanged. For example: add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t assign :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t) The `gradients` function now requires that the variables over which it's differentiating are pre-rendered: gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a] -> m [Tensor Value a] (`Rendered v2` means that `v2` is either a `Ref` or a `Value`.) Additionally, the implementation of `gradients` now takes care to render every intermediate value when performing the reverse accumulation. I suspect this fixes an exponential blowup for complicated expressions.
This commit is contained in:
parent
b2193712db
commit
a7cbc27d36
29 changed files with 636 additions and 608 deletions
|
@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do
|
|||
return (w', b')
|
||||
|
||||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> TF.Tensor TF.Build Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
|
|
|
@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary)
|
|||
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
|
||||
import System.Directory (createDirectoryIfMissing)
|
||||
import System.FilePath ((</>))
|
||||
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||
import TensorFlow.Build (MonadBuild)
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Records.Conduit (sinkTFRecords)
|
||||
import TensorFlow.Tensor (Tensor)
|
||||
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
|
||||
import TensorFlow.Types (TensorType, type(/=))
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
|
@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime
|
|||
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
|
||||
-- limited to a single value for simplicity.
|
||||
histogramSummary ::
|
||||
(TensorType t, t /= ByteString, t /= Bool)
|
||||
(MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
|
||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||
=> ByteString -> Tensor v t -> Build ()
|
||||
=> ByteString -> Tensor v t -> m ()
|
||||
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
|
||||
|
||||
-- | Adds a 'CoreOps.scalarSummary' node.
|
||||
scalarSummary ::
|
||||
(TensorType t, t /= ByteString, t /= Bool)
|
||||
(TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
|
||||
-- (TensorType t,
|
||||
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
|
||||
=> ByteString -> Tensor v t -> Build ()
|
||||
=> ByteString -> Tensor v t -> m ()
|
||||
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
|
||||
|
||||
-- | Merge all summaries accumulated in the 'Build' into one summary.
|
||||
mergeAllSummaries :: Build SummaryTensor
|
||||
mergeAllSummaries :: MonadBuild m => m SummaryTensor
|
||||
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary
|
||||
|
|
|
@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64
|
|||
numLabels = 10 :: Int64
|
||||
|
||||
-- | Create tensor with random values where the stddev depends on the width.
|
||||
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
|
||||
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float)
|
||||
randomParam width (TF.Shape shape) =
|
||||
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
|
||||
(`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape)
|
||||
where
|
||||
stddev = TF.scalar (1 / sqrt (fromIntegral width))
|
||||
|
||||
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
|
||||
reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float
|
||||
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
|
||||
|
||||
-- Types must match due to model structure.
|
||||
|
@ -87,7 +87,7 @@ createModel = do
|
|||
grads <- TF.gradients loss params
|
||||
|
||||
let lr = TF.scalar 0.00001
|
||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
|
||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
||||
|
||||
let correctPredictions = TF.equal predict labels
|
||||
|
|
|
@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph
|
|||
import TensorFlow.Build
|
||||
( asGraphDef
|
||||
, addGraphDef
|
||||
, render
|
||||
, Build
|
||||
)
|
||||
import TensorFlow.Tensor
|
||||
( Tensor(..)
|
||||
, Ref
|
||||
, Value
|
||||
, feed
|
||||
, TensorKind(..)
|
||||
, render
|
||||
, tensorFromName
|
||||
, tensorValueFromName
|
||||
)
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Session
|
||||
|
@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do
|
|||
labelData <- readMNISTLabels =<< testLabelData
|
||||
10000 @=? length labelData
|
||||
|
||||
testNodeName :: Text -> Tensor v a -> Assertion
|
||||
testNodeName :: Text -> Tensor Build a -> Assertion
|
||||
testNodeName n g = n @=? opName
|
||||
where
|
||||
opName = head (gDef^.node)^.op
|
||||
|
@ -89,7 +89,7 @@ testNodeName n g = n @=? opName
|
|||
testGraphDefGen :: Test
|
||||
testGraphDefGen = testCase "testGraphDefGen" $ do
|
||||
-- Test the inferred operation type.
|
||||
let f0 :: Tensor Value Float
|
||||
let f0 :: Tensor Build Float
|
||||
f0 = 0
|
||||
testNodeName "Const" f0
|
||||
testNodeName "Add" $ 1 + f0
|
||||
|
@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
|
|||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||
runSession $ do
|
||||
addGraphDef graphDef
|
||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||
x <- run $ tensorValueFromName "Mul_2"
|
||||
liftIO $ (50 :: Float) @=? unScalar x
|
||||
|
||||
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
|
||||
|
@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
|||
build $ addGraphDef $ mnist & version .~ 0
|
||||
-- Define nodes that restore saved weights and biases.
|
||||
let bias, wts :: Tensor Ref Float
|
||||
bias = tensorFromName RefKind "Variable"
|
||||
wts = tensorFromName RefKind "weights"
|
||||
bias = tensorFromName "Variable"
|
||||
wts = tensorFromName "weights"
|
||||
wtsCkptPath <- liftIO wtsCkpt
|
||||
biasCkptPath <- liftIO biasCkpt
|
||||
-- Run those restoring nodes on the graph in the current session.
|
||||
|
@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
|||
let ty = encodeTensorData [10] oneHotLabels
|
||||
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
|
||||
updates = [(fromIntegral label, 1)]
|
||||
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
|
||||
, feed (tensorFromName ValueKind "y-input") ty
|
||||
let feeds = [ feed (tensorValueFromName "x-input") tensorSample
|
||||
, feed (tensorValueFromName "y-input") ty
|
||||
]
|
||||
-- Run the graph with the input feeds and read the ArgMax'd result from
|
||||
-- the test (not training) side of the evaluation.
|
||||
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
|
||||
x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax"
|
||||
-- Print the trained model's predicted outcome.
|
||||
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
|
||||
++ "Prediction: " ++ show (unScalar x :: Int64)
|
||||
|
|
|
@ -24,7 +24,6 @@ import Prelude hiding ( log
|
|||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( MonadBuild
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
import TensorFlow.GenOps.Core ( greaterEqual
|
||||
|
@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
|
|||
, exp
|
||||
)
|
||||
import TensorFlow.Tensor ( Tensor(..)
|
||||
, render
|
||||
, Value
|
||||
)
|
||||
import TensorFlow.Types ( TensorType(..)
|
||||
|
@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..)
|
|||
)
|
||||
import TensorFlow.Ops ( zerosLike
|
||||
, add
|
||||
, mul
|
||||
, neg
|
||||
)
|
||||
|
||||
-- | Computes sigmoid cross entropy given `logits`.
|
||||
|
@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits
|
|||
-> Tensor Value a -- ^ __targets__
|
||||
-> m (Tensor Value a)
|
||||
sigmoidCrossEntropyWithLogits logits targets = do
|
||||
logits' <- render logits
|
||||
targets' <- render targets
|
||||
let zeros = zerosLike logits'
|
||||
cond = logits' `greaterEqual` zeros
|
||||
relu_logits = select cond logits' zeros
|
||||
neg_abs_logits = select cond (-logits') logits'
|
||||
let zeros = zerosLike logits
|
||||
cond = logits `greaterEqual` zeros
|
||||
relu_logits = select cond logits zeros
|
||||
neg_abs_logits = select cond (neg logits) logits
|
||||
withNameScope "logistic_loss" $ do
|
||||
left <- render $ relu_logits - logits' * targets'
|
||||
left <- render $ relu_logits - logits `mul` targets
|
||||
right <- render $ log (1 + exp neg_abs_logits)
|
||||
withNameScope "sigmoid_add" $ render $ left `add` right
|
||||
|
|
|
@ -23,10 +23,9 @@ import Test.Framework (Test)
|
|||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.NN as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
|
||||
-- | These tests are ported from:
|
||||
--
|
||||
|
@ -60,12 +59,11 @@ defInputs = Inputs {
|
|||
testLogisticOutput :: Test
|
||||
testLogisticOutput = testCase "testLogisticOutput" $ do
|
||||
let inputs = defInputs
|
||||
vLogits = TF.vector $ logits inputs
|
||||
vTargets = TF.vector $ targets inputs
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
|
||||
r <- run tfLoss
|
||||
r <- run $ do
|
||||
vLogits <- TF.render $ TF.vector $ logits inputs
|
||||
vTargets <- TF.render $ TF.vector $ targets inputs
|
||||
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
|
@ -74,23 +72,22 @@ testLogisticOutputMultipleDim =
|
|||
testCase "testLogisticOutputMultipleDim" $ do
|
||||
let inputs = defInputs
|
||||
shape = [2, 2, 2]
|
||||
vLogits = TF.constant shape (logits inputs)
|
||||
vTargets = TF.constant shape (targets inputs)
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
|
||||
r <- run tfLoss
|
||||
r <- run $ do
|
||||
vLogits <- TF.render $ TF.constant shape (logits inputs)
|
||||
vTargets <- TF.render $ TF.constant shape (targets inputs)
|
||||
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
|
||||
assertAllClose r ourLoss
|
||||
|
||||
|
||||
testGradientAtZero :: Test
|
||||
testGradientAtZero = testCase "testGradientAtZero" $ do
|
||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||
vLogits = TF.vector $ logits inputs
|
||||
vTargets = TF.vector $ targets inputs
|
||||
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
|
||||
r <- run $ do
|
||||
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
|
||||
vTargets <- TF.render $ TF.vector $ targets inputs
|
||||
vLogits <- TF.render $ TF.vector $ logits inputs
|
||||
let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
|
||||
|
||||
l <- tfLoss
|
||||
TF.gradients l [vLogits]
|
||||
|
||||
|
|
|
@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc
|
|||
renderHaskellAttrName = renderHaskellName . attrName
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
functionBody pOp
|
||||
| parsedOpIsMonadic pOp
|
||||
= "build $ do"
|
||||
</> indent indentation (bindOpInputsVar
|
||||
</> "buildOp" <+> outputListsSizes <+> opDef)
|
||||
| otherwise
|
||||
= "pureOp" <+> outputListsSizes <+> "$ do"
|
||||
</> indent indentation (bindOpInputsVar </> "return" <+> opDef)
|
||||
where
|
||||
maybeLift
|
||||
| parsedOpIsMonadic pOp = "build $"
|
||||
| otherwise = ""
|
||||
buildFunction
|
||||
| null outputListsSizes = "buildOp"
|
||||
| otherwise = "buildListOp" <+>
|
||||
brackets (commasep $
|
||||
map renderHaskellName outputListsSizes)
|
||||
outputListsSizes = [ a
|
||||
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||
<- parsedOutputs pOp]
|
||||
buildOpParts =
|
||||
outputListsSizes = brackets $ commasep
|
||||
[ renderHaskellName a
|
||||
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
|
||||
<- parsedOutputs pOp
|
||||
]
|
||||
opInputsVar = "op'inputs"
|
||||
bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence"
|
||||
<+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs)
|
||||
opDef = parens $ hang 0 $ stack $
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||
-- Renders type parameter arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
||||
|
@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
|
|||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
["& op'options"]
|
||||
|
||||
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
["& op'options & opInputs .~" <+> opInputsVar]
|
||||
tensorArgs = renderTensorArg <$> parsedInputs pOp
|
||||
renderTensorArg = renderHaskellName . parsedArgName
|
||||
inferredTypeExpr a
|
||||
| typeParamIsList $ attrInfo a
|
||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||
|
@ -296,7 +298,7 @@ typeSig pre pOp = constraints
|
|||
| null classConstraints = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||
|
@ -327,7 +329,7 @@ typeSig pre pOp = constraints
|
|||
wrapOutput o
|
||||
| parsedOpIsMonadic pOp = "m'" <+> parens o
|
||||
| otherwise = o
|
||||
|
||||
|
||||
-- | Render an op input or output.
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
|
||||
tensorArg :: ParsedArg -> Doc
|
||||
|
@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of
|
|||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||
-> "TensorList" <+> kind k <+> renderHaskellName t
|
||||
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
|
||||
where
|
||||
kind k = case k of
|
||||
ArgTensorRef -> "Ref"
|
||||
ArgTensorValue -> "Value"
|
||||
ArgTensorEither v' -> strictText v'
|
||||
ArgTensorBuild -> "Build"
|
||||
ArgSomeTensor v -> strictText v
|
||||
tensorType t k = let
|
||||
a = case t of
|
||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||
|
@ -350,7 +353,7 @@ tensorArg p = case parsedArgCase p of
|
|||
|
||||
attrComment :: Attr a -> Doc
|
||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||
|
||||
|
||||
argComment :: ParsedArg -> Doc
|
||||
argComment a = argComment' (parsedArgName a) (parsedArgDescription a)
|
||||
|
||||
|
@ -364,7 +367,7 @@ bold n = "__" <> n <> "__"
|
|||
-- | Comment for the outputs of an op.
|
||||
-- For example:
|
||||
-- -- ^ (__output1__, __output2__)
|
||||
-- --
|
||||
-- --
|
||||
-- -- * __output1__: description1
|
||||
-- --
|
||||
-- -- * __output2__: description2
|
||||
|
|
|
@ -141,7 +141,8 @@ data ArgType
|
|||
data ArgKind
|
||||
= ArgTensorRef -- Tensor Ref a
|
||||
| ArgTensorValue -- Tensor Value a
|
||||
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
|
||||
| ArgTensorBuild -- Tensor Build a
|
||||
| ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
|
||||
deriving (Eq)
|
||||
|
||||
isRefCase :: ParsedArgCase -> Bool
|
||||
|
@ -219,15 +220,17 @@ parseOp o = ParsedOp
|
|||
{ parsedOpName = makeName $ o ^. name
|
||||
, parsedOpSummary = o ^. summary
|
||||
, parsedOpDescription = o ^. description
|
||||
, parsedOpIsMonadic = o ^. isStateful
|
||||
|| any (isRefCase . parsedArgCase) parsedInputs
|
||||
, ..
|
||||
}
|
||||
where
|
||||
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
|
||||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
parsedOpIsMonadic = o ^. isStateful
|
||||
|| any (isRefCase . parsedArgCase) parsedInputs
|
||||
|| null (o ^. outputArg)
|
||||
parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a))
|
||||
tensorKindParams (o ^. inputArg)
|
||||
tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a))
|
||||
(o ^. outputArg)
|
||||
-- Integer attributes that can be inferred from the size of at least one
|
||||
-- input list.
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
|
@ -246,15 +249,16 @@ parseOp o = ParsedOp
|
|||
$ o ^. attr
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
|
||||
inputTensorKind a v
|
||||
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
|
||||
inputTensorKind v a
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorEither v
|
||||
| otherwise = ArgSomeTensor v
|
||||
|
||||
outputTensorKind :: OpDef'ArgDef -> ArgKind
|
||||
outputTensorKind a
|
||||
outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
|
||||
outputTensorKind isMonadic a
|
||||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorValue
|
||||
| isMonadic = ArgTensorValue
|
||||
| otherwise = ArgTensorBuild
|
||||
|
||||
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr o implicitAttrs a
|
||||
|
|
|
@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where
|
|||
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import TensorFlow.Build (MonadBuild, colocateWith, render)
|
||||
import TensorFlow.Build (MonadBuild)
|
||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
|
@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
|||
--
|
||||
-- The results of the lookup are concatenated into a dense
|
||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v m .
|
||||
embeddingLookup :: forall a b v1 v2 m .
|
||||
( MonadBuild m
|
||||
, Rendered v1
|
||||
, TensorType a
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
)
|
||||
=> [Tensor v a]
|
||||
=> [Tensor v1 a]
|
||||
-- ^ A list of tensors which can be concatenated along
|
||||
-- dimension 0. Each `Tensor` must be appropriately
|
||||
-- sized for `mod` partition strategy.
|
||||
-> Tensor Value b
|
||||
-> Tensor v2 b
|
||||
-- ^ A `Tensor` with type `int32` or `int64`
|
||||
-- containing the ids to be looked up in `params`.
|
||||
-- The ids are required to have fewer than 2^31
|
||||
|
|
|
@ -31,6 +31,7 @@ import Data.ByteString (ByteString)
|
|||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.Foldable (foldlM)
|
||||
import Data.List (foldl', sortBy)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
|
||||
|
@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage)
|
|||
import Data.Set (Set)
|
||||
import Data.Text (Text)
|
||||
import Data.Tuple (swap)
|
||||
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
|
||||
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
|
||||
import Lens.Family2.State.Strict (uses)
|
||||
import Lens.Family2.Stock (at, intAt)
|
||||
import Lens.Family2.Unchecked (lens, iso)
|
||||
|
@ -59,11 +60,10 @@ import TensorFlow.Build
|
|||
( MonadBuild
|
||||
, Build
|
||||
, build
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
, opDef
|
||||
, opAttr
|
||||
, opInputs
|
||||
)
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Ops
|
||||
|
@ -86,16 +86,19 @@ import TensorFlow.Ops
|
|||
)
|
||||
import TensorFlow.Output
|
||||
( NodeName(..)
|
||||
, Op (Rendered)
|
||||
, Output(..)
|
||||
, OutputIx(..)
|
||||
, outputIndex
|
||||
)
|
||||
import TensorFlow.Tensor
|
||||
( Tensor(..)
|
||||
, TensorKind (ValueKind)
|
||||
, Value
|
||||
, tensorOutput
|
||||
, render
|
||||
, expr
|
||||
, Rendered
|
||||
, tensorNodeName
|
||||
, renderedOutput
|
||||
, renderValue
|
||||
)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
|
@ -114,10 +117,7 @@ type GradientCompatible a =
|
|||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||
, Num (Tensor v1 a)
|
||||
-- TODO(gnezdo): remove indirect constraint.
|
||||
-- It's a wart inherited from Num instance.
|
||||
, v1 ~ Value
|
||||
, Rendered v2
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
|
@ -150,27 +150,31 @@ gradients y xs = build $ do
|
|||
--
|
||||
-- 4. Lookup the recorded gradient for each x in xs.
|
||||
|
||||
yName <- renderNodeName y
|
||||
y' <- renderValue y
|
||||
let yName = tensorNodeName y'
|
||||
yOne <- render $ fill (shape y') (scalar 1)
|
||||
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
|
||||
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
|
||||
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
|
||||
. flip Map.lookup
|
||||
let (gr, nodeMap) = createGraph yName nodeDefLookup
|
||||
-- Set gradient of y to one.
|
||||
-- TODO: nicer
|
||||
let initPending :: Map.Map FGL.Node (PendingGradients a)
|
||||
initPending = Map.empty & at (nodeMap Map.! yName)
|
||||
= Map.empty & (at (nodeMap Map.! yName)
|
||||
. nonEmpty
|
||||
. outputIxAt (y ^. tensorOutput . outputIndex)
|
||||
. outputIxAt (outputIndex $ renderedOutput y')
|
||||
. nonEmpty
|
||||
.~ [fill (shape y) (scalar 1)]
|
||||
.~ [yOne]
|
||||
)
|
||||
-- Calculate the gradients of y w.r.t. each node in the graph.
|
||||
gradientMap <- graphGrads gr initPending
|
||||
-- Lookup the gradients for each x.
|
||||
forM xs $ \x -> do
|
||||
xName <- renderNodeName x
|
||||
render $ fromMaybe (zerosLike x) $ do
|
||||
forM xs $ \x ->
|
||||
let xName = tensorNodeName x
|
||||
in maybe (render $ zerosLike x) return $ do
|
||||
n <- nodeMap ^. at xName
|
||||
let i = x ^. tensorOutput . outputIndex
|
||||
let i = outputIndex $ renderedOutput x
|
||||
gradientMap ^. at n . nonEmpty . outputIxAt i
|
||||
|
||||
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
|
||||
|
@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx
|
|||
type PendingGradients a = IntMap.IntMap [Tensor Value a]
|
||||
|
||||
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
|
||||
-- TODO: precache the rendering?
|
||||
type Gradients a = IntMap.IntMap (Tensor Value a)
|
||||
|
||||
-- | Graph of TensorFlow operations.
|
||||
|
@ -229,42 +234,43 @@ non a = anon a (a==)
|
|||
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
|
||||
nonEmpty = anon mempty null
|
||||
|
||||
-- TODO: strictness (e.g., foldlM')
|
||||
|
||||
-- | Calculate the gradients for every node in a graph.
|
||||
graphGrads :: forall a. GradientCompatible a
|
||||
=> Graph
|
||||
-> Map FGL.Node (PendingGradients a)
|
||||
-- ^ Initial gradients (usually just 1 for the node of interest).
|
||||
-> Build (Map FGL.Node (Gradients a))
|
||||
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
|
||||
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
|
||||
where
|
||||
initState = GradientsState initPending Map.empty
|
||||
-- Reverse topological sort.
|
||||
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
|
||||
-- avoid calculating gradients that won't be used.
|
||||
nodeOrder = FGL.topsort $ FGL.grev gr
|
||||
go state node =
|
||||
go :: GradientsState a -> Int -> Build (GradientsState a)
|
||||
go state node = do
|
||||
-- Aggregate the accumulated gradients for this node.
|
||||
let outputGrads =
|
||||
outputGrads <-
|
||||
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
|
||||
in if null outputGrads
|
||||
then state
|
||||
else
|
||||
if null outputGrads
|
||||
then pure state
|
||||
else do
|
||||
let ctx = FGL.context gr node
|
||||
inputGrads <- calculateInputGrads ctx outputGrads gr
|
||||
-- Calculate the gradients for each of the node's inputs.
|
||||
let nextState = state & gradientsResult %~ Map.insert node outputGrads
|
||||
ctx = FGL.context gr node
|
||||
in updatePendingGradients
|
||||
ctx
|
||||
(calculateInputGrads ctx outputGrads gr)
|
||||
nextState
|
||||
pure $ updatePendingGradients ctx inputGrads nextState
|
||||
|
||||
-- | Reduce accumulated gradients for each output to one Tensor.
|
||||
sumPendingGradient :: GradientCompatible a
|
||||
=> PendingGradients a -> Gradients a
|
||||
sumPendingGradient = IntMap.mapMaybe f
|
||||
=> PendingGradients a -> Build (Gradients a)
|
||||
sumPendingGradient = sequence . IntMap.mapMaybe f
|
||||
where
|
||||
f [] = Nothing
|
||||
f [x] = Just x
|
||||
f xs = Just (addN xs)
|
||||
f [x] = Just (pure x)
|
||||
f xs = Just (render $ addN xs)
|
||||
|
||||
|
||||
-- | Calculate the gradients of a node's input tensors.
|
||||
|
@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a
|
|||
=> FGL.Context NodeDef EdgeLabel
|
||||
-> Gradients a -- ^ Output gradients of the node.
|
||||
-> Graph
|
||||
-> [Maybe (Tensor Value a)]
|
||||
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
||||
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
||||
-> Build [Maybe (Tensor Value a)]
|
||||
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
|
||||
fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
|
||||
outputGrads
|
||||
traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
|
||||
where
|
||||
fullOutGrads =
|
||||
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
|
||||
-- Create a tensor from an edge (technically an Output, but it seems less
|
||||
-- confusing to refer to it as a tensor here).
|
||||
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
|
||||
edgeToTensor ((i, _), n) =
|
||||
case FGL.lab gr n of
|
||||
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
|
||||
Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name)
|
||||
Nothing -> error $ "calculateInputGrads: missing input node for "
|
||||
++ Text.unpack (nodeDef ^. name)
|
||||
-- Input tensors, sorted by input index.
|
||||
|
@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
|
|||
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
|
||||
fullOutputGrads :: (TensorType a, Num a)
|
||||
=> OutputIx -- ^ Number of outputs.
|
||||
-> Op
|
||||
-> NodeName
|
||||
-> Gradients a
|
||||
-> [Tensor Value a]
|
||||
-> Build [Tensor Value a]
|
||||
fullOutputGrads n o gs =
|
||||
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
|
||||
mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
|
||||
where
|
||||
-- A tensor of zeros with the same shape as the i'th output.
|
||||
zero i = zerosLike $ toT (Output i o)
|
||||
|
@ -397,19 +403,19 @@ type GradientFunc a = NodeDef
|
|||
-- ^ Input tensors.
|
||||
-> [Tensor Value a]
|
||||
-- ^ Gradient of y w.r.t. each output tensor.
|
||||
-> [Maybe (Tensor Value a)]
|
||||
-> [Maybe (Tensor Build a)]
|
||||
-- ^ Gradient of y w.r.t. each input tensor.
|
||||
|
||||
|
||||
-- TODO(fmayle): Assert the type is correct.
|
||||
-- | Create a Tensor from an Output.
|
||||
toT :: Output -> Tensor Value a
|
||||
toT = Tensor ValueKind
|
||||
toT :: Output -> Tensor Build a
|
||||
toT = Tensor . pure
|
||||
|
||||
|
||||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||
-- simple slicing operations.
|
||||
flatSlice :: forall v1 t . (TensorType t)
|
||||
flatSlice :: forall v1 t . TensorType t
|
||||
=> Tensor v1 t -- ^ __input__
|
||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||
-- 'input' to slice from.
|
||||
|
@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t)
|
|||
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||
-- are included in the slice (i.e. this is equivalent to setting
|
||||
-- size = input.dim_size(0) - begin).
|
||||
-> Tensor Value t -- ^ __output__
|
||||
-> Tensor Build t -- ^ __output__
|
||||
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
||||
|
||||
nodeDefName :: NodeDef -> NodeName
|
||||
nodeDefName = NodeName . view name
|
||||
|
||||
|
||||
-- | The gradient function for an op type.
|
||||
--
|
||||
|
@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
|
|||
-- third_party/tensorflow/python/ops/*_grad.py
|
||||
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
|
||||
|
||||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
|
||||
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
|
||||
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
|
||||
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
|
||||
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
|
||||
|
||||
opGrad "Square" _ [toT -> x] [dz] =
|
||||
|
@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] =
|
|||
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
|
||||
-- (for performance reasons?). Will need to put these functions in the Build
|
||||
-- monad to replicate that.
|
||||
[Just $ dz * (2 * x)]
|
||||
[Just $ dz `CoreOps.mul` (2 * x)]
|
||||
|
||||
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
||||
-- TODO(fmayle): The python version uses a better performance implementation
|
||||
|
@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
|||
]
|
||||
where
|
||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||
denseShape = shape (x :: Tensor Value a)
|
||||
denseShape = shape (x :: Tensor Build a)
|
||||
numRows = scalarize $ flatSlice denseShape 0 1
|
||||
valuesShape = CoreOps.concat 0 [ allDimensions
|
||||
, flatSlice denseShape 1 (-1)
|
||||
]
|
||||
values = reshape dz valuesShape
|
||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||
indices' = reshape indices allDimensions :: Tensor Value Int32
|
||||
indices' = reshape indices allDimensions :: Tensor Build Int32
|
||||
|
||||
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
|
||||
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||
sx = shape (x :: Tensor Build a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
|
||||
y = CoreOps.max x indices
|
||||
y' = reshape y outputShapeKeptDims
|
||||
dz' = reshape dz outputShapeKeptDims
|
||||
|
@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
|
|||
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
|
||||
where
|
||||
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
|
||||
sx = shape (x :: Tensor Value a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
|
||||
sx = shape (x :: Tensor Build a)
|
||||
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
|
||||
tileScaling = safeShapeDiv sx outputShapeKeptDims
|
||||
grad = reshape dz outputShapeKeptDims
|
||||
|
||||
|
@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w =
|
|||
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
|
||||
where
|
||||
[Just dz, Nothing] = opGrad "Sum" u v w
|
||||
inputShape = shape (x :: Tensor Value a)
|
||||
outputShape = shape (dz :: Tensor Value a)
|
||||
inputShape = shape (x :: Tensor Build a)
|
||||
outputShape = shape (dz :: Tensor Build a)
|
||||
-- TODO(fmayle): Add fast path when shape is known.
|
||||
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
|
||||
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
|
||||
|
@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] =
|
|||
[ Just $ reshape (sum dz rx) sx
|
||||
, Just $ reshape (sum dz ry) sy ]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
sx = shape (x :: Tensor Build a)
|
||||
sy = shape (y :: Tensor Build a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "Sub" u v w =
|
||||
|
@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
|
|||
|
||||
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
[ Just $ reshape (sum (dz * y) rx) sx
|
||||
, Just $ reshape (sum (x * dz) ry) sy ]
|
||||
[ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
|
||||
, Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
sx = shape (x :: Tensor Build a)
|
||||
sy = shape (y :: Tensor Build a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
|
||||
-- TODO(fmayle): Handle complex numbers.
|
||||
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
|
||||
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
|
||||
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
|
||||
, Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
|
||||
ry)
|
||||
sy
|
||||
]
|
||||
where
|
||||
sx = shape (x :: Tensor Value a)
|
||||
sy = shape (y :: Tensor Value a)
|
||||
sx = shape (x :: Tensor Build a)
|
||||
sy = shape (y :: Tensor Build a)
|
||||
(rx, ry) = broadcastGradientArgs sx sy
|
||||
|
||||
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
|
@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
|
||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||
[ Just $ CoreOps.transpose dz
|
||||
(CoreOps.invertPermutation p :: Tensor Value Int32)
|
||||
(CoreOps.invertPermutation p :: Tensor Build Int32)
|
||||
, Nothing
|
||||
]
|
||||
|
||||
|
@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
|||
x output dz
|
||||
]
|
||||
where
|
||||
output :: Tensor Value a
|
||||
output = toT $ Output 0 (Rendered nodeDef)
|
||||
output :: Tensor Build a
|
||||
output = toT $ Output 0 (nodeDefName nodeDef)
|
||||
ksize = lookupAttr nodeDef "ksize" :: [Int64]
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
padding = lookupAttr nodeDef "padding" :: ByteString
|
||||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "Reshape" _ [toT -> x, _] [dz] =
|
||||
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
|
||||
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
|
||||
|
||||
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
|
||||
opGrad "TruncatedNormal" _ _ _ = [Nothing]
|
||||
|
||||
opGrad "RefIdentity" _ _ [dz] = [Just dz]
|
||||
opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz]
|
||||
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
|
||||
where
|
||||
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
|
||||
reverseCast =
|
||||
buildOp (opDef "Cast"
|
||||
pureOp [] $ pure (opDef "Cast"
|
||||
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
|
||||
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
|
||||
dz
|
||||
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)
|
||||
& opInputs .~ [renderedOutput dz])
|
||||
|
||||
opGrad "DynamicStitch" nodeDef inputs [dz] =
|
||||
replicate halfLen Nothing ++ valuesGrads
|
||||
|
@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] =
|
|||
in if 2 * half == len
|
||||
then half
|
||||
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
|
||||
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
|
||||
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
|
||||
| idx <- take halfLen inputs
|
||||
]
|
||||
|
||||
|
@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
|
|||
[ Just reconstructed, Nothing ]
|
||||
where
|
||||
reconstructed = CoreOps.reshape stitched
|
||||
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
|
||||
(CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
|
||||
stitched = CoreOps.dynamicStitch partitionedIndices dz
|
||||
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
|
||||
np = lookupAttr nodeDef "num_partitions" :: Int64
|
||||
originalIndices =
|
||||
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
|
||||
prefixShape = shapeInt32 indices
|
||||
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
|
||||
shapeInt32 t = CoreOps.shape t :: Tensor Build Int32
|
||||
|
||||
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
||||
[ Nothing
|
||||
|
@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
|
|||
where zeros = CoreOps.zerosLike x
|
||||
|
||||
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
|
||||
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
|
||||
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
|
||||
-- TODO(gnezdo): Reuse the output instead of doing another exp,
|
||||
-- though, it is probably CSE'd away anyway.
|
||||
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
|
||||
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
|
||||
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
||||
[ Just $ CoreOps.unsortedSegmentSum
|
||||
(CoreOps.gather dz (t :: Tensor Value Int32))
|
||||
(y :: Tensor Value Int32) inputRows
|
||||
(CoreOps.gather dz (t :: Tensor Build Int32))
|
||||
(y :: Tensor Build Int32) inputRows
|
||||
, Nothing
|
||||
, Nothing
|
||||
]
|
||||
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
|
||||
where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1
|
||||
|
||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
|
@ -710,13 +721,13 @@ numOutputs o =
|
|||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||
|
||||
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
|
||||
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
|
||||
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
|
||||
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
|
||||
|
||||
allDimensions :: Tensor Value Int32
|
||||
allDimensions :: Tensor Build Int32
|
||||
allDimensions = vector [-1 :: Int32]
|
||||
|
||||
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
|
||||
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
|
||||
|
||||
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1
|
||||
|
|
|
@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
|||
import TensorFlow.Build
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Output (unNodeName)
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
|
@ -183,7 +182,7 @@ import qualified Prelude (abs)
|
|||
-- "1".
|
||||
instance ( TensorType a
|
||||
, Num a
|
||||
, v ~ Value
|
||||
, v ~ Build
|
||||
, OneOf '[ Double, Float, Int32, Int64
|
||||
, Complex Float, Complex Double] a) => Num (Tensor v a) where
|
||||
(+) = CoreOps.add
|
||||
|
@ -194,10 +193,10 @@ instance ( TensorType a
|
|||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
|
||||
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
||||
matTranspose = matTranspose' id
|
||||
|
||||
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
|
||||
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
|
||||
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||
|
||||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
|
@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a)
|
|||
placeholder' params pShape
|
||||
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
||||
-- and thus would be CSE'd.
|
||||
= build $ buildOp $ opDef "Placeholder"
|
||||
= build $ buildOp [] $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ pShape
|
||||
& params
|
||||
|
@ -216,11 +215,11 @@ placeholder' params pShape
|
|||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: (MonadBuild m, TensorType a)
|
||||
=> Tensor Value a -> m (Tensor Ref a)
|
||||
=> Tensor v a -> m (Tensor Ref a)
|
||||
initializedVariable = initializedVariable' id
|
||||
|
||||
initializedVariable' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
|
||||
=> OpParams -> Tensor v a -> m (Tensor Ref a)
|
||||
initializedVariable' params initializer = do
|
||||
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
||||
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
||||
|
@ -240,17 +239,20 @@ zeroInitializedVariable'
|
|||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> m ControlNode
|
||||
save path xs = do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
|
||||
save path xs = build $ do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput
|
||||
let names = fmap toByteStringTensor xs
|
||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||
let saveOp = buildOp $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
build $ saveOp (scalar path) (CoreOps.pack names) xs
|
||||
names' <- buildInputs $ CoreOps.pack names
|
||||
xs' <- buildInputs xs
|
||||
path' <- buildInputs $ scalar path
|
||||
buildOp [] $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
& opInputs .~ (path' ++ names' ++ xs')
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
--
|
||||
|
@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a)
|
|||
-> ByteString -- ^ Tensor name override.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> m ControlNode
|
||||
restoreFromName path name x = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
group =<< CoreOps.assign x
|
||||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
restoreFromName path name x = build $ do
|
||||
path' <- buildInputs $ scalar path
|
||||
name' <- buildInputs $ scalar name
|
||||
restoreOp <- buildOp [] $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
& opInputs .~ (path' ++ name')
|
||||
group =<< CoreOps.assign x (restoreOp :: Tensor Value a)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> m ControlNode
|
||||
restore path x = do
|
||||
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
|
||||
restoreFromName path name x
|
||||
restore path x = restoreFromName path name x
|
||||
where
|
||||
name = encodeUtf8 $ encodeOutput $ renderedOutput x
|
||||
|
||||
-- | Create a constant tensor.
|
||||
--
|
||||
|
@ -283,10 +287,10 @@ restore path x = do
|
|||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
constant :: TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant :: TensorType a => Shape -> [a] -> Tensor Build a
|
||||
constant = constant' id
|
||||
|
||||
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
|
||||
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
|
||||
constant' params (Shape cShape) values
|
||||
| invalidLength = error invalidLengthMsg
|
||||
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
||||
|
@ -305,24 +309,24 @@ constant' params (Shape cShape) values
|
|||
-- | Reshape a N-D tensor down to a scalar.
|
||||
--
|
||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||
scalarize :: TensorType a => Tensor v a -> Tensor Build a
|
||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||
where
|
||||
scalarShape = [] :: [Int32]
|
||||
|
||||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector :: TensorType a => [a] -> Tensor Build a
|
||||
vector = vector' id
|
||||
|
||||
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
|
||||
vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
|
||||
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
||||
|
||||
-- | Create a constant scalar.
|
||||
scalar :: TensorType a => a -> Tensor Value a
|
||||
scalar :: TensorType a => a -> Tensor Build a
|
||||
scalar = scalar' id
|
||||
|
||||
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
|
||||
scalar' :: TensorType a => OpParams -> a -> Tensor Build a
|
||||
scalar' params x = constant' params [] [x]
|
||||
|
||||
-- | Random tensor from the unit normal distribution with bounded values.
|
||||
|
@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
|||
-> m (Tensor Value a)
|
||||
truncatedNormal' = CoreOps.truncatedNormal'
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
|
||||
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
||||
|
||||
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
shape :: TensorType t => Tensor v t -> Tensor Build Int32
|
||||
shape = CoreOps.shape
|
||||
|
||||
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
|
||||
shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
|
||||
shape' = CoreOps.shape'
|
||||
|
||||
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
|
||||
expandDims = CoreOps.expandDims
|
||||
|
||||
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
|
||||
expandDims' = CoreOps.expandDims'
|
||||
|
||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
|
||||
reducedShape inputShape axes =
|
||||
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
|
||||
axes32 = toInt32 axes -- [1, 2]
|
||||
toInt32 x = CoreOps.cast x :: Tensor Value Int32
|
||||
toInt32 x = CoreOps.cast x :: Tensor Build Int32
|
||||
inputRank = CoreOps.size inputShape32 -- 4
|
||||
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
|
||||
axesShape = shape axesMod -- [2]
|
||||
|
|
|
@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest
|
|||
, test-framework
|
||||
, test-framework-hunit
|
||||
, test-framework-quickcheck2
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite ArrayOpsTest
|
||||
|
|
|
@ -24,9 +24,7 @@ import Test.HUnit ((@=?))
|
|||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
|
||||
-- | Test split and concat are inverses.
|
||||
|
@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
|
|||
testShapeN :: Test
|
||||
testShapeN = testCase "testShapeN" $ TF.runSession $ do
|
||||
let shapes = map TF.Shape [[1],[2,3]]
|
||||
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
|
||||
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float]
|
||||
result <- TF.run $ CoreOps.shapeN tensors
|
||||
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])
|
||||
|
||||
|
|
|
@ -34,9 +34,7 @@ import TensorFlow.Build
|
|||
, asGraphDef
|
||||
, evalBuildT
|
||||
, flushNodeBuffer
|
||||
, render
|
||||
, withDevice
|
||||
, colocateWith
|
||||
, withNameScope
|
||||
, opName
|
||||
)
|
||||
|
@ -50,7 +48,13 @@ import TensorFlow.Ops
|
|||
, variable'
|
||||
)
|
||||
import TensorFlow.Output (Device(..))
|
||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||
import TensorFlow.Tensor
|
||||
( colocateWith
|
||||
, render
|
||||
, Tensor
|
||||
, Value
|
||||
, Ref
|
||||
)
|
||||
import TensorFlow.Session
|
||||
( run
|
||||
, runSession
|
||||
|
@ -65,8 +69,7 @@ import qualified Data.Vector as V
|
|||
-- | Test 'opName' behavior.
|
||||
testOpName :: Test
|
||||
testOpName = testCase "testOpName" $ do
|
||||
let graph = variable' (opName .~ "foo") []
|
||||
>>= render :: Build (Tensor Ref Float)
|
||||
let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
|
@ -114,7 +117,6 @@ testNamedAndScoped :: Test
|
|||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||
let graph :: Build (Tensor Ref Float)
|
||||
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
||||
>>= render
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
|
|
|
@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run)
|
|||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
|
||||
-- DynamicSplit is undone with DynamicStitch to get the original input
|
||||
-- back.
|
||||
testDynamicPartitionStitchInverse :: forall a.
|
||||
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
|
||||
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
||||
let splitParts :: [TF.Tensor TF.Value a] =
|
||||
let splitParts :: [TF.Tensor TF.Build a] =
|
||||
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
|
||||
partTensor = TF.vector partitions
|
||||
restitchIndices = CoreOps.dynamicPartition numParts
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
-- | Tests for EmbeddingOps.
|
||||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.List (genericLength)
|
||||
import Google.Test (googleTest)
|
||||
|
@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
|||
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
|
||||
let embedding1 = [1, 1, 1 :: Int32]
|
||||
let embedding2 = [0, 0, 0 :: Int32]
|
||||
let embedding = [ TF.constant embShape embedding1
|
||||
, TF.constant embShape embedding2
|
||||
]
|
||||
|
||||
let idValues = [0, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup embedding ids
|
||||
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
embedding <- mapM TF.render [ TF.constant embShape embedding1
|
||||
, TF.constant embShape embedding2
|
||||
]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
vs <- embeddingLookup embedding ids
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
|
@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape =
|
|||
, 0, 0, 0 :: Int32
|
||||
]
|
||||
|
||||
let embedding = TF.constant embShape embeddingInit
|
||||
let idValues = [0, 1 :: Int32]
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup [embedding] ids
|
||||
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
embedding <- TF.render $ TF.constant embShape embeddingInit
|
||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
vs <- embeddingLookup [embedding] ids
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
|
@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape =
|
|||
-- "[0, 1]" should pull out the resulting vector.
|
||||
values @=? V.fromList [1, 1, 1, 0, 0, 0]
|
||||
|
||||
|
||||
-- | Check that we can calculate gradients w.r.t embeddings.
|
||||
testEmbeddingLookupGradients :: Test
|
||||
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||
|
@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
|
||||
x <- TF.placeholder (TF.Shape [2])
|
||||
embedding <- TF.initializedVariable
|
||||
=<< TF.render (TF.constant embShape embeddingInit)
|
||||
(TF.constant embShape embeddingInit)
|
||||
|
||||
op <- embeddingLookup [embedding] ids
|
||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||
let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x)
|
||||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||
|
||||
grad <- fmap head (TF.gradients loss [embedding])
|
||||
|
@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit
|
|||
(LookupExample numParts
|
||||
shape@(TF.Shape (firstDim : restDims))
|
||||
values
|
||||
indices) =
|
||||
let modShardedValues :: [TF.Tensor TF.Value a] =
|
||||
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
||||
cyclicCounter :: TF.Tensor TF.Value Int32 =
|
||||
indices) = monadicIO $ run $ TF.runSession $ do
|
||||
let shapedValues = TF.constant shape values
|
||||
indicesVector <- TF.render $ TF.vector indices
|
||||
let directs = CoreOps.gather shapedValues indicesVector
|
||||
let cyclicCounter :: TF.Tensor TF.Build Int32 =
|
||||
TF.vector [0..fromIntegral firstDim-1]
|
||||
`CoreOps.mod` fromIntegral numParts
|
||||
indicesVector = TF.vector indices
|
||||
directs = CoreOps.gather shapedValues indicesVector
|
||||
shapedValues = TF.constant shape values
|
||||
in monadicIO $ run $ do
|
||||
(shapeOut, got, want :: V.Vector a) <-
|
||||
TF.runSession $ TF.run =<< do
|
||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||
shapeOut @=? V.fromList (genericLength indices : restDims)
|
||||
got @=? want
|
||||
modShardedValues :: [TF.Tensor TF.Value a] <-
|
||||
mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter
|
||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||
(shapeOut, got, want :: V.Vector a) <-
|
||||
TF.run (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||
liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims)
|
||||
liftIO $ got @=? want
|
||||
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
|
||||
|
||||
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.
|
||||
|
|
|
@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op)
|
|||
|
||||
testGradientSimple :: Test
|
||||
testGradientSimple = testCase "testGradientSimple" $ do
|
||||
let x = TF.scalar (3 :: Float)
|
||||
b = TF.scalar (4 :: Float)
|
||||
y = x*x + b
|
||||
grads = TF.gradients y [x, b]
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
let y = x `TF.mul` x `TF.add` b
|
||||
TF.gradients y [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
6 @=? TF.unScalar dx
|
||||
|
@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do
|
|||
|
||||
testGradientDisconnected :: Test
|
||||
testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
||||
let x = TF.scalar (3 :: Float)
|
||||
b = TF.scalar (4 :: Float)
|
||||
grads = TF.gradients x [x, b]
|
||||
let grads = do
|
||||
x <- TF.render $ TF.scalar (3 :: Float)
|
||||
b <- TF.render $ TF.scalar (4 :: Float)
|
||||
TF.gradients x [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
|||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
TF.gradients (x + y*3) [x, y] >>= TF.run
|
||||
TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. These asserts are extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
|||
testDiamond :: Test
|
||||
testDiamond = testCase "testDiamond" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1]
|
||||
y = x*x
|
||||
x <- TF.render $ TF.vector [1]
|
||||
let y = x `TF.mul` x
|
||||
z = y*y
|
||||
TF.gradients z [x] >>= TF.run
|
||||
(4 :: Float) @=? TF.unScalar dx
|
||||
|
@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do
|
|||
testMaxGradient :: Test
|
||||
testMaxGradient = testCase "testMaxGradient" $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||
x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||
let y = TF.max x (0 :: TF.Tensor TF.Build Int32)
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
|
|
|
@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $
|
|||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<<
|
||||
TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||
var = TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||
(TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- var
|
||||
|
@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
|||
p2 <- TF.placeholder []
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
||||
$ p1 `TF.add` p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
-- | Test that regular tensors can also be used for feeds, as long as they each
|
||||
|
@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
|||
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
|
||||
$ p1 `TF.add` p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
main :: IO ()
|
||||
|
|
|
@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do
|
|||
return (w', b')
|
||||
|
||||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> TF.Tensor TF.Build Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
|
|
|
@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
|||
let floatData = V.fromList [1..6 :: Float]
|
||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||
boolData = V.fromList [True, True, False, True, False, False]
|
||||
f <- TF.build $ TF.placeholder [2,3]
|
||||
s <- TF.build $ TF.placeholder [2,3]
|
||||
b <- TF.build $ TF.placeholder [2,3]
|
||||
f <- TF.placeholder [2,3]
|
||||
s <- TF.placeholder [2,3]
|
||||
b <- TF.placeholder [2,3]
|
||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||
, TF.feed b (TF.encodeTensorData [2,3] boolData)
|
||||
|
@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
|||
-- Do something idempotent to the tensors to verify that tensorflow can
|
||||
-- handle the encoding. Originally this used `TF.identity`, but that
|
||||
-- wasn't enough to catch a bug in the encoding of Bool.
|
||||
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
|
||||
(f', s', b') <- TF.runWithFeeds feeds
|
||||
(f `TF.add` 0, TF.identity s, TF.select b b b)
|
||||
liftIO $ do
|
||||
floatData @=? f'
|
||||
stringData @=? s'
|
||||
|
|
|
@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
|
|||
-- under the given name across multiple sessions.
|
||||
-> m (Queue as)
|
||||
makeQueue capacity sharedName = do
|
||||
q <- build $ buildOp (opDef "FIFOQueue"
|
||||
q <- build $ buildOp [] (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||
& opAttr "shared_name" .~ sharedName
|
||||
& opAttr "capacity" .~ capacity
|
||||
|
|
|
@ -13,9 +13,13 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
module TensorFlow.Build
|
||||
( -- * Graph node types
|
||||
ControlNode(..)
|
||||
|
@ -32,8 +36,6 @@ module TensorFlow.Build
|
|||
, opControlInputs
|
||||
-- * The Build monad
|
||||
, GraphState
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
, BuildT
|
||||
, Build
|
||||
|
@ -46,27 +48,23 @@ module TensorFlow.Build
|
|||
, addGraphDef
|
||||
, flushInitializers
|
||||
, flushNodeBuffer
|
||||
, summaries
|
||||
-- * Creating and looking up Ops
|
||||
, getOrAddOp
|
||||
, addNewOp
|
||||
, renderOutput
|
||||
, encodeOutput
|
||||
, lookupNode
|
||||
-- * Modifying all nodes in a Build action
|
||||
, colocateWith
|
||||
, withStateLens
|
||||
, withDevice
|
||||
, withNameScope
|
||||
, withNodeDependencies
|
||||
-- * Internal Summary related bits.
|
||||
, addSummary
|
||||
, SummaryTensor
|
||||
, collectAllSummaries
|
||||
) where
|
||||
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (def)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import qualified Data.Map.Strict as Map
|
||||
|
@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef
|
|||
|
||||
import TensorFlow.Orphans ()
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
|
||||
newtype Unique = Unique Int
|
||||
deriving (Eq, Ord, Enum)
|
||||
|
@ -125,9 +122,6 @@ opDefWithName n t = OpDef
|
|||
, _opControlInputs = []
|
||||
}
|
||||
|
||||
-- | Synonym for the tensors that return serialized Summary proto.
|
||||
type SummaryTensor = Tensor Value ByteString
|
||||
|
||||
data GraphState = GraphState
|
||||
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
|
||||
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
|
||||
|
@ -148,8 +142,8 @@ data GraphState = GraphState
|
|||
, _initializationNodes :: [NodeName]
|
||||
-- ^ The nodes to run next time a TF.run is issued, typically
|
||||
-- variable initializers.
|
||||
, _summaries :: [SummaryTensor]
|
||||
-- ^ The tensors for summary
|
||||
, _summaries :: [Output]
|
||||
-- ^ The tensors for summary (ByteString type)
|
||||
}
|
||||
|
||||
-- | A node definition without its final name. Used as a key in the
|
||||
|
@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs
|
|||
initializationNodes :: Lens' GraphState [NodeName]
|
||||
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
|
||||
|
||||
summaries :: Lens' GraphState [SummaryTensor]
|
||||
summaries :: Lens' GraphState [Output]
|
||||
summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
|
@ -238,9 +232,7 @@ flushInitializers = do
|
|||
-- | Registers the given node to be executed before the next
|
||||
-- 'TensorFlow.Session.run'.
|
||||
addInitializer :: MonadBuild m => ControlNode -> m ()
|
||||
addInitializer (ControlNode o) = build $ do
|
||||
i <- getOrAddOp o
|
||||
initializationNodes %= (i:)
|
||||
addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
|
||||
|
||||
-- | Produce a GraphDef proto representation of the nodes that are rendered in
|
||||
-- the given 'Build' action.
|
||||
|
@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node
|
|||
|
||||
-- | Render the given op if it hasn't been rendered already, and return its
|
||||
-- name.
|
||||
getOrAddOp :: Op -> Build NodeName
|
||||
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
|
||||
|
||||
resolveOp :: Op -> Build NodeDef
|
||||
resolveOp (Rendered n) = return n
|
||||
resolveOp (Unrendered o) = do
|
||||
getOrAddOp :: OpDef -> Build NodeName
|
||||
getOrAddOp o = do
|
||||
pending <- getPendingNode o
|
||||
uses renderedNodes (Map.lookup pending) >>= \case
|
||||
Just n -> return n
|
||||
Just n -> return $ NodeName $ n ^. name
|
||||
Nothing -> addNewOpFromPending pending
|
||||
|
||||
lookupNode :: NodeName -> Build NodeDef
|
||||
lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case
|
||||
Just n' -> return n'
|
||||
Nothing -> error $ "lookupNode: unknown node name " ++ show n
|
||||
|
||||
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
|
||||
-- which are not safe to dedup (e.g, "variable" and "assign").
|
||||
addNewOp :: OpDef -> Build NodeDef
|
||||
addNewOp :: OpDef -> Build NodeName
|
||||
addNewOp o = getPendingNode o >>= addNewOpFromPending
|
||||
|
||||
addNewOpFromPending :: PendingNode -> Build NodeDef
|
||||
addNewOpFromPending :: PendingNode -> Build NodeName
|
||||
addNewOpFromPending pending = do
|
||||
nodeName <- renderPendingNode pending
|
||||
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
|
||||
nodeBuffer %= (nodeDef :)
|
||||
renderedNodes %= Map.insert pending nodeDef
|
||||
renderedNodeDefs %= Map.insert nodeName nodeDef
|
||||
return nodeDef
|
||||
return nodeName
|
||||
|
||||
-- | Get the pending node corresponding to an OpDef, which may or may not have
|
||||
-- been rendered before. Implicitly renders all of this node's inputs.
|
||||
|
@ -287,20 +280,18 @@ getPendingNode o = do
|
|||
-- An empty string in the proto field means that no specific
|
||||
-- device is specified.
|
||||
dev <- maybe "" deviceName <$> use defaultDevice
|
||||
inputs <- mapM getInput (o ^. opInputs)
|
||||
scope <- use currentScope
|
||||
controls <- use defaultControlInputs
|
||||
let inputs = map encodeOutput (o ^. opInputs)
|
||||
let controlInputs
|
||||
= map getDep (o ^. opControlInputs ++ Set.toList controls)
|
||||
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
|
||||
return $ PendingNode scope (o ^. opName)
|
||||
$ def & op .~ (unOpType (o ^. opType) :: Text)
|
||||
& attr .~ _opAttrs o
|
||||
& input .~ (inputs ++ controlInputs)
|
||||
& device .~ dev
|
||||
where
|
||||
getInput (Output (OutputIx k) subOp)
|
||||
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
|
||||
getDep = ("^" <>) . unNodeName
|
||||
makeDep = ("^" <>) . unNodeName
|
||||
|
||||
-- | Pick a name for a pending node. If it has an explicit name, just use that;
|
||||
-- if the name is implicit, assign a new unique name based on the op type.
|
||||
|
@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
|
|||
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||
|
||||
|
||||
-- | Render an 'Output' and return a string representation for the TensorFlow
|
||||
-- | Turn an 'Output' into a string representation for the TensorFlow
|
||||
-- foreign APIs.
|
||||
renderOutput :: Output -> Build Text
|
||||
renderOutput (Output (OutputIx i) o) = do
|
||||
n <- getOrAddOp o
|
||||
return $ unNodeName n <> Text.pack (":" ++ show i)
|
||||
encodeOutput :: Output -> Text
|
||||
encodeOutput (Output (OutputIx 0) n) = unNodeName n
|
||||
encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i)
|
||||
|
||||
-- | Modify some part of the state, run an action, and restore the state
|
||||
-- after that action is done.
|
||||
|
@ -339,15 +329,6 @@ withStateLens accessor f act = do
|
|||
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
||||
withDevice d = withStateLens defaultDevice (const d)
|
||||
|
||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
|
||||
colocateWith t x = do
|
||||
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
withDevice (Just d) x
|
||||
|
||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||
withNameScope :: MonadBuild m => Text -> m a -> m a
|
||||
withNameScope s = withStateLens currentScope (Scope s :)
|
||||
|
@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :)
|
|||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||
|
||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
|
||||
-- weren't already rendered.
|
||||
--
|
||||
-- This operation is idempotent; @render >=> render === render@. However,
|
||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||
-- may result in two different 'Tensor's.
|
||||
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
|
||||
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
|
||||
|
||||
-- | Render a 'Tensor' and get its node's name.
|
||||
renderNodeName :: Tensor v a -> Build NodeName
|
||||
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
-- | Records the given summary action in Build for retrieval with
|
||||
-- 'collectAllSummaries'. The summary op is required to produce a
|
||||
-- Summary protocol buffer in string form. For safety, use the
|
||||
-- pre-composed functions: Logging.scalarSummary and
|
||||
-- Logging.histogramSummary.
|
||||
addSummary :: SummaryTensor -> Build ()
|
||||
addSummary t = summaries %= (t :)
|
||||
|
||||
-- | Retrieves the summary ops collected thus far. Typically this only
|
||||
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
||||
-- repeatedly, the values accumulate.
|
||||
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
|
||||
collectAllSummaries = use summaries
|
||||
|
|
|
@ -12,26 +12,27 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.BuildOp
|
||||
( OpResult
|
||||
, BuildOp
|
||||
( BuildResult(..)
|
||||
, buildOp
|
||||
, buildListOp
|
||||
, PureResult(..)
|
||||
, pureOp
|
||||
, eqLengthGuard
|
||||
, BuildInputs(..)
|
||||
, OpParams
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad (replicateM)
|
||||
import Control.Monad (liftM2, replicateM)
|
||||
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
||||
import Control.Monad.State.Strict (State, runState, get, put)
|
||||
import Control.Monad.State.Strict (State, evalState, get, put)
|
||||
import Data.Int (Int64)
|
||||
import Lens.Family2 ((&), (<>~), (^.))
|
||||
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
|
@ -40,48 +41,45 @@ import TensorFlow.Types
|
|||
|
||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||
|
||||
type Result = ReaderT Op (State ResultState)
|
||||
type Result = ReaderT NodeName (State ResultState)
|
||||
|
||||
-- | Class of types that can be used as op outputs.
|
||||
class OpResult a where
|
||||
toResult :: Result a
|
||||
class BuildResult a where
|
||||
buildResult :: Result a
|
||||
|
||||
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
|
||||
toResult = (,) <$> toResult <*> toResult
|
||||
instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
|
||||
buildResult = (,) <$> buildResult <*> buildResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
|
||||
toResult = (,,) <$> toResult <*> toResult <*> toResult
|
||||
instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
|
||||
buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
|
||||
=> OpResult (a1, a2, a3, a4) where
|
||||
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
|
||||
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
|
||||
=> BuildResult (a1, a2, a3, a4) where
|
||||
buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult
|
||||
|
||||
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
|
||||
=> OpResult (a1, a2, a3, a4, a5) where
|
||||
toResult = (,,,,) <$> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
|
||||
=> BuildResult (a1, a2, a3, a4, a5) where
|
||||
buildResult = (,,,,) <$> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
|
||||
instance ( OpResult a1
|
||||
, OpResult a2
|
||||
, OpResult a3
|
||||
, OpResult a4
|
||||
, OpResult a5
|
||||
, OpResult a6
|
||||
instance ( BuildResult a1
|
||||
, BuildResult a2
|
||||
, BuildResult a3
|
||||
, BuildResult a4
|
||||
, BuildResult a5
|
||||
, BuildResult a6
|
||||
)
|
||||
=> OpResult (a1, a2, a3, a4, a5, a6) where
|
||||
toResult = (,,,,,)
|
||||
<$> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
<*> toResult
|
||||
|
||||
tensorResult :: TensorKind v -> Result (Tensor v a)
|
||||
tensorResult v = Tensor v <$> recordResult
|
||||
=> BuildResult (a1, a2, a3, a4, a5, a6) where
|
||||
buildResult = (,,,,,)
|
||||
<$> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
<*> buildResult
|
||||
|
||||
recordResult :: Result Output
|
||||
recordResult = do
|
||||
|
@ -90,144 +88,39 @@ recordResult = do
|
|||
put $! ResultState (i+1) ns
|
||||
return $! output i o
|
||||
|
||||
instance OpResult ResourceHandle where
|
||||
toResult = ResourceHandle <$> recordResult
|
||||
instance BuildResult ResourceHandle where
|
||||
buildResult = ResourceHandle <$> recordResult
|
||||
|
||||
instance OpResult (Tensor Value a) where
|
||||
toResult = tensorResult ValueKind
|
||||
instance Rendered v => BuildResult (Tensor v a) where
|
||||
buildResult = Tensor . pure <$> recordResult
|
||||
|
||||
instance OpResult (Tensor Ref a) where
|
||||
toResult = tensorResult RefKind
|
||||
instance BuildResult ControlNode where
|
||||
buildResult = ControlNode <$> ask
|
||||
|
||||
instance OpResult ControlNode where
|
||||
toResult = ControlNode <$> ask
|
||||
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
|
||||
buildResult = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||
loop Nil = return Nil
|
||||
loop (TensorTypeProxy :/ ls) = do
|
||||
t <- buildResult
|
||||
ts <- loop ls
|
||||
return (t :/ ts)
|
||||
|
||||
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
|
||||
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||
loop Nil = return Nil
|
||||
loop (TensorTypeProxy :/ ls) = do
|
||||
t <- tensorResult v
|
||||
ts <- loop ls
|
||||
return (t :/ ts)
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Value as) where
|
||||
toResult = tensorListResult ValueKind
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Ref as) where
|
||||
toResult = tensorListResult RefKind
|
||||
|
||||
instance OpResult a => OpResult [a] where
|
||||
toResult = do
|
||||
instance BuildResult a => BuildResult [a] where
|
||||
buildResult = do
|
||||
ResultState i ns <- get
|
||||
case ns of
|
||||
[] -> error $ "Ran out of counts in toResult. " ++
|
||||
"Likely misuse of buildListOp."
|
||||
[] -> error $ "Ran out of counts in buildResult. " ++
|
||||
"Likely misuse of buildOp."
|
||||
(n : rest) -> do
|
||||
put $! ResultState i rest
|
||||
replicateM (fromIntegral n) toResult
|
||||
replicateM (fromIntegral n) buildResult
|
||||
|
||||
runResult :: OpResult a => [Int64] -> Op -> a
|
||||
runResult ns o =
|
||||
case runState (runReaderT toResult o) (ResultState 0 ns) of
|
||||
(x, ResultState _ []) -> x
|
||||
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
|
||||
show (ns, ns')
|
||||
|
||||
-- | Make a new "pure" op, which may be deduped with identical ops within
|
||||
-- the same scope.
|
||||
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
|
||||
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
-- | Make a new "stateful" op, which will not be deduped with otherwise
|
||||
-- identical ops.
|
||||
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
|
||||
buildResult ns o ts
|
||||
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
|
||||
|
||||
addReversedInputs :: OpDef -> [Output] -> OpDef
|
||||
addReversedInputs o ts = o & opInputs <>~ reverse ts
|
||||
|
||||
-- | Class of types that can be used as op functions.
|
||||
class BuildOp f where
|
||||
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
|
||||
-> OpDef
|
||||
-> [Output] -- ^ Accumulator for inputs to the op.
|
||||
-> f
|
||||
|
||||
-- | Starts an operation that returns a structured set of tensors
|
||||
-- (singletons or tuples).
|
||||
buildOp :: BuildOp f => OpDef -> f
|
||||
buildOp o = buildOp' [] o []
|
||||
|
||||
-- | Starts an operation that returns a list of tensors.
|
||||
buildListOp :: BuildOp f => [Int64]
|
||||
-- ^ Cardinality of the corresponding list of tensors output.
|
||||
-> OpDef -> f
|
||||
buildListOp counts o = buildOp' counts o []
|
||||
|
||||
instance BuildOp ControlNode where
|
||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
instance BuildOp ResourceHandle where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp (Tensor Value a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp (Tensor Ref a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Value as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Ref as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp [Tensor Value a] where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
|
||||
=> BuildOp (t1, t2, t3, t4) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
|
||||
=> BuildOp (t1, t2, t3, t4, t5) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance ( OpResult t1
|
||||
, OpResult t2
|
||||
, OpResult t3
|
||||
, OpResult t4
|
||||
, OpResult t5
|
||||
, OpResult t6
|
||||
)
|
||||
=> BuildOp (t1, t2, t3, t4, t5, t6) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance OpResult a => BuildOp (Build a) where
|
||||
buildOp' = buildResult
|
||||
|
||||
instance BuildOp f => BuildOp (ResourceHandle -> f) where
|
||||
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
||||
|
||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
||||
|
||||
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
||||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||
|
||||
instance BuildOp f => BuildOp (TensorList v as -> f) where
|
||||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
|
||||
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
|
||||
buildOp sizes o = do
|
||||
n <- addNewOp o
|
||||
return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)
|
||||
|
||||
-- | Returns true if all the integers in each tuple are identical.
|
||||
-- Throws an error with a descriptive message if not.
|
||||
|
@ -240,6 +133,104 @@ eqLengthGuard = all eachOk
|
|||
error ("number_attr " ++ numberAttrName ++
|
||||
" contains tensors with different length " ++ show pairs)
|
||||
|
||||
-----------
|
||||
|
||||
|
||||
-- | Class of types that can be used as op outputs.
|
||||
class PureResult a where
|
||||
pureResult :: ReaderT (Build OpDef) (State ResultState) a
|
||||
|
||||
instance PureResult (Tensor Build a) where
|
||||
pureResult = do
|
||||
ResultState i ns <- get
|
||||
put $! ResultState (i+1) ns
|
||||
makeOp <- ask
|
||||
return $ Tensor $ do
|
||||
o <- makeOp
|
||||
-- TODO: unify with BuildResult (Tensor v)
|
||||
output i <$> getOrAddOp o
|
||||
|
||||
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
|
||||
pureResult = (,) <$> pureResult <*> pureResult
|
||||
|
||||
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
|
||||
pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult
|
||||
|
||||
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
|
||||
=> PureResult (a1, a2, a3, a4) where
|
||||
pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult
|
||||
|
||||
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
|
||||
=> PureResult (a1, a2, a3, a4, a5) where
|
||||
pureResult = (,,,,) <$> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
|
||||
instance ( PureResult a1
|
||||
, PureResult a2
|
||||
, PureResult a3
|
||||
, PureResult a4
|
||||
, PureResult a5
|
||||
, PureResult a6
|
||||
)
|
||||
=> PureResult (a1, a2, a3, a4, a5, a6) where
|
||||
pureResult = (,,,,,)
|
||||
<$> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
<*> pureResult
|
||||
|
||||
instance PureResult a => PureResult [a] where
|
||||
pureResult = do
|
||||
ResultState i ns <- get
|
||||
case ns of
|
||||
[] -> error $ "Ran out of counts in pureResult. " ++
|
||||
"Likely misuse of pureOp with output lists."
|
||||
n : rest -> do
|
||||
put $! ResultState i rest
|
||||
replicateM (fromIntegral n) pureResult
|
||||
|
||||
instance TensorTypes as => PureResult (TensorList Build as) where
|
||||
pureResult = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
|
||||
(TensorList Build bs)
|
||||
loop Nil = return Nil
|
||||
loop (TensorTypeProxy :/ ls) = do
|
||||
t <- pureResult
|
||||
ts <- loop ls
|
||||
return (t :/ ts)
|
||||
|
||||
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
|
||||
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)
|
||||
|
||||
-----
|
||||
-- Class of types that can be used as arguments
|
||||
|
||||
class BuildInputs a where
|
||||
buildInputs :: a -> Build [Output]
|
||||
|
||||
instance BuildInputs a => BuildInputs [a] where
|
||||
buildInputs = fmap concat . mapM buildInputs
|
||||
|
||||
instance BuildInputs (Tensor v a) where
|
||||
buildInputs (Tensor t) = do
|
||||
o <- toBuild t
|
||||
return [o]
|
||||
|
||||
instance BuildInputs (ListOf (Tensor v) as) where
|
||||
buildInputs Nil = return []
|
||||
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
|
||||
|
||||
instance BuildInputs ResourceHandle where
|
||||
buildInputs (ResourceHandle o) = return [o]
|
||||
|
||||
----
|
||||
|
||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||
-- TODO: be more type safe.
|
||||
type OpParams = OpDef -> OpDef
|
||||
|
|
|
@ -25,9 +25,6 @@ module TensorFlow.ControlFlow
|
|||
, noOp
|
||||
) where
|
||||
|
||||
import qualified Data.Set as Set
|
||||
import Lens.Family2 ((&), (.~))
|
||||
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
|
@ -46,11 +43,8 @@ withControlDependencies deps act = do
|
|||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||
-- no output.
|
||||
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
||||
group deps = do
|
||||
nodes <- build $ Set.toList <$> getNodes deps
|
||||
-- TODO: slicker way
|
||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||
group deps = withControlDependencies deps noOp
|
||||
|
||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||
noOp :: ControlNode
|
||||
noOp = buildOp $ opDef "NoOp"
|
||||
noOp :: MonadBuild m => m ControlNode
|
||||
noOp = build $ buildOp [] $ opDef "NoOp"
|
||||
|
|
|
@ -57,9 +57,9 @@ module TensorFlow.Core
|
|||
, Tensor
|
||||
, Value
|
||||
, Ref
|
||||
, TensorKind(..)
|
||||
, value
|
||||
, tensorFromName
|
||||
, expr
|
||||
-- ** Element types
|
||||
, TensorType
|
||||
, TensorData
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a)
|
||||
module TensorFlow.Nodes where
|
||||
|
||||
import Control.Applicative (liftA2, liftA3)
|
||||
|
@ -28,7 +29,6 @@ import Data.Map.Strict (Map)
|
|||
import Data.Monoid ((<>))
|
||||
import Data.Set (Set)
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((^.))
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
|
||||
|
@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where
|
|||
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||
|
||||
instance Nodes ControlNode where
|
||||
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
|
||||
getNodes (ControlNode o) = pure $ Set.singleton o
|
||||
|
||||
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
|
||||
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
|
||||
|
@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
|
|||
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
||||
|
||||
instance Nodes (Tensor v a) where
|
||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o
|
||||
|
||||
fetchTensorVector :: forall a v . TensorType a
|
||||
fetchTensorVector :: forall a v . (TensorType a)
|
||||
=> Tensor v a -> Build (Fetch (TensorData a))
|
||||
fetchTensorVector (Tensor _ o) = do
|
||||
outputName <- renderOutput o
|
||||
return $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||
fetchTensorVector (Tensor o) = do
|
||||
outputName <- encodeOutput <$> toBuild o
|
||||
pure $ Fetch (Set.singleton outputName) $ \tensors ->
|
||||
let tensorData = tensors Map.! outputName
|
||||
expectedType = tensorType (undefined :: a)
|
||||
actualType = FFI.tensorDataType tensorData
|
||||
|
|
|
@ -22,8 +22,6 @@ module TensorFlow.Output
|
|||
, Device(..)
|
||||
-- * Ops
|
||||
, NodeName(..)
|
||||
, Op(..)
|
||||
, opUnrendered
|
||||
, OpDef(..)
|
||||
, opName
|
||||
, opType
|
||||
|
@ -34,28 +32,24 @@ module TensorFlow.Output
|
|||
, OutputIx(..)
|
||||
, Output(..)
|
||||
, output
|
||||
, outputIndex
|
||||
, outputOp
|
||||
, PendingNodeName(..)
|
||||
, ResourceHandle(..)
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.ProtoLens.TextFormat (showMessage)
|
||||
import Data.String (IsString(..))
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
|
||||
import Lens.Family2 (Lens')
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
|
||||
import Data.Default (def)
|
||||
import TensorFlow.Types (Attribute, attrLens)
|
||||
import TensorFlow.Orphans ()
|
||||
|
||||
-- | A type of graph node which has no outputs. These nodes are
|
||||
-- valuable for causing side effects when they are run.
|
||||
newtype ControlNode = ControlNode { unControlNode :: Op }
|
||||
newtype ControlNode = ControlNode { unControlNode :: NodeName }
|
||||
|
||||
-- | The type of op of a node in the graph. This corresponds to the proto field
|
||||
-- NodeDef.op.
|
||||
|
@ -66,18 +60,12 @@ instance IsString OpType where
|
|||
fromString = OpType . Text.pack
|
||||
|
||||
-- | An output of a TensorFlow node.
|
||||
data Output = Output !OutputIx !Op
|
||||
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
output :: OutputIx -> Op -> Output
|
||||
output :: OutputIx -> NodeName -> Output
|
||||
output = Output
|
||||
|
||||
outputOp :: Lens' Output Op
|
||||
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
|
||||
|
||||
outputIndex :: Lens' Output OutputIx
|
||||
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
|
||||
|
||||
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
||||
deriving (Eq, Ord, Num, Enum, Show)
|
||||
|
||||
|
@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text}
|
|||
instance Show Device where
|
||||
show (Device d) = show d
|
||||
|
||||
-- | The representation of a node in a TensorFlow graph.
|
||||
data Op
|
||||
= Rendered !NodeDef -- ^ Properties are fixed, including the
|
||||
-- device, name, and scope.
|
||||
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
|
||||
-- on which context this op is rendered in.
|
||||
deriving (Eq, Ord)
|
||||
|
||||
instance Show Op where
|
||||
show (Rendered n) = "Rendered " ++ showMessage n
|
||||
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
|
||||
|
||||
-- | Traverse on the 'Unrendered' of an 'Op'.
|
||||
--
|
||||
-- Same implementation as _Left.
|
||||
opUnrendered :: Traversal' Op OpDef
|
||||
opUnrendered f (Unrendered a) = Unrendered <$> f a
|
||||
opUnrendered _ (Rendered b) = pure (Rendered b)
|
||||
|
||||
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
||||
data OpDef = OpDef
|
||||
{ _opName :: !PendingNodeName
|
||||
|
@ -157,7 +126,7 @@ instance IsString Output where
|
|||
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
|
||||
-> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||
where assigned = NodeName . Text.pack
|
||||
|
||||
|
||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||
|
|
|
@ -163,7 +163,7 @@ runWithFeeds feeds t = do
|
|||
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
|
||||
runFetchWithFeeds feeds target (Fetch fetch restore) = do
|
||||
extend
|
||||
feeds' <- build $ fixFeeds feeds
|
||||
let feeds' = fixFeeds feeds
|
||||
let fetchNames = encodeUtf8 <$> Set.toList fetch
|
||||
targetNames = toNodeNames $ Set.toList target
|
||||
session <- Session (asks rawSession)
|
||||
|
@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do
|
|||
ns <- build $ getNodes t
|
||||
runFetchWithFeeds feeds ns (pure ())
|
||||
|
||||
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
|
||||
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
|
||||
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
|
||||
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
|
||||
|
||||
-- | Starts a concurrent thread which evaluates the given Nodes
|
||||
-- forever until runSession exits or an exception occurs. Graph
|
||||
|
|
|
@ -16,21 +16,26 @@
|
|||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-} -- For the Render class
|
||||
|
||||
module TensorFlow.Tensor where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.String (IsString(..))
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
import Lens.Family2 ((^.))
|
||||
import Lens.Family2.State ((%=), use)
|
||||
|
||||
import TensorFlow.Output (Output)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef (device)
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
|
||||
import TensorFlow.Types
|
||||
( TensorData(..)
|
||||
, ListOf(..)
|
||||
|
@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI
|
|||
-- | A named output of a TensorFlow operation.
|
||||
--
|
||||
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
|
||||
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
|
||||
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
|
||||
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
|
||||
-- 'value'.
|
||||
data Tensor v a = Tensor (TensorKind v) Output
|
||||
-- parameter @v@ is either:
|
||||
--
|
||||
-- * 'Build': An unrendered, immutable value.
|
||||
-- * 'Value': A rendered, immutable value.
|
||||
-- * 'Ref': A rendered stateful handle (e.g., a variable).
|
||||
--
|
||||
-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between
|
||||
-- the different types of 'Tensor'.
|
||||
data Tensor v a where
|
||||
Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
|
||||
|
||||
data Value
|
||||
data Ref
|
||||
newtype Value a = Value {runValue :: a}
|
||||
deriving Functor
|
||||
|
||||
-- | This class provides a runtime switch on whether a 'Tensor' should be
|
||||
-- treated as a 'Value' or as a 'Ref'.
|
||||
data TensorKind v where
|
||||
ValueKind :: TensorKind Value
|
||||
RefKind :: TensorKind Ref
|
||||
instance Applicative Value where
|
||||
pure = Value
|
||||
Value f <*> Value x = Value $ f x
|
||||
|
||||
tensorKind :: Lens' (Tensor v a) (TensorKind v)
|
||||
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
||||
instance Monad Value where
|
||||
f >>= g = g $ runValue f
|
||||
|
||||
tensorOutput :: Lens' (Tensor v a) Output
|
||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||
newtype Ref a = Ref {runRef :: a}
|
||||
deriving Functor
|
||||
|
||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||
-- Ref into Value. This behaves like a no-op.
|
||||
value :: Tensor v a -> Tensor Value a
|
||||
value (Tensor _ o) = Tensor ValueKind o
|
||||
instance Applicative Ref where
|
||||
pure = Ref
|
||||
Ref f <*> Ref x = Ref $ f x
|
||||
|
||||
instance Monad Ref where
|
||||
f >>= g = g $ runRef f
|
||||
|
||||
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
|
||||
value :: Tensor Ref a -> Tensor Value a
|
||||
value (Tensor o) = Tensor $ Value $ runRef o
|
||||
|
||||
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
|
||||
renderValue (Tensor o) = render $ Tensor $ toBuild o
|
||||
|
||||
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
|
||||
-- when running the graph.
|
||||
data Feed = Feed Output FFI.TensorData
|
||||
|
||||
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
||||
-- name, device, etc.
|
||||
class TensorKind v => Rendered v where
|
||||
rendered :: v a -> a
|
||||
|
||||
instance Rendered Value where
|
||||
rendered = runValue
|
||||
|
||||
instance Rendered Ref where
|
||||
rendered = runRef
|
||||
|
||||
renderedOutput :: Rendered v => Tensor v a -> Output
|
||||
renderedOutput = rendered . tensorOutput
|
||||
|
||||
tensorNodeName :: Rendered v => Tensor v a -> NodeName
|
||||
tensorNodeName = outputNodeName . renderedOutput
|
||||
|
||||
|
||||
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
|
||||
-- the graph.
|
||||
--
|
||||
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
|
||||
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
|
||||
feed :: Tensor v a -> TensorData a -> Feed
|
||||
feed (Tensor _ o) (TensorData td) = Feed o td
|
||||
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
|
||||
feed t (TensorData td) = Feed (renderedOutput t) td
|
||||
|
||||
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
|
||||
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
|
||||
-- TODO(judahjacobson): add more safety checks here.
|
||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
||||
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
|
||||
tensorFromName = Tensor . pure . fromString . Text.unpack
|
||||
|
||||
-- | Like 'tensorFromName', but type-restricted to 'Value'.
|
||||
tensorValueFromName :: Text.Text -> Tensor Value a
|
||||
tensorValueFromName = tensorFromName
|
||||
|
||||
-- | Like 'tensorFromName', but type-restricted to 'Ref'.
|
||||
tensorRefFromName :: Text.Text -> Tensor Ref a
|
||||
tensorRefFromName = tensorFromName
|
||||
|
||||
type TensorList v = ListOf (Tensor v)
|
||||
|
||||
tensorListOutputs :: TensorList v as -> [Output]
|
||||
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
|
||||
tensorListOutputs Nil = []
|
||||
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
|
||||
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
|
||||
|
||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
|
||||
colocateWith t x = do
|
||||
d <- build $ Device . (^. device)
|
||||
<$> lookupNode (outputNodeName $ renderedOutput t)
|
||||
withDevice (Just d) x
|
||||
|
||||
|
||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||
-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that
|
||||
-- weren't already rendered.
|
||||
--
|
||||
-- This operation is idempotent; calling 'render' on the same input in the same
|
||||
-- context will produce the same result. However, rendering the same
|
||||
-- @Tensor Build@ in two different contexts may result in two different
|
||||
-- @Tensor Value@s.
|
||||
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
|
||||
render (Tensor t) = Tensor . Value <$> build t
|
||||
|
||||
-- TODO: better name.
|
||||
expr :: TensorKind v => Tensor v a -> Tensor Build a
|
||||
expr (Tensor o) = Tensor $ toBuild o
|
||||
|
||||
-- | Records the given summary action in Build for retrieval with
|
||||
-- Summary protocol buffer in string form. For safety, use the
|
||||
-- pre-composed functions: Logging.scalarSummary and
|
||||
-- Logging.histogramSummary.
|
||||
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor'
|
||||
-> m ()
|
||||
addSummary t = build $ do
|
||||
-- TODO: more generic way
|
||||
o <- toBuild $ tensorOutput t
|
||||
summaries %= (o :)
|
||||
|
||||
-- | Retrieves the summary ops collected thus far. Typically this only
|
||||
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
|
||||
-- repeatedly, the values accumulate.
|
||||
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
|
||||
collectAllSummaries = build $ map (Tensor . Value) <$> use summaries
|
||||
|
||||
-- | Synonym for the tensors that return serialized Summary proto.
|
||||
type SummaryTensor = Tensor Value ByteString
|
||||
|
||||
-- | An internal class for kinds of Tensors.
|
||||
class Monad v => TensorKind v where
|
||||
toBuild :: v a -> Build a
|
||||
|
||||
instance TensorKind Value where
|
||||
toBuild = return . rendered
|
||||
|
||||
instance TensorKind Ref where
|
||||
toBuild = return . rendered
|
||||
|
||||
instance TensorKind Build where
|
||||
toBuild = id
|
||||
|
|
Loading…
Reference in a new issue