mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add resource-based variable ops. (#98)
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix #92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
This commit is contained in:
parent
21b723d542
commit
42f4fc647e
10 changed files with 231 additions and 48 deletions
|
@ -33,15 +33,14 @@ where:
|
|||
(for example: 'TensorType' or 'OneOf').
|
||||
|
||||
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
|
||||
of those types), and @a@ is either a concrete type or a (constrained) type
|
||||
variable.
|
||||
the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types),
|
||||
and @a@ is either a concrete type or a (constrained) type variable.
|
||||
|
||||
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
|
||||
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
|
||||
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
|
||||
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
|
||||
it is either @ControlNode@ or @Build ControlNode@.)
|
||||
it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's
|
||||
explicitly marked \"Stateful\" in its @REGISTER_OP@ definition. (If there
|
||||
are no outputs, it is either @ControlNode@ or @Build ControlNode@.)
|
||||
-}
|
||||
|
||||
module TensorFlow.OpGen
|
||||
|
@ -155,7 +154,6 @@ imports = stack [
|
|||
, "import Lens.Family2 ((.~), (&))"
|
||||
, "import TensorFlow.Build"
|
||||
, "import TensorFlow.BuildOp"
|
||||
, "import TensorFlow.Output (ResourceHandle)"
|
||||
, "import TensorFlow.Tensor"
|
||||
, "import TensorFlow.Types"
|
||||
]
|
||||
|
@ -300,7 +298,7 @@ typeSig pre pOp = constraints
|
|||
| null classConstraints = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
|
||||
ArgSomeTensor v <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||
|
@ -333,13 +331,12 @@ typeSig pre pOp = constraints
|
|||
| otherwise = o
|
||||
|
||||
-- | Render an op input or output.
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
|
||||
-- For example: "Tensor Ref Int64", "Tensor v t"
|
||||
tensorArg :: ParsedArg -> Doc
|
||||
tensorArg p = case parsedArgCase p of
|
||||
ResourceArg -> "ResourceHandle"
|
||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||
SimpleArg { argType = t, argKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {argTypeAttr = t, argKind = k}
|
||||
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
|
||||
where
|
||||
kind k = case k of
|
||||
|
@ -420,8 +417,7 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
|||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||
dtTypeToHaskell DT_RESOURCE =
|
||||
error "ResourceHandle must be prevented from getting here."
|
||||
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
|
||||
dtTypeToHaskell x =
|
||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ module TensorFlow.OpGen.ParsedOp
|
|||
, ParsedArgCase(..)
|
||||
, ArgType(..)
|
||||
, ArgKind(..)
|
||||
, argKind
|
||||
, parseOp
|
||||
, camelCase
|
||||
) where
|
||||
|
@ -118,19 +117,18 @@ data ParsedArg = ParsedArg
|
|||
}
|
||||
|
||||
data ParsedArgCase
|
||||
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
|
||||
= SimpleArg { argType :: ArgType, argKind :: ArgKind }
|
||||
| ListArg
|
||||
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||
, argType :: ArgType
|
||||
, argCaseKind :: ArgKind
|
||||
, argKind :: ArgKind
|
||||
}
|
||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||
| MixedListArg { argTypeAttr :: Name, argKind :: ArgKind }
|
||||
-- ^ A heterogeneous list.
|
||||
| ResourceArg
|
||||
|
||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||
argKind ResourceArg = Nothing
|
||||
argKind a = Just $ argCaseKind a
|
||||
maybeArgType :: ParsedArgCase -> Maybe ArgType
|
||||
maybeArgType MixedListArg{} = Nothing
|
||||
maybeArgType a = Just $ argType a
|
||||
|
||||
-- | The type of an argument.
|
||||
data ArgType
|
||||
|
@ -146,10 +144,10 @@ data ArgKind
|
|||
deriving (Eq)
|
||||
|
||||
isRefCase :: ParsedArgCase -> Bool
|
||||
isRefCase a = case argKind a of
|
||||
Nothing -> True -- Resource
|
||||
Just ArgTensorRef -> True
|
||||
_ -> False
|
||||
isRefCase a
|
||||
| ArgTensorRef <- argKind a = True
|
||||
| Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True
|
||||
| otherwise = False
|
||||
|
||||
makeName :: Text -> Name
|
||||
makeName n = Name
|
||||
|
@ -314,7 +312,6 @@ parseArg a tKind = ParsedArg
|
|||
|
||||
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
||||
parseArgCase a tKind
|
||||
| a ^. type' == DT_RESOURCE = ResourceArg
|
||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
||||
| otherwise = SimpleArg thisArgType tKind
|
||||
|
|
109
tensorflow-ops/src/TensorFlow/Variable.hs
Normal file
109
tensorflow-ops/src/TensorFlow/Variable.hs
Normal file
|
@ -0,0 +1,109 @@
|
|||
-- | An implementation of ResourceHandle-based variables.
|
||||
--
|
||||
-- The main difference between this and 'Ref'-based variables is
|
||||
-- that reads are explicit, via the 'readValue' op.
|
||||
--
|
||||
-- TODO: given that distinction, figure out a good story around
|
||||
-- gradients and save/restore. Then, merge this module into
|
||||
-- TensorFlow.Ops.
|
||||
{-# LANGUAGE RecursiveDo #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
module TensorFlow.Variable
|
||||
( Variable
|
||||
, variable
|
||||
, variable'
|
||||
, readValue
|
||||
, initializedVariable
|
||||
, initializedVariable'
|
||||
, zeroInitializedVariable
|
||||
, zeroInitializedVariable'
|
||||
, assign
|
||||
, assign'
|
||||
, assignAdd
|
||||
, assignAdd'
|
||||
) where
|
||||
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import TensorFlow.Core
|
||||
import TensorFlow.Build (opDef)
|
||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||
import TensorFlow.Output (opInputs, unNodeName)
|
||||
import TensorFlow.Tensor (tensorNodeName)
|
||||
import TensorFlow.Types (tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Ops (zeros)
|
||||
|
||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||
|
||||
-- | Creates a new, uninitialized variable.
|
||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||
variable = variable' id
|
||||
|
||||
variable' :: forall m a . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Shape -> m (Variable a)
|
||||
variable' params s = build $ do
|
||||
-- Each variable needs a unique "shared_name". Use MonadFix to
|
||||
-- set the attribute to the same name as the variable itself, without
|
||||
-- exposing more internals of the Build module.
|
||||
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||||
(tensorType (undefined :: a)) s
|
||||
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||
return $ Variable t
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: (MonadBuild m, TensorType a)
|
||||
=> Tensor v a -> m (Variable a)
|
||||
initializedVariable = initializedVariable' id
|
||||
|
||||
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Tensor v a -> m (Variable a)
|
||||
initializedVariable' params initializer = do
|
||||
-- The shape is not known initially.
|
||||
v@(Variable h) <- variable' params (Shape [])
|
||||
i <- CoreOps.assignVariableOp h initializer
|
||||
addInitializer =<< group i
|
||||
return v
|
||||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
|
||||
zeroInitializedVariable = zeroInitializedVariable' id
|
||||
|
||||
zeroInitializedVariable'
|
||||
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
|
||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- | Gets the value stored in a variable.
|
||||
readValue :: TensorType a => Variable a -> Tensor Build a
|
||||
readValue = readValue' id
|
||||
|
||||
readValue' :: forall a . TensorType a
|
||||
=> OpParams -> Variable a -> Tensor Build a
|
||||
readValue' params (Variable h)
|
||||
= pureOp [] $ do
|
||||
os <- buildInputs h
|
||||
pure $ opDef "ReadVariableOp"
|
||||
& (params
|
||||
. (opAttr "dtype" .~ tensorType (undefined :: a))
|
||||
. (opInputs .~ os))
|
||||
|
||||
-- | Sets the value of a variable.
|
||||
assign :: (MonadBuild m, TensorType a)
|
||||
=> Variable a -> Tensor v a -> m ControlNode
|
||||
assign = assign' id
|
||||
|
||||
assign' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
|
||||
|
||||
-- | Increments the value of a variable.
|
||||
assignAdd :: (MonadBuild m, TensorType a)
|
||||
=> Variable a -> Tensor v a -> m ControlNode
|
||||
assignAdd = assignAdd' id
|
||||
|
||||
assignAdd' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
|
|
@ -16,6 +16,7 @@ library
|
|||
exposed-modules: TensorFlow.Gradient
|
||||
, TensorFlow.Ops
|
||||
, TensorFlow.EmbeddingOps
|
||||
, TensorFlow.Variable
|
||||
build-depends: proto-lens == 0.2.*
|
||||
, base >= 4.7 && < 5
|
||||
, bytestring
|
||||
|
@ -37,6 +38,8 @@ Test-Suite RegressionTest
|
|||
hs-source-dirs: tests
|
||||
build-depends: base
|
||||
, HUnit
|
||||
, lens-family
|
||||
, transformers
|
||||
, random
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
|
@ -126,6 +129,23 @@ Test-Suite OpsTest
|
|||
, transformers
|
||||
, vector
|
||||
|
||||
Test-Suite VariableTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: VariableTest.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: HUnit
|
||||
, base
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, test-framework
|
||||
, test-framework-hunit
|
||||
, transformers
|
||||
, vector
|
||||
|
||||
|
||||
Test-Suite DataFlowOpsTest
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
|
|
62
tensorflow-ops/tests/VariableTest.hs
Normal file
62
tensorflow-ops/tests/VariableTest.hs
Normal file
|
@ -0,0 +1,62 @@
|
|||
{-# LANGUAGE OverloadedLists #-}
|
||||
module Main (main) where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import qualified Data.Vector.Storable as V
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Core (unScalar, run_, runSession, run)
|
||||
import qualified TensorFlow.Ops as Ops
|
||||
import TensorFlow.Variable
|
||||
( readValue
|
||||
, initializedVariable
|
||||
, assign
|
||||
, assignAdd
|
||||
)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testRereadRef
|
||||
, testAssignAdd
|
||||
]
|
||||
|
||||
testInitializedVariable :: Test
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- do
|
||||
v <- initializedVariable 42
|
||||
r <- assign v 24
|
||||
return (1 + readValue v, r)
|
||||
result <- run formula
|
||||
liftIO $ 43 @=? (unScalar result :: Float)
|
||||
run_ reset -- Updates v to a different value
|
||||
rerunResult <- run formula
|
||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||
|
||||
testInitializedVariableShape :: Test
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||
result <- run (readValue vector)
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
-- | See https://github.com/tensorflow/haskell/issues/92.
|
||||
-- Even though we're not explicitly evaluating `f0` until the end,
|
||||
-- it should hold the earlier value of the variable.
|
||||
testRereadRef :: Test
|
||||
testRereadRef = testCase "testReRunAssign" $ runSession $ do
|
||||
w <- initializedVariable 0
|
||||
f0 <- run (readValue w)
|
||||
run_ =<< assign w (Ops.scalar (0.1 :: Float))
|
||||
f1 <- run (readValue w)
|
||||
liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1)
|
||||
|
||||
testAssignAdd :: Test
|
||||
testAssignAdd = testCase "testAssignAdd" $ runSession $ do
|
||||
w <- initializedVariable 42
|
||||
run_ =<< assignAdd w 17
|
||||
f1 <- run (readValue w)
|
||||
liftIO $ (42 + 17 :: Float) @=? unScalar f1
|
|
@ -62,6 +62,7 @@ module TensorFlow.Build
|
|||
) where
|
||||
|
||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||
import Control.Monad.Fix (MonadFix(..))
|
||||
import Control.Monad.IO.Class (MonadIO(..))
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||
|
@ -192,7 +193,8 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
|||
-- Used to manage build state internally as part of the @Session@ monad.
|
||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
|
||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
||||
MonadFix)
|
||||
|
||||
-- | An action for building nodes in a TensorFlow graph.
|
||||
type Build = BuildT Identity
|
||||
|
@ -307,7 +309,6 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
|
|||
nextUnique .= succ u
|
||||
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||
|
||||
|
||||
-- | Turn an 'Output' into a string representation for the TensorFlow
|
||||
-- foreign APIs.
|
||||
encodeOutput :: Output -> Text
|
||||
|
|
|
@ -126,9 +126,6 @@ recordResult = do
|
|||
put $! ResultState (i+1) ns
|
||||
return $! output i o
|
||||
|
||||
instance BuildResult ResourceHandle where
|
||||
buildResult = ResourceHandle <$> recordResult
|
||||
|
||||
instance Rendered v => BuildResult (Tensor v a) where
|
||||
buildResult = Tensor . pure <$> recordResult
|
||||
|
||||
|
@ -302,9 +299,6 @@ instance BuildInputs (ListOf (Tensor v) as) where
|
|||
buildInputs Nil = return []
|
||||
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
|
||||
|
||||
instance BuildInputs ResourceHandle where
|
||||
buildInputs (ResourceHandle o) = return [o]
|
||||
|
||||
----
|
||||
|
||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||
|
|
|
@ -52,6 +52,7 @@ module TensorFlow.Core
|
|||
, addGraphDef
|
||||
, opName
|
||||
, opAttr
|
||||
, addInitializer
|
||||
-- * Tensor
|
||||
, ControlNode
|
||||
, Tensor
|
||||
|
@ -64,6 +65,7 @@ module TensorFlow.Core
|
|||
, TensorType
|
||||
, TensorData
|
||||
, TensorDataType(decodeTensorData, encodeTensorData)
|
||||
, ResourceHandle
|
||||
, Scalar(..)
|
||||
, Shape(..)
|
||||
, OneOf
|
||||
|
|
|
@ -33,7 +33,6 @@ module TensorFlow.Output
|
|||
, Output(..)
|
||||
, output
|
||||
, PendingNodeName(..)
|
||||
, ResourceHandle(..)
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
|
@ -127,8 +126,3 @@ instance IsString Output where
|
|||
-> Output (fromInteger ix) $ assigned n
|
||||
_ -> Output 0 $ assigned s
|
||||
where assigned = NodeName . Text.pack
|
||||
|
||||
|
||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||
-- resources are variables.
|
||||
newtype ResourceHandle = ResourceHandle Output
|
||||
|
|
|
@ -39,6 +39,7 @@ module TensorFlow.Types
|
|||
, protoShape
|
||||
, Attribute(..)
|
||||
, DataType(..)
|
||||
, ResourceHandle
|
||||
-- * Lists
|
||||
, ListOf(..)
|
||||
, List
|
||||
|
@ -94,15 +95,18 @@ import Proto.Tensorflow.Core.Framework.AttrValue
|
|||
, shape
|
||||
, tensor
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||
(ResourceHandle)
|
||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||
( TensorProto(..)
|
||||
, floatVal
|
||||
, doubleVal
|
||||
, intVal
|
||||
, stringVal
|
||||
, int64Val
|
||||
, stringVal
|
||||
, boolVal
|
||||
, doubleVal
|
||||
, floatVal
|
||||
, intVal
|
||||
, int64Val
|
||||
, resourceHandleVal
|
||||
, stringVal
|
||||
, stringVal
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||
( TensorShapeProto(..)
|
||||
|
@ -183,6 +187,10 @@ instance TensorType (Complex Double) where
|
|||
tensorRefType _ = DT_COMPLEX128
|
||||
tensorVal = error "TODO (Complex Double)"
|
||||
|
||||
instance TensorType ResourceHandle where
|
||||
tensorType _ = DT_RESOURCE
|
||||
tensorRefType _ = DT_RESOURCE_REF
|
||||
tensorVal = resourceHandleVal
|
||||
|
||||
-- | Tensor data with the correct memory layout for tensorflow.
|
||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||
|
|
Loading…
Reference in a new issue