diff --git a/shrlok/shrlok.py b/shrlok/shrlok.py index 273333b..70ce367 100644 --- a/shrlok/shrlok.py +++ b/shrlok/shrlok.py @@ -1,4 +1,27 @@ #!/usr/bin/env python3 +"""shrlok server: receive text/files from a socket, put them on a Web server. + +The server expects messages with the following format to come through the +socket: + +1. a length, as ASCII digits, +2. a null byte, +3. a JSON object containing at least the "type" key with a known value, +4. another null byte, +5. the text or file itself. + +The length is of the JSON object + the null char + the content itself. + +Example, with \\0 representing a null byte: + +28\\0{"type":"txt"}\\0hello shrlok! + +The content is 13 bytes (assuming no LF at the end), the header is 14 bytes, +plus the null byte it is 13 + 14 + 1 = 28 bytes, thus the prefixed length. + +After the content is succesfully retrieved and put on an appropriate location, +the server will reply the file path through the socket and close the connection. +""" import argparse import json @@ -29,11 +52,9 @@ HTML_TEMPLATE = """\ class Handler(socketserver.StreamRequestHandler): def handle(self): - fragments = [] - while (chunk := self.request.recv(4096)): - fragments.append(chunk) - data = b"".join(fragments) + data = self.receive_input() + # Extract header. try: first_zero = data.index(b"\0") header_data, data = data[:first_zero], data[first_zero + 1:] @@ -42,17 +63,38 @@ class Handler(socketserver.StreamRequestHandler): print("Bad header.") return + file_name = None if header.get("type") == "txt": file_name = write_text(data, title=header.get("title")) else: print("Unknown type.") - return - - if not file_name: + if file_name is None: return print(f"{len(data)} bytes — {header} — '{file_name}'.") - self.request.sendall(file_name.encode()) + try: + self.request.sendall(file_name.encode()) + except BrokenPipeError: + print("Broken pipe.") + + def receive_input(self): + length = None + data = b"" + while True: + chunk = self.request.recv(4096) + if not chunk: + break + data += chunk + if length is None: + try: + first_zero = data.index(b"\0") + except ValueError: + return b"" + length_data, data = data[:first_zero], data[first_zero + 1:] + length = int(length_data.decode()) + if len(data) >= length: # retrieval completed + break + return data def write_text(data: bytes, title=None): @@ -92,6 +134,7 @@ def main(): try: with socketserver.UnixStreamServer(socket_path, Handler) as server: + os.chmod(socket_path, 0o664) server.serve_forever() except KeyboardInterrupt: print("Stopping server.")