1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Add tensorflow-haskell-deptyped example (#174)

https://github.com/tensorflow/haskell/issues/156
This commit is contained in:
fkm3 2018-01-17 13:31:23 -05:00 committed by GitHub
parent bc15709cb4
commit 760c067e89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -120,7 +120,37 @@ No exportable Nix package will appear, but local development is possible.
# Related Projects
## Statically validated tensor shapes
https://github.com/helq/tensorflow-haskell-deptyped is experimenting with using dependent types to statically validate tensor shapes. May be merged with this repository in the future.
Example:
```haskell
{-# LANGUAGE DataKinds, ScopedTypeVariables #-}
import Data.Maybe (fromJust)
import Data.Vector.Sized (Vector, fromList)
import TensorFlow.DepTyped
test :: IO (Vector 8 Float)
test = runSession $ do
(x :: Placeholder "x" '[4,3] Float) <- placeholder
let elems1 = fromJust $ fromList [1,2,3,4,1,2]
elems2 = fromJust $ fromList [5,6,7,8]
(w :: Tensor '[3,2] '[] Build Float) = constant elems1
(b :: Tensor '[4,1] '[] Build Float) = constant elems2
y = (x `matMul` w) `add` b -- y shape: [4,2] (b shape is [4.1] but `add` broadcasts it to [4,2])
let (inputX :: TensorData "x" [4,3] Float) =
encodeTensorData . fromJust $ fromList [1,2,3,4,1,0,7,9,5,3,5,4]
runWithFeeds (feed x inputX :~~ NilFeedList) y
main :: IO ()
main = test >>= print
```
# License
This project is licensed under the terms of the [Apache 2.0 license](LICENSE).