1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 02:29:46 +01:00

Add versions of each op that take optional params as an extra arg. (#84)

Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...`
which lets you set optional attributes.  `OpParams` is currently a type alias for
`OpDef -> OpDef`.  In the future we should consider more type safety, e.g.,
using type-level strings and OverloadedLabels for optional attributes.

I used it to replace a few manual `buildOp`s in our code with the codegenerated
ops, now that it's easier to set attributes.  I also removed `tensorAttr` and
`named` since it's now possible to set those op attributes directly.

Although this clutters up the API a bit, I think it's simpler than using type
classes to implement optional arguments (as in, for example, `Text.Printf`) --
especially in terms of type inference with the rest of the library.
This commit is contained in:
Judah Jacobson 2017-03-20 18:16:38 -07:00 committed by GitHub
parent 2c5c879037
commit c99a23b6a7
11 changed files with 190 additions and 152 deletions

View file

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
@ -172,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
renderOp :: ParsedOp -> Doc
renderOp pOp = stack $
[ haddocks
, n <+> "::" <+> hang 0 (typeSig pOp)
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
-- Prevent unreasonably long compilation times on ghc-7.10, due
-- to stack calling "-dump-hi" which (unnecessarily) includes the
-- inlining information, and is large for ops with many arguments.
#if __GLASGOW_HASKELL__ < 800
, "{-# NOINLINE " <> n <> "#-}"
#endif
, n <+> "::" <+> hang 0 (typeSig empty pOp)
, n <+> "=" <+> n <> "' id"
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where
n = renderHaskellName $ parsedOpName pOp
n' = n <> "'"
listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp)
args = sep $ "op'options"
: (map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp))
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length.
@ -247,7 +258,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
-- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a
]
] ++
["& op'options"]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a
@ -270,12 +283,12 @@ extras d = enclose "{-\n" "\n-}" $
-- | The type signature for an op.
-- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc
typeSig pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
typeSig :: Doc -> ParsedOp -> Doc
typeSig pre pOp = constraints
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where

View file

@ -72,6 +72,7 @@ import TensorFlow.Ops
, expandDims
, fill
, matMul
, matMul'
, reducedShape
, reluGrad
, reshape
@ -95,7 +96,6 @@ import TensorFlow.Tensor
, TensorKind (ValueKind)
, Value
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
@ -532,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b =
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of
(False, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ (x `matMul` dz) & transAttrs True False ]
[ Just $ matMul' (transAttrs False True) dz y
, Just $ matMul' (transAttrs True False) x dz]
(False, True) ->
[ Just $ dz `matMul` y
, Just $ (x `matMul` dz) & transAttrs True False ]
[ Just $ matMul dz y
, Just $ matMul' (transAttrs True False) x dz]
(True, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ x `matMul` dz ]
[ Just $ matMul' (transAttrs False True) dz y
, Just $ matMul x dz]
(True, True) ->
[ Just $ (dz `matMul` y) & transAttrs True True
, Just $ (x `matMul` dz) & transAttrs True True ]
[ Just $ matMul' (transAttrs True True) dz y
, Just $ matMul' (transAttrs True True) x dz]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
@ -554,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
[ Just $ CoreOps.conv2DBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -572,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad x output dz
& tensorAttr "ksize" .~ ksize
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "data_format" .~ dataFormat
[ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize)
. (opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x output dz
]
where
output :: Tensor Value a

View file

@ -58,56 +58,99 @@
module TensorFlow.Ops
( CoreOps.add
, CoreOps.add'
, CoreOps.abs
, CoreOps.abs'
, CoreOps.addN
, CoreOps.addN'
, CoreOps.argMax
, CoreOps.argMax'
, CoreOps.assign
, CoreOps.assign'
, CoreOps.broadcastGradientArgs
, CoreOps.broadcastGradientArgs'
, CoreOps.cast
, CoreOps.cast'
, CoreOps.concat
, CoreOps.concat'
, constant
, constant'
, CoreOps.equal
, CoreOps.equal'
, expandDims
, expandDims'
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, CoreOps.fill
, CoreOps.oneHot
, CoreOps.fill'
, CoreOps.identity
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, matTranspose
, matTranspose'
, CoreOps.mean
, CoreOps.mean'
, CoreOps.mul
, CoreOps.mul'
, CoreOps.neg
, CoreOps.neg'
, CoreOps.oneHot
, CoreOps.oneHot'
, CoreOps.pack
, CoreOps.pack'
, placeholder
, placeholder'
, CoreOps.range
, CoreOps.range'
, reducedShape
, CoreOps.relu
, CoreOps.relu'
, CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.reshape
, CoreOps.reshape'
, restore
, restoreFromName
, save
, scalar
, scalar'
, shape
, shape'
, CoreOps.sign
, CoreOps.sign'
, CoreOps.size
, CoreOps.size'
, CoreOps.softmax
, CoreOps.softmax'
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.softmaxCrossEntropyWithLogits'
, CoreOps.sparseToDense
, CoreOps.sparseToDense'
, CoreOps.sub
, CoreOps.sub'
, CoreOps.sum
, CoreOps.sum'
, CoreOps.transpose
, CoreOps.transpose'
, truncatedNormal
, truncatedNormal'
, CoreOps.variable
, CoreOps.variable'
, vector
, vector'
, zeros
, CoreOps.zerosLike
, CoreOps.zerosLike'
, scalarize
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
@ -151,29 +194,31 @@ instance ( TensorType a
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
matTranspose = matTranspose' id
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder shape' =
build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder = placeholder' id
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
placeholder' params pShape
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a)
initializedVariable initializer = do
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
build $ buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
)
v initializer
initializedVariable = initializedVariable' id
initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
initializedVariable' params initializer = do
v <- CoreOps.variable' params [] -- The shape is not known initially.
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
initializer
addInitializer =<< group i
return v
@ -181,7 +226,12 @@ initializedVariable initializer = do
zeroInitializedVariable
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) =>
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a m v . (MonadBuild m, TensorType a)
@ -227,27 +277,27 @@ restore path x = do
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
constant (Shape shape') values
constant :: TensorType a => Shape -> [a] -> Tensor Value a
constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const"
& opAttr "value" .~ typedNode
& opAttr "dtype" .~ nodeType
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
where
invalidLength = product shape' /= fromIntegral (length values)
invalidLength = product cShape /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product shape')
(product cShape)
(length values)
nodeType = tensorType (undefined :: a)
typedNode :: TensorProto
typedNode = def
& dtype .~ nodeType
& dtype .~ tensorType (undefined :: a)
& tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape']
[def & TensorShape.size .~ x | x <- cShape]
& tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar.
--
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape)
@ -257,30 +307,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs
vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]
scalar :: TensorType a => a -> Tensor Value a
scalar = scalar' id
-- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values.
--
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal
= build . buildOp (opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64))
truncatedNormal = CoreOps.truncatedNormal
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> OpParams -> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
shape = CoreOps.shape
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32

View file

@ -19,7 +19,7 @@
module Main where
import Control.Monad.IO.Class (liftIO)
import Lens.Family2 ((^.))
import Lens.Family2 ((^.), (.~))
import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph
( node )
@ -38,8 +38,8 @@ import TensorFlow.Build
, withDevice
, colocateWith
, withNameScope
, opName
)
import TensorFlow.ControlFlow (named)
import TensorFlow.Types (unScalar)
import TensorFlow.Ops
( add
@ -47,6 +47,7 @@ import TensorFlow.Ops
, constant
, initializedVariable
, variable
, variable'
)
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
@ -61,26 +62,16 @@ import Test.HUnit ((@=?))
import Google.Test (googleTest)
import qualified Data.Vector as V
-- | Test named behavior.
testNamed :: Test
testNamed = testCase "testNamed" $ do
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
-- | Test 'opName' behavior.
testOpName :: Test
testOpName = testCase "testOpName" $ do
let graph = variable' (opName .~ "foo") []
>>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"Variable" @=? (nodeDef ^. op)
"foo" @=? (nodeDef ^. name)
-- | Test named deRef behavior.
testNamedDeRef :: Test
testNamedDeRef = testCase "testNamedDeRef" $ do
let graph = named "foo" <$> do
v :: Tensor Ref Float <- variable []
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- graph >>= run
liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already
-- been rendered.
testPureRender :: Test
@ -118,14 +109,15 @@ testNameScoped = testCase "testNameScoped" $ do
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
"Variable" @=? (nodeDef ^. op)
-- | Test combined named and nameScoped behavior.
-- | Test combined opName and nameScoped behavior.
testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
>>= render
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"Variable" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name)
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
@ -174,8 +166,7 @@ main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testDeviceColocation
, testNamed
, testNamedDeRef
, testOpName
, testNameScoped
, testNamedAndScoped
, testPureRender

View file

@ -19,6 +19,7 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Google.Test (googleTest)
import Lens.Family2 ((.~))
import System.IO.Temp (withSystemTempDirectory)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -27,7 +28,6 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
@ -56,7 +56,8 @@ testSaveRestore = testCase "testSaveRestore" $
let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.zeroInitializedVariable' (TF.opName .~ "a")
(TF.Shape [])
TF.runSession $ do
v <- var
TF.assign v 134 >>= TF.run_

View file

@ -36,7 +36,6 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF

View file

@ -23,6 +23,7 @@ module TensorFlow.BuildOp
, buildOp
, buildListOp
, eqLengthGuard
, OpParams
)
where
@ -238,3 +239,7 @@ eqLengthGuard = all eachOk
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)
-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef

View file

@ -22,21 +22,15 @@ module TensorFlow.ControlFlow
withControlDependencies
, group
-- * Operations
, identity
, noOp
, named
) where
import qualified Data.Set as Set
import Data.Text (Text)
import Lens.Family2 ((&), (^.), (.~))
import Lens.Family2 ((&), (.~))
import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
@ -57,31 +51,6 @@ group deps = do
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
-- | Returns a 'Tensor' with the same shape and contents as the input.
identity :: TensorType a => Tensor v a -> Tensor v a
identity = namedIdentity implicitName
-- | Returns a 'Tensor' with a given name and the same shape and contents as
-- the input.
--
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
-- whether we can change that op.
named :: TensorType a => Text -> Tensor v a -> Tensor v a
named = namedIdentity . explicitName
-- | An internal version of "identity" that allows setting the name
-- of the output Tensor.
namedIdentity :: forall a v . TensorType a
=> PendingNodeName -> Tensor v a -> Tensor v a
namedIdentity n t = case t ^. tensorKind of
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
where
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
-- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode
noOp = buildOp $ opDef "NoOp"

View file

@ -50,14 +50,14 @@ module TensorFlow.Core
, render
, asGraphDef
, addGraphDef
, opName
, opAttr
-- * Tensor
, ControlNode
, Tensor
, Value
, Ref
, TensorKind(..)
, tensorAttr
, value
, tensorFromName
-- ** Element types
@ -74,12 +74,10 @@ module TensorFlow.Core
, Device(..)
, withDevice
, withNameScope
, named
-- ** Dependencies
, withControlDependencies
, group
-- ** Misc
, identity
, noOp
) where

View file

@ -124,6 +124,9 @@ data OpDef = OpDef
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
instance IsString PendingNodeName where
fromString = ExplicitName . fromString
-- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).

View file

@ -27,13 +27,12 @@ module TensorFlow.Tensor where
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (^.))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
import TensorFlow.Output (Output)
import TensorFlow.Types
( TensorData(..)
, Attribute
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
@ -62,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
-- TODO: Come up with a better API for handling attributes.
-- | Lens for the attributes of a tensor.
--
-- Only valid if the tensor has not yet been rendered. If the tensor has been
-- rendered, the traversal will be over nothing (nothing can be read or
-- written).
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a