1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Add initializedValue function for Variable (#124)

This commit is contained in:
fkm3 2017-05-20 21:42:45 -07:00 committed by GitHub
parent 8e136d3a9c
commit a86d424cac
4 changed files with 42 additions and 12 deletions

View file

@ -4,6 +4,7 @@
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
`Variables`.
- Add `initializedValue` function for `Variable`.
## v0.1.0.2
- Add extra-lib-dirs for OS X in the Hackage release (#122).

View file

@ -14,6 +14,7 @@ module TensorFlow.Variable
, variable
, variable'
, readValue
, initializedValue
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
@ -30,15 +31,19 @@ import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), renderValue, tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
newtype Variable a = Variable (Tensor Value ResourceHandle)
data Variable a = Variable
{ variableHandle :: Tensor Value ResourceHandle
, initializedValue :: Maybe (Tensor Value a)
-- ^ The initial value of a 'Variable' created with 'initializedVariable'.
}
instance Rendered Variable where
renderedOutput (Variable v) = renderedOutput v
renderedOutput = renderedOutput . variableHandle
instance ToTensor Variable where
toTensor = readValue
@ -56,7 +61,7 @@ variable' params s = build $ do
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
(tensorType (undefined :: a)) s
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t
return $ Variable t Nothing
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
@ -68,10 +73,11 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
-- The shape is not known initially.
v@(Variable h) <- variable' params (Shape [])
i <- CoreOps.assignVariableOp h initializer
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i
return v
return (Variable h (Just initializer'))
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
@ -102,7 +108,7 @@ readValue = readValue' id
readValue' :: forall a . TensorType a
=> OpParams -> Variable a -> Tensor Build a
readValue' params (Variable h)
readValue' params (Variable h _)
= pureOp [] $ do
os <- buildInputs h
pure $ opDef "ReadVariableOp"
@ -117,7 +123,7 @@ assign = assign' id
assign' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
assign' params (Variable h _) v = CoreOps.assignVariableOp' params h v
-- | Increments the value of a variable.
assignAdd :: (MonadBuild m, TensorType a)
@ -126,4 +132,4 @@ assignAdd = assignAdd' id
assignAdd' :: (MonadBuild m, TensorType a)
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v

View file

@ -1,6 +1,8 @@
{-# LANGUAGE OverloadedLists #-}
module Main (main) where
import Data.Maybe (isJust)
import Control.Monad (when)
import Control.Monad.IO.Class (liftIO)
import qualified Data.Vector.Storable as V
import TensorFlow.Core
@ -12,7 +14,9 @@ import TensorFlow.Core
, withControlDependencies)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Variable
( readValue
( Variable
, readValue
, initializedValue
, initializedVariable
, assign
, assignAdd
@ -20,12 +24,13 @@ import TensorFlow.Variable
)
import Test.Framework (defaultMain, Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))
import Test.HUnit ((@=?), assertFailure)
main :: IO ()
main = defaultMain
[ testInitializedVariable
, testInitializedVariableShape
, testInitializedValue
, testDependency
, testRereadRef
, testAssignAdd
@ -51,6 +56,18 @@ testInitializedVariableShape =
result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float)
testInitializedValue :: Test
testInitializedValue =
testCase "testInitializedValue" $ runSession $ do
initialized <- initializedVariable (Ops.constant [1] [42 :: Float])
result <- run (initializedValue initialized)
liftIO $ Just [42] @=? (result :: Maybe (V.Vector Float))
uninitialized <- variable [1]
-- Can't use @=? because there is no Show instance for Tensor.
when (isJust (initializedValue (uninitialized :: Variable Float))) $
liftIO $ assertFailure "initializedValue should be Nothing, got Just"
testDependency :: Test
testDependency =
testCase "testDependency" $ runSession $ do

View file

@ -89,6 +89,12 @@ instance Nodes t => Nodes [t] where
instance Fetchable t a => Fetchable [t] [a] where
getFetch ts = sequenceA <$> mapM getFetch ts
instance Nodes t => Nodes (Maybe t) where
getNodes = nodesUnion . fmap getNodes
instance Fetchable t a => Fetchable (Maybe t) (Maybe a) where
getFetch = fmap sequenceA . mapM getFetch
instance Nodes ControlNode where
getNodes (ControlNode o) = pure $ Set.singleton o