From 440fe803d210914ba6a55ea30e71bfdcce7291cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=98=D0=B2=D0=B0=D0=BD=20=D0=A1=D0=BE=D0=BB=D0=BD=D1=86?= =?UTF-8?q?=D0=B5=D0=B2?= Date: Thu, 3 Oct 2024 14:50:42 +0300 Subject: [PATCH] Split TCP socket server and HTTP handler. - File reader function small improvent; - Move MAX_REQUEST_LINE_SIZE to common server file. --- main.py | 3 +- server/common.py | 2 + server/file_read.py | 10 +- server/http_handler.py | 245 +++++++++++++++++++++++++++++++++++++++++ server/main.py | 244 +--------------------------------------- 5 files changed, 255 insertions(+), 249 deletions(-) create mode 100644 server/http_handler.py diff --git a/main.py b/main.py index ef2a453..8acb773 100755 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import os import logging from server.main import init_server_socket +from server.http_handler import HTTPHandler import config @@ -18,4 +19,4 @@ if __name__ == "__main__": log.error("Attempt to start server...") - init_server_socket() + init_server_socket(server_handler=HTTPHandler) diff --git a/server/common.py b/server/common.py index dff3e1b..50cf85e 100644 --- a/server/common.py +++ b/server/common.py @@ -1,6 +1,8 @@ import enum +MAX_REQUEST_LINE_SIZE = 64 * 1024 + @enum.unique class HandlerType(enum.Enum): STATIC_FILE = 0 diff --git a/server/file_read.py b/server/file_read.py index c993466..f245601 100644 --- a/server/file_read.py +++ b/server/file_read.py @@ -10,12 +10,10 @@ def fileiter(filename): while True: file_data = f.read(1024) + yield file_data + if len(file_data) >= 1024: - log.debug("File size greeter 1024 bytes, yield data") - yield file_data - log.debug("Continue") + log.debug("File size greeter 1024 bytes, yield data. Continue to read") else: - log.debug("File size lower 1024 bytes, yield data") - yield file_data - log.debug("Break") + log.debug("File size lower 1024 bytes, yield data. Break to read") break diff --git a/server/http_handler.py b/server/http_handler.py new file mode 100644 index 0000000..db73ef0 --- /dev/null +++ b/server/http_handler.py @@ -0,0 +1,245 @@ +import logging + +import config + +from .common import MAX_REQUEST_LINE_SIZE, HandlerType, HTTPMethod +from .file_read import fileiter +from .response import Response + + +log = logging.getLogger(__name__) + + +class HTTPHandler: + def __init__(self, conn, addr): + self.conn = conn + self.addr = addr + + self.read_fd = conn.makefile("rb") + self.write_fd = conn.makefile("wb") + + self.http_method = "" + self.http_address = "" + self.http_get_request = {} + self.http_headers = {} + + self.data = b"" + + self.__handle_connection() + + def __write_data(self, data_iter): + for chunk in data_iter: + self.write_fd.write(chunk) + self.write_fd.flush() + + def __read_data(self, size=MAX_REQUEST_LINE_SIZE+1): + return self.read_fd.readline(size) + + def __close(self): + self.write_fd.close() + self.read_fd.close() + self.conn.close() + + def __handle_connection(self): + """ + Chain of handle first line of HTTP request + """ + raw = self.__read_data() + + if len(raw) > MAX_REQUEST_LINE_SIZE: + log.debug("Request line too long") + + r = Response(status_code=413, data=b"Request line too long") + self.__write_data(r) + + self.__close() + return + + # FIXME data maybe isn't UTF-8 or just binary + req_line = raw.decode().rstrip("\r\n") + req_line_splited = req_line.split(" ") + + if len(req_line_splited) != 3: + if len(req_line_splited) > 3: + log.debug("Request head too long") + else: + log.debug("Request head too small") + + r = Response(status_code=400, data=b"Bad request") + self.__write_data(r) + self.__close() + return + + + if req_line_splited[2].startswith("HTTP/"): + if (req_line_splited[2] == "HTTP/1.0" or + req_line_splited[2] == "HTTP/1.1"): + http_method = req_line_splited[0] + else: + log.debug("Unsupported HTTP version") + + r = Response(status_code=505, data=b"HTTP Version Not Supported") + self.__write_data(r) + self.__close() + return + else: + log.debug("Protocol isn't HTTP version 1.0 or 1.1") + + r = Response(status_code=400, data=b"Bad request") + self.__write_data(r) + self.__close() + return + + if http_method == "GET": + self.http_method = HTTPMethod.GET + elif http_method == "POST": + self.http_method = HTTPMethod.POST + elif http_method == "PUT": + self.http_method = HTTPMethod.PUT + elif http_method == "HEAD": + self.http_method = HTTPMethod.HEAD + elif http_method == "DELETE": + self.http_method = HTTPMethod.DELETE + elif http_method == "OPTIONS": + self.http_method = HTTPMethod.OPTIONS + elif http_method == "TRACE": + self.http_method = HTTPMethod.TRACE + elif http_method == "PATCH": + self.http_method = HTTPMethod.PATCH + else: + log.debug("Method not allowed") + + r = Response(status_code=405, data=b"Method not allowed") + self.__write_data(r) + self.__close() + return + + if (self.http_method == HTTPMethod.GET and + req_line_splited[1].find("?") >= 0): + + addr, get_req = req_line_splited[1].split("?", 1) + self.http_address = addr + + for param in get_req.split("&"): + if len(param) == 0: + continue + + elif param.find("=") >= 0: + name, val = param.split("=", 1) + + else: + name, val = (param, "") + + self.http_get_request.update({name: val}) + else: + self.http_address = req_line_splited[1] + self.http_get_request = {} + + self.__http_header_handle() + + def __http_header_handle(self): + """ + Chain of handle headers in HTTP request + """ + while True: + raw = self.__read_data() + + if len(raw) > MAX_REQUEST_LINE_SIZE: + log.debug("Request header line too long") + self.conn.send(b"") + self.conn.close() + return + + if raw == b"": + log.debug("Client not send data, close connection") + self.conn.send(b"") + self.conn.close() + return + + if raw == b"\r\n" or raw == b"\n": + self.__http_data_handle() + break + else: + decoded_data = raw.decode("UTF-8").rstrip("\r\n") + decoded_data_split = decoded_data.split(":", 1) + self.http_headers.update( + { + decoded_data_split[0]: decoded_data_split[1].strip(" ") + }) + + def __http_data_handle(self): + """ + Chain of receive data partitions HTTP request + """ + + if "Content-Length" in self.http_headers: + # TODO: if content-length biggest of MAX_REQUEST_LINE_SIZE - partition receive + log.debug("Request has contain body, try to read") + log.debug("Check Content-Length value type:") + + if self.http_headers["Content-Length"].isdigit(): + log.debug(" Pass") + bytes_to_receive = int(self.http_headers["Content-Length"]) + else: + log.error(" Content-Length is not integer") + self.conn.send(b"") + self.conn.close() + return + + log.debug("Want to receive {} bytes from client".format(bytes_to_receive)) + self.data = self.read_fd.read(bytes_to_receive) + + log.debug("Request method: %s" % self.http_method) + log.debug("Requested address: %s" % self.http_address) + log.debug("Request headers: %s" % self.http_headers) + log.debug("Request in GET after ?: %s" % self.http_get_request) + log.debug("Data: %s" % self.data) + + self.__send_form_data() + + def __send_form_data(self): + """ + Chain of handle, pack & send data to client + """ + found = False + + for path in config.urls: + if self.http_address == path.address: + if self.http_method in path.methods: + found = True + + if path.handler_type == HandlerType.STATIC_FILE: + log.debug("Address associated with static file on path {}".format(path.link)) + r = Response(data=fileiter(path.link)) + self.__write_data(r) + + elif path.handler_type == HandlerType.FUNCTION: + log.debug("Address associated with function") + func = path.link + + if func.__code__.co_argcount == 0: + func_return = func() + + elif func.__code__.co_argcount == 1: + func_return = func(self.http_headers) + + elif func.__code__.co_argcount == 2: + func_return = func(self.http_headers, self.http_get_request) + + elif func.__code__.co_argcount == 3: + func_return = func(self.http_headers, self.http_get_request, self.data) + + self.__write_data(func_return) + + else: + log.warning("Address configured on server, but type not allowed URL type in URLs list") + + r = Response(status_code=404, + data=b"Address configured on server, but type not allowed URL type in URLs list") + self.__write_data(r) + + if not found: + r = Response(status_code=404, data=b"Not found!") + self.__write_func(r) + + self.__close() diff --git a/server/main.py b/server/main.py index 7ffbeb5..970dbfe 100644 --- a/server/main.py +++ b/server/main.py @@ -5,251 +5,11 @@ import os import config -from .file_read import fileiter -from .response import Response -from .common import HandlerType, HTTPMethod - - -MAX_REQLINE = 64 * 1024 log = logging.getLogger(__name__) -class HTTPHandler: - def __init__(self, conn, addr): - self.conn = conn - self.addr = addr - - self.read_fd = conn.makefile("rb") - self.write_fd = conn.makefile("wb") - - self.http_method = "" - self.http_address = "" - self.http_get_request = {} - self.http_headers = {} - - self.data = b"" - - self.__handle_connection() - - def __write_data(self, data_iter): - for chunk in data_iter: - self.write_fd.write(chunk) - self.write_fd.flush() - - def __read_data(self, size=MAX_REQLINE+1): - return self.read_fd.readline(size) - - def __close(self): - self.write_fd.close() - self.read_fd.close() - self.conn.close() - - def __handle_connection(self): - """ - Chain of handle first line of HTTP request - """ - raw = self.__read_data() - - if len(raw) > MAX_REQLINE: - log.debug("Request line too long") - - r = Response(status_code=413, data=b"Request line too long") - self.__write_data(r) - - self.__close() - return - - # FIXME data maybe isn't UTF-8 or just binary - req_line = raw.decode().rstrip("\r\n") - req_line_splited = req_line.split(" ") - - if len(req_line_splited) != 3: - if len(req_line_splited) > 3: - log.debug("Request head too long") - else: - log.debug("Request head too small") - - r = Response(status_code=400, data=b"Bad request") - self.__write_data(r) - self.__close() - return - - - if req_line_splited[2].startswith("HTTP/"): - if (req_line_splited[2] == "HTTP/1.0" or - req_line_splited[2] == "HTTP/1.1"): - http_method = req_line_splited[0] - else: - log.debug("Unsupported HTTP version") - - r = Response(status_code=505, data=b"HTTP Version Not Supported") - self.__write_data(r) - self.__close() - return - else: - log.debug("Protocol isn't HTTP version 1.0 or 1.1") - - r = Response(status_code=400, data=b"Bad request") - self.__write_data(r) - self.__close() - return - - if http_method == "GET": - self.http_method = HTTPMethod.GET - elif http_method == "POST": - self.http_method = HTTPMethod.POST - elif http_method == "PUT": - self.http_method = HTTPMethod.PUT - elif http_method == "HEAD": - self.http_method = HTTPMethod.HEAD - elif http_method == "DELETE": - self.http_method = HTTPMethod.DELETE - elif http_method == "OPTIONS": - self.http_method = HTTPMethod.OPTIONS - elif http_method == "TRACE": - self.http_method = HTTPMethod.TRACE - elif http_method == "PATCH": - self.http_method = HTTPMethod.PATCH - else: - log.debug("Method not allowed") - - r = Response(status_code=405, data=b"Method not allowed") - self.__write_data(r) - self.__close() - return - - if (self.http_method == HTTPMethod.GET and - req_line_splited[1].find("?") >= 0): - - addr, get_req = req_line_splited[1].split("?", 1) - self.http_address = addr - - for param in get_req.split("&"): - if len(param) == 0: - continue - - elif param.find("=") >= 0: - name, val = param.split("=", 1) - - else: - name, val = (param, "") - - self.http_get_request.update({name: val}) - else: - self.http_address = req_line_splited[1] - self.http_get_request = {} - - self.__http_header_handle() - - def __http_header_handle(self): - """ - Chain of handle headers in HTTP request - """ - while True: - raw = self.__read_data() - - if len(raw) > MAX_REQLINE: - log.debug("Request header line too long") - self.conn.send(b"") - self.conn.close() - return - - if raw == b"": - log.debug("Client not send data, close connection") - self.conn.send(b"") - self.conn.close() - return - - if raw == b"\r\n" or raw == b"\n": - self.__http_data_handle() - break - else: - decoded_data = raw.decode("UTF-8").rstrip("\r\n") - decoded_data_split = decoded_data.split(":", 1) - self.http_headers.update( - { - decoded_data_split[0]: decoded_data_split[1].strip(" ") - }) - - def __http_data_handle(self): - """ - Chain of receive data partitions HTTP request - """ - - if "Content-Length" in self.http_headers: - # TODO: if content-length biggest of MAX_REQLINE - partition receive - log.debug("Request has contain body, try to read") - log.debug("Check Content-Length value type:") - - if self.http_headers["Content-Length"].isdigit(): - log.debug(" Pass") - bytes_to_receive = int(self.http_headers["Content-Length"]) - else: - log.error(" Content-Length is not integer") - self.conn.send(b"") - self.conn.close() - return - - log.debug("Want to receive {} bytes from client".format(bytes_to_receive)) - self.data = self.read_fd.read(bytes_to_receive) - - log.debug("Request method: %s" % self.http_method) - log.debug("Requested address: %s" % self.http_address) - log.debug("Request headers: %s" % self.http_headers) - log.debug("Request in GET after ?: %s" % self.http_get_request) - log.debug("Data: %s" % self.data) - - self.__send_form_data() - - def __send_form_data(self): - """ - Chain of handle, pack & send data to client - """ - found = False - - for path in config.urls: - if self.http_address == path.address: - if self.http_method in path.methods: - found = True - - if path.handler_type == HandlerType.STATIC_FILE: - log.debug("Address associated with static file on path {}".format(path.link)) - r = Response(data=fileiter(path.link)) - self.__write_data(r) - - elif path.handler_type == HandlerType.FUNCTION: - log.debug("Address associated with function") - func = path.link - - if func.__code__.co_argcount == 0: - func_return = func() - - elif func.__code__.co_argcount == 1: - func_return = func(self.http_headers) - - elif func.__code__.co_argcount == 2: - func_return = func(self.http_headers, self.http_get_request) - - elif func.__code__.co_argcount == 3: - func_return = func(self.http_headers, self.http_get_request, self.data) - - self.__write_data(func_return) - - else: - log.warning("Address configured on server, but type not allowed URL type in URLs list") - - r = Response(status_code=404, - data=b"Address configured on server, but type not allowed URL type in URLs list") - self.__write_data(r) - - if not found: - r = Response(status_code=404, data=b"Not found!") - self.__write_func(r) - - self.__close() - -def init_server_socket(): +def init_server_socket(server_handler): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as main_server_socket: main_server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) @@ -262,7 +22,7 @@ def init_server_socket(): conn, addr = main_server_socket.accept() log.info("Accepted connection from %s", addr[0]) - thr_serv_conn = threading.Thread(target=HTTPHandler, args=(conn, addr,), daemon=True) + thr_serv_conn = threading.Thread(target=server_handler, args=(conn, addr,), daemon=True) thr_serv_conn.run() except KeyboardInterrupt: