From a11a417ad52345978b4e1768c0c59f01476e1703 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Thu, 23 Mar 2017 12:58:40 -0700 Subject: [PATCH] Add another test of CSE and feeds. (#87) As a follow-up to #86, check that our CSE isn't too aggressive to prevent feeds of pure ops with distinct names. --- tensorflow-ops/tests/OpsTest.hs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 26adc39..a117015 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -79,10 +79,24 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 liftIO $ result @=? TF.Scalar 5 +-- | Test that regular tensors can also be used for feeds, as long as they each +-- have a different name. +testScalarFeedCse :: Test +testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do + p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0 + -- The second op is identical to the first other than its name; make sure + -- we don't aggressively CSE them together and prevent feeding them + -- separately. + p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0 + let enc :: Float -> TF.TensorData Float + enc n = TF.encodeTensorData [] (V.fromList [n]) + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + liftIO $ result @=? TF.Scalar 5 main :: IO () main = googleTest [ testSaveRestore , testSize , testReducedShape , testPlaceholderCse + , testScalarFeedCse ]