1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Prevent CSE of placeholder ops. (#86)

The bug was introduced in #84.
This commit is contained in:
Judah Jacobson 2017-03-22 22:47:42 -07:00 committed by fkm3
parent c99a23b6a7
commit fdbfd050f8
2 changed files with 20 additions and 2 deletions

View file

@ -203,9 +203,15 @@ matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder = placeholder' id placeholder = placeholder' id
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a) placeholder' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Tensor Value a)
placeholder' params pShape placeholder' params pShape
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape)) -- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
-- and thus would be CSE'd.
= build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ pShape
& params
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.

View file

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Main where module Main where
@ -68,9 +69,20 @@ testSaveRestore = testCase "testSaveRestore" $
TF.run v TF.run v
liftIO $ TF.Scalar 134 @=? result liftIO $ TF.Scalar 134 @=? result
-- | Test that 'placeholder' is not CSE'd.
testPlaceholderCse :: Test
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
p1 <- TF.placeholder []
p2 <- TF.placeholder []
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
liftIO $ result @=? TF.Scalar 5
main :: IO () main :: IO ()
main = googleTest [ testSaveRestore main = googleTest [ testSaveRestore
, testSize , testSize
, testReducedShape , testReducedShape
, testPlaceholderCse
] ]