{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
( Variable
, variable
, variable'
, readValue
, initializedValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, assign
, assign'
, assignAdd
, assignAdd'
, resourceApplyAdam
, resourceApplyAdam'
) where
import qualified Data.Complex
import qualified Data.Int
import qualified Data.Word
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
data Variable a = Variable
{ variableHandle :: Tensor Value ResourceHandle
, initializedValue :: Maybe (Tensor Value a)
}
instance Rendered Variable where
renderedOutput = renderedOutput . variableHandle
instance ToTensor Variable where
toTensor = readValue
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' params s = variableInternal params (Just s)
variableInternal :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Maybe Shape -> m (Variable a)
variableInternal params s = build $ do
rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s)
dtype = tensorType (undefined :: a)
shape = Shape []
t <- CoreOps.varHandleOp' attrs dtype shape
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t Nothing
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Variable a)
initializedVariable = initializedVariable' id
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i
return (Variable h (Just initializer'))
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params = initializedVariable' params . zeros
readValue :: TensorType a => Variable a -> Tensor Build a
readValue = readValue' id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' params (Variable h _)
= pureOp [] $ do
os <- buildInputs h
pure $ opDef "ReadVariableOp"
& (params
. (opAttr "dtype" .~ tensorType (undefined :: a))
. (opInputs .~ os))
assign :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assign = assign' id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v
assignAdd :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assignAdd = assignAdd' id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
resourceApplyAdam ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16,
Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16,
Data.Word.Word8, Double,
Float] t)
=> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m (ControlNode)
resourceApplyAdam = resourceApplyAdam' id
resourceApplyAdam' ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16, Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16, Data.Word.Word8, Double,
Float] t)
=> OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m (ControlNode)
resourceApplyAdam' params (Variable var _) (Variable m _) (Variable v _) =
CoreOps.resourceApplyAdam' params var m v