mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
Add tensorflow-haskell-deptyped example (#174)
https://github.com/tensorflow/haskell/issues/156
This commit is contained in:
parent
bc15709cb4
commit
760c067e89
1 changed files with 30 additions and 0 deletions
30
README.md
30
README.md
|
@ -120,7 +120,37 @@ No exportable Nix package will appear, but local development is possible.
|
||||||
|
|
||||||
# Related Projects
|
# Related Projects
|
||||||
|
|
||||||
|
## Statically validated tensor shapes
|
||||||
|
|
||||||
https://github.com/helq/tensorflow-haskell-deptyped is experimenting with using dependent types to statically validate tensor shapes. May be merged with this repository in the future.
|
https://github.com/helq/tensorflow-haskell-deptyped is experimenting with using dependent types to statically validate tensor shapes. May be merged with this repository in the future.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```haskell
|
||||||
|
{-# LANGUAGE DataKinds, ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
import Data.Maybe (fromJust)
|
||||||
|
import Data.Vector.Sized (Vector, fromList)
|
||||||
|
import TensorFlow.DepTyped
|
||||||
|
|
||||||
|
test :: IO (Vector 8 Float)
|
||||||
|
test = runSession $ do
|
||||||
|
(x :: Placeholder "x" '[4,3] Float) <- placeholder
|
||||||
|
|
||||||
|
let elems1 = fromJust $ fromList [1,2,3,4,1,2]
|
||||||
|
elems2 = fromJust $ fromList [5,6,7,8]
|
||||||
|
(w :: Tensor '[3,2] '[] Build Float) = constant elems1
|
||||||
|
(b :: Tensor '[4,1] '[] Build Float) = constant elems2
|
||||||
|
y = (x `matMul` w) `add` b -- y shape: [4,2] (b shape is [4.1] but `add` broadcasts it to [4,2])
|
||||||
|
|
||||||
|
let (inputX :: TensorData "x" [4,3] Float) =
|
||||||
|
encodeTensorData . fromJust $ fromList [1,2,3,4,1,0,7,9,5,3,5,4]
|
||||||
|
|
||||||
|
runWithFeeds (feed x inputX :~~ NilFeedList) y
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = test >>= print
|
||||||
|
```
|
||||||
|
|
||||||
# License
|
# License
|
||||||
This project is licensed under the terms of the [Apache 2.0 license](LICENSE).
|
This project is licensed under the terms of the [Apache 2.0 license](LICENSE).
|
||||||
|
|
Loading…
Reference in a new issue