From a64af5076af0768c35218fabbecb6c5263ee25ab Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Mon, 8 May 2017 17:45:56 -0700 Subject: [PATCH] Work around #92 by always copying TensorData when fetching. It would be better to avoid the copy when it's not necessary, but that will require more involved changes to the internal API. (For example, Fetchable might need to allow IO or ST actions.) --- tensorflow-ops/tests/OpsTest.hs | 12 ++++++++++++ tensorflow/src/TensorFlow/Internal/FFI.hs | 15 +++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index de35828..4824d6f 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 +-- | See https://github.com/tensorflow/haskell/issues/92. +-- Even though we're not explicitly evaluating `f0` until the end, +-- it should hold the earlier value of the variable. +testRereadRef :: Test +testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do + w <- TF.initializedVariable 0 + f0 <- TF.run w + TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float)) + f1 <- TF.run w + liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1) + main :: IO () main = googleTest [ testSaveRestore , testSize , testReducedShape , testPlaceholderCse , testScalarFeedCse + , testRereadRef ] diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 03efe4d..59a79b3 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -42,7 +42,7 @@ import Data.Typeable (Typeable) import Data.Word (Word8) import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign.C.String (CString) -import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr) import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import System.IO.Unsafe (unsafePerformIO) @@ -51,7 +51,7 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T import qualified Data.Vector.Storable as S -import qualified Foreign.Concurrent as ForeignC +import qualified Data.Vector.Storable.Mutable as M import Data.ProtoLens (Message, encodeMessage) import Proto.Tensorflow.Core.Framework.Graph (GraphDef) @@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x -- | Create a TensorData from a Raw.Tensor. -- -- Takes ownership of the Raw.Tensor. +-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it), +-- since the raw pointer may refer to storage inside a mutable TensorFlow +-- variable. We should avoid that copy when it's not needed; for example, +-- by making TensorData wrap an IOVector, and changing the code that uses it. createTensorData :: Raw.Tensor -> IO TensorData createTensorData t = do -- Read dimensions. @@ -203,8 +207,11 @@ createTensorData t = do -- Read data. len <- safeConvert <$> Raw.tensorByteSize t bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) - fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t) - let v = S.unsafeFromForeignPtr0 fp len + fp <- newForeignPtr_ bytes + -- Make an explicit copy of the raw data, since it might point + -- to a mutable variable's memory. + v <- S.freeze (M.unsafeFromForeignPtr0 fp len) + Raw.deleteTensor t return $ TensorData (map safeConvert dims) dtype v -- | Runs the given action which does FFI calls updating a provided