mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-28 12:39:46 +01:00
168 lines
5.2 KiB
Haskell
168 lines
5.2 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
{-# LANGUAGE CPP #-}
|
|
|
|
-- | Generates the wrappers for Ops shipped with tensorflow.
|
|
module Main where
|
|
|
|
import Distribution.PackageDescription
|
|
( PackageDescription(..)
|
|
, libBuildInfo
|
|
, hsSourceDirs
|
|
)
|
|
import qualified Distribution.Simple.BuildPaths as BuildPaths
|
|
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
|
|
import Distribution.Simple
|
|
( defaultMainWithHooks
|
|
, simpleUserHooks
|
|
, UserHooks(..)
|
|
)
|
|
import Data.List (intercalate)
|
|
import Data.ProtoLens (decodeMessage)
|
|
import System.Directory (createDirectoryIfMissing)
|
|
import System.Exit (exitFailure)
|
|
import System.FilePath ((</>))
|
|
import System.IO (hPutStrLn, stderr)
|
|
import TensorFlow.Internal.FFI (getAllOpList)
|
|
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
|
|
import Text.PrettyPrint.Mainland (prettyLazyText)
|
|
import qualified Data.Text.Lazy.IO as Text
|
|
|
|
main = defaultMainWithHooks generatingOpsWrappers
|
|
|
|
-- TODO: Generalize for user libraries by replacing getAllOpList with
|
|
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
|
|
-- between bazel and Haskell build system.
|
|
generatingOpsWrappers :: UserHooks
|
|
generatingOpsWrappers = hooks
|
|
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
|
|
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
|
|
, replHook = \p l h f args -> generateSources l
|
|
>> replHook hooks p l h f args
|
|
}
|
|
where
|
|
flagsBuilder dir = OpGenFlags
|
|
{ outputFile = dir </> "Core.hs"
|
|
, prefix = "TensorFlow.GenOps"
|
|
, excludeList = intercalate "," blackList
|
|
}
|
|
hooks = simpleUserHooks
|
|
generateSources :: LocalBuildInfo -> IO ()
|
|
generateSources l = do
|
|
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
|
|
createDirectoryIfMissing True dir
|
|
let flags = flagsBuilder dir
|
|
pb <- getAllOpList
|
|
case decodeMessage pb of
|
|
Left e -> hPutStrLn stderr e >> exitFailure
|
|
Right x -> Text.writeFile (outputFile flags)
|
|
(prettyLazyText 80 $ docOpList flags x)
|
|
|
|
-- | Add the autogen directory to the hs-source-dirs of all the targets in the
|
|
-- .cabal file. Used to fool 'sdist' by pointing it to the generated source
|
|
-- files.
|
|
fudgePackageDesc
|
|
:: LocalBuildInfo -> PackageDescription -> PackageDescription
|
|
fudgePackageDesc lbi p = p
|
|
{ library =
|
|
(\lib -> lib { libBuildInfo = fudgeBuildInfo (libBuildInfo lib) })
|
|
<$> library p
|
|
}
|
|
where
|
|
fudgeBuildInfo bi =
|
|
bi { hsSourceDirs = autogenModulesDir lbi : hsSourceDirs bi }
|
|
|
|
blackList =
|
|
[ -- Requires the "func" type:
|
|
"FilterDataset"
|
|
, "BatchFunction"
|
|
, "Case"
|
|
, "ChooseFastestBranchDataset"
|
|
, "ExperimentalGroupByReducerDataset"
|
|
, "ExperimentalGroupByWindowDataset"
|
|
, "ExperimentalMapAndBatchDataset"
|
|
, "ExperimentalMapDataset"
|
|
, "ExperimentalNumaMapAndBatchDataset"
|
|
, "ExperimentalParallelInterleaveDataset"
|
|
, "ExperimentalScanDataset"
|
|
, "ExperimentalTakeWhileDataset"
|
|
, "FilterDataset"
|
|
, "FlatMapDataset"
|
|
, "For"
|
|
, "GeneratorDataset"
|
|
, "GroupByReducerDataset"
|
|
, "GroupByWindowDataset"
|
|
, "If"
|
|
, "InterleaveDataset"
|
|
, "LegacyParallelInterleaveDatasetV2"
|
|
, "LoadDataset"
|
|
, "MapAndBatchDataset"
|
|
, "MapAndBatchDatasetV2"
|
|
, "MapDataset"
|
|
, "MapDefun"
|
|
, "OneShotIterator"
|
|
, "ParallelInterleaveDataset"
|
|
, "ParallelInterleaveDatasetV2"
|
|
, "ParallelInterleaveDatasetV3"
|
|
, "ParallelInterleaveDatasetV4"
|
|
, "ParallelMapDataset"
|
|
, "ParallelMapDatasetV2"
|
|
, "ParseSequenceExample"
|
|
, "ParseSequenceExampleV2"
|
|
, "ParseSingleSequenceExample"
|
|
, "PartitionedCall"
|
|
, "ReduceDataset"
|
|
, "RemoteCall"
|
|
, "SaveDataset"
|
|
, "ScanDataset"
|
|
, "SnapshotDatasetV2"
|
|
, "StatefulPartitionedCall"
|
|
, "StatelessCase"
|
|
, "StatelessIf"
|
|
, "StatelessWhile"
|
|
, "SymbolicGradient"
|
|
, "TakeWhileDataset"
|
|
, "TPUCompile"
|
|
, "TPUPartitionedCall"
|
|
, "TPUReplicate"
|
|
, "While"
|
|
, "XlaIf"
|
|
, "XlaLaunch"
|
|
, "XlaReduce"
|
|
, "XlaReduceWindow"
|
|
, "XlaSelectAndScatter"
|
|
, "XlaScatter"
|
|
, "XlaWhile"
|
|
, "_If"
|
|
, "_TPUReplicate"
|
|
, "_While"
|
|
, "_XlaCompile"
|
|
-- Incorrectly generated:
|
|
, "_FusedBatchNormGradEx"
|
|
-- Could not deduce:
|
|
, "_MklFusedBatchNorm"
|
|
, "_MklFusedBatchNormEx"
|
|
, "_MklFusedBatchNormGrad"
|
|
, "_MklFusedBatchNormGradV2"
|
|
, "_MklFusedBatchNormGradV3"
|
|
, "_MklFusedBatchNormV2"
|
|
, "_MklFusedBatchNormV3"
|
|
]
|
|
|
|
autogenModulesDir :: LocalBuildInfo -> FilePath
|
|
#if MIN_VERSION_Cabal(2,0,0)
|
|
autogenModulesDir = BuildPaths.autogenPackageModulesDir
|
|
#else
|
|
autogenModulesDir = BuildPaths.autogenModulesDir
|
|
#endif
|