mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add initializedValue function for Variable (#124)
This commit is contained in:
parent
8e136d3a9c
commit
a86d424cac
4 changed files with 42 additions and 12 deletions
|
@ -4,6 +4,7 @@
|
|||
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
|
||||
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
|
||||
`Variables`.
|
||||
- Add `initializedValue` function for `Variable`.
|
||||
|
||||
## v0.1.0.2
|
||||
- Add extra-lib-dirs for OS X in the Hackage release (#122).
|
||||
|
|
|
@ -14,6 +14,7 @@ module TensorFlow.Variable
|
|||
, variable
|
||||
, variable'
|
||||
, readValue
|
||||
, initializedValue
|
||||
, initializedVariable
|
||||
, initializedVariable'
|
||||
, zeroInitializedVariable
|
||||
|
@ -30,15 +31,19 @@ import TensorFlow.Core
|
|||
import TensorFlow.Build (opDef)
|
||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||
import TensorFlow.Output (opInputs, unNodeName)
|
||||
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
|
||||
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
|
||||
import TensorFlow.Types (tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Ops (zeros)
|
||||
|
||||
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||
data Variable a = Variable
|
||||
{ variableHandle :: Tensor Value ResourceHandle
|
||||
, initializedValue :: Maybe (Tensor Value a)
|
||||
-- ^ The initial value of a 'Variable' created with 'initializedVariable'.
|
||||
}
|
||||
|
||||
instance Rendered Variable where
|
||||
renderedOutput (Variable v) = renderedOutput v
|
||||
renderedOutput = renderedOutput . variableHandle
|
||||
|
||||
instance ToTensor Variable where
|
||||
toTensor = readValue
|
||||
|
@ -56,7 +61,7 @@ variable' params s = build $ do
|
|||
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||||
(tensorType (undefined :: a)) s
|
||||
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||
return $ Variable t
|
||||
return $ Variable t Nothing
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
|
@ -68,10 +73,11 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
|||
=> OpParams -> Tensor v a -> m (Variable a)
|
||||
initializedVariable' params initializer = do
|
||||
-- The shape is not known initially.
|
||||
v@(Variable h) <- variable' params (Shape [])
|
||||
i <- CoreOps.assignVariableOp h initializer
|
||||
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
|
||||
initializer' <- renderValue initializer
|
||||
i <- CoreOps.assignVariableOp h initializer'
|
||||
addInitializer =<< group i
|
||||
return v
|
||||
return (Variable h (Just initializer'))
|
||||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
|
@ -102,7 +108,7 @@ readValue = readValue' id
|
|||
|
||||
readValue' :: forall a . TensorType a
|
||||
=> OpParams -> Variable a -> Tensor Build a
|
||||
readValue' params (Variable h)
|
||||
readValue' params (Variable h _)
|
||||
= pureOp [] $ do
|
||||
os <- buildInputs h
|
||||
pure $ opDef "ReadVariableOp"
|
||||
|
@ -117,7 +123,7 @@ assign = assign' id
|
|||
|
||||
assign' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
|
||||
assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v
|
||||
|
||||
-- | Increments the value of a variable.
|
||||
assignAdd :: (MonadBuild m, TensorType a)
|
||||
|
@ -126,4 +132,4 @@ assignAdd = assignAdd' id
|
|||
|
||||
assignAdd' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
|
||||
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
{-# LANGUAGE OverloadedLists #-}
|
||||
module Main (main) where
|
||||
|
||||
import Data.Maybe (isJust)
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import qualified Data.Vector.Storable as V
|
||||
import TensorFlow.Core
|
||||
|
@ -12,7 +14,9 @@ import TensorFlow.Core
|
|||
, withControlDependencies)
|
||||
import qualified TensorFlow.Ops as Ops
|
||||
import TensorFlow.Variable
|
||||
( readValue
|
||||
( Variable
|
||||
, readValue
|
||||
, initializedValue
|
||||
, initializedVariable
|
||||
, assign
|
||||
, assignAdd
|
||||
|
@ -20,12 +24,13 @@ import TensorFlow.Variable
|
|||
)
|
||||
import Test.Framework (defaultMain, Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import Test.HUnit ((@=?))
|
||||
import Test.HUnit ((@=?), assertFailure)
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testInitializedValue
|
||||
, testDependency
|
||||
, testRereadRef
|
||||
, testAssignAdd
|
||||
|
@ -51,6 +56,18 @@ testInitializedVariableShape =
|
|||
result <- run (readValue vector)
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
testInitializedValue :: Test
|
||||
testInitializedValue =
|
||||
testCase "testInitializedValue" $ runSession $ do
|
||||
initialized <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||
result <- run (initializedValue initialized)
|
||||
liftIO $ Just [42] @=? (result :: Maybe (V.Vector Float))
|
||||
|
||||
uninitialized <- variable [1]
|
||||
-- Can't use @=? because there is no Show instance for Tensor.
|
||||
when (isJust (initializedValue (uninitialized :: Variable Float))) $
|
||||
liftIO $ assertFailure "initializedValue should be Nothing, got Just"
|
||||
|
||||
testDependency :: Test
|
||||
testDependency =
|
||||
testCase "testDependency" $ runSession $ do
|
||||
|
|
|
@ -89,6 +89,12 @@ instance Nodes t => Nodes [t] where
|
|||
instance Fetchable t a => Fetchable [t] [a] where
|
||||
getFetch ts = sequenceA <$> mapM getFetch ts
|
||||
|
||||
instance Nodes t => Nodes (Maybe t) where
|
||||
getNodes = nodesUnion . fmap getNodes
|
||||
|
||||
instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where
|
||||
getFetch = fmap sequenceA . mapM getFetch
|
||||
|
||||
instance Nodes ControlNode where
|
||||
getNodes (ControlNode o) = pure $ Set.singleton o
|
||||
|
||||
|
|
Loading…
Reference in a new issue