module TensorFlow.Output
( ControlNode(..)
, Device(..)
, NodeName(..)
, Op(..)
, opUnrendered
, OpDef(..)
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, OpType(..)
, OutputIx(..)
, Output(..)
, output
, outputIndex
, outputOp
, PendingNodeName(..)
) where
import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans ()
newtype ControlNode = ControlNode { unControlNode :: Op }
newtype OpType = OpType { unOpType :: Text }
deriving (Eq, Ord, Show)
instance IsString OpType where
fromString = OpType . Text.pack
data Output = Output !OutputIx !Op
deriving (Eq, Ord, Show)
output :: OutputIx -> Op -> Output
output = Output
outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
newtype Device = Device {deviceName :: Text}
deriving (Eq, Ord, IsString)
instance Show Device where
show (Device d) = show d
data Op
= Rendered !NodeDef
| Unrendered !OpDef
deriving (Eq, Ord)
instance Show Op where
show (Rendered n) = "Rendered " ++ showMessage n
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)
data OpDef = OpDef
{ _opName :: !PendingNodeName
, _opType :: !OpType
, _opAttrs :: !(Map.Map Text AttrValue)
, _opInputs :: [Output]
, _opControlInputs :: [NodeName]
} deriving (Eq, Ord)
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
newtype NodeName = NodeName { unNodeName :: Text }
deriving (Eq, Ord, Show)
opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})
opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault def n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr)
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n