From c99a23b6a746949136ae193d3982de5ee2a4c414 Mon Sep 17 00:00:00 2001 From: Judah Jacobson Date: Mon, 20 Mar 2017 18:16:38 -0700 Subject: [PATCH] Add versions of each op that take optional params as an extra arg. (#84) Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...` which lets you set optional attributes. `OpParams` is currently a type alias for `OpDef -> OpDef`. In the future we should consider more type safety, e.g., using type-level strings and OverloadedLabels for optional attributes. I used it to replace a few manual `buildOp`s in our code with the codegenerated ops, now that it's easier to set attributes. I also removed `tensorAttr` and `named` since it's now possible to set those op attributes directly. Although this clutters up the API a bit, I think it's simpler than using type classes to implement optional arguments (as in, for example, `Text.Printf`) -- especially in terms of type inference with the rest of the library. --- tensorflow-opgen/src/TensorFlow/OpGen.hs | 33 +++-- tensorflow-ops/src/TensorFlow/Gradient.hs | 53 ++++---- tensorflow-ops/src/TensorFlow/Ops.hs | 152 ++++++++++++++++------ tensorflow-ops/tests/BuildTest.hs | 37 ++---- tensorflow-ops/tests/OpsTest.hs | 5 +- tensorflow-ops/tests/TypesTest.hs | 1 - tensorflow/src/TensorFlow/BuildOp.hs | 5 + tensorflow/src/TensorFlow/ControlFlow.hs | 33 +---- tensorflow/src/TensorFlow/Core.hs | 6 +- tensorflow/src/TensorFlow/Output.hs | 3 + tensorflow/src/TensorFlow/Tensor.hs | 14 +- 11 files changed, 190 insertions(+), 152 deletions(-) diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index bf2cfe8..2a56c7d 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -12,6 +12,7 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} @@ -172,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName renderOp :: ParsedOp -> Doc renderOp pOp = stack $ [ haddocks - , n <+> "::" <+> hang 0 (typeSig pOp) - , n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs + -- Prevent unreasonably long compilation times on ghc-7.10, due + -- to stack calling "-dump-hi" which (unnecessarily) includes the + -- inlining information, and is large for ops with many arguments. +#if __GLASGOW_HASKELL__ < 800 + , "{-# NOINLINE " <> n <> "#-}" +#endif + , n <+> "::" <+> hang 0 (typeSig empty pOp) + , n <+> "=" <+> n <> "' id" + , n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp) + , n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs <+> "=" -- args are indented -- the body needs to be indented wrt the name indent indentation (functionBody pOp) ] ++ whereClause listSizeAttrs where n = renderHaskellName $ parsedOpName pOp + n' = n <> "'" listSizeAttrs = inferredListSizeAttrs pOp - args = sep $ map renderHaskellName - $ map attrName (explicitInputAttrs pOp) - ++ map parsedArgName (parsedInputs pOp) + args = sep $ "op'options" + : (map renderHaskellName + $ map attrName (explicitInputAttrs pOp) + ++ map parsedArgName (parsedInputs pOp)) haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp) -- | A check that all lists of the given size have the given length. @@ -247,7 +258,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp -- Renders sizes of tensor list types having number_attr. [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n | a <- inferredListSizeAttrs pOp, let n = attrName a - ] + ] ++ + ["& op'options"] + tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp inferredTypeExpr a @@ -270,12 +283,12 @@ extras d = enclose "{-\n" "\n-}" $ -- | The type signature for an op. -- Of the form: -- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2) --- => Float -> Tensor t1 v1 -> Tensor t2 v2 +-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2 -- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and -- "Tensor t2 v2" is an output. -typeSig :: ParsedOp -> Doc -typeSig pOp = constraints - <+/> signatureFold (map attrInput (explicitInputAttrs pOp) +typeSig :: Doc -> ParsedOp -> Doc +typeSig pre pOp = constraints + <+/> pre signatureFold (map attrInput (explicitInputAttrs pOp) ++ map tensorArgAndComment (parsedInputs pOp) ++ [outputs]) where diff --git a/tensorflow-ops/src/TensorFlow/Gradient.hs b/tensorflow-ops/src/TensorFlow/Gradient.hs index f58d4ad..099a196 100644 --- a/tensorflow-ops/src/TensorFlow/Gradient.hs +++ b/tensorflow-ops/src/TensorFlow/Gradient.hs @@ -72,6 +72,7 @@ import TensorFlow.Ops , expandDims , fill , matMul + , matMul' , reducedShape , reluGrad , reshape @@ -95,7 +96,6 @@ import TensorFlow.Tensor , TensorKind (ValueKind) , Value , tensorOutput - , tensorAttr ) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import Proto.Tensorflow.Core.Framework.NodeDef @@ -532,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] = let transposeA = lookupAttr nodeDef "transpose_a" transposeB = lookupAttr nodeDef "transpose_b" transAttrs a b = - (tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b) + (opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b) in case (transposeA, transposeB) of (False, False) -> - [ Just $ (dz `matMul` y) & transAttrs False True - , Just $ (x `matMul` dz) & transAttrs True False ] + [ Just $ matMul' (transAttrs False True) dz y + , Just $ matMul' (transAttrs True False) x dz] (False, True) -> - [ Just $ dz `matMul` y - , Just $ (x `matMul` dz) & transAttrs True False ] + [ Just $ matMul dz y + , Just $ matMul' (transAttrs True False) x dz] (True, False) -> - [ Just $ (dz `matMul` y) & transAttrs False True - , Just $ x `matMul` dz ] + [ Just $ matMul' (transAttrs False True) dz y + , Just $ matMul x dz] (True, True) -> - [ Just $ (dz `matMul` y) & transAttrs True True - , Just $ (x `matMul` dz) & transAttrs True True ] + [ Just $ matMul' (transAttrs True True) dz y + , Just $ matMul' (transAttrs True True) x dz] opGrad "Transpose" _ [_, toT -> p] [dz] = [ Just $ CoreOps.transpose dz @@ -554,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] = ] opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = - [ Just $ CoreOps.conv2DBackpropInput (shape x) y dz - & tensorAttr "strides" .~ strides - & tensorAttr "padding" .~ padding - & tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu - & tensorAttr "data_format" .~ dataFormat - , Just $ CoreOps.conv2DBackpropFilter x (shape y) dz - & tensorAttr "strides" .~ strides - & tensorAttr "padding" .~ padding - & tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu - & tensorAttr "data_format" .~ dataFormat + [ Just $ CoreOps.conv2DBackpropInput' + ((opAttr "strides" .~ strides) + . (opAttr "padding" .~ padding) + . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) + . (opAttr "data_format" .~ dataFormat)) + (shape x) y dz + , Just $ CoreOps.conv2DBackpropFilter' + ((opAttr "strides" .~ strides) + . (opAttr "padding" .~ padding) + . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) + . (opAttr "data_format" .~ dataFormat)) + x (shape y) dz ] where strides = lookupAttr nodeDef "strides" :: [Int64] @@ -572,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = dataFormat = lookupAttr nodeDef "data_format" :: ByteString opGrad "MaxPool" nodeDef [toT -> x] [dz] = - [ Just $ CoreOps.maxPoolGrad x output dz - & tensorAttr "ksize" .~ ksize - & tensorAttr "strides" .~ strides - & tensorAttr "padding" .~ padding - & tensorAttr "data_format" .~ dataFormat + [ Just $ CoreOps.maxPoolGrad' + ((opAttr "ksize" .~ ksize) + . (opAttr "strides" .~ strides) + . (opAttr "padding" .~ padding) + . (opAttr "data_format" .~ dataFormat)) + x output dz ] where output :: Tensor Value a diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 086430d..825bc64 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -58,56 +58,99 @@ module TensorFlow.Ops ( CoreOps.add + , CoreOps.add' , CoreOps.abs + , CoreOps.abs' , CoreOps.addN + , CoreOps.addN' , CoreOps.argMax + , CoreOps.argMax' , CoreOps.assign + , CoreOps.assign' , CoreOps.broadcastGradientArgs + , CoreOps.broadcastGradientArgs' , CoreOps.cast + , CoreOps.cast' , CoreOps.concat + , CoreOps.concat' , constant + , constant' , CoreOps.equal + , CoreOps.equal' , expandDims + , expandDims' , initializedVariable + , initializedVariable' , zeroInitializedVariable + , zeroInitializedVariable' , CoreOps.fill - , CoreOps.oneHot + , CoreOps.fill' + , CoreOps.identity + , CoreOps.identity' , CoreOps.matMul + , CoreOps.matMul' , matTranspose + , matTranspose' , CoreOps.mean + , CoreOps.mean' , CoreOps.mul + , CoreOps.mul' , CoreOps.neg + , CoreOps.neg' + , CoreOps.oneHot + , CoreOps.oneHot' , CoreOps.pack + , CoreOps.pack' , placeholder + , placeholder' , CoreOps.range + , CoreOps.range' , reducedShape , CoreOps.relu + , CoreOps.relu' , CoreOps.reluGrad + , CoreOps.reluGrad' , CoreOps.reshape + , CoreOps.reshape' , restore , restoreFromName , save , scalar + , scalar' , shape + , shape' , CoreOps.sign + , CoreOps.sign' , CoreOps.size + , CoreOps.size' , CoreOps.softmax + , CoreOps.softmax' , CoreOps.softmaxCrossEntropyWithLogits + , CoreOps.softmaxCrossEntropyWithLogits' , CoreOps.sparseToDense + , CoreOps.sparseToDense' , CoreOps.sub + , CoreOps.sub' , CoreOps.sum + , CoreOps.sum' , CoreOps.transpose + , CoreOps.transpose' , truncatedNormal + , truncatedNormal' , CoreOps.variable + , CoreOps.variable' , vector + , vector' , zeros , CoreOps.zerosLike + , CoreOps.zerosLike' , scalarize ) where import Data.ByteString (ByteString) import Data.Complex (Complex) import Data.Int (Int32, Int64) +import Data.Word (Word16) import Prelude hiding (abs, sum, concat) import Data.ProtoLens (def) import Data.Text.Encoding (encodeUtf8) @@ -151,29 +194,31 @@ instance ( TensorType a signum = CoreOps.sign negate = CoreOps.neg -matTranspose :: forall a v . TensorType a - => Tensor v a -> Tensor Value a -matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32]) +matTranspose :: TensorType a => Tensor v a -> Tensor Value a +matTranspose = matTranspose' id -placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) -placeholder shape' = - build $ buildOp $ opDef "Placeholder" - & opAttr "dtype" .~ tensorType (undefined :: a) - & opAttr "shape" .~ shape' +matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a +matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32]) + +placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) +placeholder = placeholder' id + +placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a) +placeholder' params pShape + = render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape)) -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. -initializedVariable :: forall a m . (MonadBuild m, TensorType a) +initializedVariable :: (MonadBuild m, TensorType a) => Tensor Value a -> m (Tensor Ref a) -initializedVariable initializer = do - v <- CoreOps.variable [] -- The shape is not known initially. - (i :: Tensor Ref a) <- - build $ buildOp (opDef "Assign" - & opAttr "T" .~ tensorType (undefined :: a) - & opAttr "use_locking" .~ True - & opAttr "validate_shape" .~ False - ) - v initializer +initializedVariable = initializedVariable' id + +initializedVariable' :: (MonadBuild m, TensorType a) + => OpParams -> Tensor Value a -> m (Tensor Ref a) +initializedVariable' params initializer = do + v <- CoreOps.variable' params [] -- The shape is not known initially. + i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v + initializer addInitializer =<< group i return v @@ -181,7 +226,12 @@ initializedVariable initializer = do zeroInitializedVariable :: (MonadBuild m, TensorType a, Num a) => TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a) -zeroInitializedVariable = initializedVariable . zeros +zeroInitializedVariable = zeroInitializedVariable' id + +zeroInitializedVariable' + :: (MonadBuild m, TensorType a, Num a) => + OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a) +zeroInitializedVariable' params = initializedVariable' params . zeros -- TODO: Support heterogeneous list of tensors. save :: forall a m v . (MonadBuild m, TensorType a) @@ -227,27 +277,27 @@ restore path x = do -- element 0: index (0, ..., 0) -- element 1: index (0, ..., 1) -- ... -constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a -constant (Shape shape') values +constant :: TensorType a => Shape -> [a] -> Tensor Value a +constant = constant' id + +constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a +constant' params (Shape cShape) values | invalidLength = error invalidLengthMsg - | otherwise = buildOp $ opDef "Const" - & opAttr "value" .~ typedNode - & opAttr "dtype" .~ nodeType + | otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode)) where - invalidLength = product shape' /= fromIntegral (length values) + invalidLength = product cShape /= fromIntegral (length values) invalidLengthMsg = printf "invalid tensor length: expected %d got %d" - (product shape') + (product cShape) (length values) - nodeType = tensorType (undefined :: a) typedNode :: TensorProto typedNode = def - & dtype .~ nodeType + & dtype .~ tensorType (undefined :: a) & tensorShape.TensorShape.dim .~ - [def & TensorShape.size .~ x | x <- shape'] + [def & TensorShape.size .~ x | x <- cShape] & tensorVal .~ values -- | Reshape a N-D tensor down to a scalar. --- +-- -- See `TensorFlow.GenOps.Core.reshape`. scalarize :: (TensorType a) => Tensor v a -> Tensor Value a scalarize t = CoreOps.reshape t (vector scalarShape) @@ -257,30 +307,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape) -- | Create a constant vector. vector :: TensorType a => [a] -> Tensor Value a -vector xs = constant [fromIntegral $ length xs] xs +vector = vector' id + +vector' :: TensorType a => OpParams -> [a] -> Tensor Value a +vector' params xs = constant' params [fromIntegral $ length xs] xs -- | Create a constant scalar. -scalar :: forall a . TensorType a => a -> Tensor Value a -scalar x = constant [] [x] +scalar :: TensorType a => a -> Tensor Value a +scalar = scalar' id --- Random tensor from the unit normal distribution with bounded values. -truncatedNormal :: forall a m v . (MonadBuild m, TensorType a) +scalar' :: TensorType a => OpParams -> a -> Tensor Value a +scalar' params x = constant' params [] [x] + +-- | Random tensor from the unit normal distribution with bounded values. +-- +-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'. +truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a) => Tensor v Int64 -- ^ Shape. -> m (Tensor Value a) -truncatedNormal - = build . buildOp (opDef "TruncatedNormal" - & opAttr "dtype" .~ tensorType (undefined :: a) - & opAttr "T" .~ tensorType (undefined :: Int64)) +truncatedNormal = CoreOps.truncatedNormal + +truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a) + => OpParams -> Tensor v Int64 -- ^ Shape. + -> m (Tensor Value a) +truncatedNormal' = CoreOps.truncatedNormal' zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a -zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) +zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0) -shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32 +shape :: TensorType t => Tensor v1 t -> Tensor Value Int32 shape = CoreOps.shape -expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32 +shape' = CoreOps.shape' + +expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t expandDims = CoreOps.expandDims +expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t +expandDims' = CoreOps.expandDims' + -- | Helper function for reduction ops (translation of math_ops.reduced_shape). reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 diff --git a/tensorflow-ops/tests/BuildTest.hs b/tensorflow-ops/tests/BuildTest.hs index c75cf09..9410dab 100644 --- a/tensorflow-ops/tests/BuildTest.hs +++ b/tensorflow-ops/tests/BuildTest.hs @@ -19,7 +19,7 @@ module Main where import Control.Monad.IO.Class (liftIO) -import Lens.Family2 ((^.)) +import Lens.Family2 ((^.), (.~)) import Data.List (sort) import Proto.Tensorflow.Core.Framework.Graph ( node ) @@ -38,8 +38,8 @@ import TensorFlow.Build , withDevice , colocateWith , withNameScope + , opName ) -import TensorFlow.ControlFlow (named) import TensorFlow.Types (unScalar) import TensorFlow.Ops ( add @@ -47,6 +47,7 @@ import TensorFlow.Ops , constant , initializedVariable , variable + , variable' ) import TensorFlow.Output (Device(..)) import TensorFlow.Tensor (Tensor, Value, Ref) @@ -61,26 +62,16 @@ import Test.HUnit ((@=?)) import Google.Test (googleTest) import qualified Data.Vector as V --- | Test named behavior. -testNamed :: Test -testNamed = testCase "testNamed" $ do - let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float) +-- | Test 'opName' behavior. +testOpName :: Test +testOpName = testCase "testOpName" $ do + let graph = variable' (opName .~ "foo") [] + >>= render :: Build (Tensor Ref Float) nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node - "RefIdentity" @=? (nodeDef ^. op) + "Variable" @=? (nodeDef ^. op) "foo" @=? (nodeDef ^. name) --- | Test named deRef behavior. -testNamedDeRef :: Test -testNamedDeRef = testCase "testNamedDeRef" $ do - let graph = named "foo" <$> do - v :: Tensor Ref Float <- variable [] - assign v 5 - -- TODO: Implement TensorFlow get_variable and test it. - runSession $ do - out <- graph >>= run - liftIO $ 5 @=? (unScalar out :: Float) - -- | Test that "run" will render and extend any pure ops that haven't already -- been rendered. testPureRender :: Test @@ -118,14 +109,15 @@ testNameScoped = testCase "testNameScoped" $ do "foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix. "Variable" @=? (nodeDef ^. op) --- | Test combined named and nameScoped behavior. +-- | Test combined opName and nameScoped behavior. testNamedAndScoped :: Test testNamedAndScoped = testCase "testNamedAndScoped" $ do let graph :: Build (Tensor Ref Float) - graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render) + graph = withNameScope "foo1" (variable' (opName .~ "bar1") []) + >>= render nodeDef :: NodeDef nodeDef = head $ asGraphDef graph ^. node - "RefIdentity" @=? (nodeDef ^. op) + "Variable" @=? (nodeDef ^. op) "foo1/bar1" @=? (nodeDef ^. name) -- | Flush the node buffer and sort the nodes by name (for more stable tests). @@ -174,8 +166,7 @@ main :: IO () main = googleTest [ testInitializedVariable , testInitializedVariableShape , testDeviceColocation - , testNamed - , testNamedDeRef + , testOpName , testNameScoped , testNamedAndScoped , testPureRender diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 20f796e..53b14cb 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -19,6 +19,7 @@ module Main where import Control.Monad.IO.Class (liftIO) import Data.Int (Int32, Int64) import Google.Test (googleTest) +import Lens.Family2 ((.~)) import System.IO.Temp (withSystemTempDirectory) import Test.Framework (Test) import Test.Framework.Providers.HUnit (testCase) @@ -27,7 +28,6 @@ import qualified Data.ByteString.Char8 as B8 import qualified Data.Vector as V import qualified TensorFlow.Build as TF -import qualified TensorFlow.ControlFlow as TF import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF @@ -56,7 +56,8 @@ testSaveRestore = testCase "testSaveRestore" $ let path = B8.pack $ dirPath ++ "/checkpoint" var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) var = TF.render =<< - TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape []) + TF.zeroInitializedVariable' (TF.opName .~ "a") + (TF.Shape []) TF.runSession $ do v <- var TF.assign v 134 >>= TF.run_ diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs index 2610231..567d232 100644 --- a/tensorflow-ops/tests/TypesTest.hs +++ b/tensorflow-ops/tests/TypesTest.hs @@ -36,7 +36,6 @@ import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.Vector as V -import qualified TensorFlow.ControlFlow as TF import qualified TensorFlow.GenOps.Core as TF (select) import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 3411e24..ad625d1 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -23,6 +23,7 @@ module TensorFlow.BuildOp , buildOp , buildListOp , eqLengthGuard + , OpParams ) where @@ -238,3 +239,7 @@ eqLengthGuard = all eachOk eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs || error ("number_attr " ++ numberAttrName ++ " contains tensors with different length " ++ show pairs) + +-- | Parameters to build an op (for example, the node name or optional attributes). +-- TODO: be more type safe. +type OpParams = OpDef -> OpDef diff --git a/tensorflow/src/TensorFlow/ControlFlow.hs b/tensorflow/src/TensorFlow/ControlFlow.hs index 2a57b22..2f20b92 100644 --- a/tensorflow/src/TensorFlow/ControlFlow.hs +++ b/tensorflow/src/TensorFlow/ControlFlow.hs @@ -22,21 +22,15 @@ module TensorFlow.ControlFlow withControlDependencies , group -- * Operations - , identity , noOp - , named ) where import qualified Data.Set as Set -import Data.Text (Text) -import Lens.Family2 ((&), (^.), (.~)) +import Lens.Family2 ((&), (.~)) import TensorFlow.BuildOp import TensorFlow.Build import TensorFlow.Nodes -import TensorFlow.Output -import TensorFlow.Tensor -import TensorFlow.Types -- | Modify a 'Build' action, such that all new ops rendered in it will depend -- on the nodes in the first argument. @@ -57,31 +51,6 @@ group deps = do -- TODO: slicker way return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes - --- | Returns a 'Tensor' with the same shape and contents as the input. -identity :: TensorType a => Tensor v a -> Tensor v a -identity = namedIdentity implicitName - --- | Returns a 'Tensor' with a given name and the same shape and contents as --- the input. --- --- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s, --- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into --- whether we can change that op. -named :: TensorType a => Text -> Tensor v a -> Tensor v a -named = namedIdentity . explicitName - --- | An internal version of "identity" that allows setting the name --- of the output Tensor. -namedIdentity :: forall a v . TensorType a - => PendingNodeName -> Tensor v a -> Tensor v a -namedIdentity n t = case t ^. tensorKind of - ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t - RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t - where - setTypeAttr = opAttr "T" .~ tensorType (undefined :: a) - - -- | Does nothing. Only useful as a placeholder for control edges. noOp :: ControlNode noOp = buildOp $ opDef "NoOp" diff --git a/tensorflow/src/TensorFlow/Core.hs b/tensorflow/src/TensorFlow/Core.hs index 67781aa..7764e37 100644 --- a/tensorflow/src/TensorFlow/Core.hs +++ b/tensorflow/src/TensorFlow/Core.hs @@ -50,14 +50,14 @@ module TensorFlow.Core , render , asGraphDef , addGraphDef - + , opName + , opAttr -- * Tensor , ControlNode , Tensor , Value , Ref , TensorKind(..) - , tensorAttr , value , tensorFromName -- ** Element types @@ -74,12 +74,10 @@ module TensorFlow.Core , Device(..) , withDevice , withNameScope - , named -- ** Dependencies , withControlDependencies , group -- ** Misc - , identity , noOp ) where diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 9edd720..9ee31c8 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -124,6 +124,9 @@ data OpDef = OpDef data PendingNodeName = ExplicitName !Text | ImplicitName deriving (Eq, Ord, Show) +instance IsString PendingNodeName where + fromString = ExplicitName . fromString + -- | The name of a node in the graph. This corresponds to the proto field -- NodeDef.name. Includes the scope prefix (if any) and a unique identifier -- (if the node was implicitly named). diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index 7d6ca4f..e353eb6 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -27,13 +27,12 @@ module TensorFlow.Tensor where import Data.String (IsString(..)) import qualified Data.Text as Text -import Lens.Family2 (Lens', Traversal', (^.)) +import Lens.Family2 (Lens', (^.)) import Lens.Family2.Unchecked (lens) -import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) +import TensorFlow.Output (Output) import TensorFlow.Types ( TensorData(..) - , Attribute , ListOf(..) ) import qualified TensorFlow.Internal.FFI as FFI @@ -62,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o) tensorOutput :: Lens' (Tensor v a) Output tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) --- TODO: Come up with a better API for handling attributes. --- | Lens for the attributes of a tensor. --- --- Only valid if the tensor has not yet been rendered. If the tensor has been --- rendered, the traversal will be over nothing (nothing can be read or --- written). -tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr -tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x - -- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a -- Ref into Value. This behaves like a no-op. value :: Tensor v a -> Tensor Value a