mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Update to tensorflow 1.7 (#185)
All of the non-s/1.3/1.7/ changes are because * There are new tensorflow datatypes * Some ops have looser types (e.g. fill now accepts both int64 and int32) * There are more ops of type "func"
This commit is contained in:
parent
e35211d49b
commit
1e2dca8701
10 changed files with 48 additions and 15 deletions
|
@ -1,7 +1,7 @@
|
||||||
# ChangeLog
|
# ChangeLog
|
||||||
|
|
||||||
## Upcoming (v0.2.0.0)
|
## Upcoming (v0.2.0.0)
|
||||||
- Switch to tensorflow 1.3.
|
- Switch to tensorflow 1.7.
|
||||||
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
|
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
|
||||||
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
|
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
|
||||||
`Variables`.
|
`Variables`.
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# stack to be installed on the host. This comes at the expense of
|
# stack to be installed on the host. This comes at the expense of
|
||||||
# flexibility.
|
# flexibility.
|
||||||
|
|
||||||
FROM tensorflow/tensorflow:1.3.0
|
FROM tensorflow/tensorflow:1.7.0
|
||||||
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
||||||
|
|
||||||
# The build context directory is the top of the tensorflow-haskell
|
# The build context directory is the top of the tensorflow-haskell
|
||||||
|
@ -28,8 +28,8 @@ RUN \
|
||||||
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
|
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
|
||||||
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
|
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
|
||||||
chmod 755 /usr/local/bin/protoc && \
|
chmod 755 /usr/local/bin/protoc && \
|
||||||
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz && \
|
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz && \
|
||||||
tar zxf libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz -C /usr/local && \
|
tar zxf libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz -C /usr/local && \
|
||||||
ldconfig && \
|
ldconfig && \
|
||||||
stack setup && \
|
stack setup && \
|
||||||
stack test --only-dependencies
|
stack test --only-dependencies
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Prepare the image with:
|
# Prepare the image with:
|
||||||
# docker build -t tensorflow/haskell:v0 docker
|
# docker build -t tensorflow/haskell:v0 docker
|
||||||
FROM tensorflow/tensorflow:1.3.0
|
FROM tensorflow/tensorflow:1.7.0
|
||||||
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
|
||||||
|
|
||||||
RUN apt-get update
|
RUN apt-get update
|
||||||
|
@ -27,8 +27,8 @@ RUN \
|
||||||
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
|
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
|
||||||
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
|
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
|
||||||
chmod 755 /usr/local/bin/protoc && \
|
chmod 755 /usr/local/bin/protoc && \
|
||||||
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz && \
|
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz && \
|
||||||
tar zxf libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz -C /usr/local && \
|
tar zxf libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz -C /usr/local && \
|
||||||
ldconfig
|
ldconfig
|
||||||
|
|
||||||
ENV LANG en_US.UTF-8
|
ENV LANG en_US.UTF-8
|
||||||
|
|
|
@ -92,12 +92,20 @@ blackList =
|
||||||
[ -- Requires the "func" type:
|
[ -- Requires the "func" type:
|
||||||
"FilterDataset"
|
"FilterDataset"
|
||||||
, "FlatMapDataset"
|
, "FlatMapDataset"
|
||||||
|
, "GeneratorDataset"
|
||||||
, "GroupByWindowDataset"
|
, "GroupByWindowDataset"
|
||||||
, "InterleaveDataset"
|
, "InterleaveDataset"
|
||||||
|
, "MapAndBatchDataset"
|
||||||
|
, "MapDataset"
|
||||||
, "MapDataset"
|
, "MapDataset"
|
||||||
, "OneShotIterator"
|
, "OneShotIterator"
|
||||||
|
, "ParallelInterleaveDataset"
|
||||||
, "ParallelMapDataset"
|
, "ParallelMapDataset"
|
||||||
|
, "RemoteCall"
|
||||||
|
, "ScanDataset"
|
||||||
, "SymbolicGradient"
|
, "SymbolicGradient"
|
||||||
|
, "_If"
|
||||||
|
, "_While"
|
||||||
]
|
]
|
||||||
|
|
||||||
autogenModulesDir :: LocalBuildInfo -> FilePath
|
autogenModulesDir :: LocalBuildInfo -> FilePath
|
||||||
|
|
|
@ -150,7 +150,7 @@ imports = stack [
|
||||||
, "import Data.Complex (Complex)"
|
, "import Data.Complex (Complex)"
|
||||||
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||||
, "import Data.Proxy (Proxy(Proxy))"
|
, "import Data.Proxy (Proxy(Proxy))"
|
||||||
, "import Data.Word (Word8, Word16)"
|
, "import Data.Word (Word8, Word16, Word32, Word64)"
|
||||||
, "import Lens.Family2 ((.~), (&))"
|
, "import Lens.Family2 ((.~), (&))"
|
||||||
, "import TensorFlow.Build"
|
, "import TensorFlow.Build"
|
||||||
, "import TensorFlow.BuildOp"
|
, "import TensorFlow.BuildOp"
|
||||||
|
@ -415,9 +415,12 @@ dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
|
||||||
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
||||||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||||
|
dtTypeToHaskell DT_UINT32 = "Data.Word.Word32"
|
||||||
|
dtTypeToHaskell DT_UINT64 = "Data.Word.Word64"
|
||||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||||
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
|
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
|
||||||
|
dtTypeToHaskell DT_VARIANT = "Variant"
|
||||||
dtTypeToHaskell x =
|
dtTypeToHaskell x =
|
||||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||||
|
|
||||||
|
|
|
@ -377,7 +377,7 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||||
truncatedNormal' = CoreOps.truncatedNormal'
|
truncatedNormal' = CoreOps.truncatedNormal'
|
||||||
|
|
||||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
|
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
|
||||||
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
zeros (Shape s) = CoreOps.fill (vector s) (scalar 0)
|
||||||
|
|
||||||
shape :: TensorType t => Tensor v t -> Tensor Build Int32
|
shape :: TensorType t => Tensor v t -> Tensor Build Int32
|
||||||
shape = CoreOps.shape
|
shape = CoreOps.shape
|
||||||
|
|
|
@ -41,6 +41,7 @@ module TensorFlow.Types
|
||||||
, Attribute(..)
|
, Attribute(..)
|
||||||
, DataType(..)
|
, DataType(..)
|
||||||
, ResourceHandle
|
, ResourceHandle
|
||||||
|
, Variant
|
||||||
-- * Lists
|
-- * Lists
|
||||||
, ListOf(..)
|
, ListOf(..)
|
||||||
, List
|
, List
|
||||||
|
@ -72,7 +73,7 @@ import Data.Monoid ((<>))
|
||||||
import Data.ProtoLens.TextFormat (showMessageShort)
|
import Data.ProtoLens.TextFormat (showMessageShort)
|
||||||
import Data.Proxy (Proxy(..))
|
import Data.Proxy (Proxy(..))
|
||||||
import Data.String (IsString)
|
import Data.String (IsString)
|
||||||
import Data.Word (Word8, Word16, Word64)
|
import Data.Word (Word8, Word16, Word32, Word64)
|
||||||
import Foreign.Storable (Storable)
|
import Foreign.Storable (Storable)
|
||||||
import GHC.Exts (Constraint, IsList(..))
|
import GHC.Exts (Constraint, IsList(..))
|
||||||
import Lens.Family2 (Lens', view, (&), (.~), (^..))
|
import Lens.Family2 (Lens', view, (&), (.~), (^..))
|
||||||
|
@ -109,7 +110,8 @@ import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||||
, int64Val
|
, int64Val
|
||||||
, resourceHandleVal
|
, resourceHandleVal
|
||||||
, stringVal
|
, stringVal
|
||||||
, stringVal
|
, uint32Val
|
||||||
|
, uint64Val
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
( TensorShapeProto(..)
|
( TensorShapeProto(..)
|
||||||
|
@ -124,6 +126,11 @@ import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
type ResourceHandle = ResourceHandleProto
|
type ResourceHandle = ResourceHandleProto
|
||||||
|
|
||||||
|
-- | Dynamic type.
|
||||||
|
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
|
||||||
|
-- simplify op generation.
|
||||||
|
data Variant
|
||||||
|
|
||||||
-- | The class of scalar types supported by tensorflow.
|
-- | The class of scalar types supported by tensorflow.
|
||||||
class TensorType a where
|
class TensorType a where
|
||||||
tensorType :: a -> DataType
|
tensorType :: a -> DataType
|
||||||
|
@ -163,6 +170,16 @@ instance TensorType Word16 where
|
||||||
tensorRefType _ = DT_UINT16_REF
|
tensorRefType _ = DT_UINT16_REF
|
||||||
tensorVal = intVal . integral
|
tensorVal = intVal . integral
|
||||||
|
|
||||||
|
instance TensorType Word32 where
|
||||||
|
tensorType _ = DT_UINT32
|
||||||
|
tensorRefType _ = DT_UINT32_REF
|
||||||
|
tensorVal = uint32Val
|
||||||
|
|
||||||
|
instance TensorType Word64 where
|
||||||
|
tensorType _ = DT_UINT64
|
||||||
|
tensorRefType _ = DT_UINT64_REF
|
||||||
|
tensorVal = uint64Val
|
||||||
|
|
||||||
instance TensorType Int16 where
|
instance TensorType Int16 where
|
||||||
tensorType _ = DT_INT16
|
tensorType _ = DT_INT16
|
||||||
tensorRefType _ = DT_INT16_REF
|
tensorRefType _ = DT_INT16_REF
|
||||||
|
@ -198,6 +215,11 @@ instance TensorType ResourceHandle where
|
||||||
tensorRefType _ = DT_RESOURCE_REF
|
tensorRefType _ = DT_RESOURCE_REF
|
||||||
tensorVal = resourceHandleVal
|
tensorVal = resourceHandleVal
|
||||||
|
|
||||||
|
instance TensorType Variant where
|
||||||
|
tensorType _ = DT_VARIANT
|
||||||
|
tensorRefType _ = DT_VARIANT_REF
|
||||||
|
tensorVal = error "TODO Variant"
|
||||||
|
|
||||||
-- | Tensor data with the correct memory layout for tensorflow.
|
-- | Tensor data with the correct memory layout for tensorflow.
|
||||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||||
|
|
||||||
|
|
2
third_party/tensorflow
vendored
2
third_party/tensorflow
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit 408fd454d7d2a16269576ea12bcd516e25a6b0c5
|
Subproject commit 92e6c3e4f5c1cabfda1e61547a6a1b268ef95fa5
|
|
@ -25,7 +25,7 @@ else
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Downloading libtensorflow..."
|
echo "Downloading libtensorflow..."
|
||||||
curl https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.3.0.tar.gz > libtensorflow.tar.gz
|
curl https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.7.0.tar.gz > libtensorflow.tar.gz
|
||||||
|
|
||||||
echo "Extracting and copying libtensorflow..."
|
echo "Extracting and copying libtensorflow..."
|
||||||
sudo tar zxf libtensorflow.tar.gz -C /usr/local
|
sudo tar zxf libtensorflow.tar.gz -C /usr/local
|
||||||
|
|
|
@ -13,8 +13,8 @@ let
|
||||||
name = "tensorflow-c";
|
name = "tensorflow-c";
|
||||||
|
|
||||||
src = fetchurl {
|
src = fetchurl {
|
||||||
url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz";
|
url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz";
|
||||||
sha256 = "1d4bda5316063b70cf50a668d774b2067ef2a8ab163ff2eb29592bf3c24e2183";
|
sha256 = "621642b1fddd3831e048817d2220d9d7cf8ba359ac81c83a808bcdd9a982ee90";
|
||||||
};
|
};
|
||||||
|
|
||||||
buildCommand = ''
|
buildCommand = ''
|
||||||
|
|
Loading…
Reference in a new issue