mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Added ByteString as a possible argument to a function.
This commit is contained in:
parent
b30d9a52c1
commit
43eebd22ad
|
@ -266,7 +266,7 @@ getExplicitInputAttr o implicitAttrs a
|
||||||
, a ^. maybe'defaultValue == Nothing
|
, a ^. maybe'defaultValue == Nothing
|
||||||
, t <- parseAttrType o (a ^. type')
|
, t <- parseAttrType o (a ^. type')
|
||||||
, t `elem` map AttrSingle
|
, t `elem` map AttrSingle
|
||||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
|
||||||
++ [AttrList AttrType] = Just t
|
++ [AttrList AttrType] = Just t
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
|
|
|
@ -678,16 +678,14 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ CoreOps.conv2DBackpropInput'
|
[ Just $ CoreOps.conv2DBackpropInput'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
(shape x) y dz
|
padding (shape x) y dz
|
||||||
, Just $ CoreOps.conv2DBackpropFilter'
|
, Just $ CoreOps.conv2DBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x (shape y) dz
|
padding x (shape y) dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -699,16 +697,14 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
[ Nothing
|
[ Nothing
|
||||||
, Just $ CoreOps.conv2DBackpropFilter'
|
, Just $ CoreOps.conv2DBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz (shape x) y
|
padding dz (shape x) y
|
||||||
, Just $ CoreOps.conv2D'
|
, Just $ CoreOps.conv2D'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz x
|
padding dz x
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -719,14 +715,12 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
|
||||||
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
|
||||||
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
(shape x) y dz
|
padding (shape x) y dz
|
||||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x (shape y) dz
|
padding x (shape y) dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -737,14 +731,12 @@ opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz]
|
||||||
[ Nothing
|
[ Nothing
|
||||||
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz (shape x) y
|
padding dz (shape x) y
|
||||||
, Just $ CoreOps.depthwiseConv2dNative'
|
, Just $ CoreOps.depthwiseConv2dNative'
|
||||||
((opAttr "strides" .~ strides)
|
((opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
dz x
|
padding dz x
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||||
|
@ -755,9 +747,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||||
[ Just $ CoreOps.maxPoolGrad'
|
[ Just $ CoreOps.maxPoolGrad'
|
||||||
((opAttr "ksize" .~ ksize)
|
((opAttr "ksize" .~ ksize)
|
||||||
. (opAttr "strides" .~ strides)
|
. (opAttr "strides" .~ strides)
|
||||||
. (opAttr "padding" .~ padding)
|
|
||||||
. (opAttr "data_format" .~ dataFormat))
|
. (opAttr "data_format" .~ dataFormat))
|
||||||
x output dz
|
padding x output dz
|
||||||
]
|
]
|
||||||
where
|
where
|
||||||
output :: Tensor Build a
|
output :: Tensor Build a
|
||||||
|
|
|
@ -89,8 +89,8 @@ module TensorFlow.Ops
|
||||||
, CoreOps.identity'
|
, CoreOps.identity'
|
||||||
, CoreOps.matMul
|
, CoreOps.matMul
|
||||||
, CoreOps.matMul'
|
, CoreOps.matMul'
|
||||||
, einsum
|
, CoreOps.einsum
|
||||||
, einsum'
|
, CoreOps.einsum'
|
||||||
, matTranspose
|
, matTranspose
|
||||||
, matTranspose'
|
, matTranspose'
|
||||||
, CoreOps.mean
|
, CoreOps.mean
|
||||||
|
@ -204,13 +204,6 @@ instance ( TensorType a
|
||||||
signum = CoreOps.sign
|
signum = CoreOps.sign
|
||||||
negate = CoreOps.neg
|
negate = CoreOps.neg
|
||||||
|
|
||||||
-- | Einstein summation
|
|
||||||
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
|
|
||||||
einsum = einsum' id
|
|
||||||
|
|
||||||
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
|
|
||||||
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
|
|
||||||
|
|
||||||
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
||||||
matTranspose = matTranspose' id
|
matTranspose = matTranspose' id
|
||||||
|
|
||||||
|
|
|
@ -716,10 +716,9 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
|
||||||
let y = TF.conv2DBackpropInput'
|
let y = TF.conv2DBackpropInput'
|
||||||
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||||
)
|
)
|
||||||
conv_input_shape filter' x
|
"VALID" conv_input_shape filter' x
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
@ -736,10 +735,9 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
let y = TF.depthwiseConv2dNative'
|
let y = TF.depthwiseConv2dNative'
|
||||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||||
)
|
)
|
||||||
x filter'
|
"VALID" x filter'
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
@ -758,10 +756,9 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
|
||||||
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
|
||||||
let y = TF.depthwiseConv2dNativeBackpropInput'
|
let y = TF.depthwiseConv2dNativeBackpropInput'
|
||||||
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
|
||||||
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
|
|
||||||
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
|
||||||
)
|
)
|
||||||
conv_input_shape filter' x
|
"VALID" conv_input_shape filter' x
|
||||||
|
|
||||||
[dx] <- TF.gradients y [x]
|
[dx] <- TF.gradients y [x]
|
||||||
TF.run (dx, TF.shape dx, TF.shape x)
|
TF.run (dx, TF.shape dx, TF.shape x)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user