diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 26adc39..a117015 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -79,10 +79,24 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 liftIO $ result @=? TF.Scalar 5 +-- | Test that regular tensors can also be used for feeds, as long as they each +-- have a different name. +testScalarFeedCse :: Test +testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do + p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0 + -- The second op is identical to the first other than its name; make sure + -- we don't aggressively CSE them together and prevent feeding them + -- separately. + p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0 + let enc :: Float -> TF.TensorData Float + enc n = TF.encodeTensorData [] (V.fromList [n]) + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + liftIO $ result @=? TF.Scalar 5 main :: IO () main = googleTest [ testSaveRestore , testSize , testReducedShape , testPlaceholderCse + , testScalarFeedCse ]