mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
alignment, removed some typing
This commit is contained in:
parent
a0354c0169
commit
828a257c80
|
@ -407,13 +407,13 @@ toT = Tensor ValueKind
|
|||
-- | Wrapper around `TensorFlow.GenOps.Core.slice` that builds vectors from scalars for
|
||||
-- simple slicing operations.
|
||||
flatSlice :: forall v1 t i . (TensorType t)
|
||||
=> Tensor v1 t -- ^ __input__
|
||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||
-- 'input' to slice from.
|
||||
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
|
||||
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||
-- are included in the slice (i.e. this is equivalent to setting
|
||||
-- size = input.dim_size(0) - begin).
|
||||
=> Tensor v1 t -- ^ __input__
|
||||
-> Int32 -- ^ __begin__: specifies the offset into the first dimension of
|
||||
-- 'input' to slice from.
|
||||
-> Int32 -- ^ __size__: specifies the number of elements of the first dimension
|
||||
-- of 'input' to slice. If size is -1, all remaining elements in the dimension
|
||||
-- are included in the slice (i.e. this is equivalent to setting
|
||||
-- size = input.dim_size(0) - begin).
|
||||
-> Tensor Value t -- ^ __output__
|
||||
flatSlice input begin size = CoreOps.slice input (vector [begin]) (vector [size])
|
||||
|
||||
|
@ -446,9 +446,9 @@ opGrad "Gather" _ [toT -> x, toT -> indices] [dz] =
|
|||
where
|
||||
-- TODO(gnezdo): Use colocateWith but it requires Build monad.
|
||||
denseShape = shape (x :: Tensor Value a)
|
||||
numRows = scalarize $ flatSlice denseShape (0 :: Int32) 1
|
||||
numRows = scalarize $ flatSlice denseShape 0 1
|
||||
valuesShape = CoreOps.concat 0 [ allDimensions
|
||||
, flatSlice denseShape 1 (-1 :: Int32)
|
||||
, flatSlice denseShape 1 (-1)
|
||||
]
|
||||
values = reshape dz valuesShape
|
||||
-- TODO(fmayle): This could be either Int32 or Int64.
|
||||
|
@ -643,7 +643,7 @@ opGrad "SparseSegmentSum" _ [toT -> x, toT -> y, toT -> t] [dz] =
|
|||
, Nothing
|
||||
, Nothing
|
||||
]
|
||||
where inputRows = flatSlice (shape (x :: Tensor Value a)) (0 :: Int32) 1
|
||||
where inputRows = flatSlice (shape (x :: Tensor Value a)) 0 1
|
||||
|
||||
opGrad "LabelClasses" _ _ _ = [Nothing, Nothing]
|
||||
opGrad "LabelWeights" _ _ _ = [Nothing]
|
||||
|
|
|
@ -260,9 +260,7 @@ constant (Shape shape') values
|
|||
-- | Reshape a N-D tensor down to a scalar.
|
||||
--
|
||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||
scalarize :: (TensorType a)
|
||||
=> Tensor v a
|
||||
-> Tensor Value a
|
||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||
where
|
||||
scalarShape = [] :: [Int32]
|
||||
|
|
Loading…
Reference in New Issue
Block a user