1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Distinguish between "rendered" and "unrendered" Tensors. (#88)

Distinguish between "rendered" and "unrendered" Tensors.

There are now three types of `Tensor`:

- `Tensor Value a`: rendered value
- `Tensor Ref a`: rendered reference
- `Tensor Build a` : unrendered value

The extra bookkeeping makes it easier to track (and enforce) which tensors are
rendered or not.  For examples where this has been confusing in the past, see

With this change, pure ops look similar to before, returning `Tensor Build`
instead of `Tensor Value`.  "Stateful" (monadic) ops are unchanged.  For
example:

    add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
    assign :: (MonadBuild m, TensorType t)
           => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t)

The `gradients` function now requires that the variables over which it's
differentiating are pre-rendered:

    gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a]
              -> m [Tensor Value a]

(`Rendered v2` means that `v2` is either a `Ref` or a `Value`.)

Additionally, the implementation of `gradients` now takes care to render every
intermediate value when performing the reverse accumulation.  I suspect this
fixes an exponential blowup for complicated expressions.
This commit is contained in:
Judah Jacobson 2017-04-06 15:10:33 -07:00 committed by fkm3
parent d71f48090a
commit d62c614695
29 changed files with 636 additions and 608 deletions

View file

@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do
return (w', b') return (w', b')
gradientDescent :: Float gradientDescent :: Float
-> TF.Tensor TF.Value Float -> TF.Tensor TF.Build Float
-> [TF.Tensor TF.Ref Float] -> [TF.Tensor TF.Ref Float]
-> TF.Session TF.ControlNode -> TF.Session TF.ControlNode
gradientDescent alpha loss params = do gradientDescent alpha loss params = do

View file

@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime) import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing) import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>)) import System.FilePath ((</>))
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries) import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (scalar) import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords) import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor) import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Types (TensorType, type(/=)) import TensorFlow.Types (TensorType, type(/=))
import Text.Printf (printf) import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally -- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity. -- limited to a single value for simplicity.
histogramSummary :: histogramSummary ::
(TensorType t, t /= ByteString, t /= Bool) (MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build () => ByteString -> Tensor v t -> m ()
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag) histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
-- | Adds a 'CoreOps.scalarSummary' node. -- | Adds a 'CoreOps.scalarSummary' node.
scalarSummary :: scalarSummary ::
(TensorType t, t /= ByteString, t /= Bool) (TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
-- (TensorType t, -- (TensorType t,
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) -- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build () => ByteString -> Tensor v t -> m ()
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag) scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
-- | Merge all summaries accumulated in the 'Build' into one summary. -- | Merge all summaries accumulated in the 'Build' into one summary.
mergeAllSummaries :: Build SummaryTensor mergeAllSummaries :: MonadBuild m => m SummaryTensor
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary

View file

@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64
numLabels = 10 :: Int64 numLabels = 10 :: Int64
-- | Create tensor with random values where the stddev depends on the width. -- | Create tensor with random values where the stddev depends on the width.
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float) randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float)
randomParam width (TF.Shape shape) = randomParam width (TF.Shape shape) =
(* stddev) <$> TF.truncatedNormal (TF.vector shape) (`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape)
where where
stddev = TF.scalar (1 / sqrt (fromIntegral width)) stddev = TF.scalar (1 / sqrt (fromIntegral width))
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32)) reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
-- Types must match due to model structure. -- Types must match due to model structure.
@ -87,7 +87,7 @@ createModel = do
grads <- TF.gradients loss params grads <- TF.gradients loss params
let lr = TF.scalar 0.00001 let lr = TF.scalar 0.00001
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad) applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads trainStep <- TF.group =<< zipWithM applyGrad params grads
let correctPredictions = TF.equal predict labels let correctPredictions = TF.equal predict labels

View file

@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph
import TensorFlow.Build import TensorFlow.Build
( asGraphDef ( asGraphDef
, addGraphDef , addGraphDef
, render , Build
) )
import TensorFlow.Tensor import TensorFlow.Tensor
( Tensor(..) ( Tensor(..)
, Ref , Ref
, Value
, feed , feed
, TensorKind(..) , render
, tensorFromName , tensorFromName
, tensorValueFromName
) )
import TensorFlow.Ops import TensorFlow.Ops
import TensorFlow.Session import TensorFlow.Session
@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do
labelData <- readMNISTLabels =<< testLabelData labelData <- readMNISTLabels =<< testLabelData
10000 @=? length labelData 10000 @=? length labelData
testNodeName :: Text -> Tensor v a -> Assertion testNodeName :: Text -> Tensor Build a -> Assertion
testNodeName n g = n @=? opName testNodeName n g = n @=? opName
where where
opName = head (gDef^.node)^.op opName = head (gDef^.node)^.op
@ -89,7 +89,7 @@ testNodeName n g = n @=? opName
testGraphDefGen :: Test testGraphDefGen :: Test
testGraphDefGen = testCase "testGraphDefGen" $ do testGraphDefGen = testCase "testGraphDefGen" $ do
-- Test the inferred operation type. -- Test the inferred operation type.
let f0 :: Tensor Value Float let f0 :: Tensor Build Float
f0 = 0 f0 = 0
testNodeName "Const" f0 testNodeName "Const" f0
testNodeName "Add" $ 1 + f0 testNodeName "Add" $ 1 + f0
@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10 let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do runSession $ do
addGraphDef graphDef addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2" x <- run $ tensorValueFromName "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x liftIO $ (50 :: Float) @=? unScalar x
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on -- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do
build $ addGraphDef $ mnist & version .~ 0 build $ addGraphDef $ mnist & version .~ 0
-- Define nodes that restore saved weights and biases. -- Define nodes that restore saved weights and biases.
let bias, wts :: Tensor Ref Float let bias, wts :: Tensor Ref Float
bias = tensorFromName RefKind "Variable" bias = tensorFromName "Variable"
wts = tensorFromName RefKind "weights" wts = tensorFromName "weights"
wtsCkptPath <- liftIO wtsCkpt wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session. -- Run those restoring nodes on the graph in the current session.
@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do
let ty = encodeTensorData [10] oneHotLabels let ty = encodeTensorData [10] oneHotLabels
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
updates = [(fromIntegral label, 1)] updates = [(fromIntegral label, 1)]
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample let feeds = [ feed (tensorValueFromName "x-input") tensorSample
, feed (tensorFromName ValueKind "y-input") ty , feed (tensorValueFromName "y-input") ty
] ]
-- Run the graph with the input feeds and read the ArgMax'd result from -- Run the graph with the input feeds and read the ArgMax'd result from
-- the test (not training) side of the evaluation. -- the test (not training) side of the evaluation.
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax" x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax"
-- Print the trained model's predicted outcome. -- Print the trained model's predicted outcome.
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n" liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
++ "Prediction: " ++ show (unScalar x :: Int64) ++ "Prediction: " ++ show (unScalar x :: Int64)

View file

@ -24,7 +24,6 @@ import Prelude hiding ( log
, exp , exp
) )
import TensorFlow.Build ( MonadBuild import TensorFlow.Build ( MonadBuild
, render
, withNameScope , withNameScope
) )
import TensorFlow.GenOps.Core ( greaterEqual import TensorFlow.GenOps.Core ( greaterEqual
@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
, exp , exp
) )
import TensorFlow.Tensor ( Tensor(..) import TensorFlow.Tensor ( Tensor(..)
, render
, Value , Value
) )
import TensorFlow.Types ( TensorType(..) import TensorFlow.Types ( TensorType(..)
@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..)
) )
import TensorFlow.Ops ( zerosLike import TensorFlow.Ops ( zerosLike
, add , add
, mul
, neg
) )
-- | Computes sigmoid cross entropy given `logits`. -- | Computes sigmoid cross entropy given `logits`.
@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits
-> Tensor Value a -- ^ __targets__ -> Tensor Value a -- ^ __targets__
-> m (Tensor Value a) -> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits let zeros = zerosLike logits
targets' <- render targets cond = logits `greaterEqual` zeros
let zeros = zerosLike logits' relu_logits = select cond logits zeros
cond = logits' `greaterEqual` zeros neg_abs_logits = select cond (neg logits) logits
relu_logits = select cond logits' zeros
neg_abs_logits = select cond (-logits') logits'
withNameScope "logistic_loss" $ do withNameScope "logistic_loss" $ do
left <- render $ relu_logits - logits' * targets' left <- render $ relu_logits - logits `mul` targets
right <- render $ log (1 + exp neg_abs_logits) right <- render $ log (1 + exp neg_abs_logits)
withNameScope "sigmoid_add" $ render $ left `add` right withNameScope "sigmoid_add" $ render $ left `add` right

View file

@ -23,10 +23,9 @@ import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Gradient as TF import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF import qualified TensorFlow.NN as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Core as TF
-- | These tests are ported from: -- | These tests are ported from:
-- --
@ -60,12 +59,11 @@ defInputs = Inputs {
testLogisticOutput :: Test testLogisticOutput :: Test
testLogisticOutput = testCase "testLogisticOutput" $ do testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs let inputs = defInputs
vLogits = TF.vector $ logits inputs r <- run $ do
vTargets = TF.vector $ targets inputs vLogits <- TF.render $ TF.vector $ logits inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets vTargets <- TF.render $ TF.vector $ targets inputs
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) TF.sigmoidCrossEntropyWithLogits vLogits vTargets
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
assertAllClose r ourLoss assertAllClose r ourLoss
@ -74,23 +72,22 @@ testLogisticOutputMultipleDim =
testCase "testLogisticOutputMultipleDim" $ do testCase "testLogisticOutputMultipleDim" $ do
let inputs = defInputs let inputs = defInputs
shape = [2, 2, 2] shape = [2, 2, 2]
vLogits = TF.constant shape (logits inputs) r <- run $ do
vTargets = TF.constant shape (targets inputs) vLogits <- TF.render $ TF.constant shape (logits inputs)
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets vTargets <- TF.render $ TF.constant shape (targets inputs)
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs) TF.sigmoidCrossEntropyWithLogits vLogits vTargets
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
assertAllClose r ourLoss assertAllClose r ourLoss
testGradientAtZero :: Test testGradientAtZero :: Test
testGradientAtZero = testCase "testGradientAtZero" $ do testGradientAtZero = testCase "testGradientAtZero" $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vLogits = TF.vector $ logits inputs
vTargets = TF.vector $ targets inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
r <- run $ do r <- run $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vTargets <- TF.render $ TF.vector $ targets inputs
vLogits <- TF.render $ TF.vector $ logits inputs
let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
l <- tfLoss l <- tfLoss
TF.gradients l [vLogits] TF.gradients l [vLogits]

View file

@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc functionBody :: ParsedOp -> Doc
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts)) functionBody pOp
</> indent indentation (sep tensorArgs) | parsedOpIsMonadic pOp
= "build $ do"
</> indent indentation (bindOpInputsVar
</> "buildOp" <+> outputListsSizes <+> opDef)
| otherwise
= "pureOp" <+> outputListsSizes <+> "$ do"
</> indent indentation (bindOpInputsVar </> "return" <+> opDef)
where where
maybeLift outputListsSizes = brackets $ commasep
| parsedOpIsMonadic pOp = "build $" [ renderHaskellName a
| otherwise = ""
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+>
brackets (commasep $
map renderHaskellName outputListsSizes)
outputListsSizes = [ a
| ParsedArg { parsedArgCase = ListArg { argLength = a } } | ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs pOp] <- parsedOutputs pOp
buildOpParts = ]
opInputsVar = "op'inputs"
bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence"
<+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs)
opDef = parens $ hang 0 $ stack $
"opDef" <+> renderQuotedTFName (parsedOpName pOp) : "opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders type parameter arguments. -- Renders type parameter arguments.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a | a <- inferredListSizeAttrs pOp, let n = attrName a
] ++ ] ++
["& op'options"] ["& op'options & opInputs .~" <+> opInputsVar]
tensorArgs = renderTensorArg <$> parsedInputs pOp
renderTensorArg = renderHaskellName . parsedArgName
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a inferredTypeExpr a
| typeParamIsList $ attrInfo a | typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a = "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
@ -296,7 +298,7 @@ typeSig pre pOp = constraints
| null classConstraints = empty | null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else [] ++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name. -- Use m' as the type parameter to avoid clashing with an attribute name.
@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {argTypeAttr = t, argCaseKind = k} MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> kind k <+> renderHaskellName t -> "TensorList" <+> parens (kind k) <+> renderHaskellName t
where where
kind k = case k of kind k = case k of
ArgTensorRef -> "Ref" ArgTensorRef -> "Ref"
ArgTensorValue -> "Value" ArgTensorValue -> "Value"
ArgTensorEither v' -> strictText v' ArgTensorBuild -> "Build"
ArgSomeTensor v -> strictText v
tensorType t k = let tensorType t k = let
a = case t of a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt

View file

@ -141,7 +141,8 @@ data ArgType
data ArgKind data ArgKind
= ArgTensorRef -- Tensor Ref a = ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a | ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v` | ArgTensorBuild -- Tensor Build a
| ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
deriving (Eq) deriving (Eq)
isRefCase :: ParsedArgCase -> Bool isRefCase :: ParsedArgCase -> Bool
@ -219,15 +220,17 @@ parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name { parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary , parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description , parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefCase . parsedArgCase) parsedInputs
, .. , ..
} }
where where
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v)) parsedOpIsMonadic = o ^. isStateful
(o ^. inputArg) tensorKindParams || any (isRefCase . parsedArgCase) parsedInputs
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] || null (o ^. outputArg)
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a))
tensorKindParams (o ^. inputArg)
tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a))
(o ^. outputArg)
-- Integer attributes that can be inferred from the size of at least one -- Integer attributes that can be inferred from the size of at least one
-- input list. -- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
@ -246,15 +249,16 @@ parseOp o = ParsedOp
$ o ^. attr $ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs. -- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
inputTensorKind a v inputTensorKind v a
| a ^. isRef = ArgTensorRef | a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v | otherwise = ArgSomeTensor v
outputTensorKind :: OpDef'ArgDef -> ArgKind outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind a outputTensorKind isMonadic a
| a ^. isRef = ArgTensorRef | a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue | isMonadic = ArgTensorValue
| otherwise = ArgTensorBuild
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a getExplicitInputAttr o implicitAttrs a

View file

@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import TensorFlow.Build (MonadBuild, colocateWith, render) import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value) import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
import TensorFlow.Types (OneOf, TensorType) import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps
-- --
-- The results of the lookup are concatenated into a dense -- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. -- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v m . embeddingLookup :: forall a b v1 v2 m .
( MonadBuild m ( MonadBuild m
, Rendered v1
, TensorType a , TensorType a
, OneOf '[Int64, Int32] b , OneOf '[Int64, Int32] b
, Num b , Num b
) )
=> [Tensor v a] => [Tensor v1 a]
-- ^ A list of tensors which can be concatenated along -- ^ A list of tensors which can be concatenated along
-- dimension 0. Each `Tensor` must be appropriately -- dimension 0. Each `Tensor` must be appropriately
-- sized for `mod` partition strategy. -- sized for `mod` partition strategy.
-> Tensor Value b -> Tensor v2 b
-- ^ A `Tensor` with type `int32` or `int64` -- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`. -- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31 -- The ids are required to have fewer than 2^31

View file

@ -31,6 +31,7 @@ import Data.ByteString (ByteString)
import Data.Complex (Complex) import Data.Complex (Complex)
import Data.Default (def) import Data.Default (def)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy) import Data.List (foldl', sortBy)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe) import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set) import Data.Set (Set)
import Data.Text (Text) import Data.Text (Text)
import Data.Tuple (swap) import Data.Tuple (swap)
import Lens.Family2 (Lens', (&), (^.), (.~), (%~)) import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses) import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt) import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso) import Lens.Family2.Unchecked (lens, iso)
@ -59,11 +60,10 @@ import TensorFlow.Build
( MonadBuild ( MonadBuild
, Build , Build
, build , build
, render
, renderNodeName
, renderedNodeDefs , renderedNodeDefs
, opDef , opDef
, opAttr , opAttr
, opInputs
) )
import TensorFlow.BuildOp import TensorFlow.BuildOp
import TensorFlow.Ops import TensorFlow.Ops
@ -86,16 +86,19 @@ import TensorFlow.Ops
) )
import TensorFlow.Output import TensorFlow.Output
( NodeName(..) ( NodeName(..)
, Op (Rendered)
, Output(..) , Output(..)
, OutputIx(..) , OutputIx(..)
, outputIndex , outputIndex
) )
import TensorFlow.Tensor import TensorFlow.Tensor
( Tensor(..) ( Tensor(..)
, TensorKind (ValueKind)
, Value , Value
, tensorOutput , render
, expr
, Rendered
, tensorNodeName
, renderedOutput
, renderValue
) )
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef import Proto.Tensorflow.Core.Framework.NodeDef
@ -114,10 +117,7 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@. -- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 m . (MonadBuild m gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a) , Rendered v2
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, GradientCompatible a , GradientCompatible a
) )
=> Tensor v1 a -- ^ The output of the graph. => Tensor v1 a -- ^ The output of the graph.
@ -150,27 +150,31 @@ gradients y xs = build $ do
-- --
-- 4. Lookup the recorded gradient for each x in xs. -- 4. Lookup the recorded gradient for each x in xs.
yName <- renderNodeName y y' <- renderValue y
let yName = tensorNodeName y'
yOne <- render $ fill (shape y') (scalar 1)
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName? -- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $ nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x)) (\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup . flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup let (gr, nodeMap) = createGraph yName nodeDefLookup
-- Set gradient of y to one. -- Set gradient of y to one.
-- TODO: nicer
let initPending :: Map.Map FGL.Node (PendingGradients a) let initPending :: Map.Map FGL.Node (PendingGradients a)
initPending = Map.empty & at (nodeMap Map.! yName) = Map.empty & (at (nodeMap Map.! yName)
. nonEmpty . nonEmpty
. outputIxAt (y ^. tensorOutput . outputIndex) . outputIxAt (outputIndex $ renderedOutput y')
. nonEmpty . nonEmpty
.~ [fill (shape y) (scalar 1)] .~ [yOne]
)
-- Calculate the gradients of y w.r.t. each node in the graph. -- Calculate the gradients of y w.r.t. each node in the graph.
gradientMap <- graphGrads gr initPending gradientMap <- graphGrads gr initPending
-- Lookup the gradients for each x. -- Lookup the gradients for each x.
forM xs $ \x -> do forM xs $ \x ->
xName <- renderNodeName x let xName = tensorNodeName x
render $ fromMaybe (zerosLike x) $ do in maybe (render $ zerosLike x) return $ do
n <- nodeMap ^. at xName n <- nodeMap ^. at xName
let i = x ^. tensorOutput . outputIndex let i = outputIndex $ renderedOutput x
gradientMap ^. at n . nonEmpty . outputIxAt i gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v) outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx
type PendingGradients a = IntMap.IntMap [Tensor Value a] type PendingGradients a = IntMap.IntMap [Tensor Value a]
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype. -- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
-- TODO: precache the rendering?
type Gradients a = IntMap.IntMap (Tensor Value a) type Gradients a = IntMap.IntMap (Tensor Value a)
-- | Graph of TensorFlow operations. -- | Graph of TensorFlow operations.
@ -229,42 +234,43 @@ non a = anon a (a==)
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v) nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null nonEmpty = anon mempty null
-- TODO: strictness (e.g., foldlM')
-- | Calculate the gradients for every node in a graph. -- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a graphGrads :: forall a. GradientCompatible a
=> Graph => Graph
-> Map FGL.Node (PendingGradients a) -> Map FGL.Node (PendingGradients a)
-- ^ Initial gradients (usually just 1 for the node of interest). -- ^ Initial gradients (usually just 1 for the node of interest).
-> Build (Map FGL.Node (Gradients a)) -> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult) graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
where where
initState = GradientsState initPending Map.empty initState = GradientsState initPending Map.empty
-- Reverse topological sort. -- Reverse topological sort.
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to -- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
-- avoid calculating gradients that won't be used. -- avoid calculating gradients that won't be used.
nodeOrder = FGL.topsort $ FGL.grev gr nodeOrder = FGL.topsort $ FGL.grev gr
go state node = go :: GradientsState a -> Int -> Build (GradientsState a)
go state node = do
-- Aggregate the accumulated gradients for this node. -- Aggregate the accumulated gradients for this node.
let outputGrads = outputGrads <-
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty) sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
in if null outputGrads if null outputGrads
then state then pure state
else else do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs. -- Calculate the gradients for each of the node's inputs.
let nextState = state & gradientsResult %~ Map.insert node outputGrads let nextState = state & gradientsResult %~ Map.insert node outputGrads
ctx = FGL.context gr node pure $ updatePendingGradients ctx inputGrads nextState
in updatePendingGradients
ctx
(calculateInputGrads ctx outputGrads gr)
nextState
-- | Reduce accumulated gradients for each output to one Tensor. -- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Gradients a => PendingGradients a -> Build (Gradients a)
sumPendingGradient = IntMap.mapMaybe f sumPendingGradient = sequence . IntMap.mapMaybe f
where where
f [] = Nothing f [] = Nothing
f [x] = Just x f [x] = Just (pure x)
f xs = Just (addN xs) f xs = Just (render $ addN xs)
-- | Calculate the gradients of a node's input tensors. -- | Calculate the gradients of a node's input tensors.
@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel => FGL.Context NodeDef EdgeLabel
-> Gradients a -- ^ Output gradients of the node. -> Gradients a -- ^ Output gradients of the node.
-> Graph -> Graph
-> [Maybe (Tensor Value a)] -> Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
outputGrads
traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
where where
fullOutGrads =
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
-- Create a tensor from an edge (technically an Output, but it seems less -- Create a tensor from an edge (technically an Output, but it seems less
-- confusing to refer to it as a tensor here). -- confusing to refer to it as a tensor here).
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor ((i, _), n) = edgeToTensor ((i, _), n) =
case FGL.lab gr n of case FGL.lab gr n of
Just edgeNodeDef -> Output i (Rendered edgeNodeDef) Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name)
Nothing -> error $ "calculateInputGrads: missing input node for " Nothing -> error $ "calculateInputGrads: missing input node for "
++ Text.unpack (nodeDef ^. name) ++ Text.unpack (nodeDef ^. name)
-- Input tensors, sorted by input index. -- Input tensors, sorted by input index.
@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
-- | Convert a Map of gradients to a list, with zeros for missing outputs. -- | Convert a Map of gradients to a list, with zeros for missing outputs.
fullOutputGrads :: (TensorType a, Num a) fullOutputGrads :: (TensorType a, Num a)
=> OutputIx -- ^ Number of outputs. => OutputIx -- ^ Number of outputs.
-> Op -> NodeName
-> Gradients a -> Gradients a
-> [Tensor Value a] -> Build [Tensor Value a]
fullOutputGrads n o gs = fullOutputGrads n o gs =
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1] mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
where where
-- A tensor of zeros with the same shape as the i'th output. -- A tensor of zeros with the same shape as the i'th output.
zero i = zerosLike $ toT (Output i o) zero i = zerosLike $ toT (Output i o)
@ -397,19 +403,19 @@ type GradientFunc a = NodeDef
-- ^ Input tensors. -- ^ Input tensors.
-> [Tensor Value a] -> [Tensor Value a]
-- ^ Gradient of y w.r.t. each output tensor. -- ^ Gradient of y w.r.t. each output tensor.
-> [Maybe (Tensor Value a)] -> [Maybe (Tensor Build a)]
-- ^ Gradient of y w.r.t. each input tensor. -- ^ Gradient of y w.r.t. each input tensor.
-- TODO(fmayle): Assert the type is correct. -- TODO(fmayle): Assert the type is correct.
-- | Create a Tensor from an Output. -- | Create a Tensor from an Output.
toT :: Output -> Tensor Value a toT :: Output -> Tensor Build a
toT = Tensor ValueKind toT = Tensor . pure
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for -- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations. -- simple slicing operations.
flatSlice :: forall v1 t . (TensorType t) flatSlice :: forall v1 t . TensorType t
=> Tensor v1 t -- ^ __input__ => Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of -> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from. -- 'input' to slice from.
@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t)
-- of 'input' to slice. If size is -1, all remaining elements in the dimension -- of 'input' to slice. If size is -1, all remaining elements in the dimension
-- are included in the slice (i.e. this is equivalent to setting -- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin). -- size = input.dim_size(0) - begin).
-> Tensor Value t -- ^ __output__ -> Tensor Build t -- ^ __output__
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size]) flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
nodeDefName :: NodeDef -> NodeName
nodeDefName = NodeName . view name
-- | The gradient function for an op type. -- | The gradient function for an op type.
-- --
@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
-- third_party/tensorflow/python/ops/*_grad.py -- third_party/tensorflow/python/ops/*_grad.py
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x] opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ -dz] opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x] opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "Square" _ [toT -> x] [dz] = opGrad "Square" _ [toT -> x] [dz] =
@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] =
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x -- TODO(fmayle): The python code makes dz a control dependency of the 2*x
-- (for performance reasons?). Will need to put these functions in the Build -- (for performance reasons?). Will need to put these functions in the Build
-- monad to replicate that. -- monad to replicate that.
[Just $ dz * (2 * x)] [Just $ dz `CoreOps.mul` (2 * x)]
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] = opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
-- TODO(fmayle): The python version uses a better performance implementation -- TODO(fmayle): The python version uses a better performance implementation
@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
] ]
where where
-- TODO(gnezdo): Use colocateWith but it requires Build monad. -- TODO(gnezdo): Use colocateWith but it requires Build monad.
denseShape = shape (x :: Tensor Value a) denseShape = shape (x :: Tensor Build a)
numRows = scalarize $ flatSlice denseShape 0 1 numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (-1) , flatSlice denseShape 1 (-1)
] ]
values = reshape dz valuesShape values = reshape dz valuesShape
-- TODO(fmayle): This could be either Int32 or Int64. -- TODO(fmayle): This could be either Int32 or Int64.
indices' = reshape indices allDimensions :: Tensor Value Int32 indices' = reshape indices allDimensions :: Tensor Build Int32
opGrad "Max" _ [toT -> x, toT -> indices] [dz] = opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing] [Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
where where
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
y = CoreOps.max x indices y = CoreOps.max x indices
y' = reshape y outputShapeKeptDims y' = reshape y outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims dz' = reshape dz outputShapeKeptDims
@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.tile grad tileScaling, Nothing ] [ Just $ CoreOps.tile grad tileScaling, Nothing ]
where where
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad. -- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32) outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
tileScaling = safeShapeDiv sx outputShapeKeptDims tileScaling = safeShapeDiv sx outputShapeKeptDims
grad = reshape dz outputShapeKeptDims grad = reshape dz outputShapeKeptDims
@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w =
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing] [Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
where where
[Just dz, Nothing] = opGrad "Sum" u v w [Just dz, Nothing] = opGrad "Sum" u v w
inputShape = shape (x :: Tensor Value a) inputShape = shape (x :: Tensor Build a)
outputShape = shape (dz :: Tensor Value a) outputShape = shape (dz :: Tensor Build a)
-- TODO(fmayle): Add fast path when shape is known. -- TODO(fmayle): Add fast path when shape is known.
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum dz rx) sx [ Just $ reshape (sum dz rx) sx
, Just $ reshape (sum dz ry) sy ] , Just $ reshape (sum dz ry) sy ]
where where
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Value a) sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy (rx, ry) = broadcastGradientArgs sx sy
opGrad "Sub" u v w = opGrad "Sub" u v w =
@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
opGrad "Mul" _ [toT -> x, toT -> y] [dz] = opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers. -- TODO(fmayle): Handle complex numbers.
[ Just $ reshape (sum (dz * y) rx) sx [ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
, Just $ reshape (sum (x * dz) ry) sy ] , Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
where where
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Value a) sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy (rx, ry) = broadcastGradientArgs sx sy
opGrad "Div" _ [toT -> x, toT -> y] [dz] = opGrad "Div" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers. -- TODO(fmayle): Handle complex numbers.
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div. -- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx [ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy , Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
ry)
sy
] ]
where where
sx = shape (x :: Tensor Value a) sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Value a) sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy (rx, ry) = broadcastGradientArgs sx sy
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
opGrad "Transpose" _ [_, toT -> p] [dz] = opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz [ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Value Int32) (CoreOps.invertPermutation p :: Tensor Build Int32)
, Nothing , Nothing
] ]
@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
x output dz x output dz
] ]
where where
output :: Tensor Value a output :: Tensor Build a
output = toT $ Output 0 (Rendered nodeDef) output = toT $ Output 0 (nodeDefName nodeDef)
ksize = lookupAttr nodeDef "ksize" :: [Int64] ksize = lookupAttr nodeDef "ksize" :: [Int64]
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] = opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing] [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing] opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing] opGrad "TruncatedNormal" _ _ _ = [Nothing]
opGrad "RefIdentity" _ _ [dz] = [Just dz] opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz]
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast] opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
where where
-- TODO(gnezdo): too permissive, python only allows float types as src_type. -- TODO(gnezdo): too permissive, python only allows float types as src_type.
reverseCast = reverseCast =
buildOp (opDef "Cast" pureOp [] $ pure (opDef "Cast"
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString) & opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)) & opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)
dz & opInputs .~ [renderedOutput dz])
opGrad "DynamicStitch" nodeDef inputs [dz] = opGrad "DynamicStitch" nodeDef inputs [dz] =
replicate halfLen Nothing ++ valuesGrads replicate halfLen Nothing ++ valuesGrads
@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] =
in if 2 * half == len in if 2 * half == len
then half then half
else error ("Uneven input size " ++ show (len, showMessage nodeDef)) else error ("Uneven input size " ++ show (len, showMessage nodeDef))
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32) valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
| idx <- take halfLen inputs | idx <- take halfLen inputs
] ]
@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
[ Just reconstructed, Nothing ] [ Just reconstructed, Nothing ]
where where
reconstructed = CoreOps.reshape stitched reconstructed = CoreOps.reshape stitched
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32) (CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
stitched = CoreOps.dynamicStitch partitionedIndices dz stitched = CoreOps.dynamicStitch partitionedIndices dz
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
np = lookupAttr nodeDef "num_partitions" :: Int64 np = lookupAttr nodeDef "num_partitions" :: Int64
originalIndices = originalIndices =
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
prefixShape = shapeInt32 indices prefixShape = shapeInt32 indices
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32 shapeInt32 t = CoreOps.shape t :: Tensor Build Int32
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] = opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
[ Nothing [ Nothing
@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
where zeros = CoreOps.zerosLike x where zeros = CoreOps.zerosLike x
-- TODO(gnezdo): Unlike Python, no control dependency on dz. -- TODO(gnezdo): Unlike Python, no control dependency on dz.
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ] opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
-- TODO(gnezdo): Reuse the output instead of doing another exp, -- TODO(gnezdo): Reuse the output instead of doing another exp,
-- though, it is probably CSE'd away anyway. -- though, it is probably CSE'd away anyway.
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ] opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] = opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
[ Just $ CoreOps.unsortedSegmentSum [ Just $ CoreOps.unsortedSegmentSum
(CoreOps.gather dz (t :: Tensor Value Int32)) (CoreOps.gather dz (t :: Tensor Build Int32))
(y :: Tensor Value Int32) inputRows (y :: Tensor Build Int32) inputRows
, Nothing , Nothing
, Nothing , Nothing
] ]
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1 where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing] opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing] opGrad "LabelWeights" _ _ _ = [Nothing]
@ -710,13 +721,13 @@ numOutputs o =
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op) _ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0` -- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32 safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1) safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Value Int32 allDimensions :: Tensor Build Int32
allDimensions = vector [-1 :: Int32] allDimensions = vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32 rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1 rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1 lookupAttr :: Attribute a1 => NodeDef -> Text -> a1

View file

@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.BuildOp import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group) import TensorFlow.ControlFlow (group)
import TensorFlow.Output (unNodeName)
import TensorFlow.Tensor import TensorFlow.Tensor
import TensorFlow.Types import TensorFlow.Types
@ -183,7 +182,7 @@ import qualified Prelude (abs)
-- "1". -- "1".
instance ( TensorType a instance ( TensorType a
, Num a , Num a
, v ~ Value , v ~ Build
, OneOf '[ Double, Float, Int32, Int64 , OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where , Complex Float, Complex Double] a) => Num (Tensor v a) where
(+) = CoreOps.add (+) = CoreOps.add
@ -194,10 +193,10 @@ instance ( TensorType a
signum = CoreOps.sign signum = CoreOps.sign
negate = CoreOps.neg negate = CoreOps.neg
matTranspose :: TensorType a => Tensor v a -> Tensor Value a matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id matTranspose = matTranspose' id
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32]) matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a)
placeholder' params pShape placeholder' params pShape
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful, -- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
-- and thus would be CSE'd. -- and thus would be CSE'd.
= build $ buildOp $ opDef "Placeholder" = build $ buildOp [] $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a) & opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ pShape & opAttr "shape" .~ pShape
& params & params
@ -216,11 +215,11 @@ placeholder' params pShape
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a) initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a) => Tensor v a -> m (Tensor Ref a)
initializedVariable = initializedVariable' id initializedVariable = initializedVariable' id
initializedVariable' :: (MonadBuild m, TensorType a) initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor Value a -> m (Tensor Ref a) => OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' params initializer = do initializedVariable' params initializer = do
v <- CoreOps.variable' params [] -- The shape is not known initially. v <- CoreOps.variable' params [] -- The shape is not known initially.
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
@ -240,17 +239,20 @@ zeroInitializedVariable'
zeroInitializedVariable' params = initializedVariable' params . zeros zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors. -- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (MonadBuild m, TensorType a) save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save. -> [Tensor v a] -- ^ Tensors to save.
-> m ControlNode -> m ControlNode
save path xs = do save path xs = build $ do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs let names = fmap toByteStringTensor xs
let types = replicate (length xs) (tensorType (undefined :: a)) let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save" names' <- buildInputs $ CoreOps.pack names
xs' <- buildInputs xs
path' <- buildInputs $ scalar path
buildOp [] $ opDef "Save"
& opAttr "T" .~ types & opAttr "T" .~ types
build $ saveOp (scalar path) (CoreOps.pack names) xs & opInputs .~ (path' ++ names' ++ xs')
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
-- --
@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a)
-> ByteString -- ^ Tensor name override. -> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore. -> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode -> m ControlNode
restoreFromName path name x = do restoreFromName path name x = build $ do
let restoreOp = buildOp $ opDef "Restore" path' <- buildInputs $ scalar path
name' <- buildInputs $ scalar name
restoreOp <- buildOp [] $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a) & opAttr "dt" .~ tensorType (undefined :: a)
group =<< CoreOps.assign x & opInputs .~ (path' ++ name')
(restoreOp (scalar path) (scalar name) :: Tensor Value a) group =<< CoreOps.assign x (restoreOp :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file. -- | Restore a tensor's value from a checkpoint file.
restore :: forall a m . (MonadBuild m, TensorType a) restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path. => ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore. -> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode -> m ControlNode
restore path x = do restore path x = restoreFromName path name x
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x) where
restoreFromName path name x name = encodeUtf8 $ encodeOutput $ renderedOutput x
-- | Create a constant tensor. -- | Create a constant tensor.
-- --
@ -283,10 +287,10 @@ restore path x = do
-- element 0: index (0, ..., 0) -- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1) -- element 1: index (0, ..., 1)
-- ... -- ...
constant :: TensorType a => Shape -> [a] -> Tensor Value a constant :: TensorType a => Shape -> [a] -> Tensor Build a
constant = constant' id constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
constant' params (Shape cShape) values constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg | invalidLength = error invalidLengthMsg
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode)) | otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
@ -305,24 +309,24 @@ constant' params (Shape cShape) values
-- | Reshape a N-D tensor down to a scalar. -- | Reshape a N-D tensor down to a scalar.
-- --
-- See `TensorFlow.GenOps.Core.reshape`. -- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a scalarize :: TensorType a => Tensor v a -> Tensor Build a
scalarize t = CoreOps.reshape t (vector scalarShape) scalarize t = CoreOps.reshape t (vector scalarShape)
where where
scalarShape = [] :: [Int32] scalarShape = [] :: [Int32]
-- | Create a constant vector. -- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a vector :: TensorType a => [a] -> Tensor Build a
vector = vector' id vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
vector' params xs = constant' params [fromIntegral $ length xs] xs vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar. -- | Create a constant scalar.
scalar :: TensorType a => a -> Tensor Value a scalar :: TensorType a => a -> Tensor Build a
scalar = scalar' id scalar = scalar' id
scalar' :: TensorType a => OpParams -> a -> Tensor Value a scalar' :: TensorType a => OpParams -> a -> Tensor Build a
scalar' params x = constant' params [] [x] scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values. -- | Random tensor from the unit normal distribution with bounded values.
@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
-> m (Tensor Value a) -> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal' truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0) zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32 shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape = CoreOps.shape shape = CoreOps.shape
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32 shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
shape' = CoreOps.shape' shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims = CoreOps.expandDims expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' = CoreOps.expandDims' expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape). -- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape inputShape axes = reducedShape inputShape axes =
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7] let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
axes32 = toInt32 axes -- [1, 2] axes32 = toInt32 axes -- [1, 2]
toInt32 x = CoreOps.cast x :: Tensor Value Int32 toInt32 x = CoreOps.cast x :: Tensor Build Int32
inputRank = CoreOps.size inputShape32 -- 4 inputRank = CoreOps.size inputShape32 -- 4
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
axesShape = shape axesMod -- [2] axesShape = shape axesMod -- [2]

View file

@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest
, test-framework , test-framework
, test-framework-hunit , test-framework-hunit
, test-framework-quickcheck2 , test-framework-quickcheck2
, transformers
, vector , vector
Test-Suite ArrayOpsTest Test-Suite ArrayOpsTest

View file

@ -24,9 +24,7 @@ import Test.HUnit ((@=?))
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses. -- | Test split and concat are inverses.
@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
testShapeN :: Test testShapeN :: Test
testShapeN = testCase "testShapeN" $ TF.runSession $ do testShapeN = testCase "testShapeN" $ TF.runSession $ do
let shapes = map TF.Shape [[1],[2,3]] let shapes = map TF.Shape [[1],[2,3]]
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float] let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float]
result <- TF.run $ CoreOps.shapeN tensors result <- TF.run $ CoreOps.shapeN tensors
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64]) liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])

View file

@ -34,9 +34,7 @@ import TensorFlow.Build
, asGraphDef , asGraphDef
, evalBuildT , evalBuildT
, flushNodeBuffer , flushNodeBuffer
, render
, withDevice , withDevice
, colocateWith
, withNameScope , withNameScope
, opName , opName
) )
@ -50,7 +48,13 @@ import TensorFlow.Ops
, variable' , variable'
) )
import TensorFlow.Output (Device(..)) import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref) import TensorFlow.Tensor
( colocateWith
, render
, Tensor
, Value
, Ref
)
import TensorFlow.Session import TensorFlow.Session
( run ( run
, runSession , runSession
@ -65,8 +69,7 @@ import qualified Data.Vector as V
-- | Test 'opName' behavior. -- | Test 'opName' behavior.
testOpName :: Test testOpName :: Test
testOpName = testCase "testOpName" $ do testOpName = testCase "testOpName" $ do
let graph = variable' (opName .~ "foo") [] let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float)
>>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node nodeDef = head $ asGraphDef graph ^. node
"Variable" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)
@ -114,7 +117,6 @@ testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float) let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" (variable' (opName .~ "bar1") []) graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
>>= render
nodeDef :: NodeDef nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node nodeDef = head $ asGraphDef graph ^. node
"Variable" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)

View file

@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run)
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Core as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
-- DynamicSplit is undone with DynamicStitch to get the original input -- DynamicSplit is undone with DynamicStitch to get the original input
-- back. -- back.
testDynamicPartitionStitchInverse :: forall a. testDynamicPartitionStitchInverse :: forall a.
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property (TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) = testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] = let splitParts :: [TF.Tensor TF.Build a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor CoreOps.dynamicPartition numParts (TF.vector values) partTensor
partTensor = TF.vector partitions partTensor = TF.vector partitions
restitchIndices = CoreOps.dynamicPartition numParts restitchIndices = CoreOps.dynamicPartition numParts

View file

@ -19,6 +19,7 @@
-- | Tests for EmbeddingOps. -- | Tests for EmbeddingOps.
module Main where module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.List (genericLength) import Data.List (genericLength)
import Google.Test (googleTest) import Google.Test (googleTest)
@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition =
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items. let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [1, 1, 1 :: Int32] let embedding1 = [1, 1, 1 :: Int32]
let embedding2 = [0, 0, 0 :: Int32] let embedding2 = [0, 0, 0 :: Int32]
let embedding = [ TF.constant embShape embedding1
, TF.constant embShape embedding2
]
let idValues = [0, 1 :: Int32] let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- TF.runSession $ do (values, shape) <- TF.runSession $ do
vs <- op embedding <- mapM TF.render [ TF.constant embShape embedding1
, TF.constant embShape embedding2
]
let ids = TF.constant (TF.Shape [1, 2]) idValues
vs <- embeddingLookup embedding ids
TF.run (vs, TF.shape vs) TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python. -- This is the shape that is returned in the equiv. Python.
@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape =
, 0, 0, 0 :: Int32 , 0, 0, 0 :: Int32
] ]
let embedding = TF.constant embShape embeddingInit
let idValues = [0, 1 :: Int32] let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- TF.runSession $ do (values, shape) <- TF.runSession $ do
vs <- op embedding <- TF.render $ TF.constant embShape embeddingInit
let ids = TF.constant (TF.Shape [1, 2]) idValues
vs <- embeddingLookup [embedding] ids
TF.run (vs, TF.shape vs) TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python. -- This is the shape that is returned in the equiv. Python.
@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape =
-- "[0, 1]" should pull out the resulting vector. -- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [1, 1, 1, 0, 0, 0] values @=? V.fromList [1, 1, 1, 0, 0, 0]
-- | Check that we can calculate gradients w.r.t embeddings. -- | Check that we can calculate gradients w.r.t embeddings.
testEmbeddingLookupGradients :: Test testEmbeddingLookupGradients :: Test
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
x <- TF.placeholder (TF.Shape [2]) x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable embedding <- TF.initializedVariable
=<< TF.render (TF.constant embShape embeddingInit) (TF.constant embShape embeddingInit)
op <- embeddingLookup [embedding] ids op <- embeddingLookup [embedding] ids
let twoNorm = CoreOps.square $ TF.abs (op - x) let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x)
loss = TF.mean twoNorm (TF.scalar (0 :: Int32)) loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding]) grad <- fmap head (TF.gradients loss [embedding])
@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit
(LookupExample numParts (LookupExample numParts
shape@(TF.Shape (firstDim : restDims)) shape@(TF.Shape (firstDim : restDims))
values values
indices) = indices) = monadicIO $ run $ TF.runSession $ do
let modShardedValues :: [TF.Tensor TF.Value a] = let shapedValues = TF.constant shape values
CoreOps.dynamicPartition numParts shapedValues cyclicCounter indicesVector <- TF.render $ TF.vector indices
cyclicCounter :: TF.Tensor TF.Value Int32 = let directs = CoreOps.gather shapedValues indicesVector
let cyclicCounter :: TF.Tensor TF.Build Int32 =
TF.vector [0..fromIntegral firstDim-1] TF.vector [0..fromIntegral firstDim-1]
`CoreOps.mod` fromIntegral numParts `CoreOps.mod` fromIntegral numParts
indicesVector = TF.vector indices modShardedValues :: [TF.Tensor TF.Value a] <-
directs = CoreOps.gather shapedValues indicesVector mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.run =<< do
embeddings <- embeddingLookup modShardedValues indicesVector embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs) (shapeOut, got, want :: V.Vector a) <-
TF.run (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup. -- Checks the explicitly documented invariant of embeddingLookup.
shapeOut @=? V.fromList (genericLength indices : restDims) liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims)
got @=? want liftIO $ got @=? want
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)" testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit. -- | Consistent set of parameters for EmbeddingLookupUndoesSplit.

View file

@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op)
testGradientSimple :: Test testGradientSimple :: Test
testGradientSimple = testCase "testGradientSimple" $ do testGradientSimple = testCase "testGradientSimple" $ do
let x = TF.scalar (3 :: Float) let grads = do
b = TF.scalar (4 :: Float) x <- TF.render $ TF.scalar (3 :: Float)
y = x*x + b b <- TF.render $ TF.scalar (4 :: Float)
grads = TF.gradients y [x, b] let y = x `TF.mul` x `TF.add` b
TF.gradients y [x, b]
-- Assert that the gradients are right. -- Assert that the gradients are right.
[dx, db] <- TF.runSession $ grads >>= TF.run [dx, db] <- TF.runSession $ grads >>= TF.run
6 @=? TF.unScalar dx 6 @=? TF.unScalar dx
@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do
testGradientDisconnected :: Test testGradientDisconnected :: Test
testGradientDisconnected = testCase "testGradientDisconnected" $ do testGradientDisconnected = testCase "testGradientDisconnected" $ do
let x = TF.scalar (3 :: Float) let grads = do
b = TF.scalar (4 :: Float) x <- TF.render $ TF.scalar (3 :: Float)
grads = TF.gradients x [x, b] b <- TF.render $ TF.scalar (4 :: Float)
TF.gradients x [x, b]
-- Assert that the gradients are right. -- Assert that the gradients are right.
[dx, db] <- TF.runSession $ grads >>= TF.run [dx, db] <- TF.runSession $ grads >>= TF.run
1 @=? TF.unScalar dx 1 @=? TF.unScalar dx
@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
let shape = TF.constant (TF.Shape [1]) [1] let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y] >>= TF.run TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run
-- If this test fails, it will likely be caused by an exception within -- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra. -- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx 1 @=? TF.unScalar dx
@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
testDiamond :: Test testDiamond :: Test
testDiamond = testCase "testDiamond" $ do testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ do [dx] <- TF.runSession $ do
let x = TF.vector [1] x <- TF.render $ TF.vector [1]
y = x*x let y = x `TF.mul` x
z = y*y z = y*y
TF.gradients z [x] >>= TF.run TF.gradients z [x] >>= TF.run
(4 :: Float) @=? TF.unScalar dx (4 :: Float) @=? TF.unScalar dx
@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do
testMaxGradient :: Test testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ do [dx] <- TF.runSession $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float] x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32) let y = TF.max x (0 :: TF.Tensor TF.Build Int32)
TF.gradients y [x] >>= TF.run TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx

View file

@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint" let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<< var = TF.zeroInitializedVariable' (TF.opName .~ "a")
TF.zeroInitializedVariable' (TF.opName .~ "a")
(TF.Shape []) (TF.Shape [])
TF.runSession $ do TF.runSession $ do
v <- var v <- var
@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
p2 <- TF.placeholder [] p2 <- TF.placeholder []
let enc :: Float -> TF.TensorData Float let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n]) enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5 liftIO $ result @=? TF.Scalar 5
-- | Test that regular tensors can also be used for feeds, as long as they each -- | Test that regular tensors can also be used for feeds, as long as they each
@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0 p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
let enc :: Float -> TF.TensorData Float let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n]) enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5 liftIO $ result @=? TF.Scalar 5
main :: IO () main :: IO ()

View file

@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do
return (w', b') return (w', b')
gradientDescent :: Float gradientDescent :: Float
-> TF.Tensor TF.Value Float -> TF.Tensor TF.Build Float
-> [TF.Tensor TF.Ref Float] -> [TF.Tensor TF.Ref Float]
-> TF.Session TF.ControlNode -> TF.Session TF.ControlNode
gradientDescent alpha loss params = do gradientDescent alpha loss params = do

View file

@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
let floatData = V.fromList [1..6 :: Float] let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False] boolData = V.fromList [True, True, False, True, False, False]
f <- TF.build $ TF.placeholder [2,3] f <- TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3] s <- TF.placeholder [2,3]
b <- TF.build $ TF.placeholder [2,3] b <- TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData) , TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData) , TF.feed b (TF.encodeTensorData [2,3] boolData)
@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
-- Do something idempotent to the tensors to verify that tensorflow can -- Do something idempotent to the tensors to verify that tensorflow can
-- handle the encoding. Originally this used `TF.identity`, but that -- handle the encoding. Originally this used `TF.identity`, but that
-- wasn't enough to catch a bug in the encoding of Bool. -- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b) (f', s', b') <- TF.runWithFeeds feeds
(f `TF.add` 0, TF.identity s, TF.select b b b)
liftIO $ do liftIO $ do
floatData @=? f' floatData @=? f'
stringData @=? s' stringData @=? s'

View file

@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
-- under the given name across multiple sessions. -- under the given name across multiple sessions.
-> m (Queue as) -> m (Queue as)
makeQueue capacity sharedName = do makeQueue capacity sharedName = do
q <- build $ buildOp (opDef "FIFOQueue" q <- build $ buildOp [] (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName & opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity & opAttr "capacity" .~ capacity

View file

@ -13,9 +13,13 @@
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} {-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.Build module TensorFlow.Build
( -- * Graph node types ( -- * Graph node types
ControlNode(..) ControlNode(..)
@ -32,8 +36,6 @@ module TensorFlow.Build
, opControlInputs , opControlInputs
-- * The Build monad -- * The Build monad
, GraphState , GraphState
, render
, renderNodeName
, renderedNodeDefs , renderedNodeDefs
, BuildT , BuildT
, Build , Build
@ -46,27 +48,23 @@ module TensorFlow.Build
, addGraphDef , addGraphDef
, flushInitializers , flushInitializers
, flushNodeBuffer , flushNodeBuffer
, summaries
-- * Creating and looking up Ops -- * Creating and looking up Ops
, getOrAddOp , getOrAddOp
, addNewOp , addNewOp
, renderOutput , encodeOutput
, lookupNode
-- * Modifying all nodes in a Build action -- * Modifying all nodes in a Build action
, colocateWith
, withStateLens , withStateLens
, withDevice , withDevice
, withNameScope , withNameScope
, withNodeDependencies , withNodeDependencies
-- * Internal Summary related bits.
, addSummary
, SummaryTensor
, collectAllSummaries
) where ) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.ByteString (ByteString)
import Data.Default (def) import Data.Default (def)
import Data.Functor.Identity (Identity(..)) import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef
import TensorFlow.Orphans () import TensorFlow.Orphans ()
import TensorFlow.Output import TensorFlow.Output
import TensorFlow.Tensor
newtype Unique = Unique Int newtype Unique = Unique Int
deriving (Eq, Ord, Enum) deriving (Eq, Ord, Enum)
@ -125,9 +122,6 @@ opDefWithName n t = OpDef
, _opControlInputs = [] , _opControlInputs = []
} }
-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString
data GraphState = GraphState data GraphState = GraphState
{ _renderedNodes :: !(Map.Map PendingNode NodeDef) { _renderedNodes :: !(Map.Map PendingNode NodeDef)
-- ^ Nodes which have been rendered. Keeps track of the unique ID we -- ^ Nodes which have been rendered. Keeps track of the unique ID we
@ -148,8 +142,8 @@ data GraphState = GraphState
, _initializationNodes :: [NodeName] , _initializationNodes :: [NodeName]
-- ^ The nodes to run next time a TF.run is issued, typically -- ^ The nodes to run next time a TF.run is issued, typically
-- variable initializers. -- variable initializers.
, _summaries :: [SummaryTensor] , _summaries :: [Output]
-- ^ The tensors for summary -- ^ The tensors for summary (ByteString type)
} }
-- | A node definition without its final name. Used as a key in the -- | A node definition without its final name. Used as a key in the
@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs
initializationNodes :: Lens' GraphState [NodeName] initializationNodes :: Lens' GraphState [NodeName]
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x }) initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
summaries :: Lens' GraphState [SummaryTensor] summaries :: Lens' GraphState [Output]
summaries = lens _summaries (\g x -> g { _summaries = x }) summaries = lens _summaries (\g x -> g { _summaries = x })
-- | An action for building nodes in a TensorFlow graph. -- | An action for building nodes in a TensorFlow graph.
@ -238,9 +232,7 @@ flushInitializers = do
-- | Registers the given node to be executed before the next -- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'. -- 'TensorFlow.Session.run'.
addInitializer :: MonadBuild m => ControlNode -> m () addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode o) = build $ do addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
i <- getOrAddOp o
initializationNodes %= (i:)
-- | Produce a GraphDef proto representation of the nodes that are rendered in -- | Produce a GraphDef proto representation of the nodes that are rendered in
-- the given 'Build' action. -- the given 'Build' action.
@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its -- | Render the given op if it hasn't been rendered already, and return its
-- name. -- name.
getOrAddOp :: Op -> Build NodeName getOrAddOp :: OpDef -> Build NodeName
getOrAddOp o = NodeName . (^. name) <$> resolveOp o getOrAddOp o = do
resolveOp :: Op -> Build NodeDef
resolveOp (Rendered n) = return n
resolveOp (Unrendered o) = do
pending <- getPendingNode o pending <- getPendingNode o
uses renderedNodes (Map.lookup pending) >>= \case uses renderedNodes (Map.lookup pending) >>= \case
Just n -> return n Just n -> return $ NodeName $ n ^. name
Nothing -> addNewOpFromPending pending Nothing -> addNewOpFromPending pending
lookupNode :: NodeName -> Build NodeDef
lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case
Just n' -> return n'
Nothing -> error $ "lookupNode: unknown node name " ++ show n
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops -- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
-- which are not safe to dedup (e.g, "variable" and "assign"). -- which are not safe to dedup (e.g, "variable" and "assign").
addNewOp :: OpDef -> Build NodeDef addNewOp :: OpDef -> Build NodeName
addNewOp o = getPendingNode o >>= addNewOpFromPending addNewOp o = getPendingNode o >>= addNewOpFromPending
addNewOpFromPending :: PendingNode -> Build NodeDef addNewOpFromPending :: PendingNode -> Build NodeName
addNewOpFromPending pending = do addNewOpFromPending pending = do
nodeName <- renderPendingNode pending nodeName <- renderPendingNode pending
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
nodeBuffer %= (nodeDef :) nodeBuffer %= (nodeDef :)
renderedNodes %= Map.insert pending nodeDef renderedNodes %= Map.insert pending nodeDef
renderedNodeDefs %= Map.insert nodeName nodeDef renderedNodeDefs %= Map.insert nodeName nodeDef
return nodeDef return nodeName
-- | Get the pending node corresponding to an OpDef, which may or may not have -- | Get the pending node corresponding to an OpDef, which may or may not have
-- been rendered before. Implicitly renders all of this node's inputs. -- been rendered before. Implicitly renders all of this node's inputs.
@ -287,20 +280,18 @@ getPendingNode o = do
-- An empty string in the proto field means that no specific -- An empty string in the proto field means that no specific
-- device is specified. -- device is specified.
dev <- maybe "" deviceName <$> use defaultDevice dev <- maybe "" deviceName <$> use defaultDevice
inputs <- mapM getInput (o ^. opInputs)
scope <- use currentScope scope <- use currentScope
controls <- use defaultControlInputs controls <- use defaultControlInputs
let inputs = map encodeOutput (o ^. opInputs)
let controlInputs let controlInputs
= map getDep (o ^. opControlInputs ++ Set.toList controls) = map makeDep (o ^. opControlInputs ++ Set.toList controls)
return $ PendingNode scope (o ^. opName) return $ PendingNode scope (o ^. opName)
$ def & op .~ (unOpType (o ^. opType) :: Text) $ def & op .~ (unOpType (o ^. opType) :: Text)
& attr .~ _opAttrs o & attr .~ _opAttrs o
& input .~ (inputs ++ controlInputs) & input .~ (inputs ++ controlInputs)
& device .~ dev & device .~ dev
where where
getInput (Output (OutputIx k) subOp) makeDep = ("^" <>) . unNodeName
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
getDep = ("^" <>) . unNodeName
-- | Pick a name for a pending node. If it has an explicit name, just use that; -- | Pick a name for a pending node. If it has an explicit name, just use that;
-- if the name is implicit, assign a new unique name based on the op type. -- if the name is implicit, assign a new unique name based on the op type.
@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
return $ nodeDef ^. op <> "_" <> Text.pack (show k) return $ nodeDef ^. op <> "_" <> Text.pack (show k)
-- | Render an 'Output' and return a string representation for the TensorFlow -- | Turn an 'Output' into a string representation for the TensorFlow
-- foreign APIs. -- foreign APIs.
renderOutput :: Output -> Build Text encodeOutput :: Output -> Text
renderOutput (Output (OutputIx i) o) = do encodeOutput (Output (OutputIx 0) n) = unNodeName n
n <- getOrAddOp o encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i)
return $ unNodeName n <> Text.pack (":" ++ show i)
-- | Modify some part of the state, run an action, and restore the state -- | Modify some part of the state, run an action, and restore the state
-- after that action is done. -- after that action is done.
@ -339,15 +329,6 @@ withStateLens accessor f act = do
withDevice :: MonadBuild m => Maybe Device -> m a -> m a withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d) withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action. -- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: MonadBuild m => Text -> m a -> m a withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :) withNameScope s = withStateLens currentScope (Scope s :)
@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action. -- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes) withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's.
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
-- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
-- | Records the given summary action in Build for retrieval with
-- 'collectAllSummaries'. The summary op is required to produce a
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: SummaryTensor -> Build ()
addSummary t = summaries %= (t :)
-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
collectAllSummaries = use summaries

View file

@ -12,26 +12,27 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-} {-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp module TensorFlow.BuildOp
( OpResult ( BuildResult(..)
, BuildOp
, buildOp , buildOp
, buildListOp , PureResult(..)
, pureOp
, eqLengthGuard , eqLengthGuard
, BuildInputs(..)
, OpParams , OpParams
) )
where where
import Control.Monad (replicateM) import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask) import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put) import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64) import Data.Int (Int64)
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Output import TensorFlow.Output
@ -40,48 +41,45 @@ import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT Op (State ResultState) type Result = ReaderT NodeName (State ResultState)
-- | Class of types that can be used as op outputs. -- | Class of types that can be used as op outputs.
class OpResult a where class BuildResult a where
toResult :: Result a buildResult :: Result a
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
toResult = (,) <$> toResult <*> toResult buildResult = (,) <$> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
toResult = (,,) <$> toResult <*> toResult <*> toResult buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4) instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
=> OpResult (a1, a2, a3, a4) where => BuildResult (a1, a2, a3, a4) where
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5) instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
=> OpResult (a1, a2, a3, a4, a5) where => BuildResult (a1, a2, a3, a4, a5) where
toResult = (,,,,) <$> toResult buildResult = (,,,,) <$> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
instance ( OpResult a1 instance ( BuildResult a1
, OpResult a2 , BuildResult a2
, OpResult a3 , BuildResult a3
, OpResult a4 , BuildResult a4
, OpResult a5 , BuildResult a5
, OpResult a6 , BuildResult a6
) )
=> OpResult (a1, a2, a3, a4, a5, a6) where => BuildResult (a1, a2, a3, a4, a5, a6) where
toResult = (,,,,,) buildResult = (,,,,,)
<$> toResult <$> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
<*> toResult <*> buildResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = Tensor v <$> recordResult
recordResult :: Result Output recordResult :: Result Output
recordResult = do recordResult = do
@ -90,144 +88,39 @@ recordResult = do
put $! ResultState (i+1) ns put $! ResultState (i+1) ns
return $! output i o return $! output i o
instance OpResult ResourceHandle where instance BuildResult ResourceHandle where
toResult = ResourceHandle <$> recordResult buildResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where instance Rendered v => BuildResult (Tensor v a) where
toResult = tensorResult ValueKind buildResult = Tensor . pure <$> recordResult
instance OpResult (Tensor Ref a) where instance BuildResult ControlNode where
toResult = tensorResult RefKind buildResult = ControlNode <$> ask
instance OpResult ControlNode where instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
toResult = ControlNode <$> ask buildResult = loop (tensorTypes :: TensorTypeList as)
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
where where
loop :: TensorTypeList bs -> Result (TensorList v bs) loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do loop (TensorTypeProxy :/ ls) = do
t <- tensorResult v t <- buildResult
ts <- loop ls ts <- loop ls
return (t :/ ts) return (t :/ ts)
instance TensorTypes as => OpResult (TensorList Value as) where instance BuildResult a => BuildResult [a] where
toResult = tensorListResult ValueKind buildResult = do
instance TensorTypes as => OpResult (TensorList Ref as) where
toResult = tensorListResult RefKind
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get ResultState i ns <- get
case ns of case ns of
[] -> error $ "Ran out of counts in toResult. " ++ [] -> error $ "Ran out of counts in buildResult. " ++
"Likely misuse of buildListOp." "Likely misuse of buildOp."
(n : rest) -> do (n : rest) -> do
put $! ResultState i rest put $! ResultState i rest
replicateM (fromIntegral n) toResult replicateM (fromIntegral n) buildResult
runResult :: OpResult a => [Int64] -> Op -> a buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
runResult ns o = buildOp sizes o = do
case runState (runReaderT toResult o) (ResultState 0 ns) of n <- addNewOp o
(x, ResultState _ []) -> x return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
show (ns, ns')
-- | Make a new "pure" op, which may be deduped with identical ops within
-- the same scope.
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
-- | Make a new "stateful" op, which will not be deduped with otherwise
-- identical ops.
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
buildResult ns o ts
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
addReversedInputs :: OpDef -> [Output] -> OpDef
addReversedInputs o ts = o & opInputs <>~ reverse ts
-- | Class of types that can be used as op functions.
class BuildOp f where
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
-> OpDef
-> [Output] -- ^ Accumulator for inputs to the op.
-> f
-- | Starts an operation that returns a structured set of tensors
-- (singletons or tuples).
buildOp :: BuildOp f => OpDef -> f
buildOp o = buildOp' [] o []
-- | Starts an operation that returns a list of tensors.
buildListOp :: BuildOp f => [Int64]
-- ^ Cardinality of the corresponding list of tensors output.
-> OpDef -> f
buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp ResourceHandle where
buildOp' = pureResult
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Value as) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Ref as) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
=> BuildOp (t1, t2, t3, t4) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
=> BuildOp (t1, t2, t3, t4, t5) where
buildOp' = pureResult
instance ( OpResult t1
, OpResult t2
, OpResult t3
, OpResult t4
, OpResult t5
, OpResult t6
)
=> BuildOp (t1, t2, t3, t4, t5, t6) where
buildOp' = pureResult
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
instance BuildOp f => BuildOp (TensorList v as -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical. -- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not. -- Throws an error with a descriptive message if not.
@ -240,6 +133,104 @@ eqLengthGuard = all eachOk
error ("number_attr " ++ numberAttrName ++ error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs) " contains tensors with different length " ++ show pairs)
-----------
-- | Class of types that can be used as op outputs.
class PureResult a where
pureResult :: ReaderT (Build OpDef) (State ResultState) a
instance PureResult (Tensor Build a) where
pureResult = do
ResultState i ns <- get
put $! ResultState (i+1) ns
makeOp <- ask
return $ Tensor $ do
o <- makeOp
-- TODO: unify with BuildResult (Tensor v)
output i <$> getOrAddOp o
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
pureResult = (,) <$> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
=> PureResult (a1, a2, a3, a4) where
pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
=> PureResult (a1, a2, a3, a4, a5) where
pureResult = (,,,,) <$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
)
=> PureResult (a1, a2, a3, a4, a5, a6) where
pureResult = (,,,,,)
<$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance PureResult a => PureResult [a] where
pureResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in pureResult. " ++
"Likely misuse of pureOp with output lists."
n : rest -> do
put $! ResultState i rest
replicateM (fromIntegral n) pureResult
instance TensorTypes as => PureResult (TensorList Build as) where
pureResult = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
(TensorList Build bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- pureResult
ts <- loop ls
return (t :/ ts)
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)
-----
-- Class of types that can be used as arguments
class BuildInputs a where
buildInputs :: a -> Build [Output]
instance BuildInputs a => BuildInputs [a] where
buildInputs = fmap concat . mapM buildInputs
instance BuildInputs (Tensor v a) where
buildInputs (Tensor t) = do
o <- toBuild t
return [o]
instance BuildInputs (ListOf (Tensor v) as) where
buildInputs Nil = return []
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
instance BuildInputs ResourceHandle where
buildInputs (ResourceHandle o) = return [o]
----
-- | Parameters to build an op (for example, the node name or optional attributes). -- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe. -- TODO: be more type safe.
type OpParams = OpDef -> OpDef type OpParams = OpDef -> OpDef

View file

@ -25,9 +25,6 @@ module TensorFlow.ControlFlow
, noOp , noOp
) where ) where
import qualified Data.Set as Set
import Lens.Family2 ((&), (.~))
import TensorFlow.BuildOp import TensorFlow.BuildOp
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Nodes import TensorFlow.Nodes
@ -46,11 +43,8 @@ withControlDependencies deps act = do
-- When this op finishes, all ops in the input @n@ have finished. This op has -- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output. -- no output.
group :: (MonadBuild m, Nodes t) => t -> m ControlNode group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group deps = do group deps = withControlDependencies deps noOp
nodes <- build $ Set.toList <$> getNodes deps
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
-- | Does nothing. Only useful as a placeholder for control edges. -- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode noOp :: MonadBuild m => m ControlNode
noOp = buildOp $ opDef "NoOp" noOp = build $ buildOp [] $ opDef "NoOp"

View file

@ -57,9 +57,9 @@ module TensorFlow.Core
, Tensor , Tensor
, Value , Value
, Ref , Ref
, TensorKind(..)
, value , value
, tensorFromName , tensorFromName
, expr
-- ** Element types -- ** Element types
, TensorType , TensorType
, TensorData , TensorData

View file

@ -20,6 +20,7 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a)
module TensorFlow.Nodes where module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3) import Control.Applicative (liftA2, liftA3)
@ -28,7 +29,6 @@ import Data.Map.Strict (Map)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.Set (Set) import Data.Set (Set)
import Data.Text (Text) import Data.Text (Text)
import Lens.Family2 ((^.))
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import qualified Data.Set as Set import qualified Data.Set as Set
@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where
getFetch ts = sequenceA <$> mapM getFetch ts getFetch ts = sequenceA <$> mapM getFetch ts
instance Nodes ControlNode where instance Nodes ControlNode where
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o getNodes (ControlNode o) = pure $ Set.singleton o
-- We use the constraint @(a ~ ())@ to help with type inference. For example, -- We use the constraint @(a ~ ())@ to help with type inference. For example,
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session -- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
instance Nodes (Tensor v a) where instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o
fetchTensorVector :: forall a v . TensorType a fetchTensorVector :: forall a v . (TensorType a)
=> Tensor v a -> Build (Fetch (TensorData a)) => Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor _ o) = do fetchTensorVector (Tensor o) = do
outputName <- renderOutput o outputName <- encodeOutput <$> toBuild o
return $ Fetch (Set.singleton outputName) $ \tensors -> pure $ Fetch (Set.singleton outputName) $ \tensors ->
let tensorData = tensors Map.! outputName let tensorData = tensors Map.! outputName
expectedType = tensorType (undefined :: a) expectedType = tensorType (undefined :: a)
actualType = FFI.tensorDataType tensorData actualType = FFI.tensorDataType tensorData

View file

@ -22,8 +22,6 @@ module TensorFlow.Output
, Device(..) , Device(..)
-- * Ops -- * Ops
, NodeName(..) , NodeName(..)
, Op(..)
, opUnrendered
, OpDef(..) , OpDef(..)
, opName , opName
, opType , opType
@ -34,28 +32,24 @@ module TensorFlow.Output
, OutputIx(..) , OutputIx(..)
, Output(..) , Output(..)
, output , output
, outputIndex
, outputOp
, PendingNodeName(..) , PendingNodeName(..)
, ResourceHandle(..) , ResourceHandle(..)
) where ) where
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..)) import Data.String (IsString(..))
import Data.Text (Text) import Data.Text (Text)
import qualified Data.Text as Text import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.)) import Lens.Family2 (Lens')
import Lens.Family2.Unchecked (lens) import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..)) import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def) import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens) import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans () import TensorFlow.Orphans ()
-- | A type of graph node which has no outputs. These nodes are -- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run. -- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: Op } newtype ControlNode = ControlNode { unControlNode :: NodeName }
-- | The type of op of a node in the graph. This corresponds to the proto field -- | The type of op of a node in the graph. This corresponds to the proto field
-- NodeDef.op. -- NodeDef.op.
@ -66,18 +60,12 @@ instance IsString OpType where
fromString = OpType . Text.pack fromString = OpType . Text.pack
-- | An output of a TensorFlow node. -- | An output of a TensorFlow node.
data Output = Output !OutputIx !Op data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
output :: OutputIx -> Op -> Output output :: OutputIx -> NodeName -> Output
output = Output output = Output
outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
newtype OutputIx = OutputIx { unOutputIx :: Int } newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show) deriving (Eq, Ord, Num, Enum, Show)
@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text}
instance Show Device where instance Show Device where
show (Device d) = show d show (Device d) = show d
-- | The representation of a node in a TensorFlow graph.
data Op
= Rendered !NodeDef -- ^ Properties are fixed, including the
-- device, name, and scope.
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
-- on which context this op is rendered in.
deriving (Eq, Ord)
instance Show Op where
show (Rendered n) = "Rendered " ++ showMessage n
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
-- | Traverse on the 'Unrendered' of an 'Op'.
--
-- Same implementation as _Left.
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto. -- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef data OpDef = OpDef
{ _opName :: !PendingNodeName { _opName :: !PendingNodeName
@ -157,7 +126,7 @@ instance IsString Output where
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr (n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
-> Output (fromInteger ix) $ assigned n -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s _ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n where assigned = NodeName . Text.pack
-- | Opaque handle to a mutable resource in the graph. Typical such -- | Opaque handle to a mutable resource in the graph. Typical such

View file

@ -163,7 +163,7 @@ runWithFeeds feeds t = do
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
runFetchWithFeeds feeds target (Fetch fetch restore) = do runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend extend
feeds' <- build $ fixFeeds feeds let feeds' = fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession) session <- Session (asks rawSession)
@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do
ns <- build $ getNodes t ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ()) runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)] fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
-- | Starts a concurrent thread which evaluates the given Nodes -- | Starts a concurrent thread which evaluates the given Nodes
-- forever until runSession exits or an exception occurs. Graph -- forever until runSession exits or an exception occurs. Graph

View file

@ -16,21 +16,26 @@
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE KindSignatures #-} {-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-} {-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} -- For the Render class
module TensorFlow.Tensor where module TensorFlow.Tensor where
import Data.ByteString (ByteString)
import Data.String (IsString(..)) import Data.String (IsString(..))
import qualified Data.Text as Text import qualified Data.Text as Text
import Lens.Family2 (Lens', (^.)) import Lens.Family2 ((^.))
import Lens.Family2.Unchecked (lens) import Lens.Family2.State ((%=), use)
import TensorFlow.Output (Output) import Proto.Tensorflow.Core.Framework.NodeDef (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types import TensorFlow.Types
( TensorData(..) ( TensorData(..)
, ListOf(..) , ListOf(..)
@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation. -- | A named output of a TensorFlow operation.
-- --
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The -- The type parameter @a@ is the type of the elements in the 'Tensor'. The
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is -- parameter @v@ is either:
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a --
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via -- * 'Build': An unrendered, immutable value.
-- 'value'. -- * 'Value': A rendered, immutable value.
data Tensor v a = Tensor (TensorKind v) Output -- * 'Ref': A rendered stateful handle (e.g., a variable).
--
-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between
-- the different types of 'Tensor'.
data Tensor v a where
Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
data Value newtype Value a = Value {runValue :: a}
data Ref deriving Functor
-- | This class provides a runtime switch on whether a 'Tensor' should be instance Applicative Value where
-- treated as a 'Value' or as a 'Ref'. pure = Value
data TensorKind v where Value f <*> Value x = Value $ f x
ValueKind :: TensorKind Value
RefKind :: TensorKind Ref
tensorKind :: Lens' (Tensor v a) (TensorKind v) instance Monad Value where
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o) f >>= g = g $ runValue f
tensorOutput :: Lens' (Tensor v a) Output newtype Ref a = Ref {runRef :: a}
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) deriving Functor
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a instance Applicative Ref where
-- Ref into Value. This behaves like a no-op. pure = Ref
value :: Tensor v a -> Tensor Value a Ref f <*> Ref x = Ref $ f x
value (Tensor _ o) = Tensor ValueKind o
instance Monad Ref where
f >>= g = g $ runRef f
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
value :: Tensor Ref a -> Tensor Value a
value (Tensor o) = Tensor $ Value $ runRef o
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
renderValue (Tensor o) = render $ Tensor $ toBuild o
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor' -- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
-- when running the graph. -- when running the graph.
data Feed = Feed Output FFI.TensorData data Feed = Feed Output FFI.TensorData
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
-- name, device, etc.
class TensorKind v => Rendered v where
rendered :: v a -> a
instance Rendered Value where
rendered = runValue
instance Rendered Ref where
rendered = runRef
renderedOutput :: Rendered v => Tensor v a -> Output
renderedOutput = rendered . tensorOutput
tensorNodeName :: Rendered v => Tensor v a -> NodeName
tensorNodeName = outputNodeName . renderedOutput
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running -- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
-- the graph. -- the graph.
-- --
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the -- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'. -- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Tensor v a -> TensorData a -> Feed feed :: Rendered v => Tensor v a -> TensorData a -> Feed
feed (Tensor _ o) (TensorData td) = Feed o td feed t (TensorData td) = Feed (renderedOutput t) td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes -- | Create a 'Tensor' for a given name. This can be used to reference nodes
-- in a 'GraphDef' that was loaded via 'addGraphDef'. -- in a 'GraphDef' that was loaded via 'addGraphDef'.
-- TODO(judahjacobson): add more safety checks here. -- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a tensorFromName :: TensorKind v => Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack tensorFromName = Tensor . pure . fromString . Text.unpack
-- | Like 'tensorFromName', but type-restricted to 'Value'.
tensorValueFromName :: Text.Text -> Tensor Value a
tensorValueFromName = tensorFromName
-- | Like 'tensorFromName', but type-restricted to 'Ref'.
tensorRefFromName :: Text.Text -> Tensor Ref a
tensorRefFromName = tensorFromName
type TensorList v = ListOf (Tensor v) type TensorList v = ListOf (Tensor v)
tensorListOutputs :: TensorList v as -> [Output] tensorListOutputs :: Rendered v => TensorList v as -> [Output]
tensorListOutputs Nil = [] tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device)
<$> lookupNode (outputNodeName $ renderedOutput t)
withDevice (Just d) x
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; calling 'render' on the same input in the same
-- context will produce the same result. However, rendering the same
-- @Tensor Build@ in two different contexts may result in two different
-- @Tensor Value@s.
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
render (Tensor t) = Tensor . Value <$> build t
-- TODO: better name.
expr :: TensorKind v => Tensor v a -> Tensor Build a
expr (Tensor o) = Tensor $ toBuild o
-- | Records the given summary action in Build for retrieval with
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor'
-> m ()
addSummary t = build $ do
-- TODO: more generic way
o <- toBuild $ tensorOutput t
summaries %= (o :)
-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
collectAllSummaries = build $ map (Tensor . Value) <$> use summaries
-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString
-- | An internal class for kinds of Tensors.
class Monad v => TensorKind v where
toBuild :: v a -> Build a
instance TensorKind Value where
toBuild = return . rendered
instance TensorKind Ref where
toBuild = return . rendered
instance TensorKind Build where
toBuild = id