{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Variable
( Variable
, variable
, variable'
, readValue
, initializedValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, assign
, assign'
, assignAdd
, assignAdd'
, resourceApplyAdam
, resourceApplyAdam'
) where
import qualified Data.Complex
import qualified Data.Int
import qualified Data.Word
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
data Variable a = Variable
{ Variable a -> Tensor Value ResourceHandle
variableHandle :: Tensor Value ResourceHandle
, Variable a -> Maybe (Tensor Value a)
initializedValue :: Maybe (Tensor Value a)
}
instance Rendered Variable where
renderedOutput :: Variable a -> Output
renderedOutput = Tensor Value ResourceHandle -> Output
forall (t :: * -> *) a. Rendered t => t a -> Output
renderedOutput (Tensor Value ResourceHandle -> Output)
-> (Variable a -> Tensor Value ResourceHandle)
-> Variable a
-> Output
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> Tensor Value ResourceHandle
forall a. Variable a -> Tensor Value ResourceHandle
variableHandle
instance ToTensor Variable where
toTensor :: Variable a -> Tensor Build a
toTensor = Variable a -> Tensor Build a
forall a. TensorType a => Variable a -> Tensor Build a
readValue
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable :: Shape -> m (Variable a)
variable = OpParams -> Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Shape -> m (Variable a)
variable' OpParams
forall a. a -> a
id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' :: OpParams -> Shape -> m (Variable a)
variable' params :: OpParams
params s :: Shape
s = OpParams -> Maybe Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Maybe Shape -> m (Variable a)
variableInternal OpParams
params (Shape -> Maybe Shape
forall a. a -> Maybe a
Just Shape
s)
variableInternal :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Maybe Shape -> m (Variable a)
variableInternal :: OpParams -> Maybe Shape -> m (Variable a)
variableInternal params :: OpParams
params s :: Maybe Shape
s = Build (Variable a) -> m (Variable a)
forall (m :: * -> *) a. MonadBuild m => Build a -> m a
build (Build (Variable a) -> m (Variable a))
-> Build (Variable a) -> m (Variable a)
forall a b. (a -> b) -> a -> b
$ do
rec let attrs :: OpParams
attrs = OpParams
params OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef ByteString
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shared_name" (forall (f :: * -> *). Identical f => LensLike' f OpDef ByteString)
-> ByteString -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ ByteString
n) OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef (Maybe Shape)
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "shape" (forall (f :: * -> *).
Identical f =>
LensLike' f OpDef (Maybe Shape))
-> Maybe Shape -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ Maybe Shape
s)
dtype :: DataType
dtype = a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a)
shape :: Shape
shape = [Int64] -> Shape
Shape []
Tensor Value ResourceHandle
t <- OpParams
-> DataType
-> Shape
-> BuildT Identity (Tensor Value ResourceHandle)
forall (m' :: * -> *).
MonadBuild m' =>
OpParams -> DataType -> Shape -> m' (Tensor Value ResourceHandle)
CoreOps.varHandleOp' OpParams
attrs DataType
dtype Shape
shape
let n :: ByteString
n = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ NodeName -> Text
unNodeName (NodeName -> Text) -> NodeName -> Text
forall a b. (a -> b) -> a -> b
$ Tensor Value ResourceHandle -> NodeName
forall (t :: * -> *) a. Rendered t => t a -> NodeName
tensorNodeName Tensor Value ResourceHandle
t
Variable a -> Build (Variable a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Variable a -> Build (Variable a))
-> Variable a -> Build (Variable a)
forall a b. (a -> b) -> a -> b
$ Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
forall a.
Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
Variable Tensor Value ResourceHandle
t Maybe (Tensor Value a)
forall a. Maybe a
Nothing
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor v a -> m (Variable a)
initializedVariable :: Tensor v a -> m (Variable a)
initializedVariable = OpParams -> Tensor v a -> m (Variable a)
forall a (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Variable a)
initializedVariable' OpParams
forall a. a -> a
id
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' :: OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params :: OpParams
params initializer :: Tensor v a
initializer = do
Variable a
variables <- OpParams -> Maybe Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a) =>
OpParams -> Maybe Shape -> m (Variable a)
variableInternal OpParams
params Maybe Shape
forall a. Maybe a
Nothing
Tensor Value ResourceHandle
h <- Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle))
-> Tensor Value ResourceHandle -> m (Tensor Value ResourceHandle)
forall a b. (a -> b) -> a -> b
$ case Variable a
variables of
(Variable h :: Tensor Value ResourceHandle
h Nothing :: Variable a) -> Tensor Value ResourceHandle
h
_ -> [Char] -> Tensor Value ResourceHandle
forall a. HasCallStack => [Char] -> a
error "variableInternal is empty"
Tensor Value a
initializer' <- Tensor v a -> m (Tensor Value a)
forall (m :: * -> *) (v :: * -> *) a.
MonadBuild m =>
Tensor v a -> m (Tensor Value a)
renderValue Tensor v a
initializer
ControlNode
i <- Tensor Value ResourceHandle -> Tensor Value a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignVariableOp Tensor Value ResourceHandle
h Tensor Value a
initializer'
ControlNode -> m ()
forall (m :: * -> *). MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode -> m ()) -> m ControlNode -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ControlNode -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
group ControlNode
i
Variable a -> m (Variable a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
forall a.
Tensor Value ResourceHandle -> Maybe (Tensor Value a) -> Variable a
Variable Tensor Value ResourceHandle
h (Tensor Value a -> Maybe (Tensor Value a)
forall a. a -> Maybe a
Just Tensor Value a
initializer'))
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
zeroInitializedVariable :: Shape -> m (Variable a)
zeroInitializedVariable = OpParams -> Shape -> m (Variable a)
forall (m :: * -> *) a.
(MonadBuild m, TensorType a, Num a) =>
OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' OpParams
forall a. a -> a
id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' :: OpParams -> Shape -> m (Variable a)
zeroInitializedVariable' params :: OpParams
params = OpParams -> Tensor Build a -> m (Variable a)
forall a (m :: * -> *) (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Tensor v a -> m (Variable a)
initializedVariable' OpParams
params (Tensor Build a -> m (Variable a))
-> (Shape -> Tensor Build a) -> Shape -> m (Variable a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Tensor Build a
forall a. (Num a, TensorType a) => Shape -> Tensor Build a
zeros
readValue :: TensorType a => Variable a -> Tensor Build a
readValue :: Variable a -> Tensor Build a
readValue = OpParams -> Variable a -> Tensor Build a
forall a. TensorType a => OpParams -> Variable a -> Tensor Build a
readValue' OpParams
forall a. a -> a
id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' :: OpParams -> Variable a -> Tensor Build a
readValue' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _)
= [Int64] -> Build OpDef -> Tensor Build a
forall a. PureResult a => [Int64] -> Build OpDef -> a
pureOp [] (Build OpDef -> Tensor Build a) -> Build OpDef -> Tensor Build a
forall a b. (a -> b) -> a -> b
$ do
[Output]
os <- Tensor Value ResourceHandle -> Build [Output]
forall a. BuildInputs a => a -> Build [Output]
buildInputs Tensor Value ResourceHandle
h
OpDef -> Build OpDef
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpDef -> Build OpDef) -> OpDef -> Build OpDef
forall a b. (a -> b) -> a -> b
$ OpType -> OpDef
opDef "ReadVariableOp"
OpDef -> OpParams -> OpDef
forall s t. s -> (s -> t) -> t
& (OpParams
params
OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Lens' OpDef DataType
forall a. Attribute a => Text -> Lens' OpDef a
opAttr "dtype" (forall (f :: * -> *). Identical f => LensLike' f OpDef DataType)
-> DataType -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ a -> DataType
forall a. TensorType a => a -> DataType
tensorType (a
forall a. HasCallStack => a
undefined :: a))
OpParams -> OpParams -> OpParams
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Lens' OpDef [Output]
forall (f :: * -> *). Identical f => LensLike' f OpDef [Output]
opInputs (forall (f :: * -> *). Identical f => LensLike' f OpDef [Output])
-> [Output] -> OpParams
forall s t a b. Setter s t a b -> b -> s -> t
.~ [Output]
os))
assign :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assign :: Variable a -> Tensor v a -> m ControlNode
assign = OpParams -> Variable a -> Tensor v a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' OpParams
forall a. a -> a
id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' :: OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _) v :: Tensor v a
v = OpParams
-> Tensor Value ResourceHandle -> Tensor v a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams
-> Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignVariableOp' OpParams
params Tensor Value ResourceHandle
h Tensor v a
v
assignAdd :: (MonadBuild m, TensorType a)
=> Variable a -> Tensor v a -> m ControlNode
assignAdd :: Variable a -> Tensor v a -> m ControlNode
assignAdd = OpParams -> Variable a -> Tensor v a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' OpParams
forall a. a -> a
id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' :: OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params :: OpParams
params (Variable h :: Tensor Value ResourceHandle
h _) v :: Tensor v a
v = OpParams
-> Tensor Value ResourceHandle -> Tensor v a -> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) dtype (m' :: * -> *).
(MonadBuild m', TensorType dtype) =>
OpParams
-> Tensor v'1 ResourceHandle -> Tensor v'2 dtype -> m' ControlNode
CoreOps.assignAddVariableOp' OpParams
params Tensor Value ResourceHandle
h Tensor v a
v
resourceApplyAdam ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16,
Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16,
Data.Word.Word8, Double,
Float] t)
=> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m (ControlNode)
resourceApplyAdam :: Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam = OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
forall (m :: * -> *) t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *)
(v4 :: * -> *) (v5 :: * -> *) (v6 :: * -> *) (v7 :: * -> *).
(MonadBuild m,
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t) =>
OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam' OpParams
forall a. a -> a
id
resourceApplyAdam' ::
(MonadBuild m,
OneOf '[(Data.Complex.Complex Double),
(Data.Complex.Complex Float),
Data.Int.Int16, Data.Int.Int32,
Data.Int.Int64, Data.Int.Int8,
Data.Word.Word16, Data.Word.Word8, Double,
Float] t)
=> OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m (ControlNode)
resourceApplyAdam' :: OpParams
-> Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
resourceApplyAdam' params :: OpParams
params (Variable var :: Tensor Value ResourceHandle
var _) (Variable m :: Tensor Value ResourceHandle
m _) (Variable v :: Tensor Value ResourceHandle
v _) =
OpParams
-> Tensor Value ResourceHandle
-> Tensor Value ResourceHandle
-> Tensor Value ResourceHandle
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
forall (v'1 :: * -> *) (v'2 :: * -> *) (v'3 :: * -> *)
(v'4 :: * -> *) (v'5 :: * -> *) (v'6 :: * -> *) (v'7 :: * -> *)
(v'8 :: * -> *) (v'9 :: * -> *) (v'10 :: * -> *) t (m' :: * -> *).
(MonadBuild m',
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word32, Word64, Word8, Double, Float]
t) =>
OpParams
-> Tensor v'1 ResourceHandle
-> Tensor v'2 ResourceHandle
-> Tensor v'3 ResourceHandle
-> Tensor v'4 t
-> Tensor v'5 t
-> Tensor v'6 t
-> Tensor v'7 t
-> Tensor v'8 t
-> Tensor v'9 t
-> Tensor v'10 t
-> m' ControlNode
CoreOps.resourceApplyAdam' OpParams
params Tensor Value ResourceHandle
var Tensor Value ResourceHandle
m Tensor Value ResourceHandle
v