Add versions of each op that take optional params as an extra arg. (#84)

Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...`
which lets you set optional attributes.  `OpParams` is currently a type alias for
`OpDef -> OpDef`.  In the future we should consider more type safety, e.g.,
using type-level strings and OverloadedLabels for optional attributes.

I used it to replace a few manual `buildOp`s in our code with the codegenerated
ops, now that it's easier to set attributes.  I also removed `tensorAttr` and
`named` since it's now possible to set those op attributes directly.

Although this clutters up the API a bit, I think it's simpler than using type
classes to implement optional arguments (as in, for example, `Text.Printf`) --
especially in terms of type inference with the rest of the library.
This commit is contained in:
Judah Jacobson 2017-03-20 18:16:38 -07:00 committed by GitHub
parent 2c5c879037
commit c99a23b6a7
11 changed files with 190 additions and 152 deletions

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and -- See the License for the specific language governing permissions and
-- limitations under the License. -- limitations under the License.
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
@ -172,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
renderOp :: ParsedOp -> Doc renderOp :: ParsedOp -> Doc
renderOp pOp = stack $ renderOp pOp = stack $
[ haddocks [ haddocks
, n <+> "::" <+> hang 0 (typeSig pOp) -- Prevent unreasonably long compilation times on ghc-7.10, due
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs -- to stack calling "-dump-hi" which (unnecessarily) includes the
-- inlining information, and is large for ops with many arguments.
#if __GLASGOW_HASKELL__ < 800
, "{-# NOINLINE " <> n <> "#-}"
#endif
, n <+> "::" <+> hang 0 (typeSig empty pOp)
, n <+> "=" <+> n <> "' id"
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented <+> "=" </> -- args are indented
-- the body needs to be indented wrt the name -- the body needs to be indented wrt the name
indent indentation (functionBody pOp) indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs ] ++ whereClause listSizeAttrs
where where
n = renderHaskellName $ parsedOpName pOp n = renderHaskellName $ parsedOpName pOp
n' = n <> "'"
listSizeAttrs = inferredListSizeAttrs pOp listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ map renderHaskellName args = sep $ "op'options"
$ map attrName (explicitInputAttrs pOp) : (map renderHaskellName
++ map parsedArgName (parsedInputs pOp) $ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp))
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp) haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length. -- | A check that all lists of the given size have the given length.
@ -247,7 +258,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
-- Renders sizes of tensor list types having number_attr. -- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a | a <- inferredListSizeAttrs pOp, let n = attrName a
] ] ++
["& op'options"]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a inferredTypeExpr a
@ -270,12 +283,12 @@ extras d = enclose "{-\n" "\n-}" $
-- | The type signature for an op. -- | The type signature for an op.
-- Of the form: -- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2) -- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
-- => Float -> Tensor t1 v1 -> Tensor t2 v2 -- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and -- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output. -- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc typeSig :: Doc -> ParsedOp -> Doc
typeSig pOp = constraints typeSig pre pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp) <+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp) ++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs]) ++ [outputs])
where where

View File

@ -72,6 +72,7 @@ import TensorFlow.Ops
, expandDims , expandDims
, fill , fill
, matMul , matMul
, matMul'
, reducedShape , reducedShape
, reluGrad , reluGrad
, reshape , reshape
@ -95,7 +96,6 @@ import TensorFlow.Tensor
, TensorKind (ValueKind) , TensorKind (ValueKind)
, Value , Value
, tensorOutput , tensorOutput
, tensorAttr
) )
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens) import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef import Proto.Tensorflow.Core.Framework.NodeDef
@ -532,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a" let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b" transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b = transAttrs a b =
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b) (opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of in case (transposeA, transposeB) of
(False, False) -> (False, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True [ Just $ matMul' (transAttrs False True) dz y
, Just $ (x `matMul` dz) & transAttrs True False ] , Just $ matMul' (transAttrs True False) x dz]
(False, True) -> (False, True) ->
[ Just $ dz `matMul` y [ Just $ matMul dz y
, Just $ (x `matMul` dz) & transAttrs True False ] , Just $ matMul' (transAttrs True False) x dz]
(True, False) -> (True, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True [ Just $ matMul' (transAttrs False True) dz y
, Just $ x `matMul` dz ] , Just $ matMul x dz]
(True, True) -> (True, True) ->
[ Just $ (dz `matMul` y) & transAttrs True True [ Just $ matMul' (transAttrs True True) dz y
, Just $ (x `matMul` dz) & transAttrs True True ] , Just $ matMul' (transAttrs True True) x dz]
opGrad "Transpose" _ [_, toT -> p] [dz] = opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz [ Just $ CoreOps.transpose dz
@ -554,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
] ]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz [ Just $ CoreOps.conv2DBackpropInput'
& tensorAttr "strides" .~ strides ((opAttr "strides" .~ strides)
& tensorAttr "padding" .~ padding . (opAttr "padding" .~ padding)
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
& tensorAttr "data_format" .~ dataFormat . (opAttr "data_format" .~ dataFormat))
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz (shape x) y dz
& tensorAttr "strides" .~ strides , Just $ CoreOps.conv2DBackpropFilter'
& tensorAttr "padding" .~ padding ((opAttr "strides" .~ strides)
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu . (opAttr "padding" .~ padding)
& tensorAttr "data_format" .~ dataFormat . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
] ]
where where
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
@ -572,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
dataFormat = lookupAttr nodeDef "data_format" :: ByteString dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] = opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad x output dz [ Just $ CoreOps.maxPoolGrad'
& tensorAttr "ksize" .~ ksize ((opAttr "ksize" .~ ksize)
& tensorAttr "strides" .~ strides . (opAttr "strides" .~ strides)
& tensorAttr "padding" .~ padding . (opAttr "padding" .~ padding)
& tensorAttr "data_format" .~ dataFormat . (opAttr "data_format" .~ dataFormat))
x output dz
] ]
where where
output :: Tensor Value a output :: Tensor Value a

View File

@ -58,56 +58,99 @@
module TensorFlow.Ops module TensorFlow.Ops
( CoreOps.add ( CoreOps.add
, CoreOps.add'
, CoreOps.abs , CoreOps.abs
, CoreOps.abs'
, CoreOps.addN , CoreOps.addN
, CoreOps.addN'
, CoreOps.argMax , CoreOps.argMax
, CoreOps.argMax'
, CoreOps.assign , CoreOps.assign
, CoreOps.assign'
, CoreOps.broadcastGradientArgs , CoreOps.broadcastGradientArgs
, CoreOps.broadcastGradientArgs'
, CoreOps.cast , CoreOps.cast
, CoreOps.cast'
, CoreOps.concat , CoreOps.concat
, CoreOps.concat'
, constant , constant
, constant'
, CoreOps.equal , CoreOps.equal
, CoreOps.equal'
, expandDims , expandDims
, expandDims'
, initializedVariable , initializedVariable
, initializedVariable'
, zeroInitializedVariable , zeroInitializedVariable
, zeroInitializedVariable'
, CoreOps.fill , CoreOps.fill
, CoreOps.oneHot , CoreOps.fill'
, CoreOps.identity
, CoreOps.identity'
, CoreOps.matMul , CoreOps.matMul
, CoreOps.matMul'
, matTranspose , matTranspose
, matTranspose'
, CoreOps.mean , CoreOps.mean
, CoreOps.mean'
, CoreOps.mul , CoreOps.mul
, CoreOps.mul'
, CoreOps.neg , CoreOps.neg
, CoreOps.neg'
, CoreOps.oneHot
, CoreOps.oneHot'
, CoreOps.pack , CoreOps.pack
, CoreOps.pack'
, placeholder , placeholder
, placeholder'
, CoreOps.range , CoreOps.range
, CoreOps.range'
, reducedShape , reducedShape
, CoreOps.relu , CoreOps.relu
, CoreOps.relu'
, CoreOps.reluGrad , CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.reshape , CoreOps.reshape
, CoreOps.reshape'
, restore , restore
, restoreFromName , restoreFromName
, save , save
, scalar , scalar
, scalar'
, shape , shape
, shape'
, CoreOps.sign , CoreOps.sign
, CoreOps.sign'
, CoreOps.size , CoreOps.size
, CoreOps.size'
, CoreOps.softmax , CoreOps.softmax
, CoreOps.softmax'
, CoreOps.softmaxCrossEntropyWithLogits , CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.softmaxCrossEntropyWithLogits'
, CoreOps.sparseToDense , CoreOps.sparseToDense
, CoreOps.sparseToDense'
, CoreOps.sub , CoreOps.sub
, CoreOps.sub'
, CoreOps.sum , CoreOps.sum
, CoreOps.sum'
, CoreOps.transpose , CoreOps.transpose
, CoreOps.transpose'
, truncatedNormal , truncatedNormal
, truncatedNormal'
, CoreOps.variable , CoreOps.variable
, CoreOps.variable'
, vector , vector
, vector'
, zeros , zeros
, CoreOps.zerosLike , CoreOps.zerosLike
, CoreOps.zerosLike'
, scalarize , scalarize
) where ) where
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import Data.Complex (Complex) import Data.Complex (Complex)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat) import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def) import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8) import Data.Text.Encoding (encodeUtf8)
@ -151,29 +194,31 @@ instance ( TensorType a
signum = CoreOps.sign signum = CoreOps.sign
negate = CoreOps.neg negate = CoreOps.neg
matTranspose :: forall a v . TensorType a matTranspose :: TensorType a => Tensor v a -> Tensor Value a
=> Tensor v a -> Tensor Value a matTranspose = matTranspose' id
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
placeholder shape' = matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
& opAttr "shape" .~ shape' placeholder = placeholder' id
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
placeholder' params pShape
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
-- | Creates a variable initialized to the given value. -- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs. -- Initialization happens next time session runs.
initializedVariable :: forall a m . (MonadBuild m, TensorType a) initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a) => Tensor Value a -> m (Tensor Ref a)
initializedVariable initializer = do initializedVariable = initializedVariable' id
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <- initializedVariable' :: (MonadBuild m, TensorType a)
build $ buildOp (opDef "Assign" => OpParams -> Tensor Value a -> m (Tensor Ref a)
& opAttr "T" .~ tensorType (undefined :: a) initializedVariable' params initializer = do
& opAttr "use_locking" .~ True v <- CoreOps.variable' params [] -- The shape is not known initially.
& opAttr "validate_shape" .~ False i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
) initializer
v initializer
addInitializer =<< group i addInitializer =<< group i
return v return v
@ -181,7 +226,12 @@ initializedVariable initializer = do
zeroInitializedVariable zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) => :: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a) TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) =>
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors. -- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (MonadBuild m, TensorType a) save :: forall a m v . (MonadBuild m, TensorType a)
@ -227,27 +277,27 @@ restore path x = do
-- element 0: index (0, ..., 0) -- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1) -- element 1: index (0, ..., 1)
-- ... -- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a constant :: TensorType a => Shape -> [a] -> Tensor Value a
constant (Shape shape') values constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg | invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const" | otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
& opAttr "value" .~ typedNode
& opAttr "dtype" .~ nodeType
where where
invalidLength = product shape' /= fromIntegral (length values) invalidLength = product cShape /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d" invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product shape') (product cShape)
(length values) (length values)
nodeType = tensorType (undefined :: a)
typedNode :: TensorProto typedNode :: TensorProto
typedNode = def typedNode = def
& dtype .~ nodeType & dtype .~ tensorType (undefined :: a)
& tensorShape.TensorShape.dim .~ & tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape'] [def & TensorShape.size .~ x | x <- cShape]
& tensorVal .~ values & tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar. -- | Reshape a N-D tensor down to a scalar.
-- --
-- See `TensorFlow.GenOps.Core.reshape`. -- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape) scalarize t = CoreOps.reshape t (vector scalarShape)
@ -257,30 +307,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
-- | Create a constant vector. -- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar. -- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a scalar :: TensorType a => a -> Tensor Value a
scalar x = constant [] [x] scalar = scalar' id
-- Random tensor from the unit normal distribution with bounded values. scalar' :: TensorType a => OpParams -> a -> Tensor Value a
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a) scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values.
--
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> Tensor v Int64 -- ^ Shape. => Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a) -> m (Tensor Value a)
truncatedNormal truncatedNormal = CoreOps.truncatedNormal
= build . buildOp (opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a) truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
& opAttr "T" .~ tensorType (undefined :: Int64)) => OpParams -> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0) zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32 shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
shape = CoreOps.shape shape = CoreOps.shape
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims = CoreOps.expandDims expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape). -- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) => reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32 Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32

View File

@ -19,7 +19,7 @@
module Main where module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Lens.Family2 ((^.)) import Lens.Family2 ((^.), (.~))
import Data.List (sort) import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph import Proto.Tensorflow.Core.Framework.Graph
( node ) ( node )
@ -38,8 +38,8 @@ import TensorFlow.Build
, withDevice , withDevice
, colocateWith , colocateWith
, withNameScope , withNameScope
, opName
) )
import TensorFlow.ControlFlow (named)
import TensorFlow.Types (unScalar) import TensorFlow.Types (unScalar)
import TensorFlow.Ops import TensorFlow.Ops
( add ( add
@ -47,6 +47,7 @@ import TensorFlow.Ops
, constant , constant
, initializedVariable , initializedVariable
, variable , variable
, variable'
) )
import TensorFlow.Output (Device(..)) import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref) import TensorFlow.Tensor (Tensor, Value, Ref)
@ -61,26 +62,16 @@ import Test.HUnit ((@=?))
import Google.Test (googleTest) import Google.Test (googleTest)
import qualified Data.Vector as V import qualified Data.Vector as V
-- | Test named behavior. -- | Test 'opName' behavior.
testNamed :: Test testOpName :: Test
testNamed = testCase "testNamed" $ do testOpName = testCase "testOpName" $ do
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float) let graph = variable' (opName .~ "foo") []
>>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)
"foo" @=? (nodeDef ^. name) "foo" @=? (nodeDef ^. name)
-- | Test named deRef behavior.
testNamedDeRef :: Test
testNamedDeRef = testCase "testNamedDeRef" $ do
let graph = named "foo" <$> do
v :: Tensor Ref Float <- variable []
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- graph >>= run
liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already -- | Test that "run" will render and extend any pure ops that haven't already
-- been rendered. -- been rendered.
testPureRender :: Test testPureRender :: Test
@ -118,14 +109,15 @@ testNameScoped = testCase "testNameScoped" $ do
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix. "foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
"Variable" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)
-- | Test combined named and nameScoped behavior. -- | Test combined opName and nameScoped behavior.
testNamedAndScoped :: Test testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float) let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render) graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
>>= render
nodeDef :: NodeDef nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op) "Variable" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name) "foo1/bar1" @=? (nodeDef ^. name)
-- | Flush the node buffer and sort the nodes by name (for more stable tests). -- | Flush the node buffer and sort the nodes by name (for more stable tests).
@ -174,8 +166,7 @@ main :: IO ()
main = googleTest [ testInitializedVariable main = googleTest [ testInitializedVariable
, testInitializedVariableShape , testInitializedVariableShape
, testDeviceColocation , testDeviceColocation
, testNamed , testOpName
, testNamedDeRef
, testNameScoped , testNameScoped
, testNamedAndScoped , testNamedAndScoped
, testPureRender , testPureRender

View File

@ -19,6 +19,7 @@ module Main where
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Google.Test (googleTest) import Google.Test (googleTest)
import Lens.Family2 ((.~))
import System.IO.Temp (withSystemTempDirectory) import System.IO.Temp (withSystemTempDirectory)
import Test.Framework (Test) import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.HUnit (testCase)
@ -27,7 +28,6 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.Build as TF import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Nodes as TF import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
@ -56,7 +56,8 @@ testSaveRestore = testCase "testSaveRestore" $
let path = B8.pack $ dirPath ++ "/checkpoint" let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float) var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<< var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape []) TF.zeroInitializedVariable' (TF.opName .~ "a")
(TF.Shape [])
TF.runSession $ do TF.runSession $ do
v <- var v <- var
TF.assign v 134 >>= TF.run_ TF.assign v 134 >>= TF.run_

View File

@ -36,7 +36,6 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select) import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF

View File

@ -23,6 +23,7 @@ module TensorFlow.BuildOp
, buildOp , buildOp
, buildListOp , buildListOp
, eqLengthGuard , eqLengthGuard
, OpParams
) )
where where
@ -238,3 +239,7 @@ eqLengthGuard = all eachOk
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs || eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++ error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs) " contains tensors with different length " ++ show pairs)
-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef

View File

@ -22,21 +22,15 @@ module TensorFlow.ControlFlow
withControlDependencies withControlDependencies
, group , group
-- * Operations -- * Operations
, identity
, noOp , noOp
, named
) where ) where
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Text (Text) import Lens.Family2 ((&), (.~))
import Lens.Family2 ((&), (^.), (.~))
import TensorFlow.BuildOp import TensorFlow.BuildOp
import TensorFlow.Build import TensorFlow.Build
import TensorFlow.Nodes import TensorFlow.Nodes
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend -- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument. -- on the nodes in the first argument.
@ -57,31 +51,6 @@ group deps = do
-- TODO: slicker way -- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
-- | Returns a 'Tensor' with the same shape and contents as the input.
identity :: TensorType a => Tensor v a -> Tensor v a
identity = namedIdentity implicitName
-- | Returns a 'Tensor' with a given name and the same shape and contents as
-- the input.
--
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
-- whether we can change that op.
named :: TensorType a => Text -> Tensor v a -> Tensor v a
named = namedIdentity . explicitName
-- | An internal version of "identity" that allows setting the name
-- of the output Tensor.
namedIdentity :: forall a v . TensorType a
=> PendingNodeName -> Tensor v a -> Tensor v a
namedIdentity n t = case t ^. tensorKind of
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
where
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
-- | Does nothing. Only useful as a placeholder for control edges. -- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode noOp :: ControlNode
noOp = buildOp $ opDef "NoOp" noOp = buildOp $ opDef "NoOp"

View File

@ -50,14 +50,14 @@ module TensorFlow.Core
, render , render
, asGraphDef , asGraphDef
, addGraphDef , addGraphDef
, opName
, opAttr
-- * Tensor -- * Tensor
, ControlNode , ControlNode
, Tensor , Tensor
, Value , Value
, Ref , Ref
, TensorKind(..) , TensorKind(..)
, tensorAttr
, value , value
, tensorFromName , tensorFromName
-- ** Element types -- ** Element types
@ -74,12 +74,10 @@ module TensorFlow.Core
, Device(..) , Device(..)
, withDevice , withDevice
, withNameScope , withNameScope
, named
-- ** Dependencies -- ** Dependencies
, withControlDependencies , withControlDependencies
, group , group
-- ** Misc -- ** Misc
, identity
, noOp , noOp
) where ) where

View File

@ -124,6 +124,9 @@ data OpDef = OpDef
data PendingNodeName = ExplicitName !Text | ImplicitName data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
instance IsString PendingNodeName where
fromString = ExplicitName . fromString
-- | The name of a node in the graph. This corresponds to the proto field -- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier -- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named). -- (if the node was implicitly named).

View File

@ -27,13 +27,12 @@ module TensorFlow.Tensor where
import Data.String (IsString(..)) import Data.String (IsString(..))
import qualified Data.Text as Text import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (^.)) import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens) import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) import TensorFlow.Output (Output)
import TensorFlow.Types import TensorFlow.Types
( TensorData(..) ( TensorData(..)
, Attribute
, ListOf(..) , ListOf(..)
) )
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
@ -62,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
tensorOutput :: Lens' (Tensor v a) Output tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o) tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
-- TODO: Come up with a better API for handling attributes.
-- | Lens for the attributes of a tensor.
--
-- Only valid if the tensor has not yet been rendered. If the tensor has been
-- rendered, the traversal will be over nothing (nothing can be read or
-- written).
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a -- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op. -- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a value :: Tensor v a -> Tensor Value a