Fix initialized variables for tensorflow 1.7 (#184)

* Fix initialized variables for tensorflow 1.7

This is needed to support tensorflow 1.7. The trick of initializing a
variable with `Shape []` and then overriding the shape by assigning an
initial value no longer works. It seems that we need to explicitly flip
the unknown_rank bit in the shape proto.

I thought about switching opgen to use `Maybe Shape` when an op requires
a shape attribute, but that will cause a lot of api churn, so I chose to
hold off for now and just do a spot fix to unblock 1.7.
This commit is contained in:
fkm3 2018-04-16 10:48:05 -04:00 committed by Greg Steuck
parent 4c6306d914
commit e35211d49b
3 changed files with 44 additions and 7 deletions

View File

@ -61,12 +61,21 @@ variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a) variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a) => OpParams -> Shape -> m (Variable a)
variable' params s = build $ do variable' params s = variableInternal params (Just s)
variableInternal :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Maybe Shape -> m (Variable a)
variableInternal params s = build $ do
-- Each variable needs a unique "shared_name". Use MonadFix to -- Each variable needs a unique "shared_name". Use MonadFix to
-- set the attribute to the same name as the variable itself, without -- set the attribute to the same name as the variable itself, without
-- exposing more internals of the Build module. -- exposing more internals of the Build module.
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n)) rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s)
(tensorType (undefined :: a)) s dtype = tensorType (undefined :: a)
-- Generated ops don't support unknown shapes. As a workaround, we
-- pass in a rank zero shape and then override it using OpParams.
-- TODO: Consider supporting this better in op generation.
shape = Shape []
t <- CoreOps.varHandleOp' attrs dtype shape
let n = encodeUtf8 $ unNodeName $ tensorNodeName t let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t Nothing return $ Variable t Nothing
@ -80,7 +89,7 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a) => OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do initializedVariable' params initializer = do
-- The shape is not known initially. -- The shape is not known initially.
(Variable h Nothing :: Variable a) <- variable' params (Shape []) (Variable h Nothing :: Variable a) <- variableInternal params Nothing
initializer' <- renderValue initializer initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer' i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i addInitializer =<< group i

View File

@ -1,6 +1,7 @@
{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedLists #-}
module Main (main) where module Main (main) where
import Data.Int (Int32)
import Data.Maybe (isJust) import Data.Maybe (isJust)
import Control.Monad (when) import Control.Monad (when)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
@ -55,6 +56,8 @@ testInitializedVariableShape =
vector <- initializedVariable (Ops.constant [1] [42 :: Float]) vector <- initializedVariable (Ops.constant [1] [42 :: Float])
result <- run (readValue vector) result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float) liftIO $ [42] @=? (result :: V.Vector Float)
s <- run (Ops.shape (readValue vector))
liftIO $ [1] @=? (s :: V.Vector Int32)
testInitializedValue :: Test testInitializedValue :: Test
testInitializedValue = testInitializedValue =

View File

@ -67,13 +67,15 @@ import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex) import Data.Complex (Complex)
import Data.Default (def) import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64) import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>)) import Data.Monoid ((<>))
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..)) import Data.Proxy (Proxy(..))
import Data.String (IsString) import Data.String (IsString)
import Data.Word (Word8, Word16, Word64) import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable) import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..)) import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~)) import Lens.Family2 (Lens', view, (&), (.~), (^..))
import Lens.Family2.Unchecked (iso) import Lens.Family2.Unchecked (iso)
import Text.Printf (printf) import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto import qualified Data.Attoparsec.ByteString as Atto
@ -113,6 +115,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..) ( TensorShapeProto(..)
, dim , dim
, size , size
, unknownRank
) )
import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import Proto.Tensorflow.Core.Framework.Types (DataType(..))
@ -353,6 +356,9 @@ headFromSingleton x
-- | Shape (dimensions) of a tensor. -- | Shape (dimensions) of a tensor.
--
-- TensorFlow supports shapes of unknown rank, which are represented as
-- @Nothing :: Maybe Shape@ in Haskell.
newtype Shape = Shape [Int64] deriving Show newtype Shape = Shape [Int64] deriving Show
instance IsList Shape where instance IsList Shape where
@ -363,8 +369,24 @@ instance IsList Shape where
protoShape :: Lens' TensorShapeProto Shape protoShape :: Lens' TensorShapeProto Shape
protoShape = iso protoToShape shapeToProto protoShape = iso protoToShape shapeToProto
where where
protoToShape = Shape . fmap (view size) . view dim protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
shapeToProto (Shape ds) = (def :: TensorShapeProto) & dim .~ fmap (\d -> def & size .~ d) ds where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
++ showMessageShort p
shapeToProto s' = def & protoMaybeShape .~ Just s'
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape = iso protoToShape shapeToProto
where
protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape p =
if view unknownRank p
then Nothing
else Just (Shape (p ^.. dim . traverse . size))
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
def & unknownRank .~ True
shapeToProto (Just (Shape ds)) =
def & dim .~ fmap (\d -> def & size .~ d) ds
class Attribute a where class Attribute a where
@ -391,6 +413,9 @@ instance Attribute Bool where
instance Attribute Shape where instance Attribute Shape where
attrLens = shape . protoShape attrLens = shape . protoShape
instance Attribute (Maybe Shape) where
attrLens = shape . protoMaybeShape
-- TODO(gnezdo): support generating list(Foo) from [Foo]. -- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where instance Attribute AttrValue'ListValue where
attrLens = list attrLens = list