Added ByteString as a possible argument to a function.

This commit is contained in:
jcmartin 2020-09-04 20:17:17 +00:00
parent b30d9a52c1
commit 43eebd22ad
4 changed files with 15 additions and 34 deletions

View File

@ -266,7 +266,7 @@ getExplicitInputAttr o implicitAttrs a
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
++ [AttrList AttrType] = Just t
| otherwise = Nothing

View File

@ -678,16 +678,14 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
padding (shape x) y dz
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
padding x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -699,16 +697,14 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
[ Nothing
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
dz (shape x) y
padding dz (shape x) y
, Just $ CoreOps.conv2D'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
dz x
padding dz x
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -719,14 +715,12 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
padding (shape x) y dz
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
padding x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -737,14 +731,12 @@ opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz]
[ Nothing
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
dz (shape x) y
padding dz (shape x) y
, Just $ CoreOps.depthwiseConv2dNative'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
dz x
padding dz x
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -755,9 +747,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize)
. (opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x output dz
padding x output dz
]
where
output :: Tensor Build a

View File

@ -89,8 +89,8 @@ module TensorFlow.Ops
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, einsum
, einsum'
, CoreOps.einsum
, CoreOps.einsum'
, matTranspose
, matTranspose'
, CoreOps.mean
@ -204,13 +204,6 @@ instance ( TensorType a
signum = CoreOps.sign
negate = CoreOps.neg
-- | Einstein summation
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
einsum = einsum' id
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id

View File

@ -716,10 +716,9 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
let y = TF.conv2DBackpropInput'
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
conv_input_shape filter' x
"VALID" conv_input_shape filter' x
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
@ -736,10 +735,9 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNative'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
x filter'
"VALID" x filter'
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)
@ -758,10 +756,9 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNativeBackpropInput'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
)
conv_input_shape filter' x
"VALID" conv_input_shape filter' x
[dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x)