Add resource-based variable ops. (#98)

The main difference between these and the `Ref`-bases ops is the explicit
`readValue` op.  I'm not sure how this should interact with gradients
and save/restore, so I'm keeping it as a separate module for now.  Once we
figure out the details, we can merge it into `TensorFlow.Ops` and replace
all uses of the old `Ref`-based ops.  (That would also fix #92.)

Also replaces our special case newtype `ResourceHandle` to
`Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto
corresponding to `DT_RESOURCE`.
This commit is contained in:
Judah Jacobson 2017-04-16 09:24:02 -07:00 committed by Greg Steuck
parent 21b723d542
commit 42f4fc647e
10 changed files with 231 additions and 48 deletions

View File

@ -33,15 +33,14 @@ where:
(for example: 'TensorType' or 'OneOf'). (for example: 'TensorType' or 'OneOf').
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of * @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types),
of those types), and @a@ is either a concrete type or a (constrained) type and @a@ is either a concrete type or a (constrained) type variable.
variable.
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and * @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if @Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs, explicitly marked \"Stateful\" in its @REGISTER_OP@ definition. (If there
it is either @ControlNode@ or @Build ControlNode@.) are no outputs, it is either @ControlNode@ or @Build ControlNode@.)
-} -}
module TensorFlow.OpGen module TensorFlow.OpGen
@ -155,7 +154,6 @@ imports = stack [
, "import Lens.Family2 ((.~), (&))" , "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build" , "import TensorFlow.Build"
, "import TensorFlow.BuildOp" , "import TensorFlow.BuildOp"
, "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor" , "import TensorFlow.Tensor"
, "import TensorFlow.Types" , "import TensorFlow.Types"
] ]
@ -300,7 +298,7 @@ typeSig pre pOp = constraints
| null classConstraints = empty | null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]] ArgSomeTensor v <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else [] ++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name. -- Use m' as the type parameter to avoid clashing with an attribute name.
@ -333,13 +331,12 @@ typeSig pre pOp = constraints
| otherwise = o | otherwise = o
-- | Render an op input or output. -- | Render an op input or output.
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" -- For example: "Tensor Ref Int64", "Tensor v t"
tensorArg :: ParsedArg -> Doc tensorArg :: ParsedArg -> Doc
tensorArg p = case parsedArgCase p of tensorArg p = case parsedArgCase p of
ResourceArg -> "ResourceHandle" SimpleArg { argType = t, argKind = k } -> tensorType t k
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k ListArg { argType = t, argKind = k } -> brackets $ tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k MixedListArg {argTypeAttr = t, argKind = k}
MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t -> "TensorList" <+> parens (kind k) <+> renderHaskellName t
where where
kind k = case k of kind k = case k of
@ -420,8 +417,7 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16" dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8" dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE = dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
error "ResourceHandle must be prevented from getting here."
dtTypeToHaskell x = dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x

View File

@ -17,7 +17,6 @@ module TensorFlow.OpGen.ParsedOp
, ParsedArgCase(..) , ParsedArgCase(..)
, ArgType(..) , ArgType(..)
, ArgKind(..) , ArgKind(..)
, argKind
, parseOp , parseOp
, camelCase , camelCase
) where ) where
@ -118,19 +117,18 @@ data ParsedArg = ParsedArg
} }
data ParsedArgCase data ParsedArgCase
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind } = SimpleArg { argType :: ArgType, argKind :: ArgKind }
| ListArg | ListArg
{ argLength :: Name -- ^ The attribute that specifies this list's length. { argLength :: Name -- ^ The attribute that specifies this list's length.
, argType :: ArgType , argType :: ArgType
, argCaseKind :: ArgKind , argKind :: ArgKind
} }
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } | MixedListArg { argTypeAttr :: Name, argKind :: ArgKind }
-- ^ A heterogeneous list. -- ^ A heterogeneous list.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind maybeArgType :: ParsedArgCase -> Maybe ArgType
argKind ResourceArg = Nothing maybeArgType MixedListArg{} = Nothing
argKind a = Just $ argCaseKind a maybeArgType a = Just $ argType a
-- | The type of an argument. -- | The type of an argument.
data ArgType data ArgType
@ -146,10 +144,10 @@ data ArgKind
deriving (Eq) deriving (Eq)
isRefCase :: ParsedArgCase -> Bool isRefCase :: ParsedArgCase -> Bool
isRefCase a = case argKind a of isRefCase a
Nothing -> True -- Resource | ArgTensorRef <- argKind a = True
Just ArgTensorRef -> True | Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True
_ -> False | otherwise = False
makeName :: Text -> Name makeName :: Text -> Name
makeName n = Name makeName n = Name
@ -314,7 +312,6 @@ parseArg a tKind = ParsedArg
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
parseArgCase a tKind parseArgCase a tKind
| a ^. type' == DT_RESOURCE = ResourceArg
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
| otherwise = SimpleArg thisArgType tKind | otherwise = SimpleArg thisArgType tKind

View File

@ -0,0 +1,109 @@
-- | An implementation of ResourceHandle-based variables.
--
-- The main difference between this and 'Ref'-based variables is
-- that reads are explicit, via the 'readValue' op.
--
-- TODO: given that distinction, figure out a good story around
-- gradients and save/restore. Then, merge this module into
-- TensorFlow.Ops.
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
( Variable
, variable
, variable'
, readValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, assign
, assign'
, assignAdd
, assignAdd'
) where
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
newtype Variable a = Variable (Tensor Value ResourceHandle)
-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' params s = build $ do
-- Each variable needs a unique "shared_name". Use MonadFix to
-- set the attribute to the same name as the variable itself, without
-- exposing more internals of the Build module.
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
(tensorType (undefined :: a)) s
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Variable a)
initializedVariable = initializedVariable' id
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
-- The shape is not known initially.
v@(Variable h) <- variable' params (Shape [])
i <- CoreOps.assignVariableOp h initializer
addInitializer =<< group i
return v
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- | Gets the value stored in a variable.
readValue :: TensorType a => Variable a -> Tensor Build a
readValue = readValue' id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' params (Variable h)
= pureOp [] $ do
os <- buildInputs h
pure $ opDef "ReadVariableOp"
& (params
. (opAttr "dtype" .~ tensorType (undefined :: a))
. (opInputs .~ os))
-- | Sets the value of a variable.
assign :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assign = assign' id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
-- | Increments the value of a variable.
assignAdd :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assignAdd = assignAdd' id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v

View File

@ -16,6 +16,7 @@ library
exposed-modules: TensorFlow.Gradient exposed-modules: TensorFlow.Gradient
, TensorFlow.Ops , TensorFlow.Ops
, TensorFlow.EmbeddingOps , TensorFlow.EmbeddingOps
, TensorFlow.Variable
build-depends: proto-lens == 0.2.* build-depends: proto-lens == 0.2.*
, base >= 4.7 && < 5 , base >= 4.7 && < 5
, bytestring , bytestring
@ -37,6 +38,8 @@ Test-Suite RegressionTest
hs-source-dirs: tests hs-source-dirs: tests
build-depends: base build-depends: base
, HUnit , HUnit
, lens-family
, transformers
, random , random
, tensorflow , tensorflow
, tensorflow-core-ops , tensorflow-core-ops
@ -126,6 +129,23 @@ Test-Suite OpsTest
, transformers , transformers
, vector , vector
Test-Suite VariableTest
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: VariableTest.hs
hs-source-dirs: tests
build-depends: HUnit
, base
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, test-framework
, test-framework-hunit
, transformers
, vector
Test-Suite DataFlowOpsTest Test-Suite DataFlowOpsTest
default-language: Haskell2010 default-language: Haskell2010
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0

View File

@ -0,0 +1,62 @@
{-# LANGUAGE OverloadedLists #-}
module Main (main) where
import Control.Monad.IO.Class (liftIO)
import qualified Data.Vector.Storable as V
import Google.Test (googleTest)
import TensorFlow.Core (unScalar, run_, runSession, run)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Variable
( readValue
, initializedVariable
, assign
, assignAdd
)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testRereadRef
, testAssignAdd
]
testInitializedVariable :: Test
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- do
v <- initializedVariable 42
r <- assign v 24
return (1 + readValue v, r)
result <- run formula
liftIO $ 43 @=? (unScalar result :: Float)
run_ reset -- Updates v to a different value
rerunResult <- run formula
liftIO $ 25 @=? (unScalar rerunResult :: Float)
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float)
-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ runSession $ do
w <- initializedVariable 0
f0 <- run (readValue w)
run_ =<< assign w (Ops.scalar (0.1 :: Float))
f1 <- run (readValue w)
liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1)
testAssignAdd :: Test
testAssignAdd = testCase "testAssignAdd" $ runSession $ do
w <- initializedVariable 42
run_ =<< assignAdd w 17
f1 <- run (readValue w)
liftIO $ (42 + 17 :: Float) @=? unScalar f1

View File

@ -62,6 +62,7 @@ module TensorFlow.Build
) where ) where
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
import Control.Monad.Fix (MonadFix(..))
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
@ -192,7 +193,8 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
-- Used to manage build state internally as part of the @Session@ monad. -- Used to manage build state internally as part of the @Session@ monad.
newtype BuildT m a = BuildT (StateT GraphState m a) newtype BuildT m a = BuildT (StateT GraphState m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
MonadState GraphState, MonadThrow, MonadCatch, MonadMask) MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
MonadFix)
-- | An action for building nodes in a TensorFlow graph. -- | An action for building nodes in a TensorFlow graph.
type Build = BuildT Identity type Build = BuildT Identity
@ -307,7 +309,6 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
nextUnique .= succ u nextUnique .= succ u
return $ nodeDef ^. op <> "_" <> Text.pack (show k) return $ nodeDef ^. op <> "_" <> Text.pack (show k)
-- | Turn an 'Output' into a string representation for the TensorFlow -- | Turn an 'Output' into a string representation for the TensorFlow
-- foreign APIs. -- foreign APIs.
encodeOutput :: Output -> Text encodeOutput :: Output -> Text

View File

@ -126,9 +126,6 @@ recordResult = do
put $! ResultState (i+1) ns put $! ResultState (i+1) ns
return $! output i o return $! output i o
instance BuildResult ResourceHandle where
buildResult = ResourceHandle <$> recordResult
instance Rendered v => BuildResult (Tensor v a) where instance Rendered v => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult buildResult = Tensor . pure <$> recordResult
@ -302,9 +299,6 @@ instance BuildInputs (ListOf (Tensor v) as) where
buildInputs Nil = return [] buildInputs Nil = return []
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts) buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
instance BuildInputs ResourceHandle where
buildInputs (ResourceHandle o) = return [o]
---- ----
-- | Parameters to build an op (for example, the node name or optional attributes). -- | Parameters to build an op (for example, the node name or optional attributes).

View File

@ -52,6 +52,7 @@ module TensorFlow.Core
, addGraphDef , addGraphDef
, opName , opName
, opAttr , opAttr
, addInitializer
-- * Tensor -- * Tensor
, ControlNode , ControlNode
, Tensor , Tensor
@ -64,6 +65,7 @@ module TensorFlow.Core
, TensorType , TensorType
, TensorData , TensorData
, TensorDataType(decodeTensorData, encodeTensorData) , TensorDataType(decodeTensorData, encodeTensorData)
, ResourceHandle
, Scalar(..) , Scalar(..)
, Shape(..) , Shape(..)
, OneOf , OneOf

View File

@ -33,7 +33,6 @@ module TensorFlow.Output
, Output(..) , Output(..)
, output , output
, PendingNodeName(..) , PendingNodeName(..)
, ResourceHandle(..)
) where ) where
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
@ -127,8 +126,3 @@ instance IsString Output where
-> Output (fromInteger ix) $ assigned n -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s _ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack where assigned = NodeName . Text.pack
-- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables.
newtype ResourceHandle = ResourceHandle Output

View File

@ -39,6 +39,7 @@ module TensorFlow.Types
, protoShape , protoShape
, Attribute(..) , Attribute(..)
, DataType(..) , DataType(..)
, ResourceHandle
-- * Lists -- * Lists
, ListOf(..) , ListOf(..)
, List , List
@ -94,15 +95,18 @@ import Proto.Tensorflow.Core.Framework.AttrValue
, shape , shape
, tensor , tensor
) )
import Proto.Tensorflow.Core.Framework.ResourceHandle
(ResourceHandle)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..) ( TensorProto(..)
, floatVal
, doubleVal
, intVal
, stringVal
, int64Val
, stringVal
, boolVal , boolVal
, doubleVal
, floatVal
, intVal
, int64Val
, resourceHandleVal
, stringVal
, stringVal
) )
import Proto.Tensorflow.Core.Framework.TensorShape import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..) ( TensorShapeProto(..)
@ -183,6 +187,10 @@ instance TensorType (Complex Double) where
tensorRefType _ = DT_COMPLEX128 tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)" tensorVal = error "TODO (Complex Double)"
instance TensorType ResourceHandle where
tensorType _ = DT_RESOURCE
tensorRefType _ = DT_RESOURCE_REF
tensorVal = resourceHandleVal
-- | Tensor data with the correct memory layout for tensorflow. -- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData } newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }