I'm trying to write my own shared_ptr/weak_ptr implementation in C++. I have the following requirements:
I do NOT need support for the following:
- multithreading (synchronisation)
- support for polymorphic types as the templated type of the shared_ptr (such as shared_ptr Base*)
Reasons for wanting to write my own implementation:
- need to supply a separate allocator for the control block
- need to reduce the size of the control block (the standard version has a very large control block due to its support for multithreading and polymorphism among other things)
Concerns:
- I'm worried about using my current implementation in production code (need some suggestions on how best to thoroughly test it)
- I'm concerned I may have left out important features from my implementation
The following is what I've written so far (compilable in a C++11 compliant compiler, with main function and example):
#include <iostream>
#include <memory>
struct shared_ptr_control_base
{
virtual ~shared_ptr_control_base() { }
void decrement_count_shared() noexcept { m_shared--; }
void increment_count_shared() noexcept { m_shared++; }
void decrement_count_weak() noexcept { m_weak--; }
void increment_count_weak() noexcept { m_weak++; }
virtual void destroy_shared(void*) noexcept = 0;
virtual void destruct() noexcept = 0;
virtual shared_ptr_control_base* create() const = 0;
uint32_t m_shared = 1;
uint32_t m_weak = 0;
};
template <typename SharedType, typename AllocatorType> struct shared_ptr_control_derived : shared_ptr_control_base
{
shared_ptr_control_derived() = delete;
shared_ptr_control_derived(AllocatorType a_allocator) : m_allocator(a_allocator) { }
shared_ptr_control_derived<SharedType, AllocatorType>* create() const
{
auto l_alloctor = std::allocator<shared_ptr_control_derived<SharedType, AllocatorType>>();
auto l_p = l_alloctor.allocate(1);
l_alloctor.construct(l_p, *this);
return l_p;
}
void destroy_shared(void* a_pointer) noexcept
{
m_allocator.destroy(static_cast<SharedType*>(a_pointer));
m_allocator.deallocate(static_cast<SharedType*>(a_pointer), 1);
}
void destruct() noexcept
{
auto l_alloctor = std::allocator<shared_ptr_control_derived<SharedType, AllocatorType>>();
l_alloctor.destroy(this);
l_alloctor.deallocate(this, 1);
}
mutable AllocatorType m_allocator;
};
template <typename SharedType> struct shared_ptr;
template <typename SharedType> struct weak_ptr
{
friend struct shared_ptr<SharedType>;
weak_ptr() : m_pointer(nullptr), m_control(nullptr) { }
weak_ptr(const weak_ptr<SharedType>& a_that) :
m_pointer(a_that.m_pointer),
m_control(a_that.m_control)
{
std::cout << "weak_ptr<T>::weak_ptr(const weak_ptr<T>&)" << std::endl;
if (m_control != nullptr)
{
m_control->increment_count_weak();
}
}
weak_ptr(weak_ptr<SharedType>&& a_that) :
m_pointer(a_that.m_pointer),
m_control(a_that.m_control)
{
std::cout << "weak_ptr<T>::weak_ptr(shared_ptr<T>&&)" << std::endl;
a_that.m_pointer = nullptr;
a_that.m_control = nullptr;
}
weak_ptr(const shared_ptr<SharedType>& a_that) :
m_pointer(a_that.m_pointer),
m_control(a_that.m_control)
{
std::cout << "weak_ptr<T>::weak_ptr(const shared_ptr<T>&)" << std::endl;
if (m_control != nullptr)
{
m_control->increment_count_weak();
}
}
weak_ptr<SharedType>& operator=(const weak_ptr<SharedType>& a_rhs)
{
std::cout << "weak_ptr<T>& weak_ptr<T>::operator = (const weak_ptr<T>&)" << std::endl;
if (a_rhs.m_control != m_control)
{
if (m_control != nullptr) { decrement_destruct(); }
m_pointer = a_rhs.m_pointer;
m_control = a_rhs.m_control;
if (m_control != nullptr) { m_control->increment_count_weak(); }
}
return *this;
}
weak_ptr<SharedType>& operator=(weak_ptr<SharedType>&& a_rhs)
{
std::cout << "weak_ptr<T>& weak_ptr<T>::operator = (weak_ptr<T>&&)" << std::endl;
if (a_rhs.m_control != m_control)
{
if (m_control != nullptr) { decrement_destruct(); }
}
m_pointer = a_rhs.m_pointer;
m_control = a_rhs.m_control;
a_rhs.m_pointer = nullptr;
a_rhs.m_control = nullptr;
return *this;
}
weak_ptr<SharedType>& operator=(const shared_ptr<SharedType>& a_rhs)
{
std::cout << "weak_ptr<T>& weak_ptr<T>::operator = (const shared_ptr<T>&)" << std::endl;
if (a_rhs.m_control != m_control)
{
if (m_control != nullptr) { decrement_destruct(); }
m_pointer = a_rhs.m_pointer;
m_control = a_rhs.m_control;
if (m_control != nullptr) { m_control->increment_count_weak(); }
}
return *this;
}
~weak_ptr()
{
if (m_control) { decrement_destruct(); }
}
void decrement_destruct()
{
m_control->decrement_count_weak();
if (m_control->m_weak == 0)
{
if (m_control->m_shared == 0)
{
std::cout << "weak_ptr -> destructing control block" << std::endl;
m_control->destruct();
}
}
}
SharedType* operator->() const noexcept { return m_pointer; }
SharedType& operator*() const noexcept { return *m_pointer; }
explicit operator bool() const noexcept { return m_control ? m_control->m_shared : false; }
uint32_t use_count() const noexcept { return m_control ? m_control->m_shared : 0; }
SharedType* get() const noexcept { return m_pointer; };
private:
SharedType* m_pointer;
shared_ptr_control_base* m_control;
};
template <typename SharedType> struct shared_ptr
{
friend struct weak_ptr<SharedType>;
shared_ptr() : m_pointer(nullptr), m_control(nullptr) { }
explicit shared_ptr(SharedType* const a_pointer) :
m_pointer(a_pointer),
m_control(nullptr)
{
std::cout << "shared_ptr<T>::shared_ptr(T*)" << std::endl;
if (m_pointer != nullptr) { create_control(std::allocator<SharedType>()); }
}
template <typename AllocatorType> explicit shared_ptr(SharedType* const a_pointer, const AllocatorType a_allocator) :
m_pointer(a_pointer),
m_control(nullptr)
{
if (m_pointer != nullptr) { create_control(a_allocator); }
}
shared_ptr(const shared_ptr<SharedType>& a_that) :
m_pointer(a_that.m_pointer),
m_control(a_that.m_control)
{
std::cout << "shared_ptr<T>::shared_ptr(const shared_ptr<T>&)" << std::endl;
if (m_control != nullptr)
{
m_control->increment_count_shared();
}
}
shared_ptr<SharedType>& operator=(const shared_ptr<SharedType>& a_rhs)
{
std::cout << "shared_ptr<T>& shared_ptr<T>::operator = (const shared_ptr<T>&)" << std::endl;
if (a_rhs.m_control != m_control)
{
if (m_control != nullptr) { decrement_destruct(); }
m_pointer = a_rhs.m_pointer;
m_control = a_rhs.m_control;
if (m_control != nullptr) { m_control->increment_count_shared(); }
}
return *this;
}
~shared_ptr()
{
if (m_control) { decrement_destruct(); }
}
SharedType* operator->() const noexcept { return m_pointer; }
SharedType& operator*() const noexcept { return *m_pointer; }
explicit operator bool() const noexcept { return m_pointer != nullptr; }
uint32_t use_count() const noexcept { return m_control ? m_control->m_shared : 0; }
void decrement_destruct()
{
m_control->decrement_count_shared();
if (m_control->m_shared == 0)
{
std::cout << "shared_ptr -> destructing shared object" << std::endl;
m_control->destroy_shared(m_pointer);
if (m_control->m_weak == 0) { std::cout << "shared_ptr -> destructing control block" << std::endl; m_control->destruct(); }
}
}
void reset() noexcept { shared_ptr<SharedType>().swap(*this); }
void reset(SharedType* const a_pointer) noexcept { shared_ptr<SharedType>(a_pointer).swap(*this); }
void swap(shared_ptr<SharedType>& a_that) noexcept { std::swap(m_pointer, a_that.m_pointer); std::swap(m_control, a_that.m_control); }
template <typename AllocatorType> void create_control(AllocatorType a_allocator)
{
m_control = shared_ptr_control_derived<SharedType, AllocatorType>(a_allocator).create();
}
SharedType* get() const noexcept { return m_pointer; };
private:
SharedType* m_pointer;
shared_ptr_control_base* m_control;
};
int* allocate(const int a_argument)
{
std::allocator<int> l_a;
auto l_p = l_a.allocate(1);
l_a.construct(l_p, a_argument);
return l_p;
}
int main()
{
// don't use this yet, as it might crash the program
weak_ptr<int> weak_1;
{
// allocate memory for an int, and take shared ownership of the memory in the shared_ptr
shared_ptr<int> shared_1(allocate(42));
// set the weak_ptr to refer to the memory in the shared_ptr;
weak_1 = shared_1;
if (weak_1)
{
std::cout << "weak_1 is safe to use" << std::endl;
*(weak_1.get()) = 47;
*weak_1 = 42;
std::cout << *weak_1 << std::endl;
}
}
// shared_1 went out of scope, so was destroyed
std::cout << "weak_1 control block's shared count: " << weak_1.use_count() << std::endl;
if (!weak_1)
{
std::cout << "weak_1 is NOT safe to use" << std::endl;
}
}
std::unique_ptr<>should be enough for sequential code, given that ownership semantics are well thought out. \$\endgroup\$