1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

alignment, removed some typing

This commit is contained in:
silky 2016-11-18 08:02:41 +11:00
parent a0354c0169
commit 828a257c80
2 changed files with 11 additions and 13 deletions

View File

@ -407,13 +407,13 @@ toT = Tensor ValueKind
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
-- simple slicing operations.
flatSlice :: forall v1 t i . (TensorType t)
=> Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from.
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
-- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin).
=> Tensor v1 t -- ^ __input__
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
-- 'input' to slice from.
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
-- are included in the slice (i.e. this is equivalent to setting
-- size = input.dim_size(0) - begin).
-> Tensor Value t -- ^ __output__
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
@ -446,9 +446,9 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
where
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
denseShape = shape (x :: Tensor Value a)
numRows = scalarize $ flatSlice denseShape (0 :: Int32) 1
numRows = scalarize $ flatSlice denseShape 0 1
valuesShape = CoreOps.concat 0 [ allDimensions
, flatSlice denseShape 1 (-1 :: Int32)
, flatSlice denseShape 1 (-1)
]
values = reshape dz valuesShape
-- TODO(fmayle): This could be either Int32 or Int64.
@ -643,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
, Nothing
, Nothing
]
where inputRows = flatSlice (shape (x :: Tensor Value a)) (0 :: Int32) 1
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
opGrad "LabelWeights" _ _ _ = [Nothing]

View File

@ -260,9 +260,7 @@ constant (Shape shape') values
-- | Reshape a N-D tensor down to a scalar.
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a)
=> Tensor v a
-> Tensor Value a
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape)
where
scalarShape = [] :: [Int32]