import gzip
import json
import logging
import re
import threading
import time
import zlib
from email.parser import BytesParser
from functools import partial
from http import HTTPStatus, server
from urllib.parse import parse_qs, unquote, urlparse
from gcp_storage_emulator import settings
from gcp_storage_emulator.handlers import buckets, objects
from gcp_storage_emulator.storage import Storage
logger = logging.getLogger(__name__)
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
def _wipe_data(req, res, storage):
keep_buckets = bool(req.query.get("keep-buckets"))
logger.debug("Wiping storage")
if keep_buckets:
logger.debug("...while keeping the buckets")
storage.wipe(keep_buckets)
logger.debug("Storage wiped")
res.write("OK")
def _health_check(req, res, storage):
res.write("OK")
HANDLERS = (
(r"^{}/b$".format(settings.API_ENDPOINT), {GET: buckets.ls, POST: buckets.insert}),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)$".format(settings.API_ENDPOINT),
{GET: buckets.get, DELETE: buckets.delete},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o$".format(settings.API_ENDPOINT),
{GET: objects.ls},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>.*[^/]+)/copyTo/b/".format(
settings.API_ENDPOINT
)
+ r"(?P<dest_bucket_name>[-.\w]+)/o/(?P<dest_object_id>.*[^/]+)$",
{POST: objects.copy},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>.*[^/]+)/rewriteTo/b/".format(
settings.API_ENDPOINT
)
+ r"(?P<dest_bucket_name>[-.\w]+)/o/(?P<dest_object_id>.*[^/]+)$",
{POST: objects.rewrite},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>.*[^/]+)/compose$".format(
settings.API_ENDPOINT
),
{POST: objects.compose},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>.*[^/]+)$".format(
settings.API_ENDPOINT
),
{GET: objects.get, DELETE: objects.delete, PATCH: objects.patch},
),
# Non-default API endpoints
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o$".format(settings.UPLOAD_API_ENDPOINT),
{POST: objects.insert, PUT: objects.upload_partial},
),
(
r"^{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>.*[^/]+)$".format(
settings.DOWNLOAD_API_ENDPOINT
),
{GET: objects.download},
),
(
r"^{}$".format(settings.BATCH_API_ENDPOINT),
{POST: objects.batch},
),
# Internal API, not supported by the real GCS
(r"^/$", {GET: _health_check}), # Health check endpoint
(r"^/wipe$", {GET: _wipe_data}), # Wipe all data
# Public file serving, same as object.download and signed URLs
(
r"^/(?P<bucket_name>[-.\w]+)/(?P<object_id>.*[^/]+)$",
{GET: objects.download, PUT: objects.xml_upload},
),
)
BATCH_HANDLERS = (
r"^(?P<method>[\w]+).*{}/b/(?P<bucket_name>[-.\w]+)/o/(?P<object_id>[^\?]+[^/])([\?].*)?$".format(
settings.API_ENDPOINT
),
r"^(?P<method>[\w]+).*{}/b/(?P<bucket_name>[-.\w]+)([\?].*)?$".format(
settings.API_ENDPOINT
),
r"^Content-Type:\s*(?P<content_type>[-.\w/]+)$",
)
def _parse_batch_item(item):
parsed_params = {}
content_reached = None
partial_content = ""
current_content = item.get_payload()
for line in current_content.splitlines():
if not content_reached:
if not line:
content_reached = True
else:
for regex in BATCH_HANDLERS:
pattern = re.compile(regex)
match = pattern.fullmatch(line)
if match:
for k, v in match.groupdict().items():
parsed_params[k] = unquote(v)
else:
partial_content += line
if partial_content and parsed_params.get("content_type") == "application/json":
parsed_params["meta"] = json.loads(partial_content)
return parsed_params
def _read_raw_data(request_handler):
if request_handler.headers["Content-Length"]:
return request_handler.rfile.read(
int(request_handler.headers["Content-Length"])
)
if request_handler.headers["Transfer-Encoding"] == "chunked":
raw_data = b""
while True:
line = request_handler.rfile.readline().strip()
chunk_size = int(line, 16) if line else 0
if chunk_size == 0:
break
raw_data += request_handler.rfile.read(chunk_size)
request_handler.rfile.readline()
return raw_data
return None
def _decode_raw_data(raw_data, request_handler):
if not raw_data:
return None
if request_handler.headers["Content-Encoding"] == "gzip":
return gzip.decompress(raw_data)
if request_handler.headers["Content-Encoding"] == "deflate":
return zlib.decompress(raw_data)
return raw_data
def _read_data(request_handler, query):
raw_data = _decode_raw_data(_read_raw_data(request_handler), request_handler)
if not raw_data:
return None
content_type = request_handler.headers["Content-Type"] or "application/octet-stream"
if content_type.startswith("application/json") and "upload_id" not in query:
return json.loads(raw_data)
if content_type.startswith("multipart/"):
parser = BytesParser()
header = bytes("Content-Type:" + content_type + "\r\n", "utf-8")
msg = parser.parsebytes(header + raw_data)
payload = msg.get_payload()
if content_type.startswith("multipart/mixed"):
# Batch https://cloud.google.com/storage/docs/json_api/v1/how-tos/batch
rv = list()
for item in payload:
parsed_params = _parse_batch_item(item)
rv.append(parsed_params)
return rv
# For multipart upload, google API expect the first item to be a json-encoded
# object, and the second (and only other) part, the file content
return {
"meta": json.loads(payload[0].get_payload()),
"content": payload[1].get_payload(decode=True),
"content-type": payload[1].get_content_type(),
}
return raw_data
class Request(object):
def __init__(self, request_handler, method):
super().__init__()
self._path = request_handler.path
self._request_handler = request_handler
self._server_address = request_handler.server.server_address
self._base_url = "http://{}:{}".format(
self._server_address[0], self._server_address[1]
)
self._full_url = self._base_url + self._path
self._parsed_url = urlparse(self._full_url)
self._query = parse_qs(self._parsed_url.query)
self._methtod = method
self._data = None
self._parsed_params = None
@property
def path(self):
return self._parsed_url.path
@property
def base_url(self):
return self._base_url
@property
def full_url(self):
return self._full_url
@property
def method(self):
return self._methtod
@property
def query(self):
return self._query
@property
def params(self):
if not self._match:
return None
if not self._parsed_params:
self._parsed_params = {}
for k, v in self._match.groupdict().items():
self._parsed_params[k] = unquote(v)
return self._parsed_params
@property
def data(self):
if not self._data:
self._data = _read_data(self._request_handler, self._query)
return self._data
def get_header(self, key, default=None):
return self._request_handler.headers.get(key, default)
def set_match(self, match):
self._match = match
class Response(object):
def __init__(self, handler):
super().__init__()
self._handler = handler
self.status = HTTPStatus.OK
self._headers = {}
self._content = ""
def write(self, content):
logger.warning(
"[RESPONSE] Content handled as string, should be handled as stream"
)
self._content += content
def write_file(self, content, content_type="application/octet-stream"):
if content_type is not None:
self["Content-type"] = content_type
self._content = content
def json(self, obj):
self["Content-type"] = "application/json"
self._content = json.dumps(obj)
def __setitem__(self, key, value):
self._headers[key] = value
def __getitem__(self, key):
return self._headers[key]
def close(self):
self._handler.send_response(self.status.value, self.status.phrase)
for k, v in self._headers.items():
self._handler.send_header(k, v)
content = self._content
if isinstance(self._content, str):
content = self._content.encode("utf-8")
self._handler.send_header("Content-Length", str(len(content)))
self._handler.end_headers()
self._handler.wfile.write(content)
class Router(object):
def __init__(self, request_handler):
super().__init__()
self._request_handler = request_handler
def handle(self, method):
if self._request_handler.headers["x-http-method-override"]:
method = self._request_handler.headers["x-http-method-override"]
request = Request(self._request_handler, method)
response = Response(self._request_handler)
for regex, handlers in HANDLERS:
pattern = re.compile(regex)
match = pattern.fullmatch(request.path)
if match:
request.set_match(match)
handler = handlers.get(method)
try:
handler(request, response, self._request_handler.storage)
except Exception as e:
logger.error(
"An error has occurred while running the handler for {} {}".format(
request.method,
request.full_url,
)
)
logger.error(e)
raise e
break
else:
logger.error(
"Method not implemented: {} - {}".format(request.method, request.path)
)
response.status = HTTPStatus.NOT_IMPLEMENTED
response.close()
class RequestHandler(server.BaseHTTPRequestHandler):
def __init__(self, storage, *args, **kwargs):
self.storage = storage
super().__init__(*args, **kwargs)
def do_GET(self):
router = Router(self)
router.handle(GET)
def do_POST(self):
router = Router(self)
router.handle(POST)
def do_DELETE(self):
router = Router(self)
router.handle(DELETE)
def do_PUT(self):
router = Router(self)
router.handle(PUT)
def do_PATCH(self):
router = Router(self)
router.handle(PATCH)
def log_message(self, format, *args):
logger.info(format % args)
class APIThread(threading.Thread):
def __init__(self, host, port, storage, *args, **kwargs):
super().__init__(*args, **kwargs)
self._host = host
self._port = port
self.is_running = threading.Event()
self._httpd = None
self._storage = storage
def run(self):
self._httpd = server.HTTPServer(
(self._host, self._port), partial(RequestHandler, self._storage)
)
self.is_running.set()
self._httpd.serve_forever()
def join(self, timeout=None):
self.is_running.clear()
if self._httpd:
logger.info("[API] Stopping API server")
self._httpd.shutdown()
self._httpd.server_close()
class Server(object):
def __init__(self, host, port, in_memory, default_bucket=None, data_dir=None):
self._storage = Storage(use_memory_fs=in_memory, data_dir=data_dir)
if default_bucket:
logger.debug('[SERVER] Creating default bucket "{}"'.format(default_bucket))
buckets.create_bucket(default_bucket, self._storage)
self._api = APIThread(host, port, self._storage)
# Context Manager
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.stop()
def start(self):
self._api.start()
self._api.is_running.wait() # Start the API thread
def stop(self):
self._api.join(timeout=1)
def wipe(self, keep_buckets=False):
self._storage.wipe(keep_buckets=keep_buckets)
def run(self):
try:
self.start()
logger.info("[SERVER] All services started")
while True:
try:
time.sleep(0.1)
except KeyboardInterrupt:
logger.info("[SERVER] Received keyboard interrupt")
break
finally:
self.stop()
def create_server(host, port, in_memory=False, default_bucket=None, data_dir=None):
logger.info("Starting server at {}:{}".format(host, port))
return Server(
host,
port,
in_memory=in_memory,
default_bucket=default_bucket,
data_dir=data_dir,
)