tensorflow-haskell/tensorflow/src/TensorFlow/BuildOp.hs

200 lines
6.4 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
( OpResult
, BuildOp
, buildOp
, buildListOp
, eqLengthGuard
)
where
import Control.Monad (replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put)
import Data.Int (Int64)
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT Op (State ResultState)
-- | Class of types that can be used as op outputs.
class OpResult a where
toResult :: Result a
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
toResult = (,) <$> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
toResult = (,,) <$> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
=> OpResult (a1, a2, a3, a4) where
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
=> OpResult (a1, a2, a3, a4, a5) where
toResult = (,,,,) <$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
instance ( OpResult a1
, OpResult a2
, OpResult a3
, OpResult a4
, OpResult a5
, OpResult a6
)
=> OpResult (a1, a2, a3, a4, a5, a6) where
toResult = (,,,,,)
<$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = do
o <- ask
ResultState i ns <- get
put $! ResultState (i+1) ns
return $! Tensor v $ output i o
instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind
instance OpResult (Tensor Ref a) where
toResult = tensorResult RefKind
instance OpResult ControlNode where
toResult = ControlNode <$> ask
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in toResult. " ++
"Likely misuse of buildListOp."
(n : rest) -> do
put $! ResultState i rest
replicateM (fromIntegral n) toResult
runResult :: OpResult a => [Int64] -> Op -> a
runResult ns o =
case runState (runReaderT toResult o) (ResultState 0 ns) of
(x, ResultState _ []) -> x
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
show (ns, ns')
-- | Make a new "pure" op, which may be deduped with identical ops within
-- the same scope.
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
-- | Make a new "stateful" op, which will not be deduped with otherwise
-- identical ops.
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
buildResult ns o ts
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
addReversedInputs :: OpDef -> [Output] -> OpDef
addReversedInputs o ts = o & opInputs <>~ reverse ts
-- | Class of types that can be used as op functions.
class BuildOp f where
buildOp' :: [Int64] -- ^ Sizes of list results (having number_attr)
-> OpDef
-> [Output] -- ^ Accumulator for inputs to the op.
-> f
-- | Starts an operation that returns a structured set of tensors
-- (singletons or tuples).
buildOp :: BuildOp f => OpDef -> f
buildOp o = buildOp' [] o []
-- | Starts an operation that returns a list of tensors.
buildListOp :: BuildOp f => [Int64]
-- ^ Cardinality of the corresponding list of tensors output.
-> OpDef -> f
buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
=> BuildOp (t1, t2, t3, t4) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
=> BuildOp (t1, t2, t3, t4, t5) where
buildOp' = pureResult
instance ( OpResult t1
, OpResult t2
, OpResult t3
, OpResult t4
, OpResult t5
, OpResult t6
)
=> BuildOp (t1, t2, t3, t4, t5, t6) where
buildOp' = pureResult
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = all eachOk
where
eachOk (_, []) = True
-- The next line has (== 1) . length . nub in disguise
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)