1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Added ByteString as a possible argument to a function.

This commit is contained in:
jcmartin 2020-09-04 20:17:17 +00:00
parent b30d9a52c1
commit 43eebd22ad
4 changed files with 15 additions and 34 deletions

View File

@ -266,7 +266,7 @@ getExplicitInputAttr o implicitAttrs a
, a ^. maybe'defaultValue == Nothing , a ^. maybe'defaultValue == Nothing
, t <- parseAttrType o (a ^. type') , t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle , t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape] [AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape, AttrBytes]
++ [AttrList AttrType] = Just t ++ [AttrList AttrType] = Just t
| otherwise = Nothing | otherwise = Nothing

View File

@ -678,16 +678,14 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] = opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput' [ Just $ CoreOps.conv2DBackpropInput'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
(shape x) y dz padding (shape x) y dz
, Just $ CoreOps.conv2DBackpropFilter' , Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
x (shape y) dz padding x (shape y) dz
] ]
where where
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
@ -699,16 +697,14 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
[ Nothing [ Nothing
, Just $ CoreOps.conv2DBackpropFilter' , Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
dz (shape x) y padding dz (shape x) y
, Just $ CoreOps.conv2D' , Just $ CoreOps.conv2D'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu) . (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
dz x padding dz x
] ]
where where
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
@ -719,14 +715,12 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] = opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput' [ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
(shape x) y dz padding (shape x) y dz
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter' , Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
x (shape y) dz padding x (shape y) dz
] ]
where where
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
@ -737,14 +731,12 @@ opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz]
[ Nothing [ Nothing
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter' , Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
dz (shape x) y padding dz (shape x) y
, Just $ CoreOps.depthwiseConv2dNative' , Just $ CoreOps.depthwiseConv2dNative'
((opAttr "strides" .~ strides) ((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
dz x padding dz x
] ]
where where
strides = lookupAttr nodeDef "strides" :: [Int64] strides = lookupAttr nodeDef "strides" :: [Int64]
@ -755,9 +747,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad' [ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize) ((opAttr "ksize" .~ ksize)
. (opAttr "strides" .~ strides) . (opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat)) . (opAttr "data_format" .~ dataFormat))
x output dz padding x output dz
] ]
where where
output :: Tensor Build a output :: Tensor Build a

View File

@ -89,8 +89,8 @@ module TensorFlow.Ops
, CoreOps.identity' , CoreOps.identity'
, CoreOps.matMul , CoreOps.matMul
, CoreOps.matMul' , CoreOps.matMul'
, einsum , CoreOps.einsum
, einsum' , CoreOps.einsum'
, matTranspose , matTranspose
, matTranspose' , matTranspose'
, CoreOps.mean , CoreOps.mean
@ -204,13 +204,6 @@ instance ( TensorType a
signum = CoreOps.sign signum = CoreOps.sign
negate = CoreOps.neg negate = CoreOps.neg
-- | Einstein summation
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
einsum = einsum' id
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
matTranspose :: TensorType a => Tensor e a -> Tensor Build a matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id matTranspose = matTranspose' id

View File

@ -716,10 +716,9 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float)) filter' <- TF.render $ TF.fill filterShape (TF.scalar (1::Float))
let y = TF.conv2DBackpropInput' let y = TF.conv2DBackpropInput'
( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1]) ( (TF.opAttr "strides" .~ [1::Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC")) . (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
) )
conv_input_shape filter' x "VALID" conv_input_shape filter' x
[dx] <- TF.gradients y [x] [dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x) TF.run (dx, TF.shape dx, TF.shape x)
@ -736,10 +735,9 @@ testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float)) filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNative' let y = TF.depthwiseConv2dNative'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1]) ( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC")) . (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
) )
x filter' "VALID" x filter'
[dx] <- TF.gradients y [x] [dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x) TF.run (dx, TF.shape dx, TF.shape x)
@ -758,10 +756,9 @@ testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInpu
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float)) filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
let y = TF.depthwiseConv2dNativeBackpropInput' let y = TF.depthwiseConv2dNativeBackpropInput'
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1]) ( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
. (TF.opAttr "data_format" .~ (BS.pack "NHWC")) . (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
) )
conv_input_shape filter' x "VALID" conv_input_shape filter' x
[dx] <- TF.gradients y [x] [dx] <- TF.gradients y [x]
TF.run (dx, TF.shape dx, TF.shape x) TF.run (dx, TF.shape dx, TF.shape x)