mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
Fix initialized variables for tensorflow 1.7 (#184)
* Fix initialized variables for tensorflow 1.7 This is needed to support tensorflow 1.7. The trick of initializing a variable with `Shape []` and then overriding the shape by assigning an initial value no longer works. It seems that we need to explicitly flip the unknown_rank bit in the shape proto. I thought about switching opgen to use `Maybe Shape` when an op requires a shape attribute, but that will cause a lot of api churn, so I chose to hold off for now and just do a spot fix to unblock 1.7.
This commit is contained in:
parent
4c6306d914
commit
e35211d49b
3 changed files with 44 additions and 7 deletions
|
@ -61,12 +61,21 @@ variable = variable' id
|
||||||
|
|
||||||
variable' :: forall m a . (MonadBuild m, TensorType a)
|
variable' :: forall m a . (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Shape -> m (Variable a)
|
=> OpParams -> Shape -> m (Variable a)
|
||||||
variable' params s = build $ do
|
variable' params s = variableInternal params (Just s)
|
||||||
|
|
||||||
|
variableInternal :: forall m a . (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Maybe Shape -> m (Variable a)
|
||||||
|
variableInternal params s = build $ do
|
||||||
-- Each variable needs a unique "shared_name". Use MonadFix to
|
-- Each variable needs a unique "shared_name". Use MonadFix to
|
||||||
-- set the attribute to the same name as the variable itself, without
|
-- set the attribute to the same name as the variable itself, without
|
||||||
-- exposing more internals of the Build module.
|
-- exposing more internals of the Build module.
|
||||||
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s)
|
||||||
(tensorType (undefined :: a)) s
|
dtype = tensorType (undefined :: a)
|
||||||
|
-- Generated ops don't support unknown shapes. As a workaround, we
|
||||||
|
-- pass in a rank zero shape and then override it using OpParams.
|
||||||
|
-- TODO: Consider supporting this better in op generation.
|
||||||
|
shape = Shape []
|
||||||
|
t <- CoreOps.varHandleOp' attrs dtype shape
|
||||||
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||||
return $ Variable t Nothing
|
return $ Variable t Nothing
|
||||||
|
|
||||||
|
@ -80,7 +89,7 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Tensor v a -> m (Variable a)
|
=> OpParams -> Tensor v a -> m (Variable a)
|
||||||
initializedVariable' params initializer = do
|
initializedVariable' params initializer = do
|
||||||
-- The shape is not known initially.
|
-- The shape is not known initially.
|
||||||
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
|
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
|
||||||
initializer' <- renderValue initializer
|
initializer' <- renderValue initializer
|
||||||
i <- CoreOps.assignVariableOp h initializer'
|
i <- CoreOps.assignVariableOp h initializer'
|
||||||
addInitializer =<< group i
|
addInitializer =<< group i
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
module Main (main) where
|
module Main (main) where
|
||||||
|
|
||||||
|
import Data.Int (Int32)
|
||||||
import Data.Maybe (isJust)
|
import Data.Maybe (isJust)
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
@ -55,6 +56,8 @@ testInitializedVariableShape =
|
||||||
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
|
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||||
result <- run (readValue vector)
|
result <- run (readValue vector)
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
s <- run (Ops.shape (readValue vector))
|
||||||
|
liftIO $ [1] @=? (s :: V.Vector Int32)
|
||||||
|
|
||||||
testInitializedValue :: Test
|
testInitializedValue :: Test
|
||||||
testInitializedValue =
|
testInitializedValue =
|
||||||
|
|
|
@ -67,13 +67,15 @@ import Data.Functor.Identity (Identity(..))
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Int (Int8, Int16, Int32, Int64)
|
import Data.Int (Int8, Int16, Int32, Int64)
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
import Data.Monoid ((<>))
|
import Data.Monoid ((<>))
|
||||||
|
import Data.ProtoLens.TextFormat (showMessageShort)
|
||||||
import Data.Proxy (Proxy(..))
|
import Data.Proxy (Proxy(..))
|
||||||
import Data.String (IsString)
|
import Data.String (IsString)
|
||||||
import Data.Word (Word8, Word16, Word64)
|
import Data.Word (Word8, Word16, Word64)
|
||||||
import Foreign.Storable (Storable)
|
import Foreign.Storable (Storable)
|
||||||
import GHC.Exts (Constraint, IsList(..))
|
import GHC.Exts (Constraint, IsList(..))
|
||||||
import Lens.Family2 (Lens', view, (&), (.~))
|
import Lens.Family2 (Lens', view, (&), (.~), (^..))
|
||||||
import Lens.Family2.Unchecked (iso)
|
import Lens.Family2.Unchecked (iso)
|
||||||
import Text.Printf (printf)
|
import Text.Printf (printf)
|
||||||
import qualified Data.Attoparsec.ByteString as Atto
|
import qualified Data.Attoparsec.ByteString as Atto
|
||||||
|
@ -113,6 +115,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
( TensorShapeProto(..)
|
( TensorShapeProto(..)
|
||||||
, dim
|
, dim
|
||||||
, size
|
, size
|
||||||
|
, unknownRank
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
|
||||||
|
@ -353,6 +356,9 @@ headFromSingleton x
|
||||||
|
|
||||||
|
|
||||||
-- | Shape (dimensions) of a tensor.
|
-- | Shape (dimensions) of a tensor.
|
||||||
|
--
|
||||||
|
-- TensorFlow supports shapes of unknown rank, which are represented as
|
||||||
|
-- @Nothing :: Maybe Shape@ in Haskell.
|
||||||
newtype Shape = Shape [Int64] deriving Show
|
newtype Shape = Shape [Int64] deriving Show
|
||||||
|
|
||||||
instance IsList Shape where
|
instance IsList Shape where
|
||||||
|
@ -363,8 +369,24 @@ instance IsList Shape where
|
||||||
protoShape :: Lens' TensorShapeProto Shape
|
protoShape :: Lens' TensorShapeProto Shape
|
||||||
protoShape = iso protoToShape shapeToProto
|
protoShape = iso protoToShape shapeToProto
|
||||||
where
|
where
|
||||||
protoToShape = Shape . fmap (view size) . view dim
|
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
||||||
shapeToProto (Shape ds) = (def :: TensorShapeProto) & dim .~ fmap (\d -> def & size .~ d) ds
|
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
||||||
|
++ showMessageShort p
|
||||||
|
shapeToProto s' = def & protoMaybeShape .~ Just s'
|
||||||
|
|
||||||
|
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
||||||
|
protoMaybeShape = iso protoToShape shapeToProto
|
||||||
|
where
|
||||||
|
protoToShape :: TensorShapeProto -> Maybe Shape
|
||||||
|
protoToShape p =
|
||||||
|
if view unknownRank p
|
||||||
|
then Nothing
|
||||||
|
else Just (Shape (p ^.. dim . traverse . size))
|
||||||
|
shapeToProto :: Maybe Shape -> TensorShapeProto
|
||||||
|
shapeToProto Nothing =
|
||||||
|
def & unknownRank .~ True
|
||||||
|
shapeToProto (Just (Shape ds)) =
|
||||||
|
def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||||
|
|
||||||
|
|
||||||
class Attribute a where
|
class Attribute a where
|
||||||
|
@ -391,6 +413,9 @@ instance Attribute Bool where
|
||||||
instance Attribute Shape where
|
instance Attribute Shape where
|
||||||
attrLens = shape . protoShape
|
attrLens = shape . protoShape
|
||||||
|
|
||||||
|
instance Attribute (Maybe Shape) where
|
||||||
|
attrLens = shape . protoMaybeShape
|
||||||
|
|
||||||
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
||||||
instance Attribute AttrValue'ListValue where
|
instance Attribute AttrValue'ListValue where
|
||||||
attrLens = list
|
attrLens = list
|
||||||
|
|
Loading…
Reference in a new issue