Mercurial > hg > orthanc-dicomweb
view UnitTestsSources/UnitTestsMain.cpp @ 283:69c2330aa561
reorganizing MultipartStreamReader
author | Sebastien Jodogne <s.jodogne@gmail.com> |
---|---|
date | Fri, 07 Jun 2019 12:40:06 +0200 |
parents | 266556663ae6 |
children | c4352c49baef |
line wrap: on
line source
/** * Orthanc - A Lightweight, RESTful DICOM Store * Copyright (C) 2012-2016 Sebastien Jodogne, Medical Physics * Department, University Hospital of Liege, Belgium * Copyright (C) 2017-2019 Osimis S.A., Belgium * * This program is free software: you can redistribute it and/or * modify it under the terms of the GNU Affero General Public License * as published by the Free Software Foundation, either version 3 of * the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. **/ #include <gtest/gtest.h> #include <boost/lexical_cast.hpp> #include <boost/algorithm/string/predicate.hpp> #include "../Plugin/Configuration.h" using namespace OrthancPlugins; OrthancPluginContext* context_ = NULL; TEST(ContentType, Parse) { std::string c; std::map<std::string, std::string> a; ParseContentType(c, a, "Multipart/Related; TYPE=Application/Dicom; Boundary=heLLO"); ASSERT_EQ(c, "multipart/related"); ASSERT_EQ(2u, a.size()); ASSERT_EQ(a["type"], "Application/Dicom"); ASSERT_EQ(a["boundary"], "heLLO"); // The WADO-RS client must support the case where the WADO-RS server // escapes the "type" subfield in the Content-Type header // cf. https://tools.ietf.org/html/rfc7231#section-3.1.1.1 ParseContentType(c, a, "Multipart/Related; TYPE=\"Application/Dicom\" ; Boundary=heLLO"); ASSERT_EQ(c, "multipart/related"); ASSERT_EQ(2u, a.size()); ASSERT_EQ(a["type"], "Application/Dicom"); ASSERT_EQ(a["boundary"], "heLLO"); ParseContentType(c, a, ""); ASSERT_TRUE(c.empty()); ASSERT_EQ(0u, a.size()); ParseContentType(c, a, "multipart/related"); ASSERT_EQ(c, "multipart/related"); ASSERT_EQ(0u, a.size()); } #include <Core/ChunkedBuffer.h> #include <Core/OrthancException.h> #include <Core/SystemToolbox.h> #include <Core/Toolbox.h> #include <boost/algorithm/searching/boyer_moore.hpp> #include <boost/algorithm/searching/boyer_moore_horspool.hpp> #include <boost/algorithm/searching/knuth_morris_pratt.hpp> #include <boost/date_time/posix_time/posix_time.hpp> #include <boost/version.hpp> namespace Orthanc { // Convenience class that wraps a Boost algorithm for string matching class StringMatcher : public boost::noncopyable { public: typedef std::string::const_iterator Iterator; private: typedef boost::algorithm::boyer_moore<Iterator> Search; //typedef boost::algorithm::boyer_moore_horspool<std::string::const_iterator> Search; // WARNING - The lifetime of "pattern_" must be larger than // "search_", as the latter references "pattern_" std::string pattern_; Search search_; bool valid_; Iterator matchBegin_; Iterator matchEnd_; public: StringMatcher(const std::string& pattern) : pattern_(pattern), search_(pattern_.begin(), pattern_.end()), valid_(false) { } const std::string& GetPattern() const { return pattern_; } bool IsValid() const { return valid_; } bool Apply(Iterator start, Iterator end) { #if BOOST_VERSION >= 106200 matchBegin_ = search_(start, end).first; #else matchBegin_ = search_(start, end); #endif if (matchBegin_ == end) { valid_ = false; } else { matchEnd_ = matchBegin_ + pattern_.size(); assert(matchEnd_ <= end); valid_ = true; } return valid_; } bool Apply(const std::string& corpus) { return Apply(corpus.begin(), corpus.end()); } Iterator GetMatchBegin() const { if (valid_) { return matchBegin_; } else { throw OrthancException(ErrorCode_BadSequenceOfCalls); } } Iterator GetMatchEnd() const { if (valid_) { return matchEnd_; } else { throw OrthancException(ErrorCode_BadSequenceOfCalls); } } const char* GetPointerBegin() const { return &GetMatchBegin()[0]; } const char* GetPointerEnd() const { return &GetMatchEnd()[0]; } }; class MultipartStreamReader : public boost::noncopyable { public: typedef std::map<std::string, std::string> HttpHeaders; class IHandler : public boost::noncopyable { public: virtual ~IHandler() { } virtual void Apply(const HttpHeaders& headers, const void* part, size_t size) = 0; }; private: enum State { State_UnusedArea, State_Content, State_Done }; State state_; IHandler* handler_; StringMatcher headersMatcher_; std::auto_ptr<StringMatcher> boundaryMatcher_; //std::auto_ptr<Search> boundaryMatcher_; ChunkedBuffer buffer_; size_t blockSize_; static void ParseHeaders(HttpHeaders& headers, StringMatcher::Iterator start, StringMatcher::Iterator end) { std::string tmp(start, end); std::vector<std::string> lines; Toolbox::TokenizeString(lines, tmp, '\n'); headers.clear(); for (size_t i = 0; i < lines.size(); i++) { size_t separator = lines[i].find(':'); if (separator != std::string::npos) { std::string key = Toolbox::StripSpaces(lines[i].substr(0, separator)); std::string value = Toolbox::StripSpaces(lines[i].substr(separator + 1)); Toolbox::ToLowerCase(key); headers[key] = value; } } } static bool LookupHeaderSizeValue(size_t& target, const HttpHeaders& headers, const std::string& key) { HttpHeaders::const_iterator it = headers.find(key); if (it == headers.end()) { return false; } else { int64_t value; try { value = boost::lexical_cast<int64_t>(it->second); } catch (boost::bad_lexical_cast&) { throw OrthancException(ErrorCode_ParameterOutOfRange); } if (value < 0) { throw OrthancException(ErrorCode_ParameterOutOfRange); } else { target = static_cast<size_t>(value); return true; } } } static bool ParseHeaderValues(std::string& main, HttpHeaders& parameters, const HttpHeaders& headers, const std::string& key) { HttpHeaders::const_iterator it = headers.find(key); main.clear(); parameters.clear(); if (it == headers.end()) { return false; } else { std::vector<std::string> tokens; Toolbox::TokenizeString(tokens, it->second, ';'); if (tokens.empty()) { return false; } main = tokens[0]; for (size_t i = 1; i < tokens.size(); i++) { size_t separator = tokens[i].find('='); if (separator != std::string::npos) { std::string key = Toolbox::StripSpaces(tokens[i].substr(0, separator)); std::string value = Toolbox::StripSpaces(tokens[i].substr(separator + 1)); if (!key.empty()) { Toolbox::ToLowerCase(key); parameters[key] = value; } } } return true; } } void ParseStream() { if (handler_ == NULL || state_ == State_Done) { return; } assert(boundaryMatcher_.get() != NULL); std::string corpus; buffer_.Flatten(corpus); StringMatcher::Iterator current = corpus.begin(); StringMatcher::Iterator corpusEnd = corpus.end(); if (state_ == State_UnusedArea) { /** * "Before the first boundary is an area that is ignored by * MIME-compliant clients. This area is generally used to put * a message to users of old non-MIME clients." * https://en.wikipedia.org/wiki/MIME#Multipart_messages **/ if (boundaryMatcher_->Apply(current, corpusEnd)) { current = boundaryMatcher_->GetMatchBegin(); state_ = State_Content; } else { // We have not seen the end of the unused area yet std::string reminder(current, corpusEnd); buffer_.AddChunkDestructive(reminder); return; } } for (;;) { size_t patternSize = boundaryMatcher_->GetPattern().size(); size_t remainingSize = std::distance(current, corpusEnd); if (remainingSize < patternSize + 2) { break; // Not enough data available } std::string boundary(current, current + patternSize + 2); if (boundary == boundaryMatcher_->GetPattern() + "--") { state_ = State_Done; return; } if (boundary != boundaryMatcher_->GetPattern() + "\r\n") { throw OrthancException(ErrorCode_NetworkProtocol, "Garbage between two items in a multipart stream"); } StringMatcher::Iterator start = current + patternSize + 2; if (!headersMatcher_.Apply(start, corpusEnd)) { break; // Not enough data available } HttpHeaders headers; ParseHeaders(headers, start, headersMatcher_.GetMatchBegin()); size_t contentLength; if (!LookupHeaderSizeValue(contentLength, headers, "content-length")) { if (boundaryMatcher_->Apply(headersMatcher_.GetMatchEnd(), corpusEnd)) { size_t d = std::distance(headersMatcher_.GetMatchEnd(), boundaryMatcher_->GetMatchBegin()); if (d <= 1) { throw OrthancException(ErrorCode_NetworkProtocol); } else { contentLength = d - 2; } } else { break; // Not enough data available to have a full part } } if (headersMatcher_.GetMatchEnd() + contentLength + 2 > corpusEnd) { break; // Not enough data available to have a full part } const char* p = headersMatcher_.GetPointerEnd() + contentLength; if (p[0] != '\r' || p[1] != '\n') { throw OrthancException(ErrorCode_NetworkProtocol, "No endline at the end of a part"); } handler_->Apply(headers, headersMatcher_.GetPointerEnd(), contentLength); current = headersMatcher_.GetMatchEnd() + contentLength + 2; } if (current != corpusEnd) { std::string reminder(current, corpusEnd); buffer_.AddChunkDestructive(reminder); } } public: MultipartStreamReader(const std::string& boundary) : state_(State_UnusedArea), handler_(NULL), headersMatcher_("\r\n\r\n"), boundaryMatcher_(new StringMatcher("--" + boundary)), blockSize_(10 * 1024 * 1024) { } void SetBlockSize(size_t size) { if (size == 0) { throw OrthancException(ErrorCode_ParameterOutOfRange); } else { blockSize_ = size; } } size_t GetBlockSize() const { return blockSize_; } void SetHandler(IHandler& handler) { handler_ = &handler; } void AddChunk(const void* chunk, size_t size) { if (size != 0) { size_t oldSize = buffer_.GetNumBytes(); buffer_.AddChunk(chunk, size); if (oldSize / blockSize_ != buffer_.GetNumBytes() / blockSize_) { ParseStream(); } } } void AddChunk(const std::string& chunk) { if (!chunk.empty()) { AddChunk(chunk.c_str(), chunk.size()); } } void CloseStream() { if (buffer_.GetNumBytes() != 0) { ParseStream(); } } static bool GetMainContentType(std::string& contentType, const HttpHeaders& headers) { HttpHeaders::const_iterator it = headers.find("content-type"); if (it == headers.end()) { return false; } else { contentType = it->second; return true; } } static bool ParseMultipartHeaders(std::string& contentType, std::string& boundary, const HttpHeaders& headers) { std::string tmp; if (!GetMainContentType(tmp, headers)) { return false; } std::vector<std::string> tokens; Orthanc::Toolbox::TokenizeString(tokens, tmp, ';'); if (tokens.empty()) { return false; } contentType = Orthanc::Toolbox::StripSpaces(tokens[0]); if (contentType.empty()) { return false; } for (size_t i = 0; i < tokens.size(); i++) { std::vector<std::string> items; Orthanc::Toolbox::TokenizeString(items, tokens[i], '='); if (items.size() == 2) { if (boost::iequals("boundary", Orthanc::Toolbox::StripSpaces(items[0]))) { boundary = Orthanc::Toolbox::StripSpaces(items[1]); return !boundary.empty(); } } } return false; } }; } class MultipartTester : public Orthanc::MultipartStreamReader::IHandler { private: struct Part { Orthanc::MultipartStreamReader::HttpHeaders headers_; std::string data_; Part(const Orthanc::MultipartStreamReader::HttpHeaders& headers, const void* part, size_t size) : headers_(headers), data_(reinterpret_cast<const char*>(part), size) { } }; std::vector<Part> parts_; public: virtual void Apply(const Orthanc::MultipartStreamReader::HttpHeaders& headers, const void* part, size_t size) { parts_.push_back(Part(headers, part, size)); } unsigned int GetCount() const { return parts_.size(); } Orthanc::MultipartStreamReader::HttpHeaders& GetHeaders(size_t i) { return parts_[i].headers_; } const std::string& GetData(size_t i) const { return parts_[i].data_; } }; TEST(StringMatcher, Basic) { Orthanc::StringMatcher matcher("---"); ASSERT_THROW(matcher.GetMatchBegin(), Orthanc::OrthancException); { const std::string s = "abc----def"; ASSERT_TRUE(matcher.Apply(s)); ASSERT_EQ(3, std::distance(s.begin(), matcher.GetMatchBegin())); ASSERT_EQ("---", std::string(matcher.GetMatchBegin(), matcher.GetMatchEnd())); } { const std::string s = "abc---"; ASSERT_TRUE(matcher.Apply(s)); ASSERT_EQ(3, std::distance(s.begin(), matcher.GetMatchBegin())); ASSERT_EQ(s.end(), matcher.GetMatchEnd()); ASSERT_EQ("---", std::string(matcher.GetMatchBegin(), matcher.GetMatchEnd())); ASSERT_EQ("", std::string(matcher.GetMatchEnd(), s.end())); } { const std::string s = "abc--def"; ASSERT_FALSE(matcher.Apply(s)); ASSERT_THROW(matcher.GetMatchBegin(), Orthanc::OrthancException); ASSERT_THROW(matcher.GetMatchEnd(), Orthanc::OrthancException); } } TEST(MultipartStreamReader, ParseHeaders) { std::string ct, b; { Orthanc::MultipartStreamReader::HttpHeaders h; h["hello"] = "world"; ASSERT_FALSE(Orthanc::MultipartStreamReader::GetMainContentType(ct, h)); ASSERT_FALSE(Orthanc::MultipartStreamReader::ParseMultipartHeaders(ct, b, h)); } { Orthanc::MultipartStreamReader::HttpHeaders h; h["content-type"] = "world"; ASSERT_TRUE(Orthanc::MultipartStreamReader::GetMainContentType(ct, h)); ASSERT_EQ(ct, "world"); ASSERT_FALSE(Orthanc::MultipartStreamReader::ParseMultipartHeaders(ct, b, h)); } { Orthanc::MultipartStreamReader::HttpHeaders h; h["content-type"] = "multipart/related; type=value; boundary=1234; hello=world"; ASSERT_TRUE(Orthanc::MultipartStreamReader::GetMainContentType(ct, h)); ASSERT_EQ(ct, h["content-type"]); ASSERT_TRUE(Orthanc::MultipartStreamReader::ParseMultipartHeaders(ct, b, h)); ASSERT_EQ(ct, "multipart/related"); ASSERT_EQ(b, "1234"); } } TEST(MultipartStreamReader, BytePerByte) { std::string stream = "GARBAGE"; std::string boundary = "123456789123456789"; { for (size_t i = 0; i < 10; i++) { std::string f = "hello " + boost::lexical_cast<std::string>(i); stream += "\r\n--" + boundary + "\r\n"; if (i % 2 == 0) stream += "Content-Length: " + boost::lexical_cast<std::string>(f.size()) + "\r\n"; stream += "Content-Type: toto " + boost::lexical_cast<std::string>(i) + "\r\n\r\n"; stream += f; } stream += "\r\n--" + boundary + "--"; stream += "GARBAGE"; } for (unsigned int k = 0; k < 2; k++) { MultipartTester decoded; Orthanc::MultipartStreamReader reader(boundary); reader.SetBlockSize(1); reader.SetHandler(decoded); if (k == 0) { for (size_t i = 0; i < stream.size(); i++) { reader.AddChunk(&stream[i], 1); } } else { reader.AddChunk(stream); } reader.CloseStream(); ASSERT_EQ(10u, decoded.GetCount()); for (size_t i = 0; i < 10; i++) { ASSERT_EQ("hello " + boost::lexical_cast<std::string>(i), decoded.GetData(i)); ASSERT_EQ("toto " + boost::lexical_cast<std::string>(i), decoded.GetHeaders(i)["content-type"]); if (i % 2 == 0) { ASSERT_EQ(2u, decoded.GetHeaders(i).size()); ASSERT_TRUE(decoded.GetHeaders(i).find("content-length") != decoded.GetHeaders(i).end()); } } } } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }