mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Add versions of each op that take optional params as an extra arg. (#84)
Each op `foo :: ...` now has a corresponding `foo' :: OpParams -> ...` which lets you set optional attributes. `OpParams` is currently a type alias for `OpDef -> OpDef`. In the future we should consider more type safety, e.g., using type-level strings and OverloadedLabels for optional attributes. I used it to replace a few manual `buildOp`s in our code with the codegenerated ops, now that it's easier to set attributes. I also removed `tensorAttr` and `named` since it's now possible to set those op attributes directly. Although this clutters up the API a bit, I think it's simpler than using type classes to implement optional arguments (as in, for example, `Text.Printf`) -- especially in terms of type inference with the rest of the library.
This commit is contained in:
parent
2c5c879037
commit
c99a23b6a7
11 changed files with 190 additions and 152 deletions
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
@ -172,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
|
|||
renderOp :: ParsedOp -> Doc
|
||||
renderOp pOp = stack $
|
||||
[ haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig pOp)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
-- Prevent unreasonably long compilation times on ghc-7.10, due
|
||||
-- to stack calling "-dump-hi" which (unnecessarily) includes the
|
||||
-- inlining information, and is large for ops with many arguments.
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
, "{-# NOINLINE " <> n <> "#-}"
|
||||
#endif
|
||||
, n <+> "::" <+> hang 0 (typeSig empty pOp)
|
||||
, n <+> "=" <+> n <> "' id"
|
||||
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
|
||||
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
<+> "=" </> -- args are indented
|
||||
-- the body needs to be indented wrt the name
|
||||
indent indentation (functionBody pOp)
|
||||
] ++ whereClause listSizeAttrs
|
||||
where
|
||||
n = renderHaskellName $ parsedOpName pOp
|
||||
n' = n <> "'"
|
||||
listSizeAttrs = inferredListSizeAttrs pOp
|
||||
args = sep $ map renderHaskellName
|
||||
args = sep $ "op'options"
|
||||
: (map renderHaskellName
|
||||
$ map attrName (explicitInputAttrs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp))
|
||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||
|
||||
-- | A check that all lists of the given size have the given length.
|
||||
|
@ -247,7 +258,9 @@ functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOp
|
|||
-- Renders sizes of tensor list types having number_attr.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||
]
|
||||
] ++
|
||||
["& op'options"]
|
||||
|
||||
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
inferredTypeExpr a
|
||||
|
@ -270,12 +283,12 @@ extras d = enclose "{-\n" "\n-}" $
|
|||
-- | The type signature for an op.
|
||||
-- Of the form:
|
||||
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
||||
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||
-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||
-- "Tensor t2 v2" is an output.
|
||||
typeSig :: ParsedOp -> Doc
|
||||
typeSig pOp = constraints
|
||||
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
typeSig :: Doc -> ParsedOp -> Doc
|
||||
typeSig pre pOp = constraints
|
||||
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
++ map tensorArgAndComment (parsedInputs pOp)
|
||||
++ [outputs])
|
||||
where
|
||||
|
|
|
@ -72,6 +72,7 @@ import TensorFlow.Ops
|
|||
, expandDims
|
||||
, fill
|
||||
, matMul
|
||||
, matMul'
|
||||
, reducedShape
|
||||
, reluGrad
|
||||
, reshape
|
||||
|
@ -95,7 +96,6 @@ import TensorFlow.Tensor
|
|||
, TensorKind (ValueKind)
|
||||
, Value
|
||||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
|
@ -532,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
let transposeA = lookupAttr nodeDef "transpose_a"
|
||||
transposeB = lookupAttr nodeDef "transpose_b"
|
||||
transAttrs a b =
|
||||
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
||||
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
|
||||
in case (transposeA, transposeB) of
|
||||
(False, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
[ Just $ matMul' (transAttrs False True) dz y
|
||||
, Just $ matMul' (transAttrs True False) x dz]
|
||||
(False, True) ->
|
||||
[ Just $ dz `matMul` y
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
[ Just $ matMul dz y
|
||||
, Just $ matMul' (transAttrs True False) x dz]
|
||||
(True, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ x `matMul` dz ]
|
||||
[ Just $ matMul' (transAttrs False True) dz y
|
||||
, Just $ matMul x dz]
|
||||
(True, True) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs True True
|
||||
, Just $ (x `matMul` dz) & transAttrs True True ]
|
||||
[ Just $ matMul' (transAttrs True True) dz y
|
||||
, Just $ matMul' (transAttrs True True) x dz]
|
||||
|
||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||
[ Just $ CoreOps.transpose dz
|
||||
|
@ -554,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
|||
]
|
||||
|
||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
[ Just $ CoreOps.conv2DBackpropInput'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
(shape x) y dz
|
||||
, Just $ CoreOps.conv2DBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x (shape y) dz
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -572,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||
[ Just $ CoreOps.maxPoolGrad x output dz
|
||||
& tensorAttr "ksize" .~ ksize
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
[ Just $ CoreOps.maxPoolGrad'
|
||||
((opAttr "ksize" .~ ksize)
|
||||
. (opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x output dz
|
||||
]
|
||||
where
|
||||
output :: Tensor Value a
|
||||
|
|
|
@ -58,56 +58,99 @@
|
|||
|
||||
module TensorFlow.Ops
|
||||
( CoreOps.add
|
||||
, CoreOps.add'
|
||||
, CoreOps.abs
|
||||
, CoreOps.abs'
|
||||
, CoreOps.addN
|
||||
, CoreOps.addN'
|
||||
, CoreOps.argMax
|
||||
, CoreOps.argMax'
|
||||
, CoreOps.assign
|
||||
, CoreOps.assign'
|
||||
, CoreOps.broadcastGradientArgs
|
||||
, CoreOps.broadcastGradientArgs'
|
||||
, CoreOps.cast
|
||||
, CoreOps.cast'
|
||||
, CoreOps.concat
|
||||
, CoreOps.concat'
|
||||
, constant
|
||||
, constant'
|
||||
, CoreOps.equal
|
||||
, CoreOps.equal'
|
||||
, expandDims
|
||||
, expandDims'
|
||||
, initializedVariable
|
||||
, initializedVariable'
|
||||
, zeroInitializedVariable
|
||||
, zeroInitializedVariable'
|
||||
, CoreOps.fill
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.fill'
|
||||
, CoreOps.identity
|
||||
, CoreOps.identity'
|
||||
, CoreOps.matMul
|
||||
, CoreOps.matMul'
|
||||
, matTranspose
|
||||
, matTranspose'
|
||||
, CoreOps.mean
|
||||
, CoreOps.mean'
|
||||
, CoreOps.mul
|
||||
, CoreOps.mul'
|
||||
, CoreOps.neg
|
||||
, CoreOps.neg'
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.oneHot'
|
||||
, CoreOps.pack
|
||||
, CoreOps.pack'
|
||||
, placeholder
|
||||
, placeholder'
|
||||
, CoreOps.range
|
||||
, CoreOps.range'
|
||||
, reducedShape
|
||||
, CoreOps.relu
|
||||
, CoreOps.relu'
|
||||
, CoreOps.reluGrad
|
||||
, CoreOps.reluGrad'
|
||||
, CoreOps.reshape
|
||||
, CoreOps.reshape'
|
||||
, restore
|
||||
, restoreFromName
|
||||
, save
|
||||
, scalar
|
||||
, scalar'
|
||||
, shape
|
||||
, shape'
|
||||
, CoreOps.sign
|
||||
, CoreOps.sign'
|
||||
, CoreOps.size
|
||||
, CoreOps.size'
|
||||
, CoreOps.softmax
|
||||
, CoreOps.softmax'
|
||||
, CoreOps.softmaxCrossEntropyWithLogits
|
||||
, CoreOps.softmaxCrossEntropyWithLogits'
|
||||
, CoreOps.sparseToDense
|
||||
, CoreOps.sparseToDense'
|
||||
, CoreOps.sub
|
||||
, CoreOps.sub'
|
||||
, CoreOps.sum
|
||||
, CoreOps.sum'
|
||||
, CoreOps.transpose
|
||||
, CoreOps.transpose'
|
||||
, truncatedNormal
|
||||
, truncatedNormal'
|
||||
, CoreOps.variable
|
||||
, CoreOps.variable'
|
||||
, vector
|
||||
, vector'
|
||||
, zeros
|
||||
, CoreOps.zerosLike
|
||||
, CoreOps.zerosLike'
|
||||
, scalarize
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.Word (Word16)
|
||||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
|
@ -151,29 +194,31 @@ instance ( TensorType a
|
|||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
matTranspose :: forall a v . TensorType a
|
||||
=> Tensor v a -> Tensor Value a
|
||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
|
||||
matTranspose = matTranspose' id
|
||||
|
||||
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
placeholder shape' =
|
||||
build $ buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
|
||||
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||
|
||||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
placeholder = placeholder' id
|
||||
|
||||
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
|
||||
placeholder' params pShape
|
||||
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
|
||||
initializedVariable :: (MonadBuild m, TensorType a)
|
||||
=> Tensor Value a -> m (Tensor Ref a)
|
||||
initializedVariable initializer = do
|
||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
build $ buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
& opAttr "validate_shape" .~ False
|
||||
)
|
||||
v initializer
|
||||
initializedVariable = initializedVariable' id
|
||||
|
||||
initializedVariable' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
|
||||
initializedVariable' params initializer = do
|
||||
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
||||
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
||||
initializer
|
||||
addInitializer =<< group i
|
||||
return v
|
||||
|
||||
|
@ -181,7 +226,12 @@ initializedVariable initializer = do
|
|||
zeroInitializedVariable
|
||||
:: (MonadBuild m, TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
zeroInitializedVariable = zeroInitializedVariable' id
|
||||
|
||||
zeroInitializedVariable'
|
||||
:: (MonadBuild m, TensorType a, Num a) =>
|
||||
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||
|
@ -227,23 +277,23 @@ restore path x = do
|
|||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant (Shape shape') values
|
||||
constant :: TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant = constant' id
|
||||
|
||||
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
|
||||
constant' params (Shape cShape) values
|
||||
| invalidLength = error invalidLengthMsg
|
||||
| otherwise = buildOp $ opDef "Const"
|
||||
& opAttr "value" .~ typedNode
|
||||
& opAttr "dtype" .~ nodeType
|
||||
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
||||
where
|
||||
invalidLength = product shape' /= fromIntegral (length values)
|
||||
invalidLength = product cShape /= fromIntegral (length values)
|
||||
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
||||
(product shape')
|
||||
(product cShape)
|
||||
(length values)
|
||||
nodeType = tensorType (undefined :: a)
|
||||
typedNode :: TensorProto
|
||||
typedNode = def
|
||||
& dtype .~ nodeType
|
||||
& dtype .~ tensorType (undefined :: a)
|
||||
& tensorShape.TensorShape.dim .~
|
||||
[def & TensorShape.size .~ x | x <- shape']
|
||||
[def & TensorShape.size .~ x | x <- cShape]
|
||||
& tensorVal .~ values
|
||||
|
||||
-- | Reshape a N-D tensor down to a scalar.
|
||||
|
@ -257,30 +307,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
|||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector xs = constant [fromIntegral $ length xs] xs
|
||||
vector = vector' id
|
||||
|
||||
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
|
||||
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
||||
|
||||
-- | Create a constant scalar.
|
||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||
scalar x = constant [] [x]
|
||||
scalar :: TensorType a => a -> Tensor Value a
|
||||
scalar = scalar' id
|
||||
|
||||
-- Random tensor from the unit normal distribution with bounded values.
|
||||
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
|
||||
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
|
||||
scalar' params x = constant' params [] [x]
|
||||
|
||||
-- | Random tensor from the unit normal distribution with bounded values.
|
||||
--
|
||||
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
|
||||
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||
=> Tensor v Int64 -- ^ Shape.
|
||||
-> m (Tensor Value a)
|
||||
truncatedNormal
|
||||
= build . buildOp (opDef "TruncatedNormal"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64))
|
||||
truncatedNormal = CoreOps.truncatedNormal
|
||||
|
||||
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||
=> OpParams -> Tensor v Int64 -- ^ Shape.
|
||||
-> m (Tensor Value a)
|
||||
truncatedNormal' = CoreOps.truncatedNormal'
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
||||
|
||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
shape = CoreOps.shape
|
||||
|
||||
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
|
||||
shape' = CoreOps.shape'
|
||||
|
||||
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims = CoreOps.expandDims
|
||||
|
||||
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims' = CoreOps.expandDims'
|
||||
|
||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Lens.Family2 ((^.))
|
||||
import Lens.Family2 ((^.), (.~))
|
||||
import Data.List (sort)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( node )
|
||||
|
@ -38,8 +38,8 @@ import TensorFlow.Build
|
|||
, withDevice
|
||||
, colocateWith
|
||||
, withNameScope
|
||||
, opName
|
||||
)
|
||||
import TensorFlow.ControlFlow (named)
|
||||
import TensorFlow.Types (unScalar)
|
||||
import TensorFlow.Ops
|
||||
( add
|
||||
|
@ -47,6 +47,7 @@ import TensorFlow.Ops
|
|||
, constant
|
||||
, initializedVariable
|
||||
, variable
|
||||
, variable'
|
||||
)
|
||||
import TensorFlow.Output (Device(..))
|
||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||
|
@ -61,26 +62,16 @@ import Test.HUnit ((@=?))
|
|||
import Google.Test (googleTest)
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test named behavior.
|
||||
testNamed :: Test
|
||||
testNamed = testCase "testNamed" $ do
|
||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||
-- | Test 'opName' behavior.
|
||||
testOpName :: Test
|
||||
testOpName = testCase "testOpName" $ do
|
||||
let graph = variable' (opName .~ "foo") []
|
||||
>>= render :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
"foo" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Test named deRef behavior.
|
||||
testNamedDeRef :: Test
|
||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||
let graph = named "foo" <$> do
|
||||
v :: Tensor Ref Float <- variable []
|
||||
assign v 5
|
||||
-- TODO: Implement TensorFlow get_variable and test it.
|
||||
runSession $ do
|
||||
out <- graph >>= run
|
||||
liftIO $ 5 @=? (unScalar out :: Float)
|
||||
|
||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||
-- been rendered.
|
||||
testPureRender :: Test
|
||||
|
@ -118,14 +109,15 @@ testNameScoped = testCase "testNameScoped" $ do
|
|||
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
|
||||
-- | Test combined named and nameScoped behavior.
|
||||
-- | Test combined opName and nameScoped behavior.
|
||||
testNamedAndScoped :: Test
|
||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||
let graph :: Build (Tensor Ref Float)
|
||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
||||
>>= render
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
"foo1/bar1" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||
|
@ -174,8 +166,7 @@ main :: IO ()
|
|||
main = googleTest [ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testDeviceColocation
|
||||
, testNamed
|
||||
, testNamedDeRef
|
||||
, testOpName
|
||||
, testNameScoped
|
||||
, testNamedAndScoped
|
||||
, testPureRender
|
||||
|
|
|
@ -19,6 +19,7 @@ module Main where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Google.Test (googleTest)
|
||||
import Lens.Family2 ((.~))
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -27,7 +28,6 @@ import qualified Data.ByteString.Char8 as B8
|
|||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
|
@ -56,7 +56,8 @@ testSaveRestore = testCase "testSaveRestore" $
|
|||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<<
|
||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||
TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||
(TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- var
|
||||
TF.assign v 134 >>= TF.run_
|
||||
|
|
|
@ -36,7 +36,6 @@ import qualified Data.ByteString as B
|
|||
import qualified Data.ByteString.Char8 as B8
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (select)
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
|
|
|
@ -23,6 +23,7 @@ module TensorFlow.BuildOp
|
|||
, buildOp
|
||||
, buildListOp
|
||||
, eqLengthGuard
|
||||
, OpParams
|
||||
)
|
||||
where
|
||||
|
||||
|
@ -238,3 +239,7 @@ eqLengthGuard = all eachOk
|
|||
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
||||
error ("number_attr " ++ numberAttrName ++
|
||||
" contains tensors with different length " ++ show pairs)
|
||||
|
||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||
-- TODO: be more type safe.
|
||||
type OpParams = OpDef -> OpDef
|
||||
|
|
|
@ -22,21 +22,15 @@ module TensorFlow.ControlFlow
|
|||
withControlDependencies
|
||||
, group
|
||||
-- * Operations
|
||||
, identity
|
||||
, noOp
|
||||
, named
|
||||
) where
|
||||
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((&), (^.), (.~))
|
||||
import Lens.Family2 ((&), (.~))
|
||||
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||
-- on the nodes in the first argument.
|
||||
|
@ -57,31 +51,6 @@ group deps = do
|
|||
-- TODO: slicker way
|
||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||
|
||||
|
||||
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
||||
identity :: TensorType a => Tensor v a -> Tensor v a
|
||||
identity = namedIdentity implicitName
|
||||
|
||||
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
||||
-- the input.
|
||||
--
|
||||
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
||||
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
||||
-- whether we can change that op.
|
||||
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
||||
named = namedIdentity . explicitName
|
||||
|
||||
-- | An internal version of "identity" that allows setting the name
|
||||
-- of the output Tensor.
|
||||
namedIdentity :: forall a v . TensorType a
|
||||
=> PendingNodeName -> Tensor v a -> Tensor v a
|
||||
namedIdentity n t = case t ^. tensorKind of
|
||||
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
||||
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
||||
where
|
||||
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
||||
|
||||
|
||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||
noOp :: ControlNode
|
||||
noOp = buildOp $ opDef "NoOp"
|
||||
|
|
|
@ -50,14 +50,14 @@ module TensorFlow.Core
|
|||
, render
|
||||
, asGraphDef
|
||||
, addGraphDef
|
||||
|
||||
, opName
|
||||
, opAttr
|
||||
-- * Tensor
|
||||
, ControlNode
|
||||
, Tensor
|
||||
, Value
|
||||
, Ref
|
||||
, TensorKind(..)
|
||||
, tensorAttr
|
||||
, value
|
||||
, tensorFromName
|
||||
-- ** Element types
|
||||
|
@ -74,12 +74,10 @@ module TensorFlow.Core
|
|||
, Device(..)
|
||||
, withDevice
|
||||
, withNameScope
|
||||
, named
|
||||
-- ** Dependencies
|
||||
, withControlDependencies
|
||||
, group
|
||||
-- ** Misc
|
||||
, identity
|
||||
, noOp
|
||||
) where
|
||||
|
||||
|
|
|
@ -124,6 +124,9 @@ data OpDef = OpDef
|
|||
data PendingNodeName = ExplicitName !Text | ImplicitName
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString PendingNodeName where
|
||||
fromString = ExplicitName . fromString
|
||||
|
||||
-- | The name of a node in the graph. This corresponds to the proto field
|
||||
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
||||
-- (if the node was implicitly named).
|
||||
|
|
|
@ -27,13 +27,12 @@ module TensorFlow.Tensor where
|
|||
|
||||
import Data.String (IsString(..))
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal', (^.))
|
||||
import Lens.Family2 (Lens', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
|
||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||
import TensorFlow.Output (Output)
|
||||
import TensorFlow.Types
|
||||
( TensorData(..)
|
||||
, Attribute
|
||||
, ListOf(..)
|
||||
)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
@ -62,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
|||
tensorOutput :: Lens' (Tensor v a) Output
|
||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||
|
||||
-- TODO: Come up with a better API for handling attributes.
|
||||
-- | Lens for the attributes of a tensor.
|
||||
--
|
||||
-- Only valid if the tensor has not yet been rendered. If the tensor has been
|
||||
-- rendered, the traversal will be over nothing (nothing can be read or
|
||||
-- written).
|
||||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||
|
||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||
-- Ref into Value. This behaves like a no-op.
|
||||
value :: Tensor v a -> Tensor Value a
|
||||
|
|
Loading…
Reference in a new issue