From 561e8744d13d6a14b5bc89c65f08ffb05fa64b69 Mon Sep 17 00:00:00 2001 From: Barney Bittner Date: Fri, 12 Jul 2024 20:34:47 +0100 Subject: [PATCH 1/2] Bundle httpfs by default --- vendor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor.py b/vendor.py index cb5758c3..820d457b 100644 --- a/vendor.py +++ b/vendor.py @@ -15,7 +15,7 @@ # list of extensions to bundle -extensions = ['parquet', 'icu', 'json'] +extensions = ['parquet', 'icu', 'json', 'httpfs'] # path to target basedir = os.getcwd() From fb6c8892f663253012b277e8aab8e2aa99873068 Mon Sep 17 00:00:00 2001 From: barnab Date: Mon, 22 Jul 2024 22:44:08 +0100 Subject: [PATCH 2/2] optionally include httpfs plugin when building from source --- binding.gyp | 22 + binding.gyp.in | 5 +- .../httpfs/create_secret_functions.cpp | 250 ++++ src/duckdb/extension/httpfs/crypto.cpp | 27 + src/duckdb/extension/httpfs/hffs.cpp | 418 ++++++ src/duckdb/extension/httpfs/httpfs.cpp | 762 +++++++++++ .../extension/httpfs/httpfs_extension.cpp | 101 ++ .../include/create_secret_functions.hpp | 51 + .../extension/httpfs/include/crypto.hpp | 19 + src/duckdb/extension/httpfs/include/hffs.hpp | 71 + .../httpfs/include/http_metadata_cache.hpp | 91 ++ .../extension/httpfs/include/httpfs.hpp | 178 +++ .../httpfs/include/httpfs_extension.hpp | 14 + src/duckdb/extension/httpfs/include/s3fs.hpp | 256 ++++ src/duckdb/extension/httpfs/s3fs.cpp | 1217 +++++++++++++++++ vendor.py | 53 +- 16 files changed, 3524 insertions(+), 11 deletions(-) create mode 100644 src/duckdb/extension/httpfs/create_secret_functions.cpp create mode 100644 src/duckdb/extension/httpfs/crypto.cpp create mode 100644 src/duckdb/extension/httpfs/hffs.cpp create mode 100644 src/duckdb/extension/httpfs/httpfs.cpp create mode 100644 src/duckdb/extension/httpfs/httpfs_extension.cpp create mode 100644 src/duckdb/extension/httpfs/include/create_secret_functions.hpp create mode 100644 src/duckdb/extension/httpfs/include/crypto.hpp create mode 100644 src/duckdb/extension/httpfs/include/hffs.hpp create mode 100644 src/duckdb/extension/httpfs/include/http_metadata_cache.hpp create mode 100644 src/duckdb/extension/httpfs/include/httpfs.hpp create mode 100644 src/duckdb/extension/httpfs/include/httpfs_extension.hpp create mode 100644 src/duckdb/extension/httpfs/include/s3fs.hpp create mode 100644 src/duckdb/extension/httpfs/s3fs.cpp diff --git a/binding.gyp b/binding.gyp index 0b91d445..6bd6c5bb 100644 --- a/binding.gyp +++ b/binding.gyp @@ -2,6 +2,9 @@ "targets": [ { "target_name": "<(module_name)", + "variables": { + "include_httpfs": " CreateS3SecretFunctions::CreateSecretFunctionInternal(ClientContext &context, + CreateSecretInput &input, + S3AuthParams params) { + // for r2 we can set the endpoint using the account id + if (input.type == "r2" && input.options.find("account_id") != input.options.end()) { + params.endpoint = input.options["account_id"].ToString() + ".r2.cloudflarestorage.com"; + } + + // apply any overridden settings + for (const auto &named_param : input.options) { + auto lower_name = StringUtil::Lower(named_param.first); + + if (lower_name == "key_id") { + params.access_key_id = named_param.second.ToString(); + } else if (lower_name == "secret") { + params.secret_access_key = named_param.second.ToString(); + } else if (lower_name == "region") { + params.region = named_param.second.ToString(); + } else if (lower_name == "session_token") { + params.session_token = named_param.second.ToString(); + } else if (lower_name == "endpoint") { + params.endpoint = named_param.second.ToString(); + } else if (lower_name == "url_style") { + params.url_style = named_param.second.ToString(); + } else if (lower_name == "use_ssl") { + if (named_param.second.type() != LogicalType::BOOLEAN) { + throw InvalidInputException("Invalid type past to secret option: '%s', found '%s', expected: 'BOOLEAN'", + lower_name, named_param.second.type().ToString()); + } + params.use_ssl = named_param.second.GetValue(); + } else if (lower_name == "url_compatibility_mode") { + if (named_param.second.type() != LogicalType::BOOLEAN) { + throw InvalidInputException("Invalid type past to secret option: '%s', found '%s', expected: 'BOOLEAN'", + lower_name, named_param.second.type().ToString()); + } + params.s3_url_compatibility_mode = named_param.second.GetValue(); + } else if (lower_name == "account_id") { + continue; // handled already + } else { + throw InternalException("Unknown named parameter passed to CreateSecretFunctionInternal: " + lower_name); + } + } + + // Set scope to user provided scope or the default + auto scope = input.scope; + if (scope.empty()) { + if (input.type == "s3") { + scope.push_back("s3://"); + scope.push_back("s3n://"); + scope.push_back("s3a://"); + } else if (input.type == "r2") { + scope.push_back("r2://"); + } else if (input.type == "gcs") { + scope.push_back("gcs://"); + scope.push_back("gs://"); + } else { + throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type); + } + } + + return S3SecretHelper::CreateSecret(scope, input.type, input.provider, input.name, params); +} + +unique_ptr CreateS3SecretFunctions::CreateS3SecretFromSettings(ClientContext &context, + CreateSecretInput &input) { + auto &opener = context.client_data->file_opener; + FileOpenerInfo info; + auto params = S3AuthParams::ReadFrom(opener.get(), info); + return CreateSecretFunctionInternal(context, input, params); +} + +unique_ptr CreateS3SecretFunctions::CreateS3SecretFromConfig(ClientContext &context, + CreateSecretInput &input) { + S3AuthParams empty_params; + empty_params.use_ssl = true; + empty_params.s3_url_compatibility_mode = false; + empty_params.region = "us-east-1"; + empty_params.endpoint = "s3.amazonaws.com"; + + if (input.type == "gcs") { + empty_params.endpoint = "storage.googleapis.com"; + } + + if (input.type == "gcs" || input.type == "r2") { + empty_params.url_style = "path"; + } + + return CreateSecretFunctionInternal(context, input, empty_params); +} + +void CreateS3SecretFunctions::SetBaseNamedParams(CreateSecretFunction &function, string &type) { + function.named_parameters["key_id"] = LogicalType::VARCHAR; + function.named_parameters["secret"] = LogicalType::VARCHAR; + function.named_parameters["region"] = LogicalType::VARCHAR; + function.named_parameters["session_token"] = LogicalType::VARCHAR; + function.named_parameters["endpoint"] = LogicalType::VARCHAR; + function.named_parameters["url_style"] = LogicalType::VARCHAR; + function.named_parameters["use_ssl"] = LogicalType::BOOLEAN; + function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN; + + if (type == "r2") { + function.named_parameters["account_id"] = LogicalType::VARCHAR; + } +} + +void CreateS3SecretFunctions::RegisterCreateSecretFunction(DatabaseInstance &instance, string type) { + // Register the new type + SecretType secret_type; + secret_type.name = type; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; + + ExtensionUtil::RegisterSecretType(instance, secret_type); + + CreateSecretFunction from_empty_config_fun2 = {type, "config", CreateS3SecretFromConfig}; + CreateSecretFunction from_settings_fun2 = {type, "duckdb_settings", CreateS3SecretFromSettings}; + SetBaseNamedParams(from_empty_config_fun2, type); + SetBaseNamedParams(from_settings_fun2, type); + ExtensionUtil::RegisterFunction(instance, from_empty_config_fun2); + ExtensionUtil::RegisterFunction(instance, from_settings_fun2); +} + +void CreateBearerTokenFunctions::Register(DatabaseInstance &instance) { + // Generic Bearer secret + SecretType secret_type; + secret_type.name = GENERIC_BEARER_TYPE; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; + ExtensionUtil::RegisterSecretType(instance, secret_type); + + // Generic Bearer config provider + CreateSecretFunction config_fun = {GENERIC_BEARER_TYPE, "config", CreateBearerSecretFromConfig}; + config_fun.named_parameters["token"] = LogicalType::VARCHAR; + ExtensionUtil::RegisterFunction(instance, config_fun); + + // HuggingFace secret + SecretType secret_type_hf; + secret_type_hf.name = HUGGINGFACE_TYPE; + secret_type_hf.deserializer = KeyValueSecret::Deserialize; + secret_type_hf.default_provider = "config"; + ExtensionUtil::RegisterSecretType(instance, secret_type_hf); + + // Huggingface config provider + CreateSecretFunction hf_config_fun = {HUGGINGFACE_TYPE, "config", CreateBearerSecretFromConfig}; + hf_config_fun.named_parameters["token"] = LogicalType::VARCHAR; + ExtensionUtil::RegisterFunction(instance, hf_config_fun); + + // Huggingface credential_chain provider + CreateSecretFunction hf_cred_fun = {HUGGINGFACE_TYPE, "credential_chain", + CreateHuggingFaceSecretFromCredentialChain}; + ExtensionUtil::RegisterFunction(instance, hf_cred_fun); +} + +unique_ptr CreateBearerTokenFunctions::CreateSecretFunctionInternal(ClientContext &context, + CreateSecretInput &input, + const string &token) { + // Set scope to user provided scope or the default + auto scope = input.scope; + if (scope.empty()) { + if (input.type == GENERIC_BEARER_TYPE) { + scope.push_back(""); + } else if (input.type == HUGGINGFACE_TYPE) { + scope.push_back("hf://"); + } else { + throw InternalException("Unknown secret type found in httpfs extension: '%s'", input.type); + } + } + auto return_value = make_uniq(scope, input.type, input.provider, input.name); + + //! Set key value map + return_value->secret_map["token"] = token; + + //! Set redact keys + return_value->redact_keys = {"token"}; + + return std::move(return_value); +} + +unique_ptr CreateBearerTokenFunctions::CreateBearerSecretFromConfig(ClientContext &context, + CreateSecretInput &input) { + string token; + + auto token_input = input.options.find("token"); + for (const auto &named_param : input.options) { + auto lower_name = StringUtil::Lower(named_param.first); + if (lower_name == "token") { + token = named_param.second.ToString(); + } + } + + return CreateSecretFunctionInternal(context, input, token); +} + +static string TryReadTokenFile(const string &token_path, const string error_source_message, + bool fail_on_exception = true) { + try { + LocalFileSystem fs; + auto handle = fs.OpenFile(token_path, {FileOpenFlags::FILE_FLAGS_READ}); + return handle->ReadLine(); + } catch (std::exception &ex) { + if (!fail_on_exception) { + return ""; + } + ErrorData error(ex); + throw IOException("Failed to read token path '%s'%s. (error: %s)", token_path, error_source_message, + error.RawMessage()); + } +} + +unique_ptr +CreateBearerTokenFunctions::CreateHuggingFaceSecretFromCredentialChain(ClientContext &context, + CreateSecretInput &input) { + // Step 1: Try the ENV variable HF_TOKEN + const char *hf_token_env = std::getenv("HF_TOKEN"); + if (hf_token_env) { + return CreateSecretFunctionInternal(context, input, hf_token_env); + } + // Step 2: Try the ENV variable HF_TOKEN_PATH + const char *hf_token_path_env = std::getenv("HF_TOKEN_PATH"); + if (hf_token_path_env) { + auto token = TryReadTokenFile(hf_token_path_env, " fetched from HF_TOKEN_PATH env variable"); + return CreateSecretFunctionInternal(context, input, token); + } + + // Step 3: Try the path $HF_HOME/token + const char *hf_home_env = std::getenv("HF_HOME"); + if (hf_home_env) { + auto token_path = LocalFileSystem().JoinPath(hf_home_env, "token"); + auto token = TryReadTokenFile(token_path, " constructed using the HF_HOME variable: '$HF_HOME/token'"); + return CreateSecretFunctionInternal(context, input, token); + } + + // Step 4: Check the default path + auto token = TryReadTokenFile("~/.cache/huggingface/token", "", false); + return CreateSecretFunctionInternal(context, input, token); +} +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/crypto.cpp b/src/duckdb/extension/httpfs/crypto.cpp new file mode 100644 index 00000000..f6d399ed --- /dev/null +++ b/src/duckdb/extension/httpfs/crypto.cpp @@ -0,0 +1,27 @@ +#include "crypto.hpp" +#include "mbedtls_wrapper.hpp" + +namespace duckdb { + +void sha256(const char *in, size_t in_len, hash_bytes &out) { + duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(in, in_len, (char *)out); +} + +void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out) { + duckdb_mbedtls::MbedTlsWrapper::Hmac256(secret, secret_len, message.data(), message.size(), (char *)out); +} + +void hmac256(std::string message, hash_bytes secret, hash_bytes &out) { + hmac256(message, (char *)secret, sizeof(hash_bytes), out); +} + +void hex256(hash_bytes &in, hash_str &out) { + const char *hex = "0123456789abcdef"; + unsigned char *pin = in; + unsigned char *pout = out; + for (; pin < in + sizeof(in); pout += 2, pin++) { + pout[0] = hex[(*pin >> 4) & 0xF]; + pout[1] = hex[*pin & 0xF]; + } +} +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/hffs.cpp b/src/duckdb/extension/httpfs/hffs.cpp new file mode 100644 index 00000000..ebc01335 --- /dev/null +++ b/src/duckdb/extension/httpfs/hffs.cpp @@ -0,0 +1,418 @@ +#include "hffs.hpp" + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/exception/http_exception.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/secret/secret_manager.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +#include +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +#include + +namespace duckdb { + +static duckdb::unique_ptr initialize_http_headers(HeaderMap &header_map) { + auto headers = make_uniq(); + for (auto &entry : header_map) { + headers->insert(entry); + } + return headers; +} + +HuggingFaceFileSystem::~HuggingFaceFileSystem() { +} + +static string ParseNextUrlFromLinkHeader(const string &link_header_content) { + auto split_outer = StringUtil::Split(link_header_content, ','); + for (auto &split : split_outer) { + auto split_inner = StringUtil::Split(split, ';'); + if (split_inner.size() != 2) { + throw IOException("Unexpected link header for huggingface pagination: %s", link_header_content); + } + + StringUtil::Trim(split_inner[1]); + if (split_inner[1] == "rel=\"next\"") { + StringUtil::Trim(split_inner[0]); + + if (!StringUtil::StartsWith(split_inner[0], "<") || !StringUtil::EndsWith(split_inner[0], ">")) { + throw IOException("Unexpected link header for huggingface pagination: %s", link_header_content); + } + + return split_inner[0].substr(1, split_inner[0].size() - 2); + } + } + + throw IOException("Failed to parse Link header for paginated response, pagination support"); +} + +HFFileHandle::~HFFileHandle() {}; + +void HFFileHandle::InitializeClient(optional_ptr client_context) { + http_client = HTTPFileSystem::GetClient(this->http_params, parsed_url.endpoint.c_str(), this); +} + +string HuggingFaceFileSystem::ListHFRequest(ParsedHFUrl &url, HTTPParams &http_params, string &next_page_url, + optional_ptr state) { + HeaderMap header_map; + auto headers = initialize_http_headers(header_map); + string link_header_result; + + auto client = HTTPFileSystem::GetClient(http_params, url.endpoint.c_str(), nullptr); + std::stringstream response; + + std::function request([&]() { + if (state) { + state->get_count++; + } + + return client->Get( + next_page_url.c_str(), *headers, + [&](const duckdb_httplib_openssl::Response &response) { + if (response.status >= 400) { + throw HTTPException(response, "HTTP GET error on '%s' (HTTP %d)", next_page_url, response.status); + } + auto link_res = response.headers.find("Link"); + if (link_res != response.headers.end()) { + link_header_result = link_res->second; + } + return true; + }, + [&](const char *data, size_t data_length) { + if (state) { + state->total_bytes_received += data_length; + } + response << string(data, data_length); + return true; + }); + }); + + auto res = RunRequestWithRetry(request, next_page_url, "GET", http_params, nullptr); + + if (res->code != 200) { + throw IOException(res->error + " error for HTTP GET to '" + next_page_url + "'"); + } + + if (!link_header_result.empty()) { + next_page_url = ParseNextUrlFromLinkHeader(link_header_result); + } else { + next_page_url = ""; + } + + return response.str(); +} + +static bool Match(vector::const_iterator key, vector::const_iterator key_end, + vector::const_iterator pattern, vector::const_iterator pattern_end) { + + while (key != key_end && pattern != pattern_end) { + if (*pattern == "**") { + if (std::next(pattern) == pattern_end) { + return true; + } + while (key != key_end) { + if (Match(key, key_end, std::next(pattern), pattern_end)) { + return true; + } + key++; + } + return false; + } + if (!LikeFun::Glob(key->data(), key->length(), pattern->data(), pattern->length())) { + return false; + } + key++; + pattern++; + } + return key == key_end && pattern == pattern_end; +} + +void ParseListResult(string &input, vector &files, vector &directories) { + enum parse_entry { FILE, DIR, UNKNOWN }; + idx_t idx = 0; + idx_t nested = 0; + bool found_path; + parse_entry type; + string current_string; +base: + found_path = false; + type = parse_entry::UNKNOWN; + for (; idx < input.size(); idx++) { + if (input[idx] == '{') { + idx++; + goto entry; + } + } + goto end; +entry: + while (idx < input.size()) { + if (input[idx] == '}') { + if (nested) { + idx++; + nested--; + continue; + } else if (!found_path || type == parse_entry::UNKNOWN) { + throw IOException("Failed to parse list result"); + } else if (type == parse_entry::FILE) { + files.push_back("/" + current_string); + } else { + directories.push_back("/" + current_string); + } + current_string = ""; + idx++; + goto base; + } else if (input[idx] == '{') { + nested++; + idx++; + } else if (strncmp(input.c_str() + idx, "\"type\":\"directory\"", 18) == 0) { + type = parse_entry::DIR; + idx += 18; + } else if (strncmp(input.c_str() + idx, "\"type\":\"file\"", 13) == 0) { + type = parse_entry::FILE; + idx += 13; + } else if (strncmp(input.c_str() + idx, "\"path\":\"", 8) == 0) { + idx += 8; + found_path = true; + goto pathname; + } else { + idx++; + } + } + goto end; +pathname: + while (idx < input.size()) { + // Handle escaped quote in url + if (input[idx] == '\\' && idx + 1 < input.size() && input[idx] == '\"') { + current_string += '\"'; + idx += 2; + } else if (input[idx] == '\"') { + idx++; + goto entry; + } else { + current_string += input[idx]; + idx++; + } + } +end: + return; +} + +// Some valid example Urls: +// - hf://datasets/lhoestq/demo1/default/train/0000.parquet +// - hf://datasets/lhoestq/demo1/default/train/*.parquet +// - hf://datasets/lhoestq/demo1/*/train/file_[abc].parquet +// - hf://datasets/lhoestq/demo1/**/train/*.parquet +vector HuggingFaceFileSystem::Glob(const string &path, FileOpener *opener) { + // Ensure the glob pattern is a valid HF url + auto parsed_glob_url = HFUrlParse(path); + auto first_wildcard_pos = parsed_glob_url.path.find_first_of("*[\\"); + + if (first_wildcard_pos == string::npos) { + return {path}; + } + + string shared_path = parsed_glob_url.path.substr(0, first_wildcard_pos); + auto last_path_slash = shared_path.find_last_of('/', first_wildcard_pos); + + // trim the final + if (last_path_slash == string::npos) { + // Root path + shared_path = ""; + } else { + shared_path = shared_path.substr(0, last_path_slash); + } + + auto http_params = HTTPParams::ReadFrom(opener); + SetParams(http_params, path, opener); + auto http_state = HTTPState::TryGetState(opener).get(); + + ParsedHFUrl curr_hf_path = parsed_glob_url; + curr_hf_path.path = shared_path; + + vector files; + vector dirs = {shared_path}; + string next_page_url = ""; + + // Loop over the paths and paginated responses for each path + while (true) { + if (next_page_url.empty() && !dirs.empty()) { + // Done with previous dir, load the next one + curr_hf_path.path = dirs.back(); + dirs.pop_back(); + next_page_url = HuggingFaceFileSystem::GetTreeUrl(curr_hf_path, http_params.hf_max_per_page); + } else if (next_page_url.empty()) { + // No more pages to read, also no more dirs + break; + } + + auto response_str = ListHFRequest(curr_hf_path, http_params, next_page_url, http_state); + ParseListResult(response_str, files, dirs); + } + + vector pattern_splits = StringUtil::Split(parsed_glob_url.path, "/"); + vector result; + for (const auto &file : files) { + + vector file_splits = StringUtil::Split(file, "/"); + bool is_match = Match(file_splits.begin(), file_splits.end(), pattern_splits.begin(), pattern_splits.end()); + + if (is_match) { + curr_hf_path.path = file; + result.push_back(GetHFUrl(curr_hf_path)); + } + } + + // Prune files using match + return result; +} + +unique_ptr HuggingFaceFileSystem::HeadRequest(FileHandle &handle, string hf_url, + HeaderMap header_map) { + auto &hf_handle = handle.Cast(); + auto http_url = HuggingFaceFileSystem::GetFileUrl(hf_handle.parsed_url); + return HTTPFileSystem::HeadRequest(handle, http_url, header_map); +} + +unique_ptr HuggingFaceFileSystem::GetRequest(FileHandle &handle, string s3_url, HeaderMap header_map) { + auto &hf_handle = handle.Cast(); + auto http_url = HuggingFaceFileSystem::GetFileUrl(hf_handle.parsed_url); + return HTTPFileSystem::GetRequest(handle, http_url, header_map); +} + +unique_ptr HuggingFaceFileSystem::GetRangeRequest(FileHandle &handle, string s3_url, + HeaderMap header_map, idx_t file_offset, + char *buffer_out, idx_t buffer_out_len) { + auto &hf_handle = handle.Cast(); + auto http_url = HuggingFaceFileSystem::GetFileUrl(hf_handle.parsed_url); + return HTTPFileSystem::GetRangeRequest(handle, http_url, header_map, file_offset, buffer_out, buffer_out_len); +} + +unique_ptr HuggingFaceFileSystem::CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener) { + D_ASSERT(flags.Compression() == FileCompressionType::UNCOMPRESSED); + + auto parsed_url = HFUrlParse(path); + + auto params = HTTPParams::ReadFrom(opener); + SetParams(params, path, opener); + + return duckdb::make_uniq(*this, std::move(parsed_url), path, flags, params); +} + +void HuggingFaceFileSystem::SetParams(HTTPParams ¶ms, const string &path, optional_ptr opener) { + auto secret_manager = FileOpener::TryGetSecretManager(opener); + auto transaction = FileOpener::TryGetCatalogTransaction(opener); + if (secret_manager && transaction) { + auto secret_match = secret_manager->LookupSecret(*transaction, path, "huggingface"); + + if (secret_match.HasMatch()) { + const auto &kv_secret = dynamic_cast(*secret_match.secret_entry->secret); + params.bearer_token = kv_secret.TryGetValue("token", true).ToString(); + } + } +} + +static void ThrowParseError(const string &url) { + throw IOException( + "Failed to parse '%s'. Please format url like: 'hf://datasets/my-username/my-dataset/path/to/file.parquet'", + url); +} + +ParsedHFUrl HuggingFaceFileSystem::HFUrlParse(const string &url) { + ParsedHFUrl result; + + if (!StringUtil::StartsWith(url, "hf://")) { + throw InternalException("Not an hf url"); + } + + size_t last_delim = 5; + size_t curr_delim; + + // Parse Repository type + curr_delim = url.find('/', last_delim); + if (curr_delim == string::npos) { + ThrowParseError(url); + } + result.repo_type = url.substr(last_delim, curr_delim - last_delim); + if (result.repo_type != "datasets" && result.repo_type != "spaces") { + throw IOException( + "Failed to parse: '%s'. Currently DuckDB only supports querying datasets or spaces, so the url should " + "start with 'hf://datasets' or 'hf://spaces'", + url); + } + + last_delim = curr_delim; + + // Parse repository and revision + auto repo_delim = url.find('/', last_delim + 1); + if (repo_delim == string::npos) { + ThrowParseError(url); + } + + auto next_at = url.find('@', repo_delim + 1); + auto next_slash = url.find('/', repo_delim + 1); + + if (next_slash == string::npos) { + ThrowParseError(url); + } + + if (next_at != string::npos && next_at < next_slash) { + result.repository = url.substr(last_delim + 1, next_at - last_delim - 1); + result.revision = url.substr(next_at + 1, next_slash - next_at - 1); + } else { + result.repository = url.substr(last_delim + 1, next_slash - last_delim - 1); + } + last_delim = next_slash; + + // The remainder is the path + result.path = url.substr(last_delim); + + return result; +} + +string HuggingFaceFileSystem::GetHFUrl(const ParsedHFUrl &url) { + if (url.revision == "main") { + return "hf://" + url.repo_type + "/" + url.repository + url.path; + } else { + return "hf://" + url.repo_type + "/" + url.repository + "@" + url.revision + url.path; + } +} + +string HuggingFaceFileSystem::GetTreeUrl(const ParsedHFUrl &url, idx_t limit) { + //! Url format {endpoint}/api/{repo_type}/{repository}/tree/{revision}{encoded_path_in_repo} + string http_url = url.endpoint; + + http_url = JoinPath(http_url, "api"); + http_url = JoinPath(http_url, url.repo_type); + http_url = JoinPath(http_url, url.repository); + http_url = JoinPath(http_url, "tree"); + http_url = JoinPath(http_url, url.revision); + http_url += url.path; + + if (limit > 0) { + http_url += "?limit=" + to_string(limit); + } + + return http_url; +} + +string HuggingFaceFileSystem::GetFileUrl(const ParsedHFUrl &url) { + //! Url format {endpoint}/{repo_type}[/{repository}/{revision}{encoded_path_in_repo} + string http_url = url.endpoint; + http_url = JoinPath(http_url, url.repo_type); + http_url = JoinPath(http_url, url.repository); + http_url = JoinPath(http_url, "resolve"); + http_url = JoinPath(http_url, url.revision); + http_url += url.path; + + return http_url; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/httpfs.cpp b/src/duckdb/extension/httpfs/httpfs.cpp new file mode 100644 index 00000000..659f93dc --- /dev/null +++ b/src/duckdb/extension/httpfs/httpfs.cpp @@ -0,0 +1,762 @@ +#include "httpfs.hpp" + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/exception/http_exception.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/thread.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#include "duckdb/logging/http_logger.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/main/secret/secret_manager.hpp" + +#include +#include +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +#include + +namespace duckdb { + +static duckdb::unique_ptr initialize_http_headers(HeaderMap &header_map) { + auto headers = make_uniq(); + for (auto &entry : header_map) { + headers->insert(entry); + } + return headers; +} + +HTTPParams HTTPParams::ReadFrom(optional_ptr opener) { + uint64_t timeout = DEFAULT_TIMEOUT; + uint64_t retries = DEFAULT_RETRIES; + uint64_t retry_wait_ms = DEFAULT_RETRY_WAIT_MS; + float retry_backoff = DEFAULT_RETRY_BACKOFF; + bool force_download = DEFAULT_FORCE_DOWNLOAD; + bool keep_alive = DEFAULT_KEEP_ALIVE; + bool enable_server_cert_verification = DEFAULT_ENABLE_SERVER_CERT_VERIFICATION; + std::string ca_cert_file; + uint64_t hf_max_per_page = DEFAULT_HF_MAX_PER_PAGE; + + Value value; + if (FileOpener::TryGetCurrentSetting(opener, "http_timeout", value)) { + timeout = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "force_download", value)) { + force_download = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "http_retries", value)) { + retries = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "http_retry_wait_ms", value)) { + retry_wait_ms = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "http_retry_backoff", value)) { + retry_backoff = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "http_keep_alive", value)) { + keep_alive = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "enable_server_cert_verification", value)) { + enable_server_cert_verification = value.GetValue(); + } + if (FileOpener::TryGetCurrentSetting(opener, "ca_cert_file", value)) { + ca_cert_file = value.ToString(); + } + if (FileOpener::TryGetCurrentSetting(opener, "hf_max_per_page", value)) { + hf_max_per_page = value.GetValue(); + } + + return {timeout, + retries, + retry_wait_ms, + retry_backoff, + force_download, + keep_alive, + enable_server_cert_verification, + ca_cert_file, + "", + hf_max_per_page}; +} + +void HTTPFileSystem::ParseUrl(string &url, string &path_out, string &proto_host_port_out) { + if (url.rfind("http://", 0) != 0 && url.rfind("https://", 0) != 0) { + throw IOException("URL needs to start with http:// or https://"); + } + auto slash_pos = url.find('/', 8); + if (slash_pos == string::npos) { + throw IOException("URL needs to contain a '/' after the host"); + } + proto_host_port_out = url.substr(0, slash_pos); + + path_out = url.substr(slash_pos); + + if (path_out.empty()) { + throw IOException("URL needs to contain a path"); + } +} + +// Retry the request performed by fun using the exponential backoff strategy defined in params. Before retry, the +// retry callback is called +duckdb::unique_ptr +HTTPFileSystem::RunRequestWithRetry(const std::function &request, string &url, + string method, const HTTPParams ¶ms, + const std::function &retry_cb) { + idx_t tries = 0; + while (true) { + std::exception_ptr caught_e = nullptr; + duckdb_httplib_openssl::Error err; + duckdb_httplib_openssl::Response response; + int status; + + try { + auto res = request(); + err = res.error(); + if (err == duckdb_httplib_openssl::Error::Success) { + status = res->status; + response = res.value(); + } + } catch (IOException &e) { + caught_e = std::current_exception(); + } + + // Note: all duckdb_httplib_openssl::Error types will be retried. + if (err == duckdb_httplib_openssl::Error::Success) { + switch (status) { + case 408: // Request Timeout + case 418: // Server is pretending to be a teapot + case 429: // Rate limiter hit + case 500: // Server has error + case 503: // Server has error + case 504: // Server has error + break; + default: + return make_uniq(response, url); + } + } + + tries += 1; + + if (tries <= params.retries) { + if (tries > 1) { + uint64_t sleep_amount = (uint64_t)((float)params.retry_wait_ms * pow(params.retry_backoff, tries - 2)); + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_amount)); + } + if (retry_cb) { + retry_cb(); + } + } else { + if (caught_e) { + std::rethrow_exception(caught_e); + } else if (err == duckdb_httplib_openssl::Error::Success) { + throw HTTPException(response, "Request returned HTTP %d for HTTP %s to '%s'", status, method, url); + } else { + throw IOException("%s error for HTTP %s to '%s'", to_string(err), method, url); + } + } + } +} + +unique_ptr HTTPFileSystem::PostRequest(FileHandle &handle, string url, HeaderMap header_map, + duckdb::unique_ptr &buffer_out, idx_t &buffer_out_len, + char *buffer_in, idx_t buffer_in_len, string params) { + auto &hfs = handle.Cast(); + string path, proto_host_port; + ParseUrl(url, path, proto_host_port); + auto headers = initialize_http_headers(header_map); + idx_t out_offset = 0; + + std::function request([&]() { + auto client = GetClient(hfs.http_params, proto_host_port.c_str(), &hfs); + + if (hfs.state) { + hfs.state->post_count++; + hfs.state->total_bytes_sent += buffer_in_len; + } + + // We use a custom Request method here, because there is no Post call with a contentreceiver in httplib + duckdb_httplib_openssl::Request req; + req.method = "POST"; + req.path = path; + req.headers = *headers; + req.headers.emplace("Content-Type", "application/octet-stream"); + req.content_receiver = [&](const char *data, size_t data_length, uint64_t /*offset*/, + uint64_t /*total_length*/) { + if (hfs.state) { + hfs.state->total_bytes_received += data_length; + } + if (out_offset + data_length > buffer_out_len) { + // Buffer too small, increase its size by at least 2x to fit the new value + auto new_size = MaxValue(out_offset + data_length, buffer_out_len * 2); + auto tmp = duckdb::unique_ptr {new char[new_size]}; + memcpy(tmp.get(), buffer_out.get(), buffer_out_len); + buffer_out = std::move(tmp); + buffer_out_len = new_size; + } + memcpy(buffer_out.get() + out_offset, data, data_length); + out_offset += data_length; + return true; + }; + req.body.assign(buffer_in, buffer_in_len); + return client->send(req); + }); + + return RunRequestWithRetry(request, url, "POST", hfs.http_params); +} + +unique_ptr HTTPFileSystem::GetClient(const HTTPParams &http_params, + const char *proto_host_port, + optional_ptr hfs) { + auto client = make_uniq(proto_host_port); + client->set_follow_location(true); + client->set_keep_alive(http_params.keep_alive); + if (!http_params.ca_cert_file.empty()) { + client->set_ca_cert_path(http_params.ca_cert_file.c_str()); + } + client->enable_server_certificate_verification(http_params.enable_server_cert_verification); + client->set_write_timeout(http_params.timeout); + client->set_read_timeout(http_params.timeout); + client->set_connection_timeout(http_params.timeout); + client->set_decompress(false); + if (hfs && hfs->http_logger) { + client->set_logger( + hfs->http_logger->GetLogger()); + } + if (!http_params.bearer_token.empty()) { + client->set_bearer_token_auth(http_params.bearer_token.c_str()); + } + return client; +} + +unique_ptr HTTPFileSystem::PutRequest(FileHandle &handle, string url, HeaderMap header_map, + char *buffer_in, idx_t buffer_in_len, string params) { + auto &hfs = handle.Cast(); + string path, proto_host_port; + ParseUrl(url, path, proto_host_port); + auto headers = initialize_http_headers(header_map); + + std::function request([&]() { + auto client = GetClient(hfs.http_params, proto_host_port.c_str(), &hfs); + if (hfs.state) { + hfs.state->put_count++; + hfs.state->total_bytes_sent += buffer_in_len; + } + return client->Put(path.c_str(), *headers, buffer_in, buffer_in_len, "application/octet-stream"); + }); + + return RunRequestWithRetry(request, url, "PUT", hfs.http_params); +} + +unique_ptr HTTPFileSystem::HeadRequest(FileHandle &handle, string url, HeaderMap header_map) { + auto &hfs = handle.Cast(); + string path, proto_host_port; + ParseUrl(url, path, proto_host_port); + auto headers = initialize_http_headers(header_map); + + std::function request([&]() { + if (hfs.state) { + hfs.state->head_count++; + } + return hfs.http_client->Head(path.c_str(), *headers); + }); + + std::function on_retry( + [&]() { hfs.http_client = GetClient(hfs.http_params, proto_host_port.c_str(), &hfs); }); + + return RunRequestWithRetry(request, url, "HEAD", hfs.http_params, on_retry); +} + +unique_ptr HTTPFileSystem::GetRequest(FileHandle &handle, string url, HeaderMap header_map) { + auto &hfh = handle.Cast(); + string path, proto_host_port; + ParseUrl(url, path, proto_host_port); + auto headers = initialize_http_headers(header_map); + + D_ASSERT(hfh.cached_file_handle); + + std::function request([&]() { + D_ASSERT(hfh.state); + hfh.state->get_count++; + return hfh.http_client->Get( + path.c_str(), *headers, + [&](const duckdb_httplib_openssl::Response &response) { + if (response.status >= 400) { + string error = "HTTP GET error on '" + url + "' (HTTP " + to_string(response.status) + ")"; + if (response.status == 416) { + error += " This could mean the file was changed. Try disabling the duckdb http metadata cache " + "if enabled, and confirm the server supports range requests."; + } + throw IOException(error); + } + return true; + }, + [&](const char *data, size_t data_length) { + D_ASSERT(hfh.state); + if (hfh.state) { + hfh.state->total_bytes_received += data_length; + } + if (!hfh.cached_file_handle->GetCapacity()) { + hfh.cached_file_handle->AllocateBuffer(data_length); + hfh.length = data_length; + hfh.cached_file_handle->Write(data, data_length); + } else { + auto new_capacity = hfh.cached_file_handle->GetCapacity(); + while (new_capacity < hfh.length + data_length) { + new_capacity *= 2; + } + // Grow buffer when running out of space + if (new_capacity != hfh.cached_file_handle->GetCapacity()) { + hfh.cached_file_handle->GrowBuffer(new_capacity, hfh.length); + } + // We can just copy stuff + hfh.cached_file_handle->Write(data, data_length, hfh.length); + hfh.length += data_length; + } + return true; + }); + }); + + std::function on_retry( + [&]() { hfh.http_client = GetClient(hfh.http_params, proto_host_port.c_str(), &hfh); }); + + return RunRequestWithRetry(request, url, "GET", hfh.http_params, on_retry); +} + +unique_ptr HTTPFileSystem::GetRangeRequest(FileHandle &handle, string url, HeaderMap header_map, + idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { + auto &hfs = handle.Cast(); + string path, proto_host_port; + ParseUrl(url, path, proto_host_port); + auto headers = initialize_http_headers(header_map); + + // send the Range header to read only subset of file + string range_expr = "bytes=" + to_string(file_offset) + "-" + to_string(file_offset + buffer_out_len - 1); + headers->insert(pair("Range", range_expr)); + + idx_t out_offset = 0; + + std::function request([&]() { + if (hfs.state) { + hfs.state->get_count++; + } + return hfs.http_client->Get( + path.c_str(), *headers, + [&](const duckdb_httplib_openssl::Response &response) { + if (response.status >= 400) { + string error = "HTTP GET error on '" + url + "' (HTTP " + to_string(response.status) + ")"; + if (response.status == 416) { + error += " This could mean the file was changed. Try disabling the duckdb http metadata cache " + "if enabled, and confirm the server supports range requests."; + } + throw HTTPException(response, error); + } + if (response.status < 300) { // done redirecting + out_offset = 0; + if (response.has_header("Content-Length")) { + auto content_length = stoll(response.get_header_value("Content-Length", 0)); + if ((idx_t)content_length != buffer_out_len) { + throw IOException("HTTP GET error: Content-Length from server mismatches requested " + "range, server may not support range requests."); + } + } + } + return true; + }, + [&](const char *data, size_t data_length) { + if (hfs.state) { + hfs.state->total_bytes_received += data_length; + } + if (buffer_out != nullptr) { + if (data_length + out_offset > buffer_out_len) { + // As of v0.8.2-dev4424 we might end up here when very big files are served from servers + // that returns more data than requested via range header. This is an uncommon but legal + // behaviour, so we have to improve logic elsewhere to properly handle this case. + + // To avoid corruption of memory, we bail out. + throw IOException("Server sent back more data than expected, `SET force_download=true` might " + "help in this case"); + } + memcpy(buffer_out + out_offset, data, data_length); + out_offset += data_length; + } + return true; + }); + }); + + std::function on_retry( + [&]() { hfs.http_client = GetClient(hfs.http_params, proto_host_port.c_str(), &hfs); }); + + return RunRequestWithRetry(request, url, "GET Range", hfs.http_params, on_retry); +} + +HTTPFileHandle::HTTPFileHandle(FileSystem &fs, const string &path, FileOpenFlags flags, const HTTPParams &http_params) + : FileHandle(fs, path), http_params(http_params), flags(flags), length(0), buffer_available(0), buffer_idx(0), + file_offset(0), buffer_start(0), buffer_end(0) { +} + +unique_ptr HTTPFileSystem::CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener) { + D_ASSERT(flags.Compression() == FileCompressionType::UNCOMPRESSED); + + auto params = HTTPParams::ReadFrom(opener); + + auto secret_manager = FileOpener::TryGetSecretManager(opener); + auto transaction = FileOpener::TryGetCatalogTransaction(opener); + if (secret_manager && transaction) { + auto secret_match = secret_manager->LookupSecret(*transaction, path, "bearer"); + + if (secret_match.HasMatch()) { + const auto &kv_secret = dynamic_cast(*secret_match.secret_entry->secret); + params.bearer_token = kv_secret.TryGetValue("token", true).ToString(); + } + } + + return duckdb::make_uniq(*this, path, flags, params); +} + +unique_ptr HTTPFileSystem::OpenFile(const string &path, FileOpenFlags flags, + optional_ptr opener) { + D_ASSERT(flags.Compression() == FileCompressionType::UNCOMPRESSED); + + if (flags.ReturnNullIfNotExists()) { + try { + auto handle = CreateHandle(path, flags, opener); + handle->Initialize(opener); + return std::move(handle); + } catch (...) { + return nullptr; + } + } + + auto handle = CreateHandle(path, flags, opener); + handle->Initialize(opener); + return std::move(handle); +} + +// Buffered read from http file. +// Note that buffering is disabled when FileFlags::FILE_FLAGS_DIRECT_IO is set +void HTTPFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + auto &hfh = handle.Cast(); + + D_ASSERT(hfh.state); + if (hfh.cached_file_handle) { + if (!hfh.cached_file_handle->Initialized()) { + throw InternalException("Cached file not initialized properly"); + } + memcpy(buffer, hfh.cached_file_handle->GetData() + location, nr_bytes); + hfh.file_offset = location + nr_bytes; + return; + } + + idx_t to_read = nr_bytes; + idx_t buffer_offset = 0; + + // Don't buffer when DirectIO is set or when we are doing parallel reads + bool skip_buffer = hfh.flags.DirectIO() || hfh.flags.RequireParallelAccess(); + if (skip_buffer && to_read > 0) { + GetRangeRequest(hfh, hfh.path, {}, location, (char *)buffer, to_read); + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset = location + nr_bytes; + return; + } + + if (location >= hfh.buffer_start && location < hfh.buffer_end) { + hfh.file_offset = location; + hfh.buffer_idx = location - hfh.buffer_start; + hfh.buffer_available = (hfh.buffer_end - hfh.buffer_start) - hfh.buffer_idx; + } else { + // reset buffer + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset = location; + } + while (to_read > 0) { + auto buffer_read_len = MinValue(hfh.buffer_available, to_read); + if (buffer_read_len > 0) { + D_ASSERT(hfh.buffer_start + hfh.buffer_idx + buffer_read_len <= hfh.buffer_end); + memcpy((char *)buffer + buffer_offset, hfh.read_buffer.get() + hfh.buffer_idx, buffer_read_len); + + buffer_offset += buffer_read_len; + to_read -= buffer_read_len; + + hfh.buffer_idx += buffer_read_len; + hfh.buffer_available -= buffer_read_len; + hfh.file_offset += buffer_read_len; + } + + if (to_read > 0 && hfh.buffer_available == 0) { + auto new_buffer_available = MinValue(hfh.READ_BUFFER_LEN, hfh.length - hfh.file_offset); + + // Bypass buffer if we read more than buffer size + if (to_read > new_buffer_available) { + GetRangeRequest(hfh, hfh.path, {}, location + buffer_offset, (char *)buffer + buffer_offset, to_read); + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset += to_read; + break; + } else { + GetRangeRequest(hfh, hfh.path, {}, hfh.file_offset, (char *)hfh.read_buffer.get(), + new_buffer_available); + hfh.buffer_available = new_buffer_available; + hfh.buffer_idx = 0; + hfh.buffer_start = hfh.file_offset; + hfh.buffer_end = hfh.buffer_start + new_buffer_available; + } + } + } +} + +int64_t HTTPFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &hfh = (HTTPFileHandle &)handle; + idx_t max_read = hfh.length - hfh.file_offset; + nr_bytes = MinValue(max_read, nr_bytes); + Read(handle, buffer, nr_bytes, hfh.file_offset); + return nr_bytes; +} + +void HTTPFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + throw NotImplementedException("Writing to HTTP files not implemented"); +} + +int64_t HTTPFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &hfh = (HTTPFileHandle &)handle; + Write(handle, buffer, nr_bytes, hfh.file_offset); + return nr_bytes; +} + +void HTTPFileSystem::FileSync(FileHandle &handle) { + throw NotImplementedException("FileSync for HTTP files not implemented"); +} + +int64_t HTTPFileSystem::GetFileSize(FileHandle &handle) { + auto &sfh = handle.Cast(); + return sfh.length; +} + +time_t HTTPFileSystem::GetLastModifiedTime(FileHandle &handle) { + auto &sfh = handle.Cast(); + return sfh.last_modified; +} + +bool HTTPFileSystem::FileExists(const string &filename, optional_ptr opener) { + try { + auto handle = OpenFile(filename, FileFlags::FILE_FLAGS_READ, opener); + auto &sfh = handle->Cast(); + if (sfh.length == 0) { + return false; + } + return true; + } catch (...) { + return false; + }; +} + +bool HTTPFileSystem::CanHandleFile(const string &fpath) { + return fpath.rfind("https://", 0) == 0 || fpath.rfind("http://", 0) == 0; +} + +void HTTPFileSystem::Seek(FileHandle &handle, idx_t location) { + auto &sfh = handle.Cast(); + sfh.file_offset = location; +} + +idx_t HTTPFileSystem::SeekPosition(FileHandle &handle) { + auto &sfh = handle.Cast(); + return sfh.file_offset; +} + +optional_ptr HTTPFileSystem::GetGlobalCache() { + lock_guard lock(global_cache_lock); + if (!global_metadata_cache) { + global_metadata_cache = make_uniq(false, true); + } + return global_metadata_cache.get(); +} + +// Get either the local, global, or no cache depending on settings +static optional_ptr TryGetMetadataCache(optional_ptr opener, HTTPFileSystem &httpfs) { + auto db = FileOpener::TryGetDatabase(opener); + auto client_context = FileOpener::TryGetClientContext(opener); + if (!db) { + return nullptr; + } + + bool use_shared_cache = db->config.options.http_metadata_cache_enable; + if (use_shared_cache) { + return httpfs.GetGlobalCache(); + } else if (client_context) { + auto lookup = client_context->registered_state.find("http_cache"); + if (lookup == client_context->registered_state.end()) { + auto cache = make_shared_ptr(true, true); + client_context->registered_state["http_cache"] = cache; + return cache.get(); + } else { + return (HTTPMetadataCache *)lookup->second.get(); + } + } + return nullptr; +} + +void HTTPFileHandle::Initialize(optional_ptr opener) { + InitializeClient(FileOpener::TryGetClientContext(opener)); + auto &hfs = file_system.Cast(); + state = HTTPState::TryGetState(opener); + if (!state) { + state = make_shared_ptr(); + } + + auto current_cache = TryGetMetadataCache(opener, hfs); + + bool should_write_cache = false; + if (!http_params.force_download && current_cache && !flags.OpenForWriting()) { + + HTTPMetadataCacheEntry value; + bool found = current_cache->Find(path, value); + + if (found) { + last_modified = value.last_modified; + length = value.length; + + if (flags.OpenForReading()) { + read_buffer = duckdb::unique_ptr(new data_t[READ_BUFFER_LEN]); + } + return; + } + + should_write_cache = true; + } + + // If we're writing to a file, we might as well remove it from the cache + if (current_cache && flags.OpenForWriting()) { + current_cache->Erase(path); + } + + auto res = hfs.HeadRequest(*this, path, {}); + string range_length; + + if (res->code != 200) { + if (flags.OpenForWriting() && res->code == 404) { + if (!flags.CreateFileIfNotExists() && !flags.OverwriteExistingFile()) { + throw IOException("Unable to open URL \"" + path + + "\" for writing: file does not exist and CREATE flag is not set"); + } + length = 0; + return; + } else { + // HEAD request fail, use Range request for another try (read only one byte) + if (flags.OpenForReading() && res->code != 404) { + auto range_res = hfs.GetRangeRequest(*this, path, {}, 0, nullptr, 2); + if (range_res->code != 206) { + throw IOException("Unable to connect to URL \"%s\": %d (%s)", path, res->code, res->error); + } + auto range_find = range_res->headers["Content-Range"].find("/"); + + if (range_find == std::string::npos || range_res->headers["Content-Range"].size() < range_find + 1) { + throw IOException("Unknown Content-Range Header \"The value of Content-Range Header\": (%s)", + range_res->headers["Content-Range"]); + } + + range_length = range_res->headers["Content-Range"].substr(range_find + 1); + if (range_length == "*") { + throw IOException("Unknown total length of the document \"%s\": %d (%s)", path, res->code, + res->error); + } + res = std::move(range_res); + } else { + throw HTTPException(*res, "Unable to connect to URL \"%s\": %s (%s)", res->http_url, + to_string(res->code), res->error); + } + } + } + + // Initialize the read buffer now that we know the file exists + if (flags.OpenForReading()) { + read_buffer = duckdb::unique_ptr(new data_t[READ_BUFFER_LEN]); + } + + if (res->headers.find("Content-Length") == res->headers.end() || res->headers["Content-Length"].empty()) { + // There was no content-length header, we can not do range requests here, so we set the length to 0 + length = 0; + } else { + try { + if (res->headers.find("Content-Range") == res->headers.end() || res->headers["Content-Range"].empty()) { + length = std::stoll(res->headers["Content-Length"]); + } else { + length = std::stoll(range_length); + } + } catch (std::invalid_argument &e) { + throw IOException("Invalid Content-Length header received: %s", res->headers["Content-Length"]); + } catch (std::out_of_range &e) { + throw IOException("Invalid Content-Length header received: %s", res->headers["Content-Length"]); + } + } + if (state && (length == 0 || http_params.force_download)) { + auto &cache_entry = state->GetCachedFile(path); + cached_file_handle = cache_entry->GetHandle(); + if (!cached_file_handle->Initialized()) { + // Try to fully download the file first + auto full_download_result = hfs.GetRequest(*this, path, {}); + if (full_download_result->code != 200) { + throw HTTPException(*res, "Full download failed to to URL \"%s\": %s (%s)", + full_download_result->http_url, to_string(full_download_result->code), + full_download_result->error); + } + + // Mark the file as initialized, set its final length, and unlock it to allowing parallel reads + cached_file_handle->SetInitialized(length); + + // We shouldn't write these to cache + should_write_cache = false; + } else { + length = cached_file_handle->GetSize(); + } + } + + if (!res->headers["Last-Modified"].empty()) { + auto result = StrpTimeFormat::Parse("%a, %d %h %Y %T %Z", res->headers["Last-Modified"]); + + struct tm tm {}; + tm.tm_year = result.data[0] - 1900; + tm.tm_mon = result.data[1] - 1; + tm.tm_mday = result.data[2]; + tm.tm_hour = result.data[3]; + tm.tm_min = result.data[4]; + tm.tm_sec = result.data[5]; + tm.tm_isdst = 0; + last_modified = mktime(&tm); + } + + if (should_write_cache) { + current_cache->Insert(path, {length, last_modified}); + } +} + +void HTTPFileHandle::InitializeClient(optional_ptr context) { + string path_out, proto_host_port; + HTTPFileSystem::ParseUrl(path, path_out, proto_host_port); + http_client = HTTPFileSystem::GetClient(this->http_params, proto_host_port.c_str(), this); + if (context && ClientConfig::GetConfig(*context).enable_http_logging) { + http_logger = context->client_data->http_logger.get(); + http_client->set_logger( + http_logger->GetLogger()); + } +} + +ResponseWrapper::ResponseWrapper(duckdb_httplib_openssl::Response &res, string &original_url) { + code = res.status; + error = res.reason; + for (auto &h : res.headers) { + headers[h.first] = h.second; + } + http_url = original_url; + body = res.body; +} + +HTTPFileHandle::~HTTPFileHandle() = default; +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/httpfs_extension.cpp b/src/duckdb/extension/httpfs/httpfs_extension.cpp new file mode 100644 index 00000000..8f1966c9 --- /dev/null +++ b/src/duckdb/extension/httpfs/httpfs_extension.cpp @@ -0,0 +1,101 @@ +#define DUCKDB_EXTENSION_MAIN + +#include "httpfs_extension.hpp" + +#include "create_secret_functions.hpp" +#include "duckdb.hpp" +#include "s3fs.hpp" +#include "hffs.hpp" + +namespace duckdb { + +static void LoadInternal(DatabaseInstance &instance) { + S3FileSystem::Verify(); // run some tests to see if all the hashes work out + auto &fs = instance.GetFileSystem(); + + fs.RegisterSubSystem(make_uniq()); + fs.RegisterSubSystem(make_uniq()); + fs.RegisterSubSystem(make_uniq(BufferManager::GetBufferManager(instance))); + + auto &config = DBConfig::GetConfig(instance); + + // Global HTTP config + // Single timeout value is used for all 4 types of timeouts, we could split it into 4 if users need that + config.AddExtensionOption("http_timeout", "HTTP timeout read/write/connection/retry", LogicalType::UBIGINT, + Value(30000)); + config.AddExtensionOption("http_retries", "HTTP retries on I/O error", LogicalType::UBIGINT, Value(3)); + config.AddExtensionOption("http_retry_wait_ms", "Time between retries", LogicalType::UBIGINT, Value(100)); + config.AddExtensionOption("force_download", "Forces upfront download of file", LogicalType::BOOLEAN, Value(false)); + // Reduces the number of requests made while waiting, for example retry_wait_ms of 50 and backoff factor of 2 will + // result in wait times of 0 50 100 200 400...etc. + config.AddExtensionOption("http_retry_backoff", "Backoff factor for exponentially increasing retry wait time", + LogicalType::FLOAT, Value(4)); + config.AddExtensionOption( + "http_keep_alive", + "Keep alive connections. Setting this to false can help when running into connection failures", + LogicalType::BOOLEAN, Value(true)); + config.AddExtensionOption("enable_server_cert_verification", "Enable server side certificate verification.", + LogicalType::BOOLEAN, Value(false)); + config.AddExtensionOption("ca_cert_file", "Path to a custom certificate file for self-signed certificates.", + LogicalType::VARCHAR, Value("")); + // Global S3 config + config.AddExtensionOption("s3_region", "S3 Region", LogicalType::VARCHAR, Value("us-east-1")); + config.AddExtensionOption("s3_access_key_id", "S3 Access Key ID", LogicalType::VARCHAR); + config.AddExtensionOption("s3_secret_access_key", "S3 Access Key", LogicalType::VARCHAR); + config.AddExtensionOption("s3_session_token", "S3 Session Token", LogicalType::VARCHAR); + config.AddExtensionOption("s3_endpoint", "S3 Endpoint", LogicalType::VARCHAR); + config.AddExtensionOption("s3_url_style", "S3 URL style", LogicalType::VARCHAR, Value("vhost")); + config.AddExtensionOption("s3_use_ssl", "S3 use SSL", LogicalType::BOOLEAN, Value(true)); + config.AddExtensionOption("s3_url_compatibility_mode", "Disable Globs and Query Parameters on S3 URLs", + LogicalType::BOOLEAN, Value(false)); + + // S3 Uploader config + config.AddExtensionOption("s3_uploader_max_filesize", "S3 Uploader max filesize (between 50GB and 5TB)", + LogicalType::VARCHAR, "800GB"); + config.AddExtensionOption("s3_uploader_max_parts_per_file", "S3 Uploader max parts per file (between 1 and 10000)", + LogicalType::UBIGINT, Value(10000)); + config.AddExtensionOption("s3_uploader_thread_limit", "S3 Uploader global thread limit", LogicalType::UBIGINT, + Value(50)); + + // HuggingFace options + config.AddExtensionOption("hf_max_per_page", "Debug option to limit number of items returned in list requests", + LogicalType::UBIGINT, Value::UBIGINT(0)); + + auto provider = make_uniq(config); + provider->SetAll(); + + CreateS3SecretFunctions::Register(instance); + CreateBearerTokenFunctions::Register(instance); +} + +void HttpfsExtension::Load(DuckDB &db) { + LoadInternal(*db.instance); +} +std::string HttpfsExtension::Name() { + return "httpfs"; +} + +std::string HttpfsExtension::Version() const { +#ifdef EXT_VERSION_HTTPFS + return EXT_VERSION_HTTPFS; +#else + return ""; +#endif +} + +} // namespace duckdb + +extern "C" { + +DUCKDB_EXTENSION_API void httpfs_init(duckdb::DatabaseInstance &db) { + LoadInternal(db); +} + +DUCKDB_EXTENSION_API const char *httpfs_version() { + return duckdb::DuckDB::LibraryVersion(); +} +} + +#ifndef DUCKDB_EXTENSION_MAIN +#error DUCKDB_EXTENSION_MAIN not defined +#endif diff --git a/src/duckdb/extension/httpfs/include/create_secret_functions.hpp b/src/duckdb/extension/httpfs/include/create_secret_functions.hpp new file mode 100644 index 00000000..91d4d8d0 --- /dev/null +++ b/src/duckdb/extension/httpfs/include/create_secret_functions.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "duckdb.hpp" + +namespace duckdb { +struct CreateSecretInput; +struct S3AuthParams; +class CreateSecretFunction; +class BaseSecret; + +struct CreateS3SecretFunctions { +public: + //! Register all CreateSecretFunctions + static void Register(DatabaseInstance &instance); + +protected: + //! Internal function to create BaseSecret from S3AuthParams + static unique_ptr CreateSecretFunctionInternal(ClientContext &context, CreateSecretInput &input, + S3AuthParams params); + + //! Function for the "settings" provider: creates secret from current duckdb settings + static unique_ptr CreateS3SecretFromSettings(ClientContext &context, CreateSecretInput &input); + //! Function for the "config" provider: creates secret from parameters passed by user + static unique_ptr CreateS3SecretFromConfig(ClientContext &context, CreateSecretInput &input); + + //! Helper function to set named params of secret function + static void SetBaseNamedParams(CreateSecretFunction &function, string &type); + //! Helper function to create secret types s3/r2/gcs + static void RegisterCreateSecretFunction(DatabaseInstance &instance, string type); +}; + +struct CreateBearerTokenFunctions { +public: + static constexpr const char *GENERIC_BEARER_TYPE = "bearer"; + static constexpr const char *HUGGINGFACE_TYPE = "huggingface"; + + //! Register all CreateSecretFunctions + static void Register(DatabaseInstance &instance); + +protected: + //! Internal function to create bearer token + static unique_ptr CreateSecretFunctionInternal(ClientContext &context, CreateSecretInput &input, + const string &token); + //! Function for the "config" provider: creates secret from parameters passed by user + static unique_ptr CreateBearerSecretFromConfig(ClientContext &context, CreateSecretInput &input); + //! Function for the "config" provider: creates secret from parameters passed by user + static unique_ptr CreateHuggingFaceSecretFromCredentialChain(ClientContext &context, + CreateSecretInput &input); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/crypto.hpp b/src/duckdb/extension/httpfs/include/crypto.hpp new file mode 100644 index 00000000..7fc755f2 --- /dev/null +++ b/src/duckdb/extension/httpfs/include/crypto.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace duckdb { + +typedef unsigned char hash_bytes[32]; +typedef unsigned char hash_str[64]; + +void sha256(const char *in, size_t in_len, hash_bytes &out); + +void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out); + +void hmac256(std::string message, hash_bytes secret, hash_bytes &out); + +void hex256(hash_bytes &in, hash_str &out); + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/hffs.hpp b/src/duckdb/extension/httpfs/include/hffs.hpp new file mode 100644 index 00000000..e577382e --- /dev/null +++ b/src/duckdb/extension/httpfs/include/hffs.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include "httpfs.hpp" + +namespace duckdb { + +struct ParsedHFUrl { + //! Path within the + string path; + //! Name of the repo (i presume) + string repository; + + //! Endpoint, defaults to HF + string endpoint = "https://huggingface.co"; + //! Which revision/branch/tag to use + string revision = "main"; + //! For DuckDB this may be a sensible default? + string repo_type = "datasets"; +}; + +class HuggingFaceFileSystem : public HTTPFileSystem { +public: + ~HuggingFaceFileSystem() override; + + vector Glob(const string &path, FileOpener *opener = nullptr) override; + + duckdb::unique_ptr HeadRequest(FileHandle &handle, string hf_url, HeaderMap header_map) override; + duckdb::unique_ptr GetRequest(FileHandle &handle, string hf_url, HeaderMap header_map) override; + duckdb::unique_ptr GetRangeRequest(FileHandle &handle, string hf_url, HeaderMap header_map, + idx_t file_offset, char *buffer_out, + idx_t buffer_out_len) override; + + bool CanHandleFile(const string &fpath) override { + return fpath.rfind("hf://", 0) == 0; + }; + + string GetName() const override { + return "HuggingFaceFileSystem"; + } + static ParsedHFUrl HFUrlParse(const string &url); + string GetHFUrl(const ParsedHFUrl &url); + string GetTreeUrl(const ParsedHFUrl &url, idx_t limit); + string GetFileUrl(const ParsedHFUrl &url); + + static void SetParams(HTTPParams ¶ms, const string &path, optional_ptr opener); + +protected: + duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener) override; + + string ListHFRequest(ParsedHFUrl &url, HTTPParams &http_params, string &next_page_url, + optional_ptr state); +}; + +class HFFileHandle : public HTTPFileHandle { + friend class HuggingFaceFileSystem; + +public: + HFFileHandle(FileSystem &fs, ParsedHFUrl hf_url, string http_url, FileOpenFlags flags, + const HTTPParams &http_params) + : HTTPFileHandle(fs, std::move(http_url), flags, http_params), parsed_url(std::move(hf_url)) { + } + ~HFFileHandle() override; + + void InitializeClient(optional_ptr client_context) override; + +protected: + ParsedHFUrl parsed_url; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/http_metadata_cache.hpp b/src/duckdb/extension/httpfs/include/http_metadata_cache.hpp new file mode 100644 index 00000000..73d032b0 --- /dev/null +++ b/src/duckdb/extension/httpfs/include/http_metadata_cache.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/list.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_context_state.hpp" + +#include +#include + +namespace duckdb { + +struct HTTPMetadataCacheEntry { + idx_t length; + time_t last_modified; +}; + +// Simple cache with a max age for an entry to be valid +class HTTPMetadataCache : public ClientContextState { +public: + explicit HTTPMetadataCache(bool flush_on_query_end_p, bool shared_p) + : flush_on_query_end(flush_on_query_end_p), shared(shared_p) {}; + + void Insert(const string &path, HTTPMetadataCacheEntry val) { + if (shared) { + lock_guard parallel_lock(lock); + map[path] = val; + } else { + map[path] = val; + } + }; + + void Erase(string path) { + if (shared) { + lock_guard parallel_lock(lock); + map.erase(path); + } else { + map.erase(path); + } + }; + + bool Find(string path, HTTPMetadataCacheEntry &ret_val) { + if (shared) { + lock_guard parallel_lock(lock); + auto lookup = map.find(path); + if (lookup != map.end()) { + ret_val = lookup->second; + return true; + } else { + return false; + } + } else { + auto lookup = map.find(path); + if (lookup != map.end()) { + ret_val = lookup->second; + return true; + } else { + return false; + } + } + }; + + void Clear() { + if (shared) { + lock_guard parallel_lock(lock); + map.clear(); + } else { + map.clear(); + } + } + + //! Called by the ClientContext when the current query ends + void QueryEnd(ClientContext &context) override { + if (flush_on_query_end) { + Clear(); + } + } + +protected: + mutex lock; + unordered_map map; + bool flush_on_query_end; + bool shared; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/httpfs.hpp b/src/duckdb/extension/httpfs/include/httpfs.hpp new file mode 100644 index 00000000..1c49889b --- /dev/null +++ b/src/duckdb/extension/httpfs/include/httpfs.hpp @@ -0,0 +1,178 @@ +#pragma once + +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_data.hpp" +#include "http_metadata_cache.hpp" + +namespace duckdb_httplib_openssl { +struct Response; +struct Result; +class Client; +} // namespace duckdb_httplib_openssl + +namespace duckdb { + +class HTTPLogger; + +using HeaderMap = case_insensitive_map_t; + +// avoid including httplib in header +struct ResponseWrapper { +public: + explicit ResponseWrapper(duckdb_httplib_openssl::Response &res, string &original_url); + int code; + string error; + HeaderMap headers; + string http_url; + string body; +}; + +struct HTTPParams { + + static constexpr uint64_t DEFAULT_TIMEOUT = 30000; // 30 sec + static constexpr uint64_t DEFAULT_RETRIES = 3; + static constexpr uint64_t DEFAULT_RETRY_WAIT_MS = 100; + static constexpr float DEFAULT_RETRY_BACKOFF = 4; + static constexpr bool DEFAULT_FORCE_DOWNLOAD = false; + static constexpr bool DEFAULT_KEEP_ALIVE = true; + static constexpr bool DEFAULT_ENABLE_SERVER_CERT_VERIFICATION = false; + static constexpr uint64_t DEFAULT_HF_MAX_PER_PAGE = 0; + + uint64_t timeout; + uint64_t retries; + uint64_t retry_wait_ms; + float retry_backoff; + bool force_download; + bool keep_alive; + bool enable_server_cert_verification; + std::string ca_cert_file; + + string bearer_token; + + idx_t hf_max_per_page; + + static HTTPParams ReadFrom(optional_ptr opener); +}; + +class HTTPFileHandle : public FileHandle { +public: + HTTPFileHandle(FileSystem &fs, const string &path, FileOpenFlags flags, const HTTPParams ¶ms); + ~HTTPFileHandle() override; + // This two-phase construction allows subclasses more flexible setup. + virtual void Initialize(optional_ptr opener); + + // We keep an http client stored for connection reuse with keep-alive headers + duckdb::unique_ptr http_client; + optional_ptr http_logger; + + const HTTPParams http_params; + + // File handle info + FileOpenFlags flags; + idx_t length; + time_t last_modified; + + // When using full file download, the full file will be written to a cached file handle + unique_ptr cached_file_handle; + + // Read info + idx_t buffer_available; + idx_t buffer_idx; + idx_t file_offset; + idx_t buffer_start; + idx_t buffer_end; + + // Read buffer + duckdb::unique_ptr read_buffer; + constexpr static idx_t READ_BUFFER_LEN = 1000000; + + shared_ptr state; + + void AddHeaders(HeaderMap &map); + +public: + void Close() override { + } + +protected: + virtual void InitializeClient(optional_ptr client_context); +}; + +class HTTPFileSystem : public FileSystem { +public: + static duckdb::unique_ptr + GetClient(const HTTPParams &http_params, const char *proto_host_port, optional_ptr hfs); + static void ParseUrl(string &url, string &path_out, string &proto_host_port_out); + duckdb::unique_ptr OpenFile(const string &path, FileOpenFlags flags, + optional_ptr opener = nullptr) final; + + vector Glob(const string &path, FileOpener *opener = nullptr) override { + return {path}; // FIXME + } + + // HTTP Requests + virtual duckdb::unique_ptr HeadRequest(FileHandle &handle, string url, HeaderMap header_map); + // Get Request with range parameter that GETs exactly buffer_out_len bytes from the url + virtual duckdb::unique_ptr GetRangeRequest(FileHandle &handle, string url, HeaderMap header_map, + idx_t file_offset, char *buffer_out, + idx_t buffer_out_len); + // Get Request without a range (i.e., downloads full file) + virtual duckdb::unique_ptr GetRequest(FileHandle &handle, string url, HeaderMap header_map); + // Post Request that can handle variable sized responses without a content-length header (needed for s3 multipart) + virtual duckdb::unique_ptr PostRequest(FileHandle &handle, string url, HeaderMap header_map, + duckdb::unique_ptr &buffer_out, + idx_t &buffer_out_len, char *buffer_in, idx_t buffer_in_len, + string params = ""); + virtual duckdb::unique_ptr PutRequest(FileHandle &handle, string url, HeaderMap header_map, + char *buffer_in, idx_t buffer_in_len, string params = ""); + + // FS methods + void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + void FileSync(FileHandle &handle) override; + int64_t GetFileSize(FileHandle &handle) override; + time_t GetLastModifiedTime(FileHandle &handle) override; + bool FileExists(const string &filename, optional_ptr opener) override; + void Seek(FileHandle &handle, idx_t location) override; + idx_t SeekPosition(FileHandle &handle) override; + bool CanHandleFile(const string &fpath) override; + bool CanSeek() override { + return true; + } + bool OnDiskFile(FileHandle &handle) override { + return false; + } + bool IsPipe(const string &filename, optional_ptr opener) override { + return false; + } + string GetName() const override { + return "HTTPFileSystem"; + } + string PathSeparator(const string &path) override { + return "/"; + } + static void Verify(); + + optional_ptr GetGlobalCache(); + +protected: + virtual duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener); + + static duckdb::unique_ptr + RunRequestWithRetry(const std::function &request, string &url, string method, + const HTTPParams ¶ms, const std::function &retry_cb = {}); + +private: + // Global cache + mutex global_cache_lock; + duckdb::unique_ptr global_metadata_cache; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/httpfs_extension.hpp b/src/duckdb/extension/httpfs/include/httpfs_extension.hpp new file mode 100644 index 00000000..3c4f3a11 --- /dev/null +++ b/src/duckdb/extension/httpfs/include/httpfs_extension.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "duckdb.hpp" + +namespace duckdb { + +class HttpfsExtension : public Extension { +public: + void Load(DuckDB &db) override; + std::string Name() override; + std::string Version() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/include/s3fs.hpp b/src/duckdb/extension/httpfs/include/s3fs.hpp new file mode 100644 index 00000000..501c49c0 --- /dev/null +++ b/src/duckdb/extension/httpfs/include/s3fs.hpp @@ -0,0 +1,256 @@ +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/secret/secret_manager.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "httpfs.hpp" + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +#include +#include +#include + +namespace duckdb { + +struct S3AuthParams { + string region; + string access_key_id; + string secret_access_key; + string session_token; + string endpoint; + string url_style; + bool use_ssl = true; + bool s3_url_compatibility_mode = false; + + static S3AuthParams ReadFrom(optional_ptr opener, FileOpenerInfo &info); + static unique_ptr ReadFromStoredCredentials(optional_ptr opener, string path); +}; + +struct AWSEnvironmentCredentialsProvider { + static constexpr const char *REGION_ENV_VAR = "AWS_REGION"; + static constexpr const char *DEFAULT_REGION_ENV_VAR = "AWS_DEFAULT_REGION"; + static constexpr const char *ACCESS_KEY_ENV_VAR = "AWS_ACCESS_KEY_ID"; + static constexpr const char *SECRET_KEY_ENV_VAR = "AWS_SECRET_ACCESS_KEY"; + static constexpr const char *SESSION_TOKEN_ENV_VAR = "AWS_SESSION_TOKEN"; + static constexpr const char *DUCKDB_ENDPOINT_ENV_VAR = "DUCKDB_S3_ENDPOINT"; + static constexpr const char *DUCKDB_USE_SSL_ENV_VAR = "DUCKDB_S3_USE_SSL"; + + explicit AWSEnvironmentCredentialsProvider(DBConfig &config) : config(config) {}; + + DBConfig &config; + + void SetExtensionOptionValue(string key, const char *env_var); + void SetAll(); + S3AuthParams CreateParams(); +}; + +struct ParsedS3Url { + const string http_proto; + const string prefix; + const string host; + const string bucket; + const string key; + const string path; + const string query_param; + const string trimmed_s3_url; + + string GetHTTPUrl(S3AuthParams &auth_params, const string &http_query_string = ""); +}; + +struct S3ConfigParams { + static constexpr uint64_t DEFAULT_MAX_FILESIZE = 800000000000; // 800GB + static constexpr uint64_t DEFAULT_MAX_PARTS_PER_FILE = 10000; // AWS DEFAULT + static constexpr uint64_t DEFAULT_MAX_UPLOAD_THREADS = 50; + + uint64_t max_file_size; + uint64_t max_parts_per_file; + uint64_t max_upload_threads; + + static S3ConfigParams ReadFrom(optional_ptr opener); +}; + +class S3SecretHelper { +public: + //! Create an S3 type secret + static unique_ptr CreateSecret(vector &prefix_paths_p, string &type, string &provider, + string &name, S3AuthParams ¶ms); + //! Parse S3AuthParams from secret + static S3AuthParams GetParams(const KeyValueSecret &secret); +}; + +class S3FileSystem; + +// Holds the buffered data for 1 part of an S3 Multipart upload +class S3WriteBuffer { +public: + explicit S3WriteBuffer(idx_t buffer_start, size_t buffer_size, BufferHandle buffer_p) + : idx(0), buffer_start(buffer_start), buffer(std::move(buffer_p)) { + buffer_end = buffer_start + buffer_size; + part_no = buffer_start / buffer_size; + uploading = false; + } + + void *Ptr() { + return buffer.Ptr(); + } + + // The S3 multipart part number. Note that internally we start at 0 but AWS S3 starts at 1 + idx_t part_no; + + idx_t idx; + idx_t buffer_start; + idx_t buffer_end; + BufferHandle buffer; + atomic uploading; +}; + +class S3FileHandle : public HTTPFileHandle { + friend class S3FileSystem; + +public: + S3FileHandle(FileSystem &fs, string path_p, FileOpenFlags flags, const HTTPParams &http_params, + const S3AuthParams &auth_params_p, const S3ConfigParams &config_params_p) + : HTTPFileHandle(fs, std::move(path_p), flags, http_params), auth_params(auth_params_p), + config_params(config_params_p), uploads_in_progress(0), parts_uploaded(0), upload_finalized(false), + uploader_has_error(false), upload_exception(nullptr) { + if (flags.OpenForReading() && flags.OpenForWriting()) { + throw NotImplementedException("Cannot open an HTTP file for both reading and writing"); + } else if (flags.OpenForAppending()) { + throw NotImplementedException("Cannot open an HTTP file for appending"); + } + } + ~S3FileHandle() override; + + S3AuthParams auth_params; + const S3ConfigParams config_params; + +public: + void Close() override; + void Initialize(optional_ptr opener) override; + + shared_ptr GetBuffer(uint16_t write_buffer_idx); + +protected: + string multipart_upload_id; + size_t part_size; + + //! Write buffers for this file + mutex write_buffers_lock; + unordered_map> write_buffers; + + //! Synchronization for upload threads + mutex uploads_in_progress_lock; + std::condition_variable uploads_in_progress_cv; + std::condition_variable final_flush_cv; + uint16_t uploads_in_progress; + + //! Etags are stored for each part + mutex part_etags_lock; + unordered_map part_etags; + + //! Info for upload + atomic parts_uploaded; + bool upload_finalized = true; + + //! Error handling in upload threads + atomic uploader_has_error {false}; + std::exception_ptr upload_exception; + + void InitializeClient(optional_ptr client_context) override; + + //! Rethrow IO Exception originating from an upload thread + void RethrowIOError() { + if (uploader_has_error) { + std::rethrow_exception(upload_exception); + } + } +}; + +class S3FileSystem : public HTTPFileSystem { +public: + explicit S3FileSystem(BufferManager &buffer_manager) : buffer_manager(buffer_manager) { + } + + BufferManager &buffer_manager; + string GetName() const override; + +public: + duckdb::unique_ptr HeadRequest(FileHandle &handle, string s3_url, HeaderMap header_map) override; + duckdb::unique_ptr GetRequest(FileHandle &handle, string url, HeaderMap header_map) override; + duckdb::unique_ptr GetRangeRequest(FileHandle &handle, string s3_url, HeaderMap header_map, + idx_t file_offset, char *buffer_out, + idx_t buffer_out_len) override; + duckdb::unique_ptr PostRequest(FileHandle &handle, string s3_url, HeaderMap header_map, + duckdb::unique_ptr &buffer_out, idx_t &buffer_out_len, + char *buffer_in, idx_t buffer_in_len, + string http_params = "") override; + duckdb::unique_ptr PutRequest(FileHandle &handle, string s3_url, HeaderMap header_map, + char *buffer_in, idx_t buffer_in_len, + string http_params = "") override; + + static void Verify(); + + bool CanHandleFile(const string &fpath) override; + bool OnDiskFile(FileHandle &handle) override { + return false; + } + void FileSync(FileHandle &handle) override; + void Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + + string InitializeMultipartUpload(S3FileHandle &file_handle); + void FinalizeMultipartUpload(S3FileHandle &file_handle); + + void FlushAllBuffers(S3FileHandle &handle); + + void ReadQueryParams(const string &url_query_param, S3AuthParams ¶ms); + static ParsedS3Url S3UrlParse(string url, S3AuthParams ¶ms); + + static string UrlEncode(const string &input, bool encode_slash = false); + static string UrlDecode(string input); + + // Uploads the contents of write_buffer to S3. + // Note: caller is responsible to not call this method twice on the same buffer + static void UploadBuffer(S3FileHandle &file_handle, shared_ptr write_buffer); + + vector Glob(const string &glob_pattern, FileOpener *opener = nullptr) override; + bool ListFiles(const string &directory, const std::function &callback, + FileOpener *opener = nullptr) override; + + //! Wrapper around BufferManager::Allocate to limit the number of buffers + BufferHandle Allocate(idx_t part_size, uint16_t max_threads); + + //! S3 is object storage so directories effectively always exist + bool DirectoryExists(const string &directory, optional_ptr opener = nullptr) override { + return true; + } + +protected: + static void NotifyUploadsInProgress(S3FileHandle &file_handle); + duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener) override; + + void FlushBuffer(S3FileHandle &handle, shared_ptr write_buffer); + string GetPayloadHash(char *buffer, idx_t buffer_len); + + // helper for ReadQueryParams + void GetQueryParam(const string &key, string ¶m, CPPHTTPLIB_NAMESPACE::Params &query_params); +}; + +// Helper class to do s3 ListObjectV2 api call https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html +struct AWSListObjectV2 { + static string Request(string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, + string &continuation_token, optional_ptr state, bool use_delimiter = false); + static void ParseKey(string &aws_response, vector &result); + static vector ParseCommonPrefix(string &aws_response); + static string ParseContinuationToken(string &aws_response); +}; +} // namespace duckdb diff --git a/src/duckdb/extension/httpfs/s3fs.cpp b/src/duckdb/extension/httpfs/s3fs.cpp new file mode 100644 index 00000000..4a481d68 --- /dev/null +++ b/src/duckdb/extension/httpfs/s3fs.cpp @@ -0,0 +1,1217 @@ +#include "s3fs.hpp" + +#include "crypto.hpp" +#include "duckdb.hpp" +#ifndef DUCKDB_AMALGAMATION +#include "duckdb/common/exception/http_exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/thread.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/function/scalar/strftime_format.hpp" +#endif + +#include +#include +#include +#include +#include + +namespace duckdb { + +static HeaderMap create_s3_header(string url, string query, string host, string service, string method, + const S3AuthParams &auth_params, string date_now = "", string datetime_now = "", + string payload_hash = "", string content_type = "") { + + HeaderMap res; + res["Host"] = host; + // If access key is not set, we don't set the headers at all to allow accessing public files through s3 urls + if (auth_params.secret_access_key.empty() && auth_params.access_key_id.empty()) { + return res; + } + + if (payload_hash == "") { + payload_hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; // Empty payload hash + } + + // we can pass date/time but this is mostly useful in testing. normally we just get the current datetime here. + if (datetime_now.empty()) { + auto timestamp = Timestamp::GetCurrentTimestamp(); + date_now = StrfTimeFormat::Format(timestamp, "%Y%m%d"); + datetime_now = StrfTimeFormat::Format(timestamp, "%Y%m%dT%H%M%SZ"); + } + + res["x-amz-date"] = datetime_now; + res["x-amz-content-sha256"] = payload_hash; + if (auth_params.session_token.length() > 0) { + res["x-amz-security-token"] = auth_params.session_token; + } + + string signed_headers = ""; + hash_bytes canonical_request_hash; + hash_str canonical_request_hash_str; + if (content_type.length() > 0) { + signed_headers += "content-type;"; + } + signed_headers += "host;x-amz-content-sha256;x-amz-date"; + if (auth_params.session_token.length() > 0) { + signed_headers += ";x-amz-security-token"; + } + auto canonical_request = method + "\n" + S3FileSystem::UrlEncode(url) + "\n" + query; + if (content_type.length() > 0) { + canonical_request += "\ncontent-type:" + content_type; + } + canonical_request += "\nhost:" + host + "\nx-amz-content-sha256:" + payload_hash + "\nx-amz-date:" + datetime_now; + if (auth_params.session_token.length() > 0) { + canonical_request += "\nx-amz-security-token:" + auth_params.session_token; + } + + canonical_request += "\n\n" + signed_headers + "\n" + payload_hash; + sha256(canonical_request.c_str(), canonical_request.length(), canonical_request_hash); + + hex256(canonical_request_hash, canonical_request_hash_str); + auto string_to_sign = "AWS4-HMAC-SHA256\n" + datetime_now + "\n" + date_now + "/" + auth_params.region + "/" + + service + "/aws4_request\n" + string((char *)canonical_request_hash_str, sizeof(hash_str)); + // compute signature + hash_bytes k_date, k_region, k_service, signing_key, signature; + hash_str signature_str; + auto sign_key = "AWS4" + auth_params.secret_access_key; + hmac256(date_now, sign_key.c_str(), sign_key.length(), k_date); + hmac256(auth_params.region, k_date, k_region); + hmac256(service, k_region, k_service); + hmac256("aws4_request", k_service, signing_key); + hmac256(string_to_sign, signing_key, signature); + hex256(signature, signature_str); + + res["Authorization"] = "AWS4-HMAC-SHA256 Credential=" + auth_params.access_key_id + "/" + date_now + "/" + + auth_params.region + "/" + service + "/aws4_request, SignedHeaders=" + signed_headers + + ", Signature=" + string((char *)signature_str, sizeof(hash_str)); + + return res; +} + +static duckdb::unique_ptr initialize_http_headers(HeaderMap &header_map) { + auto headers = make_uniq(); + for (auto &entry : header_map) { + headers->insert(entry); + } + return headers; +} + +string S3FileSystem::UrlDecode(string input) { + string result; + result.reserve(input.size()); + char ch; + replace(input.begin(), input.end(), '+', ' '); + for (idx_t i = 0; i < input.length(); i++) { + if (int(input[i]) == 37) { + unsigned int ii; + sscanf(input.substr(i + 1, 2).c_str(), "%x", &ii); + ch = static_cast(ii); + result += ch; + i += 2; + } else { + result += input[i]; + } + } + return result; +} + +string S3FileSystem::UrlEncode(const string &input, bool encode_slash) { + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + static const char *hex_digit = "0123456789ABCDEF"; + string result; + result.reserve(input.size()); + for (idx_t i = 0; i < input.length(); i++) { + char ch = input[i]; + if ((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || + ch == '-' || ch == '~' || ch == '.') { + result += ch; + } else if (ch == '/') { + if (encode_slash) { + result += string("%2F"); + } else { + result += ch; + } + } else { + result += string("%"); + result += hex_digit[static_cast(ch) >> 4]; + result += hex_digit[static_cast(ch) & 15]; + } + } + return result; +} + +void AWSEnvironmentCredentialsProvider::SetExtensionOptionValue(string key, const char *env_var_name) { + char *evar; + + if ((evar = std::getenv(env_var_name)) != NULL) { + if (StringUtil::Lower(evar) == "false") { + this->config.SetOption(key, Value(false)); + } else if (StringUtil::Lower(evar) == "true") { + this->config.SetOption(key, Value(true)); + } else { + this->config.SetOption(key, Value(evar)); + } + } +} + +void AWSEnvironmentCredentialsProvider::SetAll() { + this->SetExtensionOptionValue("s3_region", DEFAULT_REGION_ENV_VAR); + this->SetExtensionOptionValue("s3_region", REGION_ENV_VAR); + this->SetExtensionOptionValue("s3_access_key_id", ACCESS_KEY_ENV_VAR); + this->SetExtensionOptionValue("s3_secret_access_key", SECRET_KEY_ENV_VAR); + this->SetExtensionOptionValue("s3_session_token", SESSION_TOKEN_ENV_VAR); + this->SetExtensionOptionValue("s3_endpoint", DUCKDB_ENDPOINT_ENV_VAR); + this->SetExtensionOptionValue("s3_use_ssl", DUCKDB_USE_SSL_ENV_VAR); +} + +S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() { + S3AuthParams params; + + params.region = DEFAULT_REGION_ENV_VAR; + params.region = REGION_ENV_VAR; + params.access_key_id = ACCESS_KEY_ENV_VAR; + params.secret_access_key = SECRET_KEY_ENV_VAR; + params.session_token = SESSION_TOKEN_ENV_VAR; + params.endpoint = DUCKDB_ENDPOINT_ENV_VAR; + params.use_ssl = DUCKDB_USE_SSL_ENV_VAR; + + return params; +} + +unique_ptr S3AuthParams::ReadFromStoredCredentials(optional_ptr opener, string path) { + if (!opener) { + return nullptr; + } + auto db = opener->TryGetDatabase(); + if (!db) { + return nullptr; + } + auto &secret_manager = db->GetSecretManager(); + auto context = opener->TryGetClientContext(); + auto transaction = context ? CatalogTransaction::GetSystemCatalogTransaction(*context) + : CatalogTransaction::GetSystemTransaction(*db); + + auto secret_match = secret_manager.LookupSecret(transaction, path, "s3"); + if (!secret_match.HasMatch()) { + secret_match = secret_manager.LookupSecret(transaction, path, "r2"); + } + if (!secret_match.HasMatch()) { + secret_match = secret_manager.LookupSecret(transaction, path, "gcs"); + } + if (!secret_match.HasMatch()) { + return nullptr; + } + + // Return the stored credentials + const auto &secret = secret_match.GetSecret(); + const auto &kv_secret = dynamic_cast(secret); + + return make_uniq(S3SecretHelper::GetParams(kv_secret)); +} + +S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerInfo &info) { + S3AuthParams result; + Value value; + + if (FileOpener::TryGetCurrentSetting(opener, "s3_region", value, info)) { + result.region = value.ToString(); + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_access_key_id", value, info)) { + result.access_key_id = value.ToString(); + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_secret_access_key", value, info)) { + result.secret_access_key = value.ToString(); + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_session_token", value, info)) { + result.session_token = value.ToString(); + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_endpoint", value, info)) { + if (value.ToString().empty()) { + if (StringUtil::StartsWith(info.file_path, "gcs://") || StringUtil::StartsWith(info.file_path, "gs://")) { + result.endpoint = "storage.googleapis.com"; + } else { + result.endpoint = "s3.amazonaws.com"; + } + } else { + result.endpoint = value.ToString(); + } + } else { + result.endpoint = "s3.amazonaws.com"; + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_url_style", value, info)) { + auto val_str = value.ToString(); + if (!(val_str == "vhost" || val_str != "path" || !val_str.empty())) { + throw std::runtime_error( + "Incorrect setting found for s3_url_style, allowed values are: 'path' and 'vhost'"); + } + result.url_style = val_str; + } else { + result.url_style = "vhost"; + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_use_ssl", value, info)) { + result.use_ssl = value.GetValue(); + } else { + result.use_ssl = true; + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_url_compatibility_mode", value, info)) { + result.s3_url_compatibility_mode = value.GetValue(); + } else { + result.s3_url_compatibility_mode = true; + } + + return result; +} + +unique_ptr S3SecretHelper::CreateSecret(vector &prefix_paths_p, string &type, string &provider, + string &name, S3AuthParams ¶ms) { + auto return_value = make_uniq(prefix_paths_p, type, provider, name); + + //! Set key value map + return_value->secret_map["region"] = params.region; + return_value->secret_map["key_id"] = params.access_key_id; + return_value->secret_map["secret"] = params.secret_access_key; + return_value->secret_map["session_token"] = params.session_token; + return_value->secret_map["endpoint"] = params.endpoint; + return_value->secret_map["url_style"] = params.url_style; + return_value->secret_map["use_ssl"] = params.use_ssl; + return_value->secret_map["s3_url_compatibility_mode"] = params.s3_url_compatibility_mode; + + //! Set redact keys + return_value->redact_keys = {"secret", "session_token"}; + + return return_value; +} + +S3AuthParams S3SecretHelper::GetParams(const KeyValueSecret &secret) { + S3AuthParams params; + if (!secret.TryGetValue("region").IsNull()) { + params.region = secret.TryGetValue("region").ToString(); + } + if (!secret.TryGetValue("key_id").IsNull()) { + params.access_key_id = secret.TryGetValue("key_id").ToString(); + } + if (!secret.TryGetValue("secret").IsNull()) { + params.secret_access_key = secret.TryGetValue("secret").ToString(); + } + if (!secret.TryGetValue("session_token").IsNull()) { + params.session_token = secret.TryGetValue("session_token").ToString(); + } + if (!secret.TryGetValue("endpoint").IsNull()) { + params.endpoint = secret.TryGetValue("endpoint").ToString(); + } + if (!secret.TryGetValue("url_style").IsNull()) { + params.url_style = secret.TryGetValue("url_style").ToString(); + } + if (!secret.TryGetValue("use_ssl").IsNull()) { + params.use_ssl = secret.TryGetValue("use_ssl").GetValue(); + } + if (!secret.TryGetValue("s3_url_compatibility_mode").IsNull()) { + params.s3_url_compatibility_mode = secret.TryGetValue("s3_url_compatibility_mode").GetValue(); + } + return params; +} + +S3FileHandle::~S3FileHandle() { + if (Exception::UncaughtException()) { + // We are in an exception, don't do anything + return; + } + + try { + Close(); + } catch (...) { // NOLINT + } +} + +S3ConfigParams S3ConfigParams::ReadFrom(optional_ptr opener) { + uint64_t uploader_max_filesize; + uint64_t max_parts_per_file; + uint64_t max_upload_threads; + Value value; + + if (FileOpener::TryGetCurrentSetting(opener, "s3_uploader_max_filesize", value)) { + uploader_max_filesize = DBConfig::ParseMemoryLimit(value.GetValue()); + } else { + uploader_max_filesize = S3ConfigParams::DEFAULT_MAX_FILESIZE; + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_uploader_max_parts_per_file", value)) { + max_parts_per_file = value.GetValue(); + } else { + max_parts_per_file = S3ConfigParams::DEFAULT_MAX_PARTS_PER_FILE; // AWS Default + } + + if (FileOpener::TryGetCurrentSetting(opener, "s3_uploader_thread_limit", value)) { + max_upload_threads = value.GetValue(); + } else { + max_upload_threads = S3ConfigParams::DEFAULT_MAX_UPLOAD_THREADS; + } + + return {uploader_max_filesize, max_parts_per_file, max_upload_threads}; +} + +void S3FileHandle::Close() { + auto &s3fs = (S3FileSystem &)file_system; + if (flags.OpenForWriting() && !upload_finalized) { + s3fs.FlushAllBuffers(*this); + if (parts_uploaded) { + s3fs.FinalizeMultipartUpload(*this); + } + } +} + +void S3FileHandle::InitializeClient(optional_ptr client_context) { + auto parsed_url = S3FileSystem::S3UrlParse(path, this->auth_params); + + string proto_host_port = parsed_url.http_proto + parsed_url.host; + http_client = HTTPFileSystem::GetClient(this->http_params, proto_host_port.c_str(), this); +} + +// Opens the multipart upload and returns the ID +string S3FileSystem::InitializeMultipartUpload(S3FileHandle &file_handle) { + auto &s3fs = (S3FileSystem &)file_handle.file_system; + + // AWS response is around 300~ chars in docs so this should be enough to not need a resize + idx_t response_buffer_len = 1000; + auto response_buffer = duckdb::unique_ptr {new char[response_buffer_len]}; + + string query_param = "uploads="; + auto res = s3fs.PostRequest(file_handle, file_handle.path, {}, response_buffer, response_buffer_len, nullptr, 0, + query_param); + string result(response_buffer.get(), response_buffer_len); + + auto open_tag_pos = result.find("", 0); + auto close_tag_pos = result.find("", open_tag_pos); + + if (open_tag_pos == string::npos || close_tag_pos == string::npos) { + throw std::runtime_error("Unexpected response while initializing S3 multipart upload"); + } + + open_tag_pos += 10; // Skip open tag + + return result.substr(open_tag_pos, close_tag_pos - open_tag_pos); +} + +void S3FileSystem::NotifyUploadsInProgress(S3FileHandle &file_handle) { + { + unique_lock lck(file_handle.uploads_in_progress_lock); + file_handle.uploads_in_progress--; + } + // Note that there are 2 cv's because otherwise we might deadlock when the final flushing thread is notified while + // another thread is still waiting for an upload thread + file_handle.uploads_in_progress_cv.notify_one(); + file_handle.final_flush_cv.notify_one(); +} + +void S3FileSystem::UploadBuffer(S3FileHandle &file_handle, shared_ptr write_buffer) { + auto &s3fs = (S3FileSystem &)file_handle.file_system; + + string query_param = "partNumber=" + to_string(write_buffer->part_no + 1) + "&" + + "uploadId=" + S3FileSystem::UrlEncode(file_handle.multipart_upload_id, true); + unique_ptr res; + case_insensitive_map_t::iterator etag_lookup; + + try { + res = s3fs.PutRequest(file_handle, file_handle.path, {}, (char *)write_buffer->Ptr(), write_buffer->idx, + query_param); + + if (res->code != 200) { + throw HTTPException(*res, "Unable to connect to URL %s %s (HTTP code %s)", res->http_url, res->error, + to_string(res->code)); + } + + etag_lookup = res->headers.find("ETag"); + if (etag_lookup == res->headers.end()) { + throw IOException("Unexpected response when uploading part to S3"); + } + + } catch (IOException &ex) { + // Ensure only one thread sets the exception + bool f = false; + auto exchanged = file_handle.uploader_has_error.compare_exchange_strong(f, true); + if (exchanged) { + file_handle.upload_exception = std::current_exception(); + } + + NotifyUploadsInProgress(file_handle); + + return; + } + + // Insert etag + { + unique_lock lck(file_handle.part_etags_lock); + file_handle.part_etags.insert(std::pair(write_buffer->part_no, etag_lookup->second)); + } + + file_handle.parts_uploaded++; + + // Free up space for another thread to acquire an S3WriteBuffer + write_buffer.reset(); + + NotifyUploadsInProgress(file_handle); +} + +void S3FileSystem::FlushBuffer(S3FileHandle &file_handle, shared_ptr write_buffer) { + if (write_buffer->idx == 0) { + return; + } + + auto uploading = write_buffer->uploading.load(); + if (uploading) { + return; + } + bool can_upload = write_buffer->uploading.compare_exchange_strong(uploading, true); + if (!can_upload) { + return; + } + + file_handle.RethrowIOError(); + + { + unique_lock lck(file_handle.write_buffers_lock); + file_handle.write_buffers.erase(write_buffer->part_no); + } + + { + unique_lock lck(file_handle.uploads_in_progress_lock); + // check if there are upload threads available + if (file_handle.uploads_in_progress >= file_handle.config_params.max_upload_threads) { + // there are not - wait for one to become available + file_handle.uploads_in_progress_cv.wait(lck, [&file_handle] { + return file_handle.uploads_in_progress < file_handle.config_params.max_upload_threads; + }); + } + file_handle.uploads_in_progress++; + } + + thread upload_thread(UploadBuffer, std::ref(file_handle), write_buffer); + upload_thread.detach(); +} + +// Note that FlushAll currently does not allow to continue writing afterwards. Therefore, FinalizeMultipartUpload should +// be called right after it! +// TODO: we can fix this by keeping the last partially written buffer in memory and allow reuploading it with new data. +void S3FileSystem::FlushAllBuffers(S3FileHandle &file_handle) { + // Collect references to all buffers to check + vector> to_flush; + file_handle.write_buffers_lock.lock(); + for (auto &item : file_handle.write_buffers) { + to_flush.push_back(item.second); + } + file_handle.write_buffers_lock.unlock(); + + // Flush all buffers that aren't already uploading + for (auto &write_buffer : to_flush) { + if (!write_buffer->uploading) { + FlushBuffer(file_handle, write_buffer); + } + } + unique_lock lck(file_handle.uploads_in_progress_lock); + file_handle.final_flush_cv.wait(lck, [&file_handle] { return file_handle.uploads_in_progress == 0; }); + + file_handle.RethrowIOError(); +} + +void S3FileSystem::FinalizeMultipartUpload(S3FileHandle &file_handle) { + auto &s3fs = (S3FileSystem &)file_handle.file_system; + file_handle.upload_finalized = true; + + std::stringstream ss; + ss << ""; + + auto parts = file_handle.parts_uploaded.load(); + for (auto i = 0; i < parts; i++) { + auto etag_lookup = file_handle.part_etags.find(i); + if (etag_lookup == file_handle.part_etags.end()) { + throw IOException("Unknown part number"); + } + ss << "" << etag_lookup->second << "" << i + 1 << ""; + } + ss << ""; + string body = ss.str(); + + // Response is around ~400 in AWS docs so this should be enough to not need a resize + idx_t response_buffer_len = 1000; + auto response_buffer = duckdb::unique_ptr {new char[response_buffer_len]}; + + string query_param = "uploadId=" + S3FileSystem::UrlEncode(file_handle.multipart_upload_id, true); + auto res = s3fs.PostRequest(file_handle, file_handle.path, {}, response_buffer, response_buffer_len, + (char *)body.c_str(), body.length(), query_param); + string result(response_buffer.get(), response_buffer_len); + + auto open_tag_pos = result.find("code, + result); + } +} + +// Wrapper around the BufferManager::Allocate to that allows limiting the number of buffers that will be handed out +BufferHandle S3FileSystem::Allocate(idx_t part_size, uint16_t max_threads) { + return buffer_manager.Allocate(MemoryTag::EXTENSION, part_size); +} + +shared_ptr S3FileHandle::GetBuffer(uint16_t write_buffer_idx) { + auto &s3fs = (S3FileSystem &)file_system; + + // Check if write buffer already exists + { + unique_lock lck(write_buffers_lock); + auto lookup_result = write_buffers.find(write_buffer_idx); + if (lookup_result != write_buffers.end()) { + shared_ptr buffer = lookup_result->second; + return buffer; + } + } + + auto buffer_handle = s3fs.Allocate(part_size, config_params.max_upload_threads); + auto new_write_buffer = + make_shared_ptr(write_buffer_idx * part_size, part_size, std::move(buffer_handle)); + { + unique_lock lck(write_buffers_lock); + auto lookup_result = write_buffers.find(write_buffer_idx); + + // Check if other thread has created the same buffer, if so we return theirs and drop ours. + if (lookup_result != write_buffers.end()) { + // write_buffer_idx << std::endl; + shared_ptr write_buffer = lookup_result->second; + return write_buffer; + } + write_buffers.insert(pair>(write_buffer_idx, new_write_buffer)); + } + + return new_write_buffer; +} + +void S3FileSystem::GetQueryParam(const string &key, string ¶m, duckdb_httplib_openssl::Params &query_params) { + auto found_param = query_params.find(key); + if (found_param != query_params.end()) { + param = found_param->second; + query_params.erase(found_param); + } +} + +void S3FileSystem::ReadQueryParams(const string &url_query_param, S3AuthParams ¶ms) { + if (url_query_param.empty()) { + return; + } + + duckdb_httplib_openssl::Params query_params; + duckdb_httplib_openssl::detail::parse_query_text(url_query_param, query_params); + + GetQueryParam("s3_region", params.region, query_params); + GetQueryParam("s3_access_key_id", params.access_key_id, query_params); + GetQueryParam("s3_secret_access_key", params.secret_access_key, query_params); + GetQueryParam("s3_session_token", params.session_token, query_params); + GetQueryParam("s3_endpoint", params.endpoint, query_params); + GetQueryParam("s3_url_style", params.url_style, query_params); + auto found_param = query_params.find("s3_use_ssl"); + if (found_param != query_params.end()) { + if (found_param->second == "true") { + params.use_ssl = true; + } else if (found_param->second == "false") { + params.use_ssl = false; + } else { + throw IOException("Incorrect setting found for s3_use_ssl, allowed values are: 'true' or 'false'"); + } + query_params.erase(found_param); + } + if (!query_params.empty()) { + throw IOException("Invalid query parameters found. Supported parameters are:\n's3_region', 's3_access_key_id', " + "'s3_secret_access_key', 's3_session_token',\n's3_endpoint', 's3_url_style', 's3_use_ssl'"); + } +} + +static string GetPrefix(string url) { + const string prefixes[] = {"s3://", "s3a://", "s3n://", "gcs://", "gs://", "r2://"}; + for (auto &prefix : prefixes) { + if (StringUtil::StartsWith(url, prefix)) { + return prefix; + } + } + throw IOException("URL needs to start with s3://, gcs:// or r2://"); + return string(); +} + +ParsedS3Url S3FileSystem::S3UrlParse(string url, S3AuthParams ¶ms) { + string http_proto, prefix, host, bucket, key, path, query_param, trimmed_s3_url; + + prefix = GetPrefix(url); + auto prefix_end_pos = url.find("//") + 2; + auto slash_pos = url.find('/', prefix_end_pos); + if (slash_pos == string::npos) { + throw IOException("URL needs to contain a '/' after the host"); + } + bucket = url.substr(prefix_end_pos, slash_pos - prefix_end_pos); + if (bucket.empty()) { + throw IOException("URL needs to contain a bucket name"); + } + + if (params.s3_url_compatibility_mode) { + // In url compatibility mode, we will ignore any special chars, so query param strings are disabled + trimmed_s3_url = url; + key += url.substr(slash_pos); + } else { + // Parse query parameters + auto question_pos = url.find_first_of('?'); + if (question_pos != string::npos) { + query_param = url.substr(question_pos + 1); + trimmed_s3_url = url.substr(0, question_pos); + } else { + trimmed_s3_url = url; + } + + if (!query_param.empty()) { + key += url.substr(slash_pos, question_pos - slash_pos); + } else { + key += url.substr(slash_pos); + } + } + + if (key.empty()) { + throw IOException("URL needs to contain key"); + } + + // Derived host and path based on the endpoint + auto sub_path_pos = params.endpoint.find_first_of('/'); + if (sub_path_pos != string::npos) { + // Host header should conform to : so not include the path + host = params.endpoint.substr(0, sub_path_pos); + path = params.endpoint.substr(sub_path_pos); + } else { + host = params.endpoint; + path = ""; + } + + // Update host and path according to the url style + // See https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html + if (params.url_style == "vhost" || params.url_style == "") { + host = bucket + "." + host; + } else if (params.url_style == "path") { + path += "/" + bucket; + } + + // Append key (including leading slash) to the path + path += key; + + // Remove leading slash from key + key = key.substr(1); + + http_proto = params.use_ssl ? "https://" : "http://"; + + return {http_proto, prefix, host, bucket, key, path, query_param, trimmed_s3_url}; +} + +string S3FileSystem::GetPayloadHash(char *buffer, idx_t buffer_len) { + if (buffer_len > 0) { + hash_bytes payload_hash_bytes; + hash_str payload_hash_str; + sha256(buffer, buffer_len, payload_hash_bytes); + hex256(payload_hash_bytes, payload_hash_str); + return string((char *)payload_hash_str, sizeof(payload_hash_str)); + } else { + return ""; + } +} + +string ParsedS3Url::GetHTTPUrl(S3AuthParams &auth_params, const string &http_query_string) { + string full_url = http_proto + host + S3FileSystem::UrlEncode(path); + + if (!http_query_string.empty()) { + full_url += "?" + http_query_string; + } + return full_url; +} + +unique_ptr S3FileSystem::PostRequest(FileHandle &handle, string url, HeaderMap header_map, + duckdb::unique_ptr &buffer_out, idx_t &buffer_out_len, + char *buffer_in, idx_t buffer_in_len, string http_params) { + auto auth_params = handle.Cast().auth_params; + auto parsed_s3_url = S3UrlParse(url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); + auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); + auto headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, "", + "", payload_hash, "application/octet-stream"); + + return HTTPFileSystem::PostRequest(handle, http_url, headers, buffer_out, buffer_out_len, buffer_in, buffer_in_len); +} + +unique_ptr S3FileSystem::PutRequest(FileHandle &handle, string url, HeaderMap header_map, + char *buffer_in, idx_t buffer_in_len, string http_params) { + auto auth_params = handle.Cast().auth_params; + auto parsed_s3_url = S3UrlParse(url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); + auto content_type = "application/octet-stream"; + auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); + + auto headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, "", + "", payload_hash, content_type); + return HTTPFileSystem::PutRequest(handle, http_url, headers, buffer_in, buffer_in_len); +} + +unique_ptr S3FileSystem::HeadRequest(FileHandle &handle, string s3_url, HeaderMap header_map) { + auto auth_params = handle.Cast().auth_params; + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + auto headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); + return HTTPFileSystem::HeadRequest(handle, http_url, headers); +} + +unique_ptr S3FileSystem::GetRequest(FileHandle &handle, string s3_url, HeaderMap header_map) { + auto auth_params = handle.Cast().auth_params; + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + auto headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + return HTTPFileSystem::GetRequest(handle, http_url, headers); +} + +unique_ptr S3FileSystem::GetRangeRequest(FileHandle &handle, string s3_url, HeaderMap header_map, + idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { + auto auth_params = handle.Cast().auth_params; + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + auto headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + return HTTPFileSystem::GetRangeRequest(handle, http_url, headers, file_offset, buffer_out, buffer_out_len); +} + +unique_ptr S3FileSystem::CreateHandle(const string &path, FileOpenFlags flags, + optional_ptr opener) { + FileOpenerInfo info = {path}; + + S3AuthParams auth_params; + auto registered_params = S3AuthParams::ReadFromStoredCredentials(opener, path); + if (registered_params) { + auth_params = *registered_params; + } else { + auth_params = S3AuthParams::ReadFrom(opener, info); + } + + // Scan the query string for any s3 authentication parameters + auto parsed_s3_url = S3UrlParse(path, auth_params); + ReadQueryParams(parsed_s3_url.query_param, auth_params); + + return duckdb::make_uniq(*this, path, flags, HTTPParams::ReadFrom(opener), auth_params, + S3ConfigParams::ReadFrom(opener)); +} + +// this computes the signature from https://czak.pl/2015/09/15/s3-rest-api-with-curl.html +void S3FileSystem::Verify() { + S3AuthParams auth_params; + auth_params.region = "us-east-1"; + auth_params.access_key_id = "AKIAIOSFODNN7EXAMPLE"; + auth_params.secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + + auto test_header = create_s3_header("/", "", "my-precious-bucket.s3.amazonaws.com", "s3", "GET", auth_params, + "20150915", "20150915T124500Z"); + if (test_header["Authorization"] != + "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20150915/us-east-1/s3/aws4_request, " + "SignedHeaders=host;x-amz-content-sha256;x-amz-date, " + "Signature=182072eb53d85c36b2d791a1fa46a12d23454ec1e921b02075c23aee40166d5a") { + throw std::runtime_error("test fail"); + } + + if (UrlEncode("/category=Books/") != "/category%3DBooks/") { + throw std::runtime_error("test fail"); + } + if (UrlEncode("/?category=Books&title=Ducks Retreat/") != "/%3Fcategory%3DBooks%26title%3DDucks%20Retreat/") { + throw std::runtime_error("test fail"); + } + if (UrlEncode("/?category=Books&title=Ducks Retreat/", true) != + "%2F%3Fcategory%3DBooks%26title%3DDucks%20Retreat%2F") { + throw std::runtime_error("test fail"); + } + // AWS_SECRET_ACCESS_KEY="vs1BZPxSL2qVARBSg5vCMKJsavCoEPlo/HSHRaVe" AWS_ACCESS_KEY_ID="ASIAYSPIOYDTHTBIITVC" + // AWS_SESSION_TOKEN="IQoJb3JpZ2luX2VjENX//////////wEaCWV1LXdlc3QtMSJHMEUCIQDfjzs9BYHrEXDMU/NR+PHV1uSTr7CSVSQdjKSfiPRLdgIgCCztF0VMbi9+uHHAfBVKhV4t9MlUrQg3VAOIsLxrWyoqlAIIHRAAGgw1ODk0MzQ4OTY2MTQiDOGl2DsYxENcKCbh+irxARe91faI+hwUhT60sMGRFg0GWefKnPclH4uRFzczrDOcJlAAaQRJ7KOsT8BrJlrY1jSgjkO7PkVjPp92vi6lJX77bg99MkUTJActiOKmd84XvAE5bFc/jFbqechtBjXzopAPkKsGuaqAhCenXnFt6cwq+LZikv/NJGVw7TRphLV+Aq9PSL9XwdzIgsW2qXwe1c3rxDNj53yStRZHVggdxJ0OgHx5v040c98gFphzSULHyg0OY6wmCMTYcswpb4kO2IIi6AiD9cY25TlwPKRKPi5CdBsTPnyTeW62u7PvwK0fTSy4ZuJUuGKQnH2cKmCXquEwoOHEiQY6nQH9fzY/EDGHMRxWWhxu0HiqIfsuFqC7GS0p0ToKQE+pzNsvVwMjZc+KILIDDQpdCWRIwu53I5PZy2Cvk+3y4XLvdZKQCsAKqeOc4c94UAS4NmUT7mCDOuRV0cLBVM8F0JYBGrUxyI+YoIvHhQWmnRLuKgTb5PkF7ZWrXBHFWG5/tZDOvBbbaCWTlRCL9b0Vpg5+BM/81xd8jChP4w83" + // aws --region eu-west-1 --debug s3 ls my-precious-bucket 2>&1 | less + string canonical_query_string = "delimiter=%2F&encoding-type=url&list-type=2&prefix="; // aws s3 ls + + S3AuthParams auth_params2; + auth_params2.region = "eu-west-1"; + auth_params2.access_key_id = "ASIAYSPIOYDTHTBIITVC"; + auth_params2.secret_access_key = "vs1BZPxSL2qVARBSg5vCMKJsavCoEPlo/HSHRaVe"; + auth_params2.session_token = + "IQoJb3JpZ2luX2VjENX//////////wEaCWV1LXdlc3QtMSJHMEUCIQDfjzs9BYHrEXDMU/" + "NR+PHV1uSTr7CSVSQdjKSfiPRLdgIgCCztF0VMbi9+" + "uHHAfBVKhV4t9MlUrQg3VAOIsLxrWyoqlAIIHRAAGgw1ODk0MzQ4OTY2MTQiDOGl2DsYxENcKCbh+irxARe91faI+" + "hwUhT60sMGRFg0GWefKnPclH4uRFzczrDOcJlAAaQRJ7KOsT8BrJlrY1jSgjkO7PkVjPp92vi6lJX77bg99MkUTJA" + "ctiOKmd84XvAE5bFc/jFbqechtBjXzopAPkKsGuaqAhCenXnFt6cwq+LZikv/" + "NJGVw7TRphLV+" + "Aq9PSL9XwdzIgsW2qXwe1c3rxDNj53yStRZHVggdxJ0OgHx5v040c98gFphzSULHyg0OY6wmCMTYcswpb4kO2IIi6" + "AiD9cY25TlwPKRKPi5CdBsTPnyTeW62u7PvwK0fTSy4ZuJUuGKQnH2cKmCXquEwoOHEiQY6nQH9fzY/" + "EDGHMRxWWhxu0HiqIfsuFqC7GS0p0ToKQE+pzNsvVwMjZc+KILIDDQpdCWRIwu53I5PZy2Cvk+" + "3y4XLvdZKQCsAKqeOc4c94UAS4NmUT7mCDOuRV0cLBVM8F0JYBGrUxyI+" + "YoIvHhQWmnRLuKgTb5PkF7ZWrXBHFWG5/tZDOvBbbaCWTlRCL9b0Vpg5+BM/81xd8jChP4w83"; + + auto test_header2 = create_s3_header("/", canonical_query_string, "my-precious-bucket.s3.eu-west-1.amazonaws.com", + "s3", "GET", auth_params2, "20210904", "20210904T121746Z"); + if (test_header2["Authorization"] != + "AWS4-HMAC-SHA256 Credential=ASIAYSPIOYDTHTBIITVC/20210904/eu-west-1/s3/aws4_request, " + "SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, " + "Signature=4d9d6b59d7836b6485f6ad822de97be40287da30347d83042ea7fbed530dc4c0") { + throw std::runtime_error("test fail"); + } + + S3AuthParams auth_params3; + auth_params3.region = "eu-west-1"; + auth_params3.access_key_id = "S3RVER"; + auth_params3.secret_access_key = "S3RVER"; + + auto test_header3 = + create_s3_header("/correct_auth_test.csv", "", "test-bucket-ceiveran.s3.amazonaws.com", "s3", "PUT", + auth_params3, "20220121", "20220121T141452Z", + "28a0cf6ac5c4cb73793091fe6ecc6a68bf90855ac9186158748158f50241bb0c", "text/data;charset=utf-8"); + if (test_header3["Authorization"] != "AWS4-HMAC-SHA256 Credential=S3RVER/20220121/eu-west-1/s3/aws4_request, " + "SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, " + "Signature=5d9a6cbfaa78a6d0f2ab7df0445e2f1cc9c80cd3655ac7de9e7219c036f23f02") { + throw std::runtime_error("test3 fail"); + } + + // bug #4082 + S3AuthParams auth_params4; + auth_params4.region = "auto"; + auth_params4.access_key_id = "asdf"; + auth_params4.secret_access_key = "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf"; + create_s3_header("/", "", "exampple.com", "s3", "GET", auth_params4); + + if (UrlEncode("/category=Books/") != "/category%3DBooks/") { + throw std::runtime_error("test fail"); + } + if (UrlEncode("/?category=Books&title=Ducks Retreat/") != "/%3Fcategory%3DBooks%26title%3DDucks%20Retreat/") { + throw std::runtime_error("test fail"); + } + if (UrlEncode("/?category=Books&title=Ducks Retreat/", true) != + "%2F%3Fcategory%3DBooks%26title%3DDucks%20Retreat%2F") { + throw std::runtime_error("test fail"); + } + + // TODO add a test that checks the signing for path-style +} + +void S3FileHandle::Initialize(optional_ptr opener) { + HTTPFileHandle::Initialize(opener); + + auto &s3fs = file_system.Cast(); + + if (flags.OpenForWriting()) { + auto aws_minimum_part_size = 5242880; // 5 MiB https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html + auto max_part_count = config_params.max_parts_per_file; + auto required_part_size = config_params.max_file_size / max_part_count; + auto minimum_part_size = MaxValue(aws_minimum_part_size, required_part_size); + + // Round part size up to multiple of BLOCK_SIZE + part_size = ((minimum_part_size + Storage::BLOCK_SIZE - 1) / Storage::BLOCK_SIZE) * Storage::BLOCK_SIZE; + D_ASSERT(part_size * max_part_count >= config_params.max_file_size); + + multipart_upload_id = s3fs.InitializeMultipartUpload(*this); + } +} + +bool S3FileSystem::CanHandleFile(const string &fpath) { + + return fpath.rfind("s3://", 0) * fpath.rfind("s3a://", 0) * fpath.rfind("s3n://", 0) * fpath.rfind("gcs://", 0) * + fpath.rfind("gs://", 0) * fpath.rfind("r2://", 0) == + 0; +} + +void S3FileSystem::FileSync(FileHandle &handle) { + auto &s3fh = handle.Cast(); + if (!s3fh.upload_finalized) { + FlushAllBuffers(s3fh); + FinalizeMultipartUpload(s3fh); + } +} + +void S3FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + auto &s3fh = handle.Cast(); + if (!s3fh.flags.OpenForWriting()) { + throw InternalException("Write called on file not opened in write mode"); + } + int64_t bytes_written = 0; + + while (bytes_written < nr_bytes) { + auto curr_location = location + bytes_written; + + if (curr_location != s3fh.file_offset) { + throw InternalException("Non-sequential write not supported!"); + } + + // Find buffer for writing + auto write_buffer_idx = curr_location / s3fh.part_size; + + // Get write buffer, may block until buffer is available + auto write_buffer = s3fh.GetBuffer(write_buffer_idx); + + // Writing to buffer + auto idx_to_write = curr_location - write_buffer->buffer_start; + auto bytes_to_write = MinValue(nr_bytes - bytes_written, s3fh.part_size - idx_to_write); + memcpy((char *)write_buffer->Ptr() + idx_to_write, (char *)buffer + bytes_written, bytes_to_write); + write_buffer->idx += bytes_to_write; + + // Flush to HTTP if full + if (write_buffer->idx >= s3fh.part_size) { + FlushBuffer(s3fh, write_buffer); + } + s3fh.file_offset += bytes_to_write; + bytes_written += bytes_to_write; + } +} + +static bool Match(vector::const_iterator key, vector::const_iterator key_end, + vector::const_iterator pattern, vector::const_iterator pattern_end) { + + while (key != key_end && pattern != pattern_end) { + if (*pattern == "**") { + if (std::next(pattern) == pattern_end) { + return true; + } + while (key != key_end) { + if (Match(key, key_end, std::next(pattern), pattern_end)) { + return true; + } + key++; + } + return false; + } + if (!LikeFun::Glob(key->data(), key->length(), pattern->data(), pattern->length())) { + return false; + } + key++; + pattern++; + } + return key == key_end && pattern == pattern_end; +} + +vector S3FileSystem::Glob(const string &glob_pattern, FileOpener *opener) { + if (opener == nullptr) { + throw InternalException("Cannot S3 Glob without FileOpener"); + } + + FileOpenerInfo info = {glob_pattern}; + + // Trim any query parameters from the string + S3AuthParams s3_auth_params; + auto registered_params = S3AuthParams::ReadFromStoredCredentials(opener, glob_pattern); + if (registered_params) { + s3_auth_params = *registered_params; + } else { + s3_auth_params = S3AuthParams::ReadFrom(opener, info); + } + + // In url compatibility mode, we ignore globs allowing users to query files with the glob chars + if (s3_auth_params.s3_url_compatibility_mode) { + return {glob_pattern}; + } + + auto parsed_s3_url = S3UrlParse(glob_pattern, s3_auth_params); + auto parsed_glob_url = parsed_s3_url.trimmed_s3_url; + + // AWS matches on prefix, not glob pattern, so we take a substring until the first wildcard char for the aws calls + auto first_wildcard_pos = parsed_glob_url.find_first_of("*[\\"); + if (first_wildcard_pos == string::npos) { + return {glob_pattern}; + } + + string shared_path = parsed_glob_url.substr(0, first_wildcard_pos); + auto http_params = HTTPParams::ReadFrom(opener); + + ReadQueryParams(parsed_s3_url.query_param, s3_auth_params); + + // Do main listobjectsv2 request + vector s3_keys; + string main_continuation_token; + + // Main paging loop + do { + // main listobject call, may + string response_str = AWSListObjectV2::Request(shared_path, http_params, s3_auth_params, + main_continuation_token, HTTPState::TryGetState(opener).get()); + main_continuation_token = AWSListObjectV2::ParseContinuationToken(response_str); + AWSListObjectV2::ParseKey(response_str, s3_keys); + + // Repeat requests until the keys of all common prefixes are parsed. + auto common_prefixes = AWSListObjectV2::ParseCommonPrefix(response_str); + while (!common_prefixes.empty()) { + auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + common_prefixes.back(); + common_prefixes.pop_back(); + + // TODO we could optimize here by doing a match on the prefix, if it doesn't match we can skip this prefix + // Paging loop for common prefix requests + string common_prefix_continuation_token; + do { + auto prefix_res = + AWSListObjectV2::Request(prefix_path, http_params, s3_auth_params, common_prefix_continuation_token, + HTTPState::TryGetState(opener).get()); + AWSListObjectV2::ParseKey(prefix_res, s3_keys); + auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res); + common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end()); + common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res); + } while (!common_prefix_continuation_token.empty()); + } + } while (!main_continuation_token.empty()); + + vector pattern_splits = StringUtil::Split(parsed_s3_url.key, "/"); + vector result; + for (const auto &s3_key : s3_keys) { + + vector key_splits = StringUtil::Split(s3_key, "/"); + bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end()); + + if (is_match) { + auto result_full_url = parsed_s3_url.prefix + parsed_s3_url.bucket + "/" + s3_key; + // if a ? char was present, we re-add it here as the url parsing will have trimmed it. + if (!parsed_s3_url.query_param.empty()) { + result_full_url += '?' + parsed_s3_url.query_param; + } + result.push_back(result_full_url); + } + } + return result; +} + +string S3FileSystem::GetName() const { + return "S3FileSystem"; +} + +bool S3FileSystem::ListFiles(const string &directory, const std::function &callback, + FileOpener *opener) { + string trimmed_dir = directory; + StringUtil::RTrim(trimmed_dir, PathSeparator(trimmed_dir)); + auto glob_res = Glob(JoinPath(trimmed_dir, "**"), opener); + + if (glob_res.empty()) { + return false; + } + + for (const auto &file : glob_res) { + callback(file, false); + } + + return true; +} + +string AWSListObjectV2::Request(string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, + string &continuation_token, optional_ptr state, bool use_delimiter) { + auto parsed_url = S3FileSystem::S3UrlParse(path, s3_auth_params); + + // Construct the ListObjectsV2 call + string req_path = parsed_url.path.substr(0, parsed_url.path.length() - parsed_url.key.length()); + + string req_params; + if (!continuation_token.empty()) { + req_params += "continuation-token=" + S3FileSystem::UrlEncode(continuation_token, true); + req_params += "&"; + } + req_params += "encoding-type=url&list-type=2"; + req_params += "&prefix=" + S3FileSystem::UrlEncode(parsed_url.key, true); + + if (use_delimiter) { + req_params += "&delimiter=%2F"; + } + + string listobjectv2_url = req_path + "?" + req_params; + + auto header_map = + create_s3_header(req_path, req_params, parsed_url.host, "s3", "GET", s3_auth_params, "", "", "", ""); + auto headers = initialize_http_headers(header_map); + + auto client = S3FileSystem::GetClient(http_params, (parsed_url.http_proto + parsed_url.host).c_str(), + nullptr); // Get requests use fresh connection + std::stringstream response; + auto res = client->Get( + listobjectv2_url.c_str(), *headers, + [&](const duckdb_httplib_openssl::Response &response) { + if (response.status >= 400) { + throw HTTPException(response, "HTTP GET error on '%s' (HTTP %d)", listobjectv2_url, response.status); + } + return true; + }, + [&](const char *data, size_t data_length) { + if (state) { + state->total_bytes_received += data_length; + } + response << string(data, data_length); + return true; + }); + if (state) { + state->get_count++; + } + if (res.error() != duckdb_httplib_openssl::Error::Success) { + throw IOException(to_string(res.error()) + " error for HTTP GET to '" + listobjectv2_url + "'"); + } + + return response.str(); +} + +void AWSListObjectV2::ParseKey(string &aws_response, vector &result) { + idx_t cur_pos = 0; + while (true) { + auto next_open_tag_pos = aws_response.find("", cur_pos); + if (next_open_tag_pos == string::npos) { + break; + } else { + auto next_close_tag_pos = aws_response.find("", next_open_tag_pos + 5); + if (next_close_tag_pos == string::npos) { + throw InternalException("Failed to parse S3 result"); + } + auto parsed_path = S3FileSystem::UrlDecode( + aws_response.substr(next_open_tag_pos + 5, next_close_tag_pos - next_open_tag_pos - 5)); + if (parsed_path.back() != '/') { + result.push_back(parsed_path); + } + cur_pos = next_close_tag_pos + 6; + } + } +} + +string AWSListObjectV2::ParseContinuationToken(string &aws_response) { + + auto open_tag_pos = aws_response.find(""); + if (open_tag_pos == string::npos) { + return ""; + } else { + auto close_tag_pos = aws_response.find("", open_tag_pos + 23); + if (close_tag_pos == string::npos) { + throw InternalException("Failed to parse S3 result"); + } + return aws_response.substr(open_tag_pos + 23, close_tag_pos - open_tag_pos - 23); + } +} + +vector AWSListObjectV2::ParseCommonPrefix(string &aws_response) { + vector s3_prefixes; + idx_t cur_pos = 0; + while (true) { + cur_pos = aws_response.find("", cur_pos); + if (cur_pos == string::npos) { + break; + } + auto next_open_tag_pos = aws_response.find("", cur_pos); + if (next_open_tag_pos == string::npos) { + throw InternalException("Parsing error while parsing s3 listobject result"); + } else { + auto next_close_tag_pos = aws_response.find("", next_open_tag_pos + 8); + if (next_close_tag_pos == string::npos) { + throw InternalException("Failed to parse S3 result"); + } + auto parsed_path = aws_response.substr(next_open_tag_pos + 8, next_close_tag_pos - next_open_tag_pos - 8); + s3_prefixes.push_back(parsed_path); + cur_pos = next_close_tag_pos + 6; + } + } + return s3_prefixes; +} + +} // namespace duckdb diff --git a/vendor.py b/vendor.py index 820d457b..156a616e 100644 --- a/vendor.py +++ b/vendor.py @@ -9,13 +9,11 @@ parser.add_argument('--duckdb', action='store', help='Path to the DuckDB Version to be vendored in', required=True, type=str) - - args = parser.parse_args() - # list of extensions to bundle -extensions = ['parquet', 'icu', 'json', 'httpfs'] +extensions = ['parquet', 'icu', 'json'] +optional_extensions_list = ['httpfs'] # path to target basedir = os.getcwd() @@ -31,6 +29,37 @@ sys.path.append(scripts_dir) import package_build + +def sanitize_path(x): + return x.replace('\\', '/') + + +def get_optional_extensions(original_source_list, original_includes): + results = [] + for ext in optional_extensions_list: + (optional_sources, optional_includes, _) = package_build.build_package(target_dir, [ext], False) + optional_sources = [os.path.relpath(x, basedir) if os.path.isabs(x) else os.path.join('src', x) for x in + optional_sources] + optional_includes = [os.path.join('src', 'duckdb', x) for x in optional_includes] + condition = [ + f"include_{ext}=='true'", + { + 'sources': [sanitize_path(x) for x in optional_sources if x not in original_source_list], + 'include_dirs': [sanitize_path(x) for x in optional_includes if x not in original_includes], + 'defines': ['DUCKDB_EXTENSION_{}_LINKED'.format(ext.upper())] + } + ] + results.append(condition) + return results + + +def get_optional_extensions_variables(): + result = {} + for ext in optional_extensions_list: + result[f'include_{ext}'] = "