mirror of https://github.com/tensorflow/haskell.git synced 2024-06-26 14:38:34 +02:00
Judah Jacobson d62c614695 Distinguish between "rendered" and "unrendered" Tensors. (#88)
Distinguish between "rendered" and "unrendered" Tensors.

There are now three types of `Tensor`:

- `Tensor Value a`: rendered value
- `Tensor Ref a`: rendered reference
- `Tensor Build a` : unrendered value

The extra bookkeeping makes it easier to track (and enforce) which tensors are
rendered or not.  For examples where this has been confusing in the past, see

With this change, pure ops look similar to before, returning `Tensor Build`
instead of `Tensor Value`.  "Stateful" (monadic) ops are unchanged.  For

    add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
    assign :: (MonadBuild m, TensorType t)
           => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t)

The `gradients` function now requires that the variables over which it's
differentiating are pre-rendered:

    gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a]
              -> m [Tensor Value a]

(`Rendered v2` means that `v2` is either a `Ref` or a `Value`.)

Additionally, the implementation of `gradients` now takes care to render every
intermediate value when performing the reverse accumulation.  I suspect this
fixes an exponential blowup for complicated expressions.
2017-04-06 15:10:33 -07:00

375 lines
12 KiB

-- Copyright 2016 TensorFlow authors.
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
-- http://www.apache.org/licenses/LICENSE-2.0
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- | This module contains definitions for some built-in TensorFlow operations.
-- Note that certain, "stateful" ops like 'variable' and 'assign' return a
-- 'Build' action (e.g., @Build (Tensor Ref a)@ instead of a pure value; the
-- returned 'Tensor's are always rendered in the current 'Build' context. This
-- approach helps us avoid problems with inlining or common subexpression
-- elimination, by writing
-- > do
-- > v <- variable []
-- > w <- assign v 3
-- > render $ w * w
-- instead of
-- > let
-- > v = variable []
-- > w = assign v 3
-- > in w * w
-- since the latter could be reasonably transformed by the compiler into (or
-- vice versa)
-- > let
-- > v = variable []
-- > w = assign v 3
-- > w' = assign v 3
-- > in w * w'
-- Ops should return a 'Build' action if their original 'OpDef' marks them as
-- stateful, or if they take any Refs as input. (This mirrors the rules that
-- TensorFlow uses to avoid common subexpression elimination.)
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module TensorFlow.Ops
( CoreOps.add
, CoreOps.add'
, CoreOps.abs
, CoreOps.abs'
, CoreOps.addN
, CoreOps.addN'
, CoreOps.argMax
, CoreOps.argMax'
, CoreOps.assign
, CoreOps.assign'
, CoreOps.broadcastGradientArgs
, CoreOps.broadcastGradientArgs'
, CoreOps.cast
, CoreOps.cast'
, CoreOps.concat
, CoreOps.concat'
, constant
, constant'
, CoreOps.equal
, CoreOps.equal'
, expandDims
, expandDims'
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, CoreOps.fill
, CoreOps.fill'
, CoreOps.identity
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, matTranspose
, matTranspose'
, CoreOps.mean
, CoreOps.mean'
, CoreOps.mul
, CoreOps.mul'
, CoreOps.neg
, CoreOps.neg'
, CoreOps.oneHot
, CoreOps.oneHot'
, CoreOps.pack
, CoreOps.pack'
, placeholder
, placeholder'
, CoreOps.range
, CoreOps.range'
, reducedShape
, CoreOps.relu
, CoreOps.relu'
, CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.reshape
, CoreOps.reshape'
, restore
, restoreFromName
, save
, scalar
, scalar'
, shape
, shape'
, CoreOps.sign
, CoreOps.sign'
, CoreOps.size
, CoreOps.size'
, CoreOps.softmax
, CoreOps.softmax'
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.softmaxCrossEntropyWithLogits'
, CoreOps.sparseToDense
, CoreOps.sparseToDense'
, CoreOps.sub
, CoreOps.sub'
, CoreOps.sum
, CoreOps.sum'
, CoreOps.transpose
, CoreOps.transpose'
, truncatedNormal
, truncatedNormal'
, CoreOps.variable
, CoreOps.variable'
, vector
, vector'
, zeros
, CoreOps.zerosLike
, CoreOps.zerosLike'
, scalarize
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor
( TensorProto
, dtype
, tensorShape
import qualified Proto.Tensorflow.Core.Framework.TensorShape
as TensorShape
import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified Prelude (abs)
-- TODO: Look into hs-boot refactoring to allow mutually recursive imports.
-- | Must be defined as an orphan because of the dependency order between Ops
-- and Tensor.
-- The indirect constraint "v ~ Value" helps disambiguate types, for example in
-- "neg 1 :: Tensor Value Float", it helps find the type of the subexpression
-- "1".
instance ( TensorType a
, Num a
, v ~ Build
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where
(+) = CoreOps.add
(*) = CoreOps.mul
(-) = CoreOps.sub
abs = CoreOps.abs
fromInteger = scalar . fromInteger
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Build a
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder = placeholder' id
placeholder' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Tensor Value a)
placeholder' params pShape
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
-- and thus would be CSE'd.
= build $ buildOp [] $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ pShape
& params
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Tensor Ref a)
initializedVariable = initializedVariable' id
initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Tensor Ref a)
initializedVariable' params initializer = do
v <- CoreOps.variable' params [] -- The shape is not known initially.
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
addInitializer =<< group i
return v
-- | Creates a zero-initialized variable with the given shape.
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = zeroInitializedVariable' id
:: (MonadBuild m, TensorType a, Num a) =>
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (Rendered v, MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> m ControlNode
save path xs = build $ do
let toByteStringTensor = scalar . encodeUtf8 . encodeOutput . renderedOutput
let names = fmap toByteStringTensor xs
let types = replicate (length xs) (tensorType (undefined :: a))
names' <- buildInputs $ CoreOps.pack names
xs' <- buildInputs xs
path' <- buildInputs $ scalar path
buildOp [] $ opDef "Save"
& opAttr "T" .~ types
& opInputs .~ (path' ++ names' ++ xs')
-- | Restore a tensor's value from a checkpoint file.
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode
restoreFromName path name x = build $ do
path' <- buildInputs $ scalar path
name' <- buildInputs $ scalar name
restoreOp <- buildOp [] $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
& opInputs .~ (path' ++ name')
group =<< CoreOps.assign x (restoreOp :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> m ControlNode
restore path x = restoreFromName path name x
name = encodeUtf8 $ encodeOutput $ renderedOutput x
-- | Create a constant tensor.
-- The values should be in row major order, e.g.,
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: TensorType a => Shape -> [a] -> Tensor Build a
constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Build a
constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
invalidLength = product cShape /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product cShape)
(length values)
typedNode :: TensorProto
typedNode = def
& dtype .~ tensorType (undefined :: a)
& tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- cShape]
& tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar.
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: TensorType a => Tensor v a -> Tensor Build a
scalarize t = CoreOps.reshape t (vector scalarShape)
scalarShape = [] :: [Int32]
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Build a
vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Build a
vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: TensorType a => a -> Tensor Build a
scalar = scalar' id
scalar' :: TensorType a => OpParams -> a -> Tensor Build a
scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values.
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal = CoreOps.truncatedNormal
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> OpParams -> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape = CoreOps.shape
shape' :: TensorType t => OpParams -> Tensor v t -> Tensor Build Int32
shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Build t
expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Build Int32
reducedShape inputShape axes =
let inputShape32 = toInt32 inputShape -- [2, 3, 5, 7]
axes32 = toInt32 axes -- [1, 2]
toInt32 x = CoreOps.cast x :: Tensor Build Int32
inputRank = CoreOps.size inputShape32 -- 4
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
axesShape = shape axesMod -- [2]
in CoreOps.dynamicStitch -- [2, 1, 1, 7]
[CoreOps.range 0 inputRank 1, -- [0, 1, 2, 3]
axesMod] -- [1, 2]
[inputShape32, -- [2, 3, 5, 7]
CoreOps.fill axesShape 1] -- [1, 1]