view UnitTestsSources/TestMessageBroker2.cpp @ 299:3897f9f28cfa am-callable-and-promise

backup work in progress: updated messaging framework with ICallable
author am@osimis.io
date Fri, 14 Sep 2018 16:44:01 +0200
parents
children b4abaeb783b1
line wrap: on
line source

/**
 * Stone of Orthanc
 * Copyright (C) 2012-2016 Sebastien Jodogne, Medical Physics
 * Department, University Hospital of Liege, Belgium
 * Copyright (C) 2017-2018 Osimis S.A., Belgium
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU Affero General Public License
 * as published by the Free Software Foundation, either version 3 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 **/


#include "gtest/gtest.h"

#include "Framework/Messages/MessageBroker.h"

#include <boost/noncopyable.hpp>
#include <boost/function.hpp>
#include <boost/bind.hpp>

#include <string>
#include <map>
#include <set>

int testCounter = 0;
namespace {

  class IObserver;
  class IObservable;
  class Promise;

  enum MessageType
  {
    MessageType_Test1,
    MessageType_Test2,

    MessageType_CustomMessage,
    MessageType_LastGenericStoneMessage
  };

  struct IMessage  : public boost::noncopyable
  {
    MessageType messageType_;
  public:
    IMessage(const MessageType& messageType)
      : messageType_(messageType)
    {}
    virtual ~IMessage() {}

    virtual int GetType() const {return messageType_;}
  };


  struct ICustomMessage  : public IMessage
  {
    int customMessageType_;
  public:
    ICustomMessage(int customMessageType)
      : IMessage(MessageType_CustomMessage),
        customMessageType_(customMessageType)
    {}
    virtual ~ICustomMessage() {}

    virtual int GetType() const {return customMessageType_;}
  };


  // This is referencing an object and member function that can be notified
  // by an IObservable.  The object must derive from IO
  // The member functions must be of type "void Function(const IMessage& message)" or reference a derived class of IMessage
  class ICallable : public boost::noncopyable
  {
  public:
    virtual ~ICallable()
    {
    }

    virtual void Apply(const IMessage& message) = 0;

    virtual MessageType GetMessageType() const = 0;
    virtual IObserver* GetObserver() const = 0;
  };

  template <typename TObserver,
            typename TMessage>
  class Callable : public ICallable
  {
  private:
    typedef void (TObserver::* MemberFunction) (const TMessage&);

    TObserver&      observer_;
    MemberFunction  function_;

  public:
    Callable(TObserver& observer,
             MemberFunction function) :
      observer_(observer),
      function_(function)
    {
    }

    void ApplyInternal(const TMessage& message)
    {
      (observer_.*function_) (message);
    }

    virtual void Apply(const IMessage& message)
    {
      ApplyInternal(dynamic_cast<const TMessage&>(message));
    }

    virtual MessageType GetMessageType() const
    {
      return static_cast<MessageType>(TMessage::Type);
    }

    virtual IObserver* GetObserver() const
    {
      return &observer_;
    }
  };




  /*
   * This is a central message broker.  It keeps track of all observers and knows
   * when an observer is deleted.
   * This way, it can prevent an observable to send a message to a delete observer.
   */
  class MessageBroker : public boost::noncopyable
  {

    std::set<IObserver*> activeObservers_;  // the list of observers that are currently alive (that have not been deleted)

  public:

    void Register(IObserver& observer)
    {
      activeObservers_.insert(&observer);
    }

    void Unregister(IObserver& observer)
    {
      activeObservers_.erase(&observer);
    }

    bool IsActive(IObserver* observer)
    {
      return activeObservers_.find(observer) != activeObservers_.end();
    }
  };


  class Promise : public boost::noncopyable
  {
  protected:
    MessageBroker&                    broker_;

    ICallable* successCallable_;
    ICallable* failureCallable_;

  public:
    Promise(MessageBroker& broker)
      : broker_(broker),
        successCallable_(NULL),
        failureCallable_(NULL)
    {
    }

    void Success(const IMessage& message)
    {
      // check the target is still alive in the broker
      if (broker_.IsActive(successCallable_->GetObserver()))
      {
        successCallable_->Apply(message);
      }
    }

    void Failure(const IMessage& message)
    {
      // check the target is still alive in the broker
      if (broker_.IsActive(failureCallable_->GetObserver()))
      {
        failureCallable_->Apply(message);
      }
    }

    Promise& Then(ICallable* successCallable)
    {
      if (successCallable_ != NULL)
      {
        // TODO: throw throw new "Promise may only have a single success target"
      }
      successCallable_ = successCallable;
      return *this;
    }

    Promise& Else(ICallable* failureCallable)
    {
      if (failureCallable_ != NULL)
      {
        // TODO: throw throw new "Promise may only have a single failure target"
      }
      failureCallable_ = failureCallable;
      return *this;
    }

  };

  class IObserver : public boost::noncopyable
  {
  protected:
    MessageBroker&                    broker_;

  public:
    IObserver(MessageBroker& broker)
      : broker_(broker)
    {
      broker_.Register(*this);
    }

    virtual ~IObserver()
    {
      broker_.Unregister(*this);
    }

  };


  class IObservable : public boost::noncopyable
  {
  protected:
    MessageBroker&                     broker_;

    typedef std::map<int, std::set<ICallable*> >   Callables;
    Callables  callables_;
  public:

    IObservable(MessageBroker& broker)
      : broker_(broker)
    {
    }

    virtual ~IObservable()
    {
      for (Callables::const_iterator it = callables_.begin();
           it != callables_.end(); ++it)
      {
        for (std::set<ICallable*>::const_iterator
               it2 = it->second.begin(); it2 != it->second.end(); ++it2)
        {
          delete *it2;
        }
      }
    }

    void Register(ICallable* callable)
    {
      MessageType messageType = callable->GetMessageType();

      callables_[messageType].insert(callable);
    }

    void EmitMessage(const IMessage& message)
    {
      Callables::const_iterator found = callables_.find(message.GetType());

      if (found != callables_.end())
      {
        for (std::set<ICallable*>::const_iterator
               it = found->second.begin(); it != found->second.end(); ++it)
        {
          if (broker_.IsActive((*it)->GetObserver()))
          {
            (*it)->Apply(message);
          }
        }
      }
    }

  };


  enum CustomMessageType
  {
    CustomMessageType_First = MessageType_LastGenericStoneMessage + 1,

    CustomMessageType_Completed,
    CustomMessageType_Increment
  };

  class MyObservable : public IObservable
  {
  public:
    struct MyCustomMessage: public ICustomMessage
    {
      int payload_;
      enum
      {
        Type = CustomMessageType_Completed
      };

      MyCustomMessage(int payload)
        : ICustomMessage(Type),
          payload_(payload)
      {}
    };

    MyObservable(MessageBroker& broker)
      : IObservable(broker)
    {}

  };

  class MyObserver : public IObserver
  {
  public:
    MyObserver(MessageBroker& broker)
      : IObserver(broker)
    {}

    void HandleCompletedMessage(const MyObservable::MyCustomMessage& message)
    {
      testCounter += message.payload_;
    }

  };


  class MyPromiseSource : public IObservable
  {
    Promise* currentPromise_;
  public:
    struct MyPromiseMessage: public ICustomMessage
    {
      int increment;
      enum
      {
        Type = CustomMessageType_Increment
      };

      MyPromiseMessage(int increment)
        : ICustomMessage(Type),
          increment(increment)
      {}
    };

    MyPromiseSource(MessageBroker& broker)
      : IObservable(broker),
        currentPromise_(NULL)
    {}

    Promise& StartSomethingAsync()
    {
      currentPromise_ = new Promise(broker_);
      return *currentPromise_;
    }

    void CompleteSomethingAsyncWithSuccess(int payload)
    {
      currentPromise_->Success(MyPromiseMessage(payload));
      delete currentPromise_;
    }

    void CompleteSomethingAsyncWithFailure(int payload)
    {
      currentPromise_->Failure(MyPromiseMessage(payload));
      delete currentPromise_;
    }
  };


  class MyPromiseTarget : public IObserver
  {
  public:
    MyPromiseTarget(MessageBroker& broker)
      : IObserver(broker)
    {}

    void IncrementCounter(const MyPromiseSource::MyPromiseMessage& args)
    {
      testCounter += args.increment;
    }

    void DecrementCounter(const MyPromiseSource::MyPromiseMessage& args)
    {
      testCounter -= args.increment;
    }
  };
}


TEST(MessageBroker2, TestPermanentConnectionSimpleUseCase)
{
  MessageBroker broker;
  MyObservable  observable(broker);
  MyObserver    observer(broker);

  // create a permanent connection between an observable and an observer
  observable.Register(new Callable<MyObserver, MyObservable::MyCustomMessage>(observer, &MyObserver::HandleCompletedMessage));

  testCounter = 0;
  observable.EmitMessage(MyObservable::MyCustomMessage(12));
  ASSERT_EQ(12, testCounter);

  // the connection is permanent; if we emit the same message again, the observer will be notified again
  testCounter = 0;
  observable.EmitMessage(MyObservable::MyCustomMessage(20));
  ASSERT_EQ(20, testCounter);
}

TEST(MessageBroker2, TestPermanentConnectionDeleteObserver)
{
  MessageBroker broker;
  MyObservable  observable(broker);
  MyObserver*   observer = new MyObserver(broker);

  // create a permanent connection between an observable and an observer
  observable.Register(new Callable<MyObserver, MyObservable::MyCustomMessage>(*observer, &MyObserver::HandleCompletedMessage));

  testCounter = 0;
  observable.EmitMessage(MyObservable::MyCustomMessage(12));
  ASSERT_EQ(12, testCounter);

  // delete the observer and check that the callback is not called anymore
  delete observer;

  // the connection is permanent; if we emit the same message again, the observer will be notified again
  testCounter = 0;
  observable.EmitMessage(MyObservable::MyCustomMessage(20));
  ASSERT_EQ(0, testCounter);
}


TEST(MessageBroker2, TestPromiseSuccessFailure)
{
  MessageBroker broker;
  MyPromiseSource  source(broker);
  MyPromiseTarget target(broker);

  // test a successful promise
  source.StartSomethingAsync()
      .Then(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(target, &MyPromiseTarget::IncrementCounter))
      .Else(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(target, &MyPromiseTarget::DecrementCounter));

  testCounter = 0;
  source.CompleteSomethingAsyncWithSuccess(10);
  ASSERT_EQ(10, testCounter);

  // test a failing promise
  source.StartSomethingAsync()
      .Then(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(target, &MyPromiseTarget::IncrementCounter))
      .Else(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(target, &MyPromiseTarget::DecrementCounter));

  testCounter = 0;
  source.CompleteSomethingAsyncWithFailure(15);
  ASSERT_EQ(-15, testCounter);
}

TEST(MessageBroker2, TestPromiseDeleteTarget)
{
  MessageBroker broker;
  MyPromiseSource source(broker);
  MyPromiseTarget* target = new MyPromiseTarget(broker);

  // create the promise
  source.StartSomethingAsync()
      .Then(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(*target, &MyPromiseTarget::IncrementCounter))
      .Else(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(*target, &MyPromiseTarget::DecrementCounter));

  // delete the promise target
  delete target;

  // trigger the promise, make sure it does not throw and does not call the callback
  testCounter = 0;
  source.CompleteSomethingAsyncWithSuccess(10);
  ASSERT_EQ(0, testCounter);

  // test a failing promise
  source.StartSomethingAsync()
      .Then(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(*target, &MyPromiseTarget::IncrementCounter))
      .Else(new Callable<MyPromiseTarget, MyPromiseSource::MyPromiseMessage>(*target, &MyPromiseTarget::DecrementCounter));

  testCounter = 0;
  source.CompleteSomethingAsyncWithFailure(15);
  ASSERT_EQ(0, testCounter);
}



//#include <stdio.h>
//#include <boost/noncopyable.hpp>

//#include <string>
//#include <memory>
//#include <map>
//#include <set>

//enum MessageType
//{
//  MessageType_SeriesDownloaded = 1
//};


//class IMessage : public boost::noncopyable
//{
//private:
//  MessageType  type_;

//public:
//  IMessage(MessageType  type) :
//    type_(type)
//  {
//  }

//  virtual ~IMessage()
//  {
//  }

//  MessageType GetMessageType() const
//  {
//    return type_;
//  }
//};


//class IObserver : public boost::noncopyable
//{
//public:
//  virtual ~IObserver()
//  {
//  }
//};


//class SeriesDownloadedMessage : public IMessage
//{
//private:
//  std::string value_;

//public:
//  enum
//  {
//    Type = MessageType_SeriesDownloaded
//  };

//  SeriesDownloadedMessage(const std::string& value) :
//    IMessage(static_cast<MessageType>(Type)),
//    value_(value)
//  {
//  }

//  const std::string& GetValue() const
//  {
//    return value_;
//  }
//};


//class MyObserver : public IObserver
//{
//public:
//  void OnSeriesDownloaded(const SeriesDownloadedMessage& message)
//  {
//    printf("received: [%s]\n", message.GetValue().c_str());
//  }
//};



//class ICallable : public boost::noncopyable  // ne peut referencer que les classes de base
//{
//public:
//  virtual ~ICallable()
//  {
//  }

//  virtual void Apply(const IMessage& message) = 0;

//  virtual MessageType GetMessageType() const = 0;
//};



//template <typename Observer,
//          typename Message>
//class Callable : public ICallable
//{
//private:
//  typedef void (Observer::* MemberFunction) (const Message&);

//  Observer&       observer_;
//  MemberFunction  function_;

//public:
//  Callable(Observer& observer,
//           MemberFunction function) :
//    observer_(observer),
//    function_(function)
//  {
//  }

//  void ApplyInternal(const Message& message)
//  {
//    (observer_.*function_) (message);
//  }

//  virtual void Apply(const IMessage& message)
//  {
//    ApplyInternal(dynamic_cast<const Message&>(message));
//  }

//  virtual MessageType GetMessageType() const
//  {
//    return static_cast<MessageType>(Message::Type);
//  }
//};



//class IObservable : public boost::noncopyable
//{
//private:
//  typedef std::map<MessageType, std::set<ICallable*> >   Callables;

//  Callables  callables_;

//public:
//  virtual ~IObservable()
//  {
//    for (Callables::const_iterator it = callables_.begin();
//         it != callables_.end(); ++it)
//    {
//      for (std::set<ICallable*>::const_iterator
//             it2 = it->second.begin(); it2 != it->second.end(); ++it2)
//      {
//        delete *it2;
//      }
//    }
//  }

//  void Register(ICallable* callable)
//  {
//    MessageType type = callable->GetMessageType();

//    callables_[type].insert(callable);
//  }

//  void Emit(const IMessage& message) const
//  {
//    Callables::const_iterator found = callables_.find(message.GetMessageType());

//    if (found != callables_.end())
//    {
//      for (std::set<ICallable*>::const_iterator
//             it = found->second.begin(); it != found->second.end(); ++it)
//      {
//        (*it)->Apply(message);
//      }
//    }
//  }
//};




//int main()
//{
//  MyObserver observer;

//  SeriesDownloadedMessage message("coucou");

//  IObservable observable;
//  observable.Register(new Callable<MyObserver, SeriesDownloadedMessage>(observer, &MyObserver::OnSeriesDownloaded));
//  observable.Register(new Callable<MyObserver, SeriesDownloadedMessage>(observer, &MyObserver::OnSeriesDownloaded));

//  SeriesDownloadedMessage message2("hello");
//  observable.Emit(message2);

//  printf("%d\n", SeriesDownloadedMessage::Type);
//}