Safe Haskell | None |
---|---|
Language | Haskell2010 |
- data Variable a
- variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
- variable' :: forall m a. (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Variable a)
- readValue :: TensorType a => Variable a -> Tensor Build a
- initializedValue :: Variable a -> Maybe (Tensor Value a)
- initializedVariable :: (MonadBuild m, TensorType a) => Tensor v a -> m (Variable a)
- initializedVariable' :: forall a m v. (MonadBuild m, TensorType a) => OpParams -> Tensor v a -> m (Variable a)
- zeroInitializedVariable :: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
- zeroInitializedVariable' :: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
- assign :: (MonadBuild m, TensorType a) => Variable a -> Tensor v a -> m ControlNode
- assign' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode
- assignAdd :: (MonadBuild m, TensorType a) => Variable a -> Tensor v a -> m ControlNode
- assignAdd' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode
- resourceApplyAdam :: (MonadBuild m, OneOf '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) => Variable t -> Variable t -> Variable t -> Tensor v1 t -> Tensor v2 t -> Tensor v3 t -> Tensor v4 t -> Tensor v5 t -> Tensor v6 t -> Tensor v7 t -> m ControlNode
- resourceApplyAdam' :: (MonadBuild m, OneOf '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) => OpParams -> Variable t -> Variable t -> Variable t -> Tensor v1 t -> Tensor v2 t -> Tensor v3 t -> Tensor v4 t -> Tensor v5 t -> Tensor v6 t -> Tensor v7 t -> m ControlNode
Documentation
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a) Source #
Creates a new, uninitialized variable.
variable' :: forall m a. (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Variable a) Source #
readValue :: TensorType a => Variable a -> Tensor Build a Source #
Gets the value stored in a variable.
Note that this op is stateful since it depends on the value of the variable;
however, it may be CSE'd with other reads in the same context. The context can
be fixed by using render
along with (for example) withControlDependencies
.
For example:
runSession $ do v <- variable [] a <- assign v 24 r <- withControlDependencies a $ render $ readValue v + 18 result <- run r liftIO $ (42 :: Float) @=? unScalar result
initializedValue :: Variable a -> Maybe (Tensor Value a) Source #
The initial value of a Variable
created with initializedVariable
.
initializedVariable :: (MonadBuild m, TensorType a) => Tensor v a -> m (Variable a) Source #
Creates a variable initialized to the given value. Initialization happens next time session runs.
initializedVariable' :: forall a m v. (MonadBuild m, TensorType a) => OpParams -> Tensor v a -> m (Variable a) Source #
zeroInitializedVariable :: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a) Source #
Creates a zero-initialized variable with the given shape.
zeroInitializedVariable' :: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a) Source #
assign :: (MonadBuild m, TensorType a) => Variable a -> Tensor v a -> m ControlNode Source #
Sets the value of a variable.
assign' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode Source #
assignAdd :: (MonadBuild m, TensorType a) => Variable a -> Tensor v a -> m ControlNode Source #
Increments the value of a variable.
assignAdd' :: (MonadBuild m, TensorType a) => OpParams -> Variable a -> Tensor v a -> m ControlNode Source #
:: (MonadBuild m, OneOf '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) | |
=> Variable t | var: Should be from a Variable(). |
-> Variable t | m: Should be from a Variable(). |
-> Variable t | v: Should be from a Variable(). |
-> Tensor v1 t | beta1_power: Must be a scalar. |
-> Tensor v2 t | beta2_power: Must be a scalar. |
-> Tensor v3 t | lr: Scaling factor. Must be a scalar. |
-> Tensor v4 t | beta1: Momentum factor. Must be a scalar. |
-> Tensor v5 t | beta2: Momentum factor. Must be a scalar. |
-> Tensor v6 t | epsilon: Ridge term. Must be a scalar. |
-> Tensor v7 t | grad: The gradient. |
-> m ControlNode |
Update '*var' according to the Adam algorithm.
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
:: (MonadBuild m, OneOf '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16, Word8, Double, Float] t) | |
=> OpParams | |
-> Variable t | var: Should be from a Variable(). |
-> Variable t | m: Should be from a Variable(). |
-> Variable t | v: Should be from a Variable(). |
-> Tensor v1 t | beta1_power: Must be a scalar. |
-> Tensor v2 t | beta2_power: Must be a scalar. |
-> Tensor v3 t | lr: Scaling factor. Must be a scalar. |
-> Tensor v4 t | beta1: Momentum factor. Must be a scalar. |
-> Tensor v5 t | beta2: Momentum factor. Must be a scalar. |
-> Tensor v6 t | epsilon: Ridge term. Must be a scalar. |
-> Tensor v7 t | grad: The gradient. |
-> m ControlNode |