diff --git a/tensorflow-ops/src/TensorFlow/Variable.hs b/tensorflow-ops/src/TensorFlow/Variable.hs index 6c752db..ce86c55 100644 --- a/tensorflow-ops/src/TensorFlow/Variable.hs +++ b/tensorflow-ops/src/TensorFlow/Variable.hs @@ -61,12 +61,21 @@ variable = variable' id variable' :: forall m a . (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Variable a) -variable' params s = build $ do +variable' params s = variableInternal params (Just s) + +variableInternal :: forall m a . (MonadBuild m, TensorType a) + => OpParams -> Maybe Shape -> m (Variable a) +variableInternal params s = build $ do -- Each variable needs a unique "shared_name". Use MonadFix to -- set the attribute to the same name as the variable itself, without -- exposing more internals of the Build module. - rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n)) - (tensorType (undefined :: a)) s + rec let attrs = params . (opAttr "shared_name" .~ n) . (opAttr "shape" .~ s) + dtype = tensorType (undefined :: a) + -- Generated ops don't support unknown shapes. As a workaround, we + -- pass in a rank zero shape and then override it using OpParams. + -- TODO: Consider supporting this better in op generation. + shape = Shape [] + t <- CoreOps.varHandleOp' attrs dtype shape let n = encodeUtf8 $ unNodeName $ tensorNodeName t return $ Variable t Nothing @@ -80,7 +89,7 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a) => OpParams -> Tensor v a -> m (Variable a) initializedVariable' params initializer = do -- The shape is not known initially. - (Variable h Nothing :: Variable a) <- variable' params (Shape []) + (Variable h Nothing :: Variable a) <- variableInternal params Nothing initializer' <- renderValue initializer i <- CoreOps.assignVariableOp h initializer' addInitializer =<< group i diff --git a/tensorflow-ops/tests/VariableTest.hs b/tensorflow-ops/tests/VariableTest.hs index 0a8f4e9..6726af1 100644 --- a/tensorflow-ops/tests/VariableTest.hs +++ b/tensorflow-ops/tests/VariableTest.hs @@ -1,6 +1,7 @@ {-# LANGUAGE OverloadedLists #-} module Main (main) where +import Data.Int (Int32) import Data.Maybe (isJust) import Control.Monad (when) import Control.Monad.IO.Class (liftIO) @@ -55,6 +56,8 @@ testInitializedVariableShape = vector <- initializedVariable (Ops.constant [1] [42 :: Float]) result <- run (readValue vector) liftIO $ [42] @=? (result :: V.Vector Float) + s <- run (Ops.shape (readValue vector)) + liftIO $ [1] @=? (s :: V.Vector Int32) testInitializedValue :: Test testInitializedValue = diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 8c82cce..5d8215c 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -67,13 +67,15 @@ import Data.Functor.Identity (Identity(..)) import Data.Complex (Complex) import Data.Default (def) import Data.Int (Int8, Int16, Int32, Int64) +import Data.Maybe (fromMaybe) import Data.Monoid ((<>)) +import Data.ProtoLens.TextFormat (showMessageShort) import Data.Proxy (Proxy(..)) import Data.String (IsString) import Data.Word (Word8, Word16, Word64) import Foreign.Storable (Storable) import GHC.Exts (Constraint, IsList(..)) -import Lens.Family2 (Lens', view, (&), (.~)) +import Lens.Family2 (Lens', view, (&), (.~), (^..)) import Lens.Family2.Unchecked (iso) import Text.Printf (printf) import qualified Data.Attoparsec.ByteString as Atto @@ -113,6 +115,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape ( TensorShapeProto(..) , dim , size + , unknownRank ) import Proto.Tensorflow.Core.Framework.Types (DataType(..)) @@ -353,6 +356,9 @@ headFromSingleton x -- | Shape (dimensions) of a tensor. +-- +-- TensorFlow supports shapes of unknown rank, which are represented as +-- @Nothing :: Maybe Shape@ in Haskell. newtype Shape = Shape [Int64] deriving Show instance IsList Shape where @@ -363,8 +369,24 @@ instance IsList Shape where protoShape :: Lens' TensorShapeProto Shape protoShape = iso protoToShape shapeToProto where - protoToShape = Shape . fmap (view size) . view dim - shapeToProto (Shape ds) = (def :: TensorShapeProto) & dim .~ fmap (\d -> def & size .~ d) ds + protoToShape p = fromMaybe (error msg) (view protoMaybeShape p) + where msg = "Can't convert TensorShapeProto with unknown rank to Shape: " + ++ showMessageShort p + shapeToProto s' = def & protoMaybeShape .~ Just s' + +protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape) +protoMaybeShape = iso protoToShape shapeToProto + where + protoToShape :: TensorShapeProto -> Maybe Shape + protoToShape p = + if view unknownRank p + then Nothing + else Just (Shape (p ^.. dim . traverse . size)) + shapeToProto :: Maybe Shape -> TensorShapeProto + shapeToProto Nothing = + def & unknownRank .~ True + shapeToProto (Just (Shape ds)) = + def & dim .~ fmap (\d -> def & size .~ d) ds class Attribute a where @@ -391,6 +413,9 @@ instance Attribute Bool where instance Attribute Shape where attrLens = shape . protoShape +instance Attribute (Maybe Shape) where + attrLens = shape . protoMaybeShape + -- TODO(gnezdo): support generating list(Foo) from [Foo]. instance Attribute AttrValue'ListValue where attrLens = list