pythontwistedtwisted.internet

Is there a way to restrict number of connections to a specific port in a Twisted protocol server?


I have a python Twisted server application interfacing with a legacy client application and each client has been assigned a specific port on which to connect to the server. So I have set up listeners on all those ports on the server and it is working beautifully, but I need to build in some safeguards to disallow more than one client connecting on the same server port. There are too many things on the client application that break when another client is connected to the same port, and I can't update that application right now. I gotta live with how it has been operating. I know that I could build in some logic to the connectionMade() function to see if someone already exists on that port and if so, close this new connection. But I'd rather there be a way to reject it to begin with, so the client is not even allowed to connect. Then the client will know they made a mistake and they can alter which port they are trying to connect on.

Here is a stripped down version of my server code if that helps.

from twisted.internet.protocol import Factory
from twisted.internet.protocol import Protocol
from twisted.internet import reactor
from twisted.internet import task
import time

class MyServerTasks():
    def someFunction(msg):
        #Do stuff

    def someOtherFunction(msg):
        #Do other stuff

class MyServer(Protocol):

    def __init__(self, users):
        self.users = users
        self.name = None

    def connectionMade(self):
        #Depending on which port is connected, go do stuff

    def connectionLost(self, reason):
        #Update dictionaries and other global info

    def dataReceived(self, line):
        t = time.strftime('%Y-%m-%d %H:%M:%S')
        d = self.transport.getHost()
        print("{} Received message from {}:{}...{}".format(t, d.host, d.port, line))  #debug
        self.handle_GOTDATA(line)

    def handle_GOTDATA(self, msg):
        #Parse the received data string and do stuff based on the message.
        #For example:
        if "99" in msg:
            MyServerTasks.someFunction(msg)

class MyServerFactory(Factory):

    def __init__(self):
        self.users = {} # maps user names to Chat instances

    def buildProtocol(self, *args, **kwargs):
        protocol = MyServer(self.users)
        protocol.factory = self
        protocol.factory.clients = []
        return protocol

reactor.listenTCP(50010, MyServerFactory())
reactor.listenTCP(50011, MyServerFactory())
reactor.listenTCP(50012, MyServerFactory())
reactor.listenTCP(50013, MyServerFactory())

reactor.run()

Solution

  • When a client connected to the server, twisted use the factory to create a protocol (by calling its buildProtocol method) instance to handle the client request.

    therefore you can maintain a counter of connected client in your MyServerFactory, if the counter has reached maximum allowed connected client you can return None instead of creating new protocol for that client. Twisted will close the client connection if the factory doesn't return a protocol from its buildProtocol method.

    you can see here

    class MyServer(Protocol):
    
    def __init__(self, users):
        self.users = users
        self.name = None
    
    def connectionMade(self):
        #Depending on which port is connected, go do stuff
    
    def connectionLost(self, reason):
        #Update dictionaries and other global info
        self.factory.counter -= 1
    
    def dataReceived(self, line):
        t = time.strftime('%Y-%m-%d %H:%M:%S')
        d = self.transport.getHost()
        print("{} Received message from {}:{}...{}".format(t, d.host, d.port, line))  #debug
        self.handle_GOTDATA(line)
    
    def handle_GOTDATA(self, msg):
        #Parse the received data string and do stuff based on the message.
        #For example:
        if "99" in msg:
            MyServerTasks.someFunction(msg)
    
    
    
     class MyServerFactory(Factory):
       MAX_CLIENT = 2
    
     def __init__(self):
        self.users = {} # maps user names to Chat instances
        self.counter = 0
    
     def buildProtocol(self, *args, **kwargs):
        if self.counter == self.MAX_CLIENT:
            return None
        self.counter += 1
        protocol = MyServer(self.users)
        protocol.factory = self
        protocol.factory.clients = []
        return protocol