1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Distinguish between "rendered" and "unrendered" Tensors. (#88)

Distinguish between "rendered" and "unrendered" Tensors.

There are now three types of `Tensor`:

- `Tensor Value a`: rendered value
- `Tensor Ref a`: rendered reference
- `Tensor Build a` : unrendered value

The extra bookkeeping makes it easier to track (and enforce) which tensors are
rendered or not.  For examples where this has been confusing in the past, see

With this change, pure ops look similar to before, returning `Tensor Build`
instead of `Tensor Value`.  "Stateful" (monadic) ops are unchanged.  For
example:

    add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
    assign :: (MonadBuild m, TensorType t)
           => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t)

The `gradients` function now requires that the variables over which it's
differentiating are pre-rendered:

    gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a]
              -> m [Tensor Value a]

(`Rendered v2` means that `v2` is either a `Ref` or a `Value`.)

Additionally, the implementation of `gradients` now takes care to render every
intermediate value when performing the reverse accumulation.  I suspect this
fixes an exponential blowup for complicated expressions.
This commit is contained in:
Judah Jacobson 2017-04-06 15:10:33 -07:00 committed by fkm3
parent d71f48090a
commit d62c614695
29 changed files with 636 additions and 608 deletions

View file

@ -58,7 +58,7 @@ fit xData yData = TF.runSession $ do
return (w', b')
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> TF.Tensor TF.Build Float
-> [TF.Tensor TF.Ref Float]
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do

View file

@ -67,10 +67,10 @@ import Proto.Tensorflow.Core.Framework.Summary (Summary)
import Proto.Tensorflow.Core.Util.Event (Event, fileVersion, step, summary, wallTime)
import System.Directory (createDirectoryIfMissing)
import System.FilePath ((</>))
import TensorFlow.Build (Build, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (scalar)
import TensorFlow.Records.Conduit (sinkTFRecords)
import TensorFlow.Tensor (Tensor)
import TensorFlow.Tensor (Tensor, render, SummaryTensor, addSummary, collectAllSummaries)
import TensorFlow.Types (TensorType, type(/=))
import Text.Printf (printf)
import qualified Data.ByteString.Lazy as L
@ -141,19 +141,19 @@ doubleWallTime = asDouble <$> getCurrentTime
-- | Adds a 'CoreOps.histogramSummary' node. The tag argument is intentionally
-- limited to a single value for simplicity.
histogramSummary ::
(TensorType t, t /= ByteString, t /= Bool)
(MonadBuild m, TensorType t, t /= ByteString, t /= Bool)
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build ()
=> ByteString -> Tensor v t -> m ()
histogramSummary tag = addSummary . CoreOps.histogramSummary (scalar tag)
-- | Adds a 'CoreOps.scalarSummary' node.
scalarSummary ::
(TensorType t, t /= ByteString, t /= Bool)
(TensorType t, t /= ByteString, t /= Bool, MonadBuild m)
-- (TensorType t,
-- OneOf '[Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t)
=> ByteString -> Tensor v t -> Build ()
=> ByteString -> Tensor v t -> m ()
scalarSummary tag = addSummary . CoreOps.scalarSummary (scalar tag)
-- | Merge all summaries accumulated in the 'Build' into one summary.
mergeAllSummaries :: Build SummaryTensor
mergeAllSummaries :: MonadBuild m => m SummaryTensor
mergeAllSummaries = collectAllSummaries >>= render . CoreOps.mergeSummary

View file

@ -34,13 +34,13 @@ numPixels = 28*28 :: Int64
numLabels = 10 :: Int64
-- | Create tensor with random values where the stddev depends on the width.
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Value Float)
randomParam :: Int64 -> TF.Shape -> TF.Build (TF.Tensor TF.Build Float)
randomParam width (TF.Shape shape) =
(* stddev) <$> TF.truncatedNormal (TF.vector shape)
(`TF.mul` stddev) <$> TF.truncatedNormal (TF.vector shape)
where
stddev = TF.scalar (1 / sqrt (fromIntegral width))
reduceMean :: TF.Tensor TF.Value Float -> TF.Tensor TF.Value Float
reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
-- Types must match due to model structure.
@ -87,7 +87,7 @@ createModel = do
grads <- TF.gradients loss params
let lr = TF.scalar 0.00001
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
applyGrad param grad = TF.assign param $ param `TF.sub` (lr `TF.mul` grad)
trainStep <- TF.group =<< zipWithM applyGrad params grads
let correctPredictions = TF.equal predict labels

View file

@ -37,15 +37,15 @@ import TensorFlow.Examples.MNIST.TrainedGraph
import TensorFlow.Build
( asGraphDef
, addGraphDef
, render
, Build
)
import TensorFlow.Tensor
( Tensor(..)
, Ref
, Value
, feed
, TensorKind(..)
, render
, tensorFromName
, tensorValueFromName
)
import TensorFlow.Ops
import TensorFlow.Session
@ -80,7 +80,7 @@ testReadMNIST = testCase "testReadMNIST" $ do
labelData <- readMNISTLabels =<< testLabelData
10000 @=? length labelData
testNodeName :: Text -> Tensor v a -> Assertion
testNodeName :: Text -> Tensor Build a -> Assertion
testNodeName n g = n @=? opName
where
opName = head (gDef^.node)^.op
@ -89,7 +89,7 @@ testNodeName n g = n @=? opName
testGraphDefGen :: Test
testGraphDefGen = testCase "testGraphDefGen" $ do
-- Test the inferred operation type.
let f0 :: Tensor Value Float
let f0 :: Tensor Build Float
f0 = 0
testNodeName "Const" f0
testNodeName "Add" $ 1 + f0
@ -109,7 +109,7 @@ testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2"
x <- run $ tensorValueFromName "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x
-- | Load MNIST from a GraphDef and the weights from a checkpoint and run on
@ -142,8 +142,8 @@ testMNISTExec = testCase "testMNISTExec" $ do
build $ addGraphDef $ mnist & version .~ 0
-- Define nodes that restore saved weights and biases.
let bias, wts :: Tensor Ref Float
bias = tensorFromName RefKind "Variable"
wts = tensorFromName RefKind "weights"
bias = tensorFromName "Variable"
wts = tensorFromName "weights"
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
@ -155,12 +155,12 @@ testMNISTExec = testCase "testMNISTExec" $ do
let ty = encodeTensorData [10] oneHotLabels
where oneHotLabels = V.replicate 10 (0 :: Float) V.// updates
updates = [(fromIntegral label, 1)]
let feeds = [ feed (tensorFromName ValueKind "x-input") tensorSample
, feed (tensorFromName ValueKind "y-input") ty
let feeds = [ feed (tensorValueFromName "x-input") tensorSample
, feed (tensorValueFromName "y-input") ty
]
-- Run the graph with the input feeds and read the ArgMax'd result from
-- the test (not training) side of the evaluation.
x <- runWithFeeds feeds $ tensorFromName ValueKind "test/ArgMax"
x <- runWithFeeds feeds $ tensorValueFromName "test/ArgMax"
-- Print the trained model's predicted outcome.
liftIO $ putStrLn $ "Expectation: " ++ show label ++ "\n"
++ "Prediction: " ++ show (unScalar x :: Int64)

View file

@ -24,7 +24,6 @@ import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( MonadBuild
, render
, withNameScope
)
import TensorFlow.GenOps.Core ( greaterEqual
@ -33,6 +32,7 @@ import TensorFlow.GenOps.Core ( greaterEqual
, exp
)
import TensorFlow.Tensor ( Tensor(..)
, render
, Value
)
import TensorFlow.Types ( TensorType(..)
@ -40,6 +40,8 @@ import TensorFlow.Types ( TensorType(..)
)
import TensorFlow.Ops ( zerosLike
, add
, mul
, neg
)
-- | Computes sigmoid cross entropy given `logits`.
@ -76,13 +78,11 @@ sigmoidCrossEntropyWithLogits
-> Tensor Value a -- ^ __targets__
-> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits
targets' <- render targets
let zeros = zerosLike logits'
cond = logits' `greaterEqual` zeros
relu_logits = select cond logits' zeros
neg_abs_logits = select cond (-logits') logits'
let zeros = zerosLike logits
cond = logits `greaterEqual` zeros
relu_logits = select cond logits zeros
neg_abs_logits = select cond (neg logits) logits
withNameScope "logistic_loss" $ do
left <- render $ relu_logits - logits' * targets'
left <- render $ relu_logits - logits `mul` targets
right <- render $ log (1 + exp neg_abs_logits)
withNameScope "sigmoid_add" $ render $ left `add` right

View file

@ -23,10 +23,9 @@ import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Core as TF
-- | These tests are ported from:
--
@ -60,12 +59,11 @@ defInputs = Inputs {
testLogisticOutput :: Test
testLogisticOutput = testCase "testLogisticOutput" $ do
let inputs = defInputs
vLogits = TF.vector $ logits inputs
vTargets = TF.vector $ targets inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
r <- run $ do
vLogits <- TF.render $ TF.vector $ logits inputs
vTargets <- TF.render $ TF.vector $ targets inputs
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
assertAllClose r ourLoss
@ -74,23 +72,22 @@ testLogisticOutputMultipleDim =
testCase "testLogisticOutputMultipleDim" $ do
let inputs = defInputs
shape = [2, 2, 2]
vLogits = TF.constant shape (logits inputs)
vTargets = TF.constant shape (targets inputs)
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
r <- run tfLoss
r <- run $ do
vLogits <- TF.render $ TF.constant shape (logits inputs)
vTargets <- TF.render $ TF.constant shape (targets inputs)
TF.sigmoidCrossEntropyWithLogits vLogits vTargets
let ourLoss = V.fromList $ sigmoidXentWithLogits (logits inputs) (targets inputs)
assertAllClose r ourLoss
testGradientAtZero :: Test
testGradientAtZero = testCase "testGradientAtZero" $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vLogits = TF.vector $ logits inputs
vTargets = TF.vector $ targets inputs
tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
r <- run $ do
let inputs = defInputs { logits = [0, 0], targets = [0, 1] }
vTargets <- TF.render $ TF.vector $ targets inputs
vLogits <- TF.render $ TF.vector $ logits inputs
let tfLoss = TF.sigmoidCrossEntropyWithLogits vLogits vTargets
l <- tfLoss
TF.gradients l [vLogits]

View file

@ -231,21 +231,24 @@ renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
functionBody pOp
| parsedOpIsMonadic pOp
= "build $ do"
</> indent indentation (bindOpInputsVar
</> "buildOp" <+> outputListsSizes <+> opDef)
| otherwise
= "pureOp" <+> outputListsSizes <+> "$ do"
</> indent indentation (bindOpInputsVar </> "return" <+> opDef)
where
maybeLift
| parsedOpIsMonadic pOp = "build $"
| otherwise = ""
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+>
brackets (commasep $
map renderHaskellName outputListsSizes)
outputListsSizes = [ a
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs pOp]
buildOpParts =
outputListsSizes = brackets $ commasep
[ renderHaskellName a
| ParsedArg { parsedArgCase = ListArg { argLength = a } }
<- parsedOutputs pOp
]
opInputsVar = "op'inputs"
bindOpInputsVar = opInputsVar <+> "<- fmap Prelude.concat $ Prelude.sequence"
<+> brackets (commasep $ map (\a -> "buildInputs" <+> a) tensorArgs)
opDef = parens $ hang 0 $ stack $
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders type parameter arguments.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
@ -259,10 +262,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a
] ++
["& op'options"]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
["& op'options & opInputs .~" <+> opInputsVar]
tensorArgs = renderTensorArg <$> parsedInputs pOp
renderTensorArg = renderHaskellName . parsedArgName
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
@ -296,7 +298,7 @@ typeSig pre pOp = constraints
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name.
@ -336,12 +338,13 @@ tensorArg p = case parsedArgCase p of
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> kind k <+> renderHaskellName t
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
where
kind k = case k of
ArgTensorRef -> "Ref"
ArgTensorValue -> "Value"
ArgTensorEither v' -> strictText v'
ArgTensorBuild -> "Build"
ArgSomeTensor v -> strictText v
tensorType t k = let
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt

View file

@ -141,7 +141,8 @@ data ArgType
data ArgKind
= ArgTensorRef -- Tensor Ref a
| ArgTensorValue -- Tensor Value a
| ArgTensorEither Text -- Tensor v a; the Text is the variable `v`
| ArgTensorBuild -- Tensor Build a
| ArgSomeTensor Text -- Tensor v a; the Text is the variable 'v'.
deriving (Eq)
isRefCase :: ParsedArgCase -> Bool
@ -219,15 +220,17 @@ parseOp o = ParsedOp
{ parsedOpName = makeName $ o ^. name
, parsedOpSummary = o ^. summary
, parsedOpDescription = o ^. description
, parsedOpIsMonadic = o ^. isStateful
|| any (isRefCase . parsedArgCase) parsedInputs
, ..
}
where
parsedInputs = zipWith (\a v -> parseArg a (inputTensorKind a v))
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
parsedOpIsMonadic = o ^. isStateful
|| any (isRefCase . parsedArgCase) parsedInputs
|| null (o ^. outputArg)
parsedInputs = zipWith (\t a -> parseArg a (inputTensorKind t a))
tensorKindParams (o ^. inputArg)
tensorKindParams = ["v'" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind parsedOpIsMonadic a))
(o ^. outputArg)
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
@ -246,15 +249,16 @@ parseOp o = ParsedOp
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs.
inputTensorKind :: OpDef'ArgDef -> Text -> ArgKind
inputTensorKind a v
inputTensorKind :: Text -> OpDef'ArgDef -> ArgKind
inputTensorKind v a
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorEither v
| otherwise = ArgSomeTensor v
outputTensorKind :: OpDef'ArgDef -> ArgKind
outputTensorKind a
outputTensorKind :: Bool -> OpDef'ArgDef -> ArgKind
outputTensorKind isMonadic a
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
| isMonadic = ArgTensorValue
| otherwise = ArgTensorBuild
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a

View file

@ -24,9 +24,9 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (MonadBuild, colocateWith, render)
import TensorFlow.Build (MonadBuild)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Tensor (Tensor, Value, Rendered, colocateWith, render)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
@ -44,17 +44,18 @@ import qualified TensorFlow.GenOps.Core as CoreOps
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v m .
embeddingLookup :: forall a b v1 v2 m .
( MonadBuild m
, Rendered v1
, TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
=> [Tensor v a]
=> [Tensor v1 a]
-- ^ A list of tensors which can be concatenated along
-- dimension 0. Each `Tensor` must be appropriately
-- sized for `mod` partition strategy.
-> Tensor Value b
-> Tensor v2 b
-- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31

View file

@ -31,6 +31,7 @@ import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int32, Int64)
import Data.Foldable (foldlM)
import Data.List (foldl', sortBy)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, maybeToList, mapMaybe)
@ -39,7 +40,7 @@ import Data.ProtoLens.TextFormat (showMessage)
import Data.Set (Set)
import Data.Text (Text)
import Data.Tuple (swap)
import Lens.Family2 (Lens', (&), (^.), (.~), (%~))
import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
import Lens.Family2.State.Strict (uses)
import Lens.Family2.Stock (at, intAt)
import Lens.Family2.Unchecked (lens, iso)
@ -59,11 +60,10 @@ import TensorFlow.Build
( MonadBuild
, Build
, build
, render
, renderNodeName
, renderedNodeDefs
, opDef
, opAttr
, opInputs
)
import TensorFlow.BuildOp
import TensorFlow.Ops
@ -86,16 +86,19 @@ import TensorFlow.Ops
)
import TensorFlow.Output
( NodeName(..)
, Op (Rendered)
, Output(..)
, OutputIx(..)
, outputIndex
)
import TensorFlow.Tensor
( Tensor(..)
, TensorKind (ValueKind)
, Value
, tensorOutput
, render
, expr
, Rendered
, tensorNodeName
, renderedOutput
, renderValue
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
@ -114,10 +117,7 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, Rendered v2
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
@ -150,27 +150,31 @@ gradients y xs = build $ do
--
-- 4. Lookup the recorded gradient for each x in xs.
yName <- renderNodeName y
y' <- renderValue y
let yName = tensorNodeName y'
yOne <- render $ fill (shape y') (scalar 1)
-- TODO(fmayle): Move this into Build.hs and call it unsafeNodeDefFromName?
nodeDefLookup :: (NodeName -> NodeDef) <- uses renderedNodeDefs $
(\f x -> fromMaybe (error $ "no NodeDef found for " ++ show x) (f x))
. flip Map.lookup
let (gr, nodeMap) = createGraph yName nodeDefLookup
-- Set gradient of y to one.
-- TODO: nicer
let initPending :: Map.Map FGL.Node (PendingGradients a)
initPending = Map.empty & at (nodeMap Map.! yName)
= Map.empty & (at (nodeMap Map.! yName)
. nonEmpty
. outputIxAt (y ^. tensorOutput . outputIndex)
. outputIxAt (outputIndex $ renderedOutput y')
. nonEmpty
.~ [fill (shape y) (scalar 1)]
.~ [yOne]
)
-- Calculate the gradients of y w.r.t. each node in the graph.
gradientMap <- graphGrads gr initPending
-- Lookup the gradients for each x.
forM xs $ \x -> do
xName <- renderNodeName x
render $ fromMaybe (zerosLike x) $ do
forM xs $ \x ->
let xName = tensorNodeName x
in maybe (render $ zerosLike x) return $ do
n <- nodeMap ^. at xName
let i = x ^. tensorOutput . outputIndex
let i = outputIndex $ renderedOutput x
gradientMap ^. at n . nonEmpty . outputIxAt i
outputIxAt :: OutputIx -> Lens' (IntMap.IntMap v) (Maybe v)
@ -182,6 +186,7 @@ outputIxAt = intAt . unOutputIx
type PendingGradients a = IntMap.IntMap [Tensor Value a]
-- | Gradients of a node's outputs. The key is an OutputIx sans newtype.
-- TODO: precache the rendering?
type Gradients a = IntMap.IntMap (Tensor Value a)
-- | Graph of TensorFlow operations.
@ -229,42 +234,43 @@ non a = anon a (a==)
nonEmpty :: (Monoid (t v), Foldable t) => Lens' (Maybe (t v)) (t v)
nonEmpty = anon mempty null
-- TODO: strictness (e.g., foldlM')
-- | Calculate the gradients for every node in a graph.
graphGrads :: forall a. GradientCompatible a
=> Graph
-> Map FGL.Node (PendingGradients a)
-- ^ Initial gradients (usually just 1 for the node of interest).
-> Build (Map FGL.Node (Gradients a))
graphGrads gr initPending = pure (foldl' go initState nodeOrder ^. gradientsResult)
graphGrads gr initPending = view gradientsResult <$> foldlM go initState nodeOrder
where
initState = GradientsState initPending Map.empty
-- Reverse topological sort.
-- TODO(fmayle): Filter out nodes that are not successors of any x in xs to
-- avoid calculating gradients that won't be used.
nodeOrder = FGL.topsort $ FGL.grev gr
go state node =
go :: GradientsState a -> Int -> Build (GradientsState a)
go state node = do
-- Aggregate the accumulated gradients for this node.
let outputGrads =
outputGrads <-
sumPendingGradient (state ^. gradientsPending . at node . nonEmpty)
in if null outputGrads
then state
else
if null outputGrads
then pure state
else do
let ctx = FGL.context gr node
inputGrads <- calculateInputGrads ctx outputGrads gr
-- Calculate the gradients for each of the node's inputs.
let nextState = state & gradientsResult %~ Map.insert node outputGrads
ctx = FGL.context gr node
in updatePendingGradients
ctx
(calculateInputGrads ctx outputGrads gr)
nextState
pure $ updatePendingGradients ctx inputGrads nextState
-- | Reduce accumulated gradients for each output to one Tensor.
sumPendingGradient :: GradientCompatible a
=> PendingGradients a -> Gradients a
sumPendingGradient = IntMap.mapMaybe f
=> PendingGradients a -> Build (Gradients a)
sumPendingGradient = sequence . IntMap.mapMaybe f
where
f [] = Nothing
f [x] = Just x
f xs = Just (addN xs)
f [x] = Just (pure x)
f xs = Just (render $ addN xs)
-- | Calculate the gradients of a node's input tensors.
@ -274,18 +280,18 @@ calculateInputGrads :: forall a. GradientCompatible a
=> FGL.Context NodeDef EdgeLabel
-> Gradients a -- ^ Output gradients of the node.
-> Graph
-> [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
-> Build [Maybe (Tensor Value a)]
calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr = do
fullOutGrads <- fullOutputGrads (numOutputs nodeDef) (nodeDefName nodeDef)
outputGrads
traverse (traverse render) $ opGrad (nodeDef ^. op) nodeDef inputTensors fullOutGrads
where
fullOutGrads =
fullOutputGrads (numOutputs nodeDef) (Rendered nodeDef) outputGrads
-- Create a tensor from an edge (technically an Output, but it seems less
-- confusing to refer to it as a tensor here).
edgeToTensor :: (EdgeLabel, FGL.Node) -> Output
edgeToTensor ((i, _), n) =
case FGL.lab gr n of
Just edgeNodeDef -> Output i (Rendered edgeNodeDef)
Just edgeNodeDef -> Output i (NodeName $ edgeNodeDef ^. name)
Nothing -> error $ "calculateInputGrads: missing input node for "
++ Text.unpack (nodeDef ^. name)
-- Input tensors, sorted by input index.
@ -294,11 +300,11 @@ calculateInputGrads (inputEdges, _, nodeDef, _) outputGrads gr =
-- | Convert a Map of gradients to a list, with zeros for missing outputs.
fullOutputGrads :: (TensorType a, Num a)
=> OutputIx -- ^ Number of outputs.
-> Op
-> NodeName
-> Gradients a
-> [Tensor Value a]
-> Build [Tensor Value a]
fullOutputGrads n o gs =
map (\i -> fromMaybe (zero i) (gs ^. outputIxAt i)) [0..n-1]
mapM (\i -> maybe (render $ zero i) return (gs ^. outputIxAt i)) [0..n-1]
where
-- A tensor of zeros with the same shape as the i'th output.
zero i = zerosLike $ toT (Output i o)
@ -397,19 +403,19 @@ type GradientFunc a = NodeDef
-- ^ Input tensors.
-> [Tensor Value a]
-- ^ Gradient of y w.r.t. each output tensor.
-> [Maybe (Tensor Value a)]
-> [Maybe (Tensor Build a)]
-- ^ Gradient of y w.r.t. each input tensor.
-- TODO(fmayle): Assert the type is correct.
-- | Create a Tensor from an Output.
toT :: Output -> Tensor Value a
toT = Tensor ValueKind
toT :: Output -> Tensor Build a
toT = Tensor . pure
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations.
flatSlice :: forall v1 t . (TensorType t)
flatSlice :: forall v1 t . TensorType t
=> Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from.
@ -417,9 +423,12 @@ flatSlice :: forall v1 t . (TensorType t)
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
-- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin).
-> Tensor Value t -- ^ __output__
-> Tensor Build t -- ^ __output__
flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
nodeDefName :: NodeDef -> NodeName
nodeDefName = NodeName . view name
-- | The gradient function for an op type.
--
@ -427,8 +436,8 @@ flatSlice t begin size = CoreOps.slice t (vector [begin]) (vector [size])
-- third_party/tensorflow/python/ops/*_grad.py
opGrad :: forall a . GradientCompatible a => Text -> GradientFunc a
opGrad "Abs" _ [toT -> x] [dz] = [Just $ dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ -dz]
opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
opGrad "Square" _ [toT -> x] [dz] =
@ -436,7 +445,7 @@ opGrad "Square" _ [toT -> x] [dz] =
-- TODO(fmayle): The python code makes dz a control dependency of the 2*x
-- (for performance reasons?). Will need to put these functions in the Build
-- monad to replicate that.
[Just $ dz * (2 * x)]
[Just $ dz `CoreOps.mul` (2 * x)]
opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
-- TODO(fmayle): The python version uses a better performance implementation
@ -448,20 +457,20 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
]
where
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
denseShape = shape (x :: Tensor Value a)
denseShape = shape (x :: Tensor Build a)
numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (-1)
]
values = reshape dz valuesShape
-- TODO(fmayle): This could be either Int32 or Int64.
indices' = reshape indices allDimensions :: Tensor Value Int32
indices' = reshape indices allDimensions :: Tensor Build Int32
opGrad "Max" _ [toT -> x, toT -> indices] [dz] =
[Just $ indicators `CoreOps.div` numSelected * dz', Nothing]
where
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
y = CoreOps.max x indices
y' = reshape y outputShapeKeptDims
dz' = reshape dz outputShapeKeptDims
@ -475,8 +484,8 @@ opGrad "Sum" _ [toT -> x, toT -> indices] [dz] =
[ Just $ CoreOps.tile grad tileScaling, Nothing ]
where
-- TODO(gnezdo): Implement the fast-path from math_grad._SumGrad.
sx = shape (x :: Tensor Value a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Value Int32)
sx = shape (x :: Tensor Build a)
outputShapeKeptDims = reducedShape sx (indices :: Tensor Build Int32)
tileScaling = safeShapeDiv sx outputShapeKeptDims
grad = reshape dz outputShapeKeptDims
@ -484,8 +493,8 @@ opGrad "Mean" u v@[toT -> x, _] w =
[Just $ dz `CoreOps.div` CoreOps.cast factor, Nothing]
where
[Just dz, Nothing] = opGrad "Sum" u v w
inputShape = shape (x :: Tensor Value a)
outputShape = shape (dz :: Tensor Value a)
inputShape = shape (x :: Tensor Build a)
outputShape = shape (dz :: Tensor Build a)
-- TODO(fmayle): Add fast path when shape is known.
inputSize = CoreOps.prod inputShape $ rangeOfRank inputShape
outputSize = CoreOps.prod outputShape $ rangeOfRank outputShape
@ -495,8 +504,8 @@ opGrad "Add" _ [toT -> x, toT -> y] [dz] =
[ Just $ reshape (sum dz rx) sx
, Just $ reshape (sum dz ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Sub" u v w =
@ -510,22 +519,24 @@ opGrad "SoftmaxCrossEntropyWithLogits" _ [toT -> x, toT -> y] [dz, _] =
opGrad "Mul" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers.
[ Just $ reshape (sum (dz * y) rx) sx
, Just $ reshape (sum (x * dz) ry) sy ]
[ Just $ reshape (sum (dz `CoreOps.mul` y) rx) sx
, Just $ reshape (sum (x `CoreOps.mul` dz) ry) sy ]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "Div" _ [toT -> x, toT -> y] [dz] =
-- TODO(fmayle): Handle complex numbers.
-- TODO(gnezdo): Provide Fractional instance and use '/' instead of div.
[ Just $ reshape (sum (dz `CoreOps.div` y) rx) sx
, Just $ reshape (sum (dz * (negate x `CoreOps.div` (y * y))) ry) sy
, Just $ reshape (sum (dz `CoreOps.mul` (negate x `CoreOps.div` (y * y)))
ry)
sy
]
where
sx = shape (x :: Tensor Value a)
sy = shape (y :: Tensor Value a)
sx = shape (x :: Tensor Build a)
sy = shape (y :: Tensor Build a)
(rx, ry) = broadcastGradientArgs sx sy
opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
@ -549,7 +560,7 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Value Int32)
(CoreOps.invertPermutation p :: Tensor Build Int32)
, Nothing
]
@ -582,28 +593,28 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
x output dz
]
where
output :: Tensor Value a
output = toT $ Output 0 (Rendered nodeDef)
output :: Tensor Build a
output = toT $ Output 0 (nodeDefName nodeDef)
ksize = lookupAttr nodeDef "ksize" :: [Int64]
strides = lookupAttr nodeDef "strides" :: [Int64]
padding = lookupAttr nodeDef "padding" :: ByteString
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "Reshape" _ [toT -> x, _] [dz] =
[Just $ reshape dz $ shape (x :: Tensor Value a), Nothing]
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
opGrad "TruncatedNormal" _ _ _ = [Nothing]
opGrad "RefIdentity" _ _ [dz] = [Just dz]
opGrad "RefIdentity" _ _ [dz] = [Just $ expr dz]
opGrad "Cast" nodeDef _ [dz] = [Just reverseCast]
where
-- TODO(gnezdo): too permissive, python only allows float types as src_type.
reverseCast =
buildOp (opDef "Cast"
pureOp [] $ pure (opDef "Cast"
& opAttr "DstT" .~ (lookupAttr nodeDef "SrcT" :: ByteString)
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString))
dz
& opAttr "SrcT" .~ (lookupAttr nodeDef "DstT" :: ByteString)
& opInputs .~ [renderedOutput dz])
opGrad "DynamicStitch" nodeDef inputs [dz] =
replicate halfLen Nothing ++ valuesGrads
@ -614,7 +625,7 @@ opGrad "DynamicStitch" nodeDef inputs [dz] =
in if 2 * half == len
then half
else error ("Uneven input size " ++ show (len, showMessage nodeDef))
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Value Int32)
valuesGrads = [ Just $ CoreOps.gather dz (toT idx :: Tensor Build Int32)
| idx <- take halfLen inputs
]
@ -622,14 +633,14 @@ opGrad "DynamicPartition" nodeDef [toT -> xs, toT -> indices] dz =
[ Just reconstructed, Nothing ]
where
reconstructed = CoreOps.reshape stitched
(CoreOps.shape (xs :: Tensor Value a) :: Tensor Value Int32)
(CoreOps.shape (xs :: Tensor Build a) :: Tensor Build Int32)
stitched = CoreOps.dynamicStitch partitionedIndices dz
partitionedIndices = CoreOps.dynamicPartition np originalIndices indices
np = lookupAttr nodeDef "num_partitions" :: Int64
originalIndices =
CoreOps.reshape (CoreOps.range 0 (CoreOps.size indices) 1) prefixShape
prefixShape = shapeInt32 indices
shapeInt32 = CoreOps.shape :: Tensor Value Int32 -> Tensor Value Int32
shapeInt32 t = CoreOps.shape t :: Tensor Build Int32
opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
[ Nothing
@ -639,18 +650,18 @@ opGrad "Select" _ [toT -> c, toT -> x, _] [dz] =
where zeros = CoreOps.zerosLike x
-- TODO(gnezdo): Unlike Python, no control dependency on dz.
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.inv x ]
opGrad "Log" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.inv x ]
-- TODO(gnezdo): Reuse the output instead of doing another exp,
-- though, it is probably CSE'd away anyway.
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz * CoreOps.exp x ]
opGrad "Exp" _ [toT -> x] [dz] = [ Just $ dz `CoreOps.mul` CoreOps.exp x ]
opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
[ Just $ CoreOps.unsortedSegmentSum
(CoreOps.gather dz (t :: Tensor Value Int32))
(y :: Tensor Value Int32) inputRows
(CoreOps.gather dz (t :: Tensor Build Int32))
(y :: Tensor Build Int32) inputRows
, Nothing
, Nothing
]
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
where inputRows = flatSlice (shape (x :: Tensor Build a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]
@ -710,13 +721,13 @@ numOutputs o =
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
-- Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Value Int32
safeShapeDiv :: Tensor v1 Int32 -> Tensor v2 Int32 -> Tensor Build Int32
safeShapeDiv x y = x `CoreOps.div` (CoreOps.maximum y 1)
allDimensions :: Tensor Value Int32
allDimensions :: Tensor Build Int32
allDimensions = vector [-1 :: Int32]
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Value Int32
rangeOfRank :: forall v1 t. TensorType t => Tensor v1 t -> Tensor Build Int32
rangeOfRank x = CoreOps.range 0 (CoreOps.rank x) 1
lookupAttr :: Attribute a1 => NodeDef -> Text -> a1

View file

@ -166,7 +166,6 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Output (unNodeName)
import TensorFlow.Tensor
import TensorFlow.Types
@ -183,7 +182,7 @@ import qualified Prelude (abs)
-- "1".
instance ( TensorType a
, Num a
, v ~ Value
, v ~ Build
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where
(+) = CoreOps.add
@ -194,10 +193,10 @@ instance ( TensorType a
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
@ -208,7 +207,7 @@ placeholder' :: forall m a . (MonadBuild m, TensorType a)
placeholder' params pShape
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
-- and thus would be CSE'd.
= build $ buildOp $ opDef "Placeholder"
= build $ buildOp [] $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ pShape
& params
@ -216,11 +215,11 @@ placeholder' params pShape
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a)
=> Tensor v a -> m (Tensor Ref a)
initializedVariable = initializedVariable' id
initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
=> OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' params initializer = do
v <- CoreOps.variable' params [] -- The shape is not known initially.
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
@ -240,17 +239,20 @@ zeroInitializedVariable'
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (MonadBuild m, TensorType a)
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> m ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
save path xs = build $ do
let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput
let names = fmap toByteStringTensor xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
build $ saveOp (scalar path) (CoreOps.pack names) xs
names' <- buildInputs $ CoreOps.pack names
xs' <- buildInputs xs
path' <- buildInputs $ scalar path
buildOp [] $ opDef "Save"
& opAttr "T" .~ types
& opInputs .~ (path' ++ names' ++ xs')
-- | Restore a tensor's value from a checkpoint file.
--
@ -261,20 +263,22 @@ restoreFromName :: forall a m . (MonadBuild m, TensorType a)
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
group =<< CoreOps.assign x
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
restoreFromName path name x = build $ do
path' <- buildInputs $ scalar path
name' <- buildInputs $ scalar name
restoreOp <- buildOp [] $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
& opInputs .~ (path' ++ name')
group =<< CoreOps.assign x (restoreOp :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
restoreFromName path name x
restore path x = restoreFromName path name x
where
name = encodeUtf8 $ encodeOutput $ renderedOutput x
-- | Create a constant tensor.
--
@ -283,10 +287,10 @@ restore path x = do
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: TensorType a => Shape -> [a] -> Tensor Value a
constant :: TensorType a => Shape -> [a] -> Tensor Build a
constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
@ -305,24 +309,24 @@ constant' params (Shape cShape) values
-- | Reshape a N-D tensor down to a scalar.
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize :: TensorType a => Tensor v a -> Tensor Build a
scalarize t = CoreOps.reshape t (vector scalarShape)
where
scalarShape = [] :: [Int32]
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector :: TensorType a => [a] -> Tensor Build a
vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: TensorType a => a -> Tensor Value a
scalar :: TensorType a => a -> Tensor Build a
scalar = scalar' id
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
scalar' :: TensorType a => OpParams -> a -> Tensor Build a
scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values.
@ -338,28 +342,28 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
-> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape = CoreOps.shape
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape inputShape axes =
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
axes32 = toInt32 axes -- [1, 2]
toInt32 x = CoreOps.cast x :: Tensor Value Int32
toInt32 x = CoreOps.cast x :: Tensor Build Int32
inputRank = CoreOps.size inputShape32 -- 4
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
axesShape = shape axesMod -- [2]

View file

@ -79,6 +79,7 @@ Test-Suite EmbeddingOpsTest
, test-framework
, test-framework-hunit
, test-framework-quickcheck2
, transformers
, vector
Test-Suite ArrayOpsTest

View file

@ -24,9 +24,7 @@ import Test.HUnit ((@=?))
import qualified Data.Vector as V
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as CoreOps
-- | Test split and concat are inverses.
@ -44,7 +42,7 @@ testSplit = testCase "testSplit" $ TF.runSession $ do
testShapeN :: Test
testShapeN = testCase "testShapeN" $ TF.runSession $ do
let shapes = map TF.Shape [[1],[2,3]]
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Value Float]
let tensors = map TF.zeros shapes :: [TF.Tensor TF.Build Float]
result <- TF.run $ CoreOps.shapeN tensors
liftIO $ [V.fromList [1], V.fromList [2,3]] @=? (result :: [V.Vector Int64])

View file

@ -34,9 +34,7 @@ import TensorFlow.Build
, asGraphDef
, evalBuildT
, flushNodeBuffer
, render
, withDevice
, colocateWith
, withNameScope
, opName
)
@ -50,7 +48,13 @@ import TensorFlow.Ops
, variable'
)
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Tensor
( colocateWith
, render
, Tensor
, Value
, Ref
)
import TensorFlow.Session
( run
, runSession
@ -65,8 +69,7 @@ import qualified Data.Vector as V
-- | Test 'opName' behavior.
testOpName :: Test
testOpName = testCase "testOpName" $ do
let graph = variable' (opName .~ "foo") []
>>= render :: Build (Tensor Ref Float)
let graph = variable' (opName .~ "foo") [] :: Build (Tensor Ref Float)
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"Variable" @=? (nodeDef ^. op)
@ -114,7 +117,6 @@ testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
>>= render
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"Variable" @=? (nodeDef ^. op)

View file

@ -26,16 +26,14 @@ import Test.QuickCheck.Monadic (monadicIO, run)
import qualified Data.Vector as V
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Core as TF
-- DynamicSplit is undone with DynamicStitch to get the original input
-- back.
testDynamicPartitionStitchInverse :: forall a.
(TF.TensorDataType V.Vector a, Show a, Eq a) => StitchExample a -> Property
testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
let splitParts :: [TF.Tensor TF.Value a] =
let splitParts :: [TF.Tensor TF.Build a] =
CoreOps.dynamicPartition numParts (TF.vector values) partTensor
partTensor = TF.vector partitions
restitchIndices = CoreOps.dynamicPartition numParts

View file

@ -19,6 +19,7 @@
-- | Tests for EmbeddingOps.
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import Google.Test (googleTest)
@ -48,16 +49,15 @@ testEmbeddingLookupHasRightShapeWithPartition =
let embShape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [1, 1, 1 :: Int32]
let embedding2 = [0, 0, 0 :: Int32]
let embedding = [ TF.constant embShape embedding1
, TF.constant embShape embedding2
]
let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- TF.runSession $ do
vs <- op
embedding <- mapM TF.render [ TF.constant embShape embedding1
, TF.constant embShape embedding2
]
let ids = TF.constant (TF.Shape [1, 2]) idValues
vs <- embeddingLookup embedding ids
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
@ -77,13 +77,12 @@ testEmbeddingLookupHasRightShape =
, 0, 0, 0 :: Int32
]
let embedding = TF.constant embShape embeddingInit
let idValues = [0, 1 :: Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- TF.runSession $ do
vs <- op
embedding <- TF.render $ TF.constant embShape embeddingInit
let ids = TF.constant (TF.Shape [1, 2]) idValues
vs <- embeddingLookup [embedding] ids
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
@ -92,7 +91,6 @@ testEmbeddingLookupHasRightShape =
-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [1, 1, 1, 0, 0, 0]
-- | Check that we can calculate gradients w.r.t embeddings.
testEmbeddingLookupGradients :: Test
testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
@ -108,10 +106,10 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
x <- TF.placeholder (TF.Shape [2])
embedding <- TF.initializedVariable
=<< TF.render (TF.constant embShape embeddingInit)
(TF.constant embShape embeddingInit)
op <- embeddingLookup [embedding] ids
let twoNorm = CoreOps.square $ TF.abs (op - x)
let twoNorm = CoreOps.square $ TF.abs (op `TF.sub` x)
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding])
@ -131,23 +129,21 @@ testEmbeddingLookupUndoesSplit
(LookupExample numParts
shape@(TF.Shape (firstDim : restDims))
values
indices) =
let modShardedValues :: [TF.Tensor TF.Value a] =
CoreOps.dynamicPartition numParts shapedValues cyclicCounter
cyclicCounter :: TF.Tensor TF.Value Int32 =
indices) = monadicIO $ run $ TF.runSession $ do
let shapedValues = TF.constant shape values
indicesVector <- TF.render $ TF.vector indices
let directs = CoreOps.gather shapedValues indicesVector
let cyclicCounter :: TF.Tensor TF.Build Int32 =
TF.vector [0..fromIntegral firstDim-1]
`CoreOps.mod` fromIntegral numParts
indicesVector = TF.vector indices
directs = CoreOps.gather shapedValues indicesVector
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.run =<< do
embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.
shapeOut @=? V.fromList (genericLength indices : restDims)
got @=? want
modShardedValues :: [TF.Tensor TF.Value a] <-
mapM TF.render $ CoreOps.dynamicPartition numParts shapedValues cyclicCounter
embeddings <- embeddingLookup modShardedValues indicesVector
(shapeOut, got, want :: V.Vector a) <-
TF.run (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.
liftIO $ shapeOut @=? V.fromList (genericLength indices : restDims)
liftIO $ got @=? want
testEmbeddingLookupUndoesSplit _ = error "Bug in Arbitrary (LookupExample)"
-- | Consistent set of parameters for EmbeddingLookupUndoesSplit.

View file

@ -36,10 +36,11 @@ import Proto.Tensorflow.Core.Framework.NodeDef (op)
testGradientSimple :: Test
testGradientSimple = testCase "testGradientSimple" $ do
let x = TF.scalar (3 :: Float)
b = TF.scalar (4 :: Float)
y = x*x + b
grads = TF.gradients y [x, b]
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
let y = x `TF.mul` x `TF.add` b
TF.gradients y [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ grads >>= TF.run
6 @=? TF.unScalar dx
@ -88,9 +89,10 @@ testGradientSimple = testCase "testGradientSimple" $ do
testGradientDisconnected :: Test
testGradientDisconnected = testCase "testGradientDisconnected" $ do
let x = TF.scalar (3 :: Float)
b = TF.scalar (4 :: Float)
grads = TF.gradients x [x, b]
let grads = do
x <- TF.render $ TF.scalar (3 :: Float)
b <- TF.render $ TF.scalar (4 :: Float)
TF.gradients x [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ grads >>= TF.run
1 @=? TF.unScalar dx
@ -118,7 +120,7 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y] >>= TF.run
TF.gradients (TF.expr x + TF.expr y * 3) [x, y] >>= TF.run
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx
@ -142,8 +144,8 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
testDiamond :: Test
testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1]
y = x*x
x <- TF.render $ TF.vector [1]
let y = x `TF.mul` x
z = y*y
TF.gradients z [x] >>= TF.run
(4 :: Float) @=? TF.unScalar dx
@ -152,8 +154,8 @@ testDiamond = testCase "testDiamond" $ do
testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
x <- TF.render $ TF.vector [1, 2, 3, 0, 1 :: Float]
let y = TF.max x (0 :: TF.Tensor TF.Build Int32)
TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx

View file

@ -56,8 +56,7 @@ testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<<
TF.zeroInitializedVariable' (TF.opName .~ "a")
var = TF.zeroInitializedVariable' (TF.opName .~ "a")
(TF.Shape [])
TF.runSession $ do
v <- var
@ -76,7 +75,8 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
p2 <- TF.placeholder []
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5
-- | Test that regular tensors can also be used for feeds, as long as they each
@ -90,7 +90,8 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)]
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5
main :: IO ()

View file

@ -38,7 +38,7 @@ fit xData yData = TF.runSession $ do
return (w', b')
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> TF.Tensor TF.Build Float
-> [TF.Tensor TF.Ref Float]
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do

View file

@ -53,9 +53,9 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False]
f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3]
b <- TF.build $ TF.placeholder [2,3]
f <- TF.placeholder [2,3]
s <- TF.placeholder [2,3]
b <- TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData)
@ -63,7 +63,8 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
-- Do something idempotent to the tensors to verify that tensorflow can
-- handle the encoding. Originally this used `TF.identity`, but that
-- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
(f', s', b') <- TF.runWithFeeds feeds
(f `TF.add` 0, TF.identity s, TF.select b b b)
liftIO $ do
floatData @=? f'
stringData @=? s'

View file

@ -60,7 +60,7 @@ makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
-- under the given name across multiple sessions.
-> m (Queue as)
makeQueue capacity sharedName = do
q <- build $ buildOp (opDef "FIFOQueue"
q <- build $ buildOp [] (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity

View file

@ -13,9 +13,13 @@
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module TensorFlow.Build
( -- * Graph node types
ControlNode(..)
@ -32,8 +36,6 @@ module TensorFlow.Build
, opControlInputs
-- * The Build monad
, GraphState
, render
, renderNodeName
, renderedNodeDefs
, BuildT
, Build
@ -46,27 +48,23 @@ module TensorFlow.Build
, addGraphDef
, flushInitializers
, flushNodeBuffer
, summaries
-- * Creating and looking up Ops
, getOrAddOp
, addNewOp
, renderOutput
, encodeOutput
, lookupNode
-- * Modifying all nodes in a Build action
, colocateWith
, withStateLens
, withDevice
, withNameScope
, withNodeDependencies
-- * Internal Summary related bits.
, addSummary
, SummaryTensor
, collectAllSummaries
) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Functor.Identity (Identity(..))
import qualified Data.Map.Strict as Map
@ -94,7 +92,6 @@ import Proto.Tensorflow.Core.Framework.NodeDef
import TensorFlow.Orphans ()
import TensorFlow.Output
import TensorFlow.Tensor
newtype Unique = Unique Int
deriving (Eq, Ord, Enum)
@ -125,9 +122,6 @@ opDefWithName n t = OpDef
, _opControlInputs = []
}
-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString
data GraphState = GraphState
{ _renderedNodes :: !(Map.Map PendingNode NodeDef)
-- ^ Nodes which have been rendered. Keeps track of the unique ID we
@ -148,8 +142,8 @@ data GraphState = GraphState
, _initializationNodes :: [NodeName]
-- ^ The nodes to run next time a TF.run is issued, typically
-- variable initializers.
, _summaries :: [SummaryTensor]
-- ^ The tensors for summary
, _summaries :: [Output]
-- ^ The tensors for summary (ByteString type)
}
-- | A node definition without its final name. Used as a key in the
@ -191,7 +185,7 @@ defaultControlInputs = lens _defaultControlInputs
initializationNodes :: Lens' GraphState [NodeName]
initializationNodes = lens _initializationNodes (\g x -> g { _initializationNodes = x })
summaries :: Lens' GraphState [SummaryTensor]
summaries :: Lens' GraphState [Output]
summaries = lens _summaries (\g x -> g { _summaries = x })
-- | An action for building nodes in a TensorFlow graph.
@ -238,9 +232,7 @@ flushInitializers = do
-- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'.
addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode o) = build $ do
i <- getOrAddOp o
initializationNodes %= (i:)
addInitializer (ControlNode i) = build $ initializationNodes %= (i:)
-- | Produce a GraphDef proto representation of the nodes that are rendered in
-- the given 'Build' action.
@ -255,30 +247,31 @@ addGraphDef g = build $ nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its
-- name.
getOrAddOp :: Op -> Build NodeName
getOrAddOp o = NodeName . (^. name) <$> resolveOp o
resolveOp :: Op -> Build NodeDef
resolveOp (Rendered n) = return n
resolveOp (Unrendered o) = do
getOrAddOp :: OpDef -> Build NodeName
getOrAddOp o = do
pending <- getPendingNode o
uses renderedNodes (Map.lookup pending) >>= \case
Just n -> return n
Just n -> return $ NodeName $ n ^. name
Nothing -> addNewOpFromPending pending
lookupNode :: NodeName -> Build NodeDef
lookupNode n = uses renderedNodeDefs (Map.lookup n) >>= \case
Just n' -> return n'
Nothing -> error $ "lookupNode: unknown node name " ++ show n
-- | Add a new node for a given 'OpDef'. This is used for making "stateful" ops
-- which are not safe to dedup (e.g, "variable" and "assign").
addNewOp :: OpDef -> Build NodeDef
addNewOp :: OpDef -> Build NodeName
addNewOp o = getPendingNode o >>= addNewOpFromPending
addNewOpFromPending :: PendingNode -> Build NodeDef
addNewOpFromPending :: PendingNode -> Build NodeName
addNewOpFromPending pending = do
nodeName <- renderPendingNode pending
let nodeDef = pendingNodeDef pending & name .~ unNodeName nodeName
nodeBuffer %= (nodeDef :)
renderedNodes %= Map.insert pending nodeDef
renderedNodeDefs %= Map.insert nodeName nodeDef
return nodeDef
return nodeName
-- | Get the pending node corresponding to an OpDef, which may or may not have
-- been rendered before. Implicitly renders all of this node's inputs.
@ -287,20 +280,18 @@ getPendingNode o = do
-- An empty string in the proto field means that no specific
-- device is specified.
dev <- maybe "" deviceName <$> use defaultDevice
inputs <- mapM getInput (o ^. opInputs)
scope <- use currentScope
controls <- use defaultControlInputs
let inputs = map encodeOutput (o ^. opInputs)
let controlInputs
= map getDep (o ^. opControlInputs ++ Set.toList controls)
= map makeDep (o ^. opControlInputs ++ Set.toList controls)
return $ PendingNode scope (o ^. opName)
$ def & op .~ (unOpType (o ^. opType) :: Text)
& attr .~ _opAttrs o
& input .~ (inputs ++ controlInputs)
& device .~ dev
where
getInput (Output (OutputIx k) subOp)
= (<> ":" <> Text.pack (show k)) . unNodeName <$> getOrAddOp subOp
getDep = ("^" <>) . unNodeName
makeDep = ("^" <>) . unNodeName
-- | Pick a name for a pending node. If it has an explicit name, just use that;
-- if the name is implicit, assign a new unique name based on the op type.
@ -317,12 +308,11 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
-- | Render an 'Output' and return a string representation for the TensorFlow
-- | Turn an 'Output' into a string representation for the TensorFlow
-- foreign APIs.
renderOutput :: Output -> Build Text
renderOutput (Output (OutputIx i) o) = do
n <- getOrAddOp o
return $ unNodeName n <> Text.pack (":" ++ show i)
encodeOutput :: Output -> Text
encodeOutput (Output (OutputIx 0) n) = unNodeName n
encodeOutput (Output (OutputIx i) n) = unNodeName n <> Text.pack (':' : show i)
-- | Modify some part of the state, run an action, and restore the state
-- after that action is done.
@ -339,15 +329,6 @@ withStateLens accessor f act = do
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :)
@ -355,31 +336,3 @@ withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'Build' context. Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's.
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
-- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
-- | Records the given summary action in Build for retrieval with
-- 'collectAllSummaries'. The summary op is required to produce a
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: SummaryTensor -> Build ()
addSummary t = summaries %= (t :)
-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: Monad m => BuildT m [SummaryTensor]
collectAllSummaries = use summaries

View file

@ -12,26 +12,27 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
( OpResult
, BuildOp
( BuildResult(..)
, buildOp
, buildListOp
, PureResult(..)
, pureOp
, eqLengthGuard
, BuildInputs(..)
, OpParams
)
where
import Control.Monad (replicateM)
import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put)
import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64)
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
@ -40,48 +41,45 @@ import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT Op (State ResultState)
type Result = ReaderT NodeName (State ResultState)
-- | Class of types that can be used as op outputs.
class OpResult a where
toResult :: Result a
class BuildResult a where
buildResult :: Result a
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
toResult = (,) <$> toResult <*> toResult
instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
buildResult = (,) <$> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
toResult = (,,) <$> toResult <*> toResult <*> toResult
instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
=> OpResult (a1, a2, a3, a4) where
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
=> BuildResult (a1, a2, a3, a4) where
buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
=> OpResult (a1, a2, a3, a4, a5) where
toResult = (,,,,) <$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
=> BuildResult (a1, a2, a3, a4, a5) where
buildResult = (,,,,) <$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
instance ( OpResult a1
, OpResult a2
, OpResult a3
, OpResult a4
, OpResult a5
, OpResult a6
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
)
=> OpResult (a1, a2, a3, a4, a5, a6) where
toResult = (,,,,,)
<$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = Tensor v <$> recordResult
=> BuildResult (a1, a2, a3, a4, a5, a6) where
buildResult = (,,,,,)
<$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
recordResult :: Result Output
recordResult = do
@ -90,144 +88,39 @@ recordResult = do
put $! ResultState (i+1) ns
return $! output i o
instance OpResult ResourceHandle where
toResult = ResourceHandle <$> recordResult
instance BuildResult ResourceHandle where
buildResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind
instance Rendered v => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult
instance OpResult (Tensor Ref a) where
toResult = tensorResult RefKind
instance BuildResult ControlNode where
buildResult = ControlNode <$> ask
instance OpResult ControlNode where
toResult = ControlNode <$> ask
instance (Rendered v, TensorTypes as) => BuildResult (TensorList v as) where
buildResult = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- buildResult
ts <- loop ls
return (t :/ ts)
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- tensorResult v
ts <- loop ls
return (t :/ ts)
instance TensorTypes as => OpResult (TensorList Value as) where
toResult = tensorListResult ValueKind
instance TensorTypes as => OpResult (TensorList Ref as) where
toResult = tensorListResult RefKind
instance OpResult a => OpResult [a] where
toResult = do
instance BuildResult a => BuildResult [a] where
buildResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in toResult. " ++
"Likely misuse of buildListOp."
[] -> error $ "Ran out of counts in buildResult. " ++
"Likely misuse of buildOp."
(n : rest) -> do
put $! ResultState i rest
replicateM (fromIntegral n) toResult
replicateM (fromIntegral n) buildResult
runResult :: OpResult a => [Int64] -> Op -> a
runResult ns o =
case runState (runReaderT toResult o) (ResultState 0 ns) of
(x, ResultState _ []) -> x
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
show (ns, ns')
-- | Make a new "pure" op, which may be deduped with identical ops within
-- the same scope.
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
-- | Make a new "stateful" op, which will not be deduped with otherwise
-- identical ops.
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
buildResult ns o ts
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
addReversedInputs :: OpDef -> [Output] -> OpDef
addReversedInputs o ts = o & opInputs <>~ reverse ts
-- | Class of types that can be used as op functions.
class BuildOp f where
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
-> OpDef
-> [Output] -- ^ Accumulator for inputs to the op.
-> f
-- | Starts an operation that returns a structured set of tensors
-- (singletons or tuples).
buildOp :: BuildOp f => OpDef -> f
buildOp o = buildOp' [] o []
-- | Starts an operation that returns a list of tensors.
buildListOp :: BuildOp f => [Int64]
-- ^ Cardinality of the corresponding list of tensors output.
-> OpDef -> f
buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp ResourceHandle where
buildOp' = pureResult
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Value as) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Ref as) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
=> BuildOp (t1, t2, t3, t4) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
=> BuildOp (t1, t2, t3, t4, t5) where
buildOp' = pureResult
instance ( OpResult t1
, OpResult t2
, OpResult t3
, OpResult t4
, OpResult t5
, OpResult t6
)
=> BuildOp (t1, t2, t3, t4, t5, t6) where
buildOp' = pureResult
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
instance BuildOp f => BuildOp (TensorList v as -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
buildOp sizes o = do
n <- addNewOp o
return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
@ -240,6 +133,104 @@ eqLengthGuard = all eachOk
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)
-----------
-- | Class of types that can be used as op outputs.
class PureResult a where
pureResult :: ReaderT (Build OpDef) (State ResultState) a
instance PureResult (Tensor Build a) where
pureResult = do
ResultState i ns <- get
put $! ResultState (i+1) ns
makeOp <- ask
return $ Tensor $ do
o <- makeOp
-- TODO: unify with BuildResult (Tensor v)
output i <$> getOrAddOp o
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
pureResult = (,) <$> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
=> PureResult (a1, a2, a3, a4) where
pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
=> PureResult (a1, a2, a3, a4, a5) where
pureResult = (,,,,) <$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
)
=> PureResult (a1, a2, a3, a4, a5, a6) where
pureResult = (,,,,,)
<$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance PureResult a => PureResult [a] where
pureResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in pureResult. " ++
"Likely misuse of pureOp with output lists."
n : rest -> do
put $! ResultState i rest
replicateM (fromIntegral n) pureResult
instance TensorTypes as => PureResult (TensorList Build as) where
pureResult = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
(TensorList Build bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- pureResult
ts <- loop ls
return (t :/ ts)
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)
-----
-- Class of types that can be used as arguments
class BuildInputs a where
buildInputs :: a -> Build [Output]
instance BuildInputs a => BuildInputs [a] where
buildInputs = fmap concat . mapM buildInputs
instance BuildInputs (Tensor v a) where
buildInputs (Tensor t) = do
o <- toBuild t
return [o]
instance BuildInputs (ListOf (Tensor v) as) where
buildInputs Nil = return []
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
instance BuildInputs ResourceHandle where
buildInputs (ResourceHandle o) = return [o]
----
-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef

View file

@ -25,9 +25,6 @@ module TensorFlow.ControlFlow
, noOp
) where
import qualified Data.Set as Set
import Lens.Family2 ((&), (.~))
import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes
@ -46,11 +43,8 @@ withControlDependencies deps act = do
-- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output.
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group deps = do
nodes <- build $ Set.toList <$> getNodes deps
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
group deps = withControlDependencies deps noOp
-- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode
noOp = buildOp $ opDef "NoOp"
noOp :: MonadBuild m => m ControlNode
noOp = build $ buildOp [] $ opDef "NoOp"

View file

@ -57,9 +57,9 @@ module TensorFlow.Core
, Tensor
, Value
, Ref
, TensorKind(..)
, value
, tensorFromName
, expr
-- ** Element types
, TensorType
, TensorData

View file

@ -20,6 +20,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} -- For Fetchable (TensorExpr a)
module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3)
@ -28,7 +29,6 @@ import Data.Map.Strict (Map)
import Data.Monoid ((<>))
import Data.Set (Set)
import Data.Text (Text)
import Lens.Family2 ((^.))
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
@ -90,7 +90,7 @@ instance Fetchable t a => Fetchable [t] [a] where
getFetch ts = sequenceA <$> mapM getFetch ts
instance Nodes ControlNode where
getNodes (ControlNode o) = Set.singleton <$> getOrAddOp o
getNodes (ControlNode o) = pure $ Set.singleton o
-- We use the constraint @(a ~ ())@ to help with type inference. For example,
-- if @t :: ControlNode@, then this constraint ensures that @run t :: Session
@ -113,13 +113,13 @@ instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
getNodes (Tensor o) = Set.singleton . outputNodeName <$> toBuild o
fetchTensorVector :: forall a v . TensorType a
fetchTensorVector :: forall a v . (TensorType a)
=> Tensor v a -> Build (Fetch (TensorData a))
fetchTensorVector (Tensor _ o) = do
outputName <- renderOutput o
return $ Fetch (Set.singleton outputName) $ \tensors ->
fetchTensorVector (Tensor o) = do
outputName <- encodeOutput <$> toBuild o
pure $ Fetch (Set.singleton outputName) $ \tensors ->
let tensorData = tensors Map.! outputName
expectedType = tensorType (undefined :: a)
actualType = FFI.tensorDataType tensorData

View file

@ -22,8 +22,6 @@ module TensorFlow.Output
, Device(..)
-- * Ops
, NodeName(..)
, Op(..)
, opUnrendered
, OpDef(..)
, opName
, opType
@ -34,28 +32,24 @@ module TensorFlow.Output
, OutputIx(..)
, Output(..)
, output
, outputIndex
, outputOp
, PendingNodeName(..)
, ResourceHandle(..)
) where
import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
import Lens.Family2 (Lens')
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans ()
-- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: Op }
newtype ControlNode = ControlNode { unControlNode :: NodeName }
-- | The type of op of a node in the graph. This corresponds to the proto field
-- NodeDef.op.
@ -66,18 +60,12 @@ instance IsString OpType where
fromString = OpType . Text.pack
-- | An output of a TensorFlow node.
data Output = Output !OutputIx !Op
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
deriving (Eq, Ord, Show)
output :: OutputIx -> Op -> Output
output :: OutputIx -> NodeName -> Output
output = Output
outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
@ -90,25 +78,6 @@ newtype Device = Device {deviceName :: Text}
instance Show Device where
show (Device d) = show d
-- | The representation of a node in a TensorFlow graph.
data Op
= Rendered !NodeDef -- ^ Properties are fixed, including the
-- device, name, and scope.
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
-- on which context this op is rendered in.
deriving (Eq, Ord)
instance Show Op where
show (Rendered n) = "Rendered " ++ showMessage n
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
-- | Traverse on the 'Unrendered' of an 'Op'.
--
-- Same implementation as _Left.
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef
{ _opName :: !PendingNodeName
@ -157,7 +126,7 @@ instance IsString Output where
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n
where assigned = NodeName . Text.pack
-- | Opaque handle to a mutable resource in the graph. Typical such

View file

@ -163,7 +163,7 @@ runWithFeeds feeds t = do
runFetchWithFeeds :: [Feed] -> Set NodeName -> Fetch a -> Session a
runFetchWithFeeds feeds target (Fetch fetch restore) = do
extend
feeds' <- build $ fixFeeds feeds
let feeds' = fixFeeds feeds
let fetchNames = encodeUtf8 <$> Set.toList fetch
targetNames = toNodeNames $ Set.toList target
session <- Session (asks rawSession)
@ -192,8 +192,8 @@ runWithFeeds_ feeds t = do
ns <- build $ getNodes t
runFetchWithFeeds feeds ns (pure ())
fixFeeds :: [Feed] -> Build [(ByteString, FFI.TensorData)]
fixFeeds = mapM $ \(Feed o d) -> (,d) . encodeUtf8 <$> renderOutput o
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
-- | Starts a concurrent thread which evaluates the given Nodes
-- forever until runSession exits or an exception occurs. Graph

View file

@ -16,21 +16,26 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-} -- For the Render class
module TensorFlow.Tensor where
import Data.ByteString (ByteString)
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Lens.Family2 ((^.))
import Lens.Family2.State ((%=), use)
import TensorFlow.Output (Output)
import Proto.Tensorflow.Core.Framework.NodeDef (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types
( TensorData(..)
, ListOf(..)
@ -40,52 +45,149 @@ import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation.
--
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
-- 'value'.
data Tensor v a = Tensor (TensorKind v) Output
-- parameter @v@ is either:
--
-- * 'Build': An unrendered, immutable value.
-- * 'Value': A rendered, immutable value.
-- * 'Ref': A rendered stateful handle (e.g., a variable).
--
-- Note that 'expr', 'value', 'render' and 'renderValue' can help convert between
-- the different types of 'Tensor'.
data Tensor v a where
Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
data Value
data Ref
newtype Value a = Value {runValue :: a}
deriving Functor
-- | This class provides a runtime switch on whether a 'Tensor' should be
-- treated as a 'Value' or as a 'Ref'.
data TensorKind v where
ValueKind :: TensorKind Value
RefKind :: TensorKind Ref
instance Applicative Value where
pure = Value
Value f <*> Value x = Value $ f x
tensorKind :: Lens' (Tensor v a) (TensorKind v)
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
instance Monad Value where
f >>= g = g $ runValue f
tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
newtype Ref a = Ref {runRef :: a}
deriving Functor
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a
value (Tensor _ o) = Tensor ValueKind o
instance Applicative Ref where
pure = Ref
Ref f <*> Ref x = Ref $ f x
instance Monad Ref where
f >>= g = g $ runRef f
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
value :: Tensor Ref a -> Tensor Value a
value (Tensor o) = Tensor $ Value $ runRef o
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
renderValue (Tensor o) = render $ Tensor $ toBuild o
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
-- when running the graph.
data Feed = Feed Output FFI.TensorData
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
-- name, device, etc.
class TensorKind v => Rendered v where
rendered :: v a -> a
instance Rendered Value where
rendered = runValue
instance Rendered Ref where
rendered = runRef
renderedOutput :: Rendered v => Tensor v a -> Output
renderedOutput = rendered . tensorOutput
tensorNodeName :: Rendered v => Tensor v a -> NodeName
tensorNodeName = outputNodeName . renderedOutput
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
-- the graph.
--
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Tensor v a -> TensorData a -> Feed
feed (Tensor _ o) (TensorData td) = Feed o td
feed :: Rendered v => Tensor v a -> TensorData a -> Feed
feed t (TensorData td) = Feed (renderedOutput t) td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
tensorFromName = Tensor . pure . fromString . Text.unpack
-- | Like 'tensorFromName', but type-restricted to 'Value'.
tensorValueFromName :: Text.Text -> Tensor Value a
tensorValueFromName = tensorFromName
-- | Like 'tensorFromName', but type-restricted to 'Ref'.
tensorRefFromName :: Text.Text -> Tensor Ref a
tensorRefFromName = tensorFromName
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs :: Rendered v => TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: (MonadBuild m, Rendered v) => Tensor v b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device)
<$> lookupNode (outputNodeName $ renderedOutput t)
withDevice (Just d) x
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
-- the 'MonadBuild' context. Also renders any dependencies of the 'Tensor' that
-- weren't already rendered.
--
-- This operation is idempotent; calling 'render' on the same input in the same
-- context will produce the same result. However, rendering the same
-- @Tensor Build@ in two different contexts may result in two different
-- @Tensor Value@s.
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
render (Tensor t) = Tensor . Value <$> build t
-- TODO: better name.
expr :: TensorKind v => Tensor v a -> Tensor Build a
expr (Tensor o) = Tensor $ toBuild o
-- | Records the given summary action in Build for retrieval with
-- Summary protocol buffer in string form. For safety, use the
-- pre-composed functions: Logging.scalarSummary and
-- Logging.histogramSummary.
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -- ^ A 'SummaryTensor'
-> m ()
addSummary t = build $ do
-- TODO: more generic way
o <- toBuild $ tensorOutput t
summaries %= (o :)
-- | Retrieves the summary ops collected thus far. Typically this only
-- happens once, but if 'TensorFlow.Session.buildWithSummary' is used
-- repeatedly, the values accumulate.
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
collectAllSummaries = build $ map (Tensor . Value) <$> use summaries
-- | Synonym for the tensors that return serialized Summary proto.
type SummaryTensor = Tensor Value ByteString
-- | An internal class for kinds of Tensors.
class Monad v => TensorKind v where
toBuild :: v a -> Build a
instance TensorKind Value where
toBuild = return . rendered
instance TensorKind Ref where
toBuild = return . rendered
instance TensorKind Build where
toBuild = id