1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 02:29:46 +01:00

Work around #92 by always copying TensorData when fetching.

It would be better to avoid the copy when it's not necessary, but
that will require more involved changes to the internal API.  (For example,
Fetchable might need to allow IO or ST actions.)
This commit is contained in:
Judah Jacobson 2017-05-08 17:45:56 -07:00 committed by fkm3
parent 37e3c9b084
commit a64af5076a
2 changed files with 23 additions and 4 deletions

View file

@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5
-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
w <- TF.initializedVariable 0
f0 <- TF.run w
TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float))
f1 <- TF.run w
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
, testPlaceholderCse
, testScalarFeedCse
, testRereadRef
]

View file

@ -42,7 +42,7 @@ import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
@ -51,7 +51,7 @@ import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Foreign.Concurrent as ForeignC
import qualified Data.Vector.Storable.Mutable as M
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
-- since the raw pointer may refer to storage inside a mutable TensorFlow
-- variable. We should avoid that copy when it's not needed; for example,
-- by making TensorData wrap an IOVector, and changing the code that uses it.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
@ -203,8 +207,11 @@ createTensorData t = do
-- Read data.
len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
let v = S.unsafeFromForeignPtr0 fp len
fp <- newForeignPtr_ bytes
-- Make an explicit copy of the raw data, since it might point
-- to a mutable variable's memory.
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
Raw.deleteTensor t
return $ TensorData (map safeConvert dims) dtype v
-- | Runs the given action which does FFI calls updating a provided