mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 03:05:01 +01:00
Fix initialized variables for tensorflow 1.7 (#184)
* Fix initialized variables for tensorflow 1.7 This is needed to support tensorflow 1.7. The trick of initializing a variable with `Shape []` and then overriding the shape by assigning an initial value no longer works. It seems that we need to explicitly flip the unknown_rank bit in the shape proto. I thought about switching opgen to use `Maybe Shape` when an op requires a shape attribute, but that will cause a lot of api churn, so I chose to hold off for now and just do a spot fix to unblock 1.7.
This commit is contained in:
parent
4c6306d914
commit
e35211d49b
3 changed files with 44 additions and 7 deletions
|
@ -61,12 +61,21 @@ variable = variable' id
|
|||
|
||||
variable' :: forall m a . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Shape -> m (Variable a)
|
||||
variable' params s = build $ do
|
||||
variable' params s = variableInternal params (Just s)
|
||||
|
||||
variableInternal :: forall m a . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Maybe Shape -> m (Variable a)
|
||||
variableInternal params s = build $ do
|
||||
-- Each variable needs a unique "shared_name". Use MonadFix to
|
||||
-- set the attribute to the same name as the variable itself, without
|
||||
-- exposing more internals of the Build module.
|
||||
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||||
(tensorType (undefined :: a)) s
|
||||
rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s)
|
||||
dtype = tensorType (undefined :: a)
|
||||
-- Generated ops don't support unknown shapes. As a workaround, we
|
||||
-- pass in a rank zero shape and then override it using OpParams.
|
||||
-- TODO: Consider supporting this better in op generation.
|
||||
shape = Shape []
|
||||
t <- CoreOps.varHandleOp' attrs dtype shape
|
||||
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||
return $ Variable t Nothing
|
||||
|
||||
|
@ -80,7 +89,7 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
|||
=> OpParams -> Tensor v a -> m (Variable a)
|
||||
initializedVariable' params initializer = do
|
||||
-- The shape is not known initially.
|
||||
(Variable h Nothing :: Variable a) <- variable' params (Shape [])
|
||||
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
|
||||
initializer' <- renderValue initializer
|
||||
i <- CoreOps.assignVariableOp h initializer'
|
||||
addInitializer =<< group i
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
{-# LANGUAGE OverloadedLists #-}
|
||||
module Main (main) where
|
||||
|
||||
import Data.Int (Int32)
|
||||
import Data.Maybe (isJust)
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
|
@ -55,6 +56,8 @@ testInitializedVariableShape =
|
|||
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||
result <- run (readValue vector)
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
s <- run (Ops.shape (readValue vector))
|
||||
liftIO $ [1] @=? (s :: V.Vector Int32)
|
||||
|
||||
testInitializedValue :: Test
|
||||
testInitializedValue =
|
||||
|
|
|
@ -67,13 +67,15 @@ import Data.Functor.Identity (Identity(..))
|
|||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.ProtoLens.TextFormat (showMessageShort)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Data.String (IsString)
|
||||
import Data.Word (Word8, Word16, Word64)
|
||||
import Foreign.Storable (Storable)
|
||||
import GHC.Exts (Constraint, IsList(..))
|
||||
import Lens.Family2 (Lens', view, (&), (.~))
|
||||
import Lens.Family2 (Lens', view, (&), (.~), (^..))
|
||||
import Lens.Family2.Unchecked (iso)
|
||||
import Text.Printf (printf)
|
||||
import qualified Data.Attoparsec.ByteString as Atto
|
||||
|
@ -113,6 +115,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape
|
|||
( TensorShapeProto(..)
|
||||
, dim
|
||||
, size
|
||||
, unknownRank
|
||||
)
|
||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
|
||||
|
@ -353,6 +356,9 @@ headFromSingleton x
|
|||
|
||||
|
||||
-- | Shape (dimensions) of a tensor.
|
||||
--
|
||||
-- TensorFlow supports shapes of unknown rank, which are represented as
|
||||
-- @Nothing :: Maybe Shape@ in Haskell.
|
||||
newtype Shape = Shape [Int64] deriving Show
|
||||
|
||||
instance IsList Shape where
|
||||
|
@ -363,8 +369,24 @@ instance IsList Shape where
|
|||
protoShape :: Lens' TensorShapeProto Shape
|
||||
protoShape = iso protoToShape shapeToProto
|
||||
where
|
||||
protoToShape = Shape . fmap (view size) . view dim
|
||||
shapeToProto (Shape ds) = (def :: TensorShapeProto) & dim .~ fmap (\d -> def & size .~ d) ds
|
||||
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
||||
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
||||
++ showMessageShort p
|
||||
shapeToProto s' = def & protoMaybeShape .~ Just s'
|
||||
|
||||
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
||||
protoMaybeShape = iso protoToShape shapeToProto
|
||||
where
|
||||
protoToShape :: TensorShapeProto -> Maybe Shape
|
||||
protoToShape p =
|
||||
if view unknownRank p
|
||||
then Nothing
|
||||
else Just (Shape (p ^.. dim . traverse . size))
|
||||
shapeToProto :: Maybe Shape -> TensorShapeProto
|
||||
shapeToProto Nothing =
|
||||
def & unknownRank .~ True
|
||||
shapeToProto (Just (Shape ds)) =
|
||||
def & dim .~ fmap (\d -> def & size .~ d) ds
|
||||
|
||||
|
||||
class Attribute a where
|
||||
|
@ -391,6 +413,9 @@ instance Attribute Bool where
|
|||
instance Attribute Shape where
|
||||
attrLens = shape . protoShape
|
||||
|
||||
instance Attribute (Maybe Shape) where
|
||||
attrLens = shape . protoMaybeShape
|
||||
|
||||
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
||||
instance Attribute AttrValue'ListValue where
|
||||
attrLens = list
|
||||
|
|
Loading…
Add table
Reference in a new issue