module TensorFlow.Ops
( CoreOps.add
, CoreOps.abs
, CoreOps.addN
, CoreOps.argMax
, CoreOps.assign
, CoreOps.broadcastGradientArgs
, CoreOps.cast
, CoreOps.concat
, constant
, CoreOps.equal
, expandDims
, initializedVariable
, zeroInitializedVariable
, CoreOps.fill
, CoreOps.oneHot
, CoreOps.matMul
, matTranspose
, CoreOps.mean
, CoreOps.mul
, CoreOps.neg
, CoreOps.pack
, placeholder
, CoreOps.range
, reducedShape
, CoreOps.relu
, CoreOps.reluGrad
, CoreOps.reshape
, restore
, restoreFromName
, save
, scalar
, shape
, CoreOps.sign
, CoreOps.size
, CoreOps.softmax
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.sparseToDense
, CoreOps.sub
, CoreOps.sum
, CoreOps.transpose
, truncatedNormal
, CoreOps.variable
, vector
, zeros
, CoreOps.zerosLike
, scalarize
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
import Lens.Family2 ((.~), (&))
import Text.Printf (printf)
import Proto.Tensorflow.Core.Framework.Tensor
( TensorProto
, dtype
, tensorShape
)
import qualified Proto.Tensorflow.Core.Framework.TensorShape
as TensorShape
import TensorFlow.Build
import TensorFlow.BuildOp
import TensorFlow.ControlFlow (group)
import TensorFlow.Output (unNodeName)
import TensorFlow.Tensor
import TensorFlow.Types
import qualified TensorFlow.GenOps.Core as CoreOps
import qualified Prelude (abs)
instance ( TensorType a
, Num a
, v ~ Value
, OneOf '[ Double, Float, Int32, Int64
, Complex Float, Complex Double] a) => Num (Tensor v a) where
(+) = CoreOps.add
(*) = CoreOps.mul
() = CoreOps.sub
abs = CoreOps.abs
fromInteger = scalar . fromInteger
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable initializer = do
v <- CoreOps.variable []
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
)
v initializer
addInitializer =<< group i
return v
zeroInitializedVariable
:: (TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
save :: forall a v . TensorType a
=> ByteString
-> [Tensor v a]
-> Build ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs
restoreFromName :: forall a . TensorType a
=> ByteString
-> ByteString
-> Tensor Ref a
-> Build ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
group =<< CoreOps.assign x
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
restore :: forall a . TensorType a
=> ByteString
-> Tensor Ref a
-> Build ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x
restoreFromName path name x
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
constant (Shape shape') values
| invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const"
& opAttr "value" .~ typedNode
& opAttr "dtype" .~ nodeType
where
invalidLength = product shape' /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product shape')
(length values)
nodeType = tensorType (undefined :: a)
typedNode :: TensorProto
typedNode = def
& dtype .~ nodeType
& tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape']
& tensorVal .~ values
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape)
where
scalarShape = [] :: [Int32]
vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs
scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]
truncatedNormal :: forall a v . TensorType a
=> Tensor v Int64
-> Build (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
shape = CoreOps.shape
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims = CoreOps.expandDims
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
reducedShape inputShape axes =
let inputShape32 = toInt32 inputShape
axes32 = toInt32 axes
toInt32 x = CoreOps.cast x :: Tensor Value Int32
inputRank = CoreOps.size inputShape32
axesMod = (axes32 + inputRank) `CoreOps.mod` inputRank
axesShape = shape axesMod
in CoreOps.dynamicStitch
[CoreOps.range 0 inputRank 1,
axesMod]
[inputShape32,
CoreOps.fill axesShape 1]