I created a TPU VM on GCP.
I am following the documentation page on how to run a calculation on a Cloud TPU VM by using PyTorch
I have set the XRT TPU device configuration in the VM with
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
I created a Python file with following contents:
import torch
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)
But when I run the file, python3 tpu-test.py
, I get the following error:
$ python3 tpu_test.py
Traceback (most recent call last):
File "tpu_test.py", line 6, in <module>
dev = xm.xla_device()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 244, in xla_device
devices = get_xla_supported_devices(devkind=devkind)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 138, in get_xla_supported_devices
xla_devices = _DEVICES.value
File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 32, in value
self._value = self._gen_fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 20, in <lambda>
_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
RuntimeError: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:1374 : Check failed: session.Run({tensorflow::Output(result, 0)}, &outputs) == ::tensorflow::Status::OK() (INVALID_ARGUMENT: No matching devices found for '/job:localservice/replica:0/task:0/device:TPU_SYSTEM:0' vs. OK)
*** Begin stack trace ***
tensorflow::CurrentStackTrace[abi:cxx11]()
xla::XrtComputationClient::InitializeAndFetchTopology(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tensorflow::ConfigProto const&)
xla::XrtComputationClient::InitializeDevices(std::unique_ptr<tensorflow::tpu::TopologyProto, std::default_delete<tensorflow::tpu::TopologyProto> >)
xla::XrtComputationClient::XrtComputationClient(xla::XrtComputationClient::Options, std::unique_ptr<tensorflow::tpu::TopologyProto, std::default_delete<tensorflow::tpu::TopologyProto> >)
xla::ComputationClient::Create()
xla::ComputationClient::Get()
PyCFunction_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyObject_GetAttr
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
PyRun_SimpleFileExFlags
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
Try with the below flag
export TPU_NUM_DEVICES=[num of logical cores]
for eg: on v3-8,
export TPU_NUM_DEVICES=8