mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Add initializedValue function for Variable (#124)
This commit is contained in:
parent
8e136d3a9c
commit
a86d424cac
4 changed files with 42 additions and 12 deletions
|
@ -4,6 +4,7 @@
|
||||||
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
|
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
|
||||||
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
|
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
|
||||||
`Variables`.
|
`Variables`.
|
||||||
|
- Add `initializedValue` function for `Variable`.
|
||||||
|
|
||||||
## v0.1.0.2
|
## v0.1.0.2
|
||||||
- Add extra-lib-dirs for OS X in the Hackage release (#122).
|
- Add extra-lib-dirs for OS X in the Hackage release (#122).
|
||||||
|
|
|
@ -14,6 +14,7 @@ module TensorFlow.Variable
|
||||||
, variable
|
, variable
|
||||||
, variable'
|
, variable'
|
||||||
, readValue
|
, readValue
|
||||||
|
, initializedValue
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, initializedVariable'
|
, initializedVariable'
|
||||||
, zeroInitializedVariable
|
, zeroInitializedVariable
|
||||||
|
@ -30,15 +31,19 @@ import TensorFlow.Core
|
||||||
import TensorFlow.Build (opDef)
|
import TensorFlow.Build (opDef)
|
||||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||||
import TensorFlow.Output (opInputs, unNodeName)
|
import TensorFlow.Output (opInputs, unNodeName)
|
||||||
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
|
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
|
||||||
import TensorFlow.Types (tensorType)
|
import TensorFlow.Types (tensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import TensorFlow.Ops (zeros)
|
import TensorFlow.Ops (zeros)
|
||||||
|
|
||||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
data Variable a = Variable
|
||||||
|
{ variableHandle :: Tensor Value ResourceHandle
|
||||||
|
, initializedValue :: Maybe (Tensor Value a)
|
||||||
|
-- ^ The initial value of a 'Variable' created with 'initializedVariable'.
|
||||||
|
}
|
||||||
|
|
||||||
instance Rendered Variable where
|
instance Rendered Variable where
|
||||||
renderedOutput (Variable v) = renderedOutput v
|
renderedOutput = renderedOutput . variableHandle
|
||||||
|
|
||||||
instance ToTensor Variable where
|
instance ToTensor Variable where
|
||||||
toTensor = readValue
|
toTensor = readValue
|
||||||
|
@ -56,7 +61,7 @@ variable' params s = build $ do
|
||||||
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||||||
(tensorType (undefined :: a)) s
|
(tensorType (undefined :: a)) s
|
||||||
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||||
return $ Variable t
|
return $ Variable t Nothing
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
|
@ -68,10 +73,11 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Tensor v a -> m (Variable a)
|
=> OpParams -> Tensor v a -> m (Variable a)
|
||||||
initializedVariable' params initializer = do
|
initializedVariable' params initializer = do
|
||||||
-- The shape is not known initially.
|
-- The shape is not known initially.
|
||||||
v@(Variable h) <- variable' params (Shape [])
|
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
|
||||||
i <- CoreOps.assignVariableOp h initializer
|
initializer' <- renderValue initializer
|
||||||
|
i <- CoreOps.assignVariableOp h initializer'
|
||||||
addInitializer =<< group i
|
addInitializer =<< group i
|
||||||
return v
|
return (Variable h (Just initializer'))
|
||||||
|
|
||||||
-- | Creates a zero-initialized variable with the given shape.
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
zeroInitializedVariable
|
zeroInitializedVariable
|
||||||
|
@ -102,7 +108,7 @@ readValue = readValue' id
|
||||||
|
|
||||||
readValue' :: forall a . TensorType a
|
readValue' :: forall a . TensorType a
|
||||||
=> OpParams -> Variable a -> Tensor Build a
|
=> OpParams -> Variable a -> Tensor Build a
|
||||||
readValue' params (Variable h)
|
readValue' params (Variable h _)
|
||||||
= pureOp [] $ do
|
= pureOp [] $ do
|
||||||
os <- buildInputs h
|
os <- buildInputs h
|
||||||
pure $ opDef "ReadVariableOp"
|
pure $ opDef "ReadVariableOp"
|
||||||
|
@ -117,7 +123,7 @@ assign = assign' id
|
||||||
|
|
||||||
assign' :: (MonadBuild m, TensorType a)
|
assign' :: (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||||
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
|
assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v
|
||||||
|
|
||||||
-- | Increments the value of a variable.
|
-- | Increments the value of a variable.
|
||||||
assignAdd :: (MonadBuild m, TensorType a)
|
assignAdd :: (MonadBuild m, TensorType a)
|
||||||
|
@ -126,4 +132,4 @@ assignAdd = assignAdd' id
|
||||||
|
|
||||||
assignAdd' :: (MonadBuild m, TensorType a)
|
assignAdd' :: (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||||
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
|
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
module Main (main) where
|
module Main (main) where
|
||||||
|
|
||||||
|
import Data.Maybe (isJust)
|
||||||
|
import Control.Monad (when)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import qualified Data.Vector.Storable as V
|
import qualified Data.Vector.Storable as V
|
||||||
import TensorFlow.Core
|
import TensorFlow.Core
|
||||||
|
@ -12,7 +14,9 @@ import TensorFlow.Core
|
||||||
, withControlDependencies)
|
, withControlDependencies)
|
||||||
import qualified TensorFlow.Ops as Ops
|
import qualified TensorFlow.Ops as Ops
|
||||||
import TensorFlow.Variable
|
import TensorFlow.Variable
|
||||||
( readValue
|
( Variable
|
||||||
|
, readValue
|
||||||
|
, initializedValue
|
||||||
, initializedVariable
|
, initializedVariable
|
||||||
, assign
|
, assign
|
||||||
, assignAdd
|
, assignAdd
|
||||||
|
@ -20,12 +24,13 @@ import TensorFlow.Variable
|
||||||
)
|
)
|
||||||
import Test.Framework (defaultMain, Test)
|
import Test.Framework (defaultMain, Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?), assertFailure)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain
|
main = defaultMain
|
||||||
[ testInitializedVariable
|
[ testInitializedVariable
|
||||||
, testInitializedVariableShape
|
, testInitializedVariableShape
|
||||||
|
, testInitializedValue
|
||||||
, testDependency
|
, testDependency
|
||||||
, testRereadRef
|
, testRereadRef
|
||||||
, testAssignAdd
|
, testAssignAdd
|
||||||
|
@ -51,6 +56,18 @@ testInitializedVariableShape =
|
||||||
result <- run (readValue vector)
|
result <- run (readValue vector)
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
testInitializedValue :: Test
|
||||||
|
testInitializedValue =
|
||||||
|
testCase "testInitializedValue" $ runSession $ do
|
||||||
|
initialized <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||||
|
result <- run (initializedValue initialized)
|
||||||
|
liftIO $ Just [42] @=? (result :: Maybe (V.Vector Float))
|
||||||
|
|
||||||
|
uninitialized <- variable [1]
|
||||||
|
-- Can't use @=? because there is no Show instance for Tensor.
|
||||||
|
when (isJust (initializedValue (uninitialized :: Variable Float))) $
|
||||||
|
liftIO $ assertFailure "initializedValue should be Nothing, got Just"
|
||||||
|
|
||||||
testDependency :: Test
|
testDependency :: Test
|
||||||
testDependency =
|
testDependency =
|
||||||
testCase "testDependency" $ runSession $ do
|
testCase "testDependency" $ runSession $ do
|
||||||
|
|
|
@ -89,6 +89,12 @@ instance Nodes t => Nodes [t] where
|
||||||
instance Fetchable t a => Fetchable [t] [a] where
|
instance Fetchable t a => Fetchable [t] [a] where
|
||||||
getFetch ts = sequenceA <$> mapM getFetch ts
|
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||||
|
|
||||||
|
instance Nodes t => Nodes (Maybe t) where
|
||||||
|
getNodes = nodesUnion . fmap getNodes
|
||||||
|
|
||||||
|
instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where
|
||||||
|
getFetch = fmap sequenceA . mapM getFetch
|
||||||
|
|
||||||
instance Nodes ControlNode where
|
instance Nodes ControlNode where
|
||||||
getNodes (ControlNode o) = pure $ Set.singleton o
|
getNodes (ControlNode o) = pure $ Set.singleton o
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue