Add resource-based variable ops. (#98)

The main difference between these and the `Ref`-bases ops is the explicit
`readValue` op.  I'm not sure how this should interact with gradients
and save/restore, so I'm keeping it as a separate module for now.  Once we
figure out the details, we can merge it into `TensorFlow.Ops` and replace
all uses of the old `Ref`-based ops.  (That would also fix #92.)

Also replaces our special case newtype `ResourceHandle` to
`Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto
corresponding to `DT_RESOURCE`.
This commit is contained in:
Judah Jacobson 2017-04-16 09:24:02 -07:00 committed by Greg Steuck
parent 21b723d542
commit 42f4fc647e
10 changed files with 231 additions and 48 deletions

View File

@ -33,15 +33,14 @@ where:
(for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
of those types), and @a@ is either a concrete type or a (constrained) type
variable.
the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types),
and @a@ is either a concrete type or a (constrained) type variable.
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
it is either @ControlNode@ or @Build ControlNode@.)
it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's
explicitly marked \"Stateful\" in its @REGISTER_OP@ definition. (If there
are no outputs, it is either @ControlNode@ or @Build ControlNode@.)
-}
module TensorFlow.OpGen
@ -155,7 +154,6 @@ imports = stack [
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
, "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
@ -300,7 +298,7 @@ typeSig pre pOp = constraints
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
ArgSomeTensor v <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name.
@ -333,13 +331,12 @@ typeSig pre pOp = constraints
| otherwise = o
-- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
-- For example: "Tensor Ref Int64", "Tensor v t"
tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of
ResourceArg -> "ResourceHandle"
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {argTypeAttr = t, argCaseKind = k}
SimpleArg { argType = t, argKind = k } -> tensorType t k
ListArg { argType = t, argKind = k } -> brackets $ tensorType t k
MixedListArg {argTypeAttr = t, argKind = k}
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
where
kind k = case k of
@ -420,8 +417,7 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE =
error "ResourceHandle must be prevented from getting here."
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x

View File

@ -17,7 +17,6 @@ module TensorFlow.OpGen.ParsedOp
, ParsedArgCase(..)
, ArgType(..)
, ArgKind(..)
, argKind
, parseOp
, camelCase
) where
@ -118,19 +117,18 @@ data ParsedArg = ParsedArg
}
data ParsedArgCase
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
= SimpleArg { argType :: ArgType, argKind :: ArgKind }
| ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType
, argCaseKind :: ArgKind
, argKind :: ArgKind
}
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
| MixedListArg { argTypeAttr :: Name, argKind :: ArgKind }
-- ^ A heterogeneous list.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
argKind ResourceArg = Nothing
argKind a = Just $ argCaseKind a
maybeArgType :: ParsedArgCase -> Maybe ArgType
maybeArgType MixedListArg{} = Nothing
maybeArgType a = Just $ argType a
-- | The type of an argument.
data ArgType
@ -146,10 +144,10 @@ data ArgKind
deriving (Eq)
isRefCase :: ParsedArgCase -> Bool
isRefCase a = case argKind a of
Nothing -> True -- Resource
Just ArgTensorRef -> True
_ -> False
isRefCase a
| ArgTensorRef <- argKind a = True
| Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True
| otherwise = False
makeName :: Text -> Name
makeName n = Name
@ -314,7 +312,6 @@ parseArg a tKind = ParsedArg
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a tKind
| a ^. type' == DT_RESOURCE = ResourceArg
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
| otherwise = SimpleArg thisArgType tKind

View File

@ -0,0 +1,109 @@
-- | An implementation of ResourceHandle-based variables.
--
-- The main difference between this and 'Ref'-based variables is
-- that reads are explicit, via the 'readValue' op.
--
-- TODO: given that distinction, figure out a good story around
-- gradients and save/restore. Then, merge this module into
-- TensorFlow.Ops.
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
( Variable
, variable
, variable'
, readValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, assign
, assign'
, assignAdd
, assignAdd'
) where
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
newtype Variable a = Variable (Tensor Value ResourceHandle)
-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' params s = build $ do
-- Each variable needs a unique "shared_name". Use MonadFix to
-- set the attribute to the same name as the variable itself, without
-- exposing more internals of the Build module.
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
(tensorType (undefined :: a)) s
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Variable a)
initializedVariable = initializedVariable' id
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
-- The shape is not known initially.
v@(Variable h) <- variable' params (Shape [])
i <- CoreOps.assignVariableOp h initializer
addInitializer =<< group i
return v
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- | Gets the value stored in a variable.
readValue :: TensorType a => Variable a -> Tensor Build a
readValue = readValue' id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' params (Variable h)
= pureOp [] $ do
os <- buildInputs h
pure $ opDef "ReadVariableOp"
& (params
. (opAttr "dtype" .~ tensorType (undefined :: a))
. (opInputs .~ os))
-- | Sets the value of a variable.
assign :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assign = assign' id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
-- | Increments the value of a variable.
assignAdd :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assignAdd = assignAdd' id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v

View File

@ -16,6 +16,7 @@ library
exposed-modules: TensorFlow.Gradient
, TensorFlow.Ops
, TensorFlow.EmbeddingOps
, TensorFlow.Variable
build-depends: proto-lens == 0.2.*
, base >= 4.7 && < 5
, bytestring
@ -37,6 +38,8 @@ Test-Suite RegressionTest
hs-source-dirs: tests
build-depends: base
, HUnit
, lens-family
, transformers
, random
, tensorflow
, tensorflow-core-ops
@ -126,6 +129,23 @@ Test-Suite OpsTest
, transformers
, vector
Test-Suite VariableTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: VariableTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, test-framework
, test-framework-hunit
, transformers
, vector
Test-Suite DataFlowOpsTest
default-language: Haskell2010
type: exitcode-stdio-1.0

View File

@ -0,0 +1,62 @@
{-# LANGUAGE OverloadedLists #-}
module Main (main) where
import Control.Monad.IO.Class (liftIO)
import qualified Data.Vector.Storable as V
import Google.Test (googleTest)
import TensorFlow.Core (unScalar, run_, runSession, run)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Variable
( readValue
, initializedVariable
, assign
, assignAdd
)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testRereadRef
, testAssignAdd
]
testInitializedVariable :: Test
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- do
v <- initializedVariable 42
r <- assign v 24
return (1 + readValue v, r)
result <- run formula
liftIO $ 43 @=? (unScalar result :: Float)
run_ reset -- Updates v to a different value
rerunResult <- run formula
liftIO $ 25 @=? (unScalar rerunResult :: Float)
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float)
-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ runSession $ do
w <- initializedVariable 0
f0 <- run (readValue w)
run_ =<< assign w (Ops.scalar (0.1 :: Float))
f1 <- run (readValue w)
liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1)
testAssignAdd :: Test
testAssignAdd = testCase "testAssignAdd" $ runSession $ do
w <- initializedVariable 42
run_ =<< assignAdd w 17
f1 <- run (readValue w)
liftIO $ (42 + 17 :: Float) @=? unScalar f1

View File

@ -62,6 +62,7 @@ module TensorFlow.Build
) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
@ -192,7 +193,8 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
-- Used to manage build state internally as part of the @Session@ monad.
newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
MonadFix)
-- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity
@ -307,7 +309,6 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
nextUnique .= succ u
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
-- | Turn an 'Output' into a string representation for the TensorFlow
-- foreign APIs.
encodeOutput :: Output -> Text

View File

@ -126,9 +126,6 @@ recordResult = do
put $! ResultState (i+1) ns
return $! output i o
instance BuildResult ResourceHandle where
buildResult = ResourceHandle <$> recordResult
instance Rendered v => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult
@ -302,9 +299,6 @@ instance BuildInputs (ListOf (Tensor v) as) where
buildInputs Nil = return []
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
instance BuildInputs ResourceHandle where
buildInputs (ResourceHandle o) = return [o]
----
-- | Parameters to build an op (for example, the node name or optional attributes).

View File

@ -52,6 +52,7 @@ module TensorFlow.Core
, addGraphDef
, opName
, opAttr
, addInitializer
-- * Tensor
, ControlNode
, Tensor
@ -64,6 +65,7 @@ module TensorFlow.Core
, TensorType
, TensorData
, TensorDataType(decodeTensorData, encodeTensorData)
, ResourceHandle
, Scalar(..)
, Shape(..)
, OneOf

View File

@ -33,7 +33,6 @@ module TensorFlow.Output
, Output(..)
, output
, PendingNodeName(..)
, ResourceHandle(..)
) where
import qualified Data.Map.Strict as Map
@ -127,8 +126,3 @@ instance IsString Output where
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack
-- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables.
newtype ResourceHandle = ResourceHandle Output

View File

@ -39,6 +39,7 @@ module TensorFlow.Types
, protoShape
, Attribute(..)
, DataType(..)
, ResourceHandle
-- * Lists
, ListOf(..)
, List
@ -94,15 +95,18 @@ import Proto.Tensorflow.Core.Framework.AttrValue
, shape
, tensor
)
import Proto.Tensorflow.Core.Framework.ResourceHandle
(ResourceHandle)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..)
, floatVal
, doubleVal
, intVal
, stringVal
, int64Val
, stringVal
, boolVal
, doubleVal
, floatVal
, intVal
, int64Val
, resourceHandleVal
, stringVal
, stringVal
)
import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
@ -183,6 +187,10 @@ instance TensorType (Complex Double) where
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
instance TensorType ResourceHandle where
tensorType _ = DT_RESOURCE
tensorRefType _ = DT_RESOURCE_REF
tensorVal = resourceHandleVal
-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }