Further nitpicks by blackgnezdo addressed.
This commit is contained in:
parent
f9477682c7
commit
5258303cba
|
@ -43,10 +43,10 @@ module TensorFlow.Convolution
|
|||
, depthwiseConv2dNativeBackpropInput'
|
||||
) where
|
||||
|
||||
import Data.Word(Word16)
|
||||
import Data.Int(Int32,Int64)
|
||||
import Data.ByteString(ByteString)
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import Data.Word (Word16)
|
||||
import Data.Int (Int32,Int64)
|
||||
import Data.ByteString (ByteString)
|
||||
import Lens.Family2 ((.~))
|
||||
|
||||
import qualified TensorFlow.BuildOp as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
|
@ -55,8 +55,13 @@ import qualified TensorFlow.GenOps.Core as TF
|
|||
-- TODO: Support other convolution parameters such as stride.
|
||||
|
||||
-- | Convolution padding.
|
||||
data Padding = PaddingValid -- ^ output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
||||
| PaddingSame -- ^ output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
|
||||
data Padding =
|
||||
-- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
||||
PaddingValid
|
||||
-- | output_spatial_shape[i] = ceil(
|
||||
-- (input_spatial_shape[i] -
|
||||
-- (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
|
||||
| PaddingSame
|
||||
|
||||
paddingToByteString :: Padding -> ByteString
|
||||
paddingToByteString x = case x of
|
||||
|
@ -68,10 +73,6 @@ data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHW
|
|||
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
|
||||
|
||||
-- TODO: Address 1D convolution.
|
||||
--dataFormat1D :: DataFormat -> ByteString
|
||||
--dataFormat1D x = case x of
|
||||
-- ChannelLast -> "NWC"
|
||||
-- ChannelFirst -> "NCW"
|
||||
|
||||
dataFormat2D :: DataFormat -> ByteString
|
||||
dataFormat2D x = case x of
|
||||
|
|
|
@ -30,9 +30,9 @@ module TensorFlow.Minimize
|
|||
, adam'
|
||||
) where
|
||||
|
||||
import Data.Complex
|
||||
import Data.Int
|
||||
import Data.Word
|
||||
import Data.Complex (Complex)
|
||||
import Data.Int (Int8,Int16,Int32,Int64)
|
||||
import Data.Word (Word8,Word16,Word32,Word64)
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Default (Default(..))
|
||||
import Data.List (zipWith4)
|
||||
|
|
Loading…
Reference in New Issue