1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Add another test of CSE and feeds. (#87)

As a follow-up to #86, check that our CSE isn't too aggressive to prevent feeds
of pure ops with distinct names.
This commit is contained in:
Judah Jacobson 2017-03-23 12:58:40 -07:00 committed by fkm3
parent fdbfd050f8
commit a11a417ad5

View file

@ -79,10 +79,24 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
liftIO $ result @=? TF.Scalar 5
-- | Test that regular tensors can also be used for feeds, as long as they each
-- have a different name.
testScalarFeedCse :: Test
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
-- The second op is identical to the first other than its name; make sure
-- we don't aggressively CSE them together and prevent feeding them
-- separately.
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
liftIO $ result @=? TF.Scalar 5
main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
, testPlaceholderCse
, testScalarFeedCse
]