diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index 6717b47..41bbab9 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -33,15 +33,14 @@ where: (for example: 'TensorType' or 'OneOf'). * @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of -the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one -of those types), and @a@ is either a concrete type or a (constrained) type -variable. +the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types), +and @a@ is either a concrete type or a (constrained) type variable. * @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and @Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if -it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly -marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs, -it is either @ControlNode@ or @Build ControlNode@.) +it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's +explicitly marked \"Stateful\" in its @REGISTER_OP@ definition. (If there +are no outputs, it is either @ControlNode@ or @Build ControlNode@.) -} module TensorFlow.OpGen @@ -155,7 +154,6 @@ imports = stack [ , "import Lens.Family2 ((.~), (&))" , "import TensorFlow.Build" , "import TensorFlow.BuildOp" - , "import TensorFlow.Output (ResourceHandle)" , "import TensorFlow.Tensor" , "import TensorFlow.Types" ] @@ -300,7 +298,7 @@ typeSig pre pOp = constraints | null classConstraints = empty | otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, - Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]] + ArgSomeTensor v <- [argKind $ parsedArgCase k]] ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] ++ if parsedOpIsMonadic pOp then ["m'"] else [] -- Use m' as the type parameter to avoid clashing with an attribute name. @@ -333,13 +331,12 @@ typeSig pre pOp = constraints | otherwise = o -- | Render an op input or output. --- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle" +-- For example: "Tensor Ref Int64", "Tensor v t" tensorArg :: ParsedArg -> Doc tensorArg p = case parsedArgCase p of - ResourceArg -> "ResourceHandle" - SimpleArg { argType = t, argCaseKind = k } -> tensorType t k - ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k - MixedListArg {argTypeAttr = t, argCaseKind = k} + SimpleArg { argType = t, argKind = k } -> tensorType t k + ListArg { argType = t, argKind = k } -> brackets $ tensorType t k + MixedListArg {argTypeAttr = t, argKind = k} -> "TensorList" <+> parens (kind k) <+> renderHaskellName t where kind k = case k of @@ -420,8 +417,7 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString" dtTypeToHaskell DT_UINT16 = "Data.Word.Word16" dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique dtTypeToHaskell DT_UINT8 = "Data.Word.Word8" -dtTypeToHaskell DT_RESOURCE = - error "ResourceHandle must be prevented from getting here." +dtTypeToHaskell DT_RESOURCE = "ResourceHandle" dtTypeToHaskell x = Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index 2fbe4e4..63c363f 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -17,7 +17,6 @@ module TensorFlow.OpGen.ParsedOp , ParsedArgCase(..) , ArgType(..) , ArgKind(..) - , argKind , parseOp , camelCase ) where @@ -118,19 +117,18 @@ data ParsedArg = ParsedArg } data ParsedArgCase - = SimpleArg { argType :: ArgType, argCaseKind :: ArgKind } + = SimpleArg { argType :: ArgType, argKind :: ArgKind } | ListArg { argLength :: Name -- ^ The attribute that specifies this list's length. , argType :: ArgType - , argCaseKind :: ArgKind + , argKind :: ArgKind } - | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } + | MixedListArg { argTypeAttr :: Name, argKind :: ArgKind } -- ^ A heterogeneous list. - | ResourceArg -argKind :: ParsedArgCase -> Maybe ArgKind -argKind ResourceArg = Nothing -argKind a = Just $ argCaseKind a +maybeArgType :: ParsedArgCase -> Maybe ArgType +maybeArgType MixedListArg{} = Nothing +maybeArgType a = Just $ argType a -- | The type of an argument. data ArgType @@ -146,10 +144,10 @@ data ArgKind deriving (Eq) isRefCase :: ParsedArgCase -> Bool -isRefCase a = case argKind a of - Nothing -> True -- Resource - Just ArgTensorRef -> True - _ -> False +isRefCase a + | ArgTensorRef <- argKind a = True + | Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True + | otherwise = False makeName :: Text -> Name makeName n = Name @@ -314,7 +312,6 @@ parseArg a tKind = ParsedArg parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase parseArgCase a tKind - | a ^. type' == DT_RESOURCE = ResourceArg | Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind | Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind | otherwise = SimpleArg thisArgType tKind diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs new file mode 100644 index 0000000..b62082e --- /dev/null +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -0,0 +1,109 @@ +-- | An implementation of ResourceHandle-based variables. +-- +-- The main difference between this and 'Ref'-based variables is +-- that reads are explicit, via the 'readValue' op. +-- +-- TODO: given that distinction, figure out a good story around +-- gradients and save/restore. Then, merge this module into +-- TensorFlow.Ops. +{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE OverloadedStrings #-} +module TensorFlow.Variable + ( Variable + , variable + , variable' + , readValue + , initializedVariable + , initializedVariable' + , zeroInitializedVariable + , zeroInitializedVariable' + , assign + , assign' + , assignAdd + , assignAdd' + ) where + +import Data.Text.Encoding (encodeUtf8) +import Lens.Family2 ((.~), (&)) +import TensorFlow.Core +import TensorFlow.Build (opDef) +import TensorFlow.BuildOp (buildInputs, pureOp, OpParams) +import TensorFlow.Output (opInputs, unNodeName) +import TensorFlow.Tensor (tensorNodeName) +import TensorFlow.Types (tensorType) +import qualified TensorFlow.GenOps.Core as CoreOps +import TensorFlow.Ops (zeros) + +newtype Variable a = Variable (Tensor Value ResourceHandle) + +-- | Creates a new, uninitialized variable. +variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a) +variable = variable' id + +variable' :: forall m a . (MonadBuild m, TensorType a) + => OpParams -> Shape -> m (Variable a) +variable' params s = build $ do + -- Each variable needs a unique "shared_name". Use MonadFix to + -- set the attribute to the same name as the variable itself, without + -- exposing more internals of the Build module. + rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n)) + (tensorType (undefined :: a)) s + let n = encodeUtf8 $ unNodeName $ tensorNodeName t + return $ Variable t + +-- | Creates a variable initialized to the given value. +-- Initialization happens next time session runs. +initializedVariable :: (MonadBuild m, TensorType a) + => Tensor v a -> m (Variable a) +initializedVariable = initializedVariable' id + +initializedVariable' :: forall a m v . (MonadBuild m, TensorType a) + => OpParams -> Tensor v a -> m (Variable a) +initializedVariable' params initializer = do + -- The shape is not known initially. + v@(Variable h) <- variable' params (Shape []) + i <- CoreOps.assignVariableOp h initializer + addInitializer =<< group i + return v + +-- | Creates a zero-initialized variable with the given shape. +zeroInitializedVariable + :: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a) +zeroInitializedVariable = zeroInitializedVariable' id + +zeroInitializedVariable' + :: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a) +zeroInitializedVariable' params = initializedVariable' params . zeros + +-- | Gets the value stored in a variable. +readValue :: TensorType a => Variable a -> Tensor Build a +readValue = readValue' id + +readValue' :: forall a . TensorType a + => OpParams -> Variable a -> Tensor Build a +readValue' params (Variable h) + = pureOp [] $ do + os <- buildInputs h + pure $ opDef "ReadVariableOp" + & (params + . (opAttr "dtype" .~ tensorType (undefined :: a)) + . (opInputs .~ os)) + +-- | Sets the value of a variable. +assign :: (MonadBuild m, TensorType a) + => Variable a -> Tensor v a -> m ControlNode +assign = assign' id + +assign' :: (MonadBuild m, TensorType a) + => OpParams -> Variable a -> Tensor v a -> m ControlNode +assign' params (Variable h) v = CoreOps.assignVariableOp' params h v + +-- | Increments the value of a variable. +assignAdd :: (MonadBuild m, TensorType a) + => Variable a -> Tensor v a -> m ControlNode +assignAdd = assignAdd' id + +assignAdd' :: (MonadBuild m, TensorType a) + => OpParams -> Variable a -> Tensor v a -> m ControlNode +assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index 40286a7..6f0800f 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -16,6 +16,7 @@ library exposed-modules: TensorFlow.Gradient , TensorFlow.Ops , TensorFlow.EmbeddingOps + , TensorFlow.Variable build-depends: proto-lens == 0.2.* , base >= 4.7 && < 5 , bytestring @@ -37,6 +38,8 @@ Test-Suite RegressionTest hs-source-dirs: tests build-depends: base , HUnit + , lens-family + , transformers , random , tensorflow , tensorflow-core-ops @@ -126,6 +129,23 @@ Test-Suite OpsTest , transformers , vector +Test-Suite VariableTest + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: VariableTest.hs + hs-source-dirs: tests + build-depends: HUnit + , base + , google-shim + , tensorflow + , tensorflow-core-ops + , tensorflow-ops + , test-framework + , test-framework-hunit + , transformers + , vector + + Test-Suite DataFlowOpsTest default-language: Haskell2010 type: exitcode-stdio-1.0 diff --git a/tensorflow-ops/tests/VariableTest.hs b/tensorflow-ops/tests/VariableTest.hs new file mode 100644 index 0000000..eccc76a --- /dev/null +++ b/tensorflow-ops/tests/VariableTest.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE OverloadedLists #-} +module Main (main) where + +import Control.Monad.IO.Class (liftIO) +import qualified Data.Vector.Storable as V +import Google.Test (googleTest) +import TensorFlow.Core (unScalar, run_, runSession, run) +import qualified TensorFlow.Ops as Ops +import TensorFlow.Variable + ( readValue + , initializedVariable + , assign + , assignAdd + ) +import Test.Framework (Test) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@=?)) + +main :: IO () +main = googleTest [ testInitializedVariable + , testInitializedVariableShape + , testRereadRef + , testAssignAdd + ] + +testInitializedVariable :: Test +testInitializedVariable = + testCase "testInitializedVariable" $ runSession $ do + (formula, reset) <- do + v <- initializedVariable 42 + r <- assign v 24 + return (1 + readValue v, r) + result <- run formula + liftIO $ 43 @=? (unScalar result :: Float) + run_ reset -- Updates v to a different value + rerunResult <- run formula + liftIO $ 25 @=? (unScalar rerunResult :: Float) + +testInitializedVariableShape :: Test +testInitializedVariableShape = + testCase "testInitializedVariableShape" $ runSession $ do + vector <- initializedVariable (Ops.constant [1] [42 :: Float]) + result <- run (readValue vector) + liftIO $ [42] @=? (result :: V.Vector Float) + +-- | See https://github.com/tensorflow/haskell/issues/92. +-- Even though we're not explicitly evaluating `f0` until the end, +-- it should hold the earlier value of the variable. +testRereadRef :: Test +testRereadRef = testCase "testReRunAssign" $ runSession $ do + w <- initializedVariable 0 + f0 <- run (readValue w) + run_ =<< assign w (Ops.scalar (0.1 :: Float)) + f1 <- run (readValue w) + liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1) + +testAssignAdd :: Test +testAssignAdd = testCase "testAssignAdd" $ runSession $ do + w <- initializedVariable 42 + run_ =<< assignAdd w 17 + f1 <- run (readValue w) + liftIO $ (42 + 17 :: Float) @=? unScalar f1 diff --git a/tensorflow/src/TensorFlow/Build.hs b/tensorflow/src/TensorFlow/Build.hs index 84c6c09..7f549a3 100644 --- a/tensorflow/src/TensorFlow/Build.hs +++ b/tensorflow/src/TensorFlow/Build.hs @@ -62,6 +62,7 @@ module TensorFlow.Build ) where import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask) +import Control.Monad.Fix (MonadFix(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT) @@ -192,7 +193,8 @@ summaries = lens _summaries (\g x -> g { _summaries = x }) -- Used to manage build state internally as part of the @Session@ monad. newtype BuildT m a = BuildT (StateT GraphState m a) deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, - MonadState GraphState, MonadThrow, MonadCatch, MonadMask) + MonadState GraphState, MonadThrow, MonadCatch, MonadMask, + MonadFix) -- | An action for building nodes in a TensorFlow graph. type Build = BuildT Identity @@ -307,7 +309,6 @@ renderPendingNode (PendingNode scope pendingName nodeDef) nextUnique .= succ u return $ nodeDef ^. op <> "_" <> Text.pack (show k) - -- | Turn an 'Output' into a string representation for the TensorFlow -- foreign APIs. encodeOutput :: Output -> Text diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 9b074ce..50019f7 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -126,9 +126,6 @@ recordResult = do put $! ResultState (i+1) ns return $! output i o -instance BuildResult ResourceHandle where - buildResult = ResourceHandle <$> recordResult - instance Rendered v => BuildResult (Tensor v a) where buildResult = Tensor . pure <$> recordResult @@ -302,9 +299,6 @@ instance BuildInputs (ListOf (Tensor v) as) where buildInputs Nil = return [] buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts) -instance BuildInputs ResourceHandle where - buildInputs (ResourceHandle o) = return [o] - ---- -- | Parameters to build an op (for example, the node name or optional attributes). diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 7834211..6a36c2a 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -52,6 +52,7 @@ module TensorFlow.Core , addGraphDef , opName , opAttr + , addInitializer -- * Tensor , ControlNode , Tensor @@ -64,6 +65,7 @@ module TensorFlow.Core , TensorType , TensorData , TensorDataType(decodeTensorData, encodeTensorData) + , ResourceHandle , Scalar(..) , Shape(..) , OneOf diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index ef6b489..2114dbc 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -33,7 +33,6 @@ module TensorFlow.Output , Output(..) , output , PendingNodeName(..) - , ResourceHandle(..) ) where import qualified Data.Map.Strict as Map @@ -127,8 +126,3 @@ instance IsString Output where -> Output (fromInteger ix) $ assigned n _ -> Output 0 $ assigned s where assigned = NodeName . Text.pack - - --- | Opaque handle to a mutable resource in the graph. Typical such --- resources are variables. -newtype ResourceHandle = ResourceHandle Output diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 4ed8063..689dad3 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -39,6 +39,7 @@ module TensorFlow.Types , protoShape , Attribute(..) , DataType(..) + , ResourceHandle -- * Lists , ListOf(..) , List @@ -94,15 +95,18 @@ import Proto.Tensorflow.Core.Framework.AttrValue , shape , tensor ) +import Proto.Tensorflow.Core.Framework.ResourceHandle + (ResourceHandle) import Proto.Tensorflow.Core.Framework.Tensor as Tensor ( TensorProto(..) - , floatVal - , doubleVal - , intVal - , stringVal - , int64Val - , stringVal , boolVal + , doubleVal + , floatVal + , intVal + , int64Val + , resourceHandleVal + , stringVal + , stringVal ) import Proto.Tensorflow.Core.Framework.TensorShape ( TensorShapeProto(..) @@ -183,6 +187,10 @@ instance TensorType (Complex Double) where tensorRefType _ = DT_COMPLEX128 tensorVal = error "TODO (Complex Double)" +instance TensorType ResourceHandle where + tensorType _ = DT_RESOURCE + tensorRefType _ = DT_RESOURCE_REF + tensorVal = resourceHandleVal -- | Tensor data with the correct memory layout for tensorflow. newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }