From 16d660c3bc39709b8a3afb12ed457d633bdd348e Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Thu, 6 Apr 2017 19:00:18 -0700 Subject: [PATCH] Support a couple more ops by allowing larger tuples. (#93) --- tensorflow-core-ops/Setup.hs | 3 -- tensorflow/src/TensorFlow/BuildOp.hs | 76 ++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs index 79eba2a..449e6d0 100644 --- a/tensorflow-core-ops/Setup.hs +++ b/tensorflow-core-ops/Setup.hs @@ -66,7 +66,4 @@ generatingOpsWrappers = hooks blackList = [ -- Requires the "func" type: "SymbolicGradient" - -- Easy: support larger result tuples. - , "ParseSingleSequenceExample" - , "Skipgram" ] diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 4a6907e..9b074ce 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -81,6 +81,44 @@ instance ( BuildResult a1 <*> buildResult <*> buildResult +instance ( BuildResult a1 + , BuildResult a2 + , BuildResult a3 + , BuildResult a4 + , BuildResult a5 + , BuildResult a6 + , BuildResult a7 + ) + => BuildResult (a1, a2, a3, a4, a5, a6, a7) where + buildResult = (,,,,,,) + <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + +instance ( BuildResult a1 + , BuildResult a2 + , BuildResult a3 + , BuildResult a4 + , BuildResult a5 + , BuildResult a6 + , BuildResult a7 + , BuildResult a8 + ) + => BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where + buildResult = (,,,,,,,) + <$> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + <*> buildResult + recordResult :: Result Output recordResult = do o <- ask @@ -184,6 +222,44 @@ instance ( PureResult a1 <*> pureResult <*> pureResult +instance ( PureResult a1 + , PureResult a2 + , PureResult a3 + , PureResult a4 + , PureResult a5 + , PureResult a6 + , PureResult a7 + ) + => PureResult (a1, a2, a3, a4, a5, a6, a7) where + pureResult = (,,,,,,) + <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + +instance ( PureResult a1 + , PureResult a2 + , PureResult a3 + , PureResult a4 + , PureResult a5 + , PureResult a6 + , PureResult a7 + , PureResult a8 + ) + => PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where + pureResult = (,,,,,,,) + <$> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + <*> pureResult + instance PureResult a => PureResult [a] where pureResult = do ResultState i ns <- get