changeset 1873:e0966648ebd0

unit tests of SegmentTree
author Sebastien Jodogne <s.jodogne@gmail.com>
date Tue, 11 Jan 2022 15:36:04 +0100
parents db8a8a19b543
children 08f2476e8f5e
files OrthancStone/Sources/Toolbox/SegmentTree.cpp OrthancStone/Sources/Toolbox/SegmentTree.h UnitTestsSources/ComputationalGeometryTests.cpp
diffstat 3 files changed, 237 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/OrthancStone/Sources/Toolbox/SegmentTree.cpp	Tue Jan 11 12:15:22 2022 +0100
+++ b/OrthancStone/Sources/Toolbox/SegmentTree.cpp	Tue Jan 11 15:36:04 2022 +0100
@@ -104,9 +104,9 @@
   }
 
 
-  void SegmentTree::Visit(size_t low,
-                          size_t high,
-                          IVisitor& visitor)
+  void SegmentTree::VisitSegment(size_t low,
+                                 size_t high,
+                                 IVisitor& visitor) const
   {
     if (low >= high)
     {
@@ -124,7 +124,7 @@
     if (b <= bv &&
         ev <= e)
     {
-      // The interval of this node is fully inside the user-provided interval
+      // The segment of this node is fully inside the user-provided segment
       visitor.Visit(*this, true);
     }
     else if (!IsLeaf())
@@ -134,16 +134,71 @@
 
       if (b < middle)
       {
-        GetLeftChild().Visit(b, e, visitor);
+        GetLeftChild().VisitSegment(b, e, visitor);
       }
 
       if (middle < e)
       {
-        GetRightChild().Visit(b, e, visitor);
+        GetRightChild().VisitSegment(b, e, visitor);
       }
       
-      // The interval of this node only partially intersects the user-provided interval
+      // The segment of this node only partially intersects the user-provided segment
       visitor.Visit(*this, false);
     }
   }
+
+
+  const SegmentTree* SegmentTree::FindLeaf(size_t low) const
+  {
+    if (IsLeaf())
+    {
+      if (low == lowBound_)
+      {
+        return this;
+      }
+      else
+      {
+        return NULL;
+      }
+    }
+    else
+    {
+      size_t middle = (lowBound_ + highBound_) / 2;
+      if (low < middle)
+      {
+        return GetLeftChild().FindLeaf(low);
+      }
+      else
+      {
+        return GetRightChild().FindLeaf(low);
+      }
+    }
+  }
+
+
+  const SegmentTree* SegmentTree::FindNode(size_t low,
+                                           size_t high) const
+  {
+    if (low == lowBound_ &&
+        high == highBound_)
+    {
+      return this;
+    }
+    else if (IsLeaf())
+    {
+      return NULL;
+    }
+    else
+    {
+      size_t middle = (lowBound_ + highBound_) / 2;
+      if (low < middle)
+      {
+        return GetLeftChild().FindNode(low, high);
+      }
+      else
+      {
+        return GetRightChild().FindNode(low, high);
+      }
+    }
+  }
 }
--- a/OrthancStone/Sources/Toolbox/SegmentTree.h	Tue Jan 11 12:15:22 2022 +0100
+++ b/OrthancStone/Sources/Toolbox/SegmentTree.h	Tue Jan 11 15:36:04 2022 +0100
@@ -54,6 +54,8 @@
       {
       }
 
+      // "fullyInside" is true iff. the segment of "node" is fully
+      // inside the user-provided segment
       virtual void Visit(const SegmentTree& node,
                          bool fullyInside) = 0;
     };
@@ -99,10 +101,20 @@
 
     size_t CountNodes() const;
 
-    // This corresponds to both methods "INSERT()" and "DELETE()" from
-    // the reference textbook
-    void Visit(size_t low,
-               size_t high,
-               IVisitor& visitor);
+    /**
+     * Apply the given visitor to all the segments that intersect the
+     * [low,high] segment. This corresponds to both methods "INSERT()"
+     * and "DELETE()" from the reference textbook.
+     **/
+    void VisitSegment(size_t low,
+                      size_t high,
+                      IVisitor& visitor) const;
+
+    // For unit tests
+    const SegmentTree* FindLeaf(size_t low) const;
+
+    // For unit tests
+    const SegmentTree* FindNode(size_t low,
+                                size_t high) const;
   };
 }
--- a/UnitTestsSources/ComputationalGeometryTests.cpp	Tue Jan 11 12:15:22 2022 +0100
+++ b/UnitTestsSources/ComputationalGeometryTests.cpp	Tue Jan 11 15:36:04 2022 +0100
@@ -30,29 +30,69 @@
 
 namespace
 {
+  typedef Orthanc::SingleValueObject<int>  Counter;
+  
   class CounterFactory : public OrthancStone::SegmentTree::IPayloadFactory
   {
+  private:
+    int value_;
+    
   public:
-    virtual Orthanc::IDynamicObject* Create()
+    CounterFactory(int value) :
+      value_(value)
+    {
+    }
+    
+    virtual Orthanc::IDynamicObject* Create() ORTHANC_OVERRIDE
     {
-      return new Orthanc::SingleValueObject<int>(42);
+      return new Counter(value_);
+    }
+  };
+
+  class IncrementVisitor : public OrthancStone::SegmentTree::IVisitor
+  {
+  private:
+    int increment_;
+
+  public:
+    IncrementVisitor(int increment) :
+      increment_(increment)
+    {
+    }
+
+    virtual void Visit(const OrthancStone::SegmentTree& node,
+                       bool fullyInside) ORTHANC_OVERRIDE
+    {
+      if (fullyInside)
+      {
+        Counter& payload = node.GetTypedPayload<Counter>();
+
+        if (payload.GetValue() + increment_ < 0)
+        {
+          throw Orthanc::OrthancException(Orthanc::ErrorCode_InternalError);
+        }
+        else
+        {
+          payload.SetValue(payload.GetValue() + increment_);
+        }
+      }
     }
   };
 }
 
 
-TEST(SegmentTree, Basic)
+TEST(SegmentTree, Create)
 {
-  CounterFactory factory;
+  CounterFactory factory(42);
   OrthancStone::SegmentTree root(4u, 15u, factory);   // Check out Figure 1.1 (page 14) from textbook
   
   ASSERT_EQ(4u, root.GetLowBound());
   ASSERT_EQ(15u, root.GetHighBound());
   ASSERT_FALSE(root.IsLeaf());
-  ASSERT_EQ(42, root.GetTypedPayload< Orthanc::SingleValueObject<int> >().GetValue());
+  ASSERT_EQ(42, root.GetTypedPayload<Counter>().GetValue());
   ASSERT_EQ(21u, root.CountNodes());
 
-  OrthancStone::SegmentTree* n = &root.GetLeftChild();
+  const OrthancStone::SegmentTree* n = &root.GetLeftChild();
   ASSERT_EQ(4u, n->GetLowBound());
   ASSERT_EQ(9u, n->GetHighBound());
   ASSERT_FALSE(n->IsLeaf());
@@ -173,4 +213,116 @@
   ASSERT_EQ(15u, n->GetHighBound());
   ASSERT_TRUE(n->IsLeaf());
   ASSERT_EQ(1u, n->CountNodes());
+
+  ASSERT_TRUE(root.FindLeaf(3) == NULL);
+  for (size_t i = 4; i < 15; i++)
+  {
+    n = root.FindLeaf(i);
+    ASSERT_TRUE(n != NULL);
+    ASSERT_TRUE(n->IsLeaf());
+    ASSERT_EQ(i, n->GetLowBound());
+    ASSERT_EQ(i + 1, n->GetHighBound());
+    ASSERT_EQ(42, n->GetTypedPayload<Counter>().GetValue());
+  }
+  ASSERT_TRUE(root.FindLeaf(15) == NULL);
 }
+
+
+static bool CheckCounter(const OrthancStone::SegmentTree& node,
+                         int expectedValue)
+{
+  if (node.GetTypedPayload<Counter>().GetValue() != expectedValue)
+  {
+    return false;
+  }
+  else if (node.IsLeaf())
+  {
+    return true;
+  }
+  else
+  {
+    return (CheckCounter(node.GetLeftChild(), expectedValue) &&
+            CheckCounter(node.GetRightChild(), expectedValue));
+  }
+}
+
+
+#if 0
+static void Print(const OrthancStone::SegmentTree& node,
+                  unsigned int indent)
+{
+  for (size_t i = 0; i < indent; i++)
+    printf("    ");
+  printf("(%lu,%lu): %d\n", node.GetLowBound(), node.GetHighBound(), node.GetTypedPayload<Counter>().GetValue());
+  if (!node.IsLeaf())
+  {
+    Print(node.GetLeftChild(), indent + 1);
+    Print(node.GetRightChild(), indent + 1);
+  }
+}
+#endif
+
+
+TEST(SegmentTree, Visit)
+{
+  CounterFactory factory(0);
+  OrthancStone::SegmentTree root(4u, 15u, factory);   // Check out Figure 1.1 (page 14) from textbook
+
+  ASSERT_TRUE(CheckCounter(root, 0));
+
+  IncrementVisitor plus(1);
+  IncrementVisitor minus(-1);
+
+  root.VisitSegment(0, 20, plus);
+  ASSERT_EQ(1, root.GetTypedPayload<Counter>().GetValue());
+  ASSERT_TRUE(CheckCounter(root.GetLeftChild(), 0));
+  ASSERT_TRUE(CheckCounter(root.GetRightChild(), 0));
+
+  root.VisitSegment(0, 20, plus);
+  ASSERT_EQ(2, root.GetTypedPayload<Counter>().GetValue());
+  ASSERT_TRUE(CheckCounter(root.GetLeftChild(), 0));
+  ASSERT_TRUE(CheckCounter(root.GetRightChild(), 0));
+
+  root.VisitSegment(0, 20, minus);
+  root.VisitSegment(0, 20, minus);
+  ASSERT_TRUE(CheckCounter(root, 0));
+
+  root.VisitSegment(8, 11, plus);
+  ASSERT_EQ(0, root.FindNode(4, 15)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 6)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 5)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(5, 6)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(6, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(6, 7)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(7, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(7, 8)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(1, root.FindNode(8, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(9, 15)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(9, 12)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(1, root.FindNode(9, 10)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(10, 12)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(1, root.FindNode(10, 11)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(11, 12)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(12, 15)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(12, 13)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(13, 15)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(13, 14)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(14, 15)->GetTypedPayload<Counter>().GetValue());
+  
+  root.VisitSegment(9, 11, minus);
+  ASSERT_EQ(0, root.FindNode(4, 15)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 6)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(4, 5)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(5, 6)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(6, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(6, 7)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(7, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(0, root.FindNode(7, 8)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_EQ(1, root.FindNode(8, 9)->GetTypedPayload<Counter>().GetValue());
+  ASSERT_TRUE(CheckCounter(root.GetRightChild(), 0));
+
+  root.VisitSegment(8, 9, minus);
+  ASSERT_TRUE(CheckCounter(root, 0));
+}