I am trying to write a unit test for code that interacts with an SFTP server using the paramiko
library. The code under test receives a list of remote file locations and a callback. Each file is fetched and sent into the callback. The test shall simulate a scenario, where the caller sends two files to visit and one of the files fails with an IOError. I want to make sure that the failing file is excluded from the response.
Here is the code.py
:
import io
from typing import Callable, List
import typing
import paramiko
def visit_files(files: List[str], callback: Callable[[typing.BinaryIO], None]) -> List[str]:
response = []
with paramiko.SSHClient() as ssh:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect("test.rebex.net", port=22, username="demo", password="password")
with ssh.open_sftp() as sftp:
for file_name in files:
try:
with sftp.open(file_name, "rb") as f:
try:
b = f.read()
callback(io.BytesIO(b))
response.append(file_name)
except ValueError:
print("Something went wrong")
except IOError:
print("Unknown IO error")
return response
And my test_code.py
:
import typing
from unittest.mock import Mock
from pytest_mock import MockerFixture
from src.utils.code import visit_files
def test_visiting(mocker: MockerFixture):
mock = mocker.patch('paramiko.SSHClient')
ssh_client_mock = mock.return_value
ssh_client_mock.connect.return_value = Mock()
sftp_mock = ssh_client_mock.open_sftp.return_value
sftp_mock.open.side_effect = [
Mock(read=Mock(return_value=b'Hello, World!')), # Mock for the first file
IOError("Unable to open file"), # Simulate IOError for the second file
]
def print_size(b: typing.BinaryIO) -> None:
print(b.tell())
response = visit_files(files=["file1.txt", "file2.txt"], callback=print_size)
assert response == ["file1.txt"]
The error I am receiving is: TypeError: a bytes-like object is required, not 'MagicMock'
in line callback(io.BytesIO(b))
. I can't figure out where my mocks are not set up properly.
I would do something like this, starting by refactoring visit_files
so you can inject the client dependency
import io
from typing import Callable, List
import typing
import paramiko
def visit_files(files: List[str], callback: Callable[[typing.BinaryIO], None], client: paramiko.SSHClient = None) -> List[str]:
response = []
if not client:
client = paramiko.SSHClient
with client() as ssh:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect("test.rebex.net", port=22, username="demo", password="password")
with ssh.open_sftp() as sftp:
for file_name in files:
try:
with sftp.open(file_name, "rb") as f:
try:
b = f.read()
callback(io.BytesIO(b))
response.append(file_name)
except ValueError:
print("Something went wrong")
except IOError:
print("Unknown IO error")
return response
and then, create SSHClientMock
taking into account the context managers:
import typing
from unittest.mock import MagicMock
import paramiko
from contextlib import contextmanager
import io
from src.utils.code import visit_files
class SSHClientMock(MagicMock):
def __init__(self, **kwargs):
super().__init__(spec=paramiko.SSHClient, **kwargs)
def __enter__(self):
return self
def __exit__(self):
pass
@contextmanager
def open_sftp(self):
def _open_sftp_mock_open(filename, mode):
if filename == "file1.txt":
return io.BytesIO(b"Hello World!")
elif filename == "file2.txt":
raise IOError("Unable to open file")
assert False
open_sftp_mock = MagicMock()
open_sftp_mock.open = _open_sftp_mock_open
yield open_sftp_mock
def test_visiting():
def print_size(b: typing.BinaryIO) -> None:
print(b.tell())
response = visit_files(files=["file1.txt", "file2.txt"], callback=print_size, client=SSHClientMock())
assert response == ["file1.txt"]