1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-03 16:29:46 +01:00

Fix initialized variables for tensorflow 1.7 (#184)

* Fix initialized variables for tensorflow 1.7

This is needed to support tensorflow 1.7. The trick of initializing a
variable with `Shape []` and then overriding the shape by assigning an
initial value no longer works. It seems that we need to explicitly flip
the unknown_rank bit in the shape proto.

I thought about switching opgen to use `Maybe Shape` when an op requires
a shape attribute, but that will cause a lot of api churn, so I chose to
hold off for now and just do a spot fix to unblock 1.7.
This commit is contained in:
fkm3 2018-04-16 10:48:05 -04:00 committed by Greg Steuck
parent 4c6306d914
commit e35211d49b
3 changed files with 44 additions and 7 deletions

View file

@ -61,12 +61,21 @@ variable = variable' id
variable' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Variable a)
variable' params s = build $ do
variable' params s = variableInternal params (Just s)
variableInternal :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Maybe Shape -> m (Variable a)
variableInternal params s = build $ do
-- Each variable needs a unique "shared_name". Use MonadFix to
-- set the attribute to the same name as the variable itself, without
-- exposing more internals of the Build module.
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
(tensorType (undefined :: a)) s
rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s)
dtype = tensorType (undefined :: a)
-- Generated ops don't support unknown shapes. As a workaround, we
-- pass in a rank zero shape and then override it using OpParams.
-- TODO: Consider supporting this better in op generation.
shape = Shape []
t <- CoreOps.varHandleOp' attrs dtype shape
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
return $ Variable t Nothing
@ -80,7 +89,7 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
=> OpParams -> Tensor v a -> m (Variable a)
initializedVariable' params initializer = do
-- The shape is not known initially.
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
initializer' <- renderValue initializer
i <- CoreOps.assignVariableOp h initializer'
addInitializer =<< group i

View file

@ -1,6 +1,7 @@
{-# LANGUAGE OverloadedLists #-}
module Main (main) where
import Data.Int (Int32)
import Data.Maybe (isJust)
import Control.Monad (when)
import Control.Monad.IO.Class (liftIO)
@ -55,6 +56,8 @@ testInitializedVariableShape =
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
result <- run (readValue vector)
liftIO $ [42] @=? (result :: V.Vector Float)
s <- run (Ops.shape (readValue vector))
liftIO $ [1] @=? (s :: V.Vector Int32)
testInitializedValue :: Test
testInitializedValue =

View file

@ -67,13 +67,15 @@ import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2 (Lens', view, (&), (.~), (^..))
import Lens.Family2.Unchecked (iso)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
@ -113,6 +115,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
, dim
, size
, unknownRank
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
@ -353,6 +356,9 @@ headFromSingleton x
-- | Shape (dimensions) of a tensor.
--
-- TensorFlow supports shapes of unknown rank, which are represented as
-- @Nothing :: Maybe Shape@ in Haskell.
newtype Shape = Shape [Int64] deriving Show
instance IsList Shape where
@ -363,8 +369,24 @@ instance IsList Shape where
protoShape :: Lens' TensorShapeProto Shape
protoShape = iso protoToShape shapeToProto
where
protoToShape = Shape . fmap (view size) . view dim
shapeToProto (Shape ds) = (def :: TensorShapeProto) & dim .~ fmap (\d -> def & size .~ d) ds
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
++ showMessageShort p
shapeToProto s' = def & protoMaybeShape .~ Just s'
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape = iso protoToShape shapeToProto
where
protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape p =
if view unknownRank p
then Nothing
else Just (Shape (p ^.. dim . traverse . size))
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
def & unknownRank .~ True
shapeToProto (Just (Shape ds)) =
def & dim .~ fmap (\d -> def & size .~ d) ds
class Attribute a where
@ -391,6 +413,9 @@ instance Attribute Bool where
instance Attribute Shape where
attrLens = shape . protoShape
instance Attribute (Maybe Shape) where
attrLens = shape . protoMaybeShape
-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
attrLens = list