Added ByteString as a possible argument to a function.
This commit is contained in:
parent
b30d9a52c1
commit
43eebd22ad
|
@ -266,7 +266,7 @@ getExplicitInputAttr o implicitAttrs a
|
|||
, a ^. maybe'defaultValue == Nothing
|
||||
, t <- parseAttrType o (a ^. type')
|
||||
, t `elem` map AttrSingle
|
||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
|
||||
++ [AttrList AttrType] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
|
|
|
@ -678,16 +678,14 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
|||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ CoreOps.conv2DBackpropInput'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
(shape x) y dz
|
||||
padding (shape x) y dz
|
||||
, Just $ CoreOps.conv2DBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x (shape y) dz
|
||||
padding x (shape y) dz
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -699,16 +697,14 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
|||
[ Nothing
|
||||
, Just $ CoreOps.conv2DBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz (shape x) y
|
||||
padding dz (shape x) y
|
||||
, Just $ CoreOps.conv2D'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz x
|
||||
padding dz x
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -719,14 +715,12 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
|||
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
(shape x) y dz
|
||||
padding (shape x) y dz
|
||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x (shape y) dz
|
||||
padding x (shape y) dz
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -737,14 +731,12 @@ opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz]
|
|||
[ Nothing
|
||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz (shape x) y
|
||||
padding dz (shape x) y
|
||||
, Just $ CoreOps.depthwiseConv2dNative'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
dz x
|
||||
padding dz x
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -755,9 +747,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
|||
[ Just $ CoreOps.maxPoolGrad'
|
||||
((opAttr "ksize" .~ ksize)
|
||||
. (opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x output dz
|
||||
padding x output dz
|
||||
]
|
||||
where
|
||||
output :: Tensor Build a
|
||||
|
|
|
@ -89,8 +89,8 @@ module TensorFlow.Ops
|
|||
, CoreOps.identity'
|
||||
, CoreOps.matMul
|
||||
, CoreOps.matMul'
|
||||
, einsum
|
||||
, einsum'
|
||||
, CoreOps.einsum
|
||||
, CoreOps.einsum'
|
||||
, matTranspose
|
||||
, matTranspose'
|
||||
, CoreOps.mean
|
||||
|
@ -204,13 +204,6 @@ instance ( TensorType a
|
|||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
-- | Einstein summation
|
||||
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
|
||||
einsum = einsum' id
|
||||
|
||||
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
|
||||
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
|
||||
|
||||
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
||||
matTranspose = matTranspose' id
|
||||
|
||||
|
|
|
@ -716,10 +716,9 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
|||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
||||
let y = TF.conv2DBackpropInput'
|
||||
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||
)
|
||||
conv_input_shape filter' x
|
||||
"VALID" conv_input_shape filter' x
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.run (dx, TF.shape dx, TF.shape x)
|
||||
|
@ -736,10 +735,9 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
|
|||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||
let y = TF.depthwiseConv2dNative'
|
||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||
)
|
||||
x filter'
|
||||
"VALID" x filter'
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.run (dx, TF.shape dx, TF.shape x)
|
||||
|
@ -758,10 +756,9 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
|
|||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||
let y = TF.depthwiseConv2dNativeBackpropInput'
|
||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||
)
|
||||
conv_input_shape filter' x
|
||||
"VALID" conv_input_shape filter' x
|
||||
|
||||
[dx] <- TF.gradients y [x]
|
||||
TF.run (dx, TF.shape dx, TF.shape x)
|
||||
|
|
Loading…
Reference in New Issue