mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Further nitpicks by blackgnezdo addressed.
This commit is contained in:
parent
f9477682c7
commit
5258303cba
|
@ -43,10 +43,10 @@ module TensorFlow.Convolution
|
||||||
, depthwiseConv2dNativeBackpropInput'
|
, depthwiseConv2dNativeBackpropInput'
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.Word(Word16)
|
import Data.Word (Word16)
|
||||||
import Data.Int(Int32,Int64)
|
import Data.Int (Int32,Int64)
|
||||||
import Data.ByteString(ByteString)
|
import Data.ByteString (ByteString)
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~))
|
||||||
|
|
||||||
import qualified TensorFlow.BuildOp as TF
|
import qualified TensorFlow.BuildOp as TF
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
|
@ -55,8 +55,13 @@ import qualified TensorFlow.GenOps.Core as TF
|
||||||
-- TODO: Support other convolution parameters such as stride.
|
-- TODO: Support other convolution parameters such as stride.
|
||||||
|
|
||||||
-- | Convolution padding.
|
-- | Convolution padding.
|
||||||
data Padding = PaddingValid -- ^ output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
data Padding =
|
||||||
| PaddingSame -- ^ output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
|
-- | output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
|
||||||
|
PaddingValid
|
||||||
|
-- | output_spatial_shape[i] = ceil(
|
||||||
|
-- (input_spatial_shape[i] -
|
||||||
|
-- (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i])
|
||||||
|
| PaddingSame
|
||||||
|
|
||||||
paddingToByteString :: Padding -> ByteString
|
paddingToByteString :: Padding -> ByteString
|
||||||
paddingToByteString x = case x of
|
paddingToByteString x = case x of
|
||||||
|
@ -68,10 +73,6 @@ data DataFormat = ChannelLast -- ^ Channel is the last dimension (e.g. NWC, NHW
|
||||||
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
|
| ChannelFirst -- ^ Channel is the first dimension after N (e.g. NCW, NCHW, NCDHW)
|
||||||
|
|
||||||
-- TODO: Address 1D convolution.
|
-- TODO: Address 1D convolution.
|
||||||
--dataFormat1D :: DataFormat -> ByteString
|
|
||||||
--dataFormat1D x = case x of
|
|
||||||
-- ChannelLast -> "NWC"
|
|
||||||
-- ChannelFirst -> "NCW"
|
|
||||||
|
|
||||||
dataFormat2D :: DataFormat -> ByteString
|
dataFormat2D :: DataFormat -> ByteString
|
||||||
dataFormat2D x = case x of
|
dataFormat2D x = case x of
|
||||||
|
|
|
@ -30,9 +30,9 @@ module TensorFlow.Minimize
|
||||||
, adam'
|
, adam'
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.Complex
|
import Data.Complex (Complex)
|
||||||
import Data.Int
|
import Data.Int (Int8,Int16,Int32,Int64)
|
||||||
import Data.Word
|
import Data.Word (Word8,Word16,Word32,Word64)
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Default (Default(..))
|
import Data.Default (Default(..))
|
||||||
import Data.List (zipWith4)
|
import Data.List (zipWith4)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user