110 lines
3.9 KiB
Haskell
110 lines
3.9 KiB
Haskell
|
-- | An implementation of ResourceHandle-based variables.
|
||
|
--
|
||
|
-- The main difference between this and 'Ref'-based variables is
|
||
|
-- that reads are explicit, via the 'readValue' op.
|
||
|
--
|
||
|
-- TODO: given that distinction, figure out a good story around
|
||
|
-- gradients and save/restore. Then, merge this module into
|
||
|
-- TensorFlow.Ops.
|
||
|
{-# LANGUAGE RecursiveDo #-}
|
||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||
|
{-# LANGUAGE OverloadedStrings #-}
|
||
|
module TensorFlow.Variable
|
||
|
( Variable
|
||
|
, variable
|
||
|
, variable'
|
||
|
, readValue
|
||
|
, initializedVariable
|
||
|
, initializedVariable'
|
||
|
, zeroInitializedVariable
|
||
|
, zeroInitializedVariable'
|
||
|
, assign
|
||
|
, assign'
|
||
|
, assignAdd
|
||
|
, assignAdd'
|
||
|
) where
|
||
|
|
||
|
import Data.Text.Encoding (encodeUtf8)
|
||
|
import Lens.Family2 ((.~), (&))
|
||
|
import TensorFlow.Core
|
||
|
import TensorFlow.Build (opDef)
|
||
|
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||
|
import TensorFlow.Output (opInputs, unNodeName)
|
||
|
import TensorFlow.Tensor (tensorNodeName)
|
||
|
import TensorFlow.Types (tensorType)
|
||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||
|
import TensorFlow.Ops (zeros)
|
||
|
|
||
|
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||
|
|
||
|
-- | Creates a new, uninitialized variable.
|
||
|
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||
|
variable = variable' id
|
||
|
|
||
|
variable' :: forall m a . (MonadBuild m, TensorType a)
|
||
|
=> OpParams -> Shape -> m (Variable a)
|
||
|
variable' params s = build $ do
|
||
|
-- Each variable needs a unique "shared_name". Use MonadFix to
|
||
|
-- set the attribute to the same name as the variable itself, without
|
||
|
-- exposing more internals of the Build module.
|
||
|
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||
|
(tensorType (undefined :: a)) s
|
||
|
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||
|
return $ Variable t
|
||
|
|
||
|
-- | Creates a variable initialized to the given value.
|
||
|
-- Initialization happens next time session runs.
|
||
|
initializedVariable :: (MonadBuild m, TensorType a)
|
||
|
=> Tensor v a -> m (Variable a)
|
||
|
initializedVariable = initializedVariable' id
|
||
|
|
||
|
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||
|
=> OpParams -> Tensor v a -> m (Variable a)
|
||
|
initializedVariable' params initializer = do
|
||
|
-- The shape is not known initially.
|
||
|
v@(Variable h) <- variable' params (Shape [])
|
||
|
i <- CoreOps.assignVariableOp h initializer
|
||
|
addInitializer =<< group i
|
||
|
return v
|
||
|
|
||
|
-- | Creates a zero-initialized variable with the given shape.
|
||
|
zeroInitializedVariable
|
||
|
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
|
||
|
zeroInitializedVariable = zeroInitializedVariable' id
|
||
|
|
||
|
zeroInitializedVariable'
|
||
|
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
|
||
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||
|
|
||
|
-- | Gets the value stored in a variable.
|
||
|
readValue :: TensorType a => Variable a -> Tensor Build a
|
||
|
readValue = readValue' id
|
||
|
|
||
|
readValue' :: forall a . TensorType a
|
||
|
=> OpParams -> Variable a -> Tensor Build a
|
||
|
readValue' params (Variable h)
|
||
|
= pureOp [] $ do
|
||
|
os <- buildInputs h
|
||
|
pure $ opDef "ReadVariableOp"
|
||
|
& (params
|
||
|
. (opAttr "dtype" .~ tensorType (undefined :: a))
|
||
|
. (opInputs .~ os))
|
||
|
|
||
|
-- | Sets the value of a variable.
|
||
|
assign :: (MonadBuild m, TensorType a)
|
||
|
=> Variable a -> Tensor v a -> m ControlNode
|
||
|
assign = assign' id
|
||
|
|
||
|
assign' :: (MonadBuild m, TensorType a)
|
||
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||
|
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
|
||
|
|
||
|
-- | Increments the value of a variable.
|
||
|
assignAdd :: (MonadBuild m, TensorType a)
|
||
|
=> Variable a -> Tensor v a -> m ControlNode
|
||
|
assignAdd = assignAdd' id
|
||
|
|
||
|
assignAdd' :: (MonadBuild m, TensorType a)
|
||
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||
|
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
|