1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2025-01-11 19:39:49 +01:00

Update to tensorflow 1.7 (#185)

All of the non-s/1.3/1.7/ changes are because

* There are new tensorflow datatypes
* Some ops have looser types (e.g. fill now accepts both int64 and int32)
* There are more ops of type "func"
This commit is contained in:
fkm3 2018-04-17 12:24:31 -04:00 committed by GitHub
parent e35211d49b
commit 1e2dca8701
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 48 additions and 15 deletions

View file

@ -1,7 +1,7 @@
# ChangeLog
## Upcoming (v0.2.0.0)
- Switch to tensorflow 1.3.
- Switch to tensorflow 1.7.
- Expand the `Rendered` class and add a `ToTensor` class to let more functions
(gradients, feed, colocateWith) support `ResourceHandle` wrappers like
`Variables`.

View file

@ -3,7 +3,7 @@
# stack to be installed on the host. This comes at the expense of
# flexibility.
FROM tensorflow/tensorflow:1.3.0
FROM tensorflow/tensorflow:1.7.0
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
# The build context directory is the top of the tensorflow-haskell
@ -28,8 +28,8 @@ RUN \
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
chmod 755 /usr/local/bin/protoc && \
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz && \
tar zxf libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz -C /usr/local && \
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz && \
tar zxf libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz -C /usr/local && \
ldconfig && \
stack setup && \
stack test --only-dependencies

View file

@ -1,6 +1,6 @@
# Prepare the image with:
# docker build -t tensorflow/haskell:v0 docker
FROM tensorflow/tensorflow:1.3.0
FROM tensorflow/tensorflow:1.7.0
LABEL maintainer="TensorFlow authors <tensorflow-haskell@googlegroups.com>"
RUN apt-get update
@ -27,8 +27,8 @@ RUN \
curl -O -L https://github.com/google/protobuf/releases/download/v3.2.0/protoc-3.2.0-linux-x86_64.zip && \
unzip -d /usr/local protoc-3.2.0-linux-x86_64.zip bin/protoc && \
chmod 755 /usr/local/bin/protoc && \
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz && \
tar zxf libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz -C /usr/local && \
curl -O https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz && \
tar zxf libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz -C /usr/local && \
ldconfig
ENV LANG en_US.UTF-8

View file

@ -92,12 +92,20 @@ blackList =
[ -- Requires the "func" type:
"FilterDataset"
, "FlatMapDataset"
, "GeneratorDataset"
, "GroupByWindowDataset"
, "InterleaveDataset"
, "MapAndBatchDataset"
, "MapDataset"
, "MapDataset"
, "OneShotIterator"
, "ParallelInterleaveDataset"
, "ParallelMapDataset"
, "RemoteCall"
, "ScanDataset"
, "SymbolicGradient"
, "_If"
, "_While"
]
autogenModulesDir :: LocalBuildInfo -> FilePath

View file

@ -150,7 +150,7 @@ imports = stack [
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Proxy (Proxy(Proxy))"
, "import Data.Word (Word8, Word16)"
, "import Data.Word (Word8, Word16, Word32, Word64)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
@ -415,9 +415,12 @@ dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_UINT32 = "Data.Word.Word32"
dtTypeToHaskell DT_UINT64 = "Data.Word.Word64"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
dtTypeToHaskell DT_VARIANT = "Variant"
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x

View file

@ -377,7 +377,7 @@ truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Build a
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
zeros (Shape s) = CoreOps.fill (vector s) (scalar 0)
shape :: TensorType t => Tensor v t -> Tensor Build Int32
shape = CoreOps.shape

View file

@ -41,6 +41,7 @@ module TensorFlow.Types
, Attribute(..)
, DataType(..)
, ResourceHandle
, Variant
-- * Lists
, ListOf(..)
, List
@ -72,7 +73,7 @@ import Data.Monoid ((<>))
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64)
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..))
@ -109,7 +110,8 @@ import Proto.Tensorflow.Core.Framework.Tensor as Tensor
, int64Val
, resourceHandleVal
, stringVal
, stringVal
, uint32Val
, uint64Val
)
import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
@ -124,6 +126,11 @@ import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto
-- | Dynamic type.
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
-- simplify op generation.
data Variant
-- | The class of scalar types supported by tensorflow.
class TensorType a where
tensorType :: a -> DataType
@ -163,6 +170,16 @@ instance TensorType Word16 where
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
instance TensorType Word32 where
tensorType _ = DT_UINT32
tensorRefType _ = DT_UINT32_REF
tensorVal = uint32Val
instance TensorType Word64 where
tensorType _ = DT_UINT64
tensorRefType _ = DT_UINT64_REF
tensorVal = uint64Val
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
@ -198,6 +215,11 @@ instance TensorType ResourceHandle where
tensorRefType _ = DT_RESOURCE_REF
tensorVal = resourceHandleVal
instance TensorType Variant where
tensorType _ = DT_VARIANT
tensorRefType _ = DT_VARIANT_REF
tensorVal = error "TODO Variant"
-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }

@ -1 +1 @@
Subproject commit 408fd454d7d2a16269576ea12bcd516e25a6b0c5
Subproject commit 92e6c3e4f5c1cabfda1e61547a6a1b268ef95fa5

View file

@ -25,7 +25,7 @@ else
fi
echo "Downloading libtensorflow..."
curl https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.3.0.tar.gz > libtensorflow.tar.gz
curl https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.7.0.tar.gz > libtensorflow.tar.gz
echo "Extracting and copying libtensorflow..."
sudo tar zxf libtensorflow.tar.gz -C /usr/local

View file

@ -13,8 +13,8 @@ let
name = "tensorflow-c";
src = fetchurl {
url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.3.0.tar.gz";
sha256 = "1d4bda5316063b70cf50a668d774b2067ef2a8ab163ff2eb29592bf3c24e2183";
url = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.7.0.tar.gz";
sha256 = "621642b1fddd3831e048817d2220d9d7cf8ba359ac81c83a808bcdd9a982ee90";
};
buildCommand = ''