1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Further nitpicks by blackgnezdo addressed.

This commit is contained in:
jcmartin 2020-11-05 20:52:38 +00:00
parent f9477682c7
commit 5258303cba
2 changed files with 14 additions and 13 deletions

View File

@ -43,10 +43,10 @@ module TensorFlow.Convolution
, depthwiseConv2dNativeBackpropInput' , depthwiseConv2dNativeBackpropInput'
) where ) where
import Data.Word(Word16) import Data.Word (Word16)
import Data.Int(Int32,Int64) import Data.Int (Int32,Int64)
import Data.ByteString(ByteString) import Data.ByteString (ByteString)
import Lens.Family2 ((.~), (&)) import Lens.Family2 ((.~))
import qualified TensorFlow.BuildOp as TF import qualified TensorFlow.BuildOp as TF
import qualified TensorFlow.Core as TF import qualified TensorFlow.Core as TF
@ -55,8 +55,13 @@ import qualified TensorFlow.GenOps.Core as TF
-- TODO: Support other convolution parameters such as stride. -- TODO: Support other convolution parameters such as stride.
-- | Convolution padding. -- | Convolution padding.
data Padding = PaddingValid -- ^ output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) data Padding =
| PaddingSame -- ^ output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i]) -- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
PaddingValid
-- | output_spatial_shape[i] = ceil(
-- (input_spatial_shape[i] -
-- (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
| PaddingSame
paddingToByteString :: Padding -> ByteString paddingToByteString :: Padding -> ByteString
paddingToByteString x = case x of paddingToByteString x = case x of
@ -68,10 +73,6 @@ data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHW
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW) | ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
-- TODO: Address 1D convolution. -- TODO: Address 1D convolution.
--dataFormat1D :: DataFormat -> ByteString
--dataFormat1D x = case x of
-- ChannelLast -> "NWC"
-- ChannelFirst -> "NCW"
dataFormat2D :: DataFormat -> ByteString dataFormat2D :: DataFormat -> ByteString
dataFormat2D x = case x of dataFormat2D x = case x of

View File

@ -30,9 +30,9 @@ module TensorFlow.Minimize
, adam' , adam'
) where ) where
import Data.Complex import Data.Complex (Complex)
import Data.Int import Data.Int (Int8,Int16,Int32,Int64)
import Data.Word import Data.Word (Word8,Word16,Word32,Word64)
import Control.Monad (zipWithM) import Control.Monad (zipWithM)
import Data.Default (Default(..)) import Data.Default (Default(..))
import Data.List (zipWith4) import Data.List (zipWith4)