Further nitpicks by blackgnezdo addressed.

This commit is contained in:
jcmartin 2020-11-05 20:52:38 +00:00
parent f9477682c7
commit 5258303cba
2 changed files with 14 additions and 13 deletions

View File

@ -43,10 +43,10 @@ module TensorFlow.Convolution
, depthwiseConv2dNativeBackpropInput'
) where
import Data.Word(Word16)
import Data.Int(Int32,Int64)
import Data.ByteString(ByteString)
import Lens.Family2 ((.~), (&))
import Data.Word (Word16)
import Data.Int (Int32,Int64)
import Data.ByteString (ByteString)
import Lens.Family2 ((.~))
import qualified TensorFlow.BuildOp as TF
import qualified TensorFlow.Core as TF
@ -55,8 +55,13 @@ import qualified TensorFlow.GenOps.Core as TF
-- TODO: Support other convolution parameters such as stride.
-- | Convolution padding.
data Padding = PaddingValid -- ^ output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
| PaddingSame -- ^ output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
data Padding =
-- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
PaddingValid
-- | output_spatial_shape[i] = ceil(
-- (input_spatial_shape[i] -
-- (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
| PaddingSame
paddingToByteString :: Padding -> ByteString
paddingToByteString x = case x of
@ -68,10 +73,6 @@ data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHW
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
-- TODO: Address 1D convolution.
--dataFormat1D :: DataFormat -> ByteString
--dataFormat1D x = case x of
-- ChannelLast -> "NWC"
-- ChannelFirst -> "NCW"
dataFormat2D :: DataFormat -> ByteString
dataFormat2D x = case x of

View File

@ -30,9 +30,9 @@ module TensorFlow.Minimize
, adam'
) where
import Data.Complex
import Data.Int
import Data.Word
import Data.Complex (Complex)
import Data.Int (Int8,Int16,Int32,Int64)
import Data.Word (Word8,Word16,Word32,Word64)
import Control.Monad (zipWithM)
import Data.Default (Default(..))
import Data.List (zipWith4)