{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
( Variable
, variable
, variable'
, readValue
, initializedValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, assign
, assign'
, assignAdd
, assignAdd'
, resourceApplyAdam
, resourceApplyAdam'
) where
import qualified Data.Complex
import qualified Data.Int
import qualified Data.Word
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
data Variable a = Variable
{ variableHandle :: Tensor Value ResourceHandle
, initializedValue :: Maybe (Tensor Value a)
}
instance Rendered Variable where
renderedOutput = renderedOutput . variableHandle
instance ToTensor Variable where
toTensor = readValue
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' params s = build $ do
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
(tensorType (undefined :: a)) s
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t Nothing
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Variable a)
initializedVariable = initializedVariable' id
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i
return (Variable h (Just initializer'))
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params = initializedVariable' params . zeros
readValue :: TensorType a => Variable a -> Tensor Build a
readValue = readValue' id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' params (Variable h _)
= pureOp [] $ do
os <- buildInputs h
pure $ opDef "ReadVariableOp"
& (params
. (opAttr "dtype" .~ tensorType (undefined :: a))
. (opInputs .~ os))
assign :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assign = assign' id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v
assignAdd :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assignAdd = assignAdd' id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
resourceApplyAdam ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16,
Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16,
Data.Word.Word8, Double,
Float] t)
=> Variable t -- ^ __var__: Should be from a Variable().
-> Variable t -- ^ __m__: Should be from a Variable().
-> Variable t -- ^ __v__: Should be from a Variable().
-> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
-> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
-> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
-> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
-> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
-> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
-> Tensor v7 t -- ^ __grad__: The gradient.
-> m (ControlNode)
resourceApplyAdam = resourceApplyAdam' id
resourceApplyAdam' ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16, Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16, Data.Word.Word8, Double,
Float] t)
=> OpParams
-> Variable t -- ^ __var__: Should be from a Variable().
-> Variable t -- ^ __m__: Should be from a Variable().
-> Variable t -- ^ __v__: Should be from a Variable().
-> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
-> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
-> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
-> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
-> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
-> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
-> Tensor v7 t -- ^ __grad__: The gradient.
-> m (ControlNode)
resourceApplyAdam' params (Variable var _) (Variable m _) (Variable v _) =
CoreOps.resourceApplyAdam' params var m v