scalasliceoverwritescala-breeze

is it possible to directly overwrite selected columns of a DenseMatrix row using random list of indexes


I'm trying to implement the following python code in scala/breeze:

import numpy as np
mat = np.random.normal(size=(2, 5))
print(mat)
indexes = np.random.choice(5, replace = False, size = 3)
print(indexes)
mat[0, [indexes]] = 0
print(mat)

# the output:
# ./rowsliceOverwrite.py
[[ 0.30389599  0.84549682 -0.38408994 -1.11550844 -0.28496995]
 [-1.55260273 -0.41368681 -0.40455289  0.13054527 -1.43541557]]
[1 3 4]
[[ 0.30389599  0.         -0.38408994  0.          0.        ]
 [-1.55260273 -0.41368681 -0.40455289  0.13054527 -1.43541557]]

The goal is to directly zero selected indexes of a DenseMatrix row in a single expression.

Here's my scala attempt, followed by an error message:

import breeze.linalg.*
import breeze.stats.*

def main(args: Array[String]): Unit =
  val mat = DenseMatrix(
    (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
    (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
  )
  println(mat)
  val indexes = IndexedSeq(1, 3, 4)
  println(indexes)
  mat(0, indexes) = 0.0
  println(mat)

# ./rowsliceOverwrite.sc
-- [E007] Type Mismatch Error: C:\Users\philwalk\workspace\tprf_py\.\rowsliceOverwrite.sc:14:9 -----------------------------------------------------------------------------------------------------------------------------------------------
14 |  mat(0, indexes) = 0
   |         ^^^^^^^
   |         Found:    (indexes : IndexedSeq[Int])
   |         Required: Int
   |
   | longer explanation available when compiling with `-explain`
1 error found
Errors encountered during compilation

It can be done in 3 steps:

import breeze.linalg.*
import breeze.stats.*

def main(args: Array[String]): Unit =
  val mat = DenseMatrix(
    (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
    (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
  )
  println(mat)
  val indexes = IndexedSeq(1, 3, 4)
  println(indexes)
  var row0 = mat(0, ::).t
  row0(indexes) := 0.0
  mat(0, ::) := row0.t
  println(mat)
# ./rowsliceOverwrite.sc
-0.25010575  0.44800905  0.13285604  0.34085698   0.38346101
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178
Vector(1, 3, 4)
-0.25010575  0.0         0.13285604  0.0          0.0
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178

It's not as readable as the python code, it would be nice if there's a way to do it more directly, as a single expression.


Solution

  • After experimenting for awhile, I found an expression that seems to work:

    import breeze.linalg.*
    import breeze.stats.*
    
    def main(args: Array[String]): Unit =
      val mat = DenseMatrix(
        (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
        (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
      )
      println(mat)
      val indexes = IndexedSeq(1, 3, 4)
      println(indexes)
      mat(0, ::).t(indexes) := 0.0
      println(mat)
    
    # the output:
    # ./rowsliceOverwrite.sc
    -0.25010575  0.44800905  0.13285604  0.34085698   0.38346101
    -1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178
    Vector(1, 3, 4)
    -0.25010575  0.0         0.13285604  0.0          0.0
    -1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178