I'm trying to implement the following python code in scala/breeze:
import numpy as np
mat = np.random.normal(size=(2, 5))
print(mat)
indexes = np.random.choice(5, replace = False, size = 3)
print(indexes)
mat[0, [indexes]] = 0
print(mat)
# the output:
# ./rowsliceOverwrite.py
[[ 0.30389599 0.84549682 -0.38408994 -1.11550844 -0.28496995]
[-1.55260273 -0.41368681 -0.40455289 0.13054527 -1.43541557]]
[1 3 4]
[[ 0.30389599 0. -0.38408994 0. 0. ]
[-1.55260273 -0.41368681 -0.40455289 0.13054527 -1.43541557]]
The goal is to directly zero selected indexes of a DenseMatrix row in a single expression.
Here's my scala attempt, followed by an error message:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
mat(0, indexes) = 0.0
println(mat)
# ./rowsliceOverwrite.sc
-- [E007] Type Mismatch Error: C:\Users\philwalk\workspace\tprf_py\.\rowsliceOverwrite.sc:14:9 -----------------------------------------------------------------------------------------------------------------------------------------------
14 | mat(0, indexes) = 0
| ^^^^^^^
| Found: (indexes : IndexedSeq[Int])
| Required: Int
|
| longer explanation available when compiling with `-explain`
1 error found
Errors encountered during compilation
It can be done in 3 steps:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
var row0 = mat(0, ::).t
row0(indexes) := 0.0
mat(0, ::) := row0.t
println(mat)
# ./rowsliceOverwrite.sc
-0.25010575 0.44800905 0.13285604 0.34085698 0.38346101
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
Vector(1, 3, 4)
-0.25010575 0.0 0.13285604 0.0 0.0
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
It's not as readable as the python code, it would be nice if there's a way to do it more directly, as a single expression.
After experimenting for awhile, I found an expression that seems to work:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
mat(0, ::).t(indexes) := 0.0
println(mat)
# the output:
# ./rowsliceOverwrite.sc
-0.25010575 0.44800905 0.13285604 0.34085698 0.38346101
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
Vector(1, 3, 4)
-0.25010575 0.0 0.13285604 0.0 0.0
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178