{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Tensor where
import Data.ByteString (ByteString)
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 ((^.))
import Lens.Family2.State ((%=), use)
import Proto.Tensorflow.Core.Framework.NodeDef (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types
( TensorType
, TensorData(..)
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
data Tensor v a where
Tensor :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
newtype Value a = Value {runValue :: a}
deriving Functor
instance Applicative Value where
pure = Value
Value f <*> Value x = Value $ f x
instance Monad Value where
f >>= g = g $ runValue f
newtype Ref a = Ref {runRef :: a}
deriving Functor
instance Applicative Ref where
pure = Ref
Ref f <*> Ref x = Ref $ f x
instance Monad Ref where
f >>= g = g $ runRef f
value :: Tensor Ref a -> Tensor Value a
value (Tensor o) = Tensor $ Value $ runRef o
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
renderValue (Tensor o) = render $ Tensor $ toBuild o
data Feed = Feed Output FFI.TensorData
class Rendered t where
renderedOutput :: t a -> Output
instance Rendered (Tensor Value) where
renderedOutput = runValue . tensorOutput
instance Rendered (Tensor Ref) where
renderedOutput = runRef . tensorOutput
tensorNodeName :: Rendered t => t a -> NodeName
tensorNodeName = outputNodeName . renderedOutput
feed :: Rendered t => t a -> TensorData a -> Feed
feed t (TensorData td) = Feed (renderedOutput t) td
tensorFromName :: TensorKind v => Text.Text -> Tensor v a
tensorFromName = Tensor . pure . fromString . Text.unpack
tensorValueFromName :: Text.Text -> Tensor Value a
tensorValueFromName = tensorFromName
tensorRefFromName :: Text.Text -> Tensor Ref a
tensorRefFromName = tensorFromName
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = renderedOutput t : tensorListOutputs ts
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
colocateWith t x = do
d <- build $ Device . (^. device)
<$> lookupNode (outputNodeName $ renderedOutput t)
withDevice (Just d) x
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
render (Tensor t) = Tensor . Value <$> build t
expr :: TensorKind v => Tensor v a -> Tensor Build a
expr (Tensor o) = Tensor $ toBuild o
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString
-> m ()
addSummary t = build $ do
o <- toBuild $ tensorOutput t
summaries %= (o :)
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
collectAllSummaries = build $ map (Tensor . Value) <$> use summaries
type SummaryTensor = Tensor Value ByteString
class Monad v => TensorKind v where
toBuild :: v a -> Build a
instance TensorKind Value where
toBuild = return . runValue
instance TensorKind Ref where
toBuild = return . runRef
instance TensorKind Build where
toBuild = id
class ToTensor t where
toTensor :: TensorType a => t a -> Tensor Build a
instance TensorKind v => ToTensor (Tensor v) where
toTensor = expr