mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-27 03:05:01 +01:00
Add tensorflow-haskell-deptyped example (#174)
https://github.com/tensorflow/haskell/issues/156
This commit is contained in:
parent
bc15709cb4
commit
760c067e89
1 changed files with 30 additions and 0 deletions
30
README.md
30
README.md
|
@ -120,7 +120,37 @@ No exportable Nix package will appear, but local development is possible.
|
|||
|
||||
# Related Projects
|
||||
|
||||
## Statically validated tensor shapes
|
||||
|
||||
https://github.com/helq/tensorflow-haskell-deptyped is experimenting with using dependent types to statically validate tensor shapes. May be merged with this repository in the future.
|
||||
|
||||
Example:
|
||||
|
||||
```haskell
|
||||
{-# LANGUAGE DataKinds, ScopedTypeVariables #-}
|
||||
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Vector.Sized (Vector, fromList)
|
||||
import TensorFlow.DepTyped
|
||||
|
||||
test :: IO (Vector 8 Float)
|
||||
test = runSession $ do
|
||||
(x :: Placeholder "x" '[4,3] Float) <- placeholder
|
||||
|
||||
let elems1 = fromJust $ fromList [1,2,3,4,1,2]
|
||||
elems2 = fromJust $ fromList [5,6,7,8]
|
||||
(w :: Tensor '[3,2] '[] Build Float) = constant elems1
|
||||
(b :: Tensor '[4,1] '[] Build Float) = constant elems2
|
||||
y = (x `matMul` w) `add` b -- y shape: [4,2] (b shape is [4.1] but `add` broadcasts it to [4,2])
|
||||
|
||||
let (inputX :: TensorData "x" [4,3] Float) =
|
||||
encodeTensorData . fromJust $ fromList [1,2,3,4,1,0,7,9,5,3,5,4]
|
||||
|
||||
runWithFeeds (feed x inputX :~~ NilFeedList) y
|
||||
|
||||
main :: IO ()
|
||||
main = test >>= print
|
||||
```
|
||||
|
||||
# License
|
||||
This project is licensed under the terms of the [Apache 2.0 license](LICENSE).
|
||||
|
|
Loading…
Add table
Reference in a new issue