diff options
| author | Zeno Albisser <zeno.albisser@digia.com> | 2013-08-15 21:46:11 +0200 |
|---|---|---|
| committer | Zeno Albisser <zeno.albisser@digia.com> | 2013-08-15 21:46:11 +0200 |
| commit | 679147eead574d186ebf3069647b4c23e8ccace6 (patch) | |
| tree | fc247a0ac8ff119f7c8550879ebb6d3dd8d1ff69 /chromium/net/socket | |
Initial import.
Diffstat (limited to 'chromium/net/socket')
86 files changed, 35130 insertions, 0 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc new file mode 100644 index 00000000000..cf13c5e439a --- /dev/null +++ b/chromium/net/socket/buffered_write_stream_socket.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/buffered_write_stream_socket.h" + +#include "base/bind.h" +#include "base/location.h" +#include "base/message_loop/message_loop.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" + +namespace net { + +namespace { + +void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) { + int old_capacity = dst->capacity(); + dst->SetCapacity(old_capacity + src_len); + memcpy(dst->StartOfBuffer() + old_capacity, src->data(), src_len); +} + +} // anonymous namespace + +BufferedWriteStreamSocket::BufferedWriteStreamSocket( + scoped_ptr<StreamSocket> socket_to_wrap) + : wrapped_socket_(socket_to_wrap.Pass()), + io_buffer_(new GrowableIOBuffer()), + backup_buffer_(new GrowableIOBuffer()), + weak_factory_(this), + callback_pending_(false), + wrapped_write_in_progress_(false), + error_(0) { +} + +BufferedWriteStreamSocket::~BufferedWriteStreamSocket() { +} + +int BufferedWriteStreamSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return wrapped_socket_->Read(buf, buf_len, callback); +} + +int BufferedWriteStreamSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (error_) { + return error_; + } + GrowableIOBuffer* idle_buffer = + wrapped_write_in_progress_ ? backup_buffer_.get() : io_buffer_.get(); + AppendBuffer(idle_buffer, buf, buf_len); + if (!callback_pending_) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&BufferedWriteStreamSocket::DoDelayedWrite, + weak_factory_.GetWeakPtr())); + callback_pending_ = true; + } + return buf_len; +} + +bool BufferedWriteStreamSocket::SetReceiveBufferSize(int32 size) { + return wrapped_socket_->SetReceiveBufferSize(size); +} + +bool BufferedWriteStreamSocket::SetSendBufferSize(int32 size) { + return wrapped_socket_->SetSendBufferSize(size); +} + +int BufferedWriteStreamSocket::Connect(const CompletionCallback& callback) { + return wrapped_socket_->Connect(callback); +} + +void BufferedWriteStreamSocket::Disconnect() { + wrapped_socket_->Disconnect(); +} + +bool BufferedWriteStreamSocket::IsConnected() const { + return wrapped_socket_->IsConnected(); +} + +bool BufferedWriteStreamSocket::IsConnectedAndIdle() const { + return wrapped_socket_->IsConnectedAndIdle(); +} + +int BufferedWriteStreamSocket::GetPeerAddress(IPEndPoint* address) const { + return wrapped_socket_->GetPeerAddress(address); +} + +int BufferedWriteStreamSocket::GetLocalAddress(IPEndPoint* address) const { + return wrapped_socket_->GetLocalAddress(address); +} + +const BoundNetLog& BufferedWriteStreamSocket::NetLog() const { + return wrapped_socket_->NetLog(); +} + +void BufferedWriteStreamSocket::SetSubresourceSpeculation() { + wrapped_socket_->SetSubresourceSpeculation(); +} + +void BufferedWriteStreamSocket::SetOmniboxSpeculation() { + wrapped_socket_->SetOmniboxSpeculation(); +} + +bool BufferedWriteStreamSocket::WasEverUsed() const { + return wrapped_socket_->WasEverUsed(); +} + +bool BufferedWriteStreamSocket::UsingTCPFastOpen() const { + return wrapped_socket_->UsingTCPFastOpen(); +} + +bool BufferedWriteStreamSocket::WasNpnNegotiated() const { + return wrapped_socket_->WasNpnNegotiated(); +} + +NextProto BufferedWriteStreamSocket::GetNegotiatedProtocol() const { + return wrapped_socket_->GetNegotiatedProtocol(); +} + +bool BufferedWriteStreamSocket::GetSSLInfo(SSLInfo* ssl_info) { + return wrapped_socket_->GetSSLInfo(ssl_info); +} + +void BufferedWriteStreamSocket::DoDelayedWrite() { + int result = wrapped_socket_->Write( + io_buffer_.get(), + io_buffer_->RemainingCapacity(), + base::Bind(&BufferedWriteStreamSocket::OnIOComplete, + base::Unretained(this))); + if (result == ERR_IO_PENDING) { + callback_pending_ = true; + wrapped_write_in_progress_ = true; + } else { + OnIOComplete(result); + } +} + +void BufferedWriteStreamSocket::OnIOComplete(int result) { + callback_pending_ = false; + wrapped_write_in_progress_ = false; + if (backup_buffer_->RemainingCapacity()) { + AppendBuffer(io_buffer_.get(), backup_buffer_.get(), + backup_buffer_->RemainingCapacity()); + backup_buffer_->SetCapacity(0); + } + if (result < 0) { + error_ = result; + io_buffer_->SetCapacity(0); + } else { + io_buffer_->set_offset(io_buffer_->offset() + result); + if (io_buffer_->RemainingCapacity()) { + DoDelayedWrite(); + } else { + io_buffer_->SetCapacity(0); + } + } +} + +} // namespace net diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h new file mode 100644 index 00000000000..aad5736d0b0 --- /dev/null +++ b/chromium/net/socket/buffered_write_stream_socket.h @@ -0,0 +1,82 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ +#define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ + +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "net/base/net_log.h" +#include "net/socket/stream_socket.h" + +namespace base { +class TimeDelta; +} + +namespace net { + +class AddressList; +class GrowableIOBuffer; +class IPEndPoint; + +// A StreamSocket decorator. All functions are passed through to the wrapped +// socket, except for Write(). +// +// Writes are buffered locally so that multiple Write()s to this class are +// issued as only one Write() to the wrapped socket. This is useful to force +// multiple requests to be issued in a single packet, as is needed to trigger +// edge cases in HTTP pipelining. +// +// Note that the Write() always returns synchronously. It will either buffer the +// entire input or return the most recently reported error. +// +// There are no bounds on the local buffer size. Use carefully. +class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { + public: + explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap); + virtual ~BufferedWriteStreamSocket(); + + // Socket interface + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + // StreamSocket interface + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + private: + void DoDelayedWrite(); + void OnIOComplete(int result); + + scoped_ptr<StreamSocket> wrapped_socket_; + scoped_refptr<GrowableIOBuffer> io_buffer_; + scoped_refptr<GrowableIOBuffer> backup_buffer_; + base::WeakPtrFactory<BufferedWriteStreamSocket> weak_factory_; + bool callback_pending_; + bool wrapped_write_in_progress_; + int error_; + + DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_SOCKET_H_ diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc new file mode 100644 index 00000000000..485295f33f6 --- /dev/null +++ b/chromium/net/socket/buffered_write_stream_socket_unittest.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/buffered_write_stream_socket.h" + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +class BufferedWriteStreamSocketTest : public testing::Test { + public: + void Finish() { + base::MessageLoop::current()->RunUntilIdle(); + EXPECT_TRUE(data_->at_read_eof()); + EXPECT_TRUE(data_->at_write_eof()); + } + + void Initialize(MockWrite* writes, size_t writes_count) { + data_.reset(new DeterministicSocketData(NULL, 0, writes, writes_count)); + data_->set_connect_data(MockConnect(SYNCHRONOUS, 0)); + if (writes_count) { + data_->StopAfter(writes_count); + } + scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket( + new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get())); + data_->set_delegate(wrapped_socket->AsWeakPtr()); + socket_.reset(new BufferedWriteStreamSocket( + wrapped_socket.PassAs<StreamSocket>())); + socket_->Connect(callback_.callback()); + } + + void TestWrite(const char* text) { + scoped_refptr<StringIOBuffer> buf(new StringIOBuffer(text)); + EXPECT_EQ(buf->size(), + socket_->Write(buf.get(), buf->size(), callback_.callback())); + } + + scoped_ptr<BufferedWriteStreamSocket> socket_; + scoped_ptr<DeterministicSocketData> data_; + BoundNetLog net_log_; + TestCompletionCallback callback_; +}; + +TEST_F(BufferedWriteStreamSocketTest, SingleWrite) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, 0, "abc"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abc"); + Finish(); +} + +TEST_F(BufferedWriteStreamSocketTest, AsyncWrite) { + MockWrite writes[] = { + MockWrite(ASYNC, 0, "abc"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abc"); + data_->Run(); + Finish(); +} + +TEST_F(BufferedWriteStreamSocketTest, TwoWritesIntoOne) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, 0, "abcdef"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abc"); + TestWrite("def"); + Finish(); +} + +TEST_F(BufferedWriteStreamSocketTest, WriteWhileBlocked) { + MockWrite writes[] = { + MockWrite(ASYNC, 0, "abc"), + MockWrite(ASYNC, 1, "def"), + MockWrite(ASYNC, 2, "ghi"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abc"); + base::MessageLoop::current()->RunUntilIdle(); + TestWrite("def"); + data_->RunFor(1); + TestWrite("ghi"); + data_->RunFor(1); + Finish(); +} + +TEST_F(BufferedWriteStreamSocketTest, ContinuesPartialWrite) { + MockWrite writes[] = { + MockWrite(ASYNC, 0, "abc"), + MockWrite(ASYNC, 1, "def"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abcdef"); + data_->Run(); + Finish(); +} + +TEST_F(BufferedWriteStreamSocketTest, TwoSeparateWrites) { + MockWrite writes[] = { + MockWrite(ASYNC, 0, "abc"), + MockWrite(ASYNC, 1, "def"), + }; + Initialize(writes, arraysize(writes)); + TestWrite("abc"); + data_->RunFor(1); + TestWrite("def"); + data_->RunFor(1); + Finish(); +} + +} // anonymous namespace + +} // namespace net diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc new file mode 100644 index 00000000000..a86688e3333 --- /dev/null +++ b/chromium/net/socket/client_socket_factory.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_factory.h" + +#include "base/lazy_instance.h" +#include "base/thread_task_runner_handle.h" +#include "base/threading/sequenced_worker_pool.h" +#include "build/build_config.h" +#include "net/cert/cert_database.h" +#include "net/socket/client_socket_handle.h" +#if defined(USE_OPENSSL) +#include "net/socket/ssl_client_socket_openssl.h" +#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) +#include "net/socket/ssl_client_socket_nss.h" +#endif +#include "net/socket/tcp_client_socket.h" +#include "net/udp/udp_client_socket.h" + +namespace net { + +class X509Certificate; + +namespace { + +// ChromeOS and Linux may require interaction with smart cards or TPMs, which +// may cause NSS functions to block for upwards of several seconds. To avoid +// blocking all activity on the current task runner, such as network or IPC +// traffic, run NSS SSL functions on a dedicated thread. +#if defined(OS_CHROMEOS) || defined(OS_LINUX) +bool g_use_dedicated_nss_thread = true; +#else +bool g_use_dedicated_nss_thread = false; +#endif + +class DefaultClientSocketFactory : public ClientSocketFactory, + public CertDatabase::Observer { + public: + DefaultClientSocketFactory() { + if (g_use_dedicated_nss_thread) { + // Use a single thread for the worker pool. + worker_pool_ = new base::SequencedWorkerPool(1, "NSS SSL Thread"); + nss_thread_task_runner_ = + worker_pool_->GetSequencedTaskRunnerWithShutdownBehavior( + worker_pool_->GetSequenceToken(), + base::SequencedWorkerPool::CONTINUE_ON_SHUTDOWN); + } + + CertDatabase::GetInstance()->AddObserver(this); + } + + virtual ~DefaultClientSocketFactory() { + // Note: This code never runs, as the factory is defined as a Leaky + // singleton. + CertDatabase::GetInstance()->RemoveObserver(this); + } + + virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE { + ClearSSLSessionCache(); + } + + virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE { + // Per wtc, we actually only need to flush when trust is reduced. + // Always flush now because OnCertTrustChanged does not tell us this. + // See comments in ClientSocketPoolManager::OnCertTrustChanged. + ClearSSLSessionCache(); + } + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE { + return scoped_ptr<DatagramClientSocket>( + new UDPClientSocket(bind_type, rand_int_cb, net_log, source)); + } + + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE { + return scoped_ptr<StreamSocket>( + new TCPClientSocket(addresses, net_log, source)); + } + + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE { + // nss_thread_task_runner_ may be NULL if g_use_dedicated_nss_thread is + // false or if the dedicated NSS thread failed to start. If so, cause NSS + // functions to execute on the current task runner. + // + // Note: The current task runner is obtained on each call due to unit + // tests, which may create and tear down the current thread's TaskRunner + // between each test. Because the DefaultClientSocketFactory is leaky, it + // may span multiple tests, and thus the current task runner may change + // from call to call. + scoped_refptr<base::SequencedTaskRunner> nss_task_runner( + nss_thread_task_runner_); + if (!nss_task_runner.get()) + nss_task_runner = base::ThreadTaskRunnerHandle::Get(); + +#if defined(USE_OPENSSL) + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port, + ssl_config, context)); +#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketNSS(nss_task_runner.get(), + transport_socket.Pass(), + host_and_port, + ssl_config, + context)); +#else + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); +#endif + } + + virtual void ClearSSLSessionCache() OVERRIDE { + SSLClientSocket::ClearSessionCache(); + } + + private: + scoped_refptr<base::SequencedWorkerPool> worker_pool_; + scoped_refptr<base::SequencedTaskRunner> nss_thread_task_runner_; +}; + +static base::LazyInstance<DefaultClientSocketFactory>::Leaky + g_default_client_socket_factory = LAZY_INSTANCE_INITIALIZER; + +} // namespace + +// static +ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() { + return g_default_client_socket_factory.Pointer(); +} + +} // namespace net diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h new file mode 100644 index 00000000000..6cb5949f0b3 --- /dev/null +++ b/chromium/net/socket/client_socket_factory.h @@ -0,0 +1,65 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CLIENT_SOCKET_FACTORY_H_ +#define NET_SOCKET_CLIENT_SOCKET_FACTORY_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/base/rand_callback.h" +#include "net/udp/datagram_socket.h" + +namespace net { + +class AddressList; +class ClientSocketHandle; +class DatagramClientSocket; +class HostPortPair; +class SSLClientSocket; +struct SSLClientSocketContext; +struct SSLConfig; +class StreamSocket; + +// An interface used to instantiate StreamSocket objects. Used to facilitate +// testing code with mock socket implementations. +class NET_EXPORT ClientSocketFactory { + public: + virtual ~ClientSocketFactory() {} + + // |source| is the NetLog::Source for the entity trying to create the socket, + // if it has one. + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) = 0; + + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source) = 0; + + // It is allowed to pass in a |transport_socket| that is not obtained from a + // socket pool. The caller could create a ClientSocketHandle directly and call + // set_socket() on it to set a valid StreamSocket instance. + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) = 0; + + // Clears cache used for SSL session resumption. + virtual void ClearSSLSessionCache() = 0; + + // Returns the default ClientSocketFactory. + static ClientSocketFactory* GetDefaultFactory(); +}; + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_FACTORY_H_ diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc new file mode 100644 index 00000000000..acb896b36f5 --- /dev/null +++ b/chromium/net/socket/client_socket_handle.cc @@ -0,0 +1,180 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_handle.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/compiler_specific.h" +#include "base/metrics/histogram.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_histograms.h" + +namespace net { + +ClientSocketHandle::ClientSocketHandle() + : is_initialized_(false), + pool_(NULL), + layered_pool_(NULL), + is_reused_(false), + callback_(base::Bind(&ClientSocketHandle::OnIOComplete, + base::Unretained(this))), + is_ssl_error_(false) {} + +ClientSocketHandle::~ClientSocketHandle() { + Reset(); +} + +void ClientSocketHandle::Reset() { + ResetInternal(true); + ResetErrorState(); +} + +void ClientSocketHandle::ResetInternal(bool cancel) { + if (group_name_.empty()) // Was Init called? + return; + if (is_initialized()) { + // Because of http://crbug.com/37810 we may not have a pool, but have + // just a raw socket. + socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); + if (pool_) + // If we've still got a socket, release it back to the ClientSocketPool so + // it can be deleted or reused. + pool_->ReleaseSocket(group_name_, PassSocket(), pool_id_); + } else if (cancel) { + // If we did not get initialized yet, we've got a socket request pending. + // Cancel it. + pool_->CancelRequest(group_name_, this); + } + is_initialized_ = false; + group_name_.clear(); + is_reused_ = false; + user_callback_.Reset(); + if (layered_pool_) { + pool_->RemoveLayeredPool(layered_pool_); + layered_pool_ = NULL; + } + pool_ = NULL; + idle_time_ = base::TimeDelta(); + init_time_ = base::TimeTicks(); + setup_time_ = base::TimeDelta(); + connect_timing_ = LoadTimingInfo::ConnectTiming(); + pool_id_ = -1; +} + +void ClientSocketHandle::ResetErrorState() { + is_ssl_error_ = false; + ssl_error_response_info_ = HttpResponseInfo(); + pending_http_proxy_connection_.reset(); +} + +LoadState ClientSocketHandle::GetLoadState() const { + CHECK(!is_initialized()); + CHECK(!group_name_.empty()); + // Because of http://crbug.com/37810 we may not have a pool, but have + // just a raw socket. + if (!pool_) + return LOAD_STATE_IDLE; + return pool_->GetLoadState(group_name_, this); +} + +bool ClientSocketHandle::IsPoolStalled() const { + return pool_->IsStalled(); +} + +void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) { + CHECK(layered_pool); + CHECK(!layered_pool_); + if (pool_) { + pool_->AddLayeredPool(layered_pool); + layered_pool_ = layered_pool; + } +} + +void ClientSocketHandle::RemoveLayeredPool(LayeredPool* layered_pool) { + CHECK(layered_pool); + CHECK(layered_pool_); + if (pool_) { + pool_->RemoveLayeredPool(layered_pool); + layered_pool_ = NULL; + } +} + +bool ClientSocketHandle::GetLoadTimingInfo( + bool is_reused, + LoadTimingInfo* load_timing_info) const { + // Only return load timing information when there's a socket. + if (!socket_) + return false; + + load_timing_info->socket_log_id = socket_->NetLog().source().id; + load_timing_info->socket_reused = is_reused; + + // No times if the socket is reused. + if (is_reused) + return true; + + load_timing_info->connect_timing = connect_timing_; + return true; +} + +void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) { + socket_ = s.Pass(); +} + +void ClientSocketHandle::OnIOComplete(int result) { + CompletionCallback callback = user_callback_; + user_callback_.Reset(); + HandleInitCompletion(result); + callback.Run(result); +} + +scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() { + return socket_.Pass(); +} + +void ClientSocketHandle::HandleInitCompletion(int result) { + CHECK_NE(ERR_IO_PENDING, result); + ClientSocketPoolHistograms* histograms = pool_->histograms(); + histograms->AddErrorCode(result); + if (result != OK) { + if (!socket_.get()) + ResetInternal(false); // Nothing to cancel since the request failed. + else + is_initialized_ = true; + return; + } + is_initialized_ = true; + CHECK_NE(-1, pool_id_) << "Pool should have set |pool_id_| to a valid value."; + setup_time_ = base::TimeTicks::Now() - init_time_; + + histograms->AddSocketType(reuse_type()); + switch (reuse_type()) { + case ClientSocketHandle::UNUSED: + histograms->AddRequestTime(setup_time()); + break; + case ClientSocketHandle::UNUSED_IDLE: + histograms->AddUnusedIdleTime(idle_time()); + break; + case ClientSocketHandle::REUSED_IDLE: + histograms->AddReusedIdleTime(idle_time()); + break; + default: + NOTREACHED(); + break; + } + + // Broadcast that the socket has been acquired. + // TODO(eroman): This logging is not complete, in particular set_socket() and + // release() socket. It ends up working though, since those methods are being + // used to layer sockets (and the destination sources are the same). + DCHECK(socket_.get()); + socket_->NetLog().BeginEvent( + NetLog::TYPE_SOCKET_IN_USE, + requesting_source_.ToEventParametersCallback()); +} + +} // namespace net diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h new file mode 100644 index 00000000000..9651f089a86 --- /dev/null +++ b/chromium/net/socket/client_socket_handle.h @@ -0,0 +1,241 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CLIENT_SOCKET_HANDLE_H_ +#define NET_SOCKET_CLIENT_SOCKET_HANDLE_H_ + +#include <string> + +#include "base/logging.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/completion_callback.h" +#include "net/base/load_states.h" +#include "net/base/load_timing_info.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/base/request_priority.h" +#include "net/http/http_response_info.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/stream_socket.h" + +namespace net { + +// A container for a StreamSocket. +// +// The handle's |group_name| uniquely identifies the origin and type of the +// connection. It is used by the ClientSocketPool to group similar connected +// client socket objects. +// +class NET_EXPORT ClientSocketHandle { + public: + enum SocketReuseType { + UNUSED = 0, // unused socket that just finished connecting + UNUSED_IDLE, // unused socket that has been idle for awhile + REUSED_IDLE, // previously used socket + NUM_TYPES, + }; + + ClientSocketHandle(); + ~ClientSocketHandle(); + + // Initializes a ClientSocketHandle object, which involves talking to the + // ClientSocketPool to obtain a connected socket, possibly reusing one. This + // method returns either OK or ERR_IO_PENDING. On ERR_IO_PENDING, |priority| + // is used to determine the placement in ClientSocketPool's wait list. + // + // If this method succeeds, then the socket member will be set to an existing + // connected socket if an existing connected socket was available to reuse, + // otherwise it will be set to a new connected socket. Consumers can then + // call is_reused() to see if the socket was reused. If not reusing an + // existing socket, ClientSocketPool may need to establish a new + // connection using |socket_params|. + // + // This method returns ERR_IO_PENDING if it cannot complete synchronously, in + // which case the consumer will be notified of completion via |callback|. + // + // If the pool was not able to reuse an existing socket, the new socket + // may report a recoverable error. In this case, the return value will + // indicate an error and the socket member will be set. If it is determined + // that the error is not recoverable, the Disconnect method should be used + // on the socket, so that it does not get reused. + // + // A non-recoverable error may set additional state in the ClientSocketHandle + // to allow the caller to determine what went wrong. + // + // Init may be called multiple times. + // + // Profiling information for the request is saved to |net_log| if non-NULL. + // + template <typename SocketParams, typename PoolType> + int Init(const std::string& group_name, + const scoped_refptr<SocketParams>& socket_params, + RequestPriority priority, + const CompletionCallback& callback, + PoolType* pool, + const BoundNetLog& net_log); + + // An initialized handle can be reset, which causes it to return to the + // un-initialized state. This releases the underlying socket, which in the + // case of a socket that still has an established connection, indicates that + // the socket may be kept alive for use by a subsequent ClientSocketHandle. + // + // NOTE: To prevent the socket from being kept alive, be sure to call its + // Disconnect method. This will result in the ClientSocketPool deleting the + // StreamSocket. + void Reset(); + + // Used after Init() is called, but before the ClientSocketPool has + // initialized the ClientSocketHandle. + LoadState GetLoadState() const; + + bool IsPoolStalled() const; + + void AddLayeredPool(LayeredPool* layered_pool); + + void RemoveLayeredPool(LayeredPool* layered_pool); + + // Returns true when Init() has completed successfully. + bool is_initialized() const { return is_initialized_; } + + // Returns the time tick when Init() was called. + base::TimeTicks init_time() const { return init_time_; } + + // Returns the time between Init() and when is_initialized() becomes true. + base::TimeDelta setup_time() const { return setup_time_; } + + // Sets the portion of LoadTimingInfo related to connection establishment, and + // the socket id. |is_reused| is needed because the handle may not have full + // reuse information. |load_timing_info| must have all default values when + // called. Returns false and makes no changes to |load_timing_info| when + // |socket_| is NULL. + bool GetLoadTimingInfo(bool is_reused, + LoadTimingInfo* load_timing_info) const; + + // Used by ClientSocketPool to initialize the ClientSocketHandle. + void SetSocket(scoped_ptr<StreamSocket> s); + void set_is_reused(bool is_reused) { is_reused_ = is_reused; } + void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; } + void set_pool_id(int id) { pool_id_ = id; } + void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; } + void set_ssl_error_response_info(const HttpResponseInfo& ssl_error_state) { + ssl_error_response_info_ = ssl_error_state; + } + void set_pending_http_proxy_connection(ClientSocketHandle* connection) { + pending_http_proxy_connection_.reset(connection); + } + + // Only valid if there is no |socket_|. + bool is_ssl_error() const { + DCHECK(socket_.get() == NULL); + return is_ssl_error_; + } + // On an ERR_PROXY_AUTH_REQUESTED error, the |headers| and |auth_challenge| + // fields are filled in. On an ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, + // the |cert_request_info| field is set. + const HttpResponseInfo& ssl_error_response_info() const { + return ssl_error_response_info_; + } + ClientSocketHandle* release_pending_http_proxy_connection() { + return pending_http_proxy_connection_.release(); + } + + // These may only be used if is_initialized() is true. + scoped_ptr<StreamSocket> PassSocket(); + StreamSocket* socket() { return socket_.get(); } + const std::string& group_name() const { return group_name_; } + int id() const { return pool_id_; } + bool is_reused() const { return is_reused_; } + base::TimeDelta idle_time() const { return idle_time_; } + SocketReuseType reuse_type() const { + if (is_reused()) { + return REUSED_IDLE; + } else if (idle_time() == base::TimeDelta()) { + return UNUSED; + } else { + return UNUSED_IDLE; + } + } + const LoadTimingInfo::ConnectTiming& connect_timing() const { + return connect_timing_; + } + void set_connect_timing(const LoadTimingInfo::ConnectTiming& connect_timing) { + connect_timing_ = connect_timing; + } + + private: + // Called on asynchronous completion of an Init() request. + void OnIOComplete(int result); + + // Called on completion (both asynchronous & synchronous) of an Init() + // request. + void HandleInitCompletion(int result); + + // Resets the state of the ClientSocketHandle. |cancel| indicates whether or + // not to try to cancel the request with the ClientSocketPool. Does not + // reset the supplemental error state. + void ResetInternal(bool cancel); + + // Resets the supplemental error state. + void ResetErrorState(); + + bool is_initialized_; + ClientSocketPool* pool_; + LayeredPool* layered_pool_; + scoped_ptr<StreamSocket> socket_; + std::string group_name_; + bool is_reused_; + CompletionCallback callback_; + CompletionCallback user_callback_; + base::TimeDelta idle_time_; + int pool_id_; // See ClientSocketPool::ReleaseSocket() for an explanation. + bool is_ssl_error_; + HttpResponseInfo ssl_error_response_info_; + scoped_ptr<ClientSocketHandle> pending_http_proxy_connection_; + base::TimeTicks init_time_; + base::TimeDelta setup_time_; + + NetLog::Source requesting_source_; + + // Timing information is set when a connection is successfully established. + LoadTimingInfo::ConnectTiming connect_timing_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketHandle); +}; + +// Template function implementation: +template <typename SocketParams, typename PoolType> +int ClientSocketHandle::Init(const std::string& group_name, + const scoped_refptr<SocketParams>& socket_params, + RequestPriority priority, + const CompletionCallback& callback, + PoolType* pool, + const BoundNetLog& net_log) { + requesting_source_ = net_log.source(); + + CHECK(!group_name.empty()); + // Note that this will result in a compile error if the SocketParams has not + // been registered for the PoolType via REGISTER_SOCKET_PARAMS_FOR_POOL + // (defined in client_socket_pool.h). + CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); + ResetInternal(true); + ResetErrorState(); + pool_ = pool; + group_name_ = group_name; + init_time_ = base::TimeTicks::Now(); + int rv = pool_->RequestSocket( + group_name, &socket_params, priority, this, callback_, net_log); + if (rv == ERR_IO_PENDING) { + user_callback_ = callback; + } else { + HandleInitCompletion(rv); + } + return rv; +} + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_HANDLE_H_ diff --git a/chromium/net/socket/client_socket_pool.cc b/chromium/net/socket/client_socket_pool.cc new file mode 100644 index 00000000000..0eebd11b9fb --- /dev/null +++ b/chromium/net/socket/client_socket_pool.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool.h" + +#include "base/logging.h" + +namespace { + +// The maximum duration, in seconds, to keep unused idle persistent sockets +// alive. +// TODO(ziadh): Change this timeout after getting histogram data on how long it +// should be. +int g_unused_idle_socket_timeout_s = 10; + +// The maximum duration, in seconds, to keep used idle persistent sockets alive. +int g_used_idle_socket_timeout_s = 300; // 5 minutes + +} // namespace + +namespace net { + +// static +base::TimeDelta ClientSocketPool::unused_idle_socket_timeout() { + return base::TimeDelta::FromSeconds(g_unused_idle_socket_timeout_s); +} + +// static +void ClientSocketPool::set_unused_idle_socket_timeout(base::TimeDelta timeout) { + DCHECK_GT(timeout.InSeconds(), 0); + g_unused_idle_socket_timeout_s = timeout.InSeconds(); +} + +// static +base::TimeDelta ClientSocketPool::used_idle_socket_timeout() { + return base::TimeDelta::FromSeconds(g_used_idle_socket_timeout_s); +} + +// static +void ClientSocketPool::set_used_idle_socket_timeout(base::TimeDelta timeout) { + DCHECK_GT(timeout.InSeconds(), 0); + g_used_idle_socket_timeout_s = timeout.InSeconds(); +} + +ClientSocketPool::ClientSocketPool() {} + +ClientSocketPool::~ClientSocketPool() {} + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h new file mode 100644 index 00000000000..af184547d6e --- /dev/null +++ b/chromium/net/socket/client_socket_pool.h @@ -0,0 +1,221 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_H_ + +#include <deque> +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/template_util.h" +#include "base/time/time.h" +#include "net/base/completion_callback.h" +#include "net/base/load_states.h" +#include "net/base/net_export.h" +#include "net/base/request_priority.h" +#include "net/dns/host_resolver.h" + +namespace base { +class DictionaryValue; +} + +namespace net { + +class ClientSocketHandle; +class ClientSocketPoolHistograms; +class StreamSocket; + +// ClientSocketPools are layered. This defines an interface for lower level +// socket pools to communicate with higher layer pools. +class NET_EXPORT LayeredPool { + public: + virtual ~LayeredPool() {}; + + // Instructs the LayeredPool to close an idle connection. Return true if one + // was closed. + virtual bool CloseOneIdleConnection() = 0; +}; + +// A ClientSocketPool is used to restrict the number of sockets open at a time. +// It also maintains a list of idle persistent sockets. +// +class NET_EXPORT ClientSocketPool { + public: + // Requests a connected socket for a group_name. + // + // There are five possible results from calling this function: + // 1) RequestSocket returns OK and initializes |handle| with a reused socket. + // 2) RequestSocket returns OK with a newly connected socket. + // 3) RequestSocket returns ERR_IO_PENDING. The handle will be added to a + // wait list until a socket is available to reuse or a new socket finishes + // connecting. |priority| will determine the placement into the wait list. + // 4) An error occurred early on, so RequestSocket returns an error code. + // 5) A recoverable error occurred while setting up the socket. An error + // code is returned, but the |handle| is initialized with the new socket. + // The caller must recover from the error before using the connection, or + // Disconnect the socket before releasing or resetting the |handle|. + // The current recoverable errors are: the errors accepted by + // IsCertificateError(err) and PROXY_AUTH_REQUESTED, or + // HTTPS_PROXY_TUNNEL_RESPONSE when reported by HttpProxyClientSocketPool. + // + // If this function returns OK, then |handle| is initialized upon return. + // The |handle|'s is_initialized method will return true in this case. If a + // StreamSocket was reused, then ClientSocketPool will call + // |handle|->set_reused(true). In either case, the socket will have been + // allocated and will be connected. A client might want to know whether or + // not the socket is reused in order to request a new socket if he encounters + // an error with the reused socket. + // + // If ERR_IO_PENDING is returned, then the callback will be used to notify the + // client of completion. + // + // Profiling information for the request is saved to |net_log| if non-NULL. + virtual int RequestSocket(const std::string& group_name, + const void* params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) = 0; + + // RequestSockets is used to request that |num_sockets| be connected in the + // connection group for |group_name|. If the connection group already has + // |num_sockets| idle sockets / active sockets / currently connecting sockets, + // then this function doesn't do anything. Otherwise, it will start up as + // many connections as necessary to reach |num_sockets| total sockets for the + // group. It uses |params| to control how to connect the sockets. The + // ClientSocketPool will assign a priority to the new connections, if any. + // This priority will probably be lower than all others, since this method + // is intended to make sure ahead of time that |num_sockets| sockets are + // available to talk to a host. + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) = 0; + + // Called to cancel a RequestSocket call that returned ERR_IO_PENDING. The + // same handle parameter must be passed to this method as was passed to the + // RequestSocket call being cancelled. The associated CompletionCallback is + // not run. However, for performance, we will let one ConnectJob complete + // and go idle. + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) = 0; + + // Called to release a socket once the socket is no longer needed. If the + // socket still has an established connection, then it will be added to the + // set of idle sockets to be used to satisfy future RequestSocket calls. + // Otherwise, the StreamSocket is destroyed. |id| is used to differentiate + // between updated versions of the same pool instance. The pool's id will + // change when it flushes, so it can use this |id| to discard sockets with + // mismatched ids. + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) = 0; + + // This flushes all state from the ClientSocketPool. This means that all + // idle and connecting sockets are discarded with the given |error|. + // Active sockets being held by ClientSocketPool clients will be discarded + // when released back to the pool. + // Does not flush any pools wrapped by |this|. + virtual void FlushWithError(int error) = 0; + + // Returns true if a there is currently a request blocked on the + // per-pool (not per-host) max socket limit. + virtual bool IsStalled() const = 0; + + // Called to close any idle connections held by the connection manager. + virtual void CloseIdleSockets() = 0; + + // The total number of idle sockets in the pool. + virtual int IdleSocketCount() const = 0; + + // The total number of idle sockets in a connection group. + virtual int IdleSocketCountInGroup(const std::string& group_name) const = 0; + + // Determine the LoadState of a connecting ClientSocketHandle. + virtual LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const = 0; + + // Adds a LayeredPool on top of |this|. + virtual void AddLayeredPool(LayeredPool* layered_pool) = 0; + + // Removes a LayeredPool from |this|. + virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0; + + // Retrieves information on the current state of the pool as a + // DictionaryValue. Caller takes possession of the returned value. + // If |include_nested_pools| is true, the states of any nested + // ClientSocketPools will be included. + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const = 0; + + // Returns the maximum amount of time to wait before retrying a connect. + static const int kMaxConnectRetryIntervalMs = 250; + + // The set of histograms specific to this pool. We can't use the standard + // UMA_HISTOGRAM_* macros because they are callsite static. + virtual ClientSocketPoolHistograms* histograms() const = 0; + + static base::TimeDelta unused_idle_socket_timeout(); + static void set_unused_idle_socket_timeout(base::TimeDelta timeout); + + static base::TimeDelta used_idle_socket_timeout(); + static void set_used_idle_socket_timeout(base::TimeDelta timeout); + + protected: + ClientSocketPool(); + virtual ~ClientSocketPool(); + + // Return the connection timeout for this pool. + virtual base::TimeDelta ConnectionTimeout() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(ClientSocketPool); +}; + +// ClientSocketPool subclasses should indicate valid SocketParams via the +// REGISTER_SOCKET_PARAMS_FOR_POOL macro below. By default, any given +// <PoolType,SocketParams> pair will have its SocketParamsTrait inherit from +// base::false_type, but REGISTER_SOCKET_PARAMS_FOR_POOL will specialize that +// pairing to inherit from base::true_type. This provides compile time +// verification that the correct SocketParams type is used with the appropriate +// PoolType. +template <typename PoolType, typename SocketParams> +struct SocketParamTraits : public base::false_type { +}; + +template <typename PoolType, typename SocketParams> +void CheckIsValidSocketParamsForPool() { + COMPILE_ASSERT(!base::is_pointer<scoped_refptr<SocketParams> >::value, + socket_params_cannot_be_pointer); + COMPILE_ASSERT((SocketParamTraits<PoolType, + scoped_refptr<SocketParams> >::value), + invalid_socket_params_for_pool); +} + +// Provides an empty definition for CheckIsValidSocketParamsForPool() which +// should be optimized out by the compiler. +#define REGISTER_SOCKET_PARAMS_FOR_POOL(pool_type, socket_params) \ +template<> \ +struct SocketParamTraits<pool_type, scoped_refptr<socket_params> > \ + : public base::true_type { \ +} + +template <typename PoolType, typename SocketParams> +void RequestSocketsForPool(PoolType* pool, + const std::string& group_name, + const scoped_refptr<SocketParams>& params, + int num_sockets, + const BoundNetLog& net_log) { + CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); + pool->RequestSockets(group_name, ¶ms, num_sockets, net_log); +} + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc new file mode 100644 index 00000000000..b1ddd40881c --- /dev/null +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -0,0 +1,1266 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool_base.h" + +#include "base/compiler_specific.h" +#include "base/format_macros.h" +#include "base/logging.h" +#include "base/message_loop/message_loop.h" +#include "base/metrics/stats_counters.h" +#include "base/stl_util.h" +#include "base/strings/string_util.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_handle.h" + +using base::TimeDelta; + +namespace net { + +namespace { + +// Indicate whether we should enable idle socket cleanup timer. When timer is +// disabled, sockets are closed next time a socket request is made. +bool g_cleanup_timer_enabled = true; + +// The timeout value, in seconds, used to clean up idle sockets that can't be +// reused. +// +// Note: It's important to close idle sockets that have received data as soon +// as possible because the received data may cause BSOD on Windows XP under +// some conditions. See http://crbug.com/4606. +const int kCleanupInterval = 10; // DO NOT INCREASE THIS TIMEOUT. + +// Indicate whether or not we should establish a new transport layer connection +// after a certain timeout has passed without receiving an ACK. +bool g_connect_backup_jobs_enabled = true; + +// Compares the effective priority of two results, and returns 1 if |request1| +// has greater effective priority than |request2|, 0 if they have the same +// effective priority, and -1 if |request2| has the greater effective priority. +// Requests with |ignore_limits| set have higher effective priority than those +// without. If both requests have |ignore_limits| set/unset, then the request +// with the highest Pririoty has the highest effective priority. Does not take +// into account the fact that Requests are serviced in FIFO order if they would +// otherwise have the same priority. +int CompareEffectiveRequestPriority( + const internal::ClientSocketPoolBaseHelper::Request& request1, + const internal::ClientSocketPoolBaseHelper::Request& request2) { + if (request1.ignore_limits() && !request2.ignore_limits()) + return 1; + if (!request1.ignore_limits() && request2.ignore_limits()) + return -1; + if (request1.priority() > request2.priority()) + return 1; + if (request1.priority() < request2.priority()) + return -1; + return 0; +} + +} // namespace + +ConnectJob::ConnectJob(const std::string& group_name, + base::TimeDelta timeout_duration, + Delegate* delegate, + const BoundNetLog& net_log) + : group_name_(group_name), + timeout_duration_(timeout_duration), + delegate_(delegate), + net_log_(net_log), + idle_(true) { + DCHECK(!group_name.empty()); + DCHECK(delegate); + net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB, + NetLog::StringCallback("group_name", &group_name_)); +} + +ConnectJob::~ConnectJob() { + net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB); +} + +scoped_ptr<StreamSocket> ConnectJob::PassSocket() { + return socket_.Pass(); +} + +int ConnectJob::Connect() { + if (timeout_duration_ != base::TimeDelta()) + timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout); + + idle_ = false; + + LogConnectStart(); + + int rv = ConnectInternal(); + + if (rv != ERR_IO_PENDING) { + LogConnectCompletion(rv); + delegate_ = NULL; + } + + return rv; +} + +void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) { + if (socket) { + net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET, + socket->NetLog().source().ToEventParametersCallback()); + } + socket_ = socket.Pass(); +} + +void ConnectJob::NotifyDelegateOfCompletion(int rv) { + // The delegate will own |this|. + Delegate* delegate = delegate_; + delegate_ = NULL; + + LogConnectCompletion(rv); + delegate->OnConnectJobComplete(rv, this); +} + +void ConnectJob::ResetTimer(base::TimeDelta remaining_time) { + timer_.Stop(); + timer_.Start(FROM_HERE, remaining_time, this, &ConnectJob::OnTimeout); +} + +void ConnectJob::LogConnectStart() { + connect_timing_.connect_start = base::TimeTicks::Now(); + net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT); +} + +void ConnectJob::LogConnectCompletion(int net_error) { + connect_timing_.connect_end = base::TimeTicks::Now(); + net_log().EndEventWithNetErrorCode( + NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT, net_error); +} + +void ConnectJob::OnTimeout() { + // Make sure the socket is NULL before calling into |delegate|. + SetSocket(scoped_ptr<StreamSocket>()); + + net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT); + + NotifyDelegateOfCompletion(ERR_TIMED_OUT); +} + +namespace internal { + +ClientSocketPoolBaseHelper::Request::Request( + ClientSocketHandle* handle, + const CompletionCallback& callback, + RequestPriority priority, + bool ignore_limits, + Flags flags, + const BoundNetLog& net_log) + : handle_(handle), + callback_(callback), + priority_(priority), + ignore_limits_(ignore_limits), + flags_(flags), + net_log_(net_log) {} + +ClientSocketPoolBaseHelper::Request::~Request() {} + +ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( + int max_sockets, + int max_sockets_per_group, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout, + ConnectJobFactory* connect_job_factory) + : idle_socket_count_(0), + connecting_socket_count_(0), + handed_out_socket_count_(0), + max_sockets_(max_sockets), + max_sockets_per_group_(max_sockets_per_group), + use_cleanup_timer_(g_cleanup_timer_enabled), + unused_idle_socket_timeout_(unused_idle_socket_timeout), + used_idle_socket_timeout_(used_idle_socket_timeout), + connect_job_factory_(connect_job_factory), + connect_backup_jobs_enabled_(false), + pool_generation_number_(0), + weak_factory_(this) { + DCHECK_LE(0, max_sockets_per_group); + DCHECK_LE(max_sockets_per_group, max_sockets); + + NetworkChangeNotifier::AddIPAddressObserver(this); +} + +ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() { + // Clean up any idle sockets and pending connect jobs. Assert that we have no + // remaining active sockets or pending requests. They should have all been + // cleaned up prior to |this| being destroyed. + FlushWithError(ERR_ABORTED); + DCHECK(group_map_.empty()); + DCHECK(pending_callback_map_.empty()); + DCHECK_EQ(0, connecting_socket_count_); + CHECK(higher_layer_pools_.empty()); + + NetworkChangeNotifier::RemoveIPAddressObserver(this); +} + +ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair() + : result(OK) { +} + +ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair( + const CompletionCallback& callback_in, int result_in) + : callback(callback_in), + result(result_in) { +} + +ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {} + +// static +void ClientSocketPoolBaseHelper::InsertRequestIntoQueue( + const Request* r, RequestQueue* pending_requests) { + RequestQueue::iterator it = pending_requests->begin(); + // TODO(mmenke): Should the network stack require requests with + // |ignore_limits| have the highest priority? + while (it != pending_requests->end() && + CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { + ++it; + } + pending_requests->insert(it, r); +} + +// static +const ClientSocketPoolBaseHelper::Request* +ClientSocketPoolBaseHelper::RemoveRequestFromQueue( + const RequestQueue::iterator& it, Group* group) { + const Request* req = *it; + group->mutable_pending_requests()->erase(it); + // If there are no more requests, we kill the backup timer. + if (group->pending_requests().empty()) + group->CleanupBackupJob(); + return req; +} + +void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) { + CHECK(pool); + CHECK(!ContainsKey(higher_layer_pools_, pool)); + higher_layer_pools_.insert(pool); +} + +void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) { + CHECK(pool); + CHECK(ContainsKey(higher_layer_pools_, pool)); + higher_layer_pools_.erase(pool); +} + +int ClientSocketPoolBaseHelper::RequestSocket( + const std::string& group_name, + const Request* request) { + CHECK(!request->callback().is_null()); + CHECK(request->handle()); + + // Cleanup any timed-out idle sockets if no timer is used. + if (!use_cleanup_timer_) + CleanupIdleSockets(false); + + request->net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL); + Group* group = GetOrCreateGroup(group_name); + + int rv = RequestSocketInternal(group_name, request); + if (rv != ERR_IO_PENDING) { + request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); + CHECK(!request->handle()->is_initialized()); + delete request; + } else { + InsertRequestIntoQueue(request, group->mutable_pending_requests()); + // Have to do this asynchronously, as closing sockets in higher level pools + // call back in to |this|, which will cause all sorts of fun and exciting + // re-entrancy issues if the socket pool is doing something else at the + // time. + if (group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind( + &ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools, + weak_factory_.GetWeakPtr())); + } + } + return rv; +} + +void ClientSocketPoolBaseHelper::RequestSockets( + const std::string& group_name, + const Request& request, + int num_sockets) { + DCHECK(request.callback().is_null()); + DCHECK(!request.handle()); + + // Cleanup any timed out idle sockets if no timer is used. + if (!use_cleanup_timer_) + CleanupIdleSockets(false); + + if (num_sockets > max_sockets_per_group_) { + num_sockets = max_sockets_per_group_; + } + + request.net_log().BeginEvent( + NetLog::TYPE_SOCKET_POOL_CONNECTING_N_SOCKETS, + NetLog::IntegerCallback("num_sockets", num_sockets)); + + Group* group = GetOrCreateGroup(group_name); + + // RequestSocketsInternal() may delete the group. + bool deleted_group = false; + + int rv = OK; + for (int num_iterations_left = num_sockets; + group->NumActiveSocketSlots() < num_sockets && + num_iterations_left > 0 ; num_iterations_left--) { + rv = RequestSocketInternal(group_name, &request); + if (rv < 0 && rv != ERR_IO_PENDING) { + // We're encountering a synchronous error. Give up. + if (!ContainsKey(group_map_, group_name)) + deleted_group = true; + break; + } + if (!ContainsKey(group_map_, group_name)) { + // Unexpected. The group should only be getting deleted on synchronous + // error. + NOTREACHED(); + deleted_group = true; + break; + } + } + + if (!deleted_group && group->IsEmpty()) + RemoveGroup(group_name); + + if (rv == ERR_IO_PENDING) + rv = OK; + request.net_log().EndEventWithNetErrorCode( + NetLog::TYPE_SOCKET_POOL_CONNECTING_N_SOCKETS, rv); +} + +int ClientSocketPoolBaseHelper::RequestSocketInternal( + const std::string& group_name, + const Request* request) { + ClientSocketHandle* const handle = request->handle(); + const bool preconnecting = !handle; + Group* group = GetOrCreateGroup(group_name); + + if (!(request->flags() & NO_IDLE_SOCKETS)) { + // Try to reuse a socket. + if (AssignIdleSocketToRequest(request, group)) + return OK; + } + + // If there are more ConnectJobs than pending requests, don't need to do + // anything. Can just wait for the extra job to connect, and then assign it + // to the request. + if (!preconnecting && group->TryToUseUnassignedConnectJob()) + return ERR_IO_PENDING; + + // Can we make another active socket now? + if (!group->HasAvailableSocketSlot(max_sockets_per_group_) && + !request->ignore_limits()) { + // TODO(willchan): Consider whether or not we need to close a socket in a + // higher layered group. I don't think this makes sense since we would just + // reuse that socket then if we needed one and wouldn't make it down to this + // layer. + request->net_log().AddEvent( + NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP); + return ERR_IO_PENDING; + } + + if (ReachedMaxSocketsLimit() && !request->ignore_limits()) { + // NOTE(mmenke): Wonder if we really need different code for each case + // here. Only reason for them now seems to be preconnects. + if (idle_socket_count() > 0) { + // There's an idle socket in this pool. Either that's because there's + // still one in this group, but we got here due to preconnecting bypassing + // idle sockets, or because there's an idle socket in another group. + bool closed = CloseOneIdleSocketExceptInGroup(group); + if (preconnecting && !closed) + return ERR_PRECONNECT_MAX_SOCKET_LIMIT; + } else { + // We could check if we really have a stalled group here, but it requires + // a scan of all groups, so just flip a flag here, and do the check later. + request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); + return ERR_IO_PENDING; + } + } + + // We couldn't find a socket to reuse, and there's space to allocate one, + // so allocate and connect a new one. + scoped_ptr<ConnectJob> connect_job( + connect_job_factory_->NewConnectJob(group_name, *request, this)); + + int rv = connect_job->Connect(); + if (rv == OK) { + LogBoundConnectJobToRequest(connect_job->net_log().source(), request); + if (!preconnecting) { + HandOutSocket(connect_job->PassSocket(), false /* not reused */, + connect_job->connect_timing(), handle, base::TimeDelta(), + group, request->net_log()); + } else { + AddIdleSocket(connect_job->PassSocket(), group); + } + } else if (rv == ERR_IO_PENDING) { + // If we don't have any sockets in this group, set a timer for potentially + // creating a new one. If the SYN is lost, this backup socket may complete + // before the slow socket, improving end user latency. + if (connect_backup_jobs_enabled_ && + group->IsEmpty() && !group->HasBackupJob()) { + group->StartBackupSocketTimer(group_name, this); + } + + connecting_socket_count_++; + + group->AddJob(connect_job.Pass(), preconnecting); + } else { + LogBoundConnectJobToRequest(connect_job->net_log().source(), request); + scoped_ptr<StreamSocket> error_socket; + if (!preconnecting) { + DCHECK(handle); + connect_job->GetAdditionalErrorState(handle); + error_socket = connect_job->PassSocket(); + } + if (error_socket) { + HandOutSocket(error_socket.Pass(), false /* not reused */, + connect_job->connect_timing(), handle, base::TimeDelta(), + group, request->net_log()); + } else if (group->IsEmpty()) { + RemoveGroup(group_name); + } + } + + return rv; +} + +bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( + const Request* request, Group* group) { + std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets(); + std::list<IdleSocket>::iterator idle_socket_it = idle_sockets->end(); + + // Iterate through the idle sockets forwards (oldest to newest) + // * Delete any disconnected ones. + // * If we find a used idle socket, assign to |idle_socket|. At the end, + // the |idle_socket_it| will be set to the newest used idle socket. + for (std::list<IdleSocket>::iterator it = idle_sockets->begin(); + it != idle_sockets->end();) { + if (!it->socket->IsConnectedAndIdle()) { + DecrementIdleCount(); + delete it->socket; + it = idle_sockets->erase(it); + continue; + } + + if (it->socket->WasEverUsed()) { + // We found one we can reuse! + idle_socket_it = it; + } + + ++it; + } + + // If we haven't found an idle socket, that means there are no used idle + // sockets. Pick the oldest (first) idle socket (FIFO). + + if (idle_socket_it == idle_sockets->end() && !idle_sockets->empty()) + idle_socket_it = idle_sockets->begin(); + + if (idle_socket_it != idle_sockets->end()) { + DecrementIdleCount(); + base::TimeDelta idle_time = + base::TimeTicks::Now() - idle_socket_it->start_time; + IdleSocket idle_socket = *idle_socket_it; + idle_sockets->erase(idle_socket_it); + HandOutSocket( + scoped_ptr<StreamSocket>(idle_socket.socket), + idle_socket.socket->WasEverUsed(), + LoadTimingInfo::ConnectTiming(), + request->handle(), + idle_time, + group, + request->net_log()); + return true; + } + + return false; +} + +// static +void ClientSocketPoolBaseHelper::LogBoundConnectJobToRequest( + const NetLog::Source& connect_job_source, const Request* request) { + request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + connect_job_source.ToEventParametersCallback()); +} + +void ClientSocketPoolBaseHelper::CancelRequest( + const std::string& group_name, ClientSocketHandle* handle) { + PendingCallbackMap::iterator callback_it = pending_callback_map_.find(handle); + if (callback_it != pending_callback_map_.end()) { + int result = callback_it->second.result; + pending_callback_map_.erase(callback_it); + scoped_ptr<StreamSocket> socket = handle->PassSocket(); + if (socket) { + if (result != OK) + socket->Disconnect(); + ReleaseSocket(handle->group_name(), socket.Pass(), handle->id()); + } + return; + } + + CHECK(ContainsKey(group_map_, group_name)); + + Group* group = GetOrCreateGroup(group_name); + + // Search pending_requests for matching handle. + RequestQueue::iterator it = group->mutable_pending_requests()->begin(); + for (; it != group->pending_requests().end(); ++it) { + if ((*it)->handle() == handle) { + scoped_ptr<const Request> req(RemoveRequestFromQueue(it, group)); + req->net_log().AddEvent(NetLog::TYPE_CANCELLED); + req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + + // We let the job run, unless we're at the socket limit and there is + // not another request waiting on the job. + if (group->jobs().size() > group->pending_requests().size() && + ReachedMaxSocketsLimit()) { + RemoveConnectJob(*group->jobs().begin(), group); + CheckForStalledSocketGroups(); + } + break; + } + } +} + +bool ClientSocketPoolBaseHelper::HasGroup(const std::string& group_name) const { + return ContainsKey(group_map_, group_name); +} + +void ClientSocketPoolBaseHelper::CloseIdleSockets() { + CleanupIdleSockets(true); + DCHECK_EQ(0, idle_socket_count_); +} + +int ClientSocketPoolBaseHelper::IdleSocketCountInGroup( + const std::string& group_name) const { + GroupMap::const_iterator i = group_map_.find(group_name); + CHECK(i != group_map_.end()); + + return i->second->idle_sockets().size(); +} + +LoadState ClientSocketPoolBaseHelper::GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const { + if (ContainsKey(pending_callback_map_, handle)) + return LOAD_STATE_CONNECTING; + + if (!ContainsKey(group_map_, group_name)) { + NOTREACHED() << "ClientSocketPool does not contain group: " << group_name + << " for handle: " << handle; + return LOAD_STATE_IDLE; + } + + // Can't use operator[] since it is non-const. + const Group& group = *group_map_.find(group_name)->second; + + // Search the first group.jobs().size() |pending_requests| for |handle|. + // If it's farther back in the deque than that, it doesn't have a + // corresponding ConnectJob. + size_t connect_jobs = group.jobs().size(); + RequestQueue::const_iterator it = group.pending_requests().begin(); + for (size_t i = 0; it != group.pending_requests().end() && i < connect_jobs; + ++it, ++i) { + if ((*it)->handle() != handle) + continue; + + // Just return the state of the farthest along ConnectJob for the first + // group.jobs().size() pending requests. + LoadState max_state = LOAD_STATE_IDLE; + for (ConnectJobSet::const_iterator job_it = group.jobs().begin(); + job_it != group.jobs().end(); ++job_it) { + max_state = std::max(max_state, (*job_it)->GetLoadState()); + } + return max_state; + } + + if (group.IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + return LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL; + return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; +} + +base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( + const std::string& name, const std::string& type) const { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("name", name); + dict->SetString("type", type); + dict->SetInteger("handed_out_socket_count", handed_out_socket_count_); + dict->SetInteger("connecting_socket_count", connecting_socket_count_); + dict->SetInteger("idle_socket_count", idle_socket_count_); + dict->SetInteger("max_socket_count", max_sockets_); + dict->SetInteger("max_sockets_per_group", max_sockets_per_group_); + dict->SetInteger("pool_generation_number", pool_generation_number_); + + if (group_map_.empty()) + return dict; + + base::DictionaryValue* all_groups_dict = new base::DictionaryValue(); + for (GroupMap::const_iterator it = group_map_.begin(); + it != group_map_.end(); it++) { + const Group* group = it->second; + base::DictionaryValue* group_dict = new base::DictionaryValue(); + + group_dict->SetInteger("pending_request_count", + group->pending_requests().size()); + if (!group->pending_requests().empty()) { + group_dict->SetInteger("top_pending_priority", + group->TopPendingPriority()); + } + + group_dict->SetInteger("active_socket_count", group->active_socket_count()); + + base::ListValue* idle_socket_list = new base::ListValue(); + std::list<IdleSocket>::const_iterator idle_socket; + for (idle_socket = group->idle_sockets().begin(); + idle_socket != group->idle_sockets().end(); + idle_socket++) { + int source_id = idle_socket->socket->NetLog().source().id; + idle_socket_list->Append(new base::FundamentalValue(source_id)); + } + group_dict->Set("idle_sockets", idle_socket_list); + + base::ListValue* connect_jobs_list = new base::ListValue(); + std::set<ConnectJob*>::const_iterator job = group->jobs().begin(); + for (job = group->jobs().begin(); job != group->jobs().end(); job++) { + int source_id = (*job)->net_log().source().id; + connect_jobs_list->Append(new base::FundamentalValue(source_id)); + } + group_dict->Set("connect_jobs", connect_jobs_list); + + group_dict->SetBoolean("is_stalled", + group->IsStalledOnPoolMaxSockets( + max_sockets_per_group_)); + group_dict->SetBoolean("has_backup_job", group->HasBackupJob()); + + all_groups_dict->SetWithoutPathExpansion(it->first, group_dict); + } + dict->Set("groups", all_groups_dict); + return dict; +} + +bool ClientSocketPoolBaseHelper::IdleSocket::ShouldCleanup( + base::TimeTicks now, + base::TimeDelta timeout) const { + bool timed_out = (now - start_time) >= timeout; + if (timed_out) + return true; + if (socket->WasEverUsed()) + return !socket->IsConnectedAndIdle(); + return !socket->IsConnected(); +} + +void ClientSocketPoolBaseHelper::CleanupIdleSockets(bool force) { + if (idle_socket_count_ == 0) + return; + + // Current time value. Retrieving it once at the function start rather than + // inside the inner loop, since it shouldn't change by any meaningful amount. + base::TimeTicks now = base::TimeTicks::Now(); + + GroupMap::iterator i = group_map_.begin(); + while (i != group_map_.end()) { + Group* group = i->second; + + std::list<IdleSocket>::iterator j = group->mutable_idle_sockets()->begin(); + while (j != group->idle_sockets().end()) { + base::TimeDelta timeout = + j->socket->WasEverUsed() ? + used_idle_socket_timeout_ : unused_idle_socket_timeout_; + if (force || j->ShouldCleanup(now, timeout)) { + delete j->socket; + j = group->mutable_idle_sockets()->erase(j); + DecrementIdleCount(); + } else { + ++j; + } + } + + // Delete group if no longer needed. + if (group->IsEmpty()) { + RemoveGroup(i++); + } else { + ++i; + } + } +} + +ClientSocketPoolBaseHelper::Group* ClientSocketPoolBaseHelper::GetOrCreateGroup( + const std::string& group_name) { + GroupMap::iterator it = group_map_.find(group_name); + if (it != group_map_.end()) + return it->second; + Group* group = new Group; + group_map_[group_name] = group; + return group; +} + +void ClientSocketPoolBaseHelper::RemoveGroup(const std::string& group_name) { + GroupMap::iterator it = group_map_.find(group_name); + CHECK(it != group_map_.end()); + + RemoveGroup(it); +} + +void ClientSocketPoolBaseHelper::RemoveGroup(GroupMap::iterator it) { + delete it->second; + group_map_.erase(it); +} + +// static +bool ClientSocketPoolBaseHelper::connect_backup_jobs_enabled() { + return g_connect_backup_jobs_enabled; +} + +// static +bool ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(bool enabled) { + bool old_value = g_connect_backup_jobs_enabled; + g_connect_backup_jobs_enabled = enabled; + return old_value; +} + +void ClientSocketPoolBaseHelper::EnableConnectBackupJobs() { + connect_backup_jobs_enabled_ = g_connect_backup_jobs_enabled; +} + +void ClientSocketPoolBaseHelper::IncrementIdleCount() { + if (++idle_socket_count_ == 1 && use_cleanup_timer_) + StartIdleSocketTimer(); +} + +void ClientSocketPoolBaseHelper::DecrementIdleCount() { + if (--idle_socket_count_ == 0) + timer_.Stop(); +} + +// static +bool ClientSocketPoolBaseHelper::cleanup_timer_enabled() { + return g_cleanup_timer_enabled; +} + +// static +bool ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(bool enabled) { + bool old_value = g_cleanup_timer_enabled; + g_cleanup_timer_enabled = enabled; + return old_value; +} + +void ClientSocketPoolBaseHelper::StartIdleSocketTimer() { + timer_.Start(FROM_HERE, TimeDelta::FromSeconds(kCleanupInterval), this, + &ClientSocketPoolBaseHelper::OnCleanupTimerFired); +} + +void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + GroupMap::iterator i = group_map_.find(group_name); + CHECK(i != group_map_.end()); + + Group* group = i->second; + + CHECK_GT(handed_out_socket_count_, 0); + handed_out_socket_count_--; + + CHECK_GT(group->active_socket_count(), 0); + group->DecrementActiveSocketCount(); + + const bool can_reuse = socket->IsConnectedAndIdle() && + id == pool_generation_number_; + if (can_reuse) { + // Add it to the idle list. + AddIdleSocket(socket.Pass(), group); + OnAvailableSocketSlot(group_name, group); + } else { + socket.reset(); + } + + CheckForStalledSocketGroups(); +} + +void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() { + // If we have idle sockets, see if we can give one to the top-stalled group. + std::string top_group_name; + Group* top_group = NULL; + if (!FindTopStalledGroup(&top_group, &top_group_name)) + return; + + if (ReachedMaxSocketsLimit()) { + if (idle_socket_count() > 0) { + CloseOneIdleSocket(); + } else { + // We can't activate more sockets since we're already at our global + // limit. + return; + } + } + + // Note: we don't loop on waking stalled groups. If the stalled group is at + // its limit, may be left with other stalled groups that could be + // woken. This isn't optimal, but there is no starvation, so to avoid + // the looping we leave it at this. + OnAvailableSocketSlot(top_group_name, top_group); +} + +// Search for the highest priority pending request, amongst the groups that +// are not at the |max_sockets_per_group_| limit. Note: for requests with +// the same priority, the winner is based on group hash ordering (and not +// insertion order). +bool ClientSocketPoolBaseHelper::FindTopStalledGroup( + Group** group, + std::string* group_name) const { + CHECK((group && group_name) || (!group && !group_name)); + Group* top_group = NULL; + const std::string* top_group_name = NULL; + bool has_stalled_group = false; + for (GroupMap::const_iterator i = group_map_.begin(); + i != group_map_.end(); ++i) { + Group* curr_group = i->second; + const RequestQueue& queue = curr_group->pending_requests(); + if (queue.empty()) + continue; + if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { + if (!group) + return true; + has_stalled_group = true; + bool has_higher_priority = !top_group || + curr_group->TopPendingPriority() > top_group->TopPendingPriority(); + if (has_higher_priority) { + top_group = curr_group; + top_group_name = &i->first; + } + } + } + + if (top_group) { + CHECK(group); + *group = top_group; + *group_name = *top_group_name; + } else { + CHECK(!has_stalled_group); + } + return has_stalled_group; +} + +void ClientSocketPoolBaseHelper::OnConnectJobComplete( + int result, ConnectJob* job) { + DCHECK_NE(ERR_IO_PENDING, result); + const std::string group_name = job->group_name(); + GroupMap::iterator group_it = group_map_.find(group_name); + CHECK(group_it != group_map_.end()); + Group* group = group_it->second; + + scoped_ptr<StreamSocket> socket = job->PassSocket(); + + // Copies of these are needed because |job| may be deleted before they are + // accessed. + BoundNetLog job_log = job->net_log(); + LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + + // RemoveConnectJob(job, _) must be called by all branches below; + // otherwise, |job| will be leaked. + + if (result == OK) { + DCHECK(socket.get()); + RemoveConnectJob(job, group); + if (!group->pending_requests().empty()) { + scoped_ptr<const Request> r(RemoveRequestFromQueue( + group->mutable_pending_requests()->begin(), group)); + LogBoundConnectJobToRequest(job_log.source(), r.get()); + HandOutSocket( + socket.Pass(), false /* unused socket */, connect_timing, + r->handle(), base::TimeDelta(), group, r->net_log()); + r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + InvokeUserCallbackLater(r->handle(), r->callback(), result); + } else { + AddIdleSocket(socket.Pass(), group); + OnAvailableSocketSlot(group_name, group); + CheckForStalledSocketGroups(); + } + } else { + // If we got a socket, it must contain error information so pass that + // up so that the caller can retrieve it. + bool handed_out_socket = false; + if (!group->pending_requests().empty()) { + scoped_ptr<const Request> r(RemoveRequestFromQueue( + group->mutable_pending_requests()->begin(), group)); + LogBoundConnectJobToRequest(job_log.source(), r.get()); + job->GetAdditionalErrorState(r->handle()); + RemoveConnectJob(job, group); + if (socket.get()) { + handed_out_socket = true; + HandOutSocket(socket.Pass(), false /* unused socket */, + connect_timing, r->handle(), base::TimeDelta(), group, + r->net_log()); + } + r->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result); + InvokeUserCallbackLater(r->handle(), r->callback(), result); + } else { + RemoveConnectJob(job, group); + } + if (!handed_out_socket) { + OnAvailableSocketSlot(group_name, group); + CheckForStalledSocketGroups(); + } + } +} + +void ClientSocketPoolBaseHelper::OnIPAddressChanged() { + FlushWithError(ERR_NETWORK_CHANGED); +} + +void ClientSocketPoolBaseHelper::FlushWithError(int error) { + pool_generation_number_++; + CancelAllConnectJobs(); + CloseIdleSockets(); + CancelAllRequestsWithError(error); +} + +bool ClientSocketPoolBaseHelper::IsStalled() const { + // If we are not using |max_sockets_|, then clearly we are not stalled + if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) + return false; + // So in order to be stalled we need to be using |max_sockets_| AND + // we need to have a request that is actually stalled on the global + // socket limit. To find such a request, we look for a group that + // a has more requests that jobs AND where the number of jobs is less + // than |max_sockets_per_group_|. (If the number of jobs is equal to + // |max_sockets_per_group_|, then the request is stalled on the group, + // which does not count.) + for (GroupMap::const_iterator it = group_map_.begin(); + it != group_map_.end(); ++it) { + if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + return true; + } + return false; +} + +void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job, + Group* group) { + CHECK_GT(connecting_socket_count_, 0); + connecting_socket_count_--; + + DCHECK(group); + group->RemoveJob(job); + + // If we've got no more jobs for this group, then we no longer need a + // backup job either. + if (group->jobs().empty()) + group->CleanupBackupJob(); +} + +void ClientSocketPoolBaseHelper::OnAvailableSocketSlot( + const std::string& group_name, Group* group) { + DCHECK(ContainsKey(group_map_, group_name)); + if (group->IsEmpty()) + RemoveGroup(group_name); + else if (!group->pending_requests().empty()) + ProcessPendingRequest(group_name, group); +} + +void ClientSocketPoolBaseHelper::ProcessPendingRequest( + const std::string& group_name, Group* group) { + int rv = RequestSocketInternal(group_name, + *group->pending_requests().begin()); + if (rv != ERR_IO_PENDING) { + scoped_ptr<const Request> request(RemoveRequestFromQueue( + group->mutable_pending_requests()->begin(), group)); + if (group->IsEmpty()) + RemoveGroup(group_name); + + request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); + InvokeUserCallbackLater(request->handle(), request->callback(), rv); + } +} + +void ClientSocketPoolBaseHelper::HandOutSocket( + scoped_ptr<StreamSocket> socket, + bool reused, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + base::TimeDelta idle_time, + Group* group, + const BoundNetLog& net_log) { + DCHECK(socket); + handle->SetSocket(socket.Pass()); + handle->set_is_reused(reused); + handle->set_idle_time(idle_time); + handle->set_pool_id(pool_generation_number_); + handle->set_connect_timing(connect_timing); + + if (reused) { + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET, + NetLog::IntegerCallback( + "idle_ms", static_cast<int>(idle_time.InMilliseconds()))); + } + + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); + + handed_out_socket_count_++; + group->IncrementActiveSocketCount(); +} + +void ClientSocketPoolBaseHelper::AddIdleSocket( + scoped_ptr<StreamSocket> socket, + Group* group) { + DCHECK(socket); + IdleSocket idle_socket; + idle_socket.socket = socket.release(); + idle_socket.start_time = base::TimeTicks::Now(); + + group->mutable_idle_sockets()->push_back(idle_socket); + IncrementIdleCount(); +} + +void ClientSocketPoolBaseHelper::CancelAllConnectJobs() { + for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) { + Group* group = i->second; + connecting_socket_count_ -= group->jobs().size(); + group->RemoveAllJobs(); + + // Delete group if no longer needed. + if (group->IsEmpty()) { + // RemoveGroup() will call .erase() which will invalidate the iterator, + // but i will already have been incremented to a valid iterator before + // RemoveGroup() is called. + RemoveGroup(i++); + } else { + ++i; + } + } + DCHECK_EQ(0, connecting_socket_count_); +} + +void ClientSocketPoolBaseHelper::CancelAllRequestsWithError(int error) { + for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) { + Group* group = i->second; + + RequestQueue pending_requests; + pending_requests.swap(*group->mutable_pending_requests()); + for (RequestQueue::iterator it2 = pending_requests.begin(); + it2 != pending_requests.end(); ++it2) { + scoped_ptr<const Request> request(*it2); + InvokeUserCallbackLater( + request->handle(), request->callback(), error); + } + + // Delete group if no longer needed. + if (group->IsEmpty()) { + // RemoveGroup() will call .erase() which will invalidate the iterator, + // but i will already have been incremented to a valid iterator before + // RemoveGroup() is called. + RemoveGroup(i++); + } else { + ++i; + } + } +} + +bool ClientSocketPoolBaseHelper::ReachedMaxSocketsLimit() const { + // Each connecting socket will eventually connect and be handed out. + int total = handed_out_socket_count_ + connecting_socket_count_ + + idle_socket_count(); + // There can be more sockets than the limit since some requests can ignore + // the limit + if (total < max_sockets_) + return false; + return true; +} + +bool ClientSocketPoolBaseHelper::CloseOneIdleSocket() { + if (idle_socket_count() == 0) + return false; + return CloseOneIdleSocketExceptInGroup(NULL); +} + +bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( + const Group* exception_group) { + CHECK_GT(idle_socket_count(), 0); + + for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end(); ++i) { + Group* group = i->second; + if (exception_group == group) + continue; + std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets(); + + if (!idle_sockets->empty()) { + delete idle_sockets->front().socket; + idle_sockets->pop_front(); + DecrementIdleCount(); + if (group->IsEmpty()) + RemoveGroup(i); + + return true; + } + } + + return false; +} + +bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() { + // This pool doesn't have any idle sockets. It's possible that a pool at a + // higher layer is holding one of this sockets active, but it's actually idle. + // Query the higher layers. + for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin(); + it != higher_layer_pools_.end(); ++it) { + if ((*it)->CloseOneIdleConnection()) + return true; + } + return false; +} + +void ClientSocketPoolBaseHelper::InvokeUserCallbackLater( + ClientSocketHandle* handle, const CompletionCallback& callback, int rv) { + CHECK(!ContainsKey(pending_callback_map_, handle)); + pending_callback_map_[handle] = CallbackResultPair(callback, rv); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&ClientSocketPoolBaseHelper::InvokeUserCallback, + weak_factory_.GetWeakPtr(), handle)); +} + +void ClientSocketPoolBaseHelper::InvokeUserCallback( + ClientSocketHandle* handle) { + PendingCallbackMap::iterator it = pending_callback_map_.find(handle); + + // Exit if the request has already been cancelled. + if (it == pending_callback_map_.end()) + return; + + CHECK(!handle->is_initialized()); + CompletionCallback callback = it->second.callback; + int result = it->second.result; + pending_callback_map_.erase(it); + callback.Run(result); +} + +void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() { + while (IsStalled()) { + // Closing a socket will result in calling back into |this| to use the freed + // socket slot, so nothing else is needed. + if (!CloseOneIdleConnectionInLayeredPool()) + return; + } +} + +ClientSocketPoolBaseHelper::Group::Group() + : unassigned_job_count_(0), + active_socket_count_(0), + weak_factory_(this) {} + +ClientSocketPoolBaseHelper::Group::~Group() { + CleanupBackupJob(); + DCHECK_EQ(0u, unassigned_job_count_); +} + +void ClientSocketPoolBaseHelper::Group::StartBackupSocketTimer( + const std::string& group_name, + ClientSocketPoolBaseHelper* pool) { + // Only allow one timer pending to create a backup socket. + if (weak_factory_.HasWeakPtrs()) + return; + + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&Group::OnBackupSocketTimerFired, weak_factory_.GetWeakPtr(), + group_name, pool), + pool->ConnectRetryInterval()); +} + +bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() { + SanityCheck(); + + if (unassigned_job_count_ == 0) + return false; + --unassigned_job_count_; + return true; +} + +void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job, + bool is_preconnect) { + SanityCheck(); + + if (is_preconnect) + ++unassigned_job_count_; + jobs_.insert(job.release()); +} + +void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { + scoped_ptr<ConnectJob> owned_job(job); + SanityCheck(); + + std::set<ConnectJob*>::iterator it = jobs_.find(job); + if (it != jobs_.end()) { + jobs_.erase(it); + } else { + NOTREACHED(); + } + size_t job_count = jobs_.size(); + if (job_count < unassigned_job_count_) + unassigned_job_count_ = job_count; +} + +void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( + std::string group_name, + ClientSocketPoolBaseHelper* pool) { + // If there are no more jobs pending, there is no work to do. + // If we've done our cleanups correctly, this should not happen. + if (jobs_.empty()) { + NOTREACHED(); + return; + } + + // If our old job is waiting on DNS, or if we can't create any sockets + // right now due to limits, just reset the timer. + if (pool->ReachedMaxSocketsLimit() || + !HasAvailableSocketSlot(pool->max_sockets_per_group_) || + (*jobs_.begin())->GetLoadState() == LOAD_STATE_RESOLVING_HOST) { + StartBackupSocketTimer(group_name, pool); + return; + } + + if (pending_requests_.empty()) + return; + + scoped_ptr<ConnectJob> backup_job = + pool->connect_job_factory_->NewConnectJob( + group_name, **pending_requests_.begin(), pool); + backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED); + SIMPLE_STATS_COUNTER("socket.backup_created"); + int rv = backup_job->Connect(); + pool->connecting_socket_count_++; + ConnectJob* raw_backup_job = backup_job.get(); + AddJob(backup_job.Pass(), false); + if (rv != ERR_IO_PENDING) + pool->OnConnectJobComplete(rv, raw_backup_job); +} + +void ClientSocketPoolBaseHelper::Group::SanityCheck() { + DCHECK_LE(unassigned_job_count_, jobs_.size()); +} + +void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() { + SanityCheck(); + + // Delete active jobs. + STLDeleteElements(&jobs_); + unassigned_job_count_ = 0; + + // Cancel pending backup job. + weak_factory_.InvalidateWeakPtrs(); +} + +} // namespace internal + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h new file mode 100644 index 00000000000..eb642edd730 --- /dev/null +++ b/chromium/net/socket/client_socket_pool_base.h @@ -0,0 +1,819 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// A ClientSocketPoolBase is used to restrict the number of sockets open at +// a time. It also maintains a list of idle persistent sockets for reuse. +// Subclasses of ClientSocketPool should compose ClientSocketPoolBase to handle +// the core logic of (1) restricting the number of active (connected or +// connecting) sockets per "group" (generally speaking, the hostname), (2) +// maintaining a per-group list of idle, persistent sockets for reuse, and (3) +// limiting the total number of active sockets in the system. +// +// ClientSocketPoolBase abstracts socket connection details behind ConnectJob, +// ConnectJobFactory, and SocketParams. When a socket "slot" becomes available, +// the ClientSocketPoolBase will ask the ConnectJobFactory to create a +// ConnectJob with a SocketParams. Subclasses of ClientSocketPool should +// implement their socket specific connection by subclassing ConnectJob and +// implementing ConnectJob::ConnectInternal(). They can control the parameters +// passed to each new ConnectJob instance via their ConnectJobFactory subclass +// and templated SocketParams parameter. +// +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_ + +#include <deque> +#include <list> +#include <map> +#include <set> +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/load_states.h" +#include "net/base/load_timing_info.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/base/network_change_notifier.h" +#include "net/base/request_priority.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class ClientSocketHandle; + +// ConnectJob provides an abstract interface for "connecting" a socket. +// The connection may involve host resolution, tcp connection, ssl connection, +// etc. +class NET_EXPORT_PRIVATE ConnectJob { + public: + class NET_EXPORT_PRIVATE Delegate { + public: + Delegate() {} + virtual ~Delegate() {} + + // Alerts the delegate that the connection completed. |job| must + // be destroyed by the delegate. A scoped_ptr<> isn't used because + // the caller of this function doesn't own |job|. + virtual void OnConnectJobComplete(int result, + ConnectJob* job) = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(Delegate); + }; + + // A |timeout_duration| of 0 corresponds to no timeout. + ConnectJob(const std::string& group_name, + base::TimeDelta timeout_duration, + Delegate* delegate, + const BoundNetLog& net_log); + virtual ~ConnectJob(); + + // Accessors + const std::string& group_name() const { return group_name_; } + const BoundNetLog& net_log() { return net_log_; } + + // Releases ownership of the underlying socket to the caller. + // Returns the released socket, or NULL if there was a connection + // error. + scoped_ptr<StreamSocket> PassSocket(); + + // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it + // cannot complete synchronously without blocking, or another net error code + // on error. In asynchronous completion, the ConnectJob will notify + // |delegate_| via OnConnectJobComplete. In both asynchronous and synchronous + // completion, ReleaseSocket() can be called to acquire the connected socket + // if it succeeded. + int Connect(); + + virtual LoadState GetLoadState() const = 0; + + // If Connect returns an error (or OnConnectJobComplete reports an error + // result) this method will be called, allowing the pool to add + // additional error state to the ClientSocketHandle (post late-binding). + virtual void GetAdditionalErrorState(ClientSocketHandle* handle) {} + + const LoadTimingInfo::ConnectTiming& connect_timing() const { + return connect_timing_; + } + + const BoundNetLog& net_log() const { return net_log_; } + + protected: + void SetSocket(scoped_ptr<StreamSocket> socket); + StreamSocket* socket() { return socket_.get(); } + void NotifyDelegateOfCompletion(int rv); + void ResetTimer(base::TimeDelta remainingTime); + + // Connection establishment timing information. + LoadTimingInfo::ConnectTiming connect_timing_; + + private: + virtual int ConnectInternal() = 0; + + void LogConnectStart(); + void LogConnectCompletion(int net_error); + + // Alerts the delegate that the ConnectJob has timed out. + void OnTimeout(); + + const std::string group_name_; + const base::TimeDelta timeout_duration_; + // Timer to abort jobs that take too long. + base::OneShotTimer<ConnectJob> timer_; + Delegate* delegate_; + scoped_ptr<StreamSocket> socket_; + BoundNetLog net_log_; + // A ConnectJob is idle until Connect() has been called. + bool idle_; + + DISALLOW_COPY_AND_ASSIGN(ConnectJob); +}; + +namespace internal { + +// ClientSocketPoolBaseHelper is an internal class that implements almost all +// the functionality from ClientSocketPoolBase without using templates. +// ClientSocketPoolBase adds templated definitions built on top of +// ClientSocketPoolBaseHelper. This class is not for external use, please use +// ClientSocketPoolBase instead. +class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper + : public ConnectJob::Delegate, + public NetworkChangeNotifier::IPAddressObserver { + public: + typedef uint32 Flags; + + // Used to specify specific behavior for the ClientSocketPool. + enum Flag { + NORMAL = 0, // Normal behavior. + NO_IDLE_SOCKETS = 0x1, // Do not return an idle socket. Create a new one. + }; + + class NET_EXPORT_PRIVATE Request { + public: + Request(ClientSocketHandle* handle, + const CompletionCallback& callback, + RequestPriority priority, + bool ignore_limits, + Flags flags, + const BoundNetLog& net_log); + + virtual ~Request(); + + ClientSocketHandle* handle() const { return handle_; } + const CompletionCallback& callback() const { return callback_; } + RequestPriority priority() const { return priority_; } + bool ignore_limits() const { return ignore_limits_; } + Flags flags() const { return flags_; } + const BoundNetLog& net_log() const { return net_log_; } + + private: + ClientSocketHandle* const handle_; + CompletionCallback callback_; + const RequestPriority priority_; + bool ignore_limits_; + const Flags flags_; + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(Request); + }; + + class ConnectJobFactory { + public: + ConnectJobFactory() {} + virtual ~ConnectJobFactory() {} + + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const Request& request, + ConnectJob::Delegate* delegate) const = 0; + + virtual base::TimeDelta ConnectionTimeout() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory); + }; + + ClientSocketPoolBaseHelper( + int max_sockets, + int max_sockets_per_group, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout, + ConnectJobFactory* connect_job_factory); + + virtual ~ClientSocketPoolBaseHelper(); + + // Adds/Removes layered pools. It is expected in the destructor that no + // layered pools remain. + void AddLayeredPool(LayeredPool* pool); + void RemoveLayeredPool(LayeredPool* pool); + + // See ClientSocketPool::RequestSocket for documentation on this function. + // ClientSocketPoolBaseHelper takes ownership of |request|, which must be + // heap allocated. + int RequestSocket(const std::string& group_name, const Request* request); + + // See ClientSocketPool::RequestSocket for documentation on this function. + void RequestSockets(const std::string& group_name, + const Request& request, + int num_sockets); + + // See ClientSocketPool::CancelRequest for documentation on this function. + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle); + + // See ClientSocketPool::ReleaseSocket for documentation on this function. + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id); + + // See ClientSocketPool::FlushWithError for documentation on this function. + void FlushWithError(int error); + + // See ClientSocketPool::IsStalled for documentation on this function. + bool IsStalled() const; + + // See ClientSocketPool::CloseIdleSockets for documentation on this function. + void CloseIdleSockets(); + + // See ClientSocketPool::IdleSocketCount() for documentation on this function. + int idle_socket_count() const { + return idle_socket_count_; + } + + // See ClientSocketPool::IdleSocketCountInGroup() for documentation on this + // function. + int IdleSocketCountInGroup(const std::string& group_name) const; + + // See ClientSocketPool::GetLoadState() for documentation on this function. + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const; + + base::TimeDelta ConnectRetryInterval() const { + // TODO(mbelshe): Make this tuned dynamically based on measured RTT. + // For now, just use the max retry interval. + return base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs); + } + + int NumUnassignedConnectJobsInGroup(const std::string& group_name) const { + return group_map_.find(group_name)->second->unassigned_job_count(); + } + + int NumConnectJobsInGroup(const std::string& group_name) const { + return group_map_.find(group_name)->second->jobs().size(); + } + + int NumActiveSocketsInGroup(const std::string& group_name) const { + return group_map_.find(group_name)->second->active_socket_count(); + } + + bool HasGroup(const std::string& group_name) const; + + // Called to enable/disable cleaning up idle sockets. When enabled, + // idle sockets that have been around for longer than a period defined + // by kCleanupInterval are cleaned up using a timer. Otherwise they are + // closed next time client makes a request. This may reduce network + // activity and power consumption. + static bool cleanup_timer_enabled(); + static bool set_cleanup_timer_enabled(bool enabled); + + // Closes all idle sockets if |force| is true. Else, only closes idle + // sockets that timed out or can't be reused. Made public for testing. + void CleanupIdleSockets(bool force); + + // Closes one idle socket. Picks the first one encountered. + // TODO(willchan): Consider a better algorithm for doing this. Perhaps we + // should keep an ordered list of idle sockets, and close them in order. + // Requires maintaining more state. It's not clear if it's worth it since + // I'm not sure if we hit this situation often. + bool CloseOneIdleSocket(); + + // Checks layered pools to see if they can close an idle connection. + bool CloseOneIdleConnectionInLayeredPool(); + + // See ClientSocketPool::GetInfoAsValue for documentation on this function. + base::DictionaryValue* GetInfoAsValue(const std::string& name, + const std::string& type) const; + + base::TimeDelta ConnectionTimeout() const { + return connect_job_factory_->ConnectionTimeout(); + } + + static bool connect_backup_jobs_enabled(); + static bool set_connect_backup_jobs_enabled(bool enabled); + + void EnableConnectBackupJobs(); + + // ConnectJob::Delegate methods: + virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE; + + // NetworkChangeNotifier::IPAddressObserver methods: + virtual void OnIPAddressChanged() OVERRIDE; + + private: + friend class base::RefCounted<ClientSocketPoolBaseHelper>; + + // Entry for a persistent socket which became idle at time |start_time|. + struct IdleSocket { + IdleSocket() : socket(NULL) {} + + // An idle socket should be removed if it can't be reused, or has been idle + // for too long. |now| is the current time value (TimeTicks::Now()). + // |timeout| is the length of time to wait before timing out an idle socket. + // + // An idle socket can't be reused if it is disconnected or has received + // data unexpectedly (hence no longer idle). The unread data would be + // mistaken for the beginning of the next response if we were to reuse the + // socket for a new request. + bool ShouldCleanup(base::TimeTicks now, base::TimeDelta timeout) const; + + StreamSocket* socket; + base::TimeTicks start_time; + }; + + typedef std::deque<const Request* > RequestQueue; + typedef std::map<const ClientSocketHandle*, const Request*> RequestMap; + + // A Group is allocated per group_name when there are idle sockets or pending + // requests. Otherwise, the Group object is removed from the map. + // |active_socket_count| tracks the number of sockets held by clients. + class Group { + public: + Group(); + ~Group(); + + bool IsEmpty() const { + return active_socket_count_ == 0 && idle_sockets_.empty() && + jobs_.empty() && pending_requests_.empty(); + } + + bool HasAvailableSocketSlot(int max_sockets_per_group) const { + return NumActiveSocketSlots() < max_sockets_per_group; + } + + int NumActiveSocketSlots() const { + return active_socket_count_ + static_cast<int>(jobs_.size()) + + static_cast<int>(idle_sockets_.size()); + } + + bool IsStalledOnPoolMaxSockets(int max_sockets_per_group) const { + return HasAvailableSocketSlot(max_sockets_per_group) && + pending_requests_.size() > jobs_.size(); + } + + RequestPriority TopPendingPriority() const { + return pending_requests_.front()->priority(); + } + + bool HasBackupJob() const { return weak_factory_.HasWeakPtrs(); } + + void CleanupBackupJob() { + weak_factory_.InvalidateWeakPtrs(); + } + + // Set a timer to create a backup socket if it takes too long to create one. + void StartBackupSocketTimer(const std::string& group_name, + ClientSocketPoolBaseHelper* pool); + + // If there's a ConnectJob that's never been assigned to Request, + // decrements |unassigned_job_count_| and returns true. + // Otherwise, returns false. + bool TryToUseUnassignedConnectJob(); + + void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect); + // Remove |job| from this group, which must already own |job|. + void RemoveJob(ConnectJob* job); + void RemoveAllJobs(); + + void IncrementActiveSocketCount() { active_socket_count_++; } + void DecrementActiveSocketCount() { active_socket_count_--; } + + int unassigned_job_count() const { return unassigned_job_count_; } + const std::set<ConnectJob*>& jobs() const { return jobs_; } + const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; } + const RequestQueue& pending_requests() const { return pending_requests_; } + int active_socket_count() const { return active_socket_count_; } + RequestQueue* mutable_pending_requests() { return &pending_requests_; } + std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; } + + private: + // Called when the backup socket timer fires. + void OnBackupSocketTimerFired( + std::string group_name, + ClientSocketPoolBaseHelper* pool); + + // Checks that |unassigned_job_count_| does not execeed the number of + // ConnectJobs. + void SanityCheck(); + + // Total number of ConnectJobs that have never been assigned to a Request. + // Since jobs use late binding to requests, which ConnectJobs have or have + // not been assigned to a request are not tracked. This is incremented on + // preconnect and decremented when a preconnect is assigned, or when there + // are fewer than |unassigned_job_count_| ConnectJobs. Not incremented + // when a request is cancelled. + size_t unassigned_job_count_; + + std::list<IdleSocket> idle_sockets_; + std::set<ConnectJob*> jobs_; + RequestQueue pending_requests_; + int active_socket_count_; // number of active sockets used by clients + // A factory to pin the backup_job tasks. + base::WeakPtrFactory<Group> weak_factory_; + }; + + typedef std::map<std::string, Group*> GroupMap; + + typedef std::set<ConnectJob*> ConnectJobSet; + + struct CallbackResultPair { + CallbackResultPair(); + CallbackResultPair(const CompletionCallback& callback_in, int result_in); + ~CallbackResultPair(); + + CompletionCallback callback; + int result; + }; + + typedef std::map<const ClientSocketHandle*, CallbackResultPair> + PendingCallbackMap; + + // Inserts the request into the queue based on order they will receive + // sockets. Sockets which ignore the socket pool limits are first. Then + // requests are sorted by priority, with higher priorities closer to the + // front. Older requests are prioritized over requests of equal priority. + static void InsertRequestIntoQueue(const Request* r, + RequestQueue* pending_requests); + static const Request* RemoveRequestFromQueue(const RequestQueue::iterator& it, + Group* group); + + Group* GetOrCreateGroup(const std::string& group_name); + void RemoveGroup(const std::string& group_name); + void RemoveGroup(GroupMap::iterator it); + + // Called when the number of idle sockets changes. + void IncrementIdleCount(); + void DecrementIdleCount(); + + // Start cleanup timer for idle sockets. + void StartIdleSocketTimer(); + + // Scans the group map for groups which have an available socket slot and + // at least one pending request. Returns true if any groups are stalled, and + // if so (and if both |group| and |group_name| are not NULL), fills |group| + // and |group_name| with data of the stalled group having highest priority. + bool FindTopStalledGroup(Group** group, std::string* group_name) const; + + // Called when timer_ fires. This method scans the idle sockets removing + // sockets that timed out or can't be reused. + void OnCleanupTimerFired() { + CleanupIdleSockets(false); + } + + // Removes |job| from |group|, which must already own |job|. + void RemoveConnectJob(ConnectJob* job, Group* group); + + // Tries to see if we can handle any more requests for |group|. + void OnAvailableSocketSlot(const std::string& group_name, Group* group); + + // Process a pending socket request for a group. + void ProcessPendingRequest(const std::string& group_name, Group* group); + + // Assigns |socket| to |handle| and updates |group|'s counters appropriately. + void HandOutSocket(scoped_ptr<StreamSocket> socket, + bool reused, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + base::TimeDelta time_idle, + Group* group, + const BoundNetLog& net_log); + + // Adds |socket| to the list of idle sockets for |group|. + void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group); + + // Iterates through |group_map_|, canceling all ConnectJobs and deleting + // groups if they are no longer needed. + void CancelAllConnectJobs(); + + // Iterates through |group_map_|, posting |error| callbacks for all + // requests, and then deleting groups if they are no longer needed. + void CancelAllRequestsWithError(int error); + + // Returns true if we can't create any more sockets due to the total limit. + bool ReachedMaxSocketsLimit() const; + + // This is the internal implementation of RequestSocket(). It differs in that + // it does not handle logging into NetLog of the queueing status of + // |request|. + int RequestSocketInternal(const std::string& group_name, + const Request* request); + + // Assigns an idle socket for the group to the request. + // Returns |true| if an idle socket is available, false otherwise. + bool AssignIdleSocketToRequest(const Request* request, Group* group); + + static void LogBoundConnectJobToRequest( + const NetLog::Source& connect_job_source, const Request* request); + + // Same as CloseOneIdleSocket() except it won't close an idle socket in + // |group|. If |group| is NULL, it is ignored. Returns true if it closed a + // socket. + bool CloseOneIdleSocketExceptInGroup(const Group* group); + + // Checks if there are stalled socket groups that should be notified + // for possible wakeup. + void CheckForStalledSocketGroups(); + + // Posts a task to call InvokeUserCallback() on the next iteration through the + // current message loop. Inserts |callback| into |pending_callback_map_|, + // keyed by |handle|. + void InvokeUserCallbackLater( + ClientSocketHandle* handle, const CompletionCallback& callback, int rv); + + // Invokes the user callback for |handle|. By the time this task has run, + // it's possible that the request has been cancelled, so |handle| may not + // exist in |pending_callback_map_|. We look up the callback and result code + // in |pending_callback_map_|. + void InvokeUserCallback(ClientSocketHandle* handle); + + // Tries to close idle sockets in a higher level socket pool as long as this + // this pool is stalled. + void TryToCloseSocketsInLayeredPools(); + + GroupMap group_map_; + + // Map of the ClientSocketHandles for which we have a pending Task to invoke a + // callback. This is necessary since, before we invoke said callback, it's + // possible that the request is cancelled. + PendingCallbackMap pending_callback_map_; + + // Timer used to periodically prune idle sockets that timed out or can't be + // reused. + base::RepeatingTimer<ClientSocketPoolBaseHelper> timer_; + + // The total number of idle sockets in the system. + int idle_socket_count_; + + // Number of connecting sockets across all groups. + int connecting_socket_count_; + + // Number of connected sockets we handed out across all groups. + int handed_out_socket_count_; + + // The maximum total number of sockets. See ReachedMaxSocketsLimit. + const int max_sockets_; + + // The maximum number of sockets kept per group. + const int max_sockets_per_group_; + + // Whether to use timer to cleanup idle sockets. + bool use_cleanup_timer_; + + // The time to wait until closing idle sockets. + const base::TimeDelta unused_idle_socket_timeout_; + const base::TimeDelta used_idle_socket_timeout_; + + const scoped_ptr<ConnectJobFactory> connect_job_factory_; + + // TODO(vandebo) Remove when backup jobs move to TransportClientSocketPool + bool connect_backup_jobs_enabled_; + + // A unique id for the pool. It gets incremented every time we + // FlushWithError() the pool. This is so that when sockets get released back + // to the pool, we can make sure that they are discarded rather than reused. + int pool_generation_number_; + + std::set<LayeredPool*> higher_layer_pools_; + + base::WeakPtrFactory<ClientSocketPoolBaseHelper> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBaseHelper); +}; + +} // namespace internal + +template <typename SocketParams> +class ClientSocketPoolBase { + public: + class Request : public internal::ClientSocketPoolBaseHelper::Request { + public: + Request(ClientSocketHandle* handle, + const CompletionCallback& callback, + RequestPriority priority, + internal::ClientSocketPoolBaseHelper::Flags flags, + bool ignore_limits, + const scoped_refptr<SocketParams>& params, + const BoundNetLog& net_log) + : internal::ClientSocketPoolBaseHelper::Request( + handle, callback, priority, ignore_limits, flags, net_log), + params_(params) {} + + const scoped_refptr<SocketParams>& params() const { return params_; } + + private: + const scoped_refptr<SocketParams> params_; + }; + + class ConnectJobFactory { + public: + ConnectJobFactory() {} + virtual ~ConnectJobFactory() {} + + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const Request& request, + ConnectJob::Delegate* delegate) const = 0; + + virtual base::TimeDelta ConnectionTimeout() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory); + }; + + // |max_sockets| is the maximum number of sockets to be maintained by this + // ClientSocketPool. |max_sockets_per_group| specifies the maximum number of + // sockets a "group" can have. |unused_idle_socket_timeout| specifies how + // long to leave an unused idle socket open before closing it. + // |used_idle_socket_timeout| specifies how long to leave a previously used + // idle socket open before closing it. + ClientSocketPoolBase( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout, + ConnectJobFactory* connect_job_factory) + : histograms_(histograms), + helper_(max_sockets, max_sockets_per_group, + unused_idle_socket_timeout, used_idle_socket_timeout, + new ConnectJobFactoryAdaptor(connect_job_factory)) {} + + virtual ~ClientSocketPoolBase() {} + + // These member functions simply forward to ClientSocketPoolBaseHelper. + void AddLayeredPool(LayeredPool* pool) { + helper_.AddLayeredPool(pool); + } + + void RemoveLayeredPool(LayeredPool* pool) { + helper_.RemoveLayeredPool(pool); + } + + // RequestSocket bundles up the parameters into a Request and then forwards to + // ClientSocketPoolBaseHelper::RequestSocket(). + int RequestSocket(const std::string& group_name, + const scoped_refptr<SocketParams>& params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { + Request* request = + new Request(handle, callback, priority, + internal::ClientSocketPoolBaseHelper::NORMAL, + params->ignore_limits(), + params, net_log); + return helper_.RequestSocket(group_name, request); + } + + // RequestSockets bundles up the parameters into a Request and then forwards + // to ClientSocketPoolBaseHelper::RequestSockets(). Note that it assigns the + // priority to DEFAULT_PRIORITY and specifies the NO_IDLE_SOCKETS flag. + void RequestSockets(const std::string& group_name, + const scoped_refptr<SocketParams>& params, + int num_sockets, + const BoundNetLog& net_log) { + const Request request(NULL /* no handle */, + CompletionCallback(), + DEFAULT_PRIORITY, + internal::ClientSocketPoolBaseHelper::NO_IDLE_SOCKETS, + params->ignore_limits(), + params, + net_log); + helper_.RequestSockets(group_name, request, num_sockets); + } + + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) { + return helper_.CancelRequest(group_name, handle); + } + + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + return helper_.ReleaseSocket(group_name, socket.Pass(), id); + } + + void FlushWithError(int error) { helper_.FlushWithError(error); } + + bool IsStalled() const { return helper_.IsStalled(); } + + void CloseIdleSockets() { return helper_.CloseIdleSockets(); } + + int idle_socket_count() const { return helper_.idle_socket_count(); } + + int IdleSocketCountInGroup(const std::string& group_name) const { + return helper_.IdleSocketCountInGroup(group_name); + } + + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const { + return helper_.GetLoadState(group_name, handle); + } + + virtual void OnConnectJobComplete(int result, ConnectJob* job) { + return helper_.OnConnectJobComplete(result, job); + } + + int NumUnassignedConnectJobsInGroup(const std::string& group_name) const { + return helper_.NumUnassignedConnectJobsInGroup(group_name); + } + + int NumConnectJobsInGroup(const std::string& group_name) const { + return helper_.NumConnectJobsInGroup(group_name); + } + + int NumActiveSocketsInGroup(const std::string& group_name) const { + return helper_.NumActiveSocketsInGroup(group_name); + } + + bool HasGroup(const std::string& group_name) const { + return helper_.HasGroup(group_name); + } + + void CleanupIdleSockets(bool force) { + return helper_.CleanupIdleSockets(force); + } + + base::DictionaryValue* GetInfoAsValue(const std::string& name, + const std::string& type) const { + return helper_.GetInfoAsValue(name, type); + } + + base::TimeDelta ConnectionTimeout() const { + return helper_.ConnectionTimeout(); + } + + ClientSocketPoolHistograms* histograms() const { + return histograms_; + } + + void EnableConnectBackupJobs() { helper_.EnableConnectBackupJobs(); } + + bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); } + + bool CloseOneIdleConnectionInLayeredPool() { + return helper_.CloseOneIdleConnectionInLayeredPool(); + } + + private: + // This adaptor class exists to bridge the + // internal::ClientSocketPoolBaseHelper::ConnectJobFactory and + // ClientSocketPoolBase::ConnectJobFactory types, allowing clients to use the + // typesafe ClientSocketPoolBase::ConnectJobFactory, rather than having to + // static_cast themselves. + class ConnectJobFactoryAdaptor + : public internal::ClientSocketPoolBaseHelper::ConnectJobFactory { + public: + typedef typename ClientSocketPoolBase<SocketParams>::ConnectJobFactory + ConnectJobFactory; + + explicit ConnectJobFactoryAdaptor(ConnectJobFactory* connect_job_factory) + : connect_job_factory_(connect_job_factory) {} + virtual ~ConnectJobFactoryAdaptor() {} + + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const internal::ClientSocketPoolBaseHelper::Request& request, + ConnectJob::Delegate* delegate) const OVERRIDE { + const Request& casted_request = static_cast<const Request&>(request); + return connect_job_factory_->NewConnectJob( + group_name, casted_request, delegate); + } + + virtual base::TimeDelta ConnectionTimeout() const { + return connect_job_factory_->ConnectionTimeout(); + } + + const scoped_ptr<ConnectJobFactory> connect_job_factory_; + }; + + // Histograms for the pool + ClientSocketPoolHistograms* const histograms_; + internal::ClientSocketPoolBaseHelper helper_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBase); +}; + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_ diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc new file mode 100644 index 00000000000..6688e01244d --- /dev/null +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -0,0 +1,4168 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool_base.h" + +#include <vector> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_vector.h" +#include "base/memory/weak_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" +#include "base/threading/platform_thread.h" +#include "base/values.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/request_priority.h" +#include "net/base/test_completion_callback.h" +#include "net/http/http_response_headers.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::Invoke; +using ::testing::Return; + +namespace net { + +namespace { + +const int kDefaultMaxSockets = 4; +const int kDefaultMaxSocketsPerGroup = 2; +const net::RequestPriority kDefaultPriority = MEDIUM; + +// Make sure |handle| sets load times correctly when it has been assigned a +// reused socket. +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + // Only pass true in as |is_reused|, as in general, HttpStream types should + // have stricter concepts of reuse than socket pools. + EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); + + EXPECT_EQ(true, load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +// Make sure |handle| sets load times correctly when it has been assigned a +// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner +// of a connection where |is_reused| is false may consider the connection +// reused. +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { + EXPECT_FALSE(handle.is_reused()); + + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); + + TestLoadTimingInfoConnectedReused(handle); +} + +// Make sure |handle| sets load times correctly, in the case that it does not +// currently have a socket. +void TestLoadTimingInfoNotConnected(const ClientSocketHandle& handle) { + // Should only be set to true once a socket is assigned, if at all. + EXPECT_FALSE(handle.is_reused()); + + LoadTimingInfo load_timing_info; + EXPECT_FALSE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +class TestSocketParams : public base::RefCounted<TestSocketParams> { + public: + TestSocketParams() : ignore_limits_(false) {} + + void set_ignore_limits(bool ignore_limits) { + ignore_limits_ = ignore_limits; + } + bool ignore_limits() { return ignore_limits_; } + + private: + friend class base::RefCounted<TestSocketParams>; + ~TestSocketParams() {} + + bool ignore_limits_; +}; +typedef ClientSocketPoolBase<TestSocketParams> TestClientSocketPoolBase; + +class MockClientSocket : public StreamSocket { + public: + explicit MockClientSocket(net::NetLog* net_log) + : connected_(false), + net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_SOCKET)), + was_used_to_convey_data_(false) { + } + + // Socket implementation. + virtual int Read( + IOBuffer* /* buf */, int len, + const CompletionCallback& /* callback */) OVERRIDE { + return ERR_UNEXPECTED; + } + + virtual int Write( + IOBuffer* /* buf */, int len, + const CompletionCallback& /* callback */) OVERRIDE { + was_used_to_convey_data_ = true; + return len; + } + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; } + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + connected_ = true; + return OK; + } + + virtual void Disconnect() OVERRIDE { connected_ = false; } + virtual bool IsConnected() const OVERRIDE { return connected_; } + virtual bool IsConnectedAndIdle() const OVERRIDE { return connected_; } + + virtual int GetPeerAddress(IPEndPoint* /* address */) const OVERRIDE { + return ERR_UNEXPECTED; + } + + virtual int GetLocalAddress(IPEndPoint* /* address */) const OVERRIDE { + return ERR_UNEXPECTED; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { + return was_used_to_convey_data_; + } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return false; + } + + private: + bool connected_; + BoundNetLog net_log_; + bool was_used_to_convey_data_; + + DISALLOW_COPY_AND_ASSIGN(MockClientSocket); +}; + +class TestConnectJob; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + MockClientSocketFactory() : allocation_count_(0) {} + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE { + NOTREACHED(); + return scoped_ptr<DatagramClientSocket>(); + } + + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /*source*/) OVERRIDE { + allocation_count_++; + return scoped_ptr<StreamSocket>(); + } + + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); + } + + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + void WaitForSignal(TestConnectJob* job) { waiting_jobs_.push_back(job); } + + void SignalJobs(); + + void SignalJob(size_t job); + + void SetJobLoadState(size_t job, LoadState load_state); + + int allocation_count() const { return allocation_count_; } + + private: + int allocation_count_; + std::vector<TestConnectJob*> waiting_jobs_; +}; + +class TestConnectJob : public ConnectJob { + public: + enum JobType { + kMockJob, + kMockFailingJob, + kMockPendingJob, + kMockPendingFailingJob, + kMockWaitingJob, + kMockRecoverableJob, + kMockPendingRecoverableJob, + kMockAdditionalErrorStateJob, + kMockPendingAdditionalErrorStateJob, + }; + + // The kMockPendingJob uses a slight delay before allowing the connect + // to complete. + static const int kPendingConnectDelay = 2; + + TestConnectJob(JobType job_type, + const std::string& group_name, + const TestClientSocketPoolBase::Request& request, + base::TimeDelta timeout_duration, + ConnectJob::Delegate* delegate, + MockClientSocketFactory* client_socket_factory, + NetLog* net_log) + : ConnectJob(group_name, timeout_duration, delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + job_type_(job_type), + client_socket_factory_(client_socket_factory), + weak_factory_(this), + load_state_(LOAD_STATE_IDLE), + store_additional_error_state_(false) {} + + void Signal() { + DoConnect(waiting_success_, true /* async */, false /* recoverable */); + } + + void set_load_state(LoadState load_state) { load_state_ = load_state; } + + // From ConnectJob: + + virtual LoadState GetLoadState() const OVERRIDE { return load_state_; } + + virtual void GetAdditionalErrorState(ClientSocketHandle* handle) OVERRIDE { + if (store_additional_error_state_) { + // Set all of the additional error state fields in some way. + handle->set_is_ssl_error(true); + HttpResponseInfo info; + info.headers = new HttpResponseHeaders(std::string()); + handle->set_ssl_error_response_info(info); + } + } + + private: + // From ConnectJob: + + virtual int ConnectInternal() OVERRIDE { + AddressList ignored; + client_socket_factory_->CreateTransportClientSocket( + ignored, NULL, net::NetLog::Source()); + SetSocket( + scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log()))); + switch (job_type_) { + case kMockJob: + return DoConnect(true /* successful */, false /* sync */, + false /* recoverable */); + case kMockFailingJob: + return DoConnect(false /* error */, false /* sync */, + false /* recoverable */); + case kMockPendingJob: + set_load_state(LOAD_STATE_CONNECTING); + + // Depending on execution timings, posting a delayed task can result + // in the task getting executed the at the earliest possible + // opportunity or only after returning once from the message loop and + // then a second call into the message loop. In order to make behavior + // more deterministic, we change the default delay to 2ms. This should + // always require us to wait for the second call into the message loop. + // + // N.B. The correct fix for this and similar timing problems is to + // abstract time for the purpose of unittests. Unfortunately, we have + // a lot of third-party components that directly call the various + // time functions, so this change would be rather invasive. + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect), + weak_factory_.GetWeakPtr(), + true /* successful */, + true /* async */, + false /* recoverable */), + base::TimeDelta::FromMilliseconds(kPendingConnectDelay)); + return ERR_IO_PENDING; + case kMockPendingFailingJob: + set_load_state(LOAD_STATE_CONNECTING); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect), + weak_factory_.GetWeakPtr(), + false /* error */, + true /* async */, + false /* recoverable */), + base::TimeDelta::FromMilliseconds(2)); + return ERR_IO_PENDING; + case kMockWaitingJob: + set_load_state(LOAD_STATE_CONNECTING); + client_socket_factory_->WaitForSignal(this); + waiting_success_ = true; + return ERR_IO_PENDING; + case kMockRecoverableJob: + return DoConnect(false /* error */, false /* sync */, + true /* recoverable */); + case kMockPendingRecoverableJob: + set_load_state(LOAD_STATE_CONNECTING); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect), + weak_factory_.GetWeakPtr(), + false /* error */, + true /* async */, + true /* recoverable */), + base::TimeDelta::FromMilliseconds(2)); + return ERR_IO_PENDING; + case kMockAdditionalErrorStateJob: + store_additional_error_state_ = true; + return DoConnect(false /* error */, false /* sync */, + false /* recoverable */); + case kMockPendingAdditionalErrorStateJob: + set_load_state(LOAD_STATE_CONNECTING); + store_additional_error_state_ = true; + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect), + weak_factory_.GetWeakPtr(), + false /* error */, + true /* async */, + false /* recoverable */), + base::TimeDelta::FromMilliseconds(2)); + return ERR_IO_PENDING; + default: + NOTREACHED(); + SetSocket(scoped_ptr<StreamSocket>()); + return ERR_FAILED; + } + } + + int DoConnect(bool succeed, bool was_async, bool recoverable) { + int result = OK; + if (succeed) { + socket()->Connect(CompletionCallback()); + } else if (recoverable) { + result = ERR_PROXY_AUTH_REQUESTED; + } else { + result = ERR_CONNECTION_FAILED; + SetSocket(scoped_ptr<StreamSocket>()); + } + + if (was_async) + NotifyDelegateOfCompletion(result); + return result; + } + + bool waiting_success_; + const JobType job_type_; + MockClientSocketFactory* const client_socket_factory_; + base::WeakPtrFactory<TestConnectJob> weak_factory_; + LoadState load_state_; + bool store_additional_error_state_; + + DISALLOW_COPY_AND_ASSIGN(TestConnectJob); +}; + +class TestConnectJobFactory + : public TestClientSocketPoolBase::ConnectJobFactory { + public: + TestConnectJobFactory(MockClientSocketFactory* client_socket_factory, + NetLog* net_log) + : job_type_(TestConnectJob::kMockJob), + job_types_(NULL), + client_socket_factory_(client_socket_factory), + net_log_(net_log) { + } + + virtual ~TestConnectJobFactory() {} + + void set_job_type(TestConnectJob::JobType job_type) { job_type_ = job_type; } + + void set_job_types(std::list<TestConnectJob::JobType>* job_types) { + job_types_ = job_types; + CHECK(!job_types_->empty()); + } + + void set_timeout_duration(base::TimeDelta timeout_duration) { + timeout_duration_ = timeout_duration; + } + + // ConnectJobFactory implementation. + + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const TestClientSocketPoolBase::Request& request, + ConnectJob::Delegate* delegate) const OVERRIDE { + EXPECT_TRUE(!job_types_ || !job_types_->empty()); + TestConnectJob::JobType job_type = job_type_; + if (job_types_ && !job_types_->empty()) { + job_type = job_types_->front(); + job_types_->pop_front(); + } + return scoped_ptr<ConnectJob>(new TestConnectJob(job_type, + group_name, + request, + timeout_duration_, + delegate, + client_socket_factory_, + net_log_)); + } + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { + return timeout_duration_; + } + + private: + TestConnectJob::JobType job_type_; + std::list<TestConnectJob::JobType>* job_types_; + base::TimeDelta timeout_duration_; + MockClientSocketFactory* const client_socket_factory_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(TestConnectJobFactory); +}; + +class TestClientSocketPool : public ClientSocketPool { + public: + TestClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout, + TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory) + : base_(max_sockets, max_sockets_per_group, histograms, + unused_idle_socket_timeout, used_idle_socket_timeout, + connect_job_factory) {} + + virtual ~TestClientSocketPool() {} + + virtual int RequestSocket( + const std::string& group_name, + const void* params, + net::RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE { + const scoped_refptr<TestSocketParams>* casted_socket_params = + static_cast<const scoped_refptr<TestSocketParams>*>(params); + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); + } + + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) OVERRIDE { + const scoped_refptr<TestSocketParams>* casted_params = + static_cast<const scoped_refptr<TestSocketParams>*>(params); + + base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); + } + + virtual void CancelRequest( + const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE { + base_.CancelRequest(group_name, handle); + } + + virtual void ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE { + base_.ReleaseSocket(group_name, socket.Pass(), id); + } + + virtual void FlushWithError(int error) OVERRIDE { + base_.FlushWithError(error); + } + + virtual bool IsStalled() const OVERRIDE { + return base_.IsStalled(); + } + + virtual void CloseIdleSockets() OVERRIDE { + base_.CloseIdleSockets(); + } + + virtual int IdleSocketCount() const OVERRIDE { + return base_.idle_socket_count(); + } + + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE { + return base_.IdleSocketCountInGroup(group_name); + } + + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE { + return base_.GetLoadState(group_name, handle); + } + + virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE { + base_.AddLayeredPool(pool); + } + + virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE { + base_.RemoveLayeredPool(pool); + } + + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE { + return base_.GetInfoAsValue(name, type); + } + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { + return base_.ConnectionTimeout(); + } + + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE { + return base_.histograms(); + } + + const TestClientSocketPoolBase* base() const { return &base_; } + + int NumUnassignedConnectJobsInGroup(const std::string& group_name) const { + return base_.NumUnassignedConnectJobsInGroup(group_name); + } + + int NumConnectJobsInGroup(const std::string& group_name) const { + return base_.NumConnectJobsInGroup(group_name); + } + + int NumActiveSocketsInGroup(const std::string& group_name) const { + return base_.NumActiveSocketsInGroup(group_name); + } + + bool HasGroup(const std::string& group_name) const { + return base_.HasGroup(group_name); + } + + void CleanupTimedOutIdleSockets() { base_.CleanupIdleSockets(false); } + + void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); } + + bool CloseOneIdleConnectionInLayeredPool() { + return base_.CloseOneIdleConnectionInLayeredPool(); + } + + private: + TestClientSocketPoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(TestClientSocketPool); +}; + +} // namespace + +REGISTER_SOCKET_PARAMS_FOR_POOL(TestClientSocketPool, TestSocketParams); + +namespace { + +void MockClientSocketFactory::SignalJobs() { + for (std::vector<TestConnectJob*>::iterator it = waiting_jobs_.begin(); + it != waiting_jobs_.end(); ++it) { + (*it)->Signal(); + } + waiting_jobs_.clear(); +} + +void MockClientSocketFactory::SignalJob(size_t job) { + ASSERT_LT(job, waiting_jobs_.size()); + waiting_jobs_[job]->Signal(); + waiting_jobs_.erase(waiting_jobs_.begin() + job); +} + +void MockClientSocketFactory::SetJobLoadState(size_t job, + LoadState load_state) { + ASSERT_LT(job, waiting_jobs_.size()); + waiting_jobs_[job]->set_load_state(load_state); +} + +class TestConnectJobDelegate : public ConnectJob::Delegate { + public: + TestConnectJobDelegate() + : have_result_(false), waiting_for_result_(false), result_(OK) {} + virtual ~TestConnectJobDelegate() {} + + virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE { + result_ = result; + scoped_ptr<ConnectJob> owned_job(job); + scoped_ptr<StreamSocket> socket = owned_job->PassSocket(); + // socket.get() should be NULL iff result != OK + EXPECT_EQ(socket == NULL, result != OK); + have_result_ = true; + if (waiting_for_result_) + base::MessageLoop::current()->Quit(); + } + + int WaitForResult() { + DCHECK(!waiting_for_result_); + while (!have_result_) { + waiting_for_result_ = true; + base::MessageLoop::current()->Run(); + waiting_for_result_ = false; + } + have_result_ = false; // auto-reset for next callback + return result_; + } + + private: + bool have_result_; + bool waiting_for_result_; + int result_; +}; + +class ClientSocketPoolBaseTest : public testing::Test { + protected: + ClientSocketPoolBaseTest() + : params_(new TestSocketParams()), + histograms_("ClientSocketPoolTest") { + connect_backup_jobs_enabled_ = + internal::ClientSocketPoolBaseHelper::connect_backup_jobs_enabled(); + internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true); + cleanup_timer_enabled_ = + internal::ClientSocketPoolBaseHelper::cleanup_timer_enabled(); + } + + virtual ~ClientSocketPoolBaseTest() { + internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled( + connect_backup_jobs_enabled_); + internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled( + cleanup_timer_enabled_); + } + + void CreatePool(int max_sockets, int max_sockets_per_group) { + CreatePoolWithIdleTimeouts( + max_sockets, + max_sockets_per_group, + ClientSocketPool::unused_idle_socket_timeout(), + ClientSocketPool::used_idle_socket_timeout()); + } + + void CreatePoolWithIdleTimeouts( + int max_sockets, int max_sockets_per_group, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout) { + DCHECK(!pool_.get()); + connect_job_factory_ = new TestConnectJobFactory(&client_socket_factory_, + &net_log_); + pool_.reset(new TestClientSocketPool(max_sockets, + max_sockets_per_group, + &histograms_, + unused_idle_socket_timeout, + used_idle_socket_timeout, + connect_job_factory_)); + } + + int StartRequestWithParams( + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<TestSocketParams>& params) { + return test_base_.StartRequestUsingPool< + TestClientSocketPool, TestSocketParams>( + pool_.get(), group_name, priority, params); + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + return StartRequestWithParams(group_name, priority, params_); + } + + int GetOrderOfRequest(size_t index) const { + return test_base_.GetOrderOfRequest(index); + } + + bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) { + return test_base_.ReleaseOneConnection(keep_alive); + } + + void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) { + test_base_.ReleaseAllConnections(keep_alive); + } + + TestSocketRequest* request(int i) { return test_base_.request(i); } + size_t requests_size() const { return test_base_.requests_size(); } + ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } + size_t completion_count() const { return test_base_.completion_count(); } + + CapturingNetLog net_log_; + bool connect_backup_jobs_enabled_; + bool cleanup_timer_enabled_; + MockClientSocketFactory client_socket_factory_; + TestConnectJobFactory* connect_job_factory_; + scoped_refptr<TestSocketParams> params_; + ClientSocketPoolHistograms histograms_; + scoped_ptr<TestClientSocketPool> pool_; + ClientSocketPoolTest test_base_; +}; + +// Even though a timeout is specified, it doesn't time out on a synchronous +// completion. +TEST_F(ClientSocketPoolBaseTest, ConnectJob_NoTimeoutOnSynchronousCompletion) { + TestConnectJobDelegate delegate; + ClientSocketHandle ignored; + TestClientSocketPoolBase::Request request( + &ignored, CompletionCallback(), kDefaultPriority, + internal::ClientSocketPoolBaseHelper::NORMAL, + false, params_, BoundNetLog()); + scoped_ptr<TestConnectJob> job( + new TestConnectJob(TestConnectJob::kMockJob, + "a", + request, + base::TimeDelta::FromMicroseconds(1), + &delegate, + &client_socket_factory_, + NULL)); + EXPECT_EQ(OK, job->Connect()); +} + +TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) { + TestConnectJobDelegate delegate; + ClientSocketHandle ignored; + CapturingNetLog log; + + TestClientSocketPoolBase::Request request( + &ignored, CompletionCallback(), kDefaultPriority, + internal::ClientSocketPoolBaseHelper::NORMAL, + false, params_, BoundNetLog()); + // Deleted by TestConnectJobDelegate. + TestConnectJob* job = + new TestConnectJob(TestConnectJob::kMockPendingJob, + "a", + request, + base::TimeDelta::FromMicroseconds(1), + &delegate, + &client_socket_factory_, + &log); + ASSERT_EQ(ERR_IO_PENDING, job->Connect()); + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(1)); + EXPECT_EQ(ERR_TIMED_OUT, delegate.WaitForResult()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + + EXPECT_EQ(6u, entries.size()); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB)); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 1, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT)); + EXPECT_TRUE(LogContainsEvent( + entries, 2, NetLog::TYPE_CONNECT_JOB_SET_SOCKET, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEvent( + entries, 3, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 4, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 5, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB)); +} + +TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + TestCompletionCallback callback; + ClientSocketHandle handle; + CapturingBoundNetLog log; + TestLoadTimingInfoNotConnected(handle); + + EXPECT_EQ(OK, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + log.bound())); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); + + handle.Reset(); + TestLoadTimingInfoNotConnected(handle); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + + EXPECT_EQ(4u, entries.size()); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKET_POOL)); + EXPECT_TRUE(LogContainsEvent( + entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEvent( + entries, 2, NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 3, NetLog::TYPE_SOCKET_POOL)); +} + +TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); + CapturingBoundNetLog log; + + ClientSocketHandle handle; + TestCompletionCallback callback; + // Set the additional error state members to ensure that they get cleared. + handle.set_is_ssl_error(true); + HttpResponseInfo info; + info.headers = new HttpResponseHeaders(std::string()); + handle.set_ssl_error_response_info(info); + EXPECT_EQ(ERR_CONNECTION_FAILED, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + log.bound())); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); + EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL); + TestLoadTimingInfoNotConnected(handle); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + + EXPECT_EQ(3u, entries.size()); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKET_POOL)); + EXPECT_TRUE(LogContainsEvent( + entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 2, NetLog::TYPE_SOCKET_POOL)); +} + +TEST_F(ClientSocketPoolBaseTest, TotalLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + // TODO(eroman): Check that the NetLog contains this event. + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("c", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("d", kDefaultPriority)); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("f", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("g", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + EXPECT_EQ(7, GetOrderOfRequest(7)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8)); +} + +TEST_F(ClientSocketPoolBaseTest, TotalLimitReachedNewGroup) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + // TODO(eroman): Check that the NetLog contains this event. + + // Reach all limits: max total sockets, and max sockets per group. + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + // Now create a new group and verify that we don't starve it. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(6)); +} + +TEST_F(ClientSocketPoolBaseTest, TotalLimitRespectsPriority) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("b", LOWEST)); + EXPECT_EQ(OK, StartRequest("a", MEDIUM)); + EXPECT_EQ(OK, StartRequest("b", HIGHEST)); + EXPECT_EQ(OK, StartRequest("a", LOWEST)); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", HIGHEST)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + // First 4 requests don't have to wait, and finish in order. + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + + // Request ("b", HIGHEST) has the highest priority, then ("a", MEDIUM), + // and then ("c", LOWEST). + EXPECT_EQ(7, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + EXPECT_EQ(5, GetOrderOfRequest(7)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(9)); +} + +TEST_F(ClientSocketPoolBaseTest, TotalLimitRespectsGroupLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("a", LOWEST)); + EXPECT_EQ(OK, StartRequest("a", LOW)); + EXPECT_EQ(OK, StartRequest("b", HIGHEST)); + EXPECT_EQ(OK, StartRequest("b", MEDIUM)); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", HIGHEST)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count()); + + // First 4 requests don't have to wait, and finish in order. + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + + // Request ("b", 7) has the highest priority, but we can't make new socket for + // group "b", because it has reached the per-group limit. Then we make + // socket for ("c", 6), because it has higher priority than ("a", 4), + // and we still can't make a socket for group "b". + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + EXPECT_EQ(7, GetOrderOfRequest(7)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8)); +} + +// Make sure that we count connecting sockets against the total limit. +TEST_F(ClientSocketPoolBaseTest, TotalLimitCountsConnectingSockets) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("c", kDefaultPriority)); + + // Create one asynchronous request. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("d", kDefaultPriority)); + + // We post all of our delayed tasks with a 2ms delay. I.e. they don't + // actually become pending until 2ms after they have been created. In order + // to flush all tasks, we need to wait so that we know there are no + // soon-to-be-pending tasks waiting. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10)); + base::MessageLoop::current()->RunUntilIdle(); + + // The next synchronous request should wait for its turn. + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(6)); +} + +TEST_F(ClientSocketPoolBaseTest, CorrectlyCountStalledGroups) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority)); + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + EXPECT_EQ(kDefaultMaxSockets + 1, client_socket_factory_.allocation_count()); + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + EXPECT_EQ(kDefaultMaxSockets + 2, client_socket_factory_.allocation_count()); + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + EXPECT_EQ(kDefaultMaxSockets + 2, client_socket_factory_.allocation_count()); +} + +TEST_F(ClientSocketPoolBaseTest, StallAndThenCancelAndTriggerAvailableSocket) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + ClientSocketHandle handles[4]; + for (size_t i = 0; i < arraysize(handles); ++i) { + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handles[i].Init("b", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + } + + // One will be stalled, cancel all the handles now. + // This should hit the OnAvailableSocketSlot() code where we previously had + // stalled groups, but no longer have any. + for (size_t i = 0; i < arraysize(handles); ++i) + handles[i].Reset(); +} + +TEST_F(ClientSocketPoolBaseTest, CancelStalledSocketAtSocketLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + { + ClientSocketHandle handles[kDefaultMaxSockets]; + TestCompletionCallback callbacks[kDefaultMaxSockets]; + for (int i = 0; i < kDefaultMaxSockets; ++i) { + EXPECT_EQ(OK, handles[i].Init(base::IntToString(i), + params_, + kDefaultPriority, + callbacks[i].callback(), + pool_.get(), + BoundNetLog())); + } + + // Force a stalled group. + ClientSocketHandle stalled_handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + // Cancel the stalled request. + stalled_handle.Reset(); + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + EXPECT_EQ(0, pool_->IdleSocketCount()); + + // Dropping out of scope will close all handles and return them to idle. + } + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + EXPECT_EQ(kDefaultMaxSockets, pool_->IdleSocketCount()); +} + +TEST_F(ClientSocketPoolBaseTest, CancelPendingSocketAtSocketLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + { + ClientSocketHandle handles[kDefaultMaxSockets]; + for (int i = 0; i < kDefaultMaxSockets; ++i) { + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handles[i].Init(base::IntToString(i), + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + } + + // Force a stalled group. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle stalled_handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + // Since it is stalled, it should have no connect jobs. + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("foo")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo")); + + // Cancel the stalled request. + handles[0].Reset(); + + // Now we should have a connect job. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("foo")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo")); + + // The stalled socket should connect. + EXPECT_EQ(OK, callback.WaitForResult()); + + EXPECT_EQ(kDefaultMaxSockets + 1, + client_socket_factory_.allocation_count()); + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("foo")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo")); + + // Dropping out of scope will close all handles and return them to idle. + } + + EXPECT_EQ(1, pool_->IdleSocketCount()); +} + +TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + ClientSocketHandle stalled_handle; + TestCompletionCallback callback; + { + EXPECT_FALSE(pool_->IsStalled()); + ClientSocketHandle handles[kDefaultMaxSockets]; + for (int i = 0; i < kDefaultMaxSockets; ++i) { + TestCompletionCallback callback; + EXPECT_EQ(OK, handles[i].Init(base::StringPrintf( + "Take 2: %d", i), + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + } + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_FALSE(pool_->IsStalled()); + + // Now we will hit the socket limit. + EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_TRUE(pool_->IsStalled()); + + // Dropping out of scope will close all handles and return them to idle. + } + + // But if we wait for it, the released idle sockets will be closed in + // preference of the waiting request. + EXPECT_EQ(OK, callback.WaitForResult()); + + EXPECT_EQ(kDefaultMaxSockets + 1, client_socket_factory_.allocation_count()); + EXPECT_EQ(3, pool_->IdleSocketCount()); +} + +// Regression test for http://crbug.com/40952. +TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + pool_->EnableConnectBackupJobs(); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + for (int i = 0; i < kDefaultMaxSockets; ++i) { + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init(base::IntToString(i), + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + } + + // Flush all the DoReleaseSocket tasks. + base::MessageLoop::current()->RunUntilIdle(); + + // Stall a group. Set a pending job so it'll trigger a backup job if we don't + // reuse a socket. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + + // "0" is special here, since it should be the first entry in the sorted map, + // which is the one which we would close an idle socket for. We shouldn't + // close an idle socket though, since we should reuse the idle socket. + EXPECT_EQ(OK, handle.Init("0", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); + EXPECT_EQ(kDefaultMaxSockets - 1, pool_->IdleSocketCount()); +} + +TEST_F(ClientSocketPoolBaseTest, PendingRequests) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", IDLE)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(kDefaultMaxSocketsPerGroup, + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup, + completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(8, GetOrderOfRequest(3)); + EXPECT_EQ(6, GetOrderOfRequest(4)); + EXPECT_EQ(4, GetOrderOfRequest(5)); + EXPECT_EQ(3, GetOrderOfRequest(6)); + EXPECT_EQ(5, GetOrderOfRequest(7)); + EXPECT_EQ(7, GetOrderOfRequest(8)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(9)); +} + +TEST_F(ClientSocketPoolBaseTest, PendingRequests_NoKeepAlive) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + for (size_t i = kDefaultMaxSocketsPerGroup; i < requests_size(); ++i) + EXPECT_EQ(OK, request(i)->WaitForResult()); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup, + completion_count()); +} + +// This test will start up a RequestSocket() and then immediately Cancel() it. +// The pending connect job will be cancelled and should not call back into +// ClientSocketPoolBase. +TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + handle.Reset(); +} + +TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + handle.Reset(); + + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_FALSE(callback.have_result()); + + handle.Reset(); +} + +TEST_F(ClientSocketPoolBaseTest, CancelRequest) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + + // Cancel a request. + size_t index_to_cancel = kDefaultMaxSocketsPerGroup + 2; + EXPECT_FALSE((*requests())[index_to_cancel]->handle()->is_initialized()); + (*requests())[index_to_cancel]->handle()->Reset(); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(kDefaultMaxSocketsPerGroup, + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup - 1, + completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(5, GetOrderOfRequest(3)); + EXPECT_EQ(3, GetOrderOfRequest(4)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, + GetOrderOfRequest(5)); // Canceled request. + EXPECT_EQ(4, GetOrderOfRequest(6)); + EXPECT_EQ(6, GetOrderOfRequest(7)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8)); +} + +class RequestSocketCallback : public TestCompletionCallbackBase { + public: + RequestSocketCallback(ClientSocketHandle* handle, + TestClientSocketPool* pool, + TestConnectJobFactory* test_connect_job_factory, + TestConnectJob::JobType next_job_type) + : handle_(handle), + pool_(pool), + within_callback_(false), + test_connect_job_factory_(test_connect_job_factory), + next_job_type_(next_job_type), + callback_(base::Bind(&RequestSocketCallback::OnComplete, + base::Unretained(this))) { + } + + virtual ~RequestSocketCallback() {} + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + ASSERT_EQ(OK, result); + + if (!within_callback_) { + test_connect_job_factory_->set_job_type(next_job_type_); + + // Don't allow reuse of the socket. Disconnect it and then release it and + // run through the MessageLoop once to get it completely released. + handle_->socket()->Disconnect(); + handle_->Reset(); + { + // TODO: Resolve conflicting intentions of stopping recursion with the + // |!within_callback_| test (above) and the call to |RunUntilIdle()| + // below. http://crbug.com/114130. + base::MessageLoop::ScopedNestableTaskAllower allow( + base::MessageLoop::current()); + base::MessageLoop::current()->RunUntilIdle(); + } + within_callback_ = true; + TestCompletionCallback next_job_callback; + scoped_refptr<TestSocketParams> params(new TestSocketParams()); + int rv = handle_->Init("a", + params, + kDefaultPriority, + next_job_callback.callback(), + pool_, + BoundNetLog()); + switch (next_job_type_) { + case TestConnectJob::kMockJob: + EXPECT_EQ(OK, rv); + break; + case TestConnectJob::kMockPendingJob: + EXPECT_EQ(ERR_IO_PENDING, rv); + + // For pending jobs, wait for new socket to be created. This makes + // sure there are no more pending operations nor any unclosed sockets + // when the test finishes. + // We need to give it a little bit of time to run, so that all the + // operations that happen on timers (e.g. cleanup of idle + // connections) can execute. + { + base::MessageLoop::ScopedNestableTaskAllower allow( + base::MessageLoop::current()); + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10)); + EXPECT_EQ(OK, next_job_callback.WaitForResult()); + } + break; + default: + FAIL() << "Unexpected job type: " << next_job_type_; + break; + } + } + } + + ClientSocketHandle* const handle_; + TestClientSocketPool* const pool_; + bool within_callback_; + TestConnectJobFactory* const test_connect_job_factory_; + TestConnectJob::JobType next_job_type_; + CompletionCallback callback_; +}; + +TEST_F(ClientSocketPoolBaseTest, RequestPendingJobTwice) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + RequestSocketCallback callback( + &handle, pool_.get(), connect_job_factory_, + TestConnectJob::kMockPendingJob); + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(OK, callback.WaitForResult()); +} + +TEST_F(ClientSocketPoolBaseTest, RequestPendingJobThenSynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + RequestSocketCallback callback( + &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockJob); + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(OK, callback.WaitForResult()); +} + +// Make sure that pending requests get serviced after active requests get +// cancelled. +TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestWithPendingRequests) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Now, kDefaultMaxSocketsPerGroup requests should be active. + // Let's cancel them. + for (int i = 0; i < kDefaultMaxSocketsPerGroup; ++i) { + ASSERT_FALSE(request(i)->handle()->is_initialized()); + request(i)->handle()->Reset(); + } + + // Let's wait for the rest to complete now. + for (size_t i = kDefaultMaxSocketsPerGroup; i < requests_size(); ++i) { + EXPECT_EQ(OK, request(i)->WaitForResult()); + request(i)->handle()->Reset(); + } + + EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup, + completion_count()); +} + +// Make sure that pending requests get serviced after active requests fail. +TEST_F(ClientSocketPoolBaseTest, FailingActiveRequestWithPendingRequests) { + const size_t kMaxSockets = 5; + CreatePool(kMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + + const size_t kNumberOfRequests = 2 * kDefaultMaxSocketsPerGroup + 1; + ASSERT_LE(kNumberOfRequests, kMaxSockets); // Otherwise the test will hang. + + // Queue up all the requests + for (size_t i = 0; i < kNumberOfRequests; ++i) + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + for (size_t i = 0; i < kNumberOfRequests; ++i) + EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult()); +} + +TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Cancel the active request. + handle.Reset(); + + rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + + EXPECT_FALSE(handle.is_reused()); + TestLoadTimingInfoConnectedNotReused(handle); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +// Regression test for http://crbug.com/17985. +TEST_F(ClientSocketPoolBaseTest, GroupWithPendingRequestsIsNotEmpty) { + const int kMaxSockets = 3; + const int kMaxSocketsPerGroup = 2; + CreatePool(kMaxSockets, kMaxSocketsPerGroup); + + const RequestPriority kHighPriority = HIGHEST; + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + + // This is going to be a pending request in an otherwise empty group. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Reach the maximum socket limit. + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + + // Create a stalled group with high priorities. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kHighPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kHighPriority)); + + // Release the first two sockets from "a". Because this is a keepalive, + // the first release will unblock the pending request for "a". The + // second release will unblock a request for "c", becaue it is the next + // high priority socket. + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE)); + + // Closing idle sockets should not get us into trouble, but in the bug + // we were hitting a CHECK here. + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + pool_->CloseIdleSockets(); + + // Run the released socket wakeups. + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + CapturingBoundNetLog log; + int rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + log.bound()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + TestLoadTimingInfoNotConnected(handle); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); + + handle.Reset(); + TestLoadTimingInfoNotConnected(handle); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + + EXPECT_EQ(4u, entries.size()); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKET_POOL)); + EXPECT_TRUE(LogContainsEvent( + entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEvent( + entries, 2, NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 3, NetLog::TYPE_SOCKET_POOL)); +} + +TEST_F(ClientSocketPoolBaseTest, + InitConnectionAsynchronousFailure) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + CapturingBoundNetLog log; + // Set the additional error state members to ensure that they get cleared. + handle.set_is_ssl_error(true); + HttpResponseInfo info; + info.headers = new HttpResponseHeaders(std::string()); + handle.set_ssl_error_response_info(info); + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + log.bound())); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_ssl_error()); + EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + + EXPECT_EQ(3u, entries.size()); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKET_POOL)); + EXPECT_TRUE(LogContainsEvent( + entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + NetLog::PHASE_NONE)); + EXPECT_TRUE(LogContainsEndEvent( + entries, 2, NetLog::TYPE_SOCKET_POOL)); +} + +TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { + // TODO(eroman): Add back the log expectations! Removed them because the + // ordering is difficult, and some may fire during destructor. + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + ClientSocketHandle handle2; + TestCompletionCallback callback2; + + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + CapturingBoundNetLog log2; + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + handle.Reset(); + + + // At this point, request 2 is just waiting for the connect job to finish. + + EXPECT_EQ(OK, callback2.WaitForResult()); + handle2.Reset(); + + // Now request 2 has actually finished. + // TODO(eroman): Add back log expectations. +} + +TEST_F(ClientSocketPoolBaseTest, CancelRequestLimitsJobs) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a")); + (*requests())[2]->handle()->Reset(); + (*requests())[3]->handle()->Reset(); + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a")); + + (*requests())[1]->handle()->Reset(); + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a")); + + (*requests())[0]->handle()->Reset(); + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a")); +} + +// When requests and ConnectJobs are not coupled, the request will get serviced +// by whatever comes first. +TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + // Start job 1 (async OK) + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + std::vector<TestSocketRequest*> request_order; + size_t completion_count; // unused + TestSocketRequest req1(&request_order, &completion_count); + int rv = req1.handle()->Init("a", + params_, + kDefaultPriority, + req1.callback(), pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, req1.WaitForResult()); + + // Job 1 finished OK. Start job 2 (also async OK). Request 3 is pending + // without a job. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + TestSocketRequest req2(&request_order, &completion_count); + rv = req2.handle()->Init("a", + params_, + kDefaultPriority, + req2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + TestSocketRequest req3(&request_order, &completion_count); + rv = req3.handle()->Init("a", + params_, + kDefaultPriority, + req3.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Both Requests 2 and 3 are pending. We release socket 1 which should + // service request 2. Request 3 should still be waiting. + req1.handle()->Reset(); + // Run the released socket wakeups. + base::MessageLoop::current()->RunUntilIdle(); + ASSERT_TRUE(req2.handle()->socket()); + EXPECT_EQ(OK, req2.WaitForResult()); + EXPECT_FALSE(req3.handle()->socket()); + + // Signal job 2, which should service request 3. + + client_socket_factory_.SignalJobs(); + EXPECT_EQ(OK, req3.WaitForResult()); + + ASSERT_EQ(3U, request_order.size()); + EXPECT_EQ(&req1, request_order[0]); + EXPECT_EQ(&req2, request_order[1]); + EXPECT_EQ(&req3, request_order[2]); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); +} + +// The requests are not coupled to the jobs. So, the requests should finish in +// their priority / insertion order. +TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + // First two jobs are async. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + + std::vector<TestSocketRequest*> request_order; + size_t completion_count; // unused + TestSocketRequest req1(&request_order, &completion_count); + int rv = req1.handle()->Init("a", + params_, + kDefaultPriority, + req1.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + TestSocketRequest req2(&request_order, &completion_count); + rv = req2.handle()->Init("a", + params_, + kDefaultPriority, + req2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // The pending job is sync. + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + TestSocketRequest req3(&request_order, &completion_count); + rv = req3.handle()->Init("a", + params_, + kDefaultPriority, + req3.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_CONNECTION_FAILED, req1.WaitForResult()); + EXPECT_EQ(OK, req2.WaitForResult()); + EXPECT_EQ(ERR_CONNECTION_FAILED, req3.WaitForResult()); + + ASSERT_EQ(3U, request_order.size()); + EXPECT_EQ(&req1, request_order[0]); + EXPECT_EQ(&req2, request_order[1]); + EXPECT_EQ(&req3, request_order[2]); +} + +// Test GetLoadState in the case there's only one socket request. +TEST_F(ClientSocketPoolBaseTest, LoadStateOneRequest) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); + + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE); + EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState()); + + // No point in completing the connection, since ClientSocketHandles only + // expect the LoadState to be checked while connecting. +} + +// Test GetLoadState in the case there are two socket requests. +TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) { + CreatePool(2, 2); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // If the first Job is in an earlier state than the second, the state of + // the second job should be used for both handles. + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_RESOLVING_HOST); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); + + // If the second Job is in an earlier state than the second, the state of + // the first job should be used for both handles. + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE); + // One request is farther + EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle2.GetLoadState()); + + // Farthest along job connects and the first request gets the socket. The + // second handle switches to the state of the remaining ConnectJob. + client_socket_factory_.SignalJob(0); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); +} + +// Test GetLoadState in the case the per-group limit is reached. +TEST_F(ClientSocketPoolBaseTest, LoadStateGroupLimit) { + CreatePool(2, 1); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + MEDIUM, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); + + // Request another socket from the same pool, buth with a higher priority. + // The first request should now be stalled at the socket group limit. + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", + params_, + HIGHEST, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); + + // The first handle should remain stalled as the other socket goes through + // the connect process. + + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle2.GetLoadState()); + + client_socket_factory_.SignalJob(0); + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState()); + + // Closing the second socket should cause the stalled handle to finally get a + // ConnectJob. + handle2.socket()->Disconnect(); + handle2.Reset(); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); +} + +// Test GetLoadState in the case the per-pool limit is reached. +TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) { + CreatePool(2, 2); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Request for socket from another pool. + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("b", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Request another socket from the first pool. Request should stall at the + // socket pool limit. + ClientSocketHandle handle3; + TestCompletionCallback callback3; + rv = handle3.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // The third handle should remain stalled as the other sockets in its group + // goes through the connect process. + + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState()); + + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE); + EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState()); + + client_socket_factory_.SignalJob(0); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState()); + + // Closing a socket should allow the stalled handle to finally get a new + // ConnectJob. + handle.socket()->Disconnect(); + handle.Reset(); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle3.GetLoadState()); +} + +TEST_F(ClientSocketPoolBaseTest, Recoverable) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockRecoverableJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingRecoverableJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); +} + +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateSynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type( + TestConnectJob::kMockAdditionalErrorStateJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_CONNECTION_FAILED, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); + EXPECT_FALSE(handle.ssl_error_response_info().headers.get() == NULL); +} + +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateAsynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingAdditionalErrorStateJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); + EXPECT_FALSE(handle.ssl_error_response_info().headers.get() == NULL); +} + +// Make sure we can reuse sockets when the cleanup timer is disabled. +TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) { + // Disable cleanup timer. + internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(false); + + CreatePoolWithIdleTimeouts( + kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, + base::TimeDelta(), // Time out unused sockets immediately. + base::TimeDelta::FromDays(1)); // Don't time out used sockets. + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + ASSERT_EQ(OK, callback.WaitForResult()); + + // Use and release the socket. + EXPECT_EQ(1, handle.socket()->Write(NULL, 1, CompletionCallback())); + TestLoadTimingInfoConnectedNotReused(handle); + handle.Reset(); + + // Should now have one idle socket. + ASSERT_EQ(1, pool_->IdleSocketCount()); + + // Request a new socket. This should reuse the old socket and complete + // synchronously. + CapturingBoundNetLog log; + rv = handle.Init("a", + params_, + LOWEST, + CompletionCallback(), + pool_.get(), + log.bound()); + ASSERT_EQ(OK, rv); + EXPECT_TRUE(handle.is_reused()); + TestLoadTimingInfoConnectedReused(handle); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEntryWithType( + entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); +} + +// Make sure we cleanup old unused sockets when the cleanup timer is disabled. +TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerNoReuse) { + // Disable cleanup timer. + internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(false); + + CreatePoolWithIdleTimeouts( + kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, + base::TimeDelta(), // Time out unused sockets immediately + base::TimeDelta()); // Time out used sockets immediately + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + // Startup two mock pending connect jobs, which will sit in the MessageLoop. + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", + params_, + LOWEST, + callback2.callback(), + pool_.get(), + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2)); + + // Cancel one of the requests. Wait for the other, which will get the first + // job. Release the socket. Run the loop again to make sure the second + // socket is sitting idle and the first one is released (since ReleaseSocket() + // just posts a DoReleaseSocket() task). + + handle.Reset(); + ASSERT_EQ(OK, callback2.WaitForResult()); + // Use the socket. + EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback())); + handle2.Reset(); + + // We post all of our delayed tasks with a 2ms delay. I.e. they don't + // actually become pending until 2ms after they have been created. In order + // to flush all tasks, we need to wait so that we know there are no + // soon-to-be-pending tasks waiting. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10)); + base::MessageLoop::current()->RunUntilIdle(); + + // Both sockets should now be idle. + ASSERT_EQ(2, pool_->IdleSocketCount()); + + // Request a new socket. This should cleanup the unused and timed out ones. + // A new socket will be created rather than reusing the idle one. + CapturingBoundNetLog log; + TestCompletionCallback callback3; + rv = handle.Init("a", + params_, + LOWEST, + callback3.callback(), + pool_.get(), + log.bound()); + ASSERT_EQ(ERR_IO_PENDING, rv); + ASSERT_EQ(OK, callback3.WaitForResult()); + EXPECT_FALSE(handle.is_reused()); + + // Make sure the idle socket is closed. + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_FALSE(LogContainsEntryWithType( + entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); +} + +TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { + CreatePoolWithIdleTimeouts( + kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, + base::TimeDelta(), // Time out unused sockets immediately. + base::TimeDelta::FromDays(1)); // Don't time out used sockets. + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + // Startup two mock pending connect jobs, which will sit in the MessageLoop. + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", + params_, + LOWEST, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2)); + + // Cancel one of the requests. Wait for the other, which will get the first + // job. Release the socket. Run the loop again to make sure the second + // socket is sitting idle and the first one is released (since ReleaseSocket() + // just posts a DoReleaseSocket() task). + + handle.Reset(); + EXPECT_EQ(OK, callback2.WaitForResult()); + // Use the socket. + EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback())); + handle2.Reset(); + + // We post all of our delayed tasks with a 2ms delay. I.e. they don't + // actually become pending until 2ms after they have been created. In order + // to flush all tasks, we need to wait so that we know there are no + // soon-to-be-pending tasks waiting. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10)); + base::MessageLoop::current()->RunUntilIdle(); + + ASSERT_EQ(2, pool_->IdleSocketCount()); + + // Invoke the idle socket cleanup check. Only one socket should be left, the + // used socket. Request it to make sure that it's used. + + pool_->CleanupTimedOutIdleSockets(); + CapturingBoundNetLog log; + rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + log.bound()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_reused()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEntryWithType( + entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); +} + +// Make sure that we process all pending requests even when we're stalling +// because of multiple releasing disconnected sockets. +TEST_F(ClientSocketPoolBaseTest, MultipleReleasingDisconnectedSockets) { + CreatePoolWithIdleTimeouts( + kDefaultMaxSockets, kDefaultMaxSocketsPerGroup, + base::TimeDelta(), // Time out unused sockets immediately. + base::TimeDelta::FromDays(1)); // Don't time out used sockets. + + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + // Startup 4 connect jobs. Two of them will be pending. + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", + params_, + LOWEST, + callback.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(OK, rv); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", + params_, + LOWEST, + callback2.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(OK, rv); + + ClientSocketHandle handle3; + TestCompletionCallback callback3; + rv = handle3.Init("a", + params_, + LOWEST, + callback3.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + ClientSocketHandle handle4; + TestCompletionCallback callback4; + rv = handle4.Init("a", + params_, + LOWEST, + callback4.callback(), + pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Release two disconnected sockets. + + handle.socket()->Disconnect(); + handle.Reset(); + handle2.socket()->Disconnect(); + handle2.Reset(); + + EXPECT_EQ(OK, callback3.WaitForResult()); + EXPECT_FALSE(handle3.is_reused()); + EXPECT_EQ(OK, callback4.WaitForResult()); + EXPECT_FALSE(handle4.is_reused()); +} + +// Regression test for http://crbug.com/42267. +// When DoReleaseSocket() is processed for one socket, it is blocked because the +// other stalled groups all have releasing sockets, so no progress can be made. +TEST_F(ClientSocketPoolBaseTest, SocketLimitReleasingSockets) { + CreatePoolWithIdleTimeouts( + 4 /* socket limit */, 4 /* socket limit per group */, + base::TimeDelta(), // Time out unused sockets immediately. + base::TimeDelta::FromDays(1)); // Don't time out used sockets. + + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + // Max out the socket limit with 2 per group. + + ClientSocketHandle handle_a[4]; + TestCompletionCallback callback_a[4]; + ClientSocketHandle handle_b[4]; + TestCompletionCallback callback_b[4]; + + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(OK, handle_a[i].Init("a", + params_, + LOWEST, + callback_a[i].callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, handle_b[i].Init("b", + params_, + LOWEST, + callback_b[i].callback(), + pool_.get(), + BoundNetLog())); + } + + // Make 4 pending requests, 2 per group. + + for (int i = 2; i < 4; ++i) { + EXPECT_EQ(ERR_IO_PENDING, + handle_a[i].Init("a", + params_, + LOWEST, + callback_a[i].callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle_b[i].Init("b", + params_, + LOWEST, + callback_b[i].callback(), + pool_.get(), + BoundNetLog())); + } + + // Release b's socket first. The order is important, because in + // DoReleaseSocket(), we'll process b's released socket, and since both b and + // a are stalled, but 'a' is lower lexicographically, we'll process group 'a' + // first, which has a releasing socket, so it refuses to start up another + // ConnectJob. So, we used to infinite loop on this. + handle_b[0].socket()->Disconnect(); + handle_b[0].Reset(); + handle_a[0].socket()->Disconnect(); + handle_a[0].Reset(); + + // Used to get stuck here. + base::MessageLoop::current()->RunUntilIdle(); + + handle_b[1].socket()->Disconnect(); + handle_b[1].Reset(); + handle_a[1].socket()->Disconnect(); + handle_a[1].Reset(); + + for (int i = 2; i < 4; ++i) { + EXPECT_EQ(OK, callback_b[i].WaitForResult()); + EXPECT_EQ(OK, callback_a[i].WaitForResult()); + } +} + +TEST_F(ClientSocketPoolBaseTest, + ReleasingDisconnectedSocketsMaintainsPriorityOrder) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + EXPECT_EQ(OK, (*requests())[0]->WaitForResult()); + EXPECT_EQ(OK, (*requests())[1]->WaitForResult()); + EXPECT_EQ(2u, completion_count()); + + // Releases one connection. + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE)); + EXPECT_EQ(OK, (*requests())[2]->WaitForResult()); + + EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE)); + EXPECT_EQ(OK, (*requests())[3]->WaitForResult()); + EXPECT_EQ(4u, completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(5)); +} + +class TestReleasingSocketRequest : public TestCompletionCallbackBase { + public: + TestReleasingSocketRequest(TestClientSocketPool* pool, + int expected_result, + bool reset_releasing_handle) + : pool_(pool), + expected_result_(expected_result), + reset_releasing_handle_(reset_releasing_handle), + callback_(base::Bind(&TestReleasingSocketRequest::OnComplete, + base::Unretained(this))) { + } + + virtual ~TestReleasingSocketRequest() {} + + ClientSocketHandle* handle() { return &handle_; } + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + if (reset_releasing_handle_) + handle_.Reset(); + + scoped_refptr<TestSocketParams> con_params(new TestSocketParams()); + EXPECT_EQ(expected_result_, + handle2_.Init("a", con_params, kDefaultPriority, + callback2_.callback(), pool_, BoundNetLog())); + } + + TestClientSocketPool* const pool_; + int expected_result_; + bool reset_releasing_handle_; + ClientSocketHandle handle_; + ClientSocketHandle handle2_; + CompletionCallback callback_; + TestCompletionCallback callback2_; +}; + + +TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("b", kDefaultPriority)); + + EXPECT_EQ(static_cast<int>(requests_size()), + client_socket_factory_.allocation_count()); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockPendingAdditionalErrorStateJob); + TestReleasingSocketRequest req(pool_.get(), OK, false); + EXPECT_EQ(ERR_IO_PENDING, + req.handle()->Init("a", params_, kDefaultPriority, req.callback(), + pool_.get(), BoundNetLog())); + // The next job should complete synchronously + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult()); + EXPECT_FALSE(req.handle()->is_initialized()); + EXPECT_FALSE(req.handle()->socket()); + EXPECT_TRUE(req.handle()->is_ssl_error()); + EXPECT_FALSE(req.handle()->ssl_error_response_info().headers.get() == NULL); +} + +// http://crbug.com/44724 regression test. +// We start releasing the pool when we flush on network change. When that +// happens, the only active references are in the ClientSocketHandles. When a +// ConnectJob completes and calls back into the last ClientSocketHandle, that +// callback can release the last reference and delete the pool. After the +// callback finishes, we go back to the stack frame within the now-deleted pool. +// Executing any code that refers to members of the now-deleted pool can cause +// crashes. +TEST_F(ClientSocketPoolBaseTest, CallbackThatReleasesPool) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + pool_->FlushWithError(ERR_NETWORK_CHANGED); + + // We'll call back into this now. + callback.WaitForResult(); +} + +TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type()); + + pool_->FlushWithError(ERR_NETWORK_CHANGED); + + handle.Reset(); + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type()); +} + +class ConnectWithinCallback : public TestCompletionCallbackBase { + public: + ConnectWithinCallback( + const std::string& group_name, + const scoped_refptr<TestSocketParams>& params, + TestClientSocketPool* pool) + : group_name_(group_name), + params_(params), + pool_(pool), + callback_(base::Bind(&ConnectWithinCallback::OnComplete, + base::Unretained(this))) { + } + + virtual ~ConnectWithinCallback() {} + + int WaitForNestedResult() { + return nested_callback_.WaitForResult(); + } + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + EXPECT_EQ(ERR_IO_PENDING, + handle_.Init(group_name_, + params_, + kDefaultPriority, + nested_callback_.callback(), + pool_, + BoundNetLog())); + } + + const std::string group_name_; + const scoped_refptr<TestSocketParams> params_; + TestClientSocketPool* const pool_; + ClientSocketHandle handle_; + CompletionCallback callback_; + TestCompletionCallback nested_callback_; + + DISALLOW_COPY_AND_ASSIGN(ConnectWithinCallback); +}; + +TEST_F(ClientSocketPoolBaseTest, AbortAllRequestsOnFlush) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + + // First job will be waiting until it gets aborted. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + + ClientSocketHandle handle; + ConnectWithinCallback callback("a", params_, pool_.get()); + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + // Second job will be started during the first callback, and will + // asynchronously complete with OK. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + pool_->FlushWithError(ERR_NETWORK_CHANGED); + EXPECT_EQ(ERR_NETWORK_CHANGED, callback.WaitForResult()); + EXPECT_EQ(OK, callback.WaitForNestedResult()); +} + +// Cancel a pending socket request while we're at max sockets, +// and verify that the backup socket firing doesn't cause a crash. +TEST_F(ClientSocketPoolBaseTest, BackupSocketCancelAtMaxSockets) { + // Max 4 sockets globally, max 4 sockets per group. + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + pool_->EnableConnectBackupJobs(); + + // Create the first socket and set to ERR_IO_PENDING. This starts the backup + // timer. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + // Start (MaxSockets - 1) connected sockets to reach max sockets. + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + ClientSocketHandle handles[kDefaultMaxSockets]; + for (int i = 1; i < kDefaultMaxSockets; ++i) { + TestCompletionCallback callback; + EXPECT_EQ(OK, handles[i].Init("bar", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + } + + base::MessageLoop::current()->RunUntilIdle(); + + // Cancel the pending request. + handle.Reset(); + + // Wait for the backup timer to fire (add some slop to ensure it fires) + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3)); + + base::MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); +} + +TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterCancelingAllRequests) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + pool_->EnableConnectBackupJobs(); + + // Create the first socket and set to ERR_IO_PENDING. This starts the backup + // timer. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + ASSERT_TRUE(pool_->HasGroup("bar")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("bar")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("bar")); + + // Cancel the socket request. This should cancel the backup timer. Wait for + // the backup time to see if it indeed got canceled. + handle.Reset(); + // Wait for the backup timer to fire (add some slop to ensure it fires) + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3)); + base::MessageLoop::current()->RunUntilIdle(); + ASSERT_TRUE(pool_->HasGroup("bar")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("bar")); +} + +TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterFinishingAllRequests) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + pool_->EnableConnectBackupJobs(); + + // Create the first socket and set to ERR_IO_PENDING. This starts the backup + // timer. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("bar", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + ASSERT_TRUE(pool_->HasGroup("bar")); + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("bar")); + + // Cancel request 1 and then complete request 2. With the requests finished, + // the backup timer should be cancelled. + handle.Reset(); + EXPECT_EQ(OK, callback2.WaitForResult()); + // Wait for the backup timer to fire (add some slop to ensure it fires) + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3)); + base::MessageLoop::current()->RunUntilIdle(); +} + +// Test delayed socket binding for the case where we have two connects, +// and while one is waiting on a connect, the other frees up. +// The socket waiting on a connect should switch immediately to the freed +// up socket. +TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); + + // No idle sockets, no pending jobs. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + // Create a second socket to the same host, but this one will wait. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle2; + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + // No idle sockets, and one connecting job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Return the first handle to the pool. This will initiate the delayed + // binding. + handle1.Reset(); + + base::MessageLoop::current()->RunUntilIdle(); + + // Still no idle sockets, still one pending connect job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // The second socket connected, even though it was a Waiting Job. + EXPECT_EQ(OK, callback.WaitForResult()); + + // And we can see there is still one job waiting. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Finally, signal the waiting Connect. + client_socket_factory_.SignalJobs(); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + base::MessageLoop::current()->RunUntilIdle(); +} + +// Test delayed socket binding when a group is at capacity and one +// of the group's sockets frees up. +TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); + + // No idle sockets, no pending jobs. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + // Create a second socket to the same host, but this one will wait. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle2; + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + // No idle sockets, and one connecting job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Return the first handle to the pool. This will initiate the delayed + // binding. + handle1.Reset(); + + base::MessageLoop::current()->RunUntilIdle(); + + // Still no idle sockets, still one pending connect job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // The second socket connected, even though it was a Waiting Job. + EXPECT_EQ(OK, callback.WaitForResult()); + + // And we can see there is still one job waiting. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Finally, signal the waiting Connect. + client_socket_factory_.SignalJobs(); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + base::MessageLoop::current()->RunUntilIdle(); +} + +// Test out the case where we have one socket connected, one +// connecting, when the first socket finishes and goes idle. +// Although the second connection is pending, the second request +// should complete, by taking the first socket's idle socket. +TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); + + // No idle sockets, no pending jobs. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + // Create a second socket to the same host, but this one will wait. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + ClientSocketHandle handle2; + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + // No idle sockets, and one connecting job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Return the first handle to the pool. This will initiate the delayed + // binding. + handle1.Reset(); + + base::MessageLoop::current()->RunUntilIdle(); + + // Still no idle sockets, still one pending connect job. + EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // The second socket connected, even though it was a Waiting Job. + EXPECT_EQ(OK, callback.WaitForResult()); + + // And we can see there is still one job waiting. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Finally, signal the waiting Connect. + client_socket_factory_.SignalJobs(); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + base::MessageLoop::current()->RunUntilIdle(); +} + +// Cover the case where on an available socket slot, we have one pending +// request that completes synchronously, thereby making the Group empty. +TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) { + const int kUnlimitedSockets = 100; + const int kOneSocketPerGroup = 1; + CreatePool(kUnlimitedSockets, kOneSocketPerGroup); + + // Make the first request asynchronous fail. + // This will free up a socket slot later. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Make the second request synchronously fail. This should make the Group + // empty. + connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); + ClientSocketHandle handle2; + TestCompletionCallback callback2; + // It'll be ERR_IO_PENDING now, but the TestConnectJob will synchronously fail + // when created. + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback1.WaitForResult()); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback2.WaitForResult()); + EXPECT_FALSE(pool_->HasGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + ClientSocketHandle handle3; + TestCompletionCallback callback3; + EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", + params_, + kDefaultPriority, + callback3.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_EQ(OK, callback3.WaitForResult()); + + // Use the socket. + EXPECT_EQ(1, handle1.socket()->Write(NULL, 1, CompletionCallback())); + EXPECT_EQ(1, handle3.socket()->Write(NULL, 1, CompletionCallback())); + + handle1.Reset(); + handle2.Reset(); + handle3.Reset(); + + EXPECT_EQ(OK, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, handle3.Init("a", + params_, + kDefaultPriority, + callback3.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_TRUE(handle1.socket()->WasEverUsed()); + EXPECT_TRUE(handle2.socket()->WasEverUsed()); + EXPECT_FALSE(handle3.socket()->WasEverUsed()); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSockets) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_EQ(OK, callback2.WaitForResult()); + handle1.Reset(); + handle2.Reset(); + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_EQ(OK, callback2.WaitForResult()); + handle1.Reset(); + handle2.Reset(); + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, + RequestSocketsWhenAlreadyHaveMultipleConnectJob) { + CreatePool(4, 4); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + ClientSocketHandle handle3; + TestCompletionCallback callback3; + EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", + params_, + kDefaultPriority, + callback3.callback(), + pool_.get(), + BoundNetLog())); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + EXPECT_EQ(OK, callback1.WaitForResult()); + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_EQ(OK, callback3.WaitForResult()); + handle1.Reset(); + handle2.Reset(); + handle3.Reset(); + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(3, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsAtMaxSocketLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ASSERT_FALSE(pool_->HasGroup("a")); + + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSockets, + BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(kDefaultMaxSockets, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(kDefaultMaxSockets, pool_->NumUnassignedConnectJobsInGroup("a")); + + ASSERT_FALSE(pool_->HasGroup("b")); + + pool_->RequestSockets("b", ¶ms_, kDefaultMaxSockets, + BoundNetLog()); + + ASSERT_FALSE(pool_->HasGroup("b")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsHitMaxSocketLimit) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSockets); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ASSERT_FALSE(pool_->HasGroup("a")); + + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSockets - 1, + BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(kDefaultMaxSockets - 1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(kDefaultMaxSockets - 1, + pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_FALSE(pool_->IsStalled()); + + ASSERT_FALSE(pool_->HasGroup("b")); + + pool_->RequestSockets("b", ¶ms_, kDefaultMaxSockets, + BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("b")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("b")); + EXPECT_FALSE(pool_->IsStalled()); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountIdleSockets) { + CreatePool(4, 4); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + ASSERT_EQ(OK, callback1.WaitForResult()); + handle1.Reset(); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountActiveSockets) { + CreatePool(4, 4); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + ASSERT_EQ(OK, callback1.WaitForResult()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronous) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSocketsPerGroup, + BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("b", ¶ms_, kDefaultMaxSocketsPerGroup, + BoundNetLog()); + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b")); + EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->IdleSocketCountInGroup("b")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronousError) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); + + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSocketsPerGroup, + BoundNetLog()); + + ASSERT_FALSE(pool_->HasGroup("a")); + + connect_job_factory_->set_job_type( + TestConnectJob::kMockAdditionalErrorStateJob); + pool_->RequestSockets("a", ¶ms_, kDefaultMaxSocketsPerGroup, + BoundNetLog()); + + ASSERT_FALSE(pool_->HasGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) { + CreatePool(4, 4); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + ASSERT_EQ(OK, callback1.WaitForResult()); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + int rv = handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog()); + if (rv != OK) { + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback2.WaitForResult()); + } + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + handle1.Reset(); + handle2.Reset(); + + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, RequestSocketsDifferentNumSockets) { + CreatePool(4, 4); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 3, BoundNetLog()); + EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(3, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(3, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, PreconnectJobsTakenByNormalRequests) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + ASSERT_EQ(OK, callback1.WaitForResult()); + + // Make sure if a preconneced socket is not fully connected when a request + // starts, it has a connect start time. + TestLoadTimingInfoConnectedNotReused(handle1); + handle1.Reset(); + + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); +} + +// Checks that fully connected preconnect jobs have no connect times, and are +// marked as reused. +TEST_F(ClientSocketPoolBaseTest, ConnectedPreconnectJobsHaveNoConnectTimes) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + + ASSERT_TRUE(pool_->HasGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + + // Make sure the idle socket was used. + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + TestLoadTimingInfoConnectedReused(handle); + handle.Reset(); + TestLoadTimingInfoNotConnected(handle); +} + +// http://crbug.com/64940 regression test. +TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { + const int kMaxTotalSockets = 3; + const int kMaxSocketsPerGroup = 2; + CreatePool(kMaxTotalSockets, kMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + // Note that group name ordering matters here. "a" comes before "b", so + // CloseOneIdleSocket() will try to close "a"'s idle socket. + + // Set up one idle socket in "a". + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + ASSERT_EQ(OK, callback1.WaitForResult()); + handle1.Reset(); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + + // Set up two active sockets in "b". + ClientSocketHandle handle2; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle1.Init("b", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, handle2.Init("b", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + + ASSERT_EQ(OK, callback1.WaitForResult()); + ASSERT_EQ(OK, callback2.WaitForResult()); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b")); + EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b")); + + // Now we have 1 idle socket in "a" and 2 active sockets in "b". This means + // we've maxed out on sockets, since we set |kMaxTotalSockets| to 3. + // Requesting 2 preconnected sockets for "a" should fail to allocate any more + // sockets for "a", and "b" should still have 2 active sockets. + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b")); + + // Now release the 2 active sockets for "b". This will give us 1 idle socket + // in "a" and 2 idle sockets in "b". Requesting 2 preconnected sockets for + // "a" should result in closing 1 for "b". + handle1.Reset(); + handle2.Reset(); + EXPECT_EQ(2, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b")); + + pool_->RequestSockets("a", ¶ms_, 2, BoundNetLog()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a")); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b")); + EXPECT_EQ(1, pool_->IdleSocketCountInGroup("b")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b")); +} + +TEST_F(ClientSocketPoolBaseTest, PreconnectWithoutBackupJob) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + pool_->EnableConnectBackupJobs(); + + // Make the ConnectJob hang until it times out, shorten the timeout. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + connect_job_factory_->set_timeout_duration( + base::TimeDelta::FromMilliseconds(500)); + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + + // Verify the backup timer doesn't create a backup job, by making + // the backup job a pending job instead of a waiting job, so it + // *would* complete if it were created. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::MessageLoop::QuitClosure(), + base::TimeDelta::FromSeconds(1)); + base::MessageLoop::current()->Run(); + EXPECT_FALSE(pool_->HasGroup("a")); +} + +TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + pool_->EnableConnectBackupJobs(); + + // Make the ConnectJob hang forever. + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); + pool_->RequestSockets("a", ¶ms_, 1, BoundNetLog()); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + base::MessageLoop::current()->RunUntilIdle(); + + // Make the backup job be a pending job, so it completes normally. + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + // Timer has started, but the backup connect job shouldn't be created yet. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a")); + ASSERT_EQ(OK, callback.WaitForResult()); + + // The hung connect job should still be there, but everything else should be + // complete. + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); + EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); + EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); +} + +class MockLayeredPool : public LayeredPool { + public: + MockLayeredPool(TestClientSocketPool* pool, + const std::string& group_name) + : pool_(pool), + params_(new TestSocketParams), + group_name_(group_name), + can_release_connection_(true) { + pool_->AddLayeredPool(this); + } + + ~MockLayeredPool() { + pool_->RemoveLayeredPool(this); + } + + int RequestSocket(TestClientSocketPool* pool) { + return handle_.Init(group_name_, params_, kDefaultPriority, + callback_.callback(), pool, BoundNetLog()); + } + + int RequestSocketWithoutLimits(TestClientSocketPool* pool) { + params_->set_ignore_limits(true); + return handle_.Init(group_name_, params_, kDefaultPriority, + callback_.callback(), pool, BoundNetLog()); + } + + bool ReleaseOneConnection() { + if (!handle_.is_initialized() || !can_release_connection_) { + return false; + } + handle_.socket()->Disconnect(); + handle_.Reset(); + return true; + } + + void set_can_release_connection(bool can_release_connection) { + can_release_connection_ = can_release_connection; + } + + MOCK_METHOD0(CloseOneIdleConnection, bool()); + + private: + TestClientSocketPool* const pool_; + scoped_refptr<TestSocketParams> params_; + ClientSocketHandle handle_; + TestCompletionCallback callback_; + const std::string group_name_; + bool can_release_connection_; +}; + +TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Return(false)); + EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool()); +} + +TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool()); +} + +// Tests the basic case of closing an idle socket in a higher layered pool when +// a new request is issued and the lower layer pool is stalled. +TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketsHeldByLayeredPoolWhenNeeded) { + CreatePool(1, 1); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); +} + +// Same as above, but the idle socket is in the same group as the stalled +// socket, and closes the only other request in its group when closing requests +// in higher layered pools. This generally shouldn't happen, but it may be +// possible if a higher level pool issues a request and the request is +// subsequently cancelled. Even if it's not possible, best not to crash. +TEST_F(ClientSocketPoolBaseTest, + CloseIdleSocketsHeldByLayeredPoolWhenNeededSameGroup) { + CreatePool(2, 2); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + // Need a socket in another group for the pool to be stalled (If a group + // has the maximum number of connections already, it's not stalled). + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(OK, handle1.Init("group1", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + MockLayeredPool mock_layered_pool(pool_.get(), "group2"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + ClientSocketHandle handle; + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("group2", + params_, + kDefaultPriority, + callback2.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback2.WaitForResult()); +} + +// Tests the case when an idle socket can be closed when a new request is +// issued, and the new request belongs to a group that was previously stalled. +TEST_F(ClientSocketPoolBaseTest, + CloseIdleSocketsHeldByLayeredPoolInSameGroupWhenNeeded) { + CreatePool(2, 2); + std::list<TestConnectJob::JobType> job_types; + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + connect_job_factory_->set_job_types(&job_types); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(OK, handle1.Init("group1", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + MockLayeredPool mock_layered_pool(pool_.get(), "group2"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + mock_layered_pool.set_can_release_connection(false); + + // The third request is made when the socket pool is in a stalled state. + ClientSocketHandle handle3; + TestCompletionCallback callback3; + EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3", + params_, + kDefaultPriority, + callback3.callback(), + pool_.get(), + BoundNetLog())); + + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(callback3.have_result()); + + // The fourth request is made when the pool is no longer stalled. The third + // request should be serviced first, since it was issued first and has the + // same priority. + mock_layered_pool.set_can_release_connection(true); + ClientSocketHandle handle4; + TestCompletionCallback callback4; + EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3", + params_, + kDefaultPriority, + callback4.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback3.WaitForResult()); + EXPECT_FALSE(callback4.have_result()); + + // Closing a handle should free up another socket slot. + handle1.Reset(); + EXPECT_EQ(OK, callback4.WaitForResult()); +} + +// Tests the case when an idle socket can be closed when a new request is +// issued, and the new request belongs to a group that was previously stalled. +// +// The two differences from the above test are that the stalled requests are not +// in the same group as the layered pool's request, and the the fourth request +// has a higher priority than the third one, so gets a socket first. +TEST_F(ClientSocketPoolBaseTest, + CloseIdleSocketsHeldByLayeredPoolInSameGroupWhenNeeded2) { + CreatePool(2, 2); + std::list<TestConnectJob::JobType> job_types; + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + job_types.push_back(TestConnectJob::kMockJob); + connect_job_factory_->set_job_types(&job_types); + + ClientSocketHandle handle1; + TestCompletionCallback callback1; + EXPECT_EQ(OK, handle1.Init("group1", + params_, + kDefaultPriority, + callback1.callback(), + pool_.get(), + BoundNetLog())); + + MockLayeredPool mock_layered_pool(pool_.get(), "group2"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + mock_layered_pool.set_can_release_connection(false); + + // The third request is made when the socket pool is in a stalled state. + ClientSocketHandle handle3; + TestCompletionCallback callback3; + EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3", + params_, + MEDIUM, + callback3.callback(), + pool_.get(), + BoundNetLog())); + + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(callback3.have_result()); + + // The fourth request is made when the pool is no longer stalled. This + // request has a higher priority than the third request, so is serviced first. + mock_layered_pool.set_can_release_connection(true); + ClientSocketHandle handle4; + TestCompletionCallback callback4; + EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3", + params_, + HIGHEST, + callback4.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback4.WaitForResult()); + EXPECT_FALSE(callback3.have_result()); + + // Closing a handle should free up another socket slot. + handle1.Reset(); + EXPECT_EQ(OK, callback3.WaitForResult()); +} + +TEST_F(ClientSocketPoolBaseTest, + CloseMultipleIdleSocketsHeldByLayeredPoolWhenNeeded) { + CreatePool(1, 1); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool1(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool1.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool1, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool1, + &MockLayeredPool::ReleaseOneConnection)); + MockLayeredPool mock_layered_pool2(pool_.get(), "bar"); + EXPECT_EQ(OK, mock_layered_pool2.RequestSocketWithoutLimits(pool_.get())); + EXPECT_CALL(mock_layered_pool2, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool2, + &MockLayeredPool::ReleaseOneConnection)); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + pool_.get(), + BoundNetLog())); + EXPECT_EQ(OK, callback.WaitForResult()); +} + +// Test that when a socket pool and group are at their limits, a request +// with |ignore_limits| triggers creation of a new socket, and gets the socket +// instead of a request with the same priority that was issued earlier, but +// that does not have |ignore_limits| set. +TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) { + scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); + params_ignore_limits->set_ignore_limits(true); + CreatePool(1, 1); + + // Issue a request to reach the socket pool limit. + EXPECT_EQ(OK, StartRequestWithParams("a", kDefaultPriority, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority, + params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority, + params_ignore_limits)); + ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(OK, request(2)->WaitForResult()); + EXPECT_FALSE(request(1)->have_result()); +} + +// Test that when a socket pool and group are at their limits, a request with +// |ignore_limits| set triggers creation of a new socket, and gets the socket +// instead of a request with a higher priority that was issued earlier, but +// that does not have |ignore_limits| set. +TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority) { + scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); + params_ignore_limits->set_ignore_limits(true); + CreatePool(1, 1); + + // Issue a request to reach the socket pool limit. + EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW, + params_ignore_limits)); + ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(OK, request(2)->WaitForResult()); + EXPECT_FALSE(request(1)->have_result()); +} + +// Test that when a socket pool and group are at their limits, a request with +// |ignore_limits| set triggers creation of a new socket, and gets the socket +// instead of a request with a higher priority that was issued later and +// does not have |ignore_limits| set. +TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority2) { + scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); + params_ignore_limits->set_ignore_limits(true); + CreatePool(1, 1); + + // Issue a request to reach the socket pool limit. + EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW, + params_ignore_limits)); + ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(OK, request(1)->WaitForResult()); + EXPECT_FALSE(request(2)->have_result()); +} + +// Test that when a socket pool and group are at their limits, a ConnectJob +// issued for a request with |ignore_limits| set is not cancelled when a request +// without |ignore_limits| issued to the same group is cancelled. +TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsCancelOtherJob) { + scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); + params_ignore_limits->set_ignore_limits(true); + CreatePool(1, 1); + + // Issue a request to reach the socket pool limit. + EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, + params_ignore_limits)); + ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + // Cancel the pending request without ignore_limits set. The ConnectJob + // should not be cancelled. + request(1)->handle()->Reset(); + ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); + + EXPECT_EQ(OK, request(2)->WaitForResult()); + EXPECT_FALSE(request(1)->have_result()); +} + +// More involved test of ignore limits. Issues a bunch of requests and later +// checks the order in which they receive sockets. +TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsOrder) { + scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams()); + params_ignore_limits->set_ignore_limits(true); + CreatePool(1, 1); + + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); + + // Requests 0 and 1 do not have ignore_limits set, so they finish last. Since + // the maximum number of sockets per pool is 1, the second requests does not + // trigger a ConnectJob. + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_)); + + // Requests 2 and 3 have ignore_limits set, but have a low priority, so they + // finish just before the first two. + EXPECT_EQ(ERR_IO_PENDING, + StartRequestWithParams("a", LOW, params_ignore_limits)); + EXPECT_EQ(ERR_IO_PENDING, + StartRequestWithParams("a", LOW, params_ignore_limits)); + + // Request 4 finishes first, since it is high priority and ignores limits. + EXPECT_EQ(ERR_IO_PENDING, + StartRequestWithParams("a", HIGHEST, params_ignore_limits)); + + // Request 5 and 6 are cancelled right after starting. This should result in + // creating two ConnectJobs. Since only one request (Request 1) did not + // result in creating a ConnectJob, only one of the ConnectJobs should be + // cancelled when the requests are. + EXPECT_EQ(ERR_IO_PENDING, + StartRequestWithParams("a", HIGHEST, params_ignore_limits)); + EXPECT_EQ(ERR_IO_PENDING, + StartRequestWithParams("a", HIGHEST, params_ignore_limits)); + EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a")); + request(5)->handle()->Reset(); + EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a")); + request(6)->handle()->Reset(); + ASSERT_EQ(5, pool_->NumConnectJobsInGroup("a")); + + // Wait for the last request to get a socket. + EXPECT_EQ(OK, request(1)->WaitForResult()); + + // Check order in which requests received sockets. + // These are 1-based indices, while request(x) uses 0-based indices. + EXPECT_EQ(1, GetOrderOfRequest(5)); + EXPECT_EQ(2, GetOrderOfRequest(3)); + EXPECT_EQ(3, GetOrderOfRequest(4)); + EXPECT_EQ(4, GetOrderOfRequest(1)); + EXPECT_EQ(5, GetOrderOfRequest(2)); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool_histograms.cc b/chromium/net/socket/client_socket_pool_histograms.cc new file mode 100644 index 00000000000..9af8649c48b --- /dev/null +++ b/chromium/net/socket/client_socket_pool_histograms.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool_histograms.h" + +#include <string> + +#include "base/metrics/field_trial.h" +#include "base/metrics/histogram.h" +#include "net/base/net_errors.h" +#include "net/socket/client_socket_handle.h" + +namespace net { + +using base::Histogram; +using base::HistogramBase; +using base::LinearHistogram; +using base::CustomHistogram; + +ClientSocketPoolHistograms::ClientSocketPoolHistograms( + const std::string& pool_name) + : is_http_proxy_connection_(false), + is_socks_connection_(false) { + // UMA_HISTOGRAM_ENUMERATION + socket_type_ = LinearHistogram::FactoryGet("Net.SocketType_" + pool_name, 1, + ClientSocketHandle::NUM_TYPES, ClientSocketHandle::NUM_TYPES + 1, + HistogramBase::kUmaTargetedHistogramFlag); + // UMA_HISTOGRAM_CUSTOM_TIMES + request_time_ = Histogram::FactoryTimeGet( + "Net.SocketRequestTime_" + pool_name, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100, HistogramBase::kUmaTargetedHistogramFlag); + // UMA_HISTOGRAM_CUSTOM_TIMES + unused_idle_time_ = Histogram::FactoryTimeGet( + "Net.SocketIdleTimeBeforeNextUse_UnusedSocket_" + pool_name, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), + 100, HistogramBase::kUmaTargetedHistogramFlag); + // UMA_HISTOGRAM_CUSTOM_TIMES + reused_idle_time_ = Histogram::FactoryTimeGet( + "Net.SocketIdleTimeBeforeNextUse_ReusedSocket_" + pool_name, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(6), + 100, HistogramBase::kUmaTargetedHistogramFlag); + // UMA_HISTOGRAM_CUSTOM_ENUMERATION + error_code_ = CustomHistogram::FactoryGet( + "Net.SocketInitErrorCodes_" + pool_name, + GetAllErrorCodesForUma(), + HistogramBase::kUmaTargetedHistogramFlag); + + if (pool_name == "HTTPProxy") + is_http_proxy_connection_ = true; + else if (pool_name == "SOCK") + is_socks_connection_ = true; +} + +ClientSocketPoolHistograms::~ClientSocketPoolHistograms() { +} + +void ClientSocketPoolHistograms::AddSocketType(int type) const { + socket_type_->Add(type); +} + +void ClientSocketPoolHistograms::AddRequestTime(base::TimeDelta time) const { + request_time_->AddTime(time); +} + +void ClientSocketPoolHistograms::AddUnusedIdleTime(base::TimeDelta time) const { + unused_idle_time_->AddTime(time); +} + +void ClientSocketPoolHistograms::AddReusedIdleTime(base::TimeDelta time) const { + reused_idle_time_->AddTime(time); +} + +void ClientSocketPoolHistograms::AddErrorCode(int error_code) const { + // Error codes are positive (since histograms expect positive sample values). + error_code_->Add(-error_code); +} + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool_histograms.h b/chromium/net/socket/client_socket_pool_histograms.h new file mode 100644 index 00000000000..26a406362bd --- /dev/null +++ b/chromium/net/socket/client_socket_pool_histograms.h @@ -0,0 +1,46 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/time/time.h" +#include "net/base/net_export.h" + +namespace base { +class HistogramBase; +} + +namespace net { + +class NET_EXPORT_PRIVATE ClientSocketPoolHistograms { + public: + ClientSocketPoolHistograms(const std::string& pool_name); + ~ClientSocketPoolHistograms(); + + void AddSocketType(int socket_reuse_type) const; + void AddRequestTime(base::TimeDelta time) const; + void AddUnusedIdleTime(base::TimeDelta time) const; + void AddReusedIdleTime(base::TimeDelta time) const; + void AddErrorCode(int error_code) const; + + private: + base::HistogramBase* socket_type_; + base::HistogramBase* request_time_; + base::HistogramBase* unused_idle_time_; + base::HistogramBase* reused_idle_time_; + base::HistogramBase* error_code_; + + bool is_http_proxy_connection_; + bool is_socks_connection_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolHistograms); +}; + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc new file mode 100644 index 00000000000..71496d28646 --- /dev/null +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -0,0 +1,467 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool_manager.h" + +#include <string> + +#include "base/basictypes.h" +#include "base/logging.h" +#include "base/strings/stringprintf.h" +#include "net/base/load_flags.h" +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/http/http_request_info.h" +#include "net/http/http_stream_factory.h" +#include "net/proxy/proxy_info.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket_pool.h" +#include "net/socket/transport_client_socket_pool.h" + +namespace net { + +namespace { + +// Limit of sockets of each socket pool. +int g_max_sockets_per_pool[] = { + 256, // NORMAL_SOCKET_POOL + 256 // WEBSOCKET_SOCKET_POOL +}; + +COMPILE_ASSERT(arraysize(g_max_sockets_per_pool) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + max_sockets_per_pool_length_mismatch); + +// Default to allow up to 6 connections per host. Experiment and tuning may +// try other values (greater than 0). Too large may cause many problems, such +// as home routers blocking the connections!?!? See http://crbug.com/12066. +// +// WebSocket connections are long-lived, and should be treated differently +// than normal other connections. 6 connections per group sounded too small +// for such use, thus we use a larger limit which was determined somewhat +// arbitrarily. +// TODO(yutak): Look at the usage and determine the right value after +// WebSocket protocol stack starts to work. +int g_max_sockets_per_group[] = { + 6, // NORMAL_SOCKET_POOL + 30 // WEBSOCKET_SOCKET_POOL +}; + +COMPILE_ASSERT(arraysize(g_max_sockets_per_group) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + max_sockets_per_group_length_mismatch); + +// The max number of sockets to allow per proxy server. This applies both to +// http and SOCKS proxies. See http://crbug.com/12066 and +// http://crbug.com/44501 for details about proxy server connection limits. +int g_max_sockets_per_proxy_server[] = { + kDefaultMaxSocketsPerProxyServer, // NORMAL_SOCKET_POOL + kDefaultMaxSocketsPerProxyServer // WEBSOCKET_SOCKET_POOL +}; + +COMPILE_ASSERT(arraysize(g_max_sockets_per_proxy_server) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + max_sockets_per_proxy_server_length_mismatch); + +// The meat of the implementation for the InitSocketHandleForHttpRequest, +// InitSocketHandleForRawConnect and PreconnectSocketsForHttpRequest methods. +int InitSocketPoolHelper(const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + bool force_tunnel, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + int num_preconnect_streams, + ClientSocketHandle* socket_handle, + HttpNetworkSession::SocketPoolType socket_pool_type, + const OnHostResolutionCallback& resolution_callback, + const CompletionCallback& callback) { + scoped_refptr<TransportSocketParams> tcp_params; + scoped_refptr<HttpProxySocketParams> http_proxy_params; + scoped_refptr<SOCKSSocketParams> socks_params; + scoped_ptr<HostPortPair> proxy_host_port; + + bool using_ssl = request_url.SchemeIs("https") || + request_url.SchemeIs("wss") || force_spdy_over_ssl; + + HostPortPair origin_host_port = + HostPortPair(request_url.HostNoBrackets(), + request_url.EffectiveIntPort()); + + if (!using_ssl && session->params().testing_fixed_http_port != 0) { + origin_host_port.set_port(session->params().testing_fixed_http_port); + } else if (using_ssl && session->params().testing_fixed_https_port != 0) { + origin_host_port.set_port(session->params().testing_fixed_https_port); + } + + bool disable_resolver_cache = + request_load_flags & LOAD_BYPASS_CACHE || + request_load_flags & LOAD_VALIDATE_CACHE || + request_load_flags & LOAD_DISABLE_CACHE; + + int load_flags = request_load_flags; + if (session->params().ignore_certificate_errors) + load_flags |= LOAD_IGNORE_ALL_CERT_ERRORS; + + // Build the string used to uniquely identify connections of this type. + // Determine the host and port to connect to. + std::string connection_group = origin_host_port.ToString(); + DCHECK(!connection_group.empty()); + if (request_url.SchemeIs("ftp")) { + // Combining FTP with forced SPDY over SSL would be a "path to madness". + // Make sure we never do that. + DCHECK(!using_ssl); + connection_group = "ftp/" + connection_group; + } + if (using_ssl) { + // All connections in a group should use the same SSLConfig settings. + // Encode version_max in the connection group's name, unless it's the + // default version_max. (We want the common case to use the shortest + // encoding). A version_max of TLS 1.1 is encoded as "ssl(max:3.2)/" + // rather than "tlsv1.1/" because the actual protocol version, which + // is selected by the server, may not be TLS 1.1. Do not encode + // version_min in the connection group's name because version_min + // should be the same for all connections, whereas version_max may + // change for version fallbacks. + std::string prefix = "ssl/"; + if (ssl_config_for_origin.version_max != + SSLConfigService::default_version_max()) { + switch (ssl_config_for_origin.version_max) { + case SSL_PROTOCOL_VERSION_TLS1_2: + prefix = "ssl(max:3.3)/"; + break; + case SSL_PROTOCOL_VERSION_TLS1_1: + prefix = "ssl(max:3.2)/"; + break; + case SSL_PROTOCOL_VERSION_TLS1: + prefix = "ssl(max:3.1)/"; + break; + case SSL_PROTOCOL_VERSION_SSL3: + prefix = "sslv3/"; + break; + default: + CHECK(false); + break; + } + } + connection_group = prefix + connection_group; + } + + bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0; + if (proxy_info.is_direct()) { + tcp_params = new TransportSocketParams(origin_host_port, + request_priority, + disable_resolver_cache, + ignore_limits, + resolution_callback); + } else { + ProxyServer proxy_server = proxy_info.proxy_server(); + proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair())); + scoped_refptr<TransportSocketParams> proxy_tcp_params( + new TransportSocketParams(*proxy_host_port, + request_priority, + disable_resolver_cache, + ignore_limits, + resolution_callback)); + + if (proxy_info.is_http() || proxy_info.is_https()) { + std::string user_agent; + request_extra_headers.GetHeader(HttpRequestHeaders::kUserAgent, + &user_agent); + scoped_refptr<SSLSocketParams> ssl_params; + if (proxy_info.is_https()) { + // Set ssl_params, and unset proxy_tcp_params + ssl_params = new SSLSocketParams(proxy_tcp_params, + NULL, + NULL, + ProxyServer::SCHEME_DIRECT, + *proxy_host_port.get(), + ssl_config_for_proxy, + kPrivacyModeDisabled, + load_flags, + force_spdy_over_ssl, + want_spdy_over_npn); + proxy_tcp_params = NULL; + } + + http_proxy_params = + new HttpProxySocketParams(proxy_tcp_params, + ssl_params, + request_url, + user_agent, + origin_host_port, + session->http_auth_cache(), + session->http_auth_handler_factory(), + session->spdy_session_pool(), + force_tunnel || using_ssl); + } else { + DCHECK(proxy_info.is_socks()); + char socks_version; + if (proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5) + socks_version = '5'; + else + socks_version = '4'; + connection_group = base::StringPrintf( + "socks%c/%s", socks_version, connection_group.c_str()); + + socks_params = new SOCKSSocketParams(proxy_tcp_params, + socks_version == '5', + origin_host_port, + request_priority); + } + } + + // Change group name if privacy mode is enabled. + if (privacy_mode == kPrivacyModeEnabled) + connection_group = "pm/" + connection_group; + + // Deal with SSL - which layers on top of any given proxy. + if (using_ssl) { + scoped_refptr<SSLSocketParams> ssl_params = + new SSLSocketParams(tcp_params, + socks_params, + http_proxy_params, + proxy_info.proxy_server().scheme(), + origin_host_port, + ssl_config_for_origin, + privacy_mode, + load_flags, + force_spdy_over_ssl, + want_spdy_over_npn); + SSLClientSocketPool* ssl_pool = NULL; + if (proxy_info.is_direct()) { + ssl_pool = session->GetSSLSocketPool(socket_pool_type); + } else { + ssl_pool = session->GetSocketPoolForSSLWithProxy(socket_pool_type, + *proxy_host_port); + } + + if (num_preconnect_streams) { + RequestSocketsForPool(ssl_pool, connection_group, ssl_params, + num_preconnect_streams, net_log); + return OK; + } + + return socket_handle->Init(connection_group, ssl_params, + request_priority, callback, ssl_pool, + net_log); + } + + // Finally, get the connection started. + + if (proxy_info.is_http() || proxy_info.is_https()) { + HttpProxyClientSocketPool* pool = + session->GetSocketPoolForHTTPProxy(socket_pool_type, *proxy_host_port); + if (num_preconnect_streams) { + RequestSocketsForPool(pool, connection_group, http_proxy_params, + num_preconnect_streams, net_log); + return OK; + } + + return socket_handle->Init(connection_group, http_proxy_params, + request_priority, callback, + pool, net_log); + } + + if (proxy_info.is_socks()) { + SOCKSClientSocketPool* pool = + session->GetSocketPoolForSOCKSProxy(socket_pool_type, *proxy_host_port); + if (num_preconnect_streams) { + RequestSocketsForPool(pool, connection_group, socks_params, + num_preconnect_streams, net_log); + return OK; + } + + return socket_handle->Init(connection_group, socks_params, + request_priority, callback, pool, + net_log); + } + + DCHECK(proxy_info.is_direct()); + + TransportClientSocketPool* pool = + session->GetTransportSocketPool(socket_pool_type); + if (num_preconnect_streams) { + RequestSocketsForPool(pool, connection_group, tcp_params, + num_preconnect_streams, net_log); + return OK; + } + + return socket_handle->Init(connection_group, tcp_params, + request_priority, callback, + pool, net_log); +} + +} // namespace + +ClientSocketPoolManager::ClientSocketPoolManager() {} +ClientSocketPoolManager::~ClientSocketPoolManager() {} + +// static +int ClientSocketPoolManager::max_sockets_per_pool( + HttpNetworkSession::SocketPoolType pool_type) { + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + return g_max_sockets_per_pool[pool_type]; +} + +// static +void ClientSocketPoolManager::set_max_sockets_per_pool( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count) { + DCHECK_LT(0, socket_count); + DCHECK_GT(1000, socket_count); // Sanity check. + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + g_max_sockets_per_pool[pool_type] = socket_count; + DCHECK_GE(g_max_sockets_per_pool[pool_type], + g_max_sockets_per_group[pool_type]); +} + +// static +int ClientSocketPoolManager::max_sockets_per_group( + HttpNetworkSession::SocketPoolType pool_type) { + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + return g_max_sockets_per_group[pool_type]; +} + +// static +void ClientSocketPoolManager::set_max_sockets_per_group( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count) { + DCHECK_LT(0, socket_count); + // The following is a sanity check... but we should NEVER be near this value. + DCHECK_GT(100, socket_count); + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + g_max_sockets_per_group[pool_type] = socket_count; + + DCHECK_GE(g_max_sockets_per_pool[pool_type], + g_max_sockets_per_group[pool_type]); + DCHECK_GE(g_max_sockets_per_proxy_server[pool_type], + g_max_sockets_per_group[pool_type]); +} + +// static +int ClientSocketPoolManager::max_sockets_per_proxy_server( + HttpNetworkSession::SocketPoolType pool_type) { + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + return g_max_sockets_per_proxy_server[pool_type]; +} + +// static +void ClientSocketPoolManager::set_max_sockets_per_proxy_server( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count) { + DCHECK_LT(0, socket_count); + DCHECK_GT(100, socket_count); // Sanity check. + DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES); + // Assert this case early on. The max number of sockets per group cannot + // exceed the max number of sockets per proxy server. + DCHECK_LE(g_max_sockets_per_group[pool_type], socket_count); + g_max_sockets_per_proxy_server[pool_type] = socket_count; +} + +int InitSocketHandleForHttpRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const OnHostResolutionCallback& resolution_callback, + const CompletionCallback& callback) { + DCHECK(socket_handle); + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log, + 0, socket_handle, HttpNetworkSession::NORMAL_SOCKET_POOL, + resolution_callback, callback); +} + +int InitSocketHandleForWebSocketRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const OnHostResolutionCallback& resolution_callback, + const CompletionCallback& callback) { + DCHECK(socket_handle); + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, true, privacy_mode, net_log, + 0, socket_handle, HttpNetworkSession::WEBSOCKET_SOCKET_POOL, + resolution_callback, callback); +} + +int InitSocketHandleForRawConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback) { + DCHECK(socket_handle); + // Synthesize an HttpRequestInfo. + GURL request_url = GURL("http://" + host_port_pair.ToString()); + HttpRequestHeaders request_extra_headers; + int request_load_flags = 0; + RequestPriority request_priority = MEDIUM; + + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, false, false, ssl_config_for_origin, + ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle, + HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), + callback); +} + +int PreconnectSocketsForHttpRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + int num_preconnect_streams) { + return InitSocketPoolHelper( + request_url, request_extra_headers, request_load_flags, request_priority, + session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log, + num_preconnect_streams, NULL, HttpNetworkSession::NORMAL_SOCKET_POOL, + OnHostResolutionCallback(), CompletionCallback()); +} + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool_manager.h b/chromium/net/socket/client_socket_pool_manager.h new file mode 100644 index 00000000000..1b78324f233 --- /dev/null +++ b/chromium/net/socket/client_socket_pool_manager.h @@ -0,0 +1,169 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// ClientSocketPoolManager manages access to all ClientSocketPools. It's a +// simple container for all of them. Most importantly, it handles the lifetime +// and destruction order properly. + +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ + +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/request_priority.h" +#include "net/http/http_network_session.h" + +class GURL; + +namespace base { +class Value; +} + +namespace net { + +typedef base::Callback<int(const AddressList&, const BoundNetLog& net_log)> +OnHostResolutionCallback; + +class BoundNetLog; +class ClientSocketHandle; +class HostPortPair; +class HttpNetworkSession; +class HttpProxyClientSocketPool; +class HttpRequestHeaders; +class ProxyInfo; +class TransportClientSocketPool; +class SOCKSClientSocketPool; +class SSLClientSocketPool; + +struct SSLConfig; + +// This should rather be a simple constant but Windows shared libs doesn't +// really offer much flexiblity in exporting contants. +enum DefaultMaxValues { kDefaultMaxSocketsPerProxyServer = 32 }; + +class NET_EXPORT_PRIVATE ClientSocketPoolManager { + public: + ClientSocketPoolManager(); + virtual ~ClientSocketPoolManager(); + + // The setter methods below affect only newly created socket pools after the + // methods are called. Normally they should be called at program startup + // before any ClientSocketPoolManagerImpl is created. + static int max_sockets_per_pool(HttpNetworkSession::SocketPoolType pool_type); + static void set_max_sockets_per_pool( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count); + + static int max_sockets_per_group( + HttpNetworkSession::SocketPoolType pool_type); + static void set_max_sockets_per_group( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count); + + static int max_sockets_per_proxy_server( + HttpNetworkSession::SocketPoolType pool_type); + static void set_max_sockets_per_proxy_server( + HttpNetworkSession::SocketPoolType pool_type, + int socket_count); + + virtual void FlushSocketPoolsWithError(int error) = 0; + virtual void CloseIdleSockets() = 0; + virtual TransportClientSocketPool* GetTransportSocketPool() = 0; + virtual SSLClientSocketPool* GetSSLSocketPool() = 0; + virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) = 0; + virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) = 0; + virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) = 0; + // Creates a Value summary of the state of the socket pools. The caller is + // responsible for deleting the returned value. + virtual base::Value* SocketPoolInfoToValue() const = 0; +}; + +// A helper method that uses the passed in proxy information to initialize a +// ClientSocketHandle with the relevant socket pool. Use this method for +// HTTP/HTTPS requests. |ssl_config_for_origin| is only used if the request +// uses SSL and |ssl_config_for_proxy| is used if the proxy server is HTTPS. +// |resolution_callback| will be invoked after the the hostname is +// resolved. If |resolution_callback| does not return OK, then the +// connection will be aborted with that value. +int InitSocketHandleForHttpRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const OnHostResolutionCallback& resolution_callback, + const CompletionCallback& callback); + +// A helper method that uses the passed in proxy information to initialize a +// ClientSocketHandle with the relevant socket pool. Use this method for +// HTTP/HTTPS requests for WebSocket handshake. +// |ssl_config_for_origin| is only used if the request +// uses SSL and |ssl_config_for_proxy| is used if the proxy server is HTTPS. +// |resolution_callback| will be invoked after the the hostname is +// resolved. If |resolution_callback| does not return OK, then the +// connection will be aborted with that value. +// This function uses WEBSOCKET_SOCKET_POOL socket pools. +int InitSocketHandleForWebSocketRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const OnHostResolutionCallback& resolution_callback, + const CompletionCallback& callback); + +// A helper method that uses the passed in proxy information to initialize a +// ClientSocketHandle with the relevant socket pool. Use this method for +// a raw socket connection to a host-port pair (that needs to tunnel through +// the proxies). +NET_EXPORT int InitSocketHandleForRawConnect( + const HostPortPair& host_port_pair, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback); + +// Similar to InitSocketHandleForHttpRequest except that it initiates the +// desired number of preconnect streams from the relevant socket pool. +int PreconnectSocketsForHttpRequest( + const GURL& request_url, + const HttpRequestHeaders& request_extra_headers, + int request_load_flags, + RequestPriority request_priority, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + bool force_spdy_over_ssl, + bool want_spdy_over_npn, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + int num_preconnect_streams); + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_ diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc new file mode 100644 index 00000000000..b557874d011 --- /dev/null +++ b/chromium/net/socket/client_socket_pool_manager_impl.cc @@ -0,0 +1,392 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/client_socket_pool_manager_impl.h" + +#include "base/logging.h" +#include "base/values.h" +#include "net/http/http_network_session.h" +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket_pool.h" +#include "net/socket/transport_client_socket_pool.h" +#include "net/ssl/ssl_config_service.h" + +namespace net { + +namespace { + +// Appends information about all |socket_pools| to the end of |list|. +template <class MapType> +void AddSocketPoolsToList(base::ListValue* list, + const MapType& socket_pools, + const std::string& type, + bool include_nested_pools) { + for (typename MapType::const_iterator it = socket_pools.begin(); + it != socket_pools.end(); it++) { + list->Append(it->second->GetInfoAsValue(it->first.ToString(), + type, + include_nested_pools)); + } +} + +} // namespace + +ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( + NetLog* net_log, + ClientSocketFactory* socket_factory, + HostResolver* host_resolver, + CertVerifier* cert_verifier, + ServerBoundCertService* server_bound_cert_service, + TransportSecurityState* transport_security_state, + const std::string& ssl_session_cache_shard, + ProxyService* proxy_service, + SSLConfigService* ssl_config_service, + HttpNetworkSession::SocketPoolType pool_type) + : net_log_(net_log), + socket_factory_(socket_factory), + host_resolver_(host_resolver), + cert_verifier_(cert_verifier), + server_bound_cert_service_(server_bound_cert_service), + transport_security_state_(transport_security_state), + ssl_session_cache_shard_(ssl_session_cache_shard), + proxy_service_(proxy_service), + ssl_config_service_(ssl_config_service), + pool_type_(pool_type), + transport_pool_histograms_("TCP"), + transport_socket_pool_(new TransportClientSocketPool( + max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), + &transport_pool_histograms_, + host_resolver, + socket_factory_, + net_log)), + ssl_pool_histograms_("SSL2"), + ssl_socket_pool_(new SSLClientSocketPool( + max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), + &ssl_pool_histograms_, + host_resolver, + cert_verifier, + server_bound_cert_service, + transport_security_state, + ssl_session_cache_shard, + socket_factory, + transport_socket_pool_.get(), + NULL /* no socks proxy */, + NULL /* no http proxy */, + ssl_config_service, + net_log)), + transport_for_socks_pool_histograms_("TCPforSOCKS"), + socks_pool_histograms_("SOCK"), + transport_for_http_proxy_pool_histograms_("TCPforHTTPProxy"), + transport_for_https_proxy_pool_histograms_("TCPforHTTPSProxy"), + ssl_for_https_proxy_pool_histograms_("SSLforHTTPSProxy"), + http_proxy_pool_histograms_("HTTPProxy"), + ssl_socket_pool_for_proxies_histograms_("SSLForProxies") { + CertDatabase::GetInstance()->AddObserver(this); +} + +ClientSocketPoolManagerImpl::~ClientSocketPoolManagerImpl() { + CertDatabase::GetInstance()->RemoveObserver(this); +} + +void ClientSocketPoolManagerImpl::FlushSocketPoolsWithError(int error) { + // Flush the highest level pools first, since higher level pools may release + // stuff to the lower level pools. + + for (SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.begin(); + it != ssl_socket_pools_for_proxies_.end(); + ++it) + it->second->FlushWithError(error); + + for (HTTPProxySocketPoolMap::const_iterator it = + http_proxy_socket_pools_.begin(); + it != http_proxy_socket_pools_.end(); + ++it) + it->second->FlushWithError(error); + + for (SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_https_proxies_.begin(); + it != ssl_socket_pools_for_https_proxies_.end(); + ++it) + it->second->FlushWithError(error); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_https_proxies_.begin(); + it != transport_socket_pools_for_https_proxies_.end(); + ++it) + it->second->FlushWithError(error); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_http_proxies_.begin(); + it != transport_socket_pools_for_http_proxies_.end(); + ++it) + it->second->FlushWithError(error); + + for (SOCKSSocketPoolMap::const_iterator it = + socks_socket_pools_.begin(); + it != socks_socket_pools_.end(); + ++it) + it->second->FlushWithError(error); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_socks_proxies_.begin(); + it != transport_socket_pools_for_socks_proxies_.end(); + ++it) + it->second->FlushWithError(error); + + ssl_socket_pool_->FlushWithError(error); + transport_socket_pool_->FlushWithError(error); +} + +void ClientSocketPoolManagerImpl::CloseIdleSockets() { + // Close sockets in the highest level pools first, since higher level pools' + // sockets may release stuff to the lower level pools. + for (SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.begin(); + it != ssl_socket_pools_for_proxies_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (HTTPProxySocketPoolMap::const_iterator it = + http_proxy_socket_pools_.begin(); + it != http_proxy_socket_pools_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_https_proxies_.begin(); + it != ssl_socket_pools_for_https_proxies_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_https_proxies_.begin(); + it != transport_socket_pools_for_https_proxies_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_http_proxies_.begin(); + it != transport_socket_pools_for_http_proxies_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (SOCKSSocketPoolMap::const_iterator it = + socks_socket_pools_.begin(); + it != socks_socket_pools_.end(); + ++it) + it->second->CloseIdleSockets(); + + for (TransportSocketPoolMap::const_iterator it = + transport_socket_pools_for_socks_proxies_.begin(); + it != transport_socket_pools_for_socks_proxies_.end(); + ++it) + it->second->CloseIdleSockets(); + + ssl_socket_pool_->CloseIdleSockets(); + transport_socket_pool_->CloseIdleSockets(); +} + +TransportClientSocketPool* +ClientSocketPoolManagerImpl::GetTransportSocketPool() { + return transport_socket_pool_.get(); +} + +SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSSLSocketPool() { + return ssl_socket_pool_.get(); +} + +SOCKSClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) { + SOCKSSocketPoolMap::const_iterator it = socks_socket_pools_.find(socks_proxy); + if (it != socks_socket_pools_.end()) { + DCHECK(ContainsKey(transport_socket_pools_for_socks_proxies_, socks_proxy)); + return it->second; + } + + DCHECK(!ContainsKey(transport_socket_pools_for_socks_proxies_, socks_proxy)); + + std::pair<TransportSocketPoolMap::iterator, bool> tcp_ret = + transport_socket_pools_for_socks_proxies_.insert( + std::make_pair( + socks_proxy, + new TransportClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &transport_for_socks_pool_histograms_, + host_resolver_, + socket_factory_, + net_log_))); + DCHECK(tcp_ret.second); + + std::pair<SOCKSSocketPoolMap::iterator, bool> ret = + socks_socket_pools_.insert( + std::make_pair(socks_proxy, new SOCKSClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &socks_pool_histograms_, + host_resolver_, + tcp_ret.first->second, + net_log_))); + + return ret.first->second; +} + +HttpProxyClientSocketPool* +ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) { + HTTPProxySocketPoolMap::const_iterator it = + http_proxy_socket_pools_.find(http_proxy); + if (it != http_proxy_socket_pools_.end()) { + DCHECK(ContainsKey(transport_socket_pools_for_http_proxies_, http_proxy)); + DCHECK(ContainsKey(transport_socket_pools_for_https_proxies_, http_proxy)); + DCHECK(ContainsKey(ssl_socket_pools_for_https_proxies_, http_proxy)); + return it->second; + } + + DCHECK(!ContainsKey(transport_socket_pools_for_http_proxies_, http_proxy)); + DCHECK(!ContainsKey(transport_socket_pools_for_https_proxies_, http_proxy)); + DCHECK(!ContainsKey(ssl_socket_pools_for_https_proxies_, http_proxy)); + + std::pair<TransportSocketPoolMap::iterator, bool> tcp_http_ret = + transport_socket_pools_for_http_proxies_.insert( + std::make_pair( + http_proxy, + new TransportClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &transport_for_http_proxy_pool_histograms_, + host_resolver_, + socket_factory_, + net_log_))); + DCHECK(tcp_http_ret.second); + + std::pair<TransportSocketPoolMap::iterator, bool> tcp_https_ret = + transport_socket_pools_for_https_proxies_.insert( + std::make_pair( + http_proxy, + new TransportClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &transport_for_https_proxy_pool_histograms_, + host_resolver_, + socket_factory_, + net_log_))); + DCHECK(tcp_https_ret.second); + + std::pair<SSLSocketPoolMap::iterator, bool> ssl_https_ret = + ssl_socket_pools_for_https_proxies_.insert(std::make_pair( + http_proxy, + new SSLClientSocketPool(max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &ssl_for_https_proxy_pool_histograms_, + host_resolver_, + cert_verifier_, + server_bound_cert_service_, + transport_security_state_, + ssl_session_cache_shard_, + socket_factory_, + tcp_https_ret.first->second /* https proxy */, + NULL /* no socks proxy */, + NULL /* no http proxy */, + ssl_config_service_.get(), + net_log_))); + DCHECK(tcp_https_ret.second); + + std::pair<HTTPProxySocketPoolMap::iterator, bool> ret = + http_proxy_socket_pools_.insert( + std::make_pair( + http_proxy, + new HttpProxyClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &http_proxy_pool_histograms_, + host_resolver_, + tcp_http_ret.first->second, + ssl_https_ret.first->second, + net_log_))); + + return ret.first->second; +} + +SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) { + SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.find(proxy_server); + if (it != ssl_socket_pools_for_proxies_.end()) + return it->second; + + SSLClientSocketPool* new_pool = new SSLClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), + &ssl_pool_histograms_, + host_resolver_, + cert_verifier_, + server_bound_cert_service_, + transport_security_state_, + ssl_session_cache_shard_, + socket_factory_, + NULL, /* no tcp pool, we always go through a proxy */ + GetSocketPoolForSOCKSProxy(proxy_server), + GetSocketPoolForHTTPProxy(proxy_server), + ssl_config_service_.get(), + net_log_); + + std::pair<SSLSocketPoolMap::iterator, bool> ret = + ssl_socket_pools_for_proxies_.insert(std::make_pair(proxy_server, + new_pool)); + + return ret.first->second; +} + +base::Value* ClientSocketPoolManagerImpl::SocketPoolInfoToValue() const { + base::ListValue* list = new base::ListValue(); + list->Append(transport_socket_pool_->GetInfoAsValue("transport_socket_pool", + "transport_socket_pool", + false)); + // Third parameter is false because |ssl_socket_pool_| uses + // |transport_socket_pool_| internally, and do not want to add it a second + // time. + list->Append(ssl_socket_pool_->GetInfoAsValue("ssl_socket_pool", + "ssl_socket_pool", + false)); + AddSocketPoolsToList(list, + http_proxy_socket_pools_, + "http_proxy_socket_pool", + true); + AddSocketPoolsToList(list, + socks_socket_pools_, + "socks_socket_pool", + true); + + // Third parameter is false because |ssl_socket_pools_for_proxies_| use + // socket pools in |http_proxy_socket_pools_| and |socks_socket_pools_|. + AddSocketPoolsToList(list, + ssl_socket_pools_for_proxies_, + "ssl_socket_pool_for_proxies", + false); + return list; +} + +void ClientSocketPoolManagerImpl::OnCertAdded(const X509Certificate* cert) { + FlushSocketPoolsWithError(ERR_NETWORK_CHANGED); +} + +void ClientSocketPoolManagerImpl::OnCertTrustChanged( + const X509Certificate* cert) { + // We should flush the socket pools if we removed trust from a + // cert, because a previously trusted server may have become + // untrusted. + // + // We should not flush the socket pools if we added trust to a + // cert. + // + // Since the OnCertTrustChanged method doesn't tell us what + // kind of trust change it is, we have to flush the socket + // pools to be safe. + FlushSocketPoolsWithError(ERR_NETWORK_CHANGED); +} + +} // namespace net diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h new file mode 100644 index 00000000000..8f6e618d2e1 --- /dev/null +++ b/chromium/net/socket/client_socket_pool_manager_impl.h @@ -0,0 +1,150 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_ +#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_ + +#include <map> +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/stl_util.h" +#include "base/template_util.h" +#include "base/threading/non_thread_safe.h" +#include "net/cert/cert_database.h" +#include "net/http/http_network_session.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/client_socket_pool_manager.h" + +namespace net { + +class CertVerifier; +class ClientSocketFactory; +class ClientSocketPoolHistograms; +class HttpProxyClientSocketPool; +class HostResolver; +class NetLog; +class ServerBoundCertService; +class ProxyService; +class SOCKSClientSocketPool; +class SSLClientSocketPool; +class SSLConfigService; +class TransportClientSocketPool; +class TransportSecurityState; + +namespace internal { + +// A helper class for auto-deleting Values in the destructor. +template <typename Key, typename Value> +class OwnedPoolMap : public std::map<Key, Value> { + public: + OwnedPoolMap() { + COMPILE_ASSERT(base::is_pointer<Value>::value, + value_must_be_a_pointer); + } + + ~OwnedPoolMap() { + STLDeleteValues(this); + } +}; + +} // namespace internal + +class ClientSocketPoolManagerImpl : public base::NonThreadSafe, + public ClientSocketPoolManager, + public CertDatabase::Observer { + public: + ClientSocketPoolManagerImpl(NetLog* net_log, + ClientSocketFactory* socket_factory, + HostResolver* host_resolver, + CertVerifier* cert_verifier, + ServerBoundCertService* server_bound_cert_service, + TransportSecurityState* transport_security_state, + const std::string& ssl_session_cache_shard, + ProxyService* proxy_service, + SSLConfigService* ssl_config_service, + HttpNetworkSession::SocketPoolType pool_type); + virtual ~ClientSocketPoolManagerImpl(); + + virtual void FlushSocketPoolsWithError(int error) OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; + + virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE; + + virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE; + + virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) OVERRIDE; + + virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) OVERRIDE; + + virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) OVERRIDE; + + // Creates a Value summary of the state of the socket pools. The caller is + // responsible for deleting the returned value. + virtual base::Value* SocketPoolInfoToValue() const OVERRIDE; + + // CertDatabase::Observer methods: + virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE; + virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE; + + private: + typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*> + TransportSocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, SOCKSClientSocketPool*> + SOCKSSocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, HttpProxyClientSocketPool*> + HTTPProxySocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, SSLClientSocketPool*> + SSLSocketPoolMap; + + NetLog* const net_log_; + ClientSocketFactory* const socket_factory_; + HostResolver* const host_resolver_; + CertVerifier* const cert_verifier_; + ServerBoundCertService* const server_bound_cert_service_; + TransportSecurityState* const transport_security_state_; + const std::string ssl_session_cache_shard_; + ProxyService* const proxy_service_; + const scoped_refptr<SSLConfigService> ssl_config_service_; + const HttpNetworkSession::SocketPoolType pool_type_; + + // Note: this ordering is important. + + ClientSocketPoolHistograms transport_pool_histograms_; + scoped_ptr<TransportClientSocketPool> transport_socket_pool_; + + ClientSocketPoolHistograms ssl_pool_histograms_; + scoped_ptr<SSLClientSocketPool> ssl_socket_pool_; + + ClientSocketPoolHistograms transport_for_socks_pool_histograms_; + TransportSocketPoolMap transport_socket_pools_for_socks_proxies_; + + ClientSocketPoolHistograms socks_pool_histograms_; + SOCKSSocketPoolMap socks_socket_pools_; + + ClientSocketPoolHistograms transport_for_http_proxy_pool_histograms_; + TransportSocketPoolMap transport_socket_pools_for_http_proxies_; + + ClientSocketPoolHistograms transport_for_https_proxy_pool_histograms_; + TransportSocketPoolMap transport_socket_pools_for_https_proxies_; + + ClientSocketPoolHistograms ssl_for_https_proxy_pool_histograms_; + SSLSocketPoolMap ssl_socket_pools_for_https_proxies_; + + ClientSocketPoolHistograms http_proxy_pool_histograms_; + HTTPProxySocketPoolMap http_proxy_socket_pools_; + + ClientSocketPoolHistograms ssl_socket_pool_for_proxies_histograms_; + SSLSocketPoolMap ssl_socket_pools_for_proxies_; + + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolManagerImpl); +}; + +} // namespace net + +#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_ diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc new file mode 100644 index 00000000000..eba01b5e9cc --- /dev/null +++ b/chromium/net/socket/deterministic_socket_data_unittest.cc @@ -0,0 +1,621 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_test_util.h" + +#include <string.h> + +#include "base/memory/ref_counted.h" +#include "testing/platform_test.h" +#include "testing/gtest/include/gtest/gtest.h" + +//----------------------------------------------------------------------------- + +namespace { + +static const char kMsg1[] = "\0hello!\xff"; +static const int kLen1 = arraysize(kMsg1); +static const char kMsg2[] = "\012345678\0"; +static const int kLen2 = arraysize(kMsg2); +static const char kMsg3[] = "bye!"; +static const int kLen3 = arraysize(kMsg3); + +} // anonymous namespace + +namespace net { + +class DeterministicSocketDataTest : public PlatformTest { + public: + DeterministicSocketDataTest(); + + virtual void TearDown(); + + void ReentrantReadCallback(int len, int rv); + void ReentrantWriteCallback(const char* data, int len, int rv); + + protected: + void Initialize(MockRead* reads, size_t reads_count, MockWrite* writes, + size_t writes_count); + + void AssertSyncReadEquals(const char* data, int len); + void AssertAsyncReadEquals(const char* data, int len); + void AssertReadReturns(const char* data, int len, int rv); + void AssertReadBufferEquals(const char* data, int len); + + void AssertSyncWriteEquals(const char* data, int len); + void AssertAsyncWriteEquals(const char* data, int len); + void AssertWriteReturns(const char* data, int len, int rv); + + TestCompletionCallback read_callback_; + TestCompletionCallback write_callback_; + StreamSocket* sock_; + scoped_ptr<DeterministicSocketData> data_; + + private: + scoped_refptr<IOBuffer> read_buf_; + MockConnect connect_data_; + + HostPortPair endpoint_; + scoped_refptr<TransportSocketParams> tcp_params_; + ClientSocketPoolHistograms histograms_; + DeterministicMockClientSocketFactory socket_factory_; + MockTransportClientSocketPool socket_pool_; + ClientSocketHandle connection_; + + DISALLOW_COPY_AND_ASSIGN(DeterministicSocketDataTest); +}; + +DeterministicSocketDataTest::DeterministicSocketDataTest() + : sock_(NULL), + read_buf_(NULL), + connect_data_(SYNCHRONOUS, OK), + endpoint_("www.google.com", 443), + tcp_params_(new TransportSocketParams(endpoint_, + LOWEST, + false, + false, + OnHostResolutionCallback())), + histograms_(std::string()), + socket_pool_(10, 10, &histograms_, &socket_factory_) {} + +void DeterministicSocketDataTest::TearDown() { + // Empty the current queue. + base::MessageLoop::current()->RunUntilIdle(); + PlatformTest::TearDown(); +} + +void DeterministicSocketDataTest::Initialize(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) { + data_.reset(new DeterministicSocketData(reads, reads_count, + writes, writes_count)); + data_->set_connect_data(connect_data_); + socket_factory_.AddSocketDataProvider(data_.get()); + + // Perform the TCP connect + EXPECT_EQ(OK, + connection_.Init(endpoint_.ToString(), + tcp_params_, + LOWEST, + CompletionCallback(), + reinterpret_cast<TransportClientSocketPool*>(&socket_pool_), + BoundNetLog())); + sock_ = connection_.socket(); +} + +void DeterministicSocketDataTest::AssertSyncReadEquals(const char* data, + int len) { + // Issue the read, which will complete immediately + AssertReadReturns(data, len, len); + AssertReadBufferEquals(data, len); +} + +void DeterministicSocketDataTest::AssertAsyncReadEquals(const char* data, + int len) { + // Issue the read, which will be completed asynchronously + AssertReadReturns(data, len, ERR_IO_PENDING); + + EXPECT_FALSE(read_callback_.have_result()); + EXPECT_TRUE(sock_->IsConnected()); + data_->RunFor(1); // Runs 1 step, to cause the callbacks to be invoked + + // Now the read should complete + ASSERT_EQ(len, read_callback_.WaitForResult()); + AssertReadBufferEquals(data, len); +} + +void DeterministicSocketDataTest::AssertReadReturns(const char* data, + int len, int rv) { + read_buf_ = new IOBuffer(len); + ASSERT_EQ(rv, sock_->Read(read_buf_.get(), len, read_callback_.callback())); +} + +void DeterministicSocketDataTest::AssertReadBufferEquals(const char* data, + int len) { + ASSERT_EQ(std::string(data, len), std::string(read_buf_->data(), len)); +} + +void DeterministicSocketDataTest::AssertSyncWriteEquals(const char* data, + int len) { + scoped_refptr<IOBuffer> buf(new IOBuffer(len)); + memcpy(buf->data(), data, len); + + // Issue the write, which will complete immediately + ASSERT_EQ(len, sock_->Write(buf.get(), len, write_callback_.callback())); +} + +void DeterministicSocketDataTest::AssertAsyncWriteEquals(const char* data, + int len) { + // Issue the read, which will be completed asynchronously + AssertWriteReturns(data, len, ERR_IO_PENDING); + + EXPECT_FALSE(read_callback_.have_result()); + EXPECT_TRUE(sock_->IsConnected()); + data_->RunFor(1); // Runs 1 step, to cause the callbacks to be invoked + + ASSERT_EQ(len, write_callback_.WaitForResult()); +} + +void DeterministicSocketDataTest::AssertWriteReturns(const char* data, + int len, int rv) { + scoped_refptr<IOBuffer> buf(new IOBuffer(len)); + memcpy(buf->data(), data, len); + + // Issue the read, which will complete asynchronously + ASSERT_EQ(rv, sock_->Write(buf.get(), len, write_callback_.callback())); +} + +void DeterministicSocketDataTest::ReentrantReadCallback(int len, int rv) { + scoped_refptr<IOBuffer> read_buf(new IOBuffer(len)); + EXPECT_EQ(len, + sock_->Read( + read_buf.get(), + len, + base::Bind(&DeterministicSocketDataTest::ReentrantReadCallback, + base::Unretained(this), + len))); +} + +void DeterministicSocketDataTest::ReentrantWriteCallback( + const char* data, int len, int rv) { + scoped_refptr<IOBuffer> write_buf(new IOBuffer(len)); + memcpy(write_buf->data(), data, len); + EXPECT_EQ(len, + sock_->Write( + write_buf.get(), + len, + base::Bind(&DeterministicSocketDataTest::ReentrantWriteCallback, + base::Unretained(this), + data, + len))); +} + +// ----------- Read + +TEST_F(DeterministicSocketDataTest, SingleSyncReadWhileStopped) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(SYNCHRONOUS, 0, 1), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + data_->SetStopped(true); + AssertReadReturns(kMsg1, kLen1, ERR_UNEXPECTED); +} + +TEST_F(DeterministicSocketDataTest, SingleSyncReadTooEarly) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 1), // Sync Read + MockRead(SYNCHRONOUS, 0, 2), // EOF + }; + + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, 0, 0) + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + data_->StopAfter(2); + ASSERT_FALSE(data_->stopped()); + AssertReadReturns(kMsg1, kLen1, ERR_UNEXPECTED); +} + +TEST_F(DeterministicSocketDataTest, SingleSyncRead) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(SYNCHRONOUS, 0, 1), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + // Make sure we don't stop before we've read all the data + data_->StopAfter(1); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MultipleSyncReads) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Read + MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Read + MockRead(SYNCHRONOUS, kMsg3, kLen3, 3), // Sync Read + MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Read + MockRead(SYNCHRONOUS, kMsg3, kLen3, 5), // Sync Read + MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Read + MockRead(SYNCHRONOUS, 0, 7), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + // Make sure we don't stop before we've read all the data + data_->StopAfter(10); + AssertSyncReadEquals(kMsg1, kLen1); + AssertSyncReadEquals(kMsg2, kLen2); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg2, kLen2); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, SingleAsyncRead) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read + MockRead(SYNCHRONOUS, 0, 1), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + AssertAsyncReadEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MultipleAsyncReads) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read + MockRead(ASYNC, kMsg2, kLen2, 1), // Async Read + MockRead(ASYNC, kMsg3, kLen3, 2), // Async Read + MockRead(ASYNC, kMsg3, kLen3, 3), // Async Read + MockRead(ASYNC, kMsg2, kLen2, 4), // Async Read + MockRead(ASYNC, kMsg3, kLen3, 5), // Async Read + MockRead(ASYNC, kMsg1, kLen1, 6), // Async Read + MockRead(SYNCHRONOUS, 0, 7), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + AssertAsyncReadEquals(kMsg1, kLen1); + AssertAsyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MixedReads) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(ASYNC, kMsg2, kLen2, 1), // Async Read + MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Read + MockRead(ASYNC, kMsg3, kLen3, 3), // Async Read + MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Read + MockRead(ASYNC, kMsg3, kLen3, 5), // Async Read + MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Read + MockRead(SYNCHRONOUS, 0, 7), // EOF + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + data_->StopAfter(1); + AssertSyncReadEquals(kMsg1, kLen1); + AssertAsyncReadEquals(kMsg2, kLen2); + data_->StopAfter(1); + AssertSyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg3, kLen3); + data_->StopAfter(1); + AssertSyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + data_->StopAfter(1); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, SyncReadFromCompletionCallback) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read + MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Read + }; + + Initialize(reads, arraysize(reads), NULL, 0); + + data_->StopAfter(2); + + scoped_refptr<IOBuffer> read_buf(new IOBuffer(kLen1)); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Read( + read_buf.get(), + kLen1, + base::Bind(&DeterministicSocketDataTest::ReentrantReadCallback, + base::Unretained(this), + kLen2))); + data_->Run(); +} + +// ----------- Write + +TEST_F(DeterministicSocketDataTest, SingleSyncWriteWhileStopped) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + data_->SetStopped(true); + AssertWriteReturns(kMsg1, kLen1, ERR_UNEXPECTED); +} + +TEST_F(DeterministicSocketDataTest, SingleSyncWriteTooEarly) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 1), // Sync Write + }; + + MockRead reads[] = { + MockRead(SYNCHRONOUS, 0, 0) + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + data_->StopAfter(2); + ASSERT_FALSE(data_->stopped()); + AssertWriteReturns(kMsg1, kLen1, ERR_UNEXPECTED); +} + +TEST_F(DeterministicSocketDataTest, SingleSyncWrite) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + // Make sure we don't stop before we've read all the data + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MultipleSyncWrites) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 3), // Sync Write + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 5), // Sync Write + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + // Make sure we don't stop before we've read all the data + data_->StopAfter(10); + AssertSyncWriteEquals(kMsg1, kLen1); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, SingleAsyncWrite) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + AssertAsyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MultipleAsyncWrites) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write + MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write + MockWrite(ASYNC, kMsg3, kLen3, 2), // Async Write + MockWrite(ASYNC, kMsg3, kLen3, 3), // Async Write + MockWrite(ASYNC, kMsg2, kLen2, 4), // Async Write + MockWrite(ASYNC, kMsg3, kLen3, 5), // Async Write + MockWrite(ASYNC, kMsg1, kLen1, 6), // Async Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + AssertAsyncWriteEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, MixedWrites) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write + MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write + MockWrite(ASYNC, kMsg3, kLen3, 3), // Async Write + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Write + MockWrite(ASYNC, kMsg3, kLen3, 5), // Async Write + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg3, kLen3); + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(DeterministicSocketDataTest, SyncWriteFromCompletionCallback) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write + }; + + Initialize(NULL, 0, writes, arraysize(writes)); + + data_->StopAfter(2); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Write( + write_buf.get(), + kLen1, + base::Bind(&DeterministicSocketDataTest::ReentrantWriteCallback, + base::Unretained(this), + kMsg2, + kLen2))); + data_->Run(); +} + +// ----------- Mixed Reads and Writes + +TEST_F(DeterministicSocketDataTest, MixedSyncOperations) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(SYNCHRONOUS, kMsg2, kLen2, 3), // Sync Read + MockRead(SYNCHRONOUS, 0, 4), // EOF + }; + + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + // Make sure we don't stop before we've read/written everything + data_->StopAfter(10); + AssertSyncReadEquals(kMsg1, kLen1); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg2, kLen2); +} + +TEST_F(DeterministicSocketDataTest, MixedAsyncOperations) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), // Sync Read + MockRead(ASYNC, kMsg2, kLen2, 3), // Sync Read + MockRead(ASYNC, 0, 4), // EOF + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), // Sync Write + MockWrite(ASYNC, kMsg3, kLen3, 2), // Sync Write + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + AssertAsyncReadEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg2, kLen2); +} + +TEST_F(DeterministicSocketDataTest, InterleavedAsyncOperations) { + // Order of completion is read, write, write, read + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read + MockRead(ASYNC, kMsg2, kLen2, 3), // Async Read + MockRead(ASYNC, 0, 4), // EOF + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write + MockWrite(ASYNC, kMsg3, kLen3, 2), // Async Write + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + // Issue the write, which will block until the read completes + AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the read which will return first + AssertReadReturns(kMsg1, kLen1, ERR_IO_PENDING); + + data_->RunFor(1); + ASSERT_TRUE(read_callback_.have_result()); + ASSERT_EQ(kLen1, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg1, kLen1); + + data_->RunFor(1); + ASSERT_TRUE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); + + data_->StopAfter(1); + // Issue the read, which will block until the write completes + AssertReadReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the writes which will return first + AssertWriteReturns(kMsg3, kLen3, ERR_IO_PENDING); + + data_->RunFor(1); + ASSERT_TRUE(write_callback_.have_result()); + ASSERT_EQ(kLen3, write_callback_.WaitForResult()); + + data_->RunFor(1); + ASSERT_TRUE(read_callback_.have_result()); + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); +} + +TEST_F(DeterministicSocketDataTest, InterleavedMixedOperations) { + // Order of completion is read, write, write, read + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read + MockRead(ASYNC, kMsg2, kLen2, 3), // Async Read + MockRead(SYNCHRONOUS, 0, 4), // EOF + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + // Issue the write, which will block until the read completes + AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the writes which will complete immediately + data_->StopAfter(1); + AssertSyncReadEquals(kMsg1, kLen1); + + data_->RunFor(1); + ASSERT_TRUE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); + + // Issue the read, which will block until the write completes + AssertReadReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the writes which will complete immediately + data_->StopAfter(1); + AssertSyncWriteEquals(kMsg3, kLen3); + + data_->RunFor(1); + ASSERT_TRUE(read_callback_.have_result()); + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); +} + +} // namespace net diff --git a/chromium/net/socket/mock_client_socket_pool_manager.cc b/chromium/net/socket/mock_client_socket_pool_manager.cc new file mode 100644 index 00000000000..0496adb9636 --- /dev/null +++ b/chromium/net/socket/mock_client_socket_pool_manager.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/mock_client_socket_pool_manager.h" + +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket_pool.h" +#include "net/socket/transport_client_socket_pool.h" + +namespace net { + +MockClientSocketPoolManager::MockClientSocketPoolManager() {} +MockClientSocketPoolManager::~MockClientSocketPoolManager() {} + +void MockClientSocketPoolManager::SetTransportSocketPool( + TransportClientSocketPool* pool) { + transport_socket_pool_.reset(pool); +} + +void MockClientSocketPoolManager::SetSSLSocketPool( + SSLClientSocketPool* pool) { + ssl_socket_pool_.reset(pool); +} + +void MockClientSocketPoolManager::SetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy, + SOCKSClientSocketPool* pool) { + socks_socket_pools_[socks_proxy] = pool; +} + +void MockClientSocketPoolManager::SetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy, + HttpProxyClientSocketPool* pool) { + http_proxy_socket_pools_[http_proxy] = pool; +} + +void MockClientSocketPoolManager::SetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server, + SSLClientSocketPool* pool) { + ssl_socket_pools_for_proxies_[proxy_server] = pool; +} + +void MockClientSocketPoolManager::FlushSocketPoolsWithError(int error) { + NOTIMPLEMENTED(); +} + +void MockClientSocketPoolManager::CloseIdleSockets() { + NOTIMPLEMENTED(); +} + +TransportClientSocketPool* +MockClientSocketPoolManager::GetTransportSocketPool() { + return transport_socket_pool_.get(); +} + +SSLClientSocketPool* MockClientSocketPoolManager::GetSSLSocketPool() { + return ssl_socket_pool_.get(); +} + +SOCKSClientSocketPool* MockClientSocketPoolManager::GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) { + SOCKSSocketPoolMap::const_iterator it = socks_socket_pools_.find(socks_proxy); + if (it != socks_socket_pools_.end()) + return it->second; + return NULL; +} + +HttpProxyClientSocketPool* +MockClientSocketPoolManager::GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) { + HTTPProxySocketPoolMap::const_iterator it = + http_proxy_socket_pools_.find(http_proxy); + if (it != http_proxy_socket_pools_.end()) + return it->second; + return NULL; +} + +SSLClientSocketPool* MockClientSocketPoolManager::GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) { + SSLSocketPoolMap::const_iterator it = + ssl_socket_pools_for_proxies_.find(proxy_server); + if (it != ssl_socket_pools_for_proxies_.end()) + return it->second; + return NULL; +} + +base::Value* MockClientSocketPoolManager::SocketPoolInfoToValue() const { + NOTIMPLEMENTED(); + return NULL; +} + +} // namespace net diff --git a/chromium/net/socket/mock_client_socket_pool_manager.h b/chromium/net/socket/mock_client_socket_pool_manager.h new file mode 100644 index 00000000000..c2c3792a4f6 --- /dev/null +++ b/chromium/net/socket/mock_client_socket_pool_manager.h @@ -0,0 +1,63 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_ +#define NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_ + +#include "base/basictypes.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/client_socket_pool_manager_impl.h" + +namespace net { + +class MockClientSocketPoolManager : public ClientSocketPoolManager { + public: + MockClientSocketPoolManager(); + virtual ~MockClientSocketPoolManager(); + + // Sets "override" socket pools that get used instead. + void SetTransportSocketPool(TransportClientSocketPool* pool); + void SetSSLSocketPool(SSLClientSocketPool* pool); + void SetSocketPoolForSOCKSProxy(const HostPortPair& socks_proxy, + SOCKSClientSocketPool* pool); + void SetSocketPoolForHTTPProxy(const HostPortPair& http_proxy, + HttpProxyClientSocketPool* pool); + void SetSocketPoolForSSLWithProxy(const HostPortPair& proxy_server, + SSLClientSocketPool* pool); + + // ClientSocketPoolManager methods: + virtual void FlushSocketPoolsWithError(int error) OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; + virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE; + virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE; + virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) OVERRIDE; + virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) OVERRIDE; + virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) OVERRIDE; + virtual base::Value* SocketPoolInfoToValue() const OVERRIDE; + + private: + typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*> + TransportSocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, SOCKSClientSocketPool*> + SOCKSSocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, HttpProxyClientSocketPool*> + HTTPProxySocketPoolMap; + typedef internal::OwnedPoolMap<HostPortPair, SSLClientSocketPool*> + SSLSocketPoolMap; + + scoped_ptr<TransportClientSocketPool> transport_socket_pool_; + scoped_ptr<SSLClientSocketPool> ssl_socket_pool_; + SOCKSSocketPoolMap socks_socket_pools_; + HTTPProxySocketPoolMap http_proxy_socket_pools_; + SSLSocketPoolMap ssl_socket_pools_for_proxies_; + + DISALLOW_COPY_AND_ASSIGN(MockClientSocketPoolManager); +}; + +} // namespace net + +#endif // NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_ diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h new file mode 100644 index 00000000000..0bd307a3778 --- /dev/null +++ b/chromium/net/socket/next_proto.h @@ -0,0 +1,39 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_NEXT_PROTO_H_ +#define NET_SOCKET_NEXT_PROTO_H_ + +namespace net { + +// Next Protocol Negotiation (NPN), if successful, results in agreement on an +// application-level string that specifies the application level protocol to +// use over the TLS connection. NextProto enumerates the application level +// protocols that we recognise. +enum NextProto { + kProtoUnknown = 0, + kProtoHTTP11 = 1, + kProtoMinimumVersion = kProtoHTTP11, + + // TODO(akalin): Stop advertising SPDY/1 and remove this. + kProtoSPDY1 = 2, + kProtoSPDYMinimumVersion = kProtoSPDY1, + kProtoSPDY2 = 3, + // TODO(akalin): Stop adverising SPDY/2.1, too. + kProtoSPDY21 = 4, + kProtoSPDY3 = 5, + kProtoSPDY31 = 6, + kProtoSPDY4a2 = 7, + // We lump in HTTP/2 with the SPDY protocols for now. + kProtoHTTP2Draft04 = 8, + kProtoSPDYMaximumVersion = kProtoHTTP2Draft04, + + kProtoQUIC1SPDY3 = 9, + + kProtoMaximumVersion = kProtoQUIC1SPDY3, +}; + +} // namespace net + +#endif // NET_SOCKET_NEXT_PROTO_H_ diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc new file mode 100644 index 00000000000..ae037b17c1e --- /dev/null +++ b/chromium/net/socket/nss_ssl_util.cc @@ -0,0 +1,276 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/nss_ssl_util.h" + +#include <nss.h> +#include <secerr.h> +#include <ssl.h> +#include <sslerr.h> +#include <sslproto.h> + +#include <string> + +#include "base/bind.h" +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/memory/singleton.h" +#include "base/threading/thread_restrictions.h" +#include "base/values.h" +#include "build/build_config.h" +#include "crypto/nss_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +#if defined(OS_WIN) +#include "base/win/windows_version.h" +#endif + +namespace net { + +class NSSSSLInitSingleton { + public: + NSSSSLInitSingleton() { + crypto::EnsureNSSInit(); + + NSS_SetDomesticPolicy(); + + const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers(); + const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers(); + + // Disable ECDSA cipher suites on platforms that do not support ECDSA + // signed certificates, as servers may use the presence of such + // ciphersuites as a hint to send an ECDSA certificate. + bool disableECDSA = false; +#if defined(OS_WIN) + if (base::win::GetVersion() < base::win::VERSION_VISTA) + disableECDSA = true; +#endif + + // Explicitly enable exactly those ciphers with keys of at least 80 bits + for (int i = 0; i < num_ciphers; i++) { + SSLCipherSuiteInfo info; + if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info, + sizeof(info)) == SECSuccess) { + bool enabled = info.effectiveKeyBits >= 80; + if (info.authAlgorithm == ssl_auth_ecdsa && disableECDSA) + enabled = false; + + // Trim the list of cipher suites in order to keep the size of the + // ClientHello down. DSS, ECDH, CAMELLIA, SEED, ECC+3DES, and + // HMAC-SHA256 cipher suites are disabled. + if (info.symCipher == ssl_calg_camellia || + info.symCipher == ssl_calg_seed || + (info.symCipher == ssl_calg_3des && info.keaType != ssl_kea_rsa) || + info.authAlgorithm == ssl_auth_dsa || + info.macAlgorithm == ssl_hmac_sha256 || + info.nonStandard || + strcmp(info.keaTypeName, "ECDH") == 0) { + enabled = false; + } + + if (ssl_ciphers[i] == TLS_DHE_DSS_WITH_AES_128_CBC_SHA) { + // Enabled to allow servers with only a DSA certificate to function. + enabled = true; + } + SSL_CipherPrefSetDefault(ssl_ciphers[i], enabled); + } + } + + // Enable SSL. + SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE); + + // All other SSL options are set per-session by SSLClientSocket and + // SSLServerSocket. + } + + ~NSSSSLInitSingleton() { + // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY. + SSL_ClearSessionCache(); + } +}; + +static base::LazyInstance<NSSSSLInitSingleton> g_nss_ssl_init_singleton = + LAZY_INSTANCE_INITIALIZER; + +// Initialize the NSS SSL library if it isn't already initialized. This must +// be called before any other NSS SSL functions. This function is +// thread-safe, and the NSS SSL library will only ever be initialized once. +// The NSS SSL library will be properly shut down on program exit. +void EnsureNSSSSLInit() { + // Initializing SSL causes us to do blocking IO. + // Temporarily allow it until we fix + // http://code.google.com/p/chromium/issues/detail?id=59847 + base::ThreadRestrictions::ScopedAllowIO allow_io; + + g_nss_ssl_init_singleton.Get(); +} + +// Map a Chromium net error code to an NSS error code. +// See _MD_unix_map_default_error in the NSS source +// tree for inspiration. +PRErrorCode MapErrorToNSS(int result) { + if (result >=0) + return result; + + switch (result) { + case ERR_IO_PENDING: + return PR_WOULD_BLOCK_ERROR; + case ERR_ACCESS_DENIED: + case ERR_NETWORK_ACCESS_DENIED: + // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR. + return PR_NO_ACCESS_RIGHTS_ERROR; + case ERR_NOT_IMPLEMENTED: + return PR_NOT_IMPLEMENTED_ERROR; + case ERR_SOCKET_NOT_CONNECTED: + return PR_NOT_CONNECTED_ERROR; + case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN. + return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation. + case ERR_CONNECTION_TIMED_OUT: + case ERR_TIMED_OUT: + return PR_IO_TIMEOUT_ERROR; + case ERR_CONNECTION_RESET: + return PR_CONNECT_RESET_ERROR; + case ERR_CONNECTION_ABORTED: + return PR_CONNECT_ABORTED_ERROR; + case ERR_CONNECTION_REFUSED: + return PR_CONNECT_REFUSED_ERROR; + case ERR_ADDRESS_UNREACHABLE: + return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR. + case ERR_ADDRESS_INVALID: + return PR_ADDRESS_NOT_AVAILABLE_ERROR; + case ERR_NAME_NOT_RESOLVED: + return PR_DIRECTORY_LOOKUP_ERROR; + default: + LOG(WARNING) << "MapErrorToNSS " << result + << " mapped to PR_UNKNOWN_ERROR"; + return PR_UNKNOWN_ERROR; + } +} + +// The default error mapping function. +// Maps an NSS error code to a network error code. +int MapNSSError(PRErrorCode err) { + // TODO(port): fill this out as we learn what's important + switch (err) { + case PR_WOULD_BLOCK_ERROR: + return ERR_IO_PENDING; + case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect. + case PR_NO_ACCESS_RIGHTS_ERROR: + return ERR_ACCESS_DENIED; + case PR_IO_TIMEOUT_ERROR: + return ERR_TIMED_OUT; + case PR_CONNECT_RESET_ERROR: + return ERR_CONNECTION_RESET; + case PR_CONNECT_ABORTED_ERROR: + return ERR_CONNECTION_ABORTED; + case PR_CONNECT_REFUSED_ERROR: + return ERR_CONNECTION_REFUSED; + case PR_NOT_CONNECTED_ERROR: + return ERR_SOCKET_NOT_CONNECTED; + case PR_HOST_UNREACHABLE_ERROR: + case PR_NETWORK_UNREACHABLE_ERROR: + return ERR_ADDRESS_UNREACHABLE; + case PR_ADDRESS_NOT_AVAILABLE_ERROR: + return ERR_ADDRESS_INVALID; + case PR_INVALID_ARGUMENT_ERROR: + return ERR_INVALID_ARGUMENT; + case PR_END_OF_FILE_ERROR: + return ERR_CONNECTION_CLOSED; + case PR_NOT_IMPLEMENTED_ERROR: + return ERR_NOT_IMPLEMENTED; + + case SEC_ERROR_LIBRARY_FAILURE: + return ERR_UNEXPECTED; + case SEC_ERROR_INVALID_ARGS: + return ERR_INVALID_ARGUMENT; + case SEC_ERROR_NO_MEMORY: + return ERR_OUT_OF_MEMORY; + case SEC_ERROR_NO_KEY: + return ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY; + case SEC_ERROR_INVALID_KEY: + case SSL_ERROR_SIGN_HASHES_FAILURE: + LOG(ERROR) << "ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED: NSS error " << err + << ", OS error " << PR_GetOSError(); + return ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED; + // A handshake (initial or renegotiation) may fail because some signature + // (for example, the signature in the ServerKeyExchange message for an + // ephemeral Diffie-Hellman cipher suite) is invalid. + case SEC_ERROR_BAD_SIGNATURE: + return ERR_SSL_PROTOCOL_ERROR; + + case SSL_ERROR_SSL_DISABLED: + return ERR_NO_SSL_VERSIONS_ENABLED; + case SSL_ERROR_NO_CYPHER_OVERLAP: + case SSL_ERROR_PROTOCOL_VERSION_ALERT: + case SSL_ERROR_UNSUPPORTED_VERSION: + return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; + case SSL_ERROR_HANDSHAKE_FAILURE_ALERT: + case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT: + case SSL_ERROR_ILLEGAL_PARAMETER_ALERT: + return ERR_SSL_PROTOCOL_ERROR; + case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT: + return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; + case SSL_ERROR_BAD_MAC_ALERT: + return ERR_SSL_BAD_RECORD_MAC_ALERT; + case SSL_ERROR_DECRYPT_ERROR_ALERT: + return ERR_SSL_DECRYPT_ERROR_ALERT; + case SSL_ERROR_UNSAFE_NEGOTIATION: + return ERR_SSL_UNSAFE_NEGOTIATION; + case SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY: + return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; + case SSL_ERROR_HANDSHAKE_NOT_COMPLETED: + return ERR_SSL_HANDSHAKE_NOT_COMPLETED; + case SEC_ERROR_BAD_KEY: + case SSL_ERROR_EXTRACT_PUBLIC_KEY_FAILURE: + // TODO(wtc): the following errors may also occur in contexts unrelated + // to the peer's public key. We should add new error codes for them, or + // map them to ERR_SSL_BAD_PEER_PUBLIC_KEY only in the right context. + // General unsupported/unknown key algorithm error. + case SEC_ERROR_UNSUPPORTED_KEYALG: + // General DER decoding errors. + case SEC_ERROR_BAD_DER: + case SEC_ERROR_EXTRA_INPUT: + return ERR_SSL_BAD_PEER_PUBLIC_KEY; + + default: { + if (IS_SSL_ERROR(err)) { + LOG(WARNING) << "Unknown SSL error " << err + << " mapped to net::ERR_SSL_PROTOCOL_ERROR"; + return ERR_SSL_PROTOCOL_ERROR; + } + LOG(WARNING) << "Unknown error " << err << " mapped to net::ERR_FAILED"; + return ERR_FAILED; + } + } +} + +// Returns parameters to attach to the NetLog when we receive an error in +// response to a call to an NSS function. Used instead of +// NetLogSSLErrorCallback with events of type TYPE_SSL_NSS_ERROR. +base::Value* NetLogSSLFailedNSSFunctionCallback( + const char* function, + const char* param, + int ssl_lib_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("function", function); + if (param[0] != '\0') + dict->SetString("param", param); + dict->SetInteger("ssl_lib_error", ssl_lib_error); + return dict; +} + +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param) { + DCHECK(function); + DCHECK(param); + net_log.AddEvent( + NetLog::TYPE_SSL_NSS_ERROR, + base::Bind(&NetLogSSLFailedNSSFunctionCallback, + function, param, PR_GetError())); +} + +} // namespace net diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h new file mode 100644 index 00000000000..09ae3562cd7 --- /dev/null +++ b/chromium/net/socket/nss_ssl_util.h @@ -0,0 +1,35 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file is only included in ssl_client_socket_nss.cc and +// ssl_server_socket_nss.cc to share common functions of NSS. + +#ifndef NET_SOCKET_NSS_SSL_UTIL_H_ +#define NET_SOCKET_NSS_SSL_UTIL_H_ + +#include <prerror.h> + +#include "net/base/net_export.h" + +namespace net { + +class BoundNetLog; + +// Initalize NSS SSL library. +NET_EXPORT void EnsureNSSSSLInit(); + +// Log a failed NSS funcion call. +void LogFailedNSSFunction(const BoundNetLog& net_log, + const char* function, + const char* param); + +// Map network error code to NSS error code. +PRErrorCode MapErrorToNSS(int result); + +// Map NSS error code to network error code. +int MapNSSError(PRErrorCode err); + +} // namespace net + +#endif // NET_SOCKET_NSS_SSL_UTIL_H_ diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h new file mode 100644 index 00000000000..11151eea153 --- /dev/null +++ b/chromium/net/socket/server_socket.h @@ -0,0 +1,40 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SERVER_SOCKET_H_ +#define NET_SOCKET_SERVER_SOCKET_H_ + +#include "base/memory/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" + +namespace net { + +class IPEndPoint; +class StreamSocket; + +class NET_EXPORT ServerSocket { + public: + ServerSocket() { } + virtual ~ServerSocket() { } + + // Bind the socket and start listening. Destroy the socket to stop + // listening. + virtual int Listen(const net::IPEndPoint& address, int backlog) = 0; + + // Gets current address the socket is bound to. + virtual int GetLocalAddress(IPEndPoint* address) const = 0; + + // Accept connection. Callback is called when new connection is + // accepted. + virtual int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(ServerSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_SERVER_SOCKET_H_ diff --git a/chromium/net/socket/socket.h b/chromium/net/socket/socket.h new file mode 100644 index 00000000000..fccb25873a4 --- /dev/null +++ b/chromium/net/socket/socket.h @@ -0,0 +1,62 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_H_ +#define NET_SOCKET_SOCKET_H_ + +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" + +namespace net { + +class IOBuffer; + +// Represents a read/write socket. +class NET_EXPORT Socket { + public: + virtual ~Socket() {} + + // Reads data, up to |buf_len| bytes, from the socket. The number of bytes + // read is returned, or an error is returned upon failure. + // ERR_SOCKET_NOT_CONNECTED should be returned if the socket is not currently + // connected. Zero is returned once to indicate end-of-file; the return value + // of subsequent calls is undefined, and may be OS dependent. ERR_IO_PENDING + // is returned if the operation could not be completed synchronously, in which + // case the result will be passed to the callback when available. If the + // operation is not completed immediately, the socket acquires a reference to + // the provided buffer until the callback is invoked or the socket is + // closed. If the socket is Disconnected before the read completes, the + // callback will not be invoked. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; + + // Writes data, up to |buf_len| bytes, to the socket. Note: data may be + // written partially. The number of bytes written is returned, or an error + // is returned upon failure. ERR_SOCKET_NOT_CONNECTED should be returned if + // the socket is not currently connected. The return value when the + // connection is closed is undefined, and may be OS dependent. ERR_IO_PENDING + // is returned if the operation could not be completed synchronously, in which + // case the result will be passed to the callback when available. If the + // operation is not completed immediately, the socket acquires a reference to + // the provided buffer until the callback is invoked or the socket is + // closed. Implementations of this method should not modify the contents + // of the actual buffer that is written to the socket. If the socket is + // Disconnected before the write completes, the callback will not be invoked. + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; + + // Set the receive buffer size (in bytes) for the socket. + // Note: changing this value can affect the TCP window size on some platforms. + // Returns true on success, or false on failure. + virtual bool SetReceiveBufferSize(int32 size) = 0; + + // Set the send buffer size (in bytes) for the socket. + // Note: changing this value can affect the TCP window size on some platforms. + // Returns true on success, or false on failure. + virtual bool SetSendBufferSize(int32 size) = 0; +}; + +} // namespace net + +#endif // NET_SOCKET_SOCKET_H_ diff --git a/chromium/net/socket/socket_net_log_params.cc b/chromium/net/socket/socket_net_log_params.cc new file mode 100644 index 00000000000..bcc12c86bcf --- /dev/null +++ b/chromium/net/socket/socket_net_log_params.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_net_log_params.h" + +#include "base/bind.h" +#include "base/values.h" +#include "net/base/host_port_pair.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_util.h" + +namespace net { + +namespace { + +base::Value* NetLogSocketErrorCallback(int net_error, + int os_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("net_error", net_error); + dict->SetInteger("os_error", os_error); + return dict; +} + +base::Value* NetLogHostPortPairCallback(const HostPortPair* host_and_port, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("host_and_port", host_and_port->ToString()); + return dict; +} + +base::Value* NetLogIPEndPointCallback(const IPEndPoint* address, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("address", address->ToString()); + return dict; +} + +base::Value* NetLogSourceAddressCallback(const struct sockaddr* net_address, + socklen_t address_len, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("source_address", + NetAddressToStringWithPort(net_address, address_len)); + return dict; +} + +} // namespace + +NetLog::ParametersCallback CreateNetLogSocketErrorCallback(int net_error, + int os_error) { + return base::Bind(&NetLogSocketErrorCallback, net_error, os_error); +} + +NetLog::ParametersCallback CreateNetLogHostPortPairCallback( + const HostPortPair* host_and_port) { + return base::Bind(&NetLogHostPortPairCallback, host_and_port); +} + +NetLog::ParametersCallback CreateNetLogIPEndPointCallback( + const IPEndPoint* address) { + return base::Bind(&NetLogIPEndPointCallback, address); +} + +NetLog::ParametersCallback CreateNetLogSourceAddressCallback( + const struct sockaddr* net_address, + socklen_t address_len) { + return base::Bind(&NetLogSourceAddressCallback, net_address, address_len); +} + +} // namespace net diff --git a/chromium/net/socket/socket_net_log_params.h b/chromium/net/socket/socket_net_log_params.h new file mode 100644 index 00000000000..f5fe652d125 --- /dev/null +++ b/chromium/net/socket/socket_net_log_params.h @@ -0,0 +1,38 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_ +#define NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_ + +#include "net/base/net_log.h" +#include "net/base/sys_addrinfo.h" + +namespace net { + +class HostPortPair; +class IPEndPoint; + +// Creates a NetLog callback for socket error events. +NetLog::ParametersCallback CreateNetLogSocketErrorCallback(int net_error, + int os_error); + +// Creates a NetLog callback for a HostPortPair. +// |host_and_port| must remain valid for the lifetime of the returned callback. +NetLog::ParametersCallback CreateNetLogHostPortPairCallback( + const HostPortPair* host_and_port); + +// Creates a NetLog callback for an IPEndPoint. +// |address| must remain valid for the lifetime of the returned callback. +NetLog::ParametersCallback CreateNetLogIPEndPointCallback( + const IPEndPoint* address); + +// Creates a NetLog callback for the source sockaddr on connect events. +// |net_address| must remain valid for the lifetime of the returned callback. +NetLog::ParametersCallback CreateNetLogSourceAddressCallback( + const struct sockaddr* net_address, + socklen_t address_len); + +} // namespace net + +#endif // NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_ diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc new file mode 100644 index 00000000000..159f62e42c6 --- /dev/null +++ b/chromium/net/socket/socket_test_util.cc @@ -0,0 +1,1888 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_test_util.h" + +#include <algorithm> +#include <vector> + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/compiler_specific.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/time/time.h" +#include "net/base/address_family.h" +#include "net/base/address_list.h" +#include "net/base/auth.h" +#include "net/base/load_timing_info.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_info.h" +#include "testing/gtest/include/gtest/gtest.h" + +// Socket events are easier to debug if you log individual reads and writes. +// Enable these if locally debugging, but they are too noisy for the waterfall. +#if 0 +#define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() " +#else +#define NET_TRACE(level, s) EAT_STREAM_PARAMETERS +#endif + +namespace net { + +namespace { + +inline char AsciifyHigh(char x) { + char nybble = static_cast<char>((x >> 4) & 0x0F); + return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); +} + +inline char AsciifyLow(char x) { + char nybble = static_cast<char>((x >> 0) & 0x0F); + return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); +} + +inline char Asciify(char x) { + if ((x < 0) || !isprint(x)) + return '.'; + return x; +} + +void DumpData(const char* data, int data_len) { + if (logging::LOG_INFO < logging::GetMinLogLevel()) + return; + DVLOG(1) << "Length: " << data_len; + const char* pfx = "Data: "; + if (!data || (data_len <= 0)) { + DVLOG(1) << pfx << "<None>"; + } else { + int i; + for (i = 0; i <= (data_len - 4); i += 4) { + DVLOG(1) << pfx + << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) + << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) + << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) + << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) + << " '" + << Asciify(data[i + 0]) + << Asciify(data[i + 1]) + << Asciify(data[i + 2]) + << Asciify(data[i + 3]) + << "'"; + pfx = " "; + } + // Take care of any 'trailing' bytes, if data_len was not a multiple of 4. + switch (data_len - i) { + case 3: + DVLOG(1) << pfx + << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) + << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) + << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) + << " '" + << Asciify(data[i + 0]) + << Asciify(data[i + 1]) + << Asciify(data[i + 2]) + << " '"; + break; + case 2: + DVLOG(1) << pfx + << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) + << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) + << " '" + << Asciify(data[i + 0]) + << Asciify(data[i + 1]) + << " '"; + break; + case 1: + DVLOG(1) << pfx + << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) + << " '" + << Asciify(data[i + 0]) + << " '"; + break; + } + } +} + +template <MockReadWriteType type> +void DumpMockReadWrite(const MockReadWrite<type>& r) { + if (logging::LOG_INFO < logging::GetMinLogLevel()) + return; + DVLOG(1) << "Async: " << (r.mode == ASYNC) + << "\nResult: " << r.result; + DumpData(r.data, r.data_len); + const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : ""; + DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop + << "\nTime: " << r.time_stamp.ToInternalValue(); +} + +} // namespace + +MockConnect::MockConnect() : mode(ASYNC), result(OK) { + IPAddressNumber ip; + CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); + peer_addr = IPEndPoint(ip, 0); +} + +MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) { + IPAddressNumber ip; + CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); + peer_addr = IPEndPoint(ip, 0); +} + +MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : + mode(io_mode), + result(r), + peer_addr(addr) { +} + +MockConnect::~MockConnect() {} + +StaticSocketDataProvider::StaticSocketDataProvider() + : reads_(NULL), + read_index_(0), + read_count_(0), + writes_(NULL), + write_index_(0), + write_count_(0) { +} + +StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) + : reads_(reads), + read_index_(0), + read_count_(reads_count), + writes_(writes), + write_index_(0), + write_count_(writes_count) { +} + +StaticSocketDataProvider::~StaticSocketDataProvider() {} + +const MockRead& StaticSocketDataProvider::PeekRead() const { + DCHECK(!at_read_eof()); + return reads_[read_index_]; +} + +const MockWrite& StaticSocketDataProvider::PeekWrite() const { + DCHECK(!at_write_eof()); + return writes_[write_index_]; +} + +const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { + DCHECK_LT(index, read_count_); + return reads_[index]; +} + +const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { + DCHECK_LT(index, write_count_); + return writes_[index]; +} + +MockRead StaticSocketDataProvider::GetNextRead() { + DCHECK(!at_read_eof()); + reads_[read_index_].time_stamp = base::Time::Now(); + return reads_[read_index_++]; +} + +MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { + if (!writes_) { + // Not using mock writes; succeed synchronously. + return MockWriteResult(SYNCHRONOUS, data.length()); + } + DCHECK(!at_write_eof()); + + // Check that what we are writing matches the expectation. + // Then give the mocked return value. + MockWrite* w = &writes_[write_index_++]; + w->time_stamp = base::Time::Now(); + int result = w->result; + if (w->data) { + // Note - we can simulate a partial write here. If the expected data + // is a match, but shorter than the write actually written, that is legal. + // Example: + // Application writes "foobarbaz" (9 bytes) + // Expected write was "foo" (3 bytes) + // This is a success, and we return 3 to the application. + std::string expected_data(w->data, w->data_len); + EXPECT_GE(data.length(), expected_data.length()); + std::string actual_data(data.substr(0, w->data_len)); + EXPECT_EQ(expected_data, actual_data); + if (expected_data != actual_data) + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + if (result == OK) + result = w->data_len; + } + return MockWriteResult(w->mode, result); +} + +void StaticSocketDataProvider::Reset() { + read_index_ = 0; + write_index_ = 0; +} + +DynamicSocketDataProvider::DynamicSocketDataProvider() + : short_read_limit_(0), + allow_unconsumed_reads_(false) { +} + +DynamicSocketDataProvider::~DynamicSocketDataProvider() {} + +MockRead DynamicSocketDataProvider::GetNextRead() { + if (reads_.empty()) + return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); + MockRead result = reads_.front(); + if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) { + reads_.pop_front(); + } else { + result.data_len = short_read_limit_; + reads_.front().data += result.data_len; + reads_.front().data_len -= result.data_len; + } + return result; +} + +void DynamicSocketDataProvider::Reset() { + reads_.clear(); +} + +void DynamicSocketDataProvider::SimulateRead(const char* data, + const size_t length) { + if (!allow_unconsumed_reads_) { + EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data; + } + reads_.push_back(MockRead(ASYNC, data, length)); +} + +SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) + : connect(mode, result), + next_proto_status(SSLClientSocket::kNextProtoUnsupported), + was_npn_negotiated(false), + protocol_negotiated(kProtoUnknown), + client_cert_sent(false), + cert_request_info(NULL), + channel_id_sent(false) { +} + +SSLSocketDataProvider::~SSLSocketDataProvider() { +} + +void SSLSocketDataProvider::SetNextProto(NextProto proto) { + was_npn_negotiated = true; + next_proto_status = SSLClientSocket::kNextProtoNegotiated; + protocol_negotiated = proto; + next_proto = SSLClientSocket::NextProtoToString(proto); +} + +DelayedSocketData::DelayedSocketData( + int write_delay, MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + write_delay_(write_delay), + read_in_progress_(false), + weak_factory_(this) { + DCHECK_GE(write_delay_, 0); +} + +DelayedSocketData::DelayedSocketData( + const MockConnect& connect, int write_delay, MockRead* reads, + size_t reads_count, MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + write_delay_(write_delay), + read_in_progress_(false), + weak_factory_(this) { + DCHECK_GE(write_delay_, 0); + set_connect_data(connect); +} + +DelayedSocketData::~DelayedSocketData() { +} + +void DelayedSocketData::ForceNextRead() { + DCHECK(read_in_progress_); + write_delay_ = 0; + CompleteRead(); +} + +MockRead DelayedSocketData::GetNextRead() { + MockRead out = MockRead(ASYNC, ERR_IO_PENDING); + if (write_delay_ <= 0) + out = StaticSocketDataProvider::GetNextRead(); + read_in_progress_ = (out.result == ERR_IO_PENDING); + return out; +} + +MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { + MockWriteResult rv = StaticSocketDataProvider::OnWrite(data); + // Now that our write has completed, we can allow reads to continue. + if (!--write_delay_ && read_in_progress_) + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&DelayedSocketData::CompleteRead, + weak_factory_.GetWeakPtr()), + base::TimeDelta::FromMilliseconds(100)); + return rv; +} + +void DelayedSocketData::Reset() { + set_socket(NULL); + read_in_progress_ = false; + weak_factory_.InvalidateWeakPtrs(); + StaticSocketDataProvider::Reset(); +} + +void DelayedSocketData::CompleteRead() { + if (socket() && read_in_progress_) + socket()->OnReadComplete(GetNextRead()); +} + +OrderedSocketData::OrderedSocketData( + MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + sequence_number_(0), loop_stop_stage_(0), + blocked_(false), weak_factory_(this) { +} + +OrderedSocketData::OrderedSocketData( + const MockConnect& connect, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + sequence_number_(0), loop_stop_stage_(0), + blocked_(false), weak_factory_(this) { + set_connect_data(connect); +} + +void OrderedSocketData::EndLoop() { + // If we've already stopped the loop, don't do it again until we've advanced + // to the next sequence_number. + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; + if (loop_stop_stage_ > 0) { + const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + if ((next_read.sequence_number & ~MockRead::STOPLOOP) > + loop_stop_stage_) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Clearing stop index"; + loop_stop_stage_ = 0; + } else { + return; + } + } + // Record the sequence_number at which we stopped the loop. + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Posting Quit at read " << read_index(); + loop_stop_stage_ = sequence_number_; +} + +MockRead OrderedSocketData::GetNextRead() { + weak_factory_.InvalidateWeakPtrs(); + blocked_ = false; + const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + if (next_read.sequence_number & MockRead::STOPLOOP) + EndLoop(); + if ((next_read.sequence_number & ~MockRead::STOPLOOP) <= + sequence_number_++) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 + << ": Read " << read_index(); + DumpMockReadWrite(next_read); + blocked_ = (next_read.result == ERR_IO_PENDING); + return StaticSocketDataProvider::GetNextRead(); + } + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 + << ": I/O Pending"; + MockRead result = MockRead(ASYNC, ERR_IO_PENDING); + DumpMockReadWrite(result); + blocked_ = true; + return result; +} + +MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Write " << write_index(); + DumpMockReadWrite(PeekWrite()); + ++sequence_number_; + if (blocked_) { + // TODO(willchan): This 100ms delay seems to work around some weirdness. We + // should probably fix the weirdness. One example is in SpdyStream, + // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the + // SYN_REPLY causes OnResponseReceived() to get called before + // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED(). + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&OrderedSocketData::CompleteRead, + weak_factory_.GetWeakPtr()), + base::TimeDelta::FromMilliseconds(100)); + } + return StaticSocketDataProvider::OnWrite(data); +} + +void OrderedSocketData::Reset() { + NET_TRACE(INFO, " *** ") << "Stage " + << sequence_number_ << ": Reset()"; + sequence_number_ = 0; + loop_stop_stage_ = 0; + set_socket(NULL); + weak_factory_.InvalidateWeakPtrs(); + StaticSocketDataProvider::Reset(); +} + +void OrderedSocketData::CompleteRead() { + if (socket() && blocked_) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; + socket()->OnReadComplete(GetNextRead()); + } +} + +OrderedSocketData::~OrderedSocketData() {} + +DeterministicSocketData::DeterministicSocketData(MockRead* reads, + size_t reads_count, MockWrite* writes, size_t writes_count) + : StaticSocketDataProvider(reads, reads_count, writes, writes_count), + sequence_number_(0), + current_read_(), + current_write_(), + stopping_sequence_number_(0), + stopped_(false), + print_debug_(false), + is_running_(false) { + VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count); +} + +DeterministicSocketData::~DeterministicSocketData() {} + +void DeterministicSocketData::Run() { + DCHECK(!is_running_); + is_running_ = true; + + SetStopped(false); + int counter = 0; + // Continue to consume data until all data has run out, or the stopped_ flag + // has been set. Consuming data requires two separate operations -- running + // the tasks in the message loop, and explicitly invoking the read/write + // callbacks (simulating network I/O). We check our conditions between each, + // since they can change in either. + while ((!at_write_eof() || !at_read_eof()) && !stopped()) { + if (counter % 2 == 0) + base::RunLoop().RunUntilIdle(); + if (counter % 2 == 1) { + InvokeCallbacks(); + } + counter++; + } + // We're done consuming new data, but it is possible there are still some + // pending callbacks which we expect to complete before returning. + while (delegate_.get() && + (delegate_->WritePending() || delegate_->ReadPending()) && + !stopped()) { + InvokeCallbacks(); + base::RunLoop().RunUntilIdle(); + } + SetStopped(false); + is_running_ = false; +} + +void DeterministicSocketData::RunFor(int steps) { + StopAfter(steps); + Run(); +} + +void DeterministicSocketData::SetStop(int seq) { + DCHECK_LT(sequence_number_, seq); + stopping_sequence_number_ = seq; + stopped_ = false; +} + +void DeterministicSocketData::StopAfter(int seq) { + SetStop(sequence_number_ + seq); +} + +MockRead DeterministicSocketData::GetNextRead() { + current_read_ = StaticSocketDataProvider::PeekRead(); + + // Synchronous read while stopped is an error + if (stopped() && current_read_.mode == SYNCHRONOUS) { + LOG(ERROR) << "Unable to perform synchronous IO while stopped"; + return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); + } + + // Async read which will be called back in a future step. + if (sequence_number_ < current_read_.sequence_number) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": I/O Pending"; + MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING); + if (current_read_.mode == SYNCHRONOUS) { + LOG(ERROR) << "Unable to perform synchronous read: " + << current_read_.sequence_number + << " at stage: " << sequence_number_; + result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED); + } + if (print_debug_) + DumpMockReadWrite(result); + return result; + } + + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Read " << read_index(); + if (print_debug_) + DumpMockReadWrite(current_read_); + + // Increment the sequence number if IO is complete + if (current_read_.mode == SYNCHRONOUS) + NextStep(); + + DCHECK_NE(ERR_IO_PENDING, current_read_.result); + StaticSocketDataProvider::GetNextRead(); + + return current_read_; +} + +MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { + const MockWrite& next_write = StaticSocketDataProvider::PeekWrite(); + current_write_ = next_write; + + // Synchronous write while stopped is an error + if (stopped() && next_write.mode == SYNCHRONOUS) { + LOG(ERROR) << "Unable to perform synchronous IO while stopped"; + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + } + + // Async write which will be called back in a future step. + if (sequence_number_ < next_write.sequence_number) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": I/O Pending"; + if (next_write.mode == SYNCHRONOUS) { + LOG(ERROR) << "Unable to perform synchronous write: " + << next_write.sequence_number << " at stage: " << sequence_number_; + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + } + } else { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Write " << write_index(); + } + + if (print_debug_) + DumpMockReadWrite(next_write); + + // Move to the next step if I/O is synchronous, since the operation will + // complete when this method returns. + if (next_write.mode == SYNCHRONOUS) + NextStep(); + + // This is either a sync write for this step, or an async write. + return StaticSocketDataProvider::OnWrite(data); +} + +void DeterministicSocketData::Reset() { + NET_TRACE(INFO, " *** ") << "Stage " + << sequence_number_ << ": Reset()"; + sequence_number_ = 0; + StaticSocketDataProvider::Reset(); + NOTREACHED(); +} + +void DeterministicSocketData::InvokeCallbacks() { + if (delegate_.get() && delegate_->WritePending() && + (current_write().sequence_number == sequence_number())) { + NextStep(); + delegate_->CompleteWrite(); + return; + } + if (delegate_.get() && delegate_->ReadPending() && + (current_read().sequence_number == sequence_number())) { + NextStep(); + delegate_->CompleteRead(); + return; + } +} + +void DeterministicSocketData::NextStep() { + // Invariant: Can never move *past* the stopping step. + DCHECK_LT(sequence_number_, stopping_sequence_number_); + sequence_number_++; + if (sequence_number_ == stopping_sequence_number_) + SetStopped(true); +} + +void DeterministicSocketData::VerifyCorrectSequenceNumbers( + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count) { + size_t read = 0; + size_t write = 0; + int expected = 0; + while (read < reads_count || write < writes_count) { + // Check to see that we have a read or write at the expected + // state. + if (read < reads_count && reads[read].sequence_number == expected) { + ++read; + ++expected; + continue; + } + if (write < writes_count && writes[write].sequence_number == expected) { + ++write; + ++expected; + continue; + } + NOTREACHED() << "Missing sequence number: " << expected; + return; + } + DCHECK_EQ(read, reads_count); + DCHECK_EQ(write, writes_count); +} + +MockClientSocketFactory::MockClientSocketFactory() {} + +MockClientSocketFactory::~MockClientSocketFactory() {} + +void MockClientSocketFactory::AddSocketDataProvider( + SocketDataProvider* data) { + mock_data_.Add(data); +} + +void MockClientSocketFactory::AddSSLSocketDataProvider( + SSLSocketDataProvider* data) { + mock_ssl_data_.Add(data); +} + +void MockClientSocketFactory::ResetNextMockIndexes() { + mock_data_.ResetNextIndex(); + mock_ssl_data_.ResetNextIndex(); +} + +scoped_ptr<DatagramClientSocket> +MockClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) { + SocketDataProvider* data_provider = mock_data_.GetNext(); + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); +} + +scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( + const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) { + SocketDataProvider* data_provider = mock_data_.GetNext(); + scoped_ptr<MockTCPClientSocket> socket( + new MockTCPClientSocket(addresses, net_log, data_provider)); + data_provider->set_socket(socket.get()); + return socket.PassAs<StreamSocket>(); +} + +scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) { + return scoped_ptr<SSLClientSocket>( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); +} + +void MockClientSocketFactory::ClearSSLSessionCache() { +} + +const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ"; + +MockClientSocket::MockClientSocket(const BoundNetLog& net_log) + : weak_factory_(this), + connected_(false), + net_log_(net_log) { + IPAddressNumber ip; + CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); + peer_addr_ = IPEndPoint(ip, 0); +} + +bool MockClientSocket::SetReceiveBufferSize(int32 size) { + return true; +} + +bool MockClientSocket::SetSendBufferSize(int32 size) { + return true; +} + +void MockClientSocket::Disconnect() { + connected_ = false; +} + +bool MockClientSocket::IsConnected() const { + return connected_; +} + +bool MockClientSocket::IsConnectedAndIdle() const { + return connected_; +} + +int MockClientSocket::GetPeerAddress(IPEndPoint* address) const { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = peer_addr_; + return OK; +} + +int MockClientSocket::GetLocalAddress(IPEndPoint* address) const { + IPAddressNumber ip; + bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); + CHECK(rv); + *address = IPEndPoint(ip, 123); + return OK; +} + +const BoundNetLog& MockClientSocket::NetLog() const { + return net_log_; +} + +void MockClientSocket::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { +} + +int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) { + memset(out, 'A', outlen); + return OK; +} + +int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) { + out->assign(MockClientSocket::kTlsUnique); + return OK; +} + +ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const { + NOTREACHED(); + return NULL; +} + +SSLClientSocket::NextProtoStatus +MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { + proto->clear(); + server_protos->clear(); + return SSLClientSocket::kNextProtoUnsupported; +} + +MockClientSocket::~MockClientSocket() {} + +void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, + int result) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockClientSocket::RunCallback, + weak_factory_.GetWeakPtr(), + callback, + result)); +} + +void MockClientSocket::RunCallback(const net::CompletionCallback& callback, + int result) { + if (!callback.is_null()) + callback.Run(result); +} + +MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, + SocketDataProvider* data) + : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + addresses_(addresses), + data_(data), + read_offset_(0), + read_data_(SYNCHRONOUS, ERR_UNEXPECTED), + need_read_data_(true), + peer_closed_connection_(false), + pending_buf_(NULL), + pending_buf_len_(0), + was_used_to_convey_data_(false) { + DCHECK(data_); + peer_addr_ = data->connect_data().peer_addr; + data_->Reset(); +} + +MockTCPClientSocket::~MockTCPClientSocket() {} + +int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + + if (need_read_data_) { + read_data_ = data_->GetNextRead(); + if (read_data_.result == ERR_CONNECTION_CLOSED) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. + peer_closed_connection_ = true; + } + if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. Skip it and get the next one. + read_data_ = data_->GetNextRead(); + peer_closed_connection_ = true; + } + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + // We need to be using async IO in this case. + DCHECK(!callback.is_null()); + return ERR_IO_PENDING; + } + need_read_data_ = false; + } + + return CompleteRead(); +} + +int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + MockWriteResult write_result = data_->OnWrite(data); + + was_used_to_convey_data_ = true; + + if (write_result.mode == ASYNC) { + RunCallbackAsync(callback, write_result.result); + return ERR_IO_PENDING; + } + + return write_result.result; +} + +int MockTCPClientSocket::Connect(const CompletionCallback& callback) { + if (connected_) + return OK; + connected_ = true; + peer_closed_connection_ = false; + if (data_->connect_data().mode == ASYNC) { + if (data_->connect_data().result == ERR_IO_PENDING) + pending_callback_ = callback; + else + RunCallbackAsync(callback, data_->connect_data().result); + return ERR_IO_PENDING; + } + return data_->connect_data().result; +} + +void MockTCPClientSocket::Disconnect() { + MockClientSocket::Disconnect(); + pending_callback_.Reset(); +} + +bool MockTCPClientSocket::IsConnected() const { + return connected_ && !peer_closed_connection_; +} + +bool MockTCPClientSocket::IsConnectedAndIdle() const { + return IsConnected(); +} + +int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const { + if (addresses_.empty()) + return MockClientSocket::GetPeerAddress(address); + + *address = addresses_[0]; + return OK; +} + +bool MockTCPClientSocket::WasEverUsed() const { + return was_used_to_convey_data_; +} + +bool MockTCPClientSocket::UsingTCPFastOpen() const { + return false; +} + +bool MockTCPClientSocket::WasNpnNegotiated() const { + return false; +} + +bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + +void MockTCPClientSocket::OnReadComplete(const MockRead& data) { + // There must be a read pending. + DCHECK(pending_buf_); + // You can't complete a read with another ERR_IO_PENDING status code. + DCHECK_NE(ERR_IO_PENDING, data.result); + // Since we've been waiting for data, need_read_data_ should be true. + DCHECK(need_read_data_); + + read_data_ = data; + need_read_data_ = false; + + // The caller is simulating that this IO completes right now. Don't + // let CompleteRead() schedule a callback. + read_data_.mode = SYNCHRONOUS; + + CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); +} + +void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { + CompletionCallback callback = pending_callback_; + RunCallback(callback, data.result); +} + +int MockTCPClientSocket::CompleteRead() { + DCHECK(pending_buf_); + DCHECK(pending_buf_len_ > 0); + + was_used_to_convey_data_ = true; + + // Save the pending async IO data and reset our |pending_| state. + IOBuffer* buf = pending_buf_; + int buf_len = pending_buf_len_; + CompletionCallback callback = pending_callback_; + pending_buf_ = NULL; + pending_buf_len_ = 0; + pending_callback_.Reset(); + + int result = read_data_.result; + DCHECK(result != ERR_IO_PENDING); + + if (read_data_.data) { + if (read_data_.data_len - read_offset_ > 0) { + result = std::min(buf_len, read_data_.data_len - read_offset_); + memcpy(buf->data(), read_data_.data + read_offset_, result); + read_offset_ += result; + if (read_offset_ == read_data_.data_len) { + need_read_data_ = true; + read_offset_ = 0; + } + } else { + result = 0; // EOF + } + } + + if (read_data_.mode == ASYNC) { + DCHECK(!callback.is_null()); + RunCallbackAsync(callback, result); + return ERR_IO_PENDING; + } + return result; +} + +DeterministicSocketHelper::DeterministicSocketHelper( + net::NetLog* net_log, + DeterministicSocketData* data) + : write_pending_(false), + write_result_(0), + read_data_(), + read_buf_(NULL), + read_buf_len_(0), + read_pending_(false), + data_(data), + was_used_to_convey_data_(false), + peer_closed_connection_(false), + net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) { +} + +DeterministicSocketHelper::~DeterministicSocketHelper() {} + +void DeterministicSocketHelper::CompleteWrite() { + was_used_to_convey_data_ = true; + write_pending_ = false; + write_callback_.Run(write_result_); +} + +int DeterministicSocketHelper::CompleteRead() { + DCHECK_GT(read_buf_len_, 0); + DCHECK_LE(read_data_.data_len, read_buf_len_); + DCHECK(read_buf_); + + was_used_to_convey_data_ = true; + + if (read_data_.result == ERR_IO_PENDING) + read_data_ = data_->GetNextRead(); + DCHECK_NE(ERR_IO_PENDING, read_data_.result); + // If read_data_.mode is ASYNC, we do not need to wait, since this is already + // the callback. Therefore we don't even bother to check it. + int result = read_data_.result; + + if (read_data_.data_len > 0) { + DCHECK(read_data_.data); + result = std::min(read_buf_len_, read_data_.data_len); + memcpy(read_buf_->data(), read_data_.data, result); + } + + if (read_pending_) { + read_pending_ = false; + read_callback_.Run(result); + } + + return result; +} + +int DeterministicSocketHelper::Write( + IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + std::string data(buf->data(), buf_len); + MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.mode == ASYNC) { + write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(!write_callback_.is_null()); + write_pending_ = true; + return ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} + +int DeterministicSocketHelper::Read( + IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + + read_data_ = data_->GetNextRead(); + // The buffer should always be big enough to contain all the MockRead data. To + // use small buffers, split the data into multiple MockReads. + DCHECK_LE(read_data_.data_len, buf_len); + + if (read_data_.result == ERR_CONNECTION_CLOSED) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. + peer_closed_connection_ = true; + } + if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. Skip it and get the next one. + read_data_ = data_->GetNextRead(); + peer_closed_connection_ = true; + } + + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + + if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) { + read_pending_ = true; + DCHECK(!read_callback_.is_null()); + return ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + return CompleteRead(); +} + +DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket( + net::NetLog* net_log, + DeterministicSocketData* data) + : connected_(false), + helper_(net_log, data) { +} + +DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {} + +bool DeterministicMockUDPClientSocket::WritePending() const { + return helper_.write_pending(); +} + +bool DeterministicMockUDPClientSocket::ReadPending() const { + return helper_.read_pending(); +} + +void DeterministicMockUDPClientSocket::CompleteWrite() { + helper_.CompleteWrite(); +} + +int DeterministicMockUDPClientSocket::CompleteRead() { + return helper_.CompleteRead(); +} + +int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) { + if (connected_) + return OK; + connected_ = true; + peer_address_ = address; + return helper_.data()->connect_data().result; +}; + +int DeterministicMockUDPClientSocket::Write( + IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + return helper_.Write(buf, buf_len, callback); +} + +int DeterministicMockUDPClientSocket::Read( + IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + return helper_.Read(buf, buf_len, callback); +} + +bool DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) { + return true; +} + +bool DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) { + return true; +} + +void DeterministicMockUDPClientSocket::Close() { + connected_ = false; +} + +int DeterministicMockUDPClientSocket::GetPeerAddress( + IPEndPoint* address) const { + *address = peer_address_; + return OK; +} + +int DeterministicMockUDPClientSocket::GetLocalAddress( + IPEndPoint* address) const { + IPAddressNumber ip; + bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); + CHECK(rv); + *address = IPEndPoint(ip, 123); + return OK; +} + +const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const { + return helper_.net_log(); +} + +void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {} + +void DeterministicMockUDPClientSocket::OnConnectComplete( + const MockConnect& data) { + NOTIMPLEMENTED(); +} + +DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( + net::NetLog* net_log, + DeterministicSocketData* data) + : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + helper_(net_log, data) { + peer_addr_ = data->connect_data().peer_addr; +} + +DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} + +bool DeterministicMockTCPClientSocket::WritePending() const { + return helper_.write_pending(); +} + +bool DeterministicMockTCPClientSocket::ReadPending() const { + return helper_.read_pending(); +} + +void DeterministicMockTCPClientSocket::CompleteWrite() { + helper_.CompleteWrite(); +} + +int DeterministicMockTCPClientSocket::CompleteRead() { + return helper_.CompleteRead(); +} + +int DeterministicMockTCPClientSocket::Write( + IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + return helper_.Write(buf, buf_len, callback); +} + +int DeterministicMockTCPClientSocket::Read( + IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + return helper_.Read(buf, buf_len, callback); +} + +// TODO(erikchen): Support connect sequencing. +int DeterministicMockTCPClientSocket::Connect( + const CompletionCallback& callback) { + if (connected_) + return OK; + connected_ = true; + if (helper_.data()->connect_data().mode == ASYNC) { + RunCallbackAsync(callback, helper_.data()->connect_data().result); + return ERR_IO_PENDING; + } + return helper_.data()->connect_data().result; +} + +void DeterministicMockTCPClientSocket::Disconnect() { + MockClientSocket::Disconnect(); +} + +bool DeterministicMockTCPClientSocket::IsConnected() const { + return connected_ && !helper_.peer_closed_connection(); +} + +bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const { + return IsConnected(); +} + +bool DeterministicMockTCPClientSocket::WasEverUsed() const { + return helper_.was_used_to_convey_data(); +} + +bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const { + return false; +} + +bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const { + return false; +} + +bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + +void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} + +void DeterministicMockTCPClientSocket::OnConnectComplete( + const MockConnect& data) {} + +// static +void MockSSLClientSocket::ConnectCallback( + MockSSLClientSocket* ssl_client_socket, + const CompletionCallback& callback, + int rv) { + if (rv == OK) + ssl_client_socket->connected_ = true; + callback.Run(rv); +} + +MockSSLClientSocket::MockSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_port_pair, + const SSLConfig& ssl_config, + SSLSocketDataProvider* data) + : MockClientSocket( + // Have to use the right BoundNetLog for LoadTimingInfo regression + // tests. + transport_socket->socket()->NetLog()), + transport_(transport_socket.Pass()), + data_(data), + is_npn_state_set_(false), + new_npn_value_(false), + is_protocol_negotiated_set_(false), + protocol_negotiated_(kProtoUnknown) { + DCHECK(data_); + peer_addr_ = data->connect.peer_addr; +} + +MockSSLClientSocket::~MockSSLClientSocket() { + Disconnect(); +} + +int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return transport_->socket()->Read(buf, buf_len, callback); +} + +int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + return transport_->socket()->Write(buf, buf_len, callback); +} + +int MockSSLClientSocket::Connect(const CompletionCallback& callback) { + int rv = transport_->socket()->Connect( + base::Bind(&ConnectCallback, base::Unretained(this), callback)); + if (rv == OK) { + if (data_->connect.result == OK) + connected_ = true; + if (data_->connect.mode == ASYNC) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } + return rv; +} + +void MockSSLClientSocket::Disconnect() { + MockClientSocket::Disconnect(); + if (transport_->socket() != NULL) + transport_->socket()->Disconnect(); +} + +bool MockSSLClientSocket::IsConnected() const { + return transport_->socket()->IsConnected(); +} + +bool MockSSLClientSocket::WasEverUsed() const { + return transport_->socket()->WasEverUsed(); +} + +bool MockSSLClientSocket::UsingTCPFastOpen() const { + return transport_->socket()->UsingTCPFastOpen(); +} + +int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->socket()->GetPeerAddress(address); +} + +bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + ssl_info->cert = data_->cert; + ssl_info->client_cert_sent = data_->client_cert_sent; + ssl_info->channel_id_sent = data_->channel_id_sent; + return true; +} + +void MockSSLClientSocket::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { + DCHECK(cert_request_info); + if (data_->cert_request_info) { + cert_request_info->host_and_port = + data_->cert_request_info->host_and_port; + cert_request_info->client_certs = data_->cert_request_info->client_certs; + } else { + cert_request_info->Reset(); + } +} + +SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( + std::string* proto, std::string* server_protos) { + *proto = data_->next_proto; + *server_protos = data_->server_protos; + return data_->next_proto_status; +} + +bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { + is_npn_state_set_ = true; + return new_npn_value_ = negotiated; +} + +bool MockSSLClientSocket::WasNpnNegotiated() const { + if (is_npn_state_set_) + return new_npn_value_; + return data_->was_npn_negotiated; +} + +NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { + if (is_protocol_negotiated_set_) + return protocol_negotiated_; + return data_->protocol_negotiated; +} + +void MockSSLClientSocket::set_protocol_negotiated( + NextProto protocol_negotiated) { + is_protocol_negotiated_set_ = true; + protocol_negotiated_ = protocol_negotiated; +} + +bool MockSSLClientSocket::WasChannelIDSent() const { + return data_->channel_id_sent; +} + +void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) { + data_->channel_id_sent = channel_id_sent; +} + +ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const { + return data_->server_bound_cert_service; +} + +void MockSSLClientSocket::OnReadComplete(const MockRead& data) { + NOTIMPLEMENTED(); +} + +void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { + NOTIMPLEMENTED(); +} + +MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, + net::NetLog* net_log) + : connected_(false), + data_(data), + read_offset_(0), + read_data_(SYNCHRONOUS, ERR_UNEXPECTED), + need_read_data_(true), + pending_buf_(NULL), + pending_buf_len_(0), + net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + weak_factory_(this) { + DCHECK(data_); + data_->Reset(); + peer_addr_ = data->connect_data().peer_addr; +} + +MockUDPClientSocket::~MockUDPClientSocket() {} + +int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!connected_) + return ERR_UNEXPECTED; + + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + + if (need_read_data_) { + read_data_ = data_->GetNextRead(); + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + // We need to be using async IO in this case. + DCHECK(!callback.is_null()); + return ERR_IO_PENDING; + } + need_read_data_ = false; + } + + return CompleteRead(); +} + +int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.mode == ASYNC) { + RunCallbackAsync(callback, write_result.result); + return ERR_IO_PENDING; + } + return write_result.result; +} + +bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) { + return true; +} + +bool MockUDPClientSocket::SetSendBufferSize(int32 size) { + return true; +} + +void MockUDPClientSocket::Close() { + connected_ = false; +} + +int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const { + *address = peer_addr_; + return OK; +} + +int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const { + IPAddressNumber ip; + bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); + CHECK(rv); + *address = IPEndPoint(ip, 123); + return OK; +} + +const BoundNetLog& MockUDPClientSocket::NetLog() const { + return net_log_; +} + +int MockUDPClientSocket::Connect(const IPEndPoint& address) { + connected_ = true; + peer_addr_ = address; + return OK; +} + +void MockUDPClientSocket::OnReadComplete(const MockRead& data) { + // There must be a read pending. + DCHECK(pending_buf_); + // You can't complete a read with another ERR_IO_PENDING status code. + DCHECK_NE(ERR_IO_PENDING, data.result); + // Since we've been waiting for data, need_read_data_ should be true. + DCHECK(need_read_data_); + + read_data_ = data; + need_read_data_ = false; + + // The caller is simulating that this IO completes right now. Don't + // let CompleteRead() schedule a callback. + read_data_.mode = SYNCHRONOUS; + + net::CompletionCallback callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); +} + +void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { + NOTIMPLEMENTED(); +} + +int MockUDPClientSocket::CompleteRead() { + DCHECK(pending_buf_); + DCHECK(pending_buf_len_ > 0); + + // Save the pending async IO data and reset our |pending_| state. + IOBuffer* buf = pending_buf_; + int buf_len = pending_buf_len_; + CompletionCallback callback = pending_callback_; + pending_buf_ = NULL; + pending_buf_len_ = 0; + pending_callback_.Reset(); + + int result = read_data_.result; + DCHECK(result != ERR_IO_PENDING); + + if (read_data_.data) { + if (read_data_.data_len - read_offset_ > 0) { + result = std::min(buf_len, read_data_.data_len - read_offset_); + memcpy(buf->data(), read_data_.data + read_offset_, result); + read_offset_ += result; + if (read_offset_ == read_data_.data_len) { + need_read_data_ = true; + read_offset_ = 0; + } + } else { + result = 0; // EOF + } + } + + if (read_data_.mode == ASYNC) { + DCHECK(!callback.is_null()); + RunCallbackAsync(callback, result); + return ERR_IO_PENDING; + } + return result; +} + +void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback, + int result) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockUDPClientSocket::RunCallback, + weak_factory_.GetWeakPtr(), + callback, + result)); +} + +void MockUDPClientSocket::RunCallback(const CompletionCallback& callback, + int result) { + if (!callback.is_null()) + callback.Run(result); +} + +TestSocketRequest::TestSocketRequest( + std::vector<TestSocketRequest*>* request_order, size_t* completion_count) + : request_order_(request_order), + completion_count_(completion_count), + callback_(base::Bind(&TestSocketRequest::OnComplete, + base::Unretained(this))) { + DCHECK(request_order); + DCHECK(completion_count); +} + +TestSocketRequest::~TestSocketRequest() { +} + +void TestSocketRequest::OnComplete(int result) { + SetResult(result); + (*completion_count_)++; + request_order_->push_back(this); +} + +// static +const int ClientSocketPoolTest::kIndexOutOfBounds = -1; + +// static +const int ClientSocketPoolTest::kRequestNotFound = -2; + +ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {} +ClientSocketPoolTest::~ClientSocketPoolTest() {} + +int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const { + index--; + if (index >= requests_.size()) + return kIndexOutOfBounds; + + for (size_t i = 0; i < request_order_.size(); i++) + if (requests_[index] == request_order_[i]) + return i + 1; + + return kRequestNotFound; +} + +bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) { + ScopedVector<TestSocketRequest>::iterator i; + for (i = requests_.begin(); i != requests_.end(); ++i) { + if ((*i)->handle()->is_initialized()) { + if (keep_alive == NO_KEEP_ALIVE) + (*i)->handle()->socket()->Disconnect(); + (*i)->handle()->Reset(); + base::RunLoop().RunUntilIdle(); + return true; + } + } + return false; +} + +void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { + bool released_one; + do { + released_one = ReleaseOneConnection(keep_alive); + } while (released_one); +} + +MockTransportClientSocketPool::MockConnectJob::MockConnectJob( + scoped_ptr<StreamSocket> socket, + ClientSocketHandle* handle, + const CompletionCallback& callback) + : socket_(socket.Pass()), + handle_(handle), + user_callback_(callback) { +} + +MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {} + +int MockTransportClientSocketPool::MockConnectJob::Connect() { + int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect, + base::Unretained(this))); + if (rv == OK) { + user_callback_.Reset(); + OnConnect(OK); + } + return rv; +} + +bool MockTransportClientSocketPool::MockConnectJob::CancelHandle( + const ClientSocketHandle* handle) { + if (handle != handle_) + return false; + socket_.reset(); + handle_ = NULL; + user_callback_.Reset(); + return true; +} + +void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { + if (!socket_.get()) + return; + if (rv == OK) { + handle_->SetSocket(socket_.Pass()); + + // Needed for socket pool tests that layer other sockets on top of mock + // sockets. + LoadTimingInfo::ConnectTiming connect_timing; + base::TimeTicks now = base::TimeTicks::Now(); + connect_timing.dns_start = now; + connect_timing.dns_end = now; + connect_timing.connect_start = now; + connect_timing.connect_end = now; + handle_->set_connect_timing(connect_timing); + } else { + socket_.reset(); + } + + handle_ = NULL; + + if (!user_callback_.is_null()) { + CompletionCallback callback = user_callback_; + user_callback_.Reset(); + callback.Run(rv); + } +} + +MockTransportClientSocketPool::MockTransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + ClientSocketFactory* socket_factory) + : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, + NULL, NULL, NULL), + client_socket_factory_(socket_factory), + release_count_(0), + cancel_count_(0) { +} + +MockTransportClientSocketPool::~MockTransportClientSocketPool() {} + +int MockTransportClientSocketPool::RequestSocket( + const std::string& group_name, const void* socket_params, + RequestPriority priority, ClientSocketHandle* handle, + const CompletionCallback& callback, const BoundNetLog& net_log) { + scoped_ptr<StreamSocket> socket = + client_socket_factory_->CreateTransportClientSocket( + AddressList(), net_log.net_log(), net::NetLog::Source()); + MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); + job_list_.push_back(job); + handle->set_pool_id(1); + return job->Connect(); +} + +void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) { + std::vector<MockConnectJob*>::iterator i; + for (i = job_list_.begin(); i != job_list_.end(); ++i) { + if ((*i)->CancelHandle(handle)) { + cancel_count_++; + break; + } + } +} + +void MockTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + EXPECT_EQ(1, id); + release_count_++; +} + +DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} + +DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {} + +void DeterministicMockClientSocketFactory::AddSocketDataProvider( + DeterministicSocketData* data) { + mock_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( + SSLSocketDataProvider* data) { + mock_ssl_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { + mock_data_.ResetNextIndex(); + mock_ssl_data_.ResetNextIndex(); +} + +MockSSLClientSocket* DeterministicMockClientSocketFactory:: + GetMockSSLClientSocket(size_t index) const { + DCHECK_LT(index, ssl_client_sockets_.size()); + return ssl_client_sockets_[index]; +} + +scoped_ptr<DatagramClientSocket> +DeterministicMockClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const NetLog::Source& source) { + DeterministicSocketData* data_provider = mock_data().GetNext(); + scoped_ptr<DeterministicMockUDPClientSocket> socket( + new DeterministicMockUDPClientSocket(net_log, data_provider)); + data_provider->set_delegate(socket->AsWeakPtr()); + udp_client_sockets().push_back(socket.get()); + return socket.PassAs<DatagramClientSocket>(); +} + +scoped_ptr<StreamSocket> +DeterministicMockClientSocketFactory::CreateTransportClientSocket( + const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) { + DeterministicSocketData* data_provider = mock_data().GetNext(); + scoped_ptr<DeterministicMockTCPClientSocket> socket( + new DeterministicMockTCPClientSocket(net_log, data_provider)); + data_provider->set_delegate(socket->AsWeakPtr()); + tcp_client_sockets().push_back(socket.get()); + return socket.PassAs<StreamSocket>(); +} + +scoped_ptr<SSLClientSocket> +DeterministicMockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) { + scoped_ptr<MockSSLClientSocket> socket( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); + ssl_client_sockets_.push_back(socket.get()); + return socket.PassAs<SSLClientSocket>(); +} + +void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { +} + +MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + TransportClientSocketPool* transport_pool) + : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms, + NULL, transport_pool, NULL), + transport_pool_(transport_pool) { +} + +MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} + +int MockSOCKSClientSocketPool::RequestSocket( + const std::string& group_name, const void* socket_params, + RequestPriority priority, ClientSocketHandle* handle, + const CompletionCallback& callback, const BoundNetLog& net_log) { + return transport_pool_->RequestSocket( + group_name, socket_params, priority, handle, callback, net_log); +} + +void MockSOCKSClientSocketPool::CancelRequest( + const std::string& group_name, + ClientSocketHandle* handle) { + return transport_pool_->CancelRequest(group_name, handle); +} + +void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); +} + +const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; +const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); + +const char kSOCKS5GreetResponse[] = { 0x05, 0x00 }; +const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse); + +const char kSOCKS5OkRequest[] = + { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 }; +const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest); + +const char kSOCKS5OkResponse[] = + { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 }; +const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse); + +} // namespace net diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h new file mode 100644 index 00000000000..a888249654c --- /dev/null +++ b/chromium/net/socket/socket_test_util.h @@ -0,0 +1,1198 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ +#define NET_SOCKET_SOCKET_TEST_UTIL_H_ + +#include <cstring> +#include <deque> +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/logging.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/memory/weak_ptr.h" +#include "base/strings/string16.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/test_completion_callback.h" +#include "net/http/http_auth_controller.h" +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/ssl_client_socket_pool.h" +#include "net/socket/transport_client_socket_pool.h" +#include "net/ssl/ssl_config_service.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +enum { + // A private network error code used by the socket test utility classes. + // If the |result| member of a MockRead is + // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a + // marker that indicates the peer will close the connection after the next + // MockRead. The other members of that MockRead are ignored. + ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, +}; + +class AsyncSocket; +class MockClientSocket; +class ServerBoundCertService; +class SSLClientSocket; +class StreamSocket; + +enum IoMode { + ASYNC, + SYNCHRONOUS +}; + +struct MockConnect { + // Asynchronous connection success. + // Creates a MockConnect with |mode| ASYC, |result| OK, and + // |peer_addr| 192.0.2.33. + MockConnect(); + // Creates a MockConnect with the specified mode and result, with + // |peer_addr| 192.0.2.33. + MockConnect(IoMode io_mode, int r); + MockConnect(IoMode io_mode, int r, IPEndPoint addr); + ~MockConnect(); + + IoMode mode; + int result; + IPEndPoint peer_addr; +}; + +// MockRead and MockWrite shares the same interface and members, but we'd like +// to have distinct types because we don't want to have them used +// interchangably. To do this, a struct template is defined, and MockRead and +// MockWrite are instantiated by using this template. Template parameter |type| +// is not used in the struct definition (it purely exists for creating a new +// type). +// +// |data| in MockRead and MockWrite has different meanings: |data| in MockRead +// is the data returned from the socket when MockTCPClientSocket::Read() is +// attempted, while |data| in MockWrite is the expected data that should be +// given in MockTCPClientSocket::Write(). +enum MockReadWriteType { + MOCK_READ, + MOCK_WRITE +}; + +template <MockReadWriteType type> +struct MockReadWrite { + // Flag to indicate that the message loop should be terminated. + enum { + STOPLOOP = 1 << 31 + }; + + // Default + MockReadWrite() : mode(SYNCHRONOUS), result(0), data(NULL), data_len(0), + sequence_number(0), time_stamp(base::Time::Now()) {} + + // Read/write failure (no data). + MockReadWrite(IoMode io_mode, int result) : mode(io_mode), result(result), + data(NULL), data_len(0), sequence_number(0), + time_stamp(base::Time::Now()) { } + + // Read/write failure (no data), with sequence information. + MockReadWrite(IoMode io_mode, int result, int seq) : mode(io_mode), + result(result), data(NULL), data_len(0), sequence_number(seq), + time_stamp(base::Time::Now()) { } + + // Asynchronous read/write success (inferred data length). + explicit MockReadWrite(const char* data) : mode(ASYNC), result(0), + data(data), data_len(strlen(data)), sequence_number(0), + time_stamp(base::Time::Now()) { } + + // Read/write success (inferred data length). + MockReadWrite(IoMode io_mode, const char* data) : mode(io_mode), result(0), + data(data), data_len(strlen(data)), sequence_number(0), + time_stamp(base::Time::Now()) { } + + // Read/write success. + MockReadWrite(IoMode io_mode, const char* data, int data_len) : mode(io_mode), + result(0), data(data), data_len(data_len), sequence_number(0), + time_stamp(base::Time::Now()) { } + + // Read/write success (inferred data length) with sequence information. + MockReadWrite(IoMode io_mode, int seq, const char* data) : mode(io_mode), + result(0), data(data), data_len(strlen(data)), sequence_number(seq), + time_stamp(base::Time::Now()) { } + + // Read/write success with sequence information. + MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) : + mode(io_mode), result(0), data(data), data_len(data_len), + sequence_number(seq), time_stamp(base::Time::Now()) { } + + IoMode mode; + int result; + const char* data; + int data_len; + + // For OrderedSocketData, which only allows reads to occur in a particular + // sequence. If a read occurs before the given |sequence_number| is reached, + // an ERR_IO_PENDING is returned. + int sequence_number; // The sequence number at which a read is allowed + // to occur. + base::Time time_stamp; // The time stamp at which the operation occurred. +}; + +typedef MockReadWrite<MOCK_READ> MockRead; +typedef MockReadWrite<MOCK_WRITE> MockWrite; + +struct MockWriteResult { + MockWriteResult(IoMode io_mode, int result) + : mode(io_mode), + result(result) {} + + IoMode mode; + int result; +}; + +// The SocketDataProvider is an interface used by the MockClientSocket +// for getting data about individual reads and writes on the socket. +class SocketDataProvider { + public: + SocketDataProvider() : socket_(NULL) {} + + virtual ~SocketDataProvider() {} + + // Returns the buffer and result code for the next simulated read. + // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller + // that it will be called via the AsyncSocket::OnReadComplete() + // function at a later time. + virtual MockRead GetNextRead() = 0; + virtual MockWriteResult OnWrite(const std::string& data) = 0; + virtual void Reset() = 0; + + // Accessor for the socket which is using the SocketDataProvider. + AsyncSocket* socket() { return socket_; } + void set_socket(AsyncSocket* socket) { socket_ = socket; } + + MockConnect connect_data() const { return connect_; } + void set_connect_data(const MockConnect& connect) { connect_ = connect; } + + private: + MockConnect connect_; + AsyncSocket* socket_; + + DISALLOW_COPY_AND_ASSIGN(SocketDataProvider); +}; + +// The AsyncSocket is an interface used by the SocketDataProvider to +// complete the asynchronous read operation. +class AsyncSocket { + public: + // If an async IO is pending because the SocketDataProvider returned + // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete + // is called to complete the asynchronous read operation. + // data.async is ignored, and this read is completed synchronously as + // part of this call. + virtual void OnReadComplete(const MockRead& data) = 0; + virtual void OnConnectComplete(const MockConnect& data) = 0; +}; + +// SocketDataProvider which responds based on static tables of mock reads and +// writes. +class StaticSocketDataProvider : public SocketDataProvider { + public: + StaticSocketDataProvider(); + StaticSocketDataProvider(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + virtual ~StaticSocketDataProvider(); + + // These functions get access to the next available read and write data. + const MockRead& PeekRead() const; + const MockWrite& PeekWrite() const; + // These functions get random access to the read and write data, for timing. + const MockRead& PeekRead(size_t index) const; + const MockWrite& PeekWrite(size_t index) const; + size_t read_index() const { return read_index_; } + size_t write_index() const { return write_index_; } + size_t read_count() const { return read_count_; } + size_t write_count() const { return write_count_; } + + bool at_read_eof() const { return read_index_ >= read_count_; } + bool at_write_eof() const { return write_index_ >= write_count_; } + + virtual void CompleteRead() {} + + // SocketDataProvider implementation. + virtual MockRead GetNextRead() OVERRIDE; + virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; + ; virtual void Reset() OVERRIDE; + + private: + MockRead* reads_; + size_t read_index_; + size_t read_count_; + MockWrite* writes_; + size_t write_index_; + size_t write_count_; + + DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider); +}; + +// SocketDataProvider which can make decisions about next mock reads based on +// received writes. It can also be used to enforce order of operations, for +// example that tested code must send the "Hello!" message before receiving +// response. This is useful for testing conversation-like protocols like FTP. +class DynamicSocketDataProvider : public SocketDataProvider { + public: + DynamicSocketDataProvider(); + virtual ~DynamicSocketDataProvider(); + + int short_read_limit() const { return short_read_limit_; } + void set_short_read_limit(int limit) { short_read_limit_ = limit; } + + void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } + + // SocketDataProvider implementation. + virtual MockRead GetNextRead() OVERRIDE; + virtual MockWriteResult OnWrite(const std::string& data) = 0; + virtual void Reset() OVERRIDE; + + protected: + // The next time there is a read from this socket, it will return |data|. + // Before calling SimulateRead next time, the previous data must be consumed. + void SimulateRead(const char* data, size_t length); + void SimulateRead(const char* data) { + SimulateRead(data, std::strlen(data)); + } + + private: + std::deque<MockRead> reads_; + + // Max number of bytes we will read at a time. 0 means no limit. + int short_read_limit_; + + // If true, we'll not require the client to consume all data before we + // mock the next read. + bool allow_unconsumed_reads_; + + DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider); +}; + +// SSLSocketDataProviders only need to keep track of the return code from calls +// to Connect(). +struct SSLSocketDataProvider { + SSLSocketDataProvider(IoMode mode, int result); + ~SSLSocketDataProvider(); + + void SetNextProto(NextProto proto); + + MockConnect connect; + SSLClientSocket::NextProtoStatus next_proto_status; + std::string next_proto; + std::string server_protos; + bool was_npn_negotiated; + NextProto protocol_negotiated; + bool client_cert_sent; + SSLCertRequestInfo* cert_request_info; + scoped_refptr<X509Certificate> cert; + bool channel_id_sent; + ServerBoundCertService* server_bound_cert_service; +}; + +// A DataProvider where the client must write a request before the reads (e.g. +// the response) will complete. +class DelayedSocketData : public StaticSocketDataProvider { + public: + // |write_delay| the number of MockWrites to complete before allowing + // a MockRead to complete. + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a + // MockRead(true, 0, 0); + DelayedSocketData(int write_delay, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // |connect| the result for the connect phase. + // |reads| the list of MockRead completions. + // |write_delay| the number of MockWrites to complete before allowing + // a MockRead to complete. + // |writes| the list of MockWrite completions. + // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a + // MockRead(true, 0, 0); + DelayedSocketData(const MockConnect& connect, int write_delay, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + virtual ~DelayedSocketData(); + + void ForceNextRead(); + + // StaticSocketDataProvider: + virtual MockRead GetNextRead() OVERRIDE; + virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; + virtual void Reset() OVERRIDE; + virtual void CompleteRead() OVERRIDE; + + private: + int write_delay_; + bool read_in_progress_; + base::WeakPtrFactory<DelayedSocketData> weak_factory_; +}; + +// A DataProvider where the reads are ordered. +// If a read is requested before its sequence number is reached, we return an +// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to +// wait). +// The sequence number is incremented on every read and write operation. +// The message loop may be interrupted by setting the high bit of the sequence +// number in the MockRead's sequence number. When that MockRead is reached, +// we post a Quit message to the loop. This allows us to interrupt the reading +// of data before a complete message has arrived, and provides support for +// testing server push when the request is issued while the response is in the +// middle of being received. +class OrderedSocketData : public StaticSocketDataProvider { + public: + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + // Note: All MockReads and MockWrites must be async. + // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a + // MockRead(true, 0, 0); + OrderedSocketData(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + virtual ~OrderedSocketData(); + + // |connect| the result for the connect phase. + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + // Note: All MockReads and MockWrites must be async. + // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a + // MockRead(true, 0, 0); + OrderedSocketData(const MockConnect& connect, + MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + // Posts a quit message to the current message loop, if one is running. + void EndLoop(); + + // StaticSocketDataProvider: + virtual MockRead GetNextRead() OVERRIDE; + virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; + virtual void Reset() OVERRIDE; + virtual void CompleteRead() OVERRIDE; + + private: + int sequence_number_; + int loop_stop_stage_; + bool blocked_; + base::WeakPtrFactory<OrderedSocketData> weak_factory_; +}; + +class DeterministicMockTCPClientSocket; + +// This class gives the user full control over the network activity, +// specifically the timing of the COMPLETION of I/O operations. Regardless of +// the order in which I/O operations are initiated, this class ensures that they +// complete in the correct order. +// +// Network activity is modeled as a sequence of numbered steps which is +// incremented whenever an I/O operation completes. This can happen under two +// different circumstances: +// +// 1) Performing a synchronous I/O operation. (Invoking Read() or Write() +// when the corresponding MockRead or MockWrite is marked !async). +// 2) Running the Run() method of this class. The run method will invoke +// the current MessageLoop, running all pending events, and will then +// invoke any pending IO callbacks. +// +// In addition, this class allows for I/O processing to "stop" at a specified +// step, by calling SetStop(int) or StopAfter(int). Initiating an I/O operation +// by calling Read() or Write() while stopped is permitted if the operation is +// asynchronous. It is an error to perform synchronous I/O while stopped. +// +// When creating the MockReads and MockWrites, note that the sequence number +// refers to the number of the step in which the I/O will complete. In the +// case of synchronous I/O, this will be the same step as the I/O is initiated. +// However, in the case of asynchronous I/O, this I/O may be initiated in +// a much earlier step. Furthermore, when the a Read() or Write() is separated +// from its completion by other Read() or Writes()'s, it can not be marked +// synchronous. If it is, ERR_UNUEXPECTED will be returned indicating that a +// synchronous Read() or Write() could not be completed synchronously because of +// the specific ordering constraints. +// +// Sequence numbers are preserved across both reads and writes. There should be +// no gaps in sequence numbers, and no repeated sequence numbers. i.e. +// MockRead reads[] = { +// MockRead(false, "first read", length, 0) // sync +// MockRead(true, "second read", length, 2) // async +// }; +// MockWrite writes[] = { +// MockWrite(true, "first write", length, 1), // async +// MockWrite(false, "second write", length, 3), // sync +// }; +// +// Example control flow: +// Read() is called. The current step is 0. The first available read is +// synchronous, so the call to Read() returns length. The current step is +// now 1. Next, Read() is called again. The next available read can +// not be completed until step 2, so Read() returns ERR_IO_PENDING. The current +// step is still 1. Write is called(). The first available write is able to +// complete in this step, but is marked asynchronous. Write() returns +// ERR_IO_PENDING. The current step is still 1. At this point RunFor(1) is +// called which will cause the write callback to be invoked, and will then +// stop. The current state is now 2. RunFor(1) is called again, which +// causes the read callback to be invoked, and will then stop. Then current +// step is 2. Write() is called again. Then next available write is +// synchronous so the call to Write() returns length. +// +// For examples of how to use this class, see: +// deterministic_socket_data_unittests.cc +class DeterministicSocketData + : public StaticSocketDataProvider { + public: + // The Delegate is an abstract interface which handles the communication from + // the DeterministicSocketData to the Deterministic MockSocket. The + // MockSockets directly store a pointer to the DeterministicSocketData, + // whereas the DeterministicSocketData only stores a pointer to the + // abstract Delegate interface. + class Delegate { + public: + // Returns true if there is currently a write pending. That is to say, if + // an asynchronous write has been started but the callback has not been + // invoked. + virtual bool WritePending() const = 0; + // Returns true if there is currently a read pending. That is to say, if + // an asynchronous read has been started but the callback has not been + // invoked. + virtual bool ReadPending() const = 0; + // Called to complete an asynchronous write to execute the write callback. + virtual void CompleteWrite() = 0; + // Called to complete an asynchronous read to execute the read callback. + virtual int CompleteRead() = 0; + + protected: + virtual ~Delegate() {} + }; + + // |reads| the list of MockRead completions. + // |writes| the list of MockWrite completions. + DeterministicSocketData(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + virtual ~DeterministicSocketData(); + + // Consume all the data up to the give stop point (via SetStop()). + void Run(); + + // Set the stop point to be |steps| from now, and then invoke Run(). + void RunFor(int steps); + + // Stop at step |seq|, which must be in the future. + virtual void SetStop(int seq); + + // Stop |seq| steps after the current step. + virtual void StopAfter(int seq); + bool stopped() const { return stopped_; } + void SetStopped(bool val) { stopped_ = val; } + MockRead& current_read() { return current_read_; } + MockWrite& current_write() { return current_write_; } + int sequence_number() const { return sequence_number_; } + void set_delegate(base::WeakPtr<Delegate> delegate) { + delegate_ = delegate; + } + + // StaticSocketDataProvider: + + // When the socket calls Read(), that calls GetNextRead(), and expects either + // ERR_IO_PENDING or data. + virtual MockRead GetNextRead() OVERRIDE; + + // When the socket calls Write(), it always completes synchronously. OnWrite() + // checks to make sure the written data matches the expected data. The + // callback will not be invoked until its sequence number is reached. + virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; + virtual void Reset() OVERRIDE; + virtual void CompleteRead() OVERRIDE {} + + private: + // Invoke the read and write callbacks, if the timing is appropriate. + void InvokeCallbacks(); + + void NextStep(); + + void VerifyCorrectSequenceNumbers(MockRead* reads, size_t reads_count, + MockWrite* writes, size_t writes_count); + + int sequence_number_; + MockRead current_read_; + MockWrite current_write_; + int stopping_sequence_number_; + bool stopped_; + base::WeakPtr<Delegate> delegate_; + bool print_debug_; + bool is_running_; +}; + +// Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket +// objects get instantiated, they take their data from the i'th element of this +// array. +template<typename T> +class SocketDataProviderArray { + public: + SocketDataProviderArray() : next_index_(0) {} + + T* GetNext() { + DCHECK_LT(next_index_, data_providers_.size()); + return data_providers_[next_index_++]; + } + + void Add(T* data_provider) { + DCHECK(data_provider); + data_providers_.push_back(data_provider); + } + + size_t next_index() { return next_index_; } + + void ResetNextIndex() { + next_index_ = 0; + } + + private: + // Index of the next |data_providers_| element to use. Not an iterator + // because those are invalidated on vector reallocation. + size_t next_index_; + + // SocketDataProviders to be returned. + std::vector<T*> data_providers_; +}; + +class MockUDPClientSocket; +class MockTCPClientSocket; +class MockSSLClientSocket; + +// ClientSocketFactory which contains arrays of sockets of each type. +// You should first fill the arrays using AddMock{SSL,}Socket. When the factory +// is asked to create a socket, it takes next entry from appropriate array. +// You can use ResetNextMockIndexes to reset that next entry index for all mock +// socket types. +class MockClientSocketFactory : public ClientSocketFactory { + public: + MockClientSocketFactory(); + virtual ~MockClientSocketFactory(); + + void AddSocketDataProvider(SocketDataProvider* socket); + void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); + void ResetNextMockIndexes(); + + SocketDataProviderArray<SocketDataProvider>& mock_data() { + return mock_data_; + } + + // ClientSocketFactory + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE; + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE; + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE; + virtual void ClearSSLSessionCache() OVERRIDE; + + private: + SocketDataProviderArray<SocketDataProvider> mock_data_; + SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; +}; + +class MockClientSocket : public SSLClientSocket { + public: + // Value returned by GetTLSUniqueChannelBinding(). + static const char kTlsUnique[]; + + // The BoundNetLog is needed to test LoadTimingInfo, which uses NetLog IDs as + // unique socket IDs. + explicit MockClientSocket(const BoundNetLog& net_log); + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) = 0; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) = 0; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + + // SSLClientSocket implementation. + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) OVERRIDE; + virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + virtual NextProtoStatus GetNextProto(std::string* proto, + std::string* server_protos) OVERRIDE; + virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + + protected: + virtual ~MockClientSocket(); + void RunCallbackAsync(const CompletionCallback& callback, int result); + void RunCallback(const CompletionCallback& callback, int result); + + base::WeakPtrFactory<MockClientSocket> weak_factory_; + + // True if Connect completed successfully and Disconnect hasn't been called. + bool connected_; + + // Address of the "remote" peer we're connected to. + IPEndPoint peer_addr_; + + BoundNetLog net_log_; +}; + +class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { + public: + MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, + SocketDataProvider* socket); + virtual ~MockTCPClientSocket(); + + const AddressList& addresses() const { return addresses_; } + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // AsyncSocket: + virtual void OnReadComplete(const MockRead& data) OVERRIDE; + virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + + private: + int CompleteRead(); + + AddressList addresses_; + + SocketDataProvider* data_; + int read_offset_; + MockRead read_data_; + bool need_read_data_; + + // True if the peer has closed the connection. This allows us to simulate + // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real + // TCPClientSocket. + bool peer_closed_connection_; + + // While an asynchronous IO is pending, we save our user-buffer state. + IOBuffer* pending_buf_; + int pending_buf_len_; + CompletionCallback pending_callback_; + bool was_used_to_convey_data_; +}; + +// DeterministicSocketHelper is a helper class that can be used +// to simulate net::Socket::Read() and net::Socket::Write() +// using deterministic |data|. +// Note: This is provided as a common helper class because +// of the inheritance hierarchy of DeterministicMock[UDP,TCP]ClientSocket and a +// desire not to introduce an additional common base class. +class DeterministicSocketHelper { + public: + DeterministicSocketHelper(net::NetLog* net_log, + DeterministicSocketData* data); + virtual ~DeterministicSocketHelper(); + + bool write_pending() const { return write_pending_; } + bool read_pending() const { return read_pending_; } + + void CompleteWrite(); + int CompleteRead(); + + int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); + int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); + + const BoundNetLog& net_log() const { return net_log_; } + + bool was_used_to_convey_data() const { return was_used_to_convey_data_; } + + bool peer_closed_connection() const { return peer_closed_connection_; } + + DeterministicSocketData* data() const { return data_; } + + private: + bool write_pending_; + CompletionCallback write_callback_; + int write_result_; + + MockRead read_data_; + + IOBuffer* read_buf_; + int read_buf_len_; + bool read_pending_; + CompletionCallback read_callback_; + DeterministicSocketData* data_; + bool was_used_to_convey_data_; + bool peer_closed_connection_; + BoundNetLog net_log_; +}; + +// Mock UDP socket to be used in conjunction with DeterministicSocketData. +class DeterministicMockUDPClientSocket + : public DatagramClientSocket, + public AsyncSocket, + public DeterministicSocketData::Delegate, + public base::SupportsWeakPtr<DeterministicMockUDPClientSocket> { + public: + DeterministicMockUDPClientSocket(net::NetLog* net_log, + DeterministicSocketData* data); + virtual ~DeterministicMockUDPClientSocket(); + + // DeterministicSocketData::Delegate: + virtual bool WritePending() const OVERRIDE; + virtual bool ReadPending() const OVERRIDE; + virtual void CompleteWrite() OVERRIDE; + virtual int CompleteRead() OVERRIDE; + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + // DatagramSocket implementation. + virtual void Close() OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + + // DatagramClientSocket implementation. + virtual int Connect(const IPEndPoint& address) OVERRIDE; + + // AsyncSocket implementation. + virtual void OnReadComplete(const MockRead& data) OVERRIDE; + virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + + private: + bool connected_; + IPEndPoint peer_address_; + DeterministicSocketHelper helper_; +}; + +// Mock TCP socket to be used in conjunction with DeterministicSocketData. +class DeterministicMockTCPClientSocket + : public MockClientSocket, + public AsyncSocket, + public DeterministicSocketData::Delegate, + public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> { + public: + DeterministicMockTCPClientSocket(net::NetLog* net_log, + DeterministicSocketData* data); + virtual ~DeterministicMockTCPClientSocket(); + + // DeterministicSocketData::Delegate: + virtual bool WritePending() const OVERRIDE; + virtual bool ReadPending() const OVERRIDE; + virtual void CompleteWrite() OVERRIDE; + virtual int CompleteRead() OVERRIDE; + + // Socket: + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + + // StreamSocket: + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // AsyncSocket: + virtual void OnReadComplete(const MockRead& data) OVERRIDE; + virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + + private: + DeterministicSocketHelper helper_; +}; + +class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { + public: + MockSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + SSLSocketDataProvider* socket); + virtual ~MockSSLClientSocket(); + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // SSLClientSocket implementation. + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual NextProtoStatus GetNextProto(std::string* proto, + std::string* server_protos) OVERRIDE; + virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE; + virtual void set_protocol_negotiated( + NextProto protocol_negotiated) OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + + // This MockSocket does not implement the manual async IO feature. + virtual void OnReadComplete(const MockRead& data) OVERRIDE; + virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + + virtual bool WasChannelIDSent() const OVERRIDE; + virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE; + virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + + private: + static void ConnectCallback(MockSSLClientSocket *ssl_client_socket, + const CompletionCallback& callback, + int rv); + + scoped_ptr<ClientSocketHandle> transport_; + SSLSocketDataProvider* data_; + bool is_npn_state_set_; + bool new_npn_value_; + bool is_protocol_negotiated_set_; + NextProto protocol_negotiated_; +}; + +class MockUDPClientSocket + : public DatagramClientSocket, + public AsyncSocket { + public: + MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log); + virtual ~MockUDPClientSocket(); + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + // DatagramSocket implementation. + virtual void Close() OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + + // DatagramClientSocket implementation. + virtual int Connect(const IPEndPoint& address) OVERRIDE; + + // AsyncSocket implementation. + virtual void OnReadComplete(const MockRead& data) OVERRIDE; + virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + + private: + int CompleteRead(); + + void RunCallbackAsync(const CompletionCallback& callback, int result); + void RunCallback(const CompletionCallback& callback, int result); + + bool connected_; + SocketDataProvider* data_; + int read_offset_; + MockRead read_data_; + bool need_read_data_; + + // Address of the "remote" peer we're connected to. + IPEndPoint peer_addr_; + + // While an asynchronous IO is pending, we save our user-buffer state. + IOBuffer* pending_buf_; + int pending_buf_len_; + CompletionCallback pending_callback_; + + BoundNetLog net_log_; + + base::WeakPtrFactory<MockUDPClientSocket> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MockUDPClientSocket); +}; + +class TestSocketRequest : public TestCompletionCallbackBase { + public: + TestSocketRequest(std::vector<TestSocketRequest*>* request_order, + size_t* completion_count); + virtual ~TestSocketRequest(); + + ClientSocketHandle* handle() { return &handle_; } + + const net::CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result); + + ClientSocketHandle handle_; + std::vector<TestSocketRequest*>* request_order_; + size_t* completion_count_; + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(TestSocketRequest); +}; + +class ClientSocketPoolTest { + public: + enum KeepAlive { + KEEP_ALIVE, + + // A socket will be disconnected in addition to handle being reset. + NO_KEEP_ALIVE, + }; + + static const int kIndexOutOfBounds; + static const int kRequestNotFound; + + ClientSocketPoolTest(); + ~ClientSocketPoolTest(); + + template <typename PoolType, typename SocketParams> + int StartRequestUsingPool(PoolType* socket_pool, + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<SocketParams>& socket_params) { + DCHECK(socket_pool); + TestSocketRequest* request = new TestSocketRequest(&request_order_, + &completion_count_); + requests_.push_back(request); + int rv = request->handle()->Init( + group_name, socket_params, priority, request->callback(), + socket_pool, BoundNetLog()); + if (rv != ERR_IO_PENDING) + request_order_.push_back(request); + return rv; + } + + // Provided there were n requests started, takes |index| in range 1..n + // and returns order in which that request completed, in range 1..n, + // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound + // if that request did not complete (for example was canceled). + int GetOrderOfRequest(size_t index) const; + + // Resets first initialized socket handle from |requests_|. If found such + // a handle, returns true. + bool ReleaseOneConnection(KeepAlive keep_alive); + + // Releases connections until there is nothing to release. + void ReleaseAllConnections(KeepAlive keep_alive); + + // Note that this uses 0-based indices, while GetOrderOfRequest takes and + // returns 0-based indices. + TestSocketRequest* request(int i) { return requests_[i]; } + + size_t requests_size() const { return requests_.size(); } + ScopedVector<TestSocketRequest>* requests() { return &requests_; } + size_t completion_count() const { return completion_count_; } + + private: + ScopedVector<TestSocketRequest> requests_; + std::vector<TestSocketRequest*> request_order_; + size_t completion_count_; +}; + +class MockTransportClientSocketPool : public TransportClientSocketPool { + public: + class MockConnectJob { + public: + MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, + const CompletionCallback& callback); + ~MockConnectJob(); + + int Connect(); + bool CancelHandle(const ClientSocketHandle* handle); + + private: + void OnConnect(int rv); + + scoped_ptr<StreamSocket> socket_; + ClientSocketHandle* handle_; + CompletionCallback user_callback_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectJob); + }; + + MockTransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + ClientSocketFactory* socket_factory); + + virtual ~MockTransportClientSocketPool(); + + int release_count() const { return release_count_; } + int cancel_count() const { return cancel_count_; } + + // TransportClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; + + private: + ClientSocketFactory* client_socket_factory_; + ScopedVector<MockConnectJob> job_list_; + int release_count_; + int cancel_count_; + + DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool); +}; + +class DeterministicMockClientSocketFactory : public ClientSocketFactory { + public: + DeterministicMockClientSocketFactory(); + virtual ~DeterministicMockClientSocketFactory(); + + void AddSocketDataProvider(DeterministicSocketData* socket); + void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); + void ResetNextMockIndexes(); + + // Return |index|-th MockSSLClientSocket (starting from 0) that the factory + // created. + MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; + + SocketDataProviderArray<DeterministicSocketData>& mock_data() { + return mock_data_; + } + std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } + std::vector<DeterministicMockUDPClientSocket*>& udp_client_sockets() { + return udp_client_sockets_; + } + + // ClientSocketFactory + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE; + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE; + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE; + virtual void ClearSSLSessionCache() OVERRIDE; + + private: + SocketDataProviderArray<DeterministicSocketData> mock_data_; + SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; + + // Store pointers to handed out sockets in case the test wants to get them. + std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_; + std::vector<DeterministicMockUDPClientSocket*> udp_client_sockets_; + std::vector<MockSSLClientSocket*> ssl_client_sockets_; +}; + +class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { + public: + MockSOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + TransportClientSocketPool* transport_pool); + + virtual ~MockSOCKSClientSocketPool(); + + // SOCKSClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; + + private: + TransportClientSocketPool* const transport_pool_; + + DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool); +}; + +// Constants for a successful SOCKS v5 handshake. +extern const char kSOCKS5GreetRequest[]; +extern const int kSOCKS5GreetRequestLength; + +extern const char kSOCKS5GreetResponse[]; +extern const int kSOCKS5GreetResponseLength; + +extern const char kSOCKS5OkRequest[]; +extern const int kSOCKS5OkRequestLength; + +extern const char kSOCKS5OkResponse[]; +extern const int kSOCKS5OkResponseLength; + +} // namespace net + +#endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc new file mode 100644 index 00000000000..537b584a932 --- /dev/null +++ b/chromium/net/socket/socks5_client_socket.cc @@ -0,0 +1,487 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks5_client_socket.h" + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/debug/trace_event.h" +#include "base/format_macros.h" +#include "base/strings/string_util.h" +#include "base/sys_byteorder.h" +#include "net/base/io_buffer.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/socket/client_socket_handle.h" + +namespace net { + +const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2; +const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10; +const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5; +const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05; +const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01; +const uint8 SOCKS5ClientSocket::kNullByte = 0x00; + +COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); +COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); + +SOCKS5ClientSocket::SOCKS5ClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info) + : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, + base::Unretained(this))), + transport_(transport_socket.Pass()), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_sent_(0), + bytes_received_(0), + read_header_size(kReadHeaderSize), + host_request_info_(req_info), + net_log_(transport_->socket()->NetLog()) { +} + +SOCKS5ClientSocket::~SOCKS5ClientSocket() { + Disconnect(); +} + +int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(transport_->socket()); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT); + + next_state_ = STATE_GREET_WRITE; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv); + } + return rv; +} + +void SOCKS5ClientSocket::Disconnect() { + completed_handshake_ = false; + transport_->socket()->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool SOCKS5ClientSocket::IsConnected() const { + return completed_handshake_ && transport_->socket()->IsConnected(); +} + +bool SOCKS5ClientSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); +} + +const BoundNetLog& SOCKS5ClientSocket::NetLog() const { + return net_log_; +} + +void SOCKS5ClientSocket::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SOCKS5ClientSocket::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SOCKS5ClientSocket::WasEverUsed() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasEverUsed(); + } + NOTREACHED(); + return false; +} + +bool SOCKS5ClientSocket::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->UsingTCPFastOpen(); + } + NOTREACHED(); + return false; +} + +bool SOCKS5ClientSocket::WasNpnNegotiated() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasNpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; + +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + return transport_->socket()->Read(buf, buf_len, callback); +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} + +bool SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + +void SOCKS5ClientSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(!user_callback_.is_null()); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + CompletionCallback c = user_callback_; + user_callback_.Reset(); + c.Run(result); +} + +void SOCKS5ClientSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(NetLog::TYPE_SOCKS5_CONNECT); + DoCallback(rv); + } +} + +int SOCKS5ClientSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_GREET_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_WRITE); + rv = DoGreetWrite(); + break; + case STATE_GREET_WRITE_COMPLETE: + rv = DoGreetWriteComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_WRITE, rv); + break; + case STATE_GREET_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_READ); + rv = DoGreetRead(); + break; + case STATE_GREET_READ_COMPLETE: + rv = DoGreetReadComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_READ, rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE, rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_READ); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLog::TYPE_SOCKS5_HANDSHAKE_READ, rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication +const char kSOCKS5GreetReadData[] = { 0x05, 0x00 }; + +int SOCKS5ClientSocket::DoGreetWrite() { + // Since we only have 1 byte to send the hostname length in, if the + // URL has a hostname longer than 255 characters we can't send it. + if (0xFF < host_request_info_.hostname().size()) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_HOSTNAME_TOO_BIG); + return ERR_SOCKS_CONNECTION_FAILED; + } + + if (buffer_.empty()) { + buffer_ = std::string(kSOCKS5GreetWriteData, + arraysize(kSOCKS5GreetWriteData)); + bytes_sent_ = 0; + } + + next_state_ = STATE_GREET_WRITE_COMPLETE; + size_t handshake_buf_len = buffer_.size() - bytes_sent_; + handshake_buf_ = new IOBuffer(handshake_buf_len); + memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->socket() + ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_); +} + +int SOCKS5ClientSocket::DoGreetWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + bytes_received_ = 0; + next_state_ = STATE_GREET_READ; + } else { + next_state_ = STATE_GREET_WRITE; + } + return OK; +} + +int SOCKS5ClientSocket::DoGreetRead() { + next_state_ = STATE_GREET_READ_COMPLETE; + size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->socket() + ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_); +} + +int SOCKS5ClientSocket::DoGreetReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING); + return ERR_SOCKS_CONNECTION_FAILED; + } + + bytes_received_ += result; + buffer_.append(handshake_buf_->data(), result); + if (bytes_received_ < kGreetReadHeaderSize) { + next_state_ = STATE_GREET_READ; + return OK; + } + + // Got the greet data. + if (buffer_[0] != kSOCKS5Version) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION, + NetLog::IntegerCallback("version", buffer_[0])); + return ERR_SOCKS_CONNECTION_FAILED; + } + if (buffer_[1] != 0x00) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_AUTH, + NetLog::IntegerCallback("method", buffer_[1])); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.clear(); + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; +} + +int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake) + const { + DCHECK(handshake->empty()); + + handshake->push_back(kSOCKS5Version); + handshake->push_back(kTunnelCommand); // Connect command + handshake->push_back(kNullByte); // Reserved null + + handshake->push_back(kEndPointDomain); // The type of the address. + + DCHECK_GE(static_cast<size_t>(0xFF), host_request_info_.hostname().size()); + + // First add the size of the hostname, followed by the hostname. + handshake->push_back(static_cast<unsigned char>( + host_request_info_.hostname().size())); + handshake->append(host_request_info_.hostname()); + + uint16 nw_port = base::HostToNet16(host_request_info_.port()); + handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port)); + return OK; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int SOCKS5ClientSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + int rv = BuildHandshakeWriteBuffer(&buffer_); + if (rv != OK) + return rv; + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], + handshake_buf_len); + return transport_->socket() + ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_); +} + +int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + next_state_ = STATE_HANDSHAKE_READ; + buffer_.clear(); + } else if (bytes_sent_ < buffer_.size()) { + next_state_ = STATE_HANDSHAKE_WRITE; + } else { + NOTREACHED(); + } + + return OK; +} + +int SOCKS5ClientSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + bytes_received_ = 0; + read_header_size = kReadHeaderSize; + } + + int handshake_buf_len = read_header_size - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->socket() + ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_); +} + +int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + bytes_received_ += result; + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (bytes_received_ == kReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION, + NetLog::IntegerCallback("version", buffer_[0])); + return ERR_SOCKS_CONNECTION_FAILED; + } + if (buffer_[1] != 0x00) { + net_log_.AddEvent(NetLog::TYPE_SOCKS_SERVER_ERROR, + NetLog::IntegerCallback("error_code", buffer_[1])); + return ERR_SOCKS_CONNECTION_FAILED; + } + + // We check the type of IP/Domain the server returns and accordingly + // increase the size of the response. For domains, we need to read the + // size of the domain, so the initial request size is upto the domain + // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is + // read, we substract 1 byte from the additional request size. + SocksEndPointAddressType address_type = + static_cast<SocksEndPointAddressType>(buffer_[3]); + if (address_type == kEndPointDomain) + read_header_size += static_cast<uint8>(buffer_[4]); + else if (address_type == kEndPointResolvedIPv4) + read_header_size += sizeof(struct in_addr) - 1; + else if (address_type == kEndPointResolvedIPv6) + read_header_size += sizeof(struct in6_addr) - 1; + else { + net_log_.AddEvent(NetLog::TYPE_SOCKS_UNKNOWN_ADDRESS_TYPE, + NetLog::IntegerCallback("address_type", buffer_[3])); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size += 2; // for the port. + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + // When the final bytes are read, setup handshake. We ignore the rest + // of the response since they represent the SOCKSv5 endpoint and have + // no use when doing a tunnel connection. + if (bytes_received_ == read_header_size) { + completed_handshake_ = true; + buffer_.clear(); + next_state_ = STATE_NONE; + return OK; + } + + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->socket()->GetPeerAddress(address); +} + +int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->socket()->GetLocalAddress(address); +} + +} // namespace net diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h new file mode 100644 index 00000000000..45216244f10 --- /dev/null +++ b/chromium/net/socket/socks5_client_socket.h @@ -0,0 +1,155 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_ +#define NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/gtest_prod_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/dns/host_resolver.h" +#include "net/socket/stream_socket.h" +#include "url/gurl.h" + +namespace net { + +class ClientSocketHandle; +class BoundNetLog; + +// This StreamSocket is used to setup a SOCKSv5 handshake with a socks proxy. +// Currently no SOCKSv5 authentication is supported. +class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { + public: + // |req_info| contains the hostname and port to which the socket above will + // communicate to via the SOCKS layer. + // + // Although SOCKS 5 supports 3 different modes of addressing, we will + // always pass it a hostname. This means the DNS resolving is done + // proxy side. + SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info); + + // On destruction Disconnect() is called. + virtual ~SOCKS5ClientSocket(); + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + + private: + enum State { + STATE_GREET_WRITE, + STATE_GREET_WRITE_COMPLETE, + STATE_GREET_READ, + STATE_GREET_READ_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + // Addressing type that can be specified in requests or responses. + enum SocksEndPointAddressType { + kEndPointDomain = 0x03, + kEndPointResolvedIPv4 = 0x01, + kEndPointResolvedIPv6 = 0x04, + }; + + static const unsigned int kGreetReadHeaderSize; + static const unsigned int kWriteHeaderSize; + static const unsigned int kReadHeaderSize; + static const uint8 kSOCKS5Version; + static const uint8 kTunnelCommand; + static const uint8 kNullByte; + + void DoCallback(int result); + void OnIOComplete(int result); + + int DoLoop(int last_io_result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + int DoGreetRead(); + int DoGreetReadComplete(int result); + int DoGreetWrite(); + int DoGreetWriteComplete(int result); + + // Writes the SOCKS handshake buffer into |handshake| + // and return OK on success. + int BuildHandshakeWriteBuffer(std::string* handshake) const; + + CompletionCallback io_callback_; + + // Stores the underlying socket. + scoped_ptr<ClientSocketHandle> transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr<IOBuffer> handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // These contain the bytes sent / received by the SOCKS handshake. + size_t bytes_sent_; + size_t bytes_received_; + + size_t read_header_size; + + HostResolver::RequestInfo host_request_info_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_ diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc new file mode 100644 index 00000000000..4c9240ff5c1 --- /dev/null +++ b/chromium/net/socket/socks5_client_socket_unittest.cc @@ -0,0 +1,375 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks5_client_socket.h" + +#include <algorithm> +#include <iterator> +#include <map> + +#include "base/sys_byteorder.h" +#include "net/base/address_list.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/test_completion_callback.h" +#include "net/base/winsock_init.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/tcp_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +//----------------------------------------------------------------------------- + +namespace net { + +namespace { + +// Base class to test SOCKS5ClientSocket +class SOCKS5ClientSocketTest : public PlatformTest { + public: + SOCKS5ClientSocketTest(); + // Create a SOCKSClientSocket on top of a MockSocket. + scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + const std::string& hostname, + int port, + NetLog* net_log); + + virtual void SetUp(); + + protected: + const uint16 kNwPort; + CapturingNetLog net_log_; + scoped_ptr<SOCKS5ClientSocket> user_sock_; + AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). + StreamSocket* tcp_sock_; + TestCompletionCallback callback_; + scoped_ptr<MockHostResolver> host_resolver_; + scoped_ptr<SocketDataProvider> data_; + + private: + DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest); +}; + +SOCKS5ClientSocketTest::SOCKS5ClientSocketTest() + : kNwPort(base::HostToNet16(80)), + host_resolver_(new MockHostResolver) { +} + +// Set up platform before every test case +void SOCKS5ClientSocketTest::SetUp() { + PlatformTest::SetUp(); + + // Resolve the "localhost" AddressList used by the TCP connection to connect. + HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); + TestCompletionCallback callback; + int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(), + NULL, BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + ASSERT_EQ(OK, rv); +} + +scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket( + MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + const std::string& hostname, + int port, + NetLog* net_log) { + TestCompletionCallback callback; + data_.reset(new StaticSocketDataProvider(reads, reads_count, + writes, writes_count)); + tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); + + int rv = tcp_sock_->Connect(callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(tcp_sock_->IsConnected()); + + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket( + connection.Pass(), + HostResolver::RequestInfo(HostPortPair(hostname, port)))); +} + +// Tests a complete SOCKS5 handshake and the disconnection. +TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { + const std::string payload_write = "random data"; + const std::string payload_read = "moar random data"; + + const char kOkRequest[] = { + 0x05, // Version + 0x01, // Command (CONNECT) + 0x00, // Reserved. + 0x03, // Address type (DOMAINNAME). + 0x09, // Length of domain (9) + // Domain string: + 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', + 0x00, 0x50, // 16-bit port (80) + }; + + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)), + MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), + MockRead(ASYNC, payload_read.data(), payload_read.size()) }; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + "localhost", 80, &net_log_); + + // At this state the TCP connection is completed but not the SOCKS handshake. + EXPECT_TRUE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + + CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, + NetLog::TYPE_SOCKS5_CONNECT)); + + rv = callback_.WaitForResult(); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, + NetLog::TYPE_SOCKS5_CONNECT)); + + scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); + memcpy(buffer->data(), payload_write.data(), payload_write.size()); + rv = user_sock_->Write( + buffer.get(), payload_write.size(), callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(static_cast<int>(payload_write.size()), rv); + + buffer = new IOBuffer(payload_read.size()); + rv = + user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(static_cast<int>(payload_read.size()), rv); + EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); + + user_sock_->Disconnect(); + EXPECT_FALSE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); +} + +// Test that you can call Connect() again after having called Disconnect(). +TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { + const std::string hostname = "my-host-name"; + const char kSOCKS5DomainRequest[] = { + 0x05, // VER + 0x01, // CMD + 0x00, // RSV + 0x03, // ATYPE + }; + + std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest)); + request.push_back(hostname.size()); + request.append(hostname); + request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort)); + + for (int i = 0; i < 2; ++i) { + MockWrite data_writes[] = { + MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(SYNCHRONOUS, request.data(), request.size()) + }; + MockRead data_reads[] = { + MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) + }; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, NULL); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + + user_sock_->Disconnect(); + EXPECT_FALSE(user_sock_->IsConnected()); + } +} + +// Test that we fail trying to connect to a hosname longer than 255 bytes. +TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { + // Create a string of length 256, where each character is 'x'. + std::string large_host_name; + std::fill_n(std::back_inserter(large_host_name), 256, 'x'); + + // Create a SOCKS socket, with mock transport socket. + MockWrite data_writes[] = {MockWrite()}; + MockRead data_reads[] = {MockRead()}; + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + large_host_name, 80, NULL); + + // Try to connect -- should fail (without having read/written anything to + // the transport socket first) because the hostname is too long. + TestCompletionCallback callback; + int rv = user_sock_->Connect(callback.callback()); + EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); +} + +TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { + const std::string hostname = "www.google.com"; + + const char kOkRequest[] = { + 0x05, // Version + 0x01, // Command (CONNECT) + 0x00, // Reserved. + 0x03, // Address type (DOMAINNAME). + 0x0E, // Length of domain (14) + // Domain string: + 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm', + 0x00, 0x50, // 16-bit port (80) + }; + + // Test for partial greet request write + { + const char partial1[] = { 0x05, 0x01 }; + const char partial2[] = { 0x00 }; + MockWrite data_writes[] = { + MockWrite(ASYNC, arraysize(partial1)), + MockWrite(ASYNC, partial2, arraysize(partial2)), + MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, + NetLog::TYPE_SOCKS5_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, + NetLog::TYPE_SOCKS5_CONNECT)); + } + + // Test for partial greet response read + { + const char partial1[] = { 0x05 }; + const char partial2[] = { 0x00 }; + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; + MockRead data_reads[] = { + MockRead(ASYNC, partial1, arraysize(partial1)), + MockRead(ASYNC, partial2, arraysize(partial2)), + MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, + NetLog::TYPE_SOCKS5_CONNECT)); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, + NetLog::TYPE_SOCKS5_CONNECT)); + } + + // Test for partial handshake request write. + { + const int kSplitPoint = 3; // Break handshake write into two parts. + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(ASYNC, kOkRequest, kSplitPoint), + MockWrite(ASYNC, kOkRequest + kSplitPoint, + arraysize(kOkRequest) - kSplitPoint) + }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, + NetLog::TYPE_SOCKS5_CONNECT)); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, + NetLog::TYPE_SOCKS5_CONNECT)); + } + + // Test for partial handshake response read + { + const int kSplitPoint = 6; // Break the handshake read into two parts. + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), + MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) + }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), + MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint), + MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint, + kSOCKS5OkResponseLength - kSplitPoint) + }; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, + NetLog::TYPE_SOCKS5_CONNECT)); + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, + NetLog::TYPE_SOCKS5_CONNECT)); + } +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc new file mode 100644 index 00000000000..1941fdbfd95 --- /dev/null +++ b/chromium/net/socket/socks_client_socket.cc @@ -0,0 +1,432 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/compiler_specific.h" +#include "base/sys_byteorder.h" +#include "net/base/io_buffer.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/socket/client_socket_handle.h" + +namespace net { + +// Every SOCKS server requests a user-id from the client. It is optional +// and we send an empty string. +static const char kEmptyUserId[] = ""; + +// For SOCKS4, the client sends 8 bytes plus the size of the user-id. +static const unsigned int kWriteHeaderSize = 8; + +// For SOCKS4 the server sends 8 bytes for acknowledgement. +static const unsigned int kReadHeaderSize = 8; + +// Server Response codes for SOCKS. +static const uint8 kServerResponseOk = 0x5A; +static const uint8 kServerResponseRejected = 0x5B; +static const uint8 kServerResponseNotReachable = 0x5C; +static const uint8 kServerResponseMismatchedUserId = 0x5D; + +static const uint8 kSOCKSVersion4 = 0x04; +static const uint8 kSOCKSStreamRequest = 0x01; + +// A struct holding the essential details of the SOCKS4 Server Request. +// The port in the header is stored in network byte order. +struct SOCKS4ServerRequest { + uint8 version; + uint8 command; + uint16 nw_port; + uint8 ip[4]; +}; +COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, + socks4_server_request_struct_wrong_size); + +// A struct holding details of the SOCKS4 Server Response. +struct SOCKS4ServerResponse { + uint8 reserved_null; + uint8 code; + uint16 port; + uint8 ip[4]; +}; +COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, + socks4_server_response_struct_wrong_size); + +SOCKSClientSocket::SOCKSClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info, + HostResolver* host_resolver) + : transport_(transport_socket.Pass()), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_sent_(0), + bytes_received_(0), + host_resolver_(host_resolver), + host_request_info_(req_info), + net_log_(transport_->socket()->NetLog()) { +} + +SOCKSClientSocket::~SOCKSClientSocket() { + Disconnect(); +} + +int SOCKSClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(transport_.get()); + DCHECK(transport_->socket()); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_RESOLVE_HOST; + + net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); + } + return rv; +} + +void SOCKSClientSocket::Disconnect() { + completed_handshake_ = false; + host_resolver_.Cancel(); + transport_->socket()->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool SOCKSClientSocket::IsConnected() const { + return completed_handshake_ && transport_->socket()->IsConnected(); +} + +bool SOCKSClientSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); +} + +const BoundNetLog& SOCKSClientSocket::NetLog() const { + return net_log_; +} + +void SOCKSClientSocket::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SOCKSClientSocket::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SOCKSClientSocket::WasEverUsed() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasEverUsed(); + } + NOTREACHED(); + return false; +} + +bool SOCKSClientSocket::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->UsingTCPFastOpen(); + } + NOTREACHED(); + return false; +} + +bool SOCKSClientSocket::WasNpnNegotiated() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasNpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; + +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + return transport_->socket()->Read(buf, buf_len, callback); +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + return transport_->socket()->Write(buf, buf_len, callback); +} + +bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SOCKSClientSocket::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + +void SOCKSClientSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(!user_callback_.is_null()); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + CompletionCallback c = user_callback_; + user_callback_.Reset(); + DVLOG(1) << "Finished setting up SOCKS handshake"; + c.Run(result); +} + +void SOCKSClientSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); + DoCallback(rv); + } +} + +int SOCKSClientSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_RESOLVE_HOST: + DCHECK_EQ(OK, rv); + rv = DoResolveHost(); + break; + case STATE_RESOLVE_HOST_COMPLETE: + rv = DoResolveHostComplete(rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int SOCKSClientSocket::DoResolveHost() { + next_state_ = STATE_RESOLVE_HOST_COMPLETE; + // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 + // addresses for the target host. + host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); + return host_resolver_.Resolve( + host_request_info_, &addresses_, + base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), + net_log_); +} + +int SOCKSClientSocket::DoResolveHostComplete(int result) { + if (result != OK) { + // Resolving the hostname failed; fail the request rather than automatically + // falling back to SOCKS4a (since it can be confusing to see invalid IP + // addresses being sent to the SOCKS4 server when it doesn't support 4A.) + return result; + } + + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; +} + +// Builds the buffer that is to be sent to the server. +const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { + SOCKS4ServerRequest request; + request.version = kSOCKSVersion4; + request.command = kSOCKSStreamRequest; + request.nw_port = base::HostToNet16(host_request_info_.port()); + + DCHECK(!addresses_.empty()); + const IPEndPoint& endpoint = addresses_.front(); + + // We disabled IPv6 results when resolving the hostname, so none of the + // results in the list will be IPv6. + // TODO(eroman): we only ever use the first address in the list. It would be + // more robust to try all the IP addresses we have before + // failing the connect attempt. + CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); + CHECK_LE(endpoint.address().size(), sizeof(request.ip)); + memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); + + DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); + + std::string handshake_data(reinterpret_cast<char*>(&request), + sizeof(request)); + handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); + + return handshake_data; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int SOCKSClientSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + buffer_ = BuildHandshakeWriteBuffer(); + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_GT(handshake_buf_len, 0); + handshake_buf_ = new IOBuffer(handshake_buf_len); + memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], + handshake_buf_len); + return transport_->socket()->Write( + handshake_buf_.get(), + handshake_buf_len, + base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); +} + +int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + next_state_ = STATE_HANDSHAKE_READ; + buffer_.clear(); + } else if (bytes_sent_ < buffer_.size()) { + next_state_ = STATE_HANDSHAKE_WRITE; + } else { + return ERR_UNEXPECTED; + } + + return OK; +} + +int SOCKSClientSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + bytes_received_ = 0; + } + + int handshake_buf_len = kReadHeaderSize - bytes_received_; + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->socket()->Read( + handshake_buf_.get(), + handshake_buf_len, + base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); +} + +int SOCKSClientSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) + return ERR_CONNECTION_CLOSED; + + if (bytes_received_ + result > kReadHeaderSize) { + // TODO(eroman): Describe failure in NetLog. + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + bytes_received_ += result; + if (bytes_received_ < kReadHeaderSize) { + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + const SOCKS4ServerResponse* response = + reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); + + if (response->reserved_null != 0x00) { + LOG(ERROR) << "Unknown response from SOCKS server."; + return ERR_SOCKS_CONNECTION_FAILED; + } + + switch (response->code) { + case kServerResponseOk: + completed_handshake_ = true; + return OK; + case kServerResponseRejected: + LOG(ERROR) << "SOCKS request rejected or failed"; + return ERR_SOCKS_CONNECTION_FAILED; + case kServerResponseNotReachable: + LOG(ERROR) << "SOCKS request failed because client is not running " + << "identd (or not reachable from the server)"; + return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; + case kServerResponseMismatchedUserId: + LOG(ERROR) << "SOCKS request failed because client's identd could " + << "not confirm the user ID string in the request"; + return ERR_SOCKS_CONNECTION_FAILED; + default: + LOG(ERROR) << "SOCKS server sent unknown response"; + return ERR_SOCKS_CONNECTION_FAILED; + } + + // Note: we ignore the last 6 bytes as specified by the SOCKS protocol +} + +int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->socket()->GetPeerAddress(address); +} + +int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->socket()->GetLocalAddress(address); +} + +} // namespace net diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h new file mode 100644 index 00000000000..285c75ec295 --- /dev/null +++ b/chromium/net/socket/socks_client_socket.h @@ -0,0 +1,134 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ +#define NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/gtest_prod_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/dns/host_resolver.h" +#include "net/dns/single_request_host_resolver.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class ClientSocketHandle; +class BoundNetLog; + +// The SOCKS client socket implementation +class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { + public: + // |req_info| contains the hostname and port to which the socket above will + // communicate to via the socks layer. For testing the referrer is optional. + SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info, + HostResolver* host_resolver); + + // On destruction Disconnect() is called. + virtual ~SOCKSClientSocket(); + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + + private: + FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, CompleteHandshake); + FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, SOCKS4AFailedDNS); + FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6); + + enum State { + STATE_RESOLVE_HOST, + STATE_RESOLVE_HOST_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + + int DoLoop(int last_io_result); + int DoResolveHost(); + int DoResolveHostComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + + const std::string BuildHandshakeWriteBuffer() const; + + // Stores the underlying socket. + scoped_ptr<ClientSocketHandle> transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr<IOBuffer> handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // These contain the bytes sent / received by the SOCKS handshake. + size_t bytes_sent_; + size_t bytes_received_; + + // Used to resolve the hostname to which the SOCKS proxy will connect. + SingleRequestHostResolver host_resolver_; + AddressList addresses_; + HostResolver::RequestInfo host_request_info_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_H_ diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc new file mode 100644 index 00000000000..e49eabaa84c --- /dev/null +++ b/chromium/net/socket/socks_client_socket_pool.cc @@ -0,0 +1,310 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket_pool.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/socks5_client_socket.h" +#include "net/socket/socks_client_socket.h" +#include "net/socket/transport_client_socket_pool.h" + +namespace net { + +SOCKSSocketParams::SOCKSSocketParams( + const scoped_refptr<TransportSocketParams>& proxy_server, + bool socks_v5, + const HostPortPair& host_port_pair, + RequestPriority priority) + : transport_params_(proxy_server), + destination_(host_port_pair), + socks_v5_(socks_v5) { + if (transport_params_.get()) + ignore_limits_ = transport_params_->ignore_limits(); + else + ignore_limits_ = false; + destination_.set_priority(priority); +} + +SOCKSSocketParams::~SOCKSSocketParams() {} + +// SOCKSConnectJobs will time out after this many seconds. Note this is on +// top of the timeout for the transport socket. +static const int kSOCKSConnectJobTimeoutInSeconds = 30; + +SOCKSConnectJob::SOCKSConnectJob( + const std::string& group_name, + const scoped_refptr<SOCKSSocketParams>& socks_params, + const base::TimeDelta& timeout_duration, + TransportClientSocketPool* transport_pool, + HostResolver* host_resolver, + Delegate* delegate, + NetLog* net_log) + : ConnectJob(group_name, timeout_duration, delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + socks_params_(socks_params), + transport_pool_(transport_pool), + resolver_(host_resolver), + callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, + base::Unretained(this))) { +} + +SOCKSConnectJob::~SOCKSConnectJob() { + // We don't worry about cancelling the tcp socket since the destructor in + // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of + // it. +} + +LoadState SOCKSConnectJob::GetLoadState() const { + switch (next_state_) { + case STATE_TRANSPORT_CONNECT: + case STATE_TRANSPORT_CONNECT_COMPLETE: + return transport_socket_handle_->GetLoadState(); + case STATE_SOCKS_CONNECT: + case STATE_SOCKS_CONNECT_COMPLETE: + return LOAD_STATE_CONNECTING; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +void SOCKSConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this| +} + +int SOCKSConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = DoTransportConnectComplete(rv); + break; + case STATE_SOCKS_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoSOCKSConnect(); + break; + case STATE_SOCKS_CONNECT_COMPLETE: + rv = DoSOCKSConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + +int SOCKSConnectJob::DoTransportConnect() { + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + return transport_socket_handle_->Init( + group_name(), socks_params_->transport_params(), + socks_params_->destination().priority(), callback_, transport_pool_, + net_log()); +} + +int SOCKSConnectJob::DoTransportConnectComplete(int result) { + if (result != OK) + return ERR_PROXY_CONNECTION_FAILED; + + // Reset the timer to just the length of time allowed for SOCKS handshake + // so that a fast TCP connection plus a slow SOCKS failure doesn't take + // longer to timeout than it should. + ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); + next_state_ = STATE_SOCKS_CONNECT; + return result; +} + +int SOCKSConnectJob::DoSOCKSConnect() { + next_state_ = STATE_SOCKS_CONNECT_COMPLETE; + + // Add a SOCKS connection on top of the tcp socket. + if (socks_params_->is_socks_v5()) { + socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), + socks_params_->destination())); + } else { + socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), + socks_params_->destination(), + resolver_)); + } + return socket_->Connect( + base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); +} + +int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { + if (result != OK) { + socket_->Disconnect(); + return result; + } + + SetSocket(socket_.Pass()); + return result; +} + +int SOCKSConnectJob::ConnectInternal() { + next_state_ = STATE_TRANSPORT_CONNECT; + return DoLoop(OK); +} + +scoped_ptr<ConnectJob> +SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, + request.params(), + ConnectionTimeout(), + transport_pool_, + host_resolver_, + delegate, + net_log_)); +} + +base::TimeDelta +SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { + return transport_pool_->ConnectionTimeout() + + base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); +} + +SOCKSClientSocketPool::SOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + TransportClientSocketPool* transport_pool, + NetLog* net_log) + : transport_pool_(transport_pool), + base_(max_sockets, max_sockets_per_group, histograms, + ClientSocketPool::unused_idle_socket_timeout(), + ClientSocketPool::used_idle_socket_timeout(), + new SOCKSConnectJobFactory(transport_pool, + host_resolver, + net_log)) { + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->AddLayeredPool(this); +} + +SOCKSClientSocketPool::~SOCKSClientSocketPool() { + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->RemoveLayeredPool(this); +} + +int SOCKSClientSocketPool::RequestSocket( + const std::string& group_name, const void* socket_params, + RequestPriority priority, ClientSocketHandle* handle, + const CompletionCallback& callback, const BoundNetLog& net_log) { + const scoped_refptr<SOCKSSocketParams>* casted_socket_params = + static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); + + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); +} + +void SOCKSClientSocketPool::RequestSockets( + const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) { + const scoped_refptr<SOCKSSocketParams>* casted_params = + static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); + + base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); +} + +void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); +} + +void SOCKSClientSocketPool::FlushWithError(int error) { + base_.FlushWithError(error); +} + +bool SOCKSClientSocketPool::IsStalled() const { + return base_.IsStalled() || transport_pool_->IsStalled(); +} + +void SOCKSClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int SOCKSClientSocketPool::IdleSocketCount() const { + return base_.idle_socket_count(); +} + +int SOCKSClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState SOCKSClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + +base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const { + base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); + if (include_nested_pools) { + base::ListValue* list = new base::ListValue(); + list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", + "transport_socket_pool", + false)); + dict->Set("nested_pools", list); + } + return dict; +} + +base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { + return base_.ConnectionTimeout(); +} + +ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { + return base_.histograms(); +}; + +bool SOCKSClientSocketPool::CloseOneIdleConnection() { + if (base_.CloseOneIdleSocket()) + return true; + return base_.CloseOneIdleConnectionInLayeredPool(); +} + +} // namespace net diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h new file mode 100644 index 00000000000..fe69a78df69 --- /dev/null +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -0,0 +1,211 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/host_port_pair.h" +#include "net/dns/host_resolver.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool_histograms.h" + +namespace net { + +class ConnectJobFactory; +class TransportClientSocketPool; +class TransportSocketParams; + +class NET_EXPORT_PRIVATE SOCKSSocketParams + : public base::RefCounted<SOCKSSocketParams> { + public: + SOCKSSocketParams(const scoped_refptr<TransportSocketParams>& proxy_server, + bool socks_v5, const HostPortPair& host_port_pair, + RequestPriority priority); + + const scoped_refptr<TransportSocketParams>& transport_params() const { + return transport_params_; + } + const HostResolver::RequestInfo& destination() const { return destination_; } + bool is_socks_v5() const { return socks_v5_; } + bool ignore_limits() const { return ignore_limits_; } + + private: + friend class base::RefCounted<SOCKSSocketParams>; + ~SOCKSSocketParams(); + + // The transport (likely TCP) connection must point toward the proxy server. + const scoped_refptr<TransportSocketParams> transport_params_; + // This is the HTTP destination. + HostResolver::RequestInfo destination_; + const bool socks_v5_; + bool ignore_limits_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSSocketParams); +}; + +// SOCKSConnectJob handles the handshake to a socks server after setting up +// an underlying transport socket. +class SOCKSConnectJob : public ConnectJob { + public: + SOCKSConnectJob(const std::string& group_name, + const scoped_refptr<SOCKSSocketParams>& params, + const base::TimeDelta& timeout_duration, + TransportClientSocketPool* transport_pool, + HostResolver* host_resolver, + Delegate* delegate, + NetLog* net_log); + virtual ~SOCKSConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const OVERRIDE; + + private: + enum State { + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_SOCKS_CONNECT, + STATE_SOCKS_CONNECT_COMPLETE, + STATE_NONE, + }; + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + int DoSOCKSConnect(); + int DoSOCKSConnectComplete(int result); + + // Begins the transport connection and the SOCKS handshake. Returns OK on + // success and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal() OVERRIDE; + + scoped_refptr<SOCKSSocketParams> socks_params_; + TransportClientSocketPool* const transport_pool_; + HostResolver* const resolver_; + + State next_state_; + CompletionCallback callback_; + scoped_ptr<ClientSocketHandle> transport_socket_handle_; + scoped_ptr<StreamSocket> socket_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJob); +}; + +class NET_EXPORT_PRIVATE SOCKSClientSocketPool + : public ClientSocketPool, public LayeredPool { + public: + SOCKSClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + TransportClientSocketPool* transport_pool, + NetLog* net_log); + + virtual ~SOCKSClientSocketPool(); + + // ClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) OVERRIDE; + + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; + + virtual void FlushWithError(int error) OVERRIDE; + + virtual bool IsStalled() const OVERRIDE; + + virtual void CloseIdleSockets() OVERRIDE; + + virtual int IdleSocketCount() const OVERRIDE; + + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE; + + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE; + + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE; + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + + // LayeredPool implementation. + virtual bool CloseOneIdleConnection() OVERRIDE; + + private: + typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase; + + class SOCKSConnectJobFactory : public PoolBase::ConnectJobFactory { + public: + SOCKSConnectJobFactory(TransportClientSocketPool* transport_pool, + HostResolver* host_resolver, + NetLog* net_log) + : transport_pool_(transport_pool), + host_resolver_(host_resolver), + net_log_(net_log) {} + + virtual ~SOCKSConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const OVERRIDE; + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + + private: + TransportClientSocketPool* const transport_pool_; + HostResolver* const host_resolver_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJobFactory); + }; + + TransportClientSocketPool* const transport_pool_; + PoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams); + +} // namespace net + +#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc new file mode 100644 index 00000000000..77440d36a19 --- /dev/null +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -0,0 +1,297 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket_pool.h" + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/time/time.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; + +// Make sure |handle|'s load times are set correctly. Only connect times should +// be set. +void TestLoadTimingInfo(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + // None of these tests use a NetLog. + EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + EXPECT_FALSE(load_timing_info.socket_reused); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +class SOCKSClientSocketPoolTest : public testing::Test { + protected: + class SOCKS5MockData { + public: + explicit SOCKS5MockData(IoMode mode) { + writes_.reset(new MockWrite[3]); + writes_[0] = MockWrite(mode, kSOCKS5GreetRequest, + kSOCKS5GreetRequestLength); + writes_[1] = MockWrite(mode, kSOCKS5OkRequest, kSOCKS5OkRequestLength); + writes_[2] = MockWrite(mode, 0); + + reads_.reset(new MockRead[3]); + reads_[0] = MockRead(mode, kSOCKS5GreetResponse, + kSOCKS5GreetResponseLength); + reads_[1] = MockRead(mode, kSOCKS5OkResponse, kSOCKS5OkResponseLength); + reads_[2] = MockRead(mode, 0); + + data_.reset(new StaticSocketDataProvider(reads_.get(), 3, + writes_.get(), 3)); + } + + SocketDataProvider* data_provider() { return data_.get(); } + + private: + scoped_ptr<StaticSocketDataProvider> data_; + scoped_ptr<MockWrite[]> writes_; + scoped_ptr<MockRead[]> reads_; + }; + + SOCKSClientSocketPoolTest() + : ignored_transport_socket_params_(new TransportSocketParams( + HostPortPair("proxy", 80), MEDIUM, false, false, + OnHostResolutionCallback())), + transport_histograms_("MockTCP"), + transport_socket_pool_( + kMaxSockets, kMaxSocketsPerGroup, + &transport_histograms_, + &transport_client_socket_factory_), + ignored_socket_params_(new SOCKSSocketParams( + ignored_transport_socket_params_, true, HostPortPair("host", 80), + MEDIUM)), + socks_histograms_("SOCKSUnitTest"), + pool_(kMaxSockets, kMaxSocketsPerGroup, + &socks_histograms_, + NULL, + &transport_socket_pool_, + NULL) { + } + + virtual ~SOCKSClientSocketPoolTest() {} + + int StartRequest(const std::string& group_name, RequestPriority priority) { + return test_base_.StartRequestUsingPool( + &pool_, group_name, priority, ignored_socket_params_); + } + + int GetOrderOfRequest(size_t index) const { + return test_base_.GetOrderOfRequest(index); + } + + ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } + + scoped_refptr<TransportSocketParams> ignored_transport_socket_params_; + ClientSocketPoolHistograms transport_histograms_; + MockClientSocketFactory transport_client_socket_factory_; + MockTransportClientSocketPool transport_socket_pool_; + + scoped_refptr<SOCKSSocketParams> ignored_socket_params_; + ClientSocketPoolHistograms socks_histograms_; + SOCKSClientSocketPool pool_; + ClientSocketPoolTest test_base_; +}; + +TEST_F(SOCKSClientSocketPoolTest, Simple) { + SOCKS5MockData data(SYNCHRONOUS); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + &pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); +} + +TEST_F(SOCKSClientSocketPoolTest, Async) { + SOCKS5MockData data(ASYNC); + transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); +} + +TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) { + StaticSocketDataProvider socket_data; + socket_data.set_connect_data(MockConnect(SYNCHRONOUS, + ERR_CONNECTION_REFUSED)); + transport_client_socket_factory_.AddSocketDataProvider(&socket_data); + + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) { + StaticSocketDataProvider socket_data; + socket_data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_REFUSED)); + transport_client_socket_factory_.AddSocketDataProvider(&socket_data); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); +} + +TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) { + MockRead failed_read[] = { + MockRead(SYNCHRONOUS, 0), + }; + StaticSocketDataProvider socket_data( + failed_read, arraysize(failed_read), NULL, 0); + socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider(&socket_data); + + ClientSocketHandle handle; + EXPECT_EQ(0, transport_socket_pool_.release_count()); + int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_EQ(1, transport_socket_pool_.release_count()); +} + +TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) { + MockRead failed_read[] = { + MockRead(ASYNC, 0), + }; + StaticSocketDataProvider socket_data( + failed_read, arraysize(failed_read), NULL, 0); + socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider(&socket_data); + + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(0, transport_socket_pool_.release_count()); + int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_EQ(1, transport_socket_pool_.release_count()); +} + +TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) { + SOCKS5MockData data(SYNCHRONOUS); + transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + // We need two connections because the pool base lets one cancelled + // connect job proceed for potential future use. + SOCKS5MockData data2(SYNCHRONOUS); + transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); + + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + int rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + pool_.CancelRequest("a", (*requests())[0]->handle()); + pool_.CancelRequest("a", (*requests())[1]->handle()); + // Requests in the connect phase don't actually get cancelled. + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + + // Now wait for the TCP sockets to connect. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2)); + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + EXPECT_EQ(2, pool_.IdleSocketCount()); + + (*requests())[0]->handle()->Reset(); + (*requests())[1]->handle()->Reset(); +} + +TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) { + SOCKS5MockData data(ASYNC); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); + // We need two connections because the pool base lets one cancelled + // connect job proceed for potential future use. + SOCKS5MockData data2(ASYNC); + data2.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); + + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + EXPECT_EQ(0, transport_socket_pool_.release_count()); + int rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + rv = StartRequest("a", LOW); + EXPECT_EQ(ERR_IO_PENDING, rv); + + pool_.CancelRequest("a", (*requests())[0]->handle()); + pool_.CancelRequest("a", (*requests())[1]->handle()); + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + // Requests in the connect phase don't actually get cancelled. + EXPECT_EQ(0, transport_socket_pool_.release_count()); + + // Now wait for the async data to reach the SOCKS connect jobs. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2)); + EXPECT_EQ(0, transport_socket_pool_.cancel_count()); + EXPECT_EQ(0, transport_socket_pool_.release_count()); + EXPECT_EQ(2, pool_.IdleSocketCount()); + + (*requests())[0]->handle()->Reset(); + (*requests())[1]->handle()->Reset(); +} + +// It would be nice to also test the timeouts in SOCKSClientSocketPool. + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc new file mode 100644 index 00000000000..8c30838959d --- /dev/null +++ b/chromium/net/socket/socks_client_socket_unittest.cc @@ -0,0 +1,415 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socks_client_socket.h" + +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/test_completion_callback.h" +#include "net/base/winsock_init.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/tcp_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +//----------------------------------------------------------------------------- + +namespace net { + +const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; +const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; + +class SOCKSClientSocketTest : public PlatformTest { + public: + SOCKSClientSocketTest(); + // Create a SOCKSClientSocket on top of a MockSocket. + scoped_ptr<SOCKSClientSocket> BuildMockSocket( + MockRead reads[], size_t reads_count, + MockWrite writes[], size_t writes_count, + HostResolver* host_resolver, + const std::string& hostname, int port, + NetLog* net_log); + virtual void SetUp(); + + protected: + scoped_ptr<SOCKSClientSocket> user_sock_; + AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). + StreamSocket* tcp_sock_; + TestCompletionCallback callback_; + scoped_ptr<MockHostResolver> host_resolver_; + scoped_ptr<SocketDataProvider> data_; +}; + +SOCKSClientSocketTest::SOCKSClientSocketTest() + : host_resolver_(new MockHostResolver) { +} + +// Set up platform before every test case +void SOCKSClientSocketTest::SetUp() { + PlatformTest::SetUp(); +} + +scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( + MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + HostResolver* host_resolver, + const std::string& hostname, + int port, + NetLog* net_log) { + + TestCompletionCallback callback; + data_.reset(new StaticSocketDataProvider(reads, reads_count, + writes, writes_count)); + tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); + + int rv = tcp_sock_->Connect(callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(tcp_sock_->IsConnected()); + + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( + connection.Pass(), + HostResolver::RequestInfo(HostPortPair(hostname, port)), + host_resolver)); +} + +// Implementation of HostResolver that never completes its resolve request. +// We use this in the test "DisconnectWhileHostResolveInProgress" to make +// sure that the outstanding resolve request gets cancelled. +class HangingHostResolverWithCancel : public HostResolver { + public: + HangingHostResolverWithCancel() : outstanding_request_(NULL) {} + + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) OVERRIDE { + DCHECK(addresses); + DCHECK_EQ(false, callback.is_null()); + EXPECT_FALSE(HasOutstandingRequest()); + outstanding_request_ = reinterpret_cast<RequestHandle>(1); + *out_req = outstanding_request_; + return ERR_IO_PENDING; + } + + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) OVERRIDE { + NOTIMPLEMENTED(); + return ERR_UNEXPECTED; + } + + virtual void CancelRequest(RequestHandle req) OVERRIDE { + EXPECT_TRUE(HasOutstandingRequest()); + EXPECT_EQ(outstanding_request_, req); + outstanding_request_ = NULL; + } + + bool HasOutstandingRequest() { + return outstanding_request_ != NULL; + } + + private: + RequestHandle outstanding_request_; + + DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel); +}; + +// Tests a complete handshake and the disconnection. +TEST_F(SOCKSClientSocketTest, CompleteHandshake) { + const std::string payload_write = "random data"; + const std::string payload_read = "moar random data"; + + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), + MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)), + MockRead(ASYNC, payload_read.data(), payload_read.size()) }; + CapturingNetLog log; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); + + // At this state the TCP connection is completed but not the SOCKS handshake. + EXPECT_TRUE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE( + LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + EXPECT_FALSE(user_sock_->IsConnected()); + + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); + + scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); + memcpy(buffer->data(), payload_write.data(), payload_write.size()); + rv = user_sock_->Write( + buffer.get(), payload_write.size(), callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(static_cast<int>(payload_write.size()), rv); + + buffer = new IOBuffer(payload_read.size()); + rv = + user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback_.WaitForResult(); + EXPECT_EQ(static_cast<int>(payload_read.size()), rv); + EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); + + user_sock_->Disconnect(); + EXPECT_FALSE(tcp_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnected()); +} + +// List of responses from the socks server and the errors they should +// throw up are tested here. +TEST_F(SOCKSClientSocketTest, HandshakeFailures) { + const struct { + const char fail_reply[8]; + Error fail_code; + } tests[] = { + // Failure of the server response code + { + { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, + ERR_SOCKS_CONNECTION_FAILED, + }, + // Failure of the null byte + { + { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, + ERR_SOCKS_CONNECTION_FAILED, + }, + }; + + //--------------------------------------- + + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { + MockWrite data_writes[] = { + MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(SYNCHRONOUS, tests[i].fail_reply, + arraysize(tests[i].fail_reply)) }; + CapturingNetLog log; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(tests[i].fail_code, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_TRUE(tcp_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); + } +} + +// Tests scenario when the server sends the handshake response in +// more than one packet. +TEST_F(SOCKSClientSocketTest, PartialServerReads) { + const char kSOCKSPartialReply1[] = { 0x00 }; + const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; + + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), + MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; + CapturingNetLog log; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); +} + +// Tests scenario when the client sends the handshake request in +// more than one packet. +TEST_F(SOCKSClientSocketTest, PartialClientWrites) { + const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; + const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; + + MockWrite data_writes[] = { + MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)), + // simulate some empty writes + MockWrite(ASYNC, 0), + MockWrite(ASYNC, 0), + MockWrite(ASYNC, kSOCKSPartialRequest2, + arraysize(kSOCKSPartialRequest2)) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; + CapturingNetLog log; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(user_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); +} + +// Tests the case when the server sends a smaller sized handshake data +// and closes the connection. +TEST_F(SOCKSClientSocketTest, FailedSocketRead) { + MockWrite data_writes[] = { + MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; + MockRead data_reads[] = { + MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), + // close connection unexpectedly + MockRead(SYNCHRONOUS, 0) }; + CapturingNetLog log; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); +} + +// Tries to connect to an unknown hostname. Should fail rather than +// falling back to SOCKS4a. +TEST_F(SOCKSClientSocketTest, FailedDNS) { + const char hostname[] = "unresolved.ipv4.address"; + + host_resolver_->rules()->AddSimulatedFailure(hostname); + + CapturingNetLog log; + + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver_.get(), + hostname, 80, + &log); + + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent( + entries, 0, NetLog::TYPE_SOCKS_CONNECT)); + + rv = callback_.WaitForResult(); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); + EXPECT_FALSE(user_sock_->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsEndEvent( + entries, -1, NetLog::TYPE_SOCKS_CONNECT)); +} + +// Calls Disconnect() while a host resolve is in progress. The outstanding host +// resolve should be cancelled. +TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { + scoped_ptr<HangingHostResolverWithCancel> hanging_resolver( + new HangingHostResolverWithCancel()); + + // Doesn't matter what the socket data is, we will never use it -- garbage. + MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; + MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; + + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hanging_resolver.get(), + "foo", 80, + NULL); + + // Start connecting (will get stuck waiting for the host to resolve). + int rv = user_sock_->Connect(callback_.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); + + // The host resolver should have received the resolve request. + EXPECT_TRUE(hanging_resolver->HasOutstandingRequest()); + + // Disconnect the SOCKS socket -- this should cancel the outstanding resolve. + user_sock_->Disconnect(); + + EXPECT_FALSE(hanging_resolver->HasOutstandingRequest()); + + EXPECT_FALSE(user_sock_->IsConnected()); + EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); +} + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc new file mode 100644 index 00000000000..54f66a1f681 --- /dev/null +++ b/chromium/net/socket/ssl_client_socket.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_client_socket.h" + +#include "base/strings/string_util.h" + +namespace net { + +SSLClientSocket::SSLClientSocket() + : was_npn_negotiated_(false), + was_spdy_negotiated_(false), + protocol_negotiated_(kProtoUnknown), + channel_id_sent_(false) { +} + +// static +NextProto SSLClientSocket::NextProtoFromString( + const std::string& proto_string) { + if (proto_string == "http1.1" || proto_string == "http/1.1") { + return kProtoHTTP11; + } else if (proto_string == "spdy/1") { + return kProtoSPDY1; + } else if (proto_string == "spdy/2") { + return kProtoSPDY2; + } else if (proto_string == "spdy/3") { + return kProtoSPDY3; + } else if (proto_string == "spdy/3.1") { + return kProtoSPDY31; + } else if (proto_string == "spdy/4a2") { + return kProtoSPDY4a2; + } else if (proto_string == "HTTP-draft-04/2.0") { + return kProtoHTTP2Draft04; + } else if (proto_string == "quic/1+spdy/3") { + return kProtoQUIC1SPDY3; + } else { + return kProtoUnknown; + } +} + +// static +const char* SSLClientSocket::NextProtoToString(NextProto next_proto) { + switch (next_proto) { + case kProtoHTTP11: + return "http/1.1"; + case kProtoSPDY1: + return "spdy/1"; + case kProtoSPDY2: + return "spdy/2"; + case kProtoSPDY3: + return "spdy/3"; + case kProtoSPDY31: + return "spdy/3.1"; + case kProtoSPDY4a2: + return "spdy/4a2"; + case kProtoHTTP2Draft04: + return "HTTP-draft-04/2.0"; + case kProtoQUIC1SPDY3: + return "quic/1+spdy/3"; + case kProtoSPDY21: + case kProtoUnknown: + break; + } + return "unknown"; +} + +// static +const char* SSLClientSocket::NextProtoStatusToString( + const SSLClientSocket::NextProtoStatus status) { + switch (status) { + case kNextProtoUnsupported: + return "unsupported"; + case kNextProtoNegotiated: + return "negotiated"; + case kNextProtoNoOverlap: + return "no-overlap"; + } + return NULL; +} + +// static +std::string SSLClientSocket::ServerProtosToString( + const std::string& server_protos) { + const char* protos = server_protos.c_str(); + size_t protos_len = server_protos.length(); + std::vector<std::string> server_protos_with_commas; + for (size_t i = 0; i < protos_len; ) { + const size_t len = protos[i]; + std::string proto_str(&protos[i + 1], len); + server_protos_with_commas.push_back(proto_str); + i += len + 1; + } + return JoinString(server_protos_with_commas, ','); +} + +bool SSLClientSocket::WasNpnNegotiated() const { + return was_npn_negotiated_; +} + +NextProto SSLClientSocket::GetNegotiatedProtocol() const { + return protocol_negotiated_; +} + +bool SSLClientSocket::IgnoreCertError(int error, int load_flags) { + if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS) + return true; + + if (error == ERR_CERT_COMMON_NAME_INVALID && + (load_flags & LOAD_IGNORE_CERT_COMMON_NAME_INVALID)) + return true; + + if (error == ERR_CERT_DATE_INVALID && + (load_flags & LOAD_IGNORE_CERT_DATE_INVALID)) + return true; + + if (error == ERR_CERT_AUTHORITY_INVALID && + (load_flags & LOAD_IGNORE_CERT_AUTHORITY_INVALID)) + return true; + + return false; +} + +bool SSLClientSocket::set_was_npn_negotiated(bool negotiated) { + return was_npn_negotiated_ = negotiated; +} + +bool SSLClientSocket::was_spdy_negotiated() const { + return was_spdy_negotiated_; +} + +bool SSLClientSocket::set_was_spdy_negotiated(bool negotiated) { + return was_spdy_negotiated_ = negotiated; +} + +void SSLClientSocket::set_protocol_negotiated(NextProto protocol_negotiated) { + protocol_negotiated_ = protocol_negotiated; +} + +bool SSLClientSocket::WasChannelIDSent() const { + return channel_id_sent_; +} + +void SSLClientSocket::set_channel_id_sent(bool channel_id_sent) { + channel_id_sent_ = channel_id_sent; +} + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h new file mode 100644 index 00000000000..41ee0873347 --- /dev/null +++ b/chromium/net/socket/ssl_client_socket.h @@ -0,0 +1,141 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_H_ +#define NET_SOCKET_SSL_CLIENT_SOCKET_H_ + +#include <string> + +#include "net/base/completion_callback.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/socket/ssl_socket.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class CertVerifier; +class ServerBoundCertService; +class SSLCertRequestInfo; +class SSLInfo; +class TransportSecurityState; + +// This struct groups together several fields which are used by various +// classes related to SSLClientSocket. +struct SSLClientSocketContext { + SSLClientSocketContext() + : cert_verifier(NULL), + server_bound_cert_service(NULL), + transport_security_state(NULL) {} + + SSLClientSocketContext(CertVerifier* cert_verifier_arg, + ServerBoundCertService* server_bound_cert_service_arg, + TransportSecurityState* transport_security_state_arg, + const std::string& ssl_session_cache_shard_arg) + : cert_verifier(cert_verifier_arg), + server_bound_cert_service(server_bound_cert_service_arg), + transport_security_state(transport_security_state_arg), + ssl_session_cache_shard(ssl_session_cache_shard_arg) {} + + CertVerifier* cert_verifier; + ServerBoundCertService* server_bound_cert_service; + TransportSecurityState* transport_security_state; + // ssl_session_cache_shard is an opaque string that identifies a shard of the + // SSL session cache. SSL sockets with the same ssl_session_cache_shard may + // resume each other's SSL sessions but we'll never sessions between shards. + const std::string ssl_session_cache_shard; +}; + +// A client socket that uses SSL as the transport layer. +// +// NOTE: The SSL handshake occurs within the Connect method after a TCP +// connection is established. If a SSL error occurs during the handshake, +// Connect will fail. +// +class NET_EXPORT SSLClientSocket : public SSLSocket { + public: + SSLClientSocket(); + + // Next Protocol Negotiation (NPN) allows a TLS client and server to come to + // an agreement about the application level protocol to speak over a + // connection. + enum NextProtoStatus { + // WARNING: These values are serialized to disk. Don't change them. + + kNextProtoUnsupported = 0, // The server doesn't support NPN. + kNextProtoNegotiated = 1, // We agreed on a protocol. + kNextProtoNoOverlap = 2, // No protocols in common. We requested + // the first protocol in our list. + }; + + // StreamSocket: + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + + // Gets the SSL CertificateRequest info of the socket after Connect failed + // with ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) = 0; + + // Get the application level protocol that we negotiated with the server. + // *proto is set to the resulting protocol (n.b. that the string may have + // embedded NULs). + // kNextProtoUnsupported: *proto is cleared. + // kNextProtoNegotiated: *proto is set to the negotiated protocol. + // kNextProtoNoOverlap: *proto is set to the first protocol in the + // supported list. + // *server_protos is set to the server advertised protocols. + virtual NextProtoStatus GetNextProto(std::string* proto, + std::string* server_protos) = 0; + + static NextProto NextProtoFromString(const std::string& proto_string); + + static const char* NextProtoToString(NextProto next_proto); + + static const char* NextProtoStatusToString(const NextProtoStatus status); + + // Can be used with the second argument(|server_protos|) of |GetNextProto| to + // construct a comma separated string of server advertised protocols. + static std::string ServerProtosToString(const std::string& server_protos); + + static bool IgnoreCertError(int error, int load_flags); + + // ClearSessionCache clears the SSL session cache, used to resume SSL + // sessions. + static void ClearSessionCache(); + + virtual bool set_was_npn_negotiated(bool negotiated); + + virtual bool was_spdy_negotiated() const; + + virtual bool set_was_spdy_negotiated(bool negotiated); + + virtual void set_protocol_negotiated(NextProto protocol_negotiated); + + // Returns the ServerBoundCertService used by this socket, or NULL if + // server bound certificates are not supported. + virtual ServerBoundCertService* GetServerBoundCertService() const = 0; + + // Returns true if a channel ID was sent on this connection. + // This may be useful for protocols, like SPDY, which allow the same + // connection to be shared between multiple domains, each of which need + // a channel ID. + virtual bool WasChannelIDSent() const; + + virtual void set_channel_id_sent(bool channel_id_sent); + + private: + // True if NPN was responded to, independent of selecting SPDY or HTTP. + bool was_npn_negotiated_; + // True if NPN successfully negotiated SPDY. + bool was_spdy_negotiated_; + // Protocol that we negotiated with the server. + NextProto protocol_negotiated_; + // True if a channel ID was sent. + bool channel_id_sent_; +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_CLIENT_SOCKET_H_ diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc new file mode 100644 index 00000000000..acc1b0dee2e --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -0,0 +1,3504 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file includes code SSLClientSocketNSS::DoVerifyCertComplete() derived +// from AuthCertificateCallback() in +// mozilla/security/manager/ssl/src/nsNSSCallbacks.cpp. + +/* ***** BEGIN LICENSE BLOCK ***** + * Version: MPL 1.1/GPL 2.0/LGPL 2.1 + * + * The contents of this file are subject to the Mozilla Public License Version + * 1.1 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.mozilla.org/MPL/ + * + * Software distributed under the License is distributed on an "AS IS" basis, + * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License + * for the specific language governing rights and limitations under the + * License. + * + * The Original Code is the Netscape security libraries. + * + * The Initial Developer of the Original Code is + * Netscape Communications Corporation. + * Portions created by the Initial Developer are Copyright (C) 2000 + * the Initial Developer. All Rights Reserved. + * + * Contributor(s): + * Ian McGreer <mcgreer@netscape.com> + * Javier Delgadillo <javi@netscape.com> + * Kai Engert <kengert@redhat.com> + * + * Alternatively, the contents of this file may be used under the terms of + * either the GNU General Public License Version 2 or later (the "GPL"), or + * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"), + * in which case the provisions of the GPL or the LGPL are applicable instead + * of those above. If you wish to allow use of your version of this file only + * under the terms of either the GPL or the LGPL, and not to allow others to + * use your version of this file under the terms of the MPL, indicate your + * decision by deleting the provisions above and replace them with the notice + * and other provisions required by the GPL or the LGPL. If you do not delete + * the provisions above, a recipient may use your version of this file under + * the terms of any one of the MPL, the GPL or the LGPL. + * + * ***** END LICENSE BLOCK ***** */ + +#include "net/socket/ssl_client_socket_nss.h" + +#include <certdb.h> +#include <hasht.h> +#include <keyhi.h> +#include <nspr.h> +#include <nss.h> +#include <ocsp.h> +#include <pk11pub.h> +#include <secerr.h> +#include <sechash.h> +#include <ssl.h> +#include <sslerr.h> +#include <sslproto.h> + +#include <algorithm> +#include <limits> +#include <map> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback_helpers.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/memory/singleton.h" +#include "base/metrics/histogram.h" +#include "base/single_thread_task_runner.h" +#include "base/stl_util.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/string_util.h" +#include "base/strings/stringprintf.h" +#include "base/thread_task_runner_handle.h" +#include "base/threading/thread_restrictions.h" +#include "base/values.h" +#include "crypto/ec_private_key.h" +#include "crypto/nss_util.h" +#include "crypto/nss_util_internal.h" +#include "crypto/rsa_private_key.h" +#include "crypto/scoped_nss_types.h" +#include "net/base/address_list.h" +#include "net/base/connection_type_histograms.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/cert/asn1_util.h" +#include "net/cert/cert_status_flags.h" +#include "net/cert/cert_verifier.h" +#include "net/cert/single_request_cert_verifier.h" +#include "net/cert/x509_certificate_net_log_param.h" +#include "net/cert/x509_util.h" +#include "net/http/transport_security_state.h" +#include "net/ocsp/nss_ocsp.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/nss_ssl_util.h" +#include "net/socket/ssl_error_params.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_info.h" + +#if defined(OS_WIN) +#include <windows.h> +#include <wincrypt.h> + +#include "base/win/windows_version.h" +#elif defined(OS_MACOSX) +#include <Security/SecBase.h> +#include <Security/SecCertificate.h> +#include <Security/SecIdentity.h> + +#include "base/mac/mac_logging.h" +#include "base/synchronization/lock.h" +#include "crypto/mac_security_services_lock.h" +#elif defined(USE_NSS) +#include <dlfcn.h> +#endif + +namespace net { + +// State machines are easier to debug if you log state transitions. +// Enable these if you want to see what's going on. +#if 1 +#define EnterFunction(x) +#define LeaveFunction(x) +#define GotoState(s) next_handshake_state_ = s +#else +#define EnterFunction(x)\ + VLOG(1) << (void *)this << " " << __FUNCTION__ << " enter " << x\ + << "; next_handshake_state " << next_handshake_state_ +#define LeaveFunction(x)\ + VLOG(1) << (void *)this << " " << __FUNCTION__ << " leave " << x\ + << "; next_handshake_state " << next_handshake_state_ +#define GotoState(s)\ + do {\ + VLOG(1) << (void *)this << " " << __FUNCTION__ << " jump to state " << s;\ + next_handshake_state_ = s;\ + } while (0) +#endif + +namespace { + +// SSL plaintext fragments are shorter than 16KB. Although the record layer +// overhead is allowed to be 2K + 5 bytes, in practice the overhead is much +// smaller than 1KB. So a 17KB buffer should be large enough to hold an +// entire SSL record. +const int kRecvBufferSize = 17 * 1024; +const int kSendBufferSize = 17 * 1024; + +// Used by SSLClientSocketNSS::Core to indicate there is no read result +// obtained by a previous operation waiting to be returned to the caller. +// This constant can be any non-negative/non-zero value (eg: it does not +// overlap with any value of the net::Error range, including net::OK). +const int kNoPendingReadResult = 1; + +#if defined(OS_WIN) +// CERT_OCSP_RESPONSE_PROP_ID is only implemented on Vista+, but it can be +// set on Windows XP without error. There is some overhead from the server +// sending the OCSP response if it supports the extension, for the subset of +// XP clients who will request it but be unable to use it, but this is an +// acceptable trade-off for simplicity of implementation. +bool IsOCSPStaplingSupported() { + return true; +} +#elif defined(USE_NSS) +typedef SECStatus +(*CacheOCSPResponseFromSideChannelFunction)( + CERTCertDBHandle *handle, CERTCertificate *cert, PRTime time, + SECItem *encodedResponse, void *pwArg); + +// On Linux, we dynamically link against the system version of libnss3.so. In +// order to continue working on systems without up-to-date versions of NSS we +// lookup CERT_CacheOCSPResponseFromSideChannel with dlsym. + +// RuntimeLibNSSFunctionPointers is a singleton which caches the results of any +// runtime symbol resolution that we need. +class RuntimeLibNSSFunctionPointers { + public: + CacheOCSPResponseFromSideChannelFunction + GetCacheOCSPResponseFromSideChannelFunction() { + return cache_ocsp_response_from_side_channel_; + } + + static RuntimeLibNSSFunctionPointers* GetInstance() { + return Singleton<RuntimeLibNSSFunctionPointers>::get(); + } + + private: + friend struct DefaultSingletonTraits<RuntimeLibNSSFunctionPointers>; + + RuntimeLibNSSFunctionPointers() { + cache_ocsp_response_from_side_channel_ = + (CacheOCSPResponseFromSideChannelFunction) + dlsym(RTLD_DEFAULT, "CERT_CacheOCSPResponseFromSideChannel"); + } + + CacheOCSPResponseFromSideChannelFunction + cache_ocsp_response_from_side_channel_; +}; + +CacheOCSPResponseFromSideChannelFunction +GetCacheOCSPResponseFromSideChannelFunction() { + return RuntimeLibNSSFunctionPointers::GetInstance() + ->GetCacheOCSPResponseFromSideChannelFunction(); +} + +bool IsOCSPStaplingSupported() { + return GetCacheOCSPResponseFromSideChannelFunction() != NULL; +} +#else +// TODO(agl): Figure out if we can plumb the OCSP response into Mac's system +// certificate validation functions. +bool IsOCSPStaplingSupported() { + return false; +} +#endif + +class FreeCERTCertificate { + public: + inline void operator()(CERTCertificate* x) const { + CERT_DestroyCertificate(x); + } +}; +typedef scoped_ptr_malloc<CERTCertificate, FreeCERTCertificate> + ScopedCERTCertificate; + +#if defined(OS_WIN) + +// This callback is intended to be used with CertFindChainInStore. In addition +// to filtering by extended/enhanced key usage, we do not show expired +// certificates and require digital signature usage in the key usage +// extension. +// +// This matches our behavior on Mac OS X and that of NSS. It also matches the +// default behavior of IE8. See http://support.microsoft.com/kb/890326 and +// http://blogs.msdn.com/b/askie/archive/2009/06/09/my-expired-client-certificates-no-longer-display-when-connecting-to-my-web-server-using-ie8.aspx +BOOL WINAPI ClientCertFindCallback(PCCERT_CONTEXT cert_context, + void* find_arg) { + VLOG(1) << "Calling ClientCertFindCallback from _nss"; + // Verify the certificate's KU is good. + BYTE key_usage; + if (CertGetIntendedKeyUsage(X509_ASN_ENCODING, cert_context->pCertInfo, + &key_usage, 1)) { + if (!(key_usage & CERT_DIGITAL_SIGNATURE_KEY_USAGE)) + return FALSE; + } else { + DWORD err = GetLastError(); + // If |err| is non-zero, it's an actual error. Otherwise the extension + // just isn't present, and we treat it as if everything was allowed. + if (err) { + DLOG(ERROR) << "CertGetIntendedKeyUsage failed: " << err; + return FALSE; + } + } + + // Verify the current time is within the certificate's validity period. + if (CertVerifyTimeValidity(NULL, cert_context->pCertInfo) != 0) + return FALSE; + + // Verify private key metadata is associated with this certificate. + DWORD size = 0; + if (!CertGetCertificateContextProperty( + cert_context, CERT_KEY_PROV_INFO_PROP_ID, NULL, &size)) { + return FALSE; + } + + return TRUE; +} + +#endif + +void DestroyCertificates(CERTCertificate** certs, size_t len) { + for (size_t i = 0; i < len; i++) + CERT_DestroyCertificate(certs[i]); +} + +// Helper functions to make it possible to log events from within the +// SSLClientSocketNSS::Core. +void AddLogEvent(const base::WeakPtr<BoundNetLog>& net_log, + NetLog::EventType event_type) { + if (!net_log) + return; + net_log->AddEvent(event_type); +} + +// Helper function to make it possible to log events from within the +// SSLClientSocketNSS::Core. +void AddLogEventWithCallback(const base::WeakPtr<BoundNetLog>& net_log, + NetLog::EventType event_type, + const NetLog::ParametersCallback& callback) { + if (!net_log) + return; + net_log->AddEvent(event_type, callback); +} + +// Helper function to make it easier to call BoundNetLog::AddByteTransferEvent +// from within the SSLClientSocketNSS::Core. +// AddByteTransferEvent expects to receive a const char*, which within the +// Core is backed by an IOBuffer. If the "const char*" is bound via +// base::Bind and posted to another thread, and the IOBuffer that backs that +// pointer then goes out of scope on the origin thread, this would result in +// an invalid read of a stale pointer. +// Instead, provide a signature that accepts an IOBuffer*, so that a reference +// to the owning IOBuffer can be bound to the Callback. This ensures that the +// IOBuffer will stay alive long enough to cross threads if needed. +void LogByteTransferEvent( + const base::WeakPtr<BoundNetLog>& net_log, NetLog::EventType event_type, + int len, IOBuffer* buffer) { + if (!net_log) + return; + net_log->AddByteTransferEvent(event_type, len, buffer->data()); +} + +// PeerCertificateChain is a helper object which extracts the certificate +// chain, as given by the server, from an NSS socket and performs the needed +// resource management. The first element of the chain is the leaf certificate +// and the other elements are in the order given by the server. +class PeerCertificateChain { + public: + PeerCertificateChain() {} + PeerCertificateChain(const PeerCertificateChain& other); + ~PeerCertificateChain(); + PeerCertificateChain& operator=(const PeerCertificateChain& other); + + // Resets the current chain, freeing any resources, and updates the current + // chain to be a copy of the chain stored in |nss_fd|. + // If |nss_fd| is NULL, then the current certificate chain will be freed. + void Reset(PRFileDesc* nss_fd); + + // Returns the current certificate chain as a vector of DER-encoded + // base::StringPieces. The returned vector remains valid until Reset is + // called. + std::vector<base::StringPiece> AsStringPieceVector() const; + + bool empty() const { return certs_.empty(); } + size_t size() const { return certs_.size(); } + + CERTCertificate* operator[](size_t index) const { + DCHECK_LT(index, certs_.size()); + return certs_[index]; + } + + private: + std::vector<CERTCertificate*> certs_; +}; + +PeerCertificateChain::PeerCertificateChain( + const PeerCertificateChain& other) { + *this = other; +} + +PeerCertificateChain::~PeerCertificateChain() { + Reset(NULL); +} + +PeerCertificateChain& PeerCertificateChain::operator=( + const PeerCertificateChain& other) { + if (this == &other) + return *this; + + Reset(NULL); + certs_.reserve(other.certs_.size()); + for (size_t i = 0; i < other.certs_.size(); ++i) + certs_.push_back(CERT_DupCertificate(other.certs_[i])); + + return *this; +} + +void PeerCertificateChain::Reset(PRFileDesc* nss_fd) { + for (size_t i = 0; i < certs_.size(); ++i) + CERT_DestroyCertificate(certs_[i]); + certs_.clear(); + + if (nss_fd == NULL) + return; + + unsigned int num_certs = 0; + SECStatus rv = SSL_PeerCertificateChain(nss_fd, NULL, &num_certs, 0); + DCHECK_EQ(SECSuccess, rv); + + // The handshake on |nss_fd| may not have completed. + if (num_certs == 0) + return; + + certs_.resize(num_certs); + const unsigned int expected_num_certs = num_certs; + rv = SSL_PeerCertificateChain(nss_fd, vector_as_array(&certs_), + &num_certs, expected_num_certs); + DCHECK_EQ(SECSuccess, rv); + DCHECK_EQ(expected_num_certs, num_certs); +} + +std::vector<base::StringPiece> +PeerCertificateChain::AsStringPieceVector() const { + std::vector<base::StringPiece> v(certs_.size()); + for (unsigned i = 0; i < certs_.size(); i++) { + v[i] = base::StringPiece( + reinterpret_cast<const char*>(certs_[i]->derCert.data), + certs_[i]->derCert.len); + } + + return v; +} + +// HandshakeState is a helper struct used to pass handshake state between +// the NSS task runner and the network task runner. +// +// It contains members that may be read or written on the NSS task runner, +// but which also need to be read from the network task runner. The NSS task +// runner will notify the network task runner whenever this state changes, so +// that the network task runner can safely make a copy, which avoids the need +// for locking. +struct HandshakeState { + HandshakeState() { Reset(); } + + void Reset() { + next_proto_status = SSLClientSocket::kNextProtoUnsupported; + next_proto.clear(); + server_protos.clear(); + channel_id_sent = false; + server_cert_chain.Reset(NULL); + server_cert = NULL; + resumed_handshake = false; + ssl_connection_status = 0; + } + + // Set to kNextProtoNegotiated if NPN was successfully negotiated, with the + // negotiated protocol stored in |next_proto|. + SSLClientSocket::NextProtoStatus next_proto_status; + std::string next_proto; + // If the server supports NPN, the protocols supported by the server. + std::string server_protos; + + // True if a channel ID was sent. + bool channel_id_sent; + + // List of DER-encoded X.509 DistinguishedName of certificate authorities + // allowed by the server. + std::vector<std::string> cert_authorities; + + // Set when the handshake fully completes. + // + // The server certificate is first received from NSS as an NSS certificate + // chain (|server_cert_chain|) and then converted into a platform-specific + // X509Certificate object (|server_cert|). It's possible for some + // certificates to be successfully parsed by NSS, and not by the platform + // libraries (i.e.: when running within a sandbox, different parsing + // algorithms, etc), so it's not safe to assume that |server_cert| will + // always be non-NULL. + PeerCertificateChain server_cert_chain; + scoped_refptr<X509Certificate> server_cert; + + // True if the current handshake was the result of TLS session resumption. + bool resumed_handshake; + + // The negotiated security parameters (TLS version, cipher, extensions) of + // the SSL connection. + int ssl_connection_status; +}; + +// Client-side error mapping functions. + +// Map NSS error code to network error code. +int MapNSSClientError(PRErrorCode err) { + switch (err) { + case SSL_ERROR_BAD_CERT_ALERT: + case SSL_ERROR_UNSUPPORTED_CERT_ALERT: + case SSL_ERROR_REVOKED_CERT_ALERT: + case SSL_ERROR_EXPIRED_CERT_ALERT: + case SSL_ERROR_CERTIFICATE_UNKNOWN_ALERT: + case SSL_ERROR_UNKNOWN_CA_ALERT: + case SSL_ERROR_ACCESS_DENIED_ALERT: + return ERR_BAD_SSL_CLIENT_AUTH_CERT; + default: + return MapNSSError(err); + } +} + +// Map NSS error code from the first SSL handshake to network error code. +int MapNSSClientHandshakeError(PRErrorCode err) { + switch (err) { + // If the server closed on us, it is a protocol error. + // Some TLS-intolerant servers do this when we request TLS. + case PR_END_OF_FILE_ERROR: + return ERR_SSL_PROTOCOL_ERROR; + default: + return MapNSSClientError(err); + } +} + +} // namespace + +// SSLClientSocketNSS::Core provides a thread-safe, ref-counted core that is +// able to marshal data between NSS functions and an underlying transport +// socket. +// +// All public functions are meant to be called from the network task runner, +// and any callbacks supplied will be invoked there as well, provided that +// Detach() has not been called yet. +// +///////////////////////////////////////////////////////////////////////////// +// +// Threading within SSLClientSocketNSS and SSLClientSocketNSS::Core: +// +// Because NSS may block on either hardware or user input during operations +// such as signing, creating certificates, or locating private keys, the Core +// handles all of the interactions with the underlying NSS SSL socket, so +// that these blocking calls can be executed on a dedicated task runner. +// +// Note that the network task runner and the NSS task runner may be executing +// on the same thread. If that happens, then it's more performant to try to +// complete as much work as possible synchronously, even if it might block, +// rather than continually PostTask-ing to the same thread. +// +// Because NSS functions should only be called on the NSS task runner, while +// I/O resources should only be accessed on the network task runner, most +// public functions are implemented via three methods, each with different +// task runner affinities. +// +// In the single-threaded mode (where the network and NSS task runners run on +// the same thread), these are all attempted synchronously, while in the +// multi-threaded mode, message passing is used. +// +// 1) NSS Task Runner: Execute NSS function (DoPayloadRead, DoPayloadWrite, +// DoHandshake) +// 2) NSS Task Runner: Prepare data to go from NSS to an IO function: +// (BufferRecv, BufferSend) +// 3) Network Task Runner: Perform IO on that data (DoBufferRecv, +// DoBufferSend, DoGetDomainBoundCert, OnGetDomainBoundCertComplete) +// 4) Both Task Runners: Callback for asynchronous completion or to marshal +// data from the network task runner back to NSS (BufferRecvComplete, +// BufferSendComplete, OnHandshakeIOComplete) +// +///////////////////////////////////////////////////////////////////////////// +// Single-threaded example +// +// |--------------------------Network Task Runner--------------------------| +// SSLClientSocketNSS Core (Transport Socket) +// Read() +// |-------------------------V +// Read() +// | +// DoPayloadRead() +// | +// BufferRecv() +// | +// DoBufferRecv() +// |-------------------------V +// Read() +// V-------------------------| +// BufferRecvComplete() +// | +// PostOrRunCallback() +// V-------------------------| +// (Read Callback) +// +///////////////////////////////////////////////////////////////////////////// +// Multi-threaded example: +// +// |--------------------Network Task Runner-------------|--NSS Task Runner--| +// SSLClientSocketNSS Core Socket Core +// Read() +// |---------------------V +// Read() +// |-------------------------------V +// Read() +// | +// DoPayloadRead() +// | +// BufferRecv +// V-------------------------------| +// DoBufferRecv +// |----------------V +// Read() +// V----------------| +// BufferRecvComplete() +// |-------------------------------V +// BufferRecvComplete() +// | +// PostOrRunCallback() +// V-------------------------------| +// PostOrRunCallback() +// V---------------------| +// (Read Callback) +// +///////////////////////////////////////////////////////////////////////////// +class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { + public: + // Creates a new Core. + // + // Any calls to NSS are executed on the |nss_task_runner|, while any calls + // that need to operate on the underlying transport, net log, or server + // bound certificate fetching will happen on the |network_task_runner|, so + // that their lifetimes match that of the owning SSLClientSocketNSS. + // + // The caller retains ownership of |transport|, |net_log|, and + // |server_bound_cert_service|, and they will not be accessed once Detach() + // has been called. + Core(base::SequencedTaskRunner* network_task_runner, + base::SequencedTaskRunner* nss_task_runner, + ClientSocketHandle* transport, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + BoundNetLog* net_log, + ServerBoundCertService* server_bound_cert_service); + + // Called on the network task runner. + // Transfers ownership of |socket|, an NSS SSL socket, and |buffers|, the + // underlying memio implementation, to the Core. Returns true if the Core + // was successfully registered with the socket. + bool Init(PRFileDesc* socket, memio_Private* buffers); + + // Called on the network task runner. + // Sets the predicted certificate chain that the peer will send, for use + // with the TLS CachedInfo extension. If called, it must not be called + // before Init() or after Connect(). + void SetPredictedCertificates( + const std::vector<std::string>& predicted_certificates); + + // Called on the network task runner. + // + // Attempts to perform an SSL handshake. If the handshake cannot be + // completed synchronously, returns ERR_IO_PENDING, invoking |callback| on + // the network task runner once the handshake has completed. Otherwise, + // returns OK on success or a network error code on failure. + int Connect(const CompletionCallback& callback); + + // Called on the network task runner. + // Signals that the resources owned by the network task runner are going + // away. No further callbacks will be invoked on the network task runner. + // May be called at any time. + void Detach(); + + // Called on the network task runner. + // Returns the current state of the underlying SSL socket. May be called at + // any time. + const HandshakeState& state() const { return network_handshake_state_; } + + // Called on the network task runner. + // Read() and Write() mirror the net::Socket functions of the same name. + // If ERR_IO_PENDING is returned, |callback| will be invoked on the network + // task runner at a later point, unless the caller calls Detach(). + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + // Called on the network task runner. + bool IsConnected(); + bool HasPendingAsyncOperation(); + bool HasUnhandledReceivedData(); + + private: + friend class base::RefCountedThreadSafe<Core>; + ~Core(); + + enum State { + STATE_NONE, + STATE_HANDSHAKE, + STATE_GET_DOMAIN_BOUND_CERT_COMPLETE, + }; + + bool OnNSSTaskRunner() const; + bool OnNetworkTaskRunner() const; + + //////////////////////////////////////////////////////////////////////////// + // Methods that are ONLY called on the NSS task runner: + //////////////////////////////////////////////////////////////////////////// + + // Called by NSS during full handshakes to allow the application to + // verify the certificate. Instead of verifying the certificate in the midst + // of the handshake, SECSuccess is always returned and the peer's certificate + // is verified afterwards. + // This behaviour is an artifact of the original SSLClientSocketWin + // implementation, which could not verify the peer's certificate until after + // the handshake had completed, as well as bugs in NSS that prevent + // SSL_RestartHandshakeAfterCertReq from working. + static SECStatus OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server); + + // Callbacks called by NSS when the peer requests client certificate + // authentication. + // See the documentation in third_party/nss/ssl/ssl.h for the meanings of + // the arguments. +#if defined(NSS_PLATFORM_CLIENT_AUTH) + // When NSS has been integrated with awareness of the underlying system + // cryptographic libraries, this callback allows the caller to supply a + // native platform certificate and key for use by NSS. At most, one of + // either (result_certs, result_private_key) or (result_nss_certificate, + // result_nss_private_key) should be set. + // |arg| contains a pointer to the current SSLClientSocketNSS::Core. + static SECStatus PlatformClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertList** result_certs, + void** result_private_key, + CERTCertificate** result_nss_certificate, + SECKEYPrivateKey** result_nss_private_key); +#else + static SECStatus ClientAuthHandler(void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertificate** result_certificate, + SECKEYPrivateKey** result_private_key); +#endif + + // Called by NSS once the handshake has completed. + // |arg| contains a pointer to the current SSLClientSocketNSS::Core. + static void HandshakeCallback(PRFileDesc* socket, void* arg); + + // Handles an NSS error generated while handshaking or performing IO. + // Returns a network error code mapped from the original NSS error. + int HandleNSSError(PRErrorCode error, bool handshake_error); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + + int DoHandshake(); + int DoGetDBCertComplete(int result); + + int DoPayloadRead(); + int DoPayloadWrite(); + + bool DoTransportIO(); + int BufferRecv(); + int BufferSend(); + + void OnRecvComplete(int result); + void OnSendComplete(int result); + + void DoConnectCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + // Client channel ID handler. + static SECStatus ClientChannelIDHandler( + void* arg, + PRFileDesc* socket, + SECKEYPublicKey **out_public_key, + SECKEYPrivateKey **out_private_key); + + // ImportChannelIDKeys is a helper function for turning a DER-encoded cert and + // key into a SECKEYPublicKey and SECKEYPrivateKey. Returns OK upon success + // and an error code otherwise. + // Requires |domain_bound_private_key_| and |domain_bound_cert_| to have been + // set by a call to ServerBoundCertService->GetDomainBoundCert. The caller + // takes ownership of the |*cert| and |*key|. + int ImportChannelIDKeys(SECKEYPublicKey** public_key, SECKEYPrivateKey** key); + + // Updates the NSS and platform specific certificates. + void UpdateServerCert(); + // Updates the nss_handshake_state_ with the negotiated security parameters. + void UpdateConnectionStatus(); + // Record histograms for channel id support during full handshakes - resumed + // handshakes are ignored. + void RecordChannelIDSupport(); + // UpdateNextProto gets any application-layer protocol that may have been + // negotiated by the TLS connection. + void UpdateNextProto(); + + //////////////////////////////////////////////////////////////////////////// + // Methods that are ONLY called on the network task runner: + //////////////////////////////////////////////////////////////////////////// + int DoBufferRecv(IOBuffer* buffer, int len); + int DoBufferSend(IOBuffer* buffer, int len); + int DoGetDomainBoundCert(const std::string& host); + + void OnGetDomainBoundCertComplete(int result); + void OnHandshakeStateUpdated(const HandshakeState& state); + void OnNSSBufferUpdated(int amount_in_read_buffer); + void DidNSSRead(int result); + void DidNSSWrite(int result); + void RecordChannelIDSupportOnNetworkTaskRunner( + bool negotiated_channel_id, + bool channel_id_enabled, + bool supports_ecc) const; + + //////////////////////////////////////////////////////////////////////////// + // Methods that are called on both the network task runner and the NSS + // task runner. + //////////////////////////////////////////////////////////////////////////// + void OnHandshakeIOComplete(int result); + void BufferRecvComplete(IOBuffer* buffer, int result); + void BufferSendComplete(int result); + + // PostOrRunCallback is a helper function to ensure that |callback| is + // invoked on the network task runner, but only if Detach() has not yet + // been called. + void PostOrRunCallback(const tracked_objects::Location& location, + const base::Closure& callback); + + // Uses PostOrRunCallback and |weak_net_log_| to try and log a + // SSL_CLIENT_CERT_PROVIDED event, with the indicated count. + void AddCertProvidedEvent(int cert_count); + + // Sets the handshake state |channel_id_sent| flag and logs the + // SSL_CHANNEL_ID_PROVIDED event. + void SetChannelIDProvided(); + + //////////////////////////////////////////////////////////////////////////// + // Members that are ONLY accessed on the network task runner: + //////////////////////////////////////////////////////////////////////////// + + // True if the owning SSLClientSocketNSS has called Detach(). No further + // callbacks will be invoked nor access to members owned by the network + // task runner. + bool detached_; + + // The underlying transport to use for network IO. + ClientSocketHandle* transport_; + base::WeakPtrFactory<BoundNetLog> weak_net_log_factory_; + + // The current handshake state. Mirrors |nss_handshake_state_|. + HandshakeState network_handshake_state_; + + // The service for retrieving Channel ID keys. May be NULL. + ServerBoundCertService* server_bound_cert_service_; + ServerBoundCertService::RequestHandle domain_bound_cert_request_handle_; + + // The information about NSS task runner. + int unhandled_buffer_size_; + bool nss_waiting_read_; + bool nss_waiting_write_; + bool nss_is_closed_; + + //////////////////////////////////////////////////////////////////////////// + // Members that are ONLY accessed on the NSS task runner: + //////////////////////////////////////////////////////////////////////////// + HostPortPair host_and_port_; + SSLConfig ssl_config_; + + // NSS SSL socket. + PRFileDesc* nss_fd_; + + // Buffers for the network end of the SSL state machine + memio_Private* nss_bufs_; + + // Used by DoPayloadRead() when attempting to fill the caller's buffer with + // as much data as possible, without blocking. + // If DoPayloadRead() encounters an error after having read some data, stores + // the results to return on the *next* call to DoPayloadRead(). A value of + // kNoPendingReadResult indicates there is no pending result, otherwise 0 + // indicates EOF and < 0 indicates an error. + int pending_read_result_; + // Contains the previously observed NSS error. Only valid when + // pending_read_result_ != kNoPendingReadResult. + PRErrorCode pending_read_nss_error_; + + // The certificate chain, in DER form, that is expected to be received from + // the server. + std::vector<std::string> predicted_certs_; + + State next_handshake_state_; + + // True if channel ID extension was negotiated. + bool channel_id_xtn_negotiated_; + // True if the handshake state machine was interrupted for channel ID. + bool channel_id_needed_; + // True if the handshake state machine was interrupted for client auth. + bool client_auth_cert_needed_; + // True if NSS has called HandshakeCallback. + bool handshake_callback_called_; + + HandshakeState nss_handshake_state_; + + bool transport_recv_busy_; + bool transport_recv_eof_; + bool transport_send_busy_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + CompletionCallback user_connect_callback_; + CompletionCallback user_read_callback_; + CompletionCallback user_write_callback_; + + //////////////////////////////////////////////////////////////////////////// + // Members that are accessed on both the network task runner and the NSS + // task runner. + //////////////////////////////////////////////////////////////////////////// + scoped_refptr<base::SequencedTaskRunner> network_task_runner_; + scoped_refptr<base::SequencedTaskRunner> nss_task_runner_; + + // Dereferenced only on the network task runner, but bound to tasks destined + // for the network task runner from the NSS task runner. + base::WeakPtr<BoundNetLog> weak_net_log_; + + // Written on the network task runner by the |server_bound_cert_service_|, + // prior to invoking OnHandshakeIOComplete. + // Read on the NSS task runner when once OnHandshakeIOComplete is invoked + // on the NSS task runner. + std::string domain_bound_private_key_; + std::string domain_bound_cert_; + + DISALLOW_COPY_AND_ASSIGN(Core); +}; + +SSLClientSocketNSS::Core::Core( + base::SequencedTaskRunner* network_task_runner, + base::SequencedTaskRunner* nss_task_runner, + ClientSocketHandle* transport, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + BoundNetLog* net_log, + ServerBoundCertService* server_bound_cert_service) + : detached_(false), + transport_(transport), + weak_net_log_factory_(net_log), + server_bound_cert_service_(server_bound_cert_service), + unhandled_buffer_size_(0), + nss_waiting_read_(false), + nss_waiting_write_(false), + nss_is_closed_(false), + host_and_port_(host_and_port), + ssl_config_(ssl_config), + nss_fd_(NULL), + nss_bufs_(NULL), + pending_read_result_(kNoPendingReadResult), + pending_read_nss_error_(0), + next_handshake_state_(STATE_NONE), + channel_id_xtn_negotiated_(false), + channel_id_needed_(false), + client_auth_cert_needed_(false), + handshake_callback_called_(false), + transport_recv_busy_(false), + transport_recv_eof_(false), + transport_send_busy_(false), + user_read_buf_len_(0), + user_write_buf_len_(0), + network_task_runner_(network_task_runner), + nss_task_runner_(nss_task_runner), + weak_net_log_(weak_net_log_factory_.GetWeakPtr()) { +} + +SSLClientSocketNSS::Core::~Core() { + // TODO(wtc): Send SSL close_notify alert. + if (nss_fd_ != NULL) { + PR_Close(nss_fd_); + nss_fd_ = NULL; + } +} + +bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, + memio_Private* buffers) { + DCHECK(OnNetworkTaskRunner()); + DCHECK(!nss_fd_); + DCHECK(!nss_bufs_); + + nss_fd_ = socket; + nss_bufs_ = buffers; + + SECStatus rv = SECSuccess; + + if (!ssl_config_.next_protos.empty()) { + size_t wire_length = 0; + for (std::vector<std::string>::const_iterator + i = ssl_config_.next_protos.begin(); + i != ssl_config_.next_protos.end(); ++i) { + if (i->size() > 255) { + LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i; + continue; + } + wire_length += i->size(); + wire_length++; + } + scoped_ptr<uint8[]> wire_protos(new uint8[wire_length]); + uint8* dst = wire_protos.get(); + for (std::vector<std::string>::const_iterator + i = ssl_config_.next_protos.begin(); + i != ssl_config_.next_protos.end(); i++) { + if (i->size() > 255) + continue; + *dst++ = i->size(); + memcpy(dst, i->data(), i->size()); + dst += i->size(); + } + DCHECK_EQ(dst, wire_protos.get() + wire_length); + rv = SSL_SetNextProtoNego(nss_fd_, wire_protos.get(), wire_length); + if (rv != SECSuccess) + LogFailedNSSFunction(*weak_net_log_, "SSL_SetNextProtoCallback", ""); + } + + rv = SSL_AuthCertificateHook( + nss_fd_, SSLClientSocketNSS::Core::OwnAuthCertHandler, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(*weak_net_log_, "SSL_AuthCertificateHook", ""); + return false; + } + +#if defined(NSS_PLATFORM_CLIENT_AUTH) + rv = SSL_GetPlatformClientAuthDataHook( + nss_fd_, SSLClientSocketNSS::Core::PlatformClientAuthHandler, + this); +#else + rv = SSL_GetClientAuthDataHook( + nss_fd_, SSLClientSocketNSS::Core::ClientAuthHandler, this); +#endif + if (rv != SECSuccess) { + LogFailedNSSFunction(*weak_net_log_, "SSL_GetClientAuthDataHook", ""); + return false; + } + + if (ssl_config_.channel_id_enabled) { + if (!server_bound_cert_service_) { + DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID."; + } else if (!crypto::ECPrivateKey::IsSupported()) { + DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID."; + } else if (!server_bound_cert_service_->IsSystemTimeValid()) { + DVLOG(1) << "System time is weird, not enabling channel ID."; + } else { + rv = SSL_SetClientChannelIDCallback( + nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this); + if (rv != SECSuccess) + LogFailedNSSFunction(*weak_net_log_, "SSL_SetClientChannelIDCallback", + ""); + } + } + + rv = SSL_HandshakeCallback( + nss_fd_, SSLClientSocketNSS::Core::HandshakeCallback, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(*weak_net_log_, "SSL_HandshakeCallback", ""); + return false; + } + + return true; +} + +void SSLClientSocketNSS::Core::SetPredictedCertificates( + const std::vector<std::string>& predicted_certs) { + if (predicted_certs.empty()) + return; + + if (!OnNSSTaskRunner()) { + DCHECK(!detached_); + nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(&Core::SetPredictedCertificates, this, predicted_certs)); + return; + } + + DCHECK(nss_fd_); + + predicted_certs_ = predicted_certs; + + scoped_ptr<CERTCertificate*[]> certs( + new CERTCertificate*[predicted_certs.size()]); + + for (size_t i = 0; i < predicted_certs.size(); i++) { + SECItem derCert; + derCert.data = const_cast<uint8*>(reinterpret_cast<const uint8*>( + predicted_certs[i].data())); + derCert.len = predicted_certs[i].size(); + certs[i] = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &derCert, NULL /* no nickname given */, + PR_FALSE /* not permanent */, PR_TRUE /* copy DER data */); + if (!certs[i]) { + DestroyCertificates(&certs[0], i); + NOTREACHED(); + return; + } + } + + SECStatus rv; +#ifdef SSL_ENABLE_CACHED_INFO + rv = SSL_SetPredictedPeerCertificates(nss_fd_, certs.get(), + predicted_certs.size()); + DCHECK_EQ(SECSuccess, rv); +#else + rv = SECFailure; // Not implemented. +#endif + DestroyCertificates(&certs[0], predicted_certs.size()); + + if (rv != SECSuccess) { + LOG(WARNING) << "SetPredictedCertificates failed: " + << host_and_port_.ToString(); + } +} + +int SSLClientSocketNSS::Core::Connect(const CompletionCallback& callback) { + if (!OnNSSTaskRunner()) { + DCHECK(!detached_); + bool posted = nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(IgnoreResult(&Core::Connect), this, callback)); + return posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + DCHECK(OnNSSTaskRunner()); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + DCHECK(user_read_callback_.is_null()); + DCHECK(user_write_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); + DCHECK(!user_read_buf_.get()); + DCHECK(!user_write_buf_.get()); + + next_handshake_state_ = STATE_HANDSHAKE; + int rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + user_connect_callback_ = callback; + } else if (rv > OK) { + rv = OK; + } + if (rv != ERR_IO_PENDING && !OnNetworkTaskRunner()) { + PostOrRunCallback(FROM_HERE, base::Bind(callback, rv)); + return ERR_IO_PENDING; + } + + return rv; +} + +void SSLClientSocketNSS::Core::Detach() { + DCHECK(OnNetworkTaskRunner()); + + detached_ = true; + transport_ = NULL; + weak_net_log_factory_.InvalidateWeakPtrs(); + + network_handshake_state_.Reset(); + + domain_bound_cert_request_handle_.Cancel(); +} + +int SSLClientSocketNSS::Core::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!OnNSSTaskRunner()) { + DCHECK(OnNetworkTaskRunner()); + DCHECK(!detached_); + DCHECK(transport_); + DCHECK(!nss_waiting_read_); + + nss_waiting_read_ = true; + bool posted = nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(IgnoreResult(&Core::Read), this, make_scoped_refptr(buf), + buf_len, callback)); + if (!posted) { + nss_is_closed_ = true; + nss_waiting_read_ = false; + } + return posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + DCHECK(OnNSSTaskRunner()); + DCHECK(handshake_callback_called_); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + DCHECK(user_read_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); + DCHECK(!user_read_buf_.get()); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoReadLoop(OK); + if (rv == ERR_IO_PENDING) { + if (OnNetworkTaskRunner()) + nss_waiting_read_ = true; + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + + if (!OnNetworkTaskRunner()) { + PostOrRunCallback(FROM_HERE, base::Bind(&Core::DidNSSRead, this, rv)); + PostOrRunCallback(FROM_HERE, base::Bind(callback, rv)); + return ERR_IO_PENDING; + } else { + DCHECK(!nss_waiting_read_); + if (rv <= 0) + nss_is_closed_ = true; + } + } + + return rv; +} + +int SSLClientSocketNSS::Core::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!OnNSSTaskRunner()) { + DCHECK(OnNetworkTaskRunner()); + DCHECK(!detached_); + DCHECK(transport_); + DCHECK(!nss_waiting_write_); + + nss_waiting_write_ = true; + bool posted = nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(IgnoreResult(&Core::Write), this, make_scoped_refptr(buf), + buf_len, callback)); + if (!posted) { + nss_is_closed_ = true; + nss_waiting_write_ = false; + } + return posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + DCHECK(OnNSSTaskRunner()); + DCHECK(handshake_callback_called_); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + DCHECK(user_write_callback_.is_null()); + DCHECK(user_connect_callback_.is_null()); + DCHECK(!user_write_buf_.get()); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + if (rv == ERR_IO_PENDING) { + if (OnNetworkTaskRunner()) + nss_waiting_write_ = true; + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + + if (!OnNetworkTaskRunner()) { + PostOrRunCallback(FROM_HERE, base::Bind(&Core::DidNSSWrite, this, rv)); + PostOrRunCallback(FROM_HERE, base::Bind(callback, rv)); + return ERR_IO_PENDING; + } else { + DCHECK(!nss_waiting_write_); + if (rv < 0) + nss_is_closed_ = true; + } + } + + return rv; +} + +bool SSLClientSocketNSS::Core::IsConnected() { + DCHECK(OnNetworkTaskRunner()); + return !nss_is_closed_; +} + +bool SSLClientSocketNSS::Core::HasPendingAsyncOperation() { + DCHECK(OnNetworkTaskRunner()); + return nss_waiting_read_ || nss_waiting_write_; +} + +bool SSLClientSocketNSS::Core::HasUnhandledReceivedData() { + DCHECK(OnNetworkTaskRunner()); + return unhandled_buffer_size_ != 0; +} + +bool SSLClientSocketNSS::Core::OnNSSTaskRunner() const { + return nss_task_runner_->RunsTasksOnCurrentThread(); +} + +bool SSLClientSocketNSS::Core::OnNetworkTaskRunner() const { + return network_task_runner_->RunsTasksOnCurrentThread(); +} + +// static +SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler( + void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + Core* core = reinterpret_cast<Core*>(arg); + if (!core->handshake_callback_called_) { + // Only need to turn off False Start in the initial handshake. Also, it is + // unsafe to call SSL_OptionSet in a renegotiation because the "first + // handshake" lock isn't already held, which will result in an assertion + // failure in the ssl_Get1stHandshakeLock call in SSL_OptionSet. + PRBool negotiated_extension; + SECStatus rv = SSL_HandshakeNegotiatedExtension(socket, + ssl_app_layer_protocol_xtn, + &negotiated_extension); + if (rv != SECSuccess || !negotiated_extension) { + rv = SSL_HandshakeNegotiatedExtension(socket, + ssl_next_proto_nego_xtn, + &negotiated_extension); + } + if (rv != SECSuccess || !negotiated_extension) { + // If the server doesn't support NPN or ALPN, then we don't do False + // Start with it. + SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); + } + } + + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +#if defined(NSS_PLATFORM_CLIENT_AUTH) +// static +SECStatus SSLClientSocketNSS::Core::PlatformClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertList** result_certs, + void** result_private_key, + CERTCertificate** result_nss_certificate, + SECKEYPrivateKey** result_nss_private_key) { + Core* core = reinterpret_cast<Core*>(arg); + DCHECK(core->OnNSSTaskRunner()); + + core->PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEvent, core->weak_net_log_, + NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED)); + + core->client_auth_cert_needed_ = !core->ssl_config_.send_client_cert; +#if defined(OS_WIN) + if (core->ssl_config_.send_client_cert) { + if (core->ssl_config_.client_cert) { + PCCERT_CONTEXT cert_context = + core->ssl_config_.client_cert->os_cert_handle(); + + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE crypt_prov = 0; + DWORD key_spec = 0; + BOOL must_free = FALSE; + DWORD flags = 0; + if (base::win::GetVersion() >= base::win::VERSION_VISTA) + flags |= CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG; + + BOOL acquired_key = CryptAcquireCertificatePrivateKey( + cert_context, flags, NULL, &crypt_prov, &key_spec, &must_free); + + if (acquired_key) { + // Should never get a cached handle back - ownership must always be + // transferred. + CHECK_EQ(must_free, TRUE); + + SECItem der_cert; + der_cert.type = siDERCertBuffer; + der_cert.data = cert_context->pbCertEncoded; + der_cert.len = cert_context->cbCertEncoded; + + // TODO(rsleevi): Error checking for NSS allocation errors. + CERTCertDBHandle* db_handle = CERT_GetDefaultCertDB(); + CERTCertificate* user_cert = CERT_NewTempCertificate( + db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); + if (!user_cert) { + // Importing the certificate can fail for reasons including a serial + // number collision. See crbug.com/97355. + core->AddCertProvidedEvent(0); + return SECFailure; + } + CERTCertList* cert_chain = CERT_NewCertList(); + CERT_AddCertToListTail(cert_chain, user_cert); + + // Add the intermediates. + X509Certificate::OSCertHandles intermediates = + core->ssl_config_.client_cert->GetIntermediateCertificates(); + for (X509Certificate::OSCertHandles::const_iterator it = + intermediates.begin(); it != intermediates.end(); ++it) { + der_cert.data = (*it)->pbCertEncoded; + der_cert.len = (*it)->cbCertEncoded; + + CERTCertificate* intermediate = CERT_NewTempCertificate( + db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); + if (!intermediate) { + CERT_DestroyCertList(cert_chain); + core->AddCertProvidedEvent(0); + return SECFailure; + } + CERT_AddCertToListTail(cert_chain, intermediate); + } + PCERT_KEY_CONTEXT key_context = reinterpret_cast<PCERT_KEY_CONTEXT>( + PORT_ZAlloc(sizeof(CERT_KEY_CONTEXT))); + key_context->cbSize = sizeof(*key_context); + // NSS will free this context when no longer in use. + key_context->hCryptProv = crypt_prov; + key_context->dwKeySpec = key_spec; + *result_private_key = key_context; + *result_certs = cert_chain; + + int cert_count = 1 + intermediates.size(); + core->AddCertProvidedEvent(cert_count); + return SECSuccess; + } + LOG(WARNING) << "Client cert found without private key"; + } + + // Send no client certificate. + core->AddCertProvidedEvent(0); + return SECFailure; + } + + core->nss_handshake_state_.cert_authorities.clear(); + + std::vector<CERT_NAME_BLOB> issuer_list(ca_names->nnames); + for (int i = 0; i < ca_names->nnames; ++i) { + issuer_list[i].cbData = ca_names->names[i].len; + issuer_list[i].pbData = ca_names->names[i].data; + core->nss_handshake_state_.cert_authorities.push_back(std::string( + reinterpret_cast<const char*>(ca_names->names[i].data), + static_cast<size_t>(ca_names->names[i].len))); + } + + // Update the network task runner's view of the handshake state now that + // server certificate request has been recorded. + core->PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, + core->nss_handshake_state_)); + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +#elif defined(OS_MACOSX) + if (core->ssl_config_.send_client_cert) { + if (core->ssl_config_.client_cert.get()) { + OSStatus os_error = noErr; + SecIdentityRef identity = NULL; + SecKeyRef private_key = NULL; + X509Certificate::OSCertHandles chain; + { + base::AutoLock lock(crypto::GetMacSecurityServicesLock()); + os_error = SecIdentityCreateWithCertificate( + NULL, core->ssl_config_.client_cert->os_cert_handle(), &identity); + } + if (os_error == noErr) { + os_error = SecIdentityCopyPrivateKey(identity, &private_key); + CFRelease(identity); + } + + if (os_error == noErr) { + // TODO(rsleevi): Error checking for NSS allocation errors. + *result_certs = CERT_NewCertList(); + *result_private_key = private_key; + + chain.push_back(core->ssl_config_.client_cert->os_cert_handle()); + const X509Certificate::OSCertHandles& intermediates = + core->ssl_config_.client_cert->GetIntermediateCertificates(); + if (!intermediates.empty()) + chain.insert(chain.end(), intermediates.begin(), intermediates.end()); + + for (size_t i = 0, chain_count = chain.size(); i < chain_count; ++i) { + CSSM_DATA cert_data; + SecCertificateRef cert_ref = chain[i]; + os_error = SecCertificateGetData(cert_ref, &cert_data); + if (os_error != noErr) + break; + + SECItem der_cert; + der_cert.type = siDERCertBuffer; + der_cert.data = cert_data.Data; + der_cert.len = cert_data.Length; + CERTCertificate* nss_cert = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); + if (!nss_cert) { + // In the event of an NSS error, make up an OS error and reuse + // the error handling below. + os_error = errSecCreateChainFailed; + break; + } + CERT_AddCertToListTail(*result_certs, nss_cert); + } + } + + if (os_error == noErr) { + core->AddCertProvidedEvent(chain.size()); + return SECSuccess; + } + + OSSTATUS_LOG(WARNING, os_error) + << "Client cert found, but could not be used"; + if (*result_certs) { + CERT_DestroyCertList(*result_certs); + *result_certs = NULL; + } + if (*result_private_key) + *result_private_key = NULL; + if (private_key) + CFRelease(private_key); + } + + // Send no client certificate. + core->AddCertProvidedEvent(0); + return SECFailure; + } + + core->nss_handshake_state_.cert_authorities.clear(); + + // Retrieve the cert issuers accepted by the server. + std::vector<CertPrincipal> valid_issuers; + int n = ca_names->nnames; + for (int i = 0; i < n; i++) { + core->nss_handshake_state_.cert_authorities.push_back(std::string( + reinterpret_cast<const char*>(ca_names->names[i].data), + static_cast<size_t>(ca_names->names[i].len))); + } + + // Update the network task runner's view of the handshake state now that + // server certificate request has been recorded. + core->PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, + core->nss_handshake_state_)); + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +#else + return SECFailure; +#endif +} + +#elif defined(OS_IOS) + +SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertificate** result_certificate, + SECKEYPrivateKey** result_private_key) { + Core* core = reinterpret_cast<Core*>(arg); + DCHECK(core->OnNSSTaskRunner()); + + core->PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEvent, core->weak_net_log_, + NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED)); + + // TODO(droger): Support client auth on iOS. See http://crbug.com/145954). + LOG(WARNING) << "Client auth is not supported"; + + // Never send a certificate. + core->AddCertProvidedEvent(0); + return SECFailure; +} + +#else // NSS_PLATFORM_CLIENT_AUTH + +// static +// Based on Mozilla's NSS_GetClientAuthData. +SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertificate** result_certificate, + SECKEYPrivateKey** result_private_key) { + Core* core = reinterpret_cast<Core*>(arg); + DCHECK(core->OnNSSTaskRunner()); + + core->PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEvent, core->weak_net_log_, + NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED)); + + // Regular client certificate requested. + core->client_auth_cert_needed_ = !core->ssl_config_.send_client_cert; + void* wincx = SSL_RevealPinArg(socket); + + if (core->ssl_config_.send_client_cert) { + // Second pass: a client certificate should have been selected. + if (core->ssl_config_.client_cert.get()) { + CERTCertificate* cert = + CERT_DupCertificate(core->ssl_config_.client_cert->os_cert_handle()); + SECKEYPrivateKey* privkey = PK11_FindKeyByAnyCert(cert, wincx); + if (privkey) { + // TODO(jsorianopastor): We should wait for server certificate + // verification before sending our credentials. See + // http://crbug.com/13934. + *result_certificate = cert; + *result_private_key = privkey; + // A cert_count of -1 means the number of certificates is unknown. + // NSS will construct the certificate chain. + core->AddCertProvidedEvent(-1); + + return SECSuccess; + } + LOG(WARNING) << "Client cert found without private key"; + } + // Send no client certificate. + core->AddCertProvidedEvent(0); + return SECFailure; + } + + // First pass: client certificate is needed. + core->nss_handshake_state_.cert_authorities.clear(); + + // Retrieve the DER-encoded DistinguishedName of the cert issuers accepted by + // the server and save them in |cert_authorities|. + for (int i = 0; i < ca_names->nnames; i++) { + core->nss_handshake_state_.cert_authorities.push_back(std::string( + reinterpret_cast<const char*>(ca_names->names[i].data), + static_cast<size_t>(ca_names->names[i].len))); + } + + // Update the network task runner's view of the handshake state now that + // server certificate request has been recorded. + core->PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, + core->nss_handshake_state_)); + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +} +#endif // NSS_PLATFORM_CLIENT_AUTH + +// static +void SSLClientSocketNSS::Core::HandshakeCallback( + PRFileDesc* socket, + void* arg) { + Core* core = reinterpret_cast<Core*>(arg); + DCHECK(core->OnNSSTaskRunner()); + + core->handshake_callback_called_ = true; + + HandshakeState* nss_state = &core->nss_handshake_state_; + + PRBool last_handshake_resumed; + SECStatus rv = SSL_HandshakeResumedSession(socket, &last_handshake_resumed); + if (rv == SECSuccess && last_handshake_resumed) { + nss_state->resumed_handshake = true; + } else { + nss_state->resumed_handshake = false; + } + + core->RecordChannelIDSupport(); + core->UpdateServerCert(); + core->UpdateConnectionStatus(); + core->UpdateNextProto(); + + // Update the network task runners view of the handshake state whenever + // a handshake has completed. + core->PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, + *nss_state)); +} + +int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error, + bool handshake_error) { + DCHECK(OnNSSTaskRunner()); + + int net_error = handshake_error ? MapNSSClientHandshakeError(nss_error) : + MapNSSClientError(nss_error); + +#if defined(OS_WIN) + // On Windows, a handle to the HCRYPTPROV is cached in the X509Certificate + // os_cert_handle() as an optimization. However, if the certificate + // private key is stored on a smart card, and the smart card is removed, + // the cached HCRYPTPROV will not be able to obtain the HCRYPTKEY again, + // preventing client certificate authentication. Because the + // X509Certificate may outlive the individual SSLClientSocketNSS, due to + // caching in X509Certificate, this failure ends up preventing client + // certificate authentication with the same certificate for all future + // attempts, even after the smart card has been re-inserted. By setting + // the CERT_KEY_PROV_HANDLE_PROP_ID to NULL, the cached HCRYPTPROV will + // typically be freed. This allows a new HCRYPTPROV to be obtained from + // the certificate on the next attempt, which should succeed if the smart + // card has been re-inserted, or will typically prompt the user to + // re-insert the smart card if not. + if ((net_error == ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY || + net_error == ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED) && + ssl_config_.send_client_cert && ssl_config_.client_cert) { + CertSetCertificateContextProperty( + ssl_config_.client_cert->os_cert_handle(), + CERT_KEY_PROV_HANDLE_PROP_ID, 0, NULL); + } +#endif + + return net_error; +} + +int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { + DCHECK(OnNSSTaskRunner()); + + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + + switch (state) { + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + case STATE_GET_DOMAIN_BOUND_CERT_COMPLETE: + rv = DoGetDBCertComplete(rv); + break; + case STATE_NONE: + default: + rv = ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state " << state; + break; + } + + // Do the actual network I/O + bool network_moved = DoTransportIO(); + if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) { + // In general we exit the loop if rv is ERR_IO_PENDING. In this + // special case we keep looping even if rv is ERR_IO_PENDING because + // the transport IO may allow DoHandshake to make progress. + DCHECK(rv == OK || rv == ERR_IO_PENDING); + rv = OK; // This causes us to stay in the loop. + } + } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); + return rv; +} + +int SSLClientSocketNSS::Core::DoReadLoop(int result) { + DCHECK(OnNSSTaskRunner()); + DCHECK(handshake_callback_called_); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadRead(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + + return rv; +} + +int SSLClientSocketNSS::Core::DoWriteLoop(int result) { + DCHECK(OnNSSTaskRunner()); + DCHECK(handshake_callback_called_); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, 0))); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + + LeaveFunction(rv); + return rv; +} + +int SSLClientSocketNSS::Core::DoHandshake() { + DCHECK(OnNSSTaskRunner()); + + int net_error = net::OK; + SECStatus rv = SSL_ForceHandshake(nss_fd_); + + // Note: this function may be called multiple times during the handshake, so + // even though channel id and client auth are separate else cases, they can + // both be used during a single SSL handshake. + if (channel_id_needed_) { + GotoState(STATE_GET_DOMAIN_BOUND_CERT_COMPLETE); + net_error = ERR_IO_PENDING; + } else if (client_auth_cert_needed_) { + net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogSSLErrorCallback(net_error, 0))); + + // If the handshake already succeeded (because the server requests but + // doesn't require a client cert), we need to invalidate the SSL session + // so that we won't try to resume the non-client-authenticated session in + // the next handshake. This will cause the server to ask for a client + // cert again. + if (rv == SECSuccess && SSL_InvalidateSession(nss_fd_) != SECSuccess) + LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError(); + } else if (rv == SECSuccess) { + if (!handshake_callback_called_) { + // Workaround for https://bugzilla.mozilla.org/show_bug.cgi?id=562434 - + // SSL_ForceHandshake returned SECSuccess prematurely. + rv = SECFailure; + net_error = ERR_SSL_PROTOCOL_ERROR; + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogSSLErrorCallback(net_error, 0))); + } else { + #if defined(SSL_ENABLE_OCSP_STAPLING) + // TODO(agl): figure out how to plumb an OCSP response into the Mac + // system library and update IsOCSPStaplingSupported for Mac. + if (IsOCSPStaplingSupported()) { + const SECItemArray* ocsp_responses = + SSL_PeerStapledOCSPResponses(nss_fd_); + if (ocsp_responses->len) { + #if defined(OS_WIN) + if (nss_handshake_state_.server_cert) { + CRYPT_DATA_BLOB ocsp_response_blob; + ocsp_response_blob.cbData = ocsp_responses->items[0].len; + ocsp_response_blob.pbData = ocsp_responses->items[0].data; + BOOL ok = CertSetCertificateContextProperty( + nss_handshake_state_.server_cert->os_cert_handle(), + CERT_OCSP_RESPONSE_PROP_ID, + CERT_SET_PROPERTY_IGNORE_PERSIST_ERROR_FLAG, + &ocsp_response_blob); + if (!ok) { + VLOG(1) << "Failed to set OCSP response property: " + << GetLastError(); + } + } + #elif defined(USE_NSS) + CacheOCSPResponseFromSideChannelFunction cache_ocsp_response = + GetCacheOCSPResponseFromSideChannelFunction(); + + cache_ocsp_response( + CERT_GetDefaultCertDB(), + nss_handshake_state_.server_cert_chain[0], PR_Now(), + &ocsp_responses->items[0], NULL); + #endif + } + } + #endif + } + // Done! + } else { + PRErrorCode prerr = PR_GetError(); + net_error = HandleNSSError(prerr, true); + + // Some network devices that inspect application-layer packets seem to + // inject TCP reset packets to break the connections when they see + // TLS 1.1 in ClientHello or ServerHello. See http://crbug.com/130293. + // + // Only allow ERR_CONNECTION_RESET to trigger a fallback from TLS 1.1 or + // 1.2. We don't lose much in this fallback because the explicit IV for CBC + // mode in TLS 1.1 is approximated by record splitting in TLS 1.0. The + // fallback will be more painful for TLS 1.2 when we have GCM support. + // + // ERR_CONNECTION_RESET is a common network error, so we don't want it + // to trigger a version fallback in general, especially the TLS 1.0 -> + // SSL 3.0 fallback, which would drop TLS extensions. + if (prerr == PR_CONNECT_RESET_ERROR && + ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1) { + net_error = ERR_SSL_PROTOCOL_ERROR; + } + + // If not done, stay in this state + if (net_error == ERR_IO_PENDING) { + GotoState(STATE_HANDSHAKE); + } else { + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogSSLErrorCallback(net_error, prerr))); + } + } + + return net_error; +} + +int SSLClientSocketNSS::Core::DoGetDBCertComplete(int result) { + SECStatus rv; + PostOrRunCallback( + FROM_HERE, + base::Bind(&BoundNetLog::EndEventWithNetErrorCode, weak_net_log_, + NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT, result)); + + channel_id_needed_ = false; + + if (result != OK) + return result; + + SECKEYPublicKey* public_key; + SECKEYPrivateKey* private_key; + int error = ImportChannelIDKeys(&public_key, &private_key); + if (error != OK) + return error; + + rv = SSL_RestartHandshakeAfterChannelIDReq(nss_fd_, public_key, private_key); + if (rv != SECSuccess) + return MapNSSError(PORT_GetError()); + + SetChannelIDProvided(); + GotoState(STATE_HANDSHAKE); + return OK; +} + +int SSLClientSocketNSS::Core::DoPayloadRead() { + DCHECK(OnNSSTaskRunner()); + DCHECK(user_read_buf_.get()); + DCHECK_GT(user_read_buf_len_, 0); + + int rv; + // If a previous greedy read resulted in an error that was not consumed (eg: + // due to the caller having read some data successfully), then return that + // pending error now. + if (pending_read_result_ != kNoPendingReadResult) { + rv = pending_read_result_; + PRErrorCode prerr = pending_read_nss_error_; + pending_read_result_ = kNoPendingReadResult; + pending_read_nss_error_ = 0; + + if (rv == 0) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, + scoped_refptr<IOBuffer>(user_read_buf_))); + } else { + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, prerr))); + } + return rv; + } + + // Perform a greedy read, attempting to read as much as the caller has + // requested. In the current NSS implementation, PR_Read will return + // exactly one SSL application data record's worth of data per invocation. + // The record size is dictated by the server, and may be noticeably smaller + // than the caller's buffer. This may be as little as a single byte, if the + // server is performing 1/n-1 record splitting. + // + // However, this greedy read may result in renegotiations/re-handshakes + // happening or may lead to some data being read, followed by an EOF (such as + // a TLS close-notify). If at least some data was read, then that result + // should be deferred until the next call to DoPayloadRead(). Otherwise, if no + // data was read, it's safe to return the error or EOF immediately. + int total_bytes_read = 0; + do { + rv = PR_Read(nss_fd_, user_read_buf_->data() + total_bytes_read, + user_read_buf_len_ - total_bytes_read); + if (rv > 0) + total_bytes_read += rv; + } while (total_bytes_read < user_read_buf_len_ && rv > 0); + int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_); + PostOrRunCallback(FROM_HERE, base::Bind(&Core::OnNSSBufferUpdated, this, + amount_in_read_buffer)); + + if (total_bytes_read == user_read_buf_len_) { + // The caller's entire request was satisfied without error. No further + // processing needed. + rv = total_bytes_read; + } else { + // Otherwise, an error occurred (rv <= 0). The error needs to be handled + // immediately, while the NSPR/NSS errors are still available in + // thread-local storage. However, the handled/remapped error code should + // only be returned if no application data was already read; if it was, the + // error code should be deferred until the next call of DoPayloadRead. + // + // If no data was read, |*next_result| will point to the return value of + // this function. If at least some data was read, |*next_result| will point + // to |pending_read_error_|, to be returned in a future call to + // DoPayloadRead() (e.g.: after the current data is handled). + int* next_result = &rv; + if (total_bytes_read > 0) { + pending_read_result_ = rv; + rv = total_bytes_read; + next_result = &pending_read_result_; + } + + if (client_auth_cert_needed_) { + *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + pending_read_nss_error_ = 0; + } else if (*next_result < 0) { + // If *next_result == 0, then that indicates EOF, and no special error + // handling is needed. + pending_read_nss_error_ = PR_GetError(); + *next_result = HandleNSSError(pending_read_nss_error_, false); + if (rv > 0 && *next_result == ERR_IO_PENDING) { + // If at least some data was read from PR_Read(), do not treat + // insufficient data as an error to return in the next call to + // DoPayloadRead() - instead, let the call fall through to check + // PR_Read() again. This is because DoTransportIO() may complete + // in between the next call to DoPayloadRead(), and thus it is + // important to check PR_Read() on subsequent invocations to see + // if a complete record may now be read. + pending_read_nss_error_ = 0; + pending_read_result_ = kNoPendingReadResult; + } + } + } + + DCHECK_NE(ERR_IO_PENDING, pending_read_result_); + + if (rv >= 0) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, + scoped_refptr<IOBuffer>(user_read_buf_))); + } else if (rv != ERR_IO_PENDING) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, pending_read_nss_error_))); + pending_read_nss_error_ = 0; + } + return rv; +} + +int SSLClientSocketNSS::Core::DoPayloadWrite() { + DCHECK(OnNSSTaskRunner()); + + DCHECK(user_write_buf_.get()); + + int old_amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_); + int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); + int new_amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_); + // PR_Write could potentially consume the unhandled data in the memio read + // buffer if a renegotiation is in progress. If the buffer is consumed, + // notify the latest buffer size to NetworkRunner. + if (old_amount_in_read_buffer != new_amount_in_read_buffer) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::OnNSSBufferUpdated, this, new_amount_in_read_buffer)); + } + if (rv >= 0) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, + scoped_refptr<IOBuffer>(user_write_buf_))); + return rv; + } + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) + return ERR_IO_PENDING; + + rv = HandleNSSError(prerr, false); + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_WRITE_ERROR, + CreateNetLogSSLErrorCallback(rv, prerr))); + return rv; +} + +// Do as much network I/O as possible between the buffer and the +// transport socket. Return true if some I/O performed, false +// otherwise (error or ERR_IO_PENDING). +bool SSLClientSocketNSS::Core::DoTransportIO() { + DCHECK(OnNSSTaskRunner()); + + bool network_moved = false; + if (nss_bufs_ != NULL) { + int rv; + // Read and write as much data as we can. The loop is neccessary + // because Write() may return synchronously. + do { + rv = BufferSend(); + if (rv != ERR_IO_PENDING && rv != 0) + network_moved = true; + } while (rv > 0); + if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING) + network_moved = true; + } + return network_moved; +} + +int SSLClientSocketNSS::Core::BufferRecv() { + DCHECK(OnNSSTaskRunner()); + + if (transport_recv_busy_) + return ERR_IO_PENDING; + + // If NSS is blocked on reading from |nss_bufs_|, because it is empty, + // determine how much data NSS wants to read. If NSS was not blocked, + // this will return 0. + int requested = memio_GetReadRequest(nss_bufs_); + if (requested == 0) { + // This is not a perfect match of error codes, as no operation is + // actually pending. However, returning 0 would be interpreted as a + // possible sign of EOF, which is also an inappropriate match. + return ERR_IO_PENDING; + } + + char* buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + int rv; + if (!nb) { + // buffer too full to read into, so no I/O possible at moment + rv = ERR_IO_PENDING; + } else { + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(nb)); + if (OnNetworkTaskRunner()) { + rv = DoBufferRecv(read_buffer.get(), nb); + } else { + bool posted = network_task_runner_->PostTask( + FROM_HERE, + base::Bind(IgnoreResult(&Core::DoBufferRecv), this, read_buffer, + nb)); + rv = posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + if (rv > 0) { + memcpy(buf, read_buffer->data(), rv); + } else if (rv == 0) { + transport_recv_eof_ = true; + } + memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + return rv; +} + +// Return 0 if nss_bufs_ was empty, +// > 0 for bytes transferred immediately, +// < 0 for error (or the non-error ERR_IO_PENDING). +int SSLClientSocketNSS::Core::BufferSend() { + DCHECK(OnNSSTaskRunner()); + + if (transport_send_busy_) + return ERR_IO_PENDING; + + const char* buf1; + const char* buf2; + unsigned int len1, len2; + memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); + const unsigned int len = len1 + len2; + + int rv = 0; + if (len) { + scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); + memcpy(send_buffer->data(), buf1, len1); + memcpy(send_buffer->data() + len1, buf2, len2); + + if (OnNetworkTaskRunner()) { + rv = DoBufferSend(send_buffer.get(), len); + } else { + bool posted = network_task_runner_->PostTask( + FROM_HERE, + base::Bind(IgnoreResult(&Core::DoBufferSend), this, send_buffer, + len)); + rv = posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + + return rv; +} + +void SSLClientSocketNSS::Core::OnRecvComplete(int result) { + DCHECK(OnNSSTaskRunner()); + + if (next_handshake_state_ == STATE_HANDSHAKE) { + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_.get()) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +void SSLClientSocketNSS::Core::OnSendComplete(int result) { + DCHECK(OnNSSTaskRunner()); + + if (next_handshake_state_ == STATE_HANDSHAKE) { + OnHandshakeIOComplete(result); + return; + } + + // OnSendComplete may need to call DoPayloadRead while the renegotiation + // handshake is in progress. + int rv_read = ERR_IO_PENDING; + int rv_write = ERR_IO_PENDING; + bool network_moved; + do { + if (user_read_buf_.get()) + rv_read = DoPayloadRead(); + if (user_write_buf_.get()) + rv_write = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv_read == ERR_IO_PENDING && rv_write == ERR_IO_PENDING && + (user_read_buf_.get() || user_write_buf_.get()) && network_moved); + + // If the parent SSLClientSocketNSS is deleted during the processing of the + // Read callback and OnNSSTaskRunner() == OnNetworkTaskRunner(), then the Core + // will be detached (and possibly deleted). Guard against deletion by taking + // an extra reference, then check if the Core was detached before invoking the + // next callback. + scoped_refptr<Core> guard(this); + if (user_read_buf_.get() && rv_read != ERR_IO_PENDING) + DoReadCallback(rv_read); + + if (OnNetworkTaskRunner() && detached_) + return; + + if (user_write_buf_.get() && rv_write != ERR_IO_PENDING) + DoWriteCallback(rv_write); +} + +// As part of Connect(), the SSLClientSocketNSS object performs an SSL +// handshake. This requires network IO, which in turn calls +// BufferRecvComplete() with a non-zero byte count. This byte count eventually +// winds its way through the state machine and ends up being passed to the +// callback. For Read() and Write(), that's what we want. But for Connect(), +// the caller expects OK (i.e. 0) for success. +void SSLClientSocketNSS::Core::DoConnectCallback(int rv) { + DCHECK(OnNSSTaskRunner()); + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(!user_connect_callback_.is_null()); + + base::Closure c = base::Bind( + base::ResetAndReturn(&user_connect_callback_), + rv > OK ? OK : rv); + PostOrRunCallback(FROM_HERE, c); +} + +void SSLClientSocketNSS::Core::DoReadCallback(int rv) { + DCHECK(OnNSSTaskRunner()); + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(!user_read_callback_.is_null()); + + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_); + // This is used to curry the |amount_int_read_buffer| and |user_cb| back to + // the network task runner. + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::OnNSSBufferUpdated, this, amount_in_read_buffer)); + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::DidNSSRead, this, rv)); + PostOrRunCallback( + FROM_HERE, + base::Bind(base::ResetAndReturn(&user_read_callback_), rv)); +} + +void SSLClientSocketNSS::Core::DoWriteCallback(int rv) { + DCHECK(OnNSSTaskRunner()); + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(!user_write_callback_.is_null()); + + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + // Update buffer status because DoWriteLoop called DoTransportIO which may + // perform read operations. + int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_); + // This is used to curry the |amount_int_read_buffer| and |user_cb| back to + // the network task runner. + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::OnNSSBufferUpdated, this, amount_in_read_buffer)); + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::DidNSSWrite, this, rv)); + PostOrRunCallback( + FROM_HERE, + base::Bind(base::ResetAndReturn(&user_write_callback_), rv)); +} + +SECStatus SSLClientSocketNSS::Core::ClientChannelIDHandler( + void* arg, + PRFileDesc* socket, + SECKEYPublicKey **out_public_key, + SECKEYPrivateKey **out_private_key) { + Core* core = reinterpret_cast<Core*>(arg); + DCHECK(core->OnNSSTaskRunner()); + + core->PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEvent, core->weak_net_log_, + NetLog::TYPE_SSL_CHANNEL_ID_REQUESTED)); + + // We have negotiated the TLS channel ID extension. + core->channel_id_xtn_negotiated_ = true; + std::string host = core->host_and_port_.host(); + int error = ERR_UNEXPECTED; + if (core->OnNetworkTaskRunner()) { + error = core->DoGetDomainBoundCert(host); + } else { + bool posted = core->network_task_runner_->PostTask( + FROM_HERE, + base::Bind( + IgnoreResult(&Core::DoGetDomainBoundCert), + core, host)); + error = posted ? ERR_IO_PENDING : ERR_ABORTED; + } + + if (error == ERR_IO_PENDING) { + // Asynchronous case. + core->channel_id_needed_ = true; + return SECWouldBlock; + } + + core->PostOrRunCallback( + FROM_HERE, + base::Bind(&BoundNetLog::EndEventWithNetErrorCode, core->weak_net_log_, + NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT, error)); + SECStatus rv = SECSuccess; + if (error == OK) { + // Synchronous success. + int result = core->ImportChannelIDKeys(out_public_key, out_private_key); + if (result == OK) + core->SetChannelIDProvided(); + else + rv = SECFailure; + } else { + rv = SECFailure; + } + + return rv; +} + +int SSLClientSocketNSS::Core::ImportChannelIDKeys(SECKEYPublicKey** public_key, + SECKEYPrivateKey** key) { + // Set the certificate. + SECItem cert_item; + cert_item.data = (unsigned char*) domain_bound_cert_.data(); + cert_item.len = domain_bound_cert_.size(); + ScopedCERTCertificate cert(CERT_NewTempCertificate(CERT_GetDefaultCertDB(), + &cert_item, + NULL, + PR_FALSE, + PR_TRUE)); + if (cert == NULL) + return MapNSSError(PORT_GetError()); + + // Set the private key. + if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( + ServerBoundCertService::kEPKIPassword, + reinterpret_cast<const unsigned char*>( + domain_bound_private_key_.data()), + domain_bound_private_key_.size(), + &cert->subjectPublicKeyInfo, + false, + false, + key, + public_key)) { + int error = MapNSSError(PORT_GetError()); + return error; + } + + return OK; +} + +void SSLClientSocketNSS::Core::UpdateServerCert() { + nss_handshake_state_.server_cert_chain.Reset(nss_fd_); + nss_handshake_state_.server_cert = X509Certificate::CreateFromDERCertChain( + nss_handshake_state_.server_cert_chain.AsStringPieceVector()); + if (nss_handshake_state_.server_cert.get()) { + // Since this will be called asynchronously on another thread, it needs to + // own a reference to the certificate. + NetLog::ParametersCallback net_log_callback = + base::Bind(&NetLogX509CertificateCallback, + nss_handshake_state_.server_cert); + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_CERTIFICATES_RECEIVED, + net_log_callback)); + } +} + +void SSLClientSocketNSS::Core::UpdateConnectionStatus() { + SSLChannelInfo channel_info; + SECStatus ok = SSL_GetChannelInfo(nss_fd_, + &channel_info, sizeof(channel_info)); + if (ok == SECSuccess && + channel_info.length == sizeof(channel_info) && + channel_info.cipherSuite) { + nss_handshake_state_.ssl_connection_status |= + (static_cast<int>(channel_info.cipherSuite) & + SSL_CONNECTION_CIPHERSUITE_MASK) << + SSL_CONNECTION_CIPHERSUITE_SHIFT; + + nss_handshake_state_.ssl_connection_status |= + (static_cast<int>(channel_info.compressionMethod) & + SSL_CONNECTION_COMPRESSION_MASK) << + SSL_CONNECTION_COMPRESSION_SHIFT; + + // NSS 3.14.x doesn't have a version macro for TLS 1.2 (because NSS didn't + // support it yet), so use 0x0303 directly. + int version = SSL_CONNECTION_VERSION_UNKNOWN; + if (channel_info.protocolVersion < SSL_LIBRARY_VERSION_3_0) { + // All versions less than SSL_LIBRARY_VERSION_3_0 are treated as SSL + // version 2. + version = SSL_CONNECTION_VERSION_SSL2; + } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) { + version = SSL_CONNECTION_VERSION_SSL3; + } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_1_TLS) { + version = SSL_CONNECTION_VERSION_TLS1; + } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_1) { + version = SSL_CONNECTION_VERSION_TLS1_1; + } else if (channel_info.protocolVersion == 0x0303) { + version = SSL_CONNECTION_VERSION_TLS1_2; + } + nss_handshake_state_.ssl_connection_status |= + (version & SSL_CONNECTION_VERSION_MASK) << + SSL_CONNECTION_VERSION_SHIFT; + } + + PRBool peer_supports_renego_ext; + ok = SSL_HandshakeNegotiatedExtension(nss_fd_, ssl_renegotiation_info_xtn, + &peer_supports_renego_ext); + if (ok == SECSuccess) { + if (!peer_supports_renego_ext) { + nss_handshake_state_.ssl_connection_status |= + SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; + // Log an informational message if the server does not support secure + // renegotiation (RFC 5746). + VLOG(1) << "The server " << host_and_port_.ToString() + << " does not support the TLS renegotiation_info extension."; + } + UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", + peer_supports_renego_ext, 2); + + // We would like to eliminate fallback to SSLv3 for non-buggy servers + // because of security concerns. For example, Google offers forward + // secrecy with ECDHE but that requires TLS 1.0. An attacker can block + // TLSv1 connections and force us to downgrade to SSLv3 and remove forward + // secrecy. + // + // Yngve from Opera has suggested using the renegotiation extension as an + // indicator that SSLv3 fallback was mistaken: + // tools.ietf.org/html/draft-pettersen-tls-version-rollback-removal-00 . + // + // As a first step, measure how often clients perform version fallback + // while the server advertises support secure renegotiation. + if (ssl_config_.version_fallback && + channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) { + UMA_HISTOGRAM_BOOLEAN("Net.SSLv3FallbackToRenegoPatchedServer", + peer_supports_renego_ext == PR_TRUE); + } + } + + if (ssl_config_.version_fallback) { + nss_handshake_state_.ssl_connection_status |= + SSL_CONNECTION_VERSION_FALLBACK; + } +} + +void SSLClientSocketNSS::Core::UpdateNextProto() { + uint8 buf[256]; + SSLNextProtoState state; + unsigned buf_len; + + SECStatus rv = SSL_GetNextProto(nss_fd_, &state, buf, &buf_len, sizeof(buf)); + if (rv != SECSuccess) + return; + + nss_handshake_state_.next_proto = + std::string(reinterpret_cast<char*>(buf), buf_len); + switch (state) { + case SSL_NEXT_PROTO_NEGOTIATED: + case SSL_NEXT_PROTO_SELECTED: + nss_handshake_state_.next_proto_status = kNextProtoNegotiated; + break; + case SSL_NEXT_PROTO_NO_OVERLAP: + nss_handshake_state_.next_proto_status = kNextProtoNoOverlap; + break; + case SSL_NEXT_PROTO_NO_SUPPORT: + nss_handshake_state_.next_proto_status = kNextProtoUnsupported; + break; + default: + NOTREACHED(); + break; + } +} + +void SSLClientSocketNSS::Core::RecordChannelIDSupport() { + DCHECK(OnNSSTaskRunner()); + if (nss_handshake_state_.resumed_handshake) + return; + + // Copy the NSS task runner-only state to the network task runner and + // log histograms from there, since the histograms also need access to the + // network task runner state. + PostOrRunCallback( + FROM_HERE, + base::Bind(&Core::RecordChannelIDSupportOnNetworkTaskRunner, + this, + channel_id_xtn_negotiated_, + ssl_config_.channel_id_enabled, + crypto::ECPrivateKey::IsSupported())); +} + +void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNetworkTaskRunner( + bool negotiated_channel_id, + bool channel_id_enabled, + bool supports_ecc) const { + DCHECK(OnNetworkTaskRunner()); + + // Since this enum is used for a histogram, do not change or re-use values. + enum { + DISABLED = 0, + CLIENT_ONLY = 1, + CLIENT_AND_SERVER = 2, + CLIENT_NO_ECC = 3, + CLIENT_BAD_SYSTEM_TIME = 4, + CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5, + DOMAIN_BOUND_CERT_USAGE_MAX + } supported = DISABLED; + if (negotiated_channel_id) { + supported = CLIENT_AND_SERVER; + } else if (channel_id_enabled) { + if (!server_bound_cert_service_) + supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE; + else if (!supports_ecc) + supported = CLIENT_NO_ECC; + else if (!server_bound_cert_service_->IsSystemTimeValid()) + supported = CLIENT_BAD_SYSTEM_TIME; + else + supported = CLIENT_ONLY; + } + UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported, + DOMAIN_BOUND_CERT_USAGE_MAX); +} + +int SSLClientSocketNSS::Core::DoBufferRecv(IOBuffer* read_buffer, int len) { + DCHECK(OnNetworkTaskRunner()); + DCHECK_GT(len, 0); + + if (detached_) + return ERR_ABORTED; + + int rv = transport_->socket()->Read( + read_buffer, len, + base::Bind(&Core::BufferRecvComplete, base::Unretained(this), + scoped_refptr<IOBuffer>(read_buffer))); + + if (!OnNSSTaskRunner() && rv != ERR_IO_PENDING) { + nss_task_runner_->PostTask( + FROM_HERE, base::Bind(&Core::BufferRecvComplete, this, + scoped_refptr<IOBuffer>(read_buffer), rv)); + return rv; + } + + return rv; +} + +int SSLClientSocketNSS::Core::DoBufferSend(IOBuffer* send_buffer, int len) { + DCHECK(OnNetworkTaskRunner()); + DCHECK_GT(len, 0); + + if (detached_) + return ERR_ABORTED; + + int rv = transport_->socket()->Write( + send_buffer, len, + base::Bind(&Core::BufferSendComplete, + base::Unretained(this))); + + if (!OnNSSTaskRunner() && rv != ERR_IO_PENDING) { + nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(&Core::BufferSendComplete, this, rv)); + return rv; + } + + return rv; +} + +int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) { + DCHECK(OnNetworkTaskRunner()); + + if (detached_) + return ERR_FAILED; + + weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT); + + int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert( + host, + &domain_bound_private_key_, + &domain_bound_cert_, + base::Bind(&Core::OnGetDomainBoundCertComplete, base::Unretained(this)), + &domain_bound_cert_request_handle_); + + if (rv != ERR_IO_PENDING && !OnNSSTaskRunner()) { + nss_task_runner_->PostTask( + FROM_HERE, + base::Bind(&Core::OnHandshakeIOComplete, this, rv)); + return ERR_IO_PENDING; + } + + return rv; +} + +void SSLClientSocketNSS::Core::OnHandshakeStateUpdated( + const HandshakeState& state) { + DCHECK(OnNetworkTaskRunner()); + network_handshake_state_ = state; +} + +void SSLClientSocketNSS::Core::OnNSSBufferUpdated(int amount_in_read_buffer) { + DCHECK(OnNetworkTaskRunner()); + unhandled_buffer_size_ = amount_in_read_buffer; +} + +void SSLClientSocketNSS::Core::DidNSSRead(int result) { + DCHECK(OnNetworkTaskRunner()); + DCHECK(nss_waiting_read_); + nss_waiting_read_ = false; + if (result <= 0) + nss_is_closed_ = true; +} + +void SSLClientSocketNSS::Core::DidNSSWrite(int result) { + DCHECK(OnNetworkTaskRunner()); + DCHECK(nss_waiting_write_); + nss_waiting_write_ = false; + if (result < 0) + nss_is_closed_ = true; +} + +void SSLClientSocketNSS::Core::BufferSendComplete(int result) { + if (!OnNSSTaskRunner()) { + if (detached_) + return; + + nss_task_runner_->PostTask( + FROM_HERE, base::Bind(&Core::BufferSendComplete, this, result)); + return; + } + + DCHECK(OnNSSTaskRunner()); + + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); + transport_send_busy_ = false; + OnSendComplete(result); +} + +void SSLClientSocketNSS::Core::OnHandshakeIOComplete(int result) { + if (!OnNSSTaskRunner()) { + if (detached_) + return; + + nss_task_runner_->PostTask( + FROM_HERE, base::Bind(&Core::OnHandshakeIOComplete, this, result)); + return; + } + + DCHECK(OnNSSTaskRunner()); + + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) + DoConnectCallback(rv); +} + +void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) { + DVLOG(1) << __FUNCTION__ << " " << result; + DCHECK(OnNetworkTaskRunner()); + + OnHandshakeIOComplete(result); +} + +void SSLClientSocketNSS::Core::BufferRecvComplete( + IOBuffer* read_buffer, + int result) { + DCHECK(read_buffer); + + if (!OnNSSTaskRunner()) { + if (detached_) + return; + + nss_task_runner_->PostTask( + FROM_HERE, base::Bind(&Core::BufferRecvComplete, this, + scoped_refptr<IOBuffer>(read_buffer), result)); + return; + } + + DCHECK(OnNSSTaskRunner()); + + if (result > 0) { + char* buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + CHECK_GE(nb, result); + memcpy(buf, read_buffer->data(), result); + } else if (result == 0) { + transport_recv_eof_ = true; + } + + memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); + transport_recv_busy_ = false; + OnRecvComplete(result); +} + +void SSLClientSocketNSS::Core::PostOrRunCallback( + const tracked_objects::Location& location, + const base::Closure& task) { + if (!OnNetworkTaskRunner()) { + network_task_runner_->PostTask( + FROM_HERE, + base::Bind(&Core::PostOrRunCallback, this, location, task)); + return; + } + + if (detached_ || task.is_null()) + return; + task.Run(); +} + +void SSLClientSocketNSS::Core::AddCertProvidedEvent(int cert_count) { + PostOrRunCallback( + FROM_HERE, + base::Bind(&AddLogEventWithCallback, weak_net_log_, + NetLog::TYPE_SSL_CLIENT_CERT_PROVIDED, + NetLog::IntegerCallback("cert_count", cert_count))); +} + +void SSLClientSocketNSS::Core::SetChannelIDProvided() { + PostOrRunCallback( + FROM_HERE, base::Bind(&AddLogEvent, weak_net_log_, + NetLog::TYPE_SSL_CHANNEL_ID_PROVIDED)); + nss_handshake_state_.channel_id_sent = true; + // Update the network task runner's view of the handshake state now that + // channel id has been sent. + PostOrRunCallback( + FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, this, + nss_handshake_state_)); +} + +SSLClientSocketNSS::SSLClientSocketNSS( + base::SequencedTaskRunner* nss_task_runner, + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) + : nss_task_runner_(nss_task_runner), + transport_(transport_socket.Pass()), + host_and_port_(host_and_port), + ssl_config_(ssl_config), + cert_verifier_(context.cert_verifier), + server_bound_cert_service_(context.server_bound_cert_service), + ssl_session_cache_shard_(context.ssl_session_cache_shard), + completed_handshake_(false), + next_handshake_state_(STATE_NONE), + nss_fd_(NULL), + net_log_(transport_->socket()->NetLog()), + transport_security_state_(context.transport_security_state), + valid_thread_id_(base::kInvalidThreadId) { + EnterFunction(""); + InitCore(); + LeaveFunction(""); +} + +SSLClientSocketNSS::~SSLClientSocketNSS() { + EnterFunction(""); + Disconnect(); + LeaveFunction(""); +} + +// static +void SSLClientSocket::ClearSessionCache() { + // SSL_ClearSessionCache can't be called before NSS is initialized. Don't + // bother initializing NSS just to clear an empty SSL session cache. + if (!NSS_IsInitialized()) + return; + + SSL_ClearSessionCache(); +} + +bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { + EnterFunction(""); + ssl_info->Reset(); + if (core_->state().server_cert_chain.empty() || + !core_->state().server_cert_chain[0]) { + return false; + } + + ssl_info->cert_status = server_cert_verify_result_.cert_status; + ssl_info->cert = server_cert_verify_result_.verified_cert; + ssl_info->connection_status = + core_->state().ssl_connection_status; + ssl_info->public_key_hashes = server_cert_verify_result_.public_key_hashes; + for (HashValueVector::const_iterator i = side_pinned_public_keys_.begin(); + i != side_pinned_public_keys_.end(); ++i) { + ssl_info->public_key_hashes.push_back(*i); + } + ssl_info->is_issued_by_known_root = + server_cert_verify_result_.is_issued_by_known_root; + ssl_info->client_cert_sent = + ssl_config_.send_client_cert && ssl_config_.client_cert.get(); + ssl_info->channel_id_sent = WasChannelIDSent(); + + PRUint16 cipher_suite = SSLConnectionStatusToCipherSuite( + core_->state().ssl_connection_status); + SSLCipherSuiteInfo cipher_info; + SECStatus ok = SSL_GetCipherSuiteInfo(cipher_suite, + &cipher_info, sizeof(cipher_info)); + if (ok == SECSuccess) { + ssl_info->security_bits = cipher_info.effectiveKeyBits; + } else { + ssl_info->security_bits = -1; + LOG(DFATAL) << "SSL_GetCipherSuiteInfo returned " << PR_GetError() + << " for cipherSuite " << cipher_suite; + } + + ssl_info->handshake_type = core_->state().resumed_handshake ? + SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL; + + LeaveFunction(""); + return true; +} + +void SSLClientSocketNSS::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { + EnterFunction(""); + // TODO(rch): switch SSLCertRequestInfo.host_and_port to a HostPortPair + cert_request_info->host_and_port = host_and_port_.ToString(); + cert_request_info->cert_authorities = core_->state().cert_authorities; + LeaveFunction(""); +} + +int SSLClientSocketNSS::ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + + // SSL_ExportKeyingMaterial may block the current thread if |core_| is in + // the midst of a handshake. + SECStatus result = SSL_ExportKeyingMaterial( + nss_fd_, label.data(), label.size(), has_context, + reinterpret_cast<const unsigned char*>(context.data()), + context.length(), out, outlen); + if (result != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ExportKeyingMaterial", ""); + return MapNSSError(PORT_GetError()); + } + return OK; +} + +int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + unsigned char buf[64]; + unsigned int len; + SECStatus result = SSL_GetChannelBinding(nss_fd_, + SSL_CHANNEL_BINDING_TLS_UNIQUE, + buf, &len, arraysize(buf)); + if (result != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", ""); + return MapNSSError(PORT_GetError()); + } + out->assign(reinterpret_cast<char*>(buf), len); + return OK; +} + +SSLClientSocket::NextProtoStatus +SSLClientSocketNSS::GetNextProto(std::string* proto, + std::string* server_protos) { + *proto = core_->state().next_proto; + *server_protos = core_->state().server_protos; + return core_->state().next_proto_status; +} + +int SSLClientSocketNSS::Connect(const CompletionCallback& callback) { + EnterFunction(""); + DCHECK(transport_.get()); + // It is an error to create an SSLClientSocket whose context has no + // TransportSecurityState. + DCHECK(transport_security_state_); + DCHECK_EQ(STATE_NONE, next_handshake_state_); + DCHECK(user_connect_callback_.is_null()); + DCHECK(!callback.is_null()); + + EnsureThreadIdAssigned(); + + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT); + + int rv = Init(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + rv = InitializeSSLOptions(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + rv = InitializeSSLPeerName(); + if (rv != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + return rv; + } + + GotoState(STATE_HANDSHAKE); + + rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + + LeaveFunction(""); + return rv > OK ? OK : rv; +} + +void SSLClientSocketNSS::Disconnect() { + EnterFunction(""); + + CHECK(CalledOnValidThread()); + + // Shut down anything that may call us back. + core_->Detach(); + verifier_.reset(); + transport_->socket()->Disconnect(); + + // Reset object state. + user_connect_callback_.Reset(); + server_cert_verify_result_.Reset(); + completed_handshake_ = false; + start_cert_verification_time_ = base::TimeTicks(); + InitCore(); + + LeaveFunction(""); +} + +bool SSLClientSocketNSS::IsConnected() const { + EnterFunction(""); + bool ret = completed_handshake_ && + (core_->HasPendingAsyncOperation() || + (core_->IsConnected() && core_->HasUnhandledReceivedData()) || + transport_->socket()->IsConnected()); + LeaveFunction(""); + return ret; +} + +bool SSLClientSocketNSS::IsConnectedAndIdle() const { + EnterFunction(""); + bool ret = completed_handshake_ && + !core_->HasPendingAsyncOperation() && + !(core_->IsConnected() && core_->HasUnhandledReceivedData()) && + transport_->socket()->IsConnectedAndIdle(); + LeaveFunction(""); + return ret; +} + +int SSLClientSocketNSS::GetPeerAddress(IPEndPoint* address) const { + return transport_->socket()->GetPeerAddress(address); +} + +int SSLClientSocketNSS::GetLocalAddress(IPEndPoint* address) const { + return transport_->socket()->GetLocalAddress(address); +} + +const BoundNetLog& SSLClientSocketNSS::NetLog() const { + return net_log_; +} + +void SSLClientSocketNSS::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SSLClientSocketNSS::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SSLClientSocketNSS::WasEverUsed() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasEverUsed(); + } + NOTREACHED(); + return false; +} + +bool SSLClientSocketNSS::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->UsingTCPFastOpen(); + } + NOTREACHED(); + return false; +} + +int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(core_.get()); + DCHECK(!callback.is_null()); + + EnterFunction(buf_len); + int rv = core_->Read(buf, buf_len, callback); + LeaveFunction(rv); + + return rv; +} + +int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(core_.get()); + DCHECK(!callback.is_null()); + + EnterFunction(buf_len); + int rv = core_->Write(buf, buf_len, callback); + LeaveFunction(rv); + + return rv; +} + +bool SSLClientSocketNSS::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SSLClientSocketNSS::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + +int SSLClientSocketNSS::Init() { + EnterFunction(""); + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; +#if defined(USE_NSS) || defined(OS_IOS) + if (ssl_config_.cert_io_enabled) { + // We must call EnsureNSSHttpIOInit() here, on the IO thread, to get the IO + // loop by MessageLoopForIO::current(). + // X509Certificate::Verify() runs on a worker thread of CertVerifier. + EnsureNSSHttpIOInit(); + } +#endif + + LeaveFunction(""); + return OK; +} + +void SSLClientSocketNSS::InitCore() { + core_ = new Core(base::ThreadTaskRunnerHandle::Get().get(), + nss_task_runner_.get(), + transport_.get(), + host_and_port_, + ssl_config_, + &net_log_, + server_bound_cert_service_); +} + +int SSLClientSocketNSS::InitializeSSLOptions() { + // Transport connected, now hook it up to nss + nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize); + if (nss_fd_ == NULL) { + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code. + } + + // Grab pointer to buffers + memio_Private* nss_bufs = memio_GetSecret(nss_fd_); + + /* Create SSL state machine */ + /* Push SSL onto our fake I/O socket */ + nss_fd_ = SSL_ImportFD(NULL, nss_fd_); + if (nss_fd_ == NULL) { + LogFailedNSSFunction(net_log_, "SSL_ImportFD", ""); + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code. + } + // TODO(port): set more ssl options! Check errors! + + int rv; + + rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2"); + return ERR_UNEXPECTED; + } + + // Don't do V2 compatible hellos because they don't support TLS extensions. + rv = SSL_OptionSet(nss_fd_, SSL_V2_COMPATIBLE_HELLO, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_V2_COMPATIBLE_HELLO"); + return ERR_UNEXPECTED; + } + + SSLVersionRange version_range; + version_range.min = ssl_config_.version_min; + version_range.max = ssl_config_.version_max; + rv = SSL_VersionRangeSet(nss_fd_, &version_range); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", ""); + return ERR_NO_SSL_VERSIONS_ENABLED; + } + + for (std::vector<uint16>::const_iterator it = + ssl_config_.disabled_cipher_suites.begin(); + it != ssl_config_.disabled_cipher_suites.end(); ++it) { + // This will fail if the specified cipher is not implemented by NSS, but + // the failure is harmless. + SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); + } + + // Support RFC 5077 + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, + ssl_config_.false_start_enabled); + if (rv != SECSuccess) + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START"); + + // We allow servers to request renegotiation. Since we're a client, + // prohibiting this is rather a waste of time. Only servers are in a + // position to prevent renegotiation attacks. + // http://extendedsubset.com/?p=8 + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_RENEGOTIATION, + SSL_RENEGOTIATE_TRANSITIONAL); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_RENEGOTIATION"); + } + + rv = SSL_OptionSet(nss_fd_, SSL_CBC_RANDOM_IV, PR_TRUE); + if (rv != SECSuccess) + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_CBC_RANDOM_IV"); + +// Added in NSS 3.15 +#ifdef SSL_ENABLE_OCSP_STAPLING + if (IsOCSPStaplingSupported()) { + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_OCSP_STAPLING, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", + "SSL_ENABLE_OCSP_STAPLING"); + } + } +#endif + +// Chromium patch to libssl +#ifdef SSL_ENABLE_CACHED_INFO + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_CACHED_INFO, + ssl_config_.cached_info_enabled); + if (rv != SECSuccess) + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_CACHED_INFO"); +#endif + + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT"); + return ERR_UNEXPECTED; + } + + if (!core_->Init(nss_fd_, nss_bufs)) + return ERR_UNEXPECTED; + + // Tell SSL the hostname we're trying to connect to. + SSL_SetURL(nss_fd_, host_and_port_.host().c_str()); + + // Tell SSL we're a client; needed if not letting NSPR do socket I/O + SSL_ResetHandshake(nss_fd_, PR_FALSE); + + return OK; +} + +int SSLClientSocketNSS::InitializeSSLPeerName() { + // Tell NSS who we're connected to + IPEndPoint peer_address; + int err = transport_->socket()->GetPeerAddress(&peer_address); + if (err != OK) + return err; + + SockaddrStorage storage; + if (!peer_address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_UNEXPECTED; + + PRNetAddr peername; + memset(&peername, 0, sizeof(peername)); + DCHECK_LE(static_cast<size_t>(storage.addr_len), sizeof(peername)); + size_t len = std::min(static_cast<size_t>(storage.addr_len), + sizeof(peername)); + memcpy(&peername, storage.addr, len); + + // Adjust the address family field for BSD, whose sockaddr + // structure has a one-byte length and one-byte address family + // field at the beginning. PRNetAddr has a two-byte address + // family field at the beginning. + peername.raw.family = storage.addr->sa_family; + + memio_SetPeerName(nss_fd_, &peername); + + // Set the peer ID for session reuse. This is necessary when we create an + // SSL tunnel through a proxy -- GetPeerName returns the proxy's address + // rather than the destination server's address in that case. + std::string peer_id = host_and_port_.ToString(); + // If the ssl_session_cache_shard_ is non-empty, we append it to the peer id. + // This will cause session cache misses between sockets with different values + // of ssl_session_cache_shard_ and this is used to partition the session cache + // for incognito mode. + if (!ssl_session_cache_shard_.empty()) { + peer_id += "/" + ssl_session_cache_shard_; + } + SECStatus rv = SSL_SetSockPeerID(nss_fd_, const_cast<char*>(peer_id.c_str())); + if (rv != SECSuccess) + LogFailedNSSFunction(net_log_, "SSL_SetSockPeerID", peer_id.c_str()); + + return OK; +} + +void SSLClientSocketNSS::DoConnectCallback(int rv) { + EnterFunction(rv); + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(!user_connect_callback_.is_null()); + + base::ResetAndReturn(&user_connect_callback_).Run(rv > OK ? OK : rv); + LeaveFunction(""); +} + +void SSLClientSocketNSS::OnHandshakeIOComplete(int result) { + EnterFunction(result); + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + DoConnectCallback(rv); + } + LeaveFunction(""); +} + +int SSLClientSocketNSS::DoHandshakeLoop(int last_io_result) { + EnterFunction(last_io_result); + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + case STATE_HANDSHAKE_COMPLETE: + rv = DoHandshakeComplete(rv); + break; + case STATE_VERIFY_CERT: + DCHECK(rv == OK); + rv = DoVerifyCert(rv); + break; + case STATE_VERIFY_CERT_COMPLETE: + rv = DoVerifyCertComplete(rv); + break; + case STATE_NONE: + default: + rv = ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state " << state; + break; + } + } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); + LeaveFunction(""); + return rv; +} + +int SSLClientSocketNSS::DoHandshake() { + EnterFunction(""); + int rv = core_->Connect( + base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, + base::Unretained(this))); + GotoState(STATE_HANDSHAKE_COMPLETE); + + LeaveFunction(rv); + return rv; +} + +int SSLClientSocketNSS::DoHandshakeComplete(int result) { + EnterFunction(result); + + if (result == OK) { + // SSL handshake is completed. Let's verify the certificate. + GotoState(STATE_VERIFY_CERT); + // Done! + } + set_channel_id_sent(core_->state().channel_id_sent); + + LeaveFunction(result); + return result; +} + + +int SSLClientSocketNSS::DoVerifyCert(int result) { + DCHECK(!core_->state().server_cert_chain.empty()); + DCHECK(core_->state().server_cert_chain[0]); + + GotoState(STATE_VERIFY_CERT_COMPLETE); + + // If the certificate is expected to be bad we can use the expectation as + // the cert status. + base::StringPiece der_cert( + reinterpret_cast<char*>( + core_->state().server_cert_chain[0]->derCert.data), + core_->state().server_cert_chain[0]->derCert.len); + CertStatus cert_status; + if (ssl_config_.IsAllowedBadCert(der_cert, &cert_status)) { + DCHECK(start_cert_verification_time_.is_null()); + VLOG(1) << "Received an expected bad cert with status: " << cert_status; + server_cert_verify_result_.Reset(); + server_cert_verify_result_.cert_status = cert_status; + server_cert_verify_result_.verified_cert = core_->state().server_cert; + return OK; + } + + // We may have failed to create X509Certificate object if we are + // running inside sandbox. + if (!core_->state().server_cert.get()) { + server_cert_verify_result_.Reset(); + server_cert_verify_result_.cert_status = CERT_STATUS_INVALID; + return ERR_CERT_INVALID; + } + + start_cert_verification_time_ = base::TimeTicks::Now(); + + int flags = 0; + if (ssl_config_.rev_checking_enabled) + flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED; + if (ssl_config_.verify_ev_cert) + flags |= CertVerifier::VERIFY_EV_CERT; + if (ssl_config_.cert_io_enabled) + flags |= CertVerifier::VERIFY_CERT_IO_ENABLED; + if (ssl_config_.rev_checking_required_local_anchors) + flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS; + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); + return verifier_->Verify( + core_->state().server_cert.get(), + host_and_port_.host(), + flags, + SSLConfigService::GetCRLSet().get(), + &server_cert_verify_result_, + base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, + base::Unretained(this)), + net_log_); +} + +// Derived from AuthCertificateCallback() in +// mozilla/source/security/manager/ssl/src/nsNSSCallbacks.cpp. +int SSLClientSocketNSS::DoVerifyCertComplete(int result) { + verifier_.reset(); + + if (!start_cert_verification_time_.is_null()) { + base::TimeDelta verify_time = + base::TimeTicks::Now() - start_cert_verification_time_; + if (result == OK) + UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTime", verify_time); + else + UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTimeError", verify_time); + } + + // We used to remember the intermediate CA certs in the NSS database + // persistently. However, NSS opens a connection to the SQLite database + // during NSS initialization and doesn't close the connection until NSS + // shuts down. If the file system where the database resides is gone, + // the database connection goes bad. What's worse, the connection won't + // recover when the file system comes back. Until this NSS or SQLite bug + // is fixed, we need to avoid using the NSS database for non-essential + // purposes. See https://bugzilla.mozilla.org/show_bug.cgi?id=508081 and + // http://crbug.com/15630 for more info. + + // TODO(hclam): Skip logging if server cert was expected to be bad because + // |server_cert_verify_result_| doesn't contain all the information about + // the cert. + if (result == OK) + LogConnectionTypeMetrics(); + + completed_handshake_ = true; + +#if defined(OFFICIAL_BUILD) && !defined(OS_ANDROID) && !defined(OS_IOS) + // Take care of any mandates for public key pinning. + // + // Pinning is only enabled for official builds to make sure that others don't + // end up with pins that cannot be easily updated. + // + // TODO(agl): We might have an issue here where a request for foo.example.com + // merges into a SPDY connection to www.example.com, and gets a different + // certificate. + + // Perform pin validation if, and only if, all these conditions obtain: + // + // * a TransportSecurityState object is available; + // * the server's certificate chain is valid (or suffers from only a minor + // error); + // * the server's certificate chain chains up to a known root (i.e. not a + // user-installed trust anchor); and + // * the build is recent (very old builds should fail open so that users + // have some chance to recover). + // + const CertStatus cert_status = server_cert_verify_result_.cert_status; + if (transport_security_state_ && + (result == OK || + (IsCertificateError(result) && IsCertStatusMinorError(cert_status))) && + server_cert_verify_result_.is_issued_by_known_root && + TransportSecurityState::IsBuildTimely()) { + bool sni_available = + ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1 || + ssl_config_.version_fallback; + const std::string& host = host_and_port_.host(); + + TransportSecurityState::DomainState domain_state; + if (transport_security_state_->GetDomainState(host, sni_available, + &domain_state) && + domain_state.HasPublicKeyPins()) { + if (!domain_state.CheckPublicKeyPins( + server_cert_verify_result_.public_key_hashes)) { + result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; + UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", false); + TransportSecurityState::ReportUMAOnPinFailure(host); + } else { + UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", true); + } + } + } +#endif + + // Exit DoHandshakeLoop and return the result to the caller to Connect. + DCHECK_EQ(STATE_NONE, next_handshake_state_); + return result; +} + +void SSLClientSocketNSS::LogConnectionTypeMetrics() const { + UpdateConnectionTypeHistograms(CONNECTION_SSL); + int ssl_version = SSLConnectionStatusToVersion( + core_->state().ssl_connection_status); + switch (ssl_version) { + case SSL_CONNECTION_VERSION_SSL2: + UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2); + break; + case SSL_CONNECTION_VERSION_SSL3: + UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3); + break; + case SSL_CONNECTION_VERSION_TLS1: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1); + break; + case SSL_CONNECTION_VERSION_TLS1_1: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1); + break; + case SSL_CONNECTION_VERSION_TLS1_2: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2); + break; + }; +} + +void SSLClientSocketNSS::EnsureThreadIdAssigned() const { + base::AutoLock auto_lock(lock_); + if (valid_thread_id_ != base::kInvalidThreadId) + return; + valid_thread_id_ = base::PlatformThread::CurrentId(); +} + +bool SSLClientSocketNSS::CalledOnValidThread() const { + EnsureThreadIdAssigned(); + base::AutoLock auto_lock(lock_); + return valid_thread_id_ == base::PlatformThread::CurrentId(); +} + +ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const { + return server_bound_cert_service_; +} + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h new file mode 100644 index 00000000000..b41d28d74a8 --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -0,0 +1,196 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_ +#define NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_ + +#include <certt.h> +#include <keyt.h> +#include <nspr.h> +#include <nss.h> + +#include <string> +#include <vector> + +#include "base/memory/scoped_ptr.h" +#include "base/synchronization/lock.h" +#include "base/threading/platform_thread.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/base/nss_memio.h" +#include "net/cert/cert_verify_result.h" +#include "net/cert/x509_certificate.h" +#include "net/socket/ssl_client_socket.h" +#include "net/ssl/server_bound_cert_service.h" +#include "net/ssl/ssl_config_service.h" + +namespace base { +class SequencedTaskRunner; +} + +namespace net { + +class BoundNetLog; +class CertVerifier; +class ClientSocketHandle; +class ServerBoundCertService; +class SingleRequestCertVerifier; +class TransportSecurityState; +class X509Certificate; + +// An SSL client socket implemented with Mozilla NSS. +class SSLClientSocketNSS : public SSLClientSocket { + public: + // Takes ownership of the |transport_socket|, which must already be connected. + // The hostname specified in |host_and_port| will be compared with the name(s) + // in the server's certificate during the SSL handshake. If SSL client + // authentication is requested, the host_and_port field of SSLCertRequestInfo + // will be populated with |host_and_port|. |ssl_config| specifies + // the SSL settings. + // + // Because calls to NSS may block, such as due to needing to access slow + // hardware or needing to synchronously unlock protected tokens, calls to + // NSS may optionally be run on a dedicated thread. If synchronous/blocking + // behaviour is desired, for performance or compatibility, the current task + // runner should be supplied instead. + SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner, + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context); + virtual ~SSLClientSocketNSS(); + + // SSLClientSocket implementation. + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual NextProtoStatus GetNextProto(std::string* proto, + std::string* server_protos) OVERRIDE; + + // SSLSocket implementation. + virtual int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) OVERRIDE; + virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + + private: + // Helper class to handle marshalling any NSS interaction to and from the + // NSS and network task runners. Not every call needs to happen on the Core + class Core; + + enum State { + STATE_NONE, + STATE_HANDSHAKE, + STATE_HANDSHAKE_COMPLETE, + STATE_VERIFY_CERT, + STATE_VERIFY_CERT_COMPLETE, + }; + + int Init(); + void InitCore(); + + // Initializes NSS SSL options. Returns a net error code. + int InitializeSSLOptions(); + + // Initializes the socket peer name in SSL. Returns a net error code. + int InitializeSSLPeerName(); + + void DoConnectCallback(int result); + void OnHandshakeIOComplete(int result); + + int DoHandshakeLoop(int last_io_result); + int DoHandshake(); + int DoHandshakeComplete(int result); + int DoVerifyCert(int result); + int DoVerifyCertComplete(int result); + + void LogConnectionTypeMetrics() const; + + // The following methods are for debugging bug 65948. Will remove this code + // after fixing bug 65948. + void EnsureThreadIdAssigned() const; + bool CalledOnValidThread() const; + + // The task runner used to perform NSS operations. + scoped_refptr<base::SequencedTaskRunner> nss_task_runner_; + scoped_ptr<ClientSocketHandle> transport_; + HostPortPair host_and_port_; + SSLConfig ssl_config_; + + scoped_refptr<Core> core_; + + CompletionCallback user_connect_callback_; + + CertVerifyResult server_cert_verify_result_; + HashValueVector side_pinned_public_keys_; + + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; + + // The service for retrieving Channel ID keys. May be NULL. + ServerBoundCertService* server_bound_cert_service_; + + // ssl_session_cache_shard_ is an opaque string that partitions the SSL + // session cache. i.e. sessions created with one value will not attempt to + // resume on the socket with a different value. + const std::string ssl_session_cache_shard_; + + // True if the SSL handshake has been completed. + bool completed_handshake_; + + State next_handshake_state_; + + // The NSS SSL state machine. This is owned by |core_|. + // TODO(rsleevi): http://crbug.com/130616 - Remove this member once + // ExportKeyingMaterial is updated to be asynchronous. + PRFileDesc* nss_fd_; + + BoundNetLog net_log_; + + base::TimeTicks start_cert_verification_time_; + + TransportSecurityState* transport_security_state_; + + // The following two variables are added for debugging bug 65948. Will + // remove this code after fixing bug 65948. + // Added the following code Debugging in release mode. + mutable base::Lock lock_; + // This is mutable so that CalledOnValidThread can set it. + // It's guarded by |lock_|. + mutable base::PlatformThreadId valid_thread_id_; +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_ diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc new file mode 100644 index 00000000000..4591cec5b9d --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -0,0 +1,1435 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// OpenSSL binding for SSLClientSocket. The class layout and general principle +// of operation is derived from SSLClientSocketNSS. + +#include "net/socket/ssl_client_socket_openssl.h" + +#include <openssl/err.h> +#include <openssl/opensslv.h> +#include <openssl/ssl.h> + +#include "base/bind.h" +#include "base/callback_helpers.h" +#include "base/memory/singleton.h" +#include "base/metrics/histogram.h" +#include "base/synchronization/lock.h" +#include "crypto/openssl_util.h" +#include "net/base/net_errors.h" +#include "net/cert/cert_verifier.h" +#include "net/cert/single_request_cert_verifier.h" +#include "net/cert/x509_certificate_net_log_param.h" +#include "net/socket/ssl_error_params.h" +#include "net/ssl/openssl_client_key_store.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_info.h" + +namespace net { + +namespace { + +// Enable this to see logging for state machine state transitions. +#if 0 +#define GotoState(s) do { DVLOG(2) << (void *)this << " " << __FUNCTION__ << \ + " jump to state " << s; \ + next_handshake_state_ = s; } while (0) +#else +#define GotoState(s) next_handshake_state_ = s +#endif + +const int kSessionCacheTimeoutSeconds = 60 * 60; +const size_t kSessionCacheMaxEntires = 1024; + +// This constant can be any non-negative/non-zero value (eg: it does not +// overlap with any value of the net::Error range, including net::OK). +const int kNoPendingReadResult = 1; + +// If a client doesn't have a list of protocols that it supports, but +// the server supports NPN, choosing "http/1.1" is the best answer. +const char kDefaultSupportedNPNProtocol[] = "http/1.1"; + +#if OPENSSL_VERSION_NUMBER < 0x1000103fL +// This method doesn't seem to have made it into the OpenSSL headers. +unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { return cipher->id; } +#endif + +// Used for encoding the |connection_status| field of an SSLInfo object. +int EncodeSSLConnectionStatus(int cipher_suite, + int compression, + int version) { + return ((cipher_suite & SSL_CONNECTION_CIPHERSUITE_MASK) << + SSL_CONNECTION_CIPHERSUITE_SHIFT) | + ((compression & SSL_CONNECTION_COMPRESSION_MASK) << + SSL_CONNECTION_COMPRESSION_SHIFT) | + ((version & SSL_CONNECTION_VERSION_MASK) << + SSL_CONNECTION_VERSION_SHIFT); +} + +// Returns the net SSL version number (see ssl_connection_status_flags.h) for +// this SSL connection. +int GetNetSSLVersion(SSL* ssl) { + switch (SSL_version(ssl)) { + case SSL2_VERSION: + return SSL_CONNECTION_VERSION_SSL2; + case SSL3_VERSION: + return SSL_CONNECTION_VERSION_SSL3; + case TLS1_VERSION: + return SSL_CONNECTION_VERSION_TLS1; + case 0x0302: + return SSL_CONNECTION_VERSION_TLS1_1; + case 0x0303: + return SSL_CONNECTION_VERSION_TLS1_2; + default: + return SSL_CONNECTION_VERSION_UNKNOWN; + } +} + +int MapOpenSSLErrorSSL() { + // Walk down the error stack to find the SSLerr generated reason. + unsigned long error_code; + do { + error_code = ERR_get_error(); + if (error_code == 0) + return ERR_SSL_PROTOCOL_ERROR; + } while (ERR_GET_LIB(error_code) != ERR_LIB_SSL); + + DVLOG(1) << "OpenSSL SSL error, reason: " << ERR_GET_REASON(error_code) + << ", name: " << ERR_error_string(error_code, NULL); + switch (ERR_GET_REASON(error_code)) { + case SSL_R_READ_TIMEOUT_EXPIRED: + return ERR_TIMED_OUT; + case SSL_R_BAD_RESPONSE_ARGUMENT: + return ERR_INVALID_ARGUMENT; + case SSL_R_UNKNOWN_CERTIFICATE_TYPE: + case SSL_R_UNKNOWN_CIPHER_TYPE: + case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE: + case SSL_R_UNKNOWN_PKEY_TYPE: + case SSL_R_UNKNOWN_REMOTE_ERROR_TYPE: + case SSL_R_UNKNOWN_SSL_VERSION: + return ERR_NOT_IMPLEMENTED; + case SSL_R_UNSUPPORTED_SSL_VERSION: + case SSL_R_NO_CIPHER_MATCH: + case SSL_R_NO_SHARED_CIPHER: + case SSL_R_TLSV1_ALERT_INSUFFICIENT_SECURITY: + case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION: + case SSL_R_UNSUPPORTED_PROTOCOL: + return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; + case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE: + case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE: + case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED: + case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED: + case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN: + case SSL_R_TLSV1_ALERT_ACCESS_DENIED: + case SSL_R_TLSV1_ALERT_UNKNOWN_CA: + return ERR_BAD_SSL_CLIENT_AUTH_CERT; + case SSL_R_BAD_DECOMPRESSION: + case SSL_R_SSLV3_ALERT_DECOMPRESSION_FAILURE: + return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; + case SSL_R_SSLV3_ALERT_BAD_RECORD_MAC: + return ERR_SSL_BAD_RECORD_MAC_ALERT; + case SSL_R_TLSV1_ALERT_DECRYPT_ERROR: + return ERR_SSL_DECRYPT_ERROR_ALERT; + case SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED: + return ERR_SSL_UNSAFE_NEGOTIATION; + case SSL_R_WRONG_NUMBER_OF_KEY_BITS: + return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; + // SSL_R_UNKNOWN_PROTOCOL is reported if premature application data is + // received (see http://crbug.com/42538), and also if all the protocol + // versions supported by the server were disabled in this socket instance. + // Mapped to ERR_SSL_PROTOCOL_ERROR for compatibility with other SSL sockets + // in the former scenario. + case SSL_R_UNKNOWN_PROTOCOL: + case SSL_R_SSL_HANDSHAKE_FAILURE: + case SSL_R_DECRYPTION_FAILED: + case SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC: + case SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG: + case SSL_R_DIGEST_CHECK_FAILED: + case SSL_R_DUPLICATE_COMPRESSION_ID: + case SSL_R_ECGROUP_TOO_LARGE_FOR_CIPHER: + case SSL_R_ENCRYPTED_LENGTH_TOO_LONG: + case SSL_R_ERROR_IN_RECEIVED_CIPHER_LIST: + case SSL_R_EXCESSIVE_MESSAGE_SIZE: + case SSL_R_EXTRA_DATA_IN_MESSAGE: + case SSL_R_GOT_A_FIN_BEFORE_A_CCS: + case SSL_R_ILLEGAL_PADDING: + case SSL_R_INVALID_CHALLENGE_LENGTH: + case SSL_R_INVALID_COMMAND: + case SSL_R_INVALID_PURPOSE: + case SSL_R_INVALID_STATUS_RESPONSE: + case SSL_R_INVALID_TICKET_KEYS_LENGTH: + case SSL_R_KEY_ARG_TOO_LONG: + case SSL_R_READ_WRONG_PACKET_TYPE: + case SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE: + // TODO(joth): SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE may be returned from the + // server after receiving ClientHello if there's no common supported cipher. + // Ideally we'd map that specific case to ERR_SSL_VERSION_OR_CIPHER_MISMATCH + // to match the NSS implementation. See also http://goo.gl/oMtZW + case SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE: + case SSL_R_SSLV3_ALERT_NO_CERTIFICATE: + case SSL_R_SSLV3_ALERT_ILLEGAL_PARAMETER: + case SSL_R_TLSV1_ALERT_DECODE_ERROR: + case SSL_R_TLSV1_ALERT_DECRYPTION_FAILED: + case SSL_R_TLSV1_ALERT_EXPORT_RESTRICTION: + case SSL_R_TLSV1_ALERT_INTERNAL_ERROR: + case SSL_R_TLSV1_ALERT_NO_RENEGOTIATION: + case SSL_R_TLSV1_ALERT_RECORD_OVERFLOW: + case SSL_R_TLSV1_ALERT_USER_CANCELLED: + return ERR_SSL_PROTOCOL_ERROR; + default: + LOG(WARNING) << "Unmapped error reason: " << ERR_GET_REASON(error_code); + return ERR_FAILED; + } +} + +// Converts an OpenSSL error code into a net error code, walking the OpenSSL +// error stack if needed. Note that |tracer| is not currently used in the +// implementation, but is passed in anyway as this ensures the caller will clear +// any residual codes left on the error stack. +int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer) { + switch (err) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return ERR_IO_PENDING; + case SSL_ERROR_SYSCALL: + LOG(ERROR) << "OpenSSL SYSCALL error, earliest error code in " + "error queue: " << ERR_peek_error() << ", errno: " + << errno; + return ERR_SSL_PROTOCOL_ERROR; + case SSL_ERROR_SSL: + return MapOpenSSLErrorSSL(); + default: + // TODO(joth): Implement full mapping. + LOG(WARNING) << "Unknown OpenSSL error " << err; + return ERR_SSL_PROTOCOL_ERROR; + } +} + +// We do certificate verification after handshake, so we disable the default +// by registering a no-op verify function. +int NoOpVerifyCallback(X509_STORE_CTX*, void *) { + DVLOG(3) << "skipping cert verify"; + return 1; +} + +// OpenSSL manages a cache of SSL_SESSION, this class provides the application +// side policy for that cache about session re-use: we retain one session per +// unique HostPortPair, per shard. +class SSLSessionCache { + public: + SSLSessionCache() {} + + void OnSessionAdded(const HostPortPair& host_and_port, + const std::string& shard, + SSL_SESSION* session) { + // Declare the session cleaner-upper before the lock, so any call into + // OpenSSL to free the session will happen after the lock is released. + crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; + base::AutoLock lock(lock_); + + DCHECK_EQ(0U, session_map_.count(session)); + const std::string cache_key = GetCacheKey(host_and_port, shard); + + std::pair<HostPortMap::iterator, bool> res = + host_port_map_.insert(std::make_pair(cache_key, session)); + if (!res.second) { // Already exists: replace old entry. + session_to_free.reset(res.first->second); + session_map_.erase(session_to_free.get()); + res.first->second = session; + } + DVLOG(2) << "Adding session " << session << " => " + << cache_key << ", new entry = " << res.second; + DCHECK(host_port_map_[cache_key] == session); + session_map_[session] = res.first; + DCHECK_EQ(host_port_map_.size(), session_map_.size()); + DCHECK_LE(host_port_map_.size(), kSessionCacheMaxEntires); + } + + void OnSessionRemoved(SSL_SESSION* session) { + // Declare the session cleaner-upper before the lock, so any call into + // OpenSSL to free the session will happen after the lock is released. + crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; + base::AutoLock lock(lock_); + + SessionMap::iterator it = session_map_.find(session); + if (it == session_map_.end()) + return; + DVLOG(2) << "Remove session " << session << " => " << it->second->first; + DCHECK(it->second->second == session); + host_port_map_.erase(it->second); + session_map_.erase(it); + session_to_free.reset(session); + DCHECK_EQ(host_port_map_.size(), session_map_.size()); + } + + // Looks up the host:port in the cache, and if a session is found it is added + // to |ssl|, returning true on success. + bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port, + const std::string& shard) { + base::AutoLock lock(lock_); + const std::string cache_key = GetCacheKey(host_and_port, shard); + HostPortMap::iterator it = host_port_map_.find(cache_key); + if (it == host_port_map_.end()) + return false; + DVLOG(2) << "Lookup session: " << it->second << " => " << cache_key; + SSL_SESSION* session = it->second; + DCHECK(session); + DCHECK(session_map_[session] == it); + // Ideally we'd release |lock_| before calling into OpenSSL here, however + // that opens a small risk |session| will go out of scope before it is used. + // Alternatively we would take a temporary local refcount on |session|, + // except OpenSSL does not provide a public API for adding a ref (c.f. + // SSL_SESSION_free which decrements the ref). + return SSL_set_session(ssl, session) == 1; + } + + // Flush removes all entries from the cache. This is called when a client + // certificate is added. + void Flush() { + for (HostPortMap::iterator i = host_port_map_.begin(); + i != host_port_map_.end(); i++) { + SSL_SESSION_free(i->second); + } + host_port_map_.clear(); + session_map_.clear(); + } + + private: + static std::string GetCacheKey(const HostPortPair& host_and_port, + const std::string& shard) { + return host_and_port.ToString() + "/" + shard; + } + + // A pair of maps to allow bi-directional lookups between host:port and an + // associated session. + typedef std::map<std::string, SSL_SESSION*> HostPortMap; + typedef std::map<SSL_SESSION*, HostPortMap::iterator> SessionMap; + HostPortMap host_port_map_; + SessionMap session_map_; + + // Protects access to both the above maps. + base::Lock lock_; + + DISALLOW_COPY_AND_ASSIGN(SSLSessionCache); +}; + +class SSLContext { + public: + static SSLContext* GetInstance() { return Singleton<SSLContext>::get(); } + SSL_CTX* ssl_ctx() { return ssl_ctx_.get(); } + SSLSessionCache* session_cache() { return &session_cache_; } + + SSLClientSocketOpenSSL* GetClientSocketFromSSL(SSL* ssl) { + DCHECK(ssl); + SSLClientSocketOpenSSL* socket = static_cast<SSLClientSocketOpenSSL*>( + SSL_get_ex_data(ssl, ssl_socket_data_index_)); + DCHECK(socket); + return socket; + } + + bool SetClientSocketForSSL(SSL* ssl, SSLClientSocketOpenSSL* socket) { + return SSL_set_ex_data(ssl, ssl_socket_data_index_, socket) != 0; + } + + private: + friend struct DefaultSingletonTraits<SSLContext>; + + SSLContext() { + crypto::EnsureOpenSSLInit(); + ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0); + DCHECK_NE(ssl_socket_data_index_, -1); + ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method())); + SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), NoOpVerifyCallback, NULL); + SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_CLIENT); + SSL_CTX_sess_set_new_cb(ssl_ctx_.get(), NewSessionCallbackStatic); + SSL_CTX_sess_set_remove_cb(ssl_ctx_.get(), RemoveSessionCallbackStatic); + SSL_CTX_set_timeout(ssl_ctx_.get(), kSessionCacheTimeoutSeconds); + SSL_CTX_sess_set_cache_size(ssl_ctx_.get(), kSessionCacheMaxEntires); + SSL_CTX_set_client_cert_cb(ssl_ctx_.get(), ClientCertCallback); +#if defined(OPENSSL_NPN_NEGOTIATED) + // TODO(kristianm): Only select this if ssl_config_.next_proto is not empty. + // It would be better if the callback were not a global setting, + // but that is an OpenSSL issue. + SSL_CTX_set_next_proto_select_cb(ssl_ctx_.get(), SelectNextProtoCallback, + NULL); +#endif + } + + static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { + return GetInstance()->NewSessionCallback(ssl, session); + } + + int NewSessionCallback(SSL* ssl, SSL_SESSION* session) { + SSLClientSocketOpenSSL* socket = GetClientSocketFromSSL(ssl); + session_cache_.OnSessionAdded(socket->host_and_port(), + socket->ssl_session_cache_shard(), + session); + return 1; // 1 => We took ownership of |session|. + } + + static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) { + return GetInstance()->RemoveSessionCallback(ctx, session); + } + + void RemoveSessionCallback(SSL_CTX* ctx, SSL_SESSION* session) { + DCHECK(ctx == ssl_ctx()); + session_cache_.OnSessionRemoved(session); + } + + static int ClientCertCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey) { + SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); + CHECK(socket); + return socket->ClientCertRequestCallback(ssl, x509, pkey); + } + + static int SelectNextProtoCallback(SSL* ssl, + unsigned char** out, unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, void* arg) { + SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); + return socket->SelectNextProtoCallback(out, outlen, in, inlen); + } + + // This is the index used with SSL_get_ex_data to retrieve the owner + // SSLClientSocketOpenSSL object from an SSL instance. + int ssl_socket_data_index_; + + // session_cache_ must appear before |ssl_ctx_| because the destruction of + // |ssl_ctx_| may trigger callbacks into |session_cache_|. Therefore, + // |session_cache_| must be destructed after |ssl_ctx_|. + SSLSessionCache session_cache_; + crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; +}; + +// Utility to construct the appropriate set & clear masks for use the OpenSSL +// options and mode configuration functions. (SSL_set_options etc) +struct SslSetClearMask { + SslSetClearMask() : set_mask(0), clear_mask(0) {} + void ConfigureFlag(long flag, bool state) { + (state ? set_mask : clear_mask) |= flag; + // Make sure we haven't got any intersection in the set & clear options. + DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state; + } + long set_mask; + long clear_mask; +}; + +} // namespace + +// static +void SSLClientSocket::ClearSessionCache() { + SSLContext* context = SSLContext::GetInstance(); + context->session_cache()->Flush(); +} + +SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) + : transport_send_busy_(false), + transport_recv_busy_(false), + transport_recv_eof_(false), + weak_factory_(this), + pending_read_error_(kNoPendingReadResult), + completed_handshake_(false), + client_auth_cert_needed_(false), + cert_verifier_(context.cert_verifier), + ssl_(NULL), + transport_bio_(NULL), + transport_(transport_socket.Pass()), + host_and_port_(host_and_port), + ssl_config_(ssl_config), + ssl_session_cache_shard_(context.ssl_session_cache_shard), + trying_cached_session_(false), + next_handshake_state_(STATE_NONE), + npn_status_(kNextProtoUnsupported), + net_log_(transport_->socket()->NetLog()) { +} + +SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { + Disconnect(); +} + +bool SSLClientSocketOpenSSL::Init() { + DCHECK(!ssl_); + DCHECK(!transport_bio_); + + SSLContext* context = SSLContext::GetInstance(); + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + + ssl_ = SSL_new(context->ssl_ctx()); + if (!ssl_ || !context->SetClientSocketForSSL(ssl_, this)) + return false; + + if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) + return false; + + trying_cached_session_ = + context->session_cache()->SetSSLSession(ssl_, host_and_port_, + ssl_session_cache_shard_); + + BIO* ssl_bio = NULL; + // 0 => use default buffer sizes. + if (!BIO_new_bio_pair(&ssl_bio, 0, &transport_bio_, 0)) + return false; + DCHECK(ssl_bio); + DCHECK(transport_bio_); + + SSL_set_bio(ssl_, ssl_bio, ssl_bio); + + // OpenSSL defaults some options to on, others to off. To avoid ambiguity, + // set everything we care about to an absolute value. + SslSetClearMask options; + options.ConfigureFlag(SSL_OP_NO_SSLv2, true); + bool ssl3_enabled = (ssl_config_.version_min == SSL_PROTOCOL_VERSION_SSL3); + options.ConfigureFlag(SSL_OP_NO_SSLv3, !ssl3_enabled); + bool tls1_enabled = (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1 && + ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1); + options.ConfigureFlag(SSL_OP_NO_TLSv1, !tls1_enabled); +#if defined(SSL_OP_NO_TLSv1_1) + bool tls1_1_enabled = + (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_1 && + ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1); + options.ConfigureFlag(SSL_OP_NO_TLSv1_1, !tls1_1_enabled); +#endif +#if defined(SSL_OP_NO_TLSv1_2) + bool tls1_2_enabled = + (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_2 && + ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_2); + options.ConfigureFlag(SSL_OP_NO_TLSv1_2, !tls1_2_enabled); +#endif + +#if defined(SSL_OP_NO_COMPRESSION) + options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true); +#endif + + // TODO(joth): Set this conditionally, see http://crbug.com/55410 + options.ConfigureFlag(SSL_OP_LEGACY_SERVER_CONNECT, true); + + SSL_set_options(ssl_, options.set_mask); + SSL_clear_options(ssl_, options.clear_mask); + + // Same as above, this time for the SSL mode. + SslSetClearMask mode; + +#if defined(SSL_MODE_RELEASE_BUFFERS) + mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true); +#endif + +#if defined(SSL_MODE_SMALL_BUFFERS) + mode.ConfigureFlag(SSL_MODE_SMALL_BUFFERS, true); +#endif + + SSL_set_mode(ssl_, mode.set_mask); + SSL_clear_mode(ssl_, mode.clear_mask); + + // Removing ciphers by ID from OpenSSL is a bit involved as we must use the + // textual name with SSL_set_cipher_list because there is no public API to + // directly remove a cipher by ID. + STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_); + DCHECK(ciphers); + // See SSLConfig::disabled_cipher_suites for description of the suites + // disabled by default. Note that !SHA384 only removes HMAC-SHA384 cipher + // suites, not GCM cipher suites with SHA384 as the handshake hash. + std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA384:!aECDH"); + // Walk through all the installed ciphers, seeing if any need to be + // appended to the cipher removal |command|. + for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { + const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i); + const uint16 id = SSL_CIPHER_get_id(cipher); + // Remove any ciphers with a strength of less than 80 bits. Note the NSS + // implementation uses "effective" bits here but OpenSSL does not provide + // this detail. This only impacts Triple DES: reports 112 vs. 168 bits, + // both of which are greater than 80 anyway. + bool disable = SSL_CIPHER_get_bits(cipher, NULL) < 80; + if (!disable) { + disable = std::find(ssl_config_.disabled_cipher_suites.begin(), + ssl_config_.disabled_cipher_suites.end(), id) != + ssl_config_.disabled_cipher_suites.end(); + } + if (disable) { + const char* name = SSL_CIPHER_get_name(cipher); + DVLOG(3) << "Found cipher to remove: '" << name << "', ID: " << id + << " strength: " << SSL_CIPHER_get_bits(cipher, NULL); + command.append(":!"); + command.append(name); + } + } + int rv = SSL_set_cipher_list(ssl_, command.c_str()); + // If this fails (rv = 0) it means there are no ciphers enabled on this SSL. + // This will almost certainly result in the socket failing to complete the + // handshake at which point the appropriate error is bubbled up to the client. + LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') " + "returned " << rv; + return true; +} + +int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, + X509** x509, + EVP_PKEY** pkey) { + DVLOG(3) << "OpenSSL ClientCertRequestCallback called"; + DCHECK(ssl == ssl_); + DCHECK(*x509 == NULL); + DCHECK(*pkey == NULL); + + if (!ssl_config_.send_client_cert) { + // First pass: we know that a client certificate is needed, but we do not + // have one at hand. + client_auth_cert_needed_ = true; + STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl); + for (int i = 0; i < sk_X509_NAME_num(authorities); i++) { + X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i); + unsigned char* str = NULL; + int length = i2d_X509_NAME(ca_name, &str); + cert_authorities_.push_back(std::string( + reinterpret_cast<const char*>(str), + static_cast<size_t>(length))); + OPENSSL_free(str); + } + + return -1; // Suspends handshake. + } + + // Second pass: a client certificate should have been selected. + if (ssl_config_.client_cert.get()) { + // A note about ownership: FetchClientCertPrivateKey() increments + // the reference count of the EVP_PKEY. Ownership of this reference + // is passed directly to OpenSSL, which will release the reference + // using EVP_PKEY_free() when the SSL object is destroyed. + OpenSSLClientKeyStore::ScopedEVP_PKEY privkey; + if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey( + ssl_config_.client_cert.get(), &privkey)) { + // TODO(joth): (copied from NSS) We should wait for server certificate + // verification before sending our credentials. See http://crbug.com/13934 + *x509 = X509Certificate::DupOSCertHandle( + ssl_config_.client_cert->os_cert_handle()); + *pkey = privkey.release(); + return 1; + } + LOG(WARNING) << "Client cert found without private key"; + } + + // Send no client certificate. + return 0; +} + +// SSLClientSocket methods + +bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { + ssl_info->Reset(); + if (!server_cert_.get()) + return false; + + ssl_info->cert = server_cert_verify_result_.verified_cert; + ssl_info->cert_status = server_cert_verify_result_.cert_status; + ssl_info->is_issued_by_known_root = + server_cert_verify_result_.is_issued_by_known_root; + ssl_info->public_key_hashes = + server_cert_verify_result_.public_key_hashes; + ssl_info->client_cert_sent = + ssl_config_.send_client_cert && ssl_config_.client_cert.get(); + ssl_info->channel_id_sent = WasChannelIDSent(); + + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); + CHECK(cipher); + ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); + const COMP_METHOD* compression = SSL_get_current_compression(ssl_); + + ssl_info->connection_status = EncodeSSLConnectionStatus( + SSL_CIPHER_get_id(cipher), + compression ? compression->type : 0, + GetNetSSLVersion(ssl_)); + + bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_); + if (!peer_supports_renego_ext) + ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; + UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", + implicit_cast<int>(peer_supports_renego_ext), 2); + + if (ssl_config_.version_fallback) + ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; + + ssl_info->handshake_type = SSL_session_reused(ssl_) ? + SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL; + + DVLOG(3) << "Encoded connection status: cipher suite = " + << SSLConnectionStatusToCipherSuite(ssl_info->connection_status) + << " version = " + << SSLConnectionStatusToVersion(ssl_info->connection_status); + return true; +} + +void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { + cert_request_info->host_and_port = host_and_port_.ToString(); + cert_request_info->cert_authorities = cert_authorities_; +} + +int SSLClientSocketOpenSSL::ExportKeyingMaterial( + const base::StringPiece& label, + bool has_context, const base::StringPiece& context, + unsigned char* out, unsigned int outlen) { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + + int rv = SSL_export_keying_material( + ssl_, out, outlen, const_cast<char*>(label.data()), + label.size(), + reinterpret_cast<unsigned char*>(const_cast<char*>(context.data())), + context.length(), + context.length() > 0); + + if (rv != 1) { + int ssl_error = SSL_get_error(ssl_, rv); + LOG(ERROR) << "Failed to export keying material;" + << " returned " << rv + << ", SSL error code " << ssl_error; + return MapOpenSSLError(ssl_error, err_tracer); + } + return OK; +} + +int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { + return ERR_NOT_IMPLEMENTED; +} + +SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( + std::string* proto, std::string* server_protos) { + *proto = npn_proto_; + *server_protos = server_protos_; + return npn_status_; +} + +ServerBoundCertService* +SSLClientSocketOpenSSL::GetServerBoundCertService() const { + return NULL; +} + +void SSLClientSocketOpenSSL::DoReadCallback(int rv) { + // Since Run may result in Read being called, clear |user_read_callback_| + // up front. + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + base::ResetAndReturn(&user_read_callback_).Run(rv); +} + +void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + base::ResetAndReturn(&user_write_callback_).Run(rv); +} + +// StreamSocket implementation. +int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT); + + // Set up new ssl object. + if (!Init()) { + int result = ERR_UNEXPECTED; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result); + return result; + } + + // Set SSL to client mode. Handshake happens in the loop below. + SSL_set_connect_state(ssl_); + + GotoState(STATE_HANDSHAKE); + int rv = DoHandshakeLoop(net::OK); + if (rv == ERR_IO_PENDING) { + user_connect_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + } + + return rv > OK ? OK : rv; +} + +void SSLClientSocketOpenSSL::Disconnect() { + if (ssl_) { + // Calling SSL_shutdown prevents the session from being marked as + // unresumable. + SSL_shutdown(ssl_); + SSL_free(ssl_); + ssl_ = NULL; + } + if (transport_bio_) { + BIO_free_all(transport_bio_); + transport_bio_ = NULL; + } + + // Shut down anything that may call us back. + verifier_.reset(); + transport_->socket()->Disconnect(); + + // Null all callbacks, delete all buffers. + transport_send_busy_ = false; + send_buffer_ = NULL; + transport_recv_busy_ = false; + transport_recv_eof_ = false; + recv_buffer_ = NULL; + + user_connect_callback_.Reset(); + user_read_callback_.Reset(); + user_write_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + + server_cert_verify_result_.Reset(); + completed_handshake_ = false; + + cert_authorities_.clear(); + client_auth_cert_needed_ = false; +} + +int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + case STATE_VERIFY_CERT: + DCHECK(rv == OK); + rv = DoVerifyCert(rv); + break; + case STATE_VERIFY_CERT_COMPLETE: + rv = DoVerifyCertComplete(rv); + break; + case STATE_NONE: + default: + rv = ERR_UNEXPECTED; + NOTREACHED() << "unexpected state" << state; + break; + } + + bool network_moved = DoTransportIO(); + if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) { + // In general we exit the loop if rv is ERR_IO_PENDING. In this + // special case we keep looping even if rv is ERR_IO_PENDING because + // the transport IO may allow DoHandshake to make progress. + rv = OK; // This causes us to stay in the loop. + } + } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); + return rv; +} + +int SSLClientSocketOpenSSL::DoHandshake() { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + int net_error = net::OK; + int rv = SSL_do_handshake(ssl_); + + if (client_auth_cert_needed_) { + net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + // If the handshake already succeeded (because the server requests but + // doesn't require a client cert), we need to invalidate the SSL session + // so that we won't try to resume the non-client-authenticated session in + // the next handshake. This will cause the server to ask for a client + // cert again. + if (rv == 1) { + // Remove from session cache but don't clear this connection. + SSL_SESSION* session = SSL_get_session(ssl_); + if (session) { + int rv = SSL_CTX_remove_session(SSL_get_SSL_CTX(ssl_), session); + LOG_IF(WARNING, !rv) << "Couldn't invalidate SSL session: " << session; + } + } + } else if (rv == 1) { + if (trying_cached_session_ && logging::DEBUG_MODE) { + DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString() + << " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail"); + } + // SSL handshake is completed. Let's verify the certificate. + const bool got_cert = !!UpdateServerCert(); + DCHECK(got_cert); + net_log_.AddEvent( + NetLog::TYPE_SSL_CERTIFICATES_RECEIVED, + base::Bind(&NetLogX509CertificateCallback, + base::Unretained(server_cert_.get()))); + GotoState(STATE_VERIFY_CERT); + } else { + int ssl_error = SSL_get_error(ssl_, rv); + net_error = MapOpenSSLError(ssl_error, err_tracer); + + // If not done, stay in this state + if (net_error == ERR_IO_PENDING) { + GotoState(STATE_HANDSHAKE); + } else { + LOG(ERROR) << "handshake failed; returned " << rv + << ", SSL error code " << ssl_error + << ", net_error " << net_error; + net_log_.AddEvent( + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogSSLErrorCallback(net_error, ssl_error)); + } + } + return net_error; +} + +// SelectNextProtoCallback is called by OpenSSL during the handshake. If the +// server supports NPN, selects a protocol from the list that the server +// provides. According to third_party/openssl/openssl/ssl/ssl_lib.c, the +// callback can assume that |in| is syntactically valid. +int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen) { +#if defined(OPENSSL_NPN_NEGOTIATED) + if (ssl_config_.next_protos.empty()) { + *out = reinterpret_cast<uint8*>( + const_cast<char*>(kDefaultSupportedNPNProtocol)); + *outlen = arraysize(kDefaultSupportedNPNProtocol) - 1; + npn_status_ = kNextProtoUnsupported; + return SSL_TLSEXT_ERR_OK; + } + + // Assume there's no overlap between our protocols and the server's list. + npn_status_ = kNextProtoNoOverlap; + + // For each protocol in server preference order, see if we support it. + for (unsigned int i = 0; i < inlen; i += in[i] + 1) { + for (std::vector<std::string>::const_iterator + j = ssl_config_.next_protos.begin(); + j != ssl_config_.next_protos.end(); ++j) { + if (in[i] == j->size() && + memcmp(&in[i + 1], j->data(), in[i]) == 0) { + // We found a match. + *out = const_cast<unsigned char*>(in) + i + 1; + *outlen = in[i]; + npn_status_ = kNextProtoNegotiated; + break; + } + } + if (npn_status_ == kNextProtoNegotiated) + break; + } + + // If we didn't find a protocol, we select the first one from our list. + if (npn_status_ == kNextProtoNoOverlap) { + *out = reinterpret_cast<uint8*>(const_cast<char*>( + ssl_config_.next_protos[0].data())); + *outlen = ssl_config_.next_protos[0].size(); + } + + npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen); + server_protos_.assign(reinterpret_cast<const char*>(in), inlen); + DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_; +#endif + return SSL_TLSEXT_ERR_OK; +} + +int SSLClientSocketOpenSSL::DoVerifyCert(int result) { + DCHECK(server_cert_.get()); + GotoState(STATE_VERIFY_CERT_COMPLETE); + + CertStatus cert_status; + if (ssl_config_.IsAllowedBadCert(server_cert_.get(), &cert_status)) { + VLOG(1) << "Received an expected bad cert with status: " << cert_status; + server_cert_verify_result_.Reset(); + server_cert_verify_result_.cert_status = cert_status; + server_cert_verify_result_.verified_cert = server_cert_; + return OK; + } + + int flags = 0; + if (ssl_config_.rev_checking_enabled) + flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED; + if (ssl_config_.verify_ev_cert) + flags |= CertVerifier::VERIFY_EV_CERT; + if (ssl_config_.cert_io_enabled) + flags |= CertVerifier::VERIFY_CERT_IO_ENABLED; + if (ssl_config_.rev_checking_required_local_anchors) + flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS; + verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); + return verifier_->Verify( + server_cert_.get(), + host_and_port_.host(), + flags, + NULL /* no CRL set */, + &server_cert_verify_result_, + base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, + base::Unretained(this)), + net_log_); +} + +int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { + verifier_.reset(); + + if (result == OK) { + // TODO(joth): Work out if we need to remember the intermediate CA certs + // when the server sends them to us, and do so here. + } else { + DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result) + << " (" << result << ")"; + } + + completed_handshake_ = true; + // Exit DoHandshakeLoop and return the result to the caller to Connect. + DCHECK_EQ(STATE_NONE, next_handshake_state_); + return result; +} + +X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { + if (server_cert_.get()) + return server_cert_.get(); + + crypto::ScopedOpenSSL<X509, X509_free> cert(SSL_get_peer_certificate(ssl_)); + if (!cert.get()) { + LOG(WARNING) << "SSL_get_peer_certificate returned NULL"; + return NULL; + } + + // Unlike SSL_get_peer_certificate, SSL_get_peer_cert_chain does not + // increment the reference so sk_X509_free does not need to be called. + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); + X509Certificate::OSCertHandles intermediates; + if (chain) { + for (int i = 0; i < sk_X509_num(chain); ++i) + intermediates.push_back(sk_X509_value(chain, i)); + } + server_cert_ = X509Certificate::CreateFromHandle(cert.get(), intermediates); + DCHECK(server_cert_.get()); + + return server_cert_.get(); +} + +bool SSLClientSocketOpenSSL::DoTransportIO() { + bool network_moved = false; + int rv; + // Read and write as much data as possible. The loop is necessary because + // Write() may return synchronously. + do { + rv = BufferSend(); + if (rv != ERR_IO_PENDING && rv != 0) + network_moved = true; + } while (rv > 0); + if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING) + network_moved = true; + return network_moved; +} + +int SSLClientSocketOpenSSL::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + if (!send_buffer_.get()) { + // Get a fresh send buffer out of the send BIO. + size_t max_read = BIO_ctrl_pending(transport_bio_); + if (!max_read) + return 0; // Nothing pending in the OpenSSL write BIO. + send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); + int read_bytes = BIO_read(transport_bio_, send_buffer_->data(), max_read); + DCHECK_GT(read_bytes, 0); + CHECK_EQ(static_cast<int>(max_read), read_bytes); + } + + int rv = transport_->socket()->Write( + send_buffer_.get(), + send_buffer_->BytesRemaining(), + base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + TransportWriteComplete(rv); + } + return rv; +} + +void SSLClientSocketOpenSSL::BufferSendComplete(int result) { + transport_send_busy_ = false; + TransportWriteComplete(result); + OnSendComplete(result); +} + +void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { + DCHECK(ERR_IO_PENDING != result); + if (result < 0) { + // Got a socket write error; close the BIO to indicate this upward. + DVLOG(1) << "TransportWriteComplete error " << result; + (void)BIO_shutdown_wr(transport_bio_); + BIO_set_mem_eof_return(transport_bio_, 0); + send_buffer_ = NULL; + } else { + DCHECK(send_buffer_.get()); + send_buffer_->DidConsume(result); + DCHECK_GE(send_buffer_->BytesRemaining(), 0); + if (send_buffer_->BytesRemaining() <= 0) + send_buffer_ = NULL; + } +} + +int SSLClientSocketOpenSSL::BufferRecv(void) { + if (transport_recv_busy_) + return ERR_IO_PENDING; + + // Determine how much was requested from |transport_bio_| that was not + // actually available. + size_t requested = BIO_ctrl_get_read_request(transport_bio_); + if (requested == 0) { + // This is not a perfect match of error codes, as no operation is + // actually pending. However, returning 0 would be interpreted as + // a possible sign of EOF, which is also an inappropriate match. + return ERR_IO_PENDING; + } + + // Known Issue: While only reading |requested| data is the more correct + // implementation, it has the downside of resulting in frequent reads: + // One read for the SSL record header (~5 bytes) and one read for the SSL + // record body. Rather than issuing these reads to the underlying socket + // (and constantly allocating new IOBuffers), a single Read() request to + // fill |transport_bio_| is issued. As long as an SSL client socket cannot + // be gracefully shutdown (via SSL close alerts) and re-used for non-SSL + // traffic, this over-subscribed Read()ing will not cause issues. + size_t max_write = BIO_ctrl_get_write_guarantee(transport_bio_); + if (!max_write) + return ERR_IO_PENDING; + + recv_buffer_ = new IOBuffer(max_write); + int rv = transport_->socket()->Read( + recv_buffer_.get(), + max_write, + base::Bind(&SSLClientSocketOpenSSL::BufferRecvComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + TransportReadComplete(rv); + } + return rv; +} + +void SSLClientSocketOpenSSL::BufferRecvComplete(int result) { + TransportReadComplete(result); + OnRecvComplete(result); +} + +void SSLClientSocketOpenSSL::TransportReadComplete(int result) { + DCHECK(ERR_IO_PENDING != result); + if (result <= 0) { + DVLOG(1) << "TransportReadComplete result " << result; + // Received 0 (end of file) or an error. Either way, bubble it up to the + // SSL layer via the BIO. TODO(joth): consider stashing the error code, to + // relay up to the SSL socket client (i.e. via DoReadCallback). + if (result == 0) + transport_recv_eof_ = true; + BIO_set_mem_eof_return(transport_bio_, 0); + (void)BIO_shutdown_wr(transport_bio_); + } else { + DCHECK(recv_buffer_.get()); + int ret = BIO_write(transport_bio_, recv_buffer_->data(), result); + // A write into a memory BIO should always succeed. + CHECK_EQ(result, ret); + } + recv_buffer_ = NULL; + transport_recv_busy_ = false; +} + +void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { + if (!user_connect_callback_.is_null()) { + CompletionCallback c = user_connect_callback_; + user_connect_callback_.Reset(); + c.Run(rv > OK ? OK : rv); + } +} + +void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) { + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + DoConnectCallback(rv); + } +} + +void SSLClientSocketOpenSSL::OnSendComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // OnSendComplete may need to call DoPayloadRead while the renegotiation + // handshake is in progress. + int rv_read = ERR_IO_PENDING; + int rv_write = ERR_IO_PENDING; + bool network_moved; + do { + if (user_read_buf_.get()) + rv_read = DoPayloadRead(); + if (user_write_buf_.get()) + rv_write = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv_read == ERR_IO_PENDING && rv_write == ERR_IO_PENDING && + (user_read_buf_.get() || user_write_buf_.get()) && network_moved); + + // Performing the Read callback may cause |this| to be deleted. If this + // happens, the Write callback should not be invoked. Guard against this by + // holding a WeakPtr to |this| and ensuring it's still valid. + base::WeakPtr<SSLClientSocketOpenSSL> guard(weak_factory_.GetWeakPtr()); + if (user_read_buf_.get() && rv_read != ERR_IO_PENDING) + DoReadCallback(rv_read); + + if (!guard.get()) + return; + + if (user_write_buf_.get() && rv_write != ERR_IO_PENDING) + DoWriteCallback(rv_write); +} + +void SSLClientSocketOpenSSL::OnRecvComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_.get()) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +bool SSLClientSocketOpenSSL::IsConnected() const { + // If the handshake has not yet completed. + if (!completed_handshake_) + return false; + // If an asynchronous operation is still pending. + if (user_read_buf_.get() || user_write_buf_.get()) + return true; + + return transport_->socket()->IsConnected(); +} + +bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const { + // If the handshake has not yet completed. + if (!completed_handshake_) + return false; + // If an asynchronous operation is still pending. + if (user_read_buf_.get() || user_write_buf_.get()) + return false; + // If there is data waiting to be sent, or data read from the network that + // has not yet been consumed. + if (BIO_ctrl_pending(transport_bio_) > 0 || + BIO_ctrl_wpending(transport_bio_) > 0) { + return false; + } + + return transport_->socket()->IsConnectedAndIdle(); +} + +int SSLClientSocketOpenSSL::GetPeerAddress(IPEndPoint* addressList) const { + return transport_->socket()->GetPeerAddress(addressList); +} + +int SSLClientSocketOpenSSL::GetLocalAddress(IPEndPoint* addressList) const { + return transport_->socket()->GetLocalAddress(addressList); +} + +const BoundNetLog& SSLClientSocketOpenSSL::NetLog() const { + return net_log_; +} + +void SSLClientSocketOpenSSL::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SSLClientSocketOpenSSL::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SSLClientSocketOpenSSL::WasEverUsed() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->WasEverUsed(); + + NOTREACHED(); + return false; +} + +bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) + return transport_->socket()->UsingTCPFastOpen(); + + NOTREACHED(); + return false; +} + +// Socket methods + +int SSLClientSocketOpenSSL::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + + return rv; +} + +int SSLClientSocketOpenSSL::DoReadLoop(int result) { + if (result < 0) + return result; + + bool network_moved; + int rv; + do { + rv = DoPayloadRead(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + + return rv; +} + +int SSLClientSocketOpenSSL::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + + return rv; +} + +int SSLClientSocketOpenSSL::DoWriteLoop(int result) { + if (result < 0) + return result; + + bool network_moved; + int rv; + do { + rv = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + + return rv; +} + +bool SSLClientSocketOpenSSL::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SSLClientSocketOpenSSL::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + +int SSLClientSocketOpenSSL::DoPayloadRead() { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + + int rv; + if (pending_read_error_ != kNoPendingReadResult) { + rv = pending_read_error_; + pending_read_error_ = kNoPendingReadResult; + if (rv == 0) { + net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, + rv, user_read_buf_->data()); + } + return rv; + } + + int total_bytes_read = 0; + do { + rv = SSL_read(ssl_, user_read_buf_->data() + total_bytes_read, + user_read_buf_len_ - total_bytes_read); + if (rv > 0) + total_bytes_read += rv; + } while (total_bytes_read < user_read_buf_len_ && rv > 0); + + if (total_bytes_read == user_read_buf_len_) { + rv = total_bytes_read; + } else { + // Otherwise, an error occurred (rv <= 0). The error needs to be handled + // immediately, while the OpenSSL errors are still available in + // thread-local storage. However, the handled/remapped error code should + // only be returned if no application data was already read; if it was, the + // error code should be deferred until the next call of DoPayloadRead. + // + // If no data was read, |*next_result| will point to the return value of + // this function. If at least some data was read, |*next_result| will point + // to |pending_read_error_|, to be returned in a future call to + // DoPayloadRead() (e.g.: after the current data is handled). + int *next_result = &rv; + if (total_bytes_read > 0) { + pending_read_error_ = rv; + rv = total_bytes_read; + next_result = &pending_read_error_; + } + + if (client_auth_cert_needed_) { + *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + } else if (*next_result < 0) { + int err = SSL_get_error(ssl_, *next_result); + *next_result = MapOpenSSLError(err, err_tracer); + if (rv > 0 && *next_result == ERR_IO_PENDING) { + // If at least some data was read from SSL_read(), do not treat + // insufficient data as an error to return in the next call to + // DoPayloadRead() - instead, let the call fall through to check + // SSL_read() again. This is because DoTransportIO() may complete + // in between the next call to DoPayloadRead(), and thus it is + // important to check SSL_read() on subsequent invocations to see + // if a complete record may now be read. + *next_result = kNoPendingReadResult; + } + } + } + + if (rv >= 0) { + net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, + user_read_buf_->data()); + } + return rv; +} + +int SSLClientSocketOpenSSL::DoPayloadWrite() { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); + + if (rv >= 0) { + net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, + user_write_buf_->data()); + return rv; + } + + int err = SSL_get_error(ssl_, rv); + return MapOpenSSLError(err, err_tracer); +} + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h new file mode 100644 index 00000000000..f66d95cc69d --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -0,0 +1,203 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_ +#define NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_ + +#include <string> + +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/io_buffer.h" +#include "net/cert/cert_verify_result.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/ssl/ssl_config_service.h" + +// Avoid including misc OpenSSL headers, i.e.: +// <openssl/bio.h> +typedef struct bio_st BIO; +// <openssl/evp.h> +typedef struct evp_pkey_st EVP_PKEY; +// <openssl/ssl.h> +typedef struct ssl_st SSL; +// <openssl/x509.h> +typedef struct x509_st X509; + +namespace net { + +class CertVerifier; +class SingleRequestCertVerifier; +class SSLCertRequestInfo; +class SSLInfo; + +// An SSL client socket implemented with OpenSSL. +class SSLClientSocketOpenSSL : public SSLClientSocket { + public: + // Takes ownership of the transport_socket, which may already be connected. + // The given hostname will be compared with the name(s) in the server's + // certificate during the SSL handshake. ssl_config specifies the SSL + // settings. + SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context); + virtual ~SSLClientSocketOpenSSL(); + + const HostPortPair& host_and_port() const { return host_and_port_; } + const std::string& ssl_session_cache_shard() const { + return ssl_session_cache_shard_; + } + + // Callback from the SSL layer that indicates the remote server is requesting + // a certificate for this client. + int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey); + + // Callback from the SSL layer to check which NPN protocol we are supporting + int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen, + const unsigned char* in, unsigned int inlen); + + // SSLClientSocket implementation. + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) OVERRIDE; + virtual NextProtoStatus GetNextProto(std::string* proto, + std::string* server_protos) OVERRIDE; + virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + + // SSLSocket implementation. + virtual int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) OVERRIDE; + virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + private: + bool Init(); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + bool DoTransportIO(); + int DoHandshake(); + int DoVerifyCert(int result); + int DoVerifyCertComplete(int result); + void DoConnectCallback(int result); + X509Certificate* UpdateServerCert(); + + void OnHandshakeIOComplete(int result); + void OnSendComplete(int result); + void OnRecvComplete(int result); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoPayloadRead(); + int DoPayloadWrite(); + + int BufferSend(); + int BufferRecv(); + void BufferSendComplete(int result); + void BufferRecvComplete(int result); + void TransportWriteComplete(int result); + void TransportReadComplete(int result); + + bool transport_send_busy_; + bool transport_recv_busy_; + bool transport_recv_eof_; + + scoped_refptr<DrainableIOBuffer> send_buffer_; + scoped_refptr<IOBuffer> recv_buffer_; + + CompletionCallback user_connect_callback_; + CompletionCallback user_read_callback_; + CompletionCallback user_write_callback_; + + base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + // Used by DoPayloadRead() when attempting to fill the caller's buffer with + // as much data as possible without blocking. + // If DoPayloadRead() encounters an error after having read some data, stores + // the result to return on the *next* call to DoPayloadRead(). A value > 0 + // indicates there is no pending result, otherwise 0 indicates EOF and < 0 + // indicates an error. + int pending_read_error_; + + // Set when handshake finishes. + scoped_refptr<X509Certificate> server_cert_; + CertVerifyResult server_cert_verify_result_; + bool completed_handshake_; + + // Stores client authentication information between ClientAuthHandler and + // GetSSLCertRequestInfo calls. + bool client_auth_cert_needed_; + // List of DER-encoded X.509 DistinguishedName of certificate authorities + // allowed by the server. + std::vector<std::string> cert_authorities_; + + CertVerifier* const cert_verifier_; + scoped_ptr<SingleRequestCertVerifier> verifier_; + + // OpenSSL stuff + SSL* ssl_; + BIO* transport_bio_; + + scoped_ptr<ClientSocketHandle> transport_; + const HostPortPair host_and_port_; + SSLConfig ssl_config_; + // ssl_session_cache_shard_ is an opaque string that partitions the SSL + // session cache. i.e. sessions created with one value will not attempt to + // resume on the socket with a different value. + const std::string ssl_session_cache_shard_; + + // Used for session cache diagnostics. + bool trying_cached_session_; + + enum State { + STATE_NONE, + STATE_HANDSHAKE, + STATE_VERIFY_CERT, + STATE_VERIFY_CERT_COMPLETE, + }; + State next_handshake_state_; + NextProtoStatus npn_status_; + std::string npn_proto_; + std::string server_protos_; + BoundNetLog net_log_; +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_ diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc new file mode 100644 index 00000000000..04f899903ac --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -0,0 +1,279 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_client_socket.h" + +#include <errno.h> +#include <string.h> + +#include <openssl/bio.h> +#include <openssl/bn.h> +#include <openssl/evp.h> +#include <openssl/pem.h> +#include <openssl/rsa.h> + +#include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_handle.h" +#include "base/values.h" +#include "crypto/openssl_util.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/test_completion_callback.h" +#include "net/base/test_data_directory.h" +#include "net/cert/mock_cert_verifier.h" +#include "net/cert/test_root_certs.h" +#include "net/dns/host_resolver.h" +#include "net/http/transport_security_state.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/tcp_client_socket.h" +#include "net/ssl/openssl_client_key_store.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_config_service.h" +#include "net/test/cert_test_util.h" +#include "net/test/spawned_test_server/spawned_test_server.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { + +typedef OpenSSLClientKeyStore::ScopedEVP_PKEY ScopedEVP_PKEY; + +// BIO_free is a macro, it can't be used as a template parameter. +void BIO_free_func(BIO* bio) { + BIO_free(bio); +} + +typedef crypto::ScopedOpenSSL<BIO, BIO_free_func> ScopedBIO; +typedef crypto::ScopedOpenSSL<RSA, RSA_free> ScopedRSA; +typedef crypto::ScopedOpenSSL<BIGNUM, BN_free> ScopedBIGNUM; + +const SSLConfig kDefaultSSLConfig; + +// Loads a PEM-encoded private key file into a scoped EVP_PKEY object. +// |filepath| is the private key file path. +// |*pkey| is reset to the new EVP_PKEY on success, untouched otherwise. +// Returns true on success, false on failure. +bool LoadPrivateKeyOpenSSL( + const base::FilePath& filepath, + OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) { + std::string data; + if (!file_util::ReadFileToString(filepath, &data)) { + LOG(ERROR) << "Could not read private key file: " + << filepath.value() << ": " << strerror(errno); + return false; + } + ScopedBIO bio( + BIO_new_mem_buf( + const_cast<char*>(reinterpret_cast<const char*>(data.data())), + static_cast<int>(data.size()))); + if (!bio.get()) { + LOG(ERROR) << "Could not allocate BIO for buffer?"; + return false; + } + EVP_PKEY* result = PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL); + if (result == NULL) { + LOG(ERROR) << "Could not decode private key file: " + << filepath.value(); + return false; + } + pkey->reset(result); + return true; +} + +class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { + public: + SSLClientSocketOpenSSLClientAuthTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new net::MockCertVerifier), + transport_security_state_(new net::TransportSecurityState) { + cert_verifier_->set_default_result(net::OK); + context_.cert_verifier = cert_verifier_.get(); + context_.transport_security_state = transport_security_state_.get(); + key_store_ = net::OpenSSLClientKeyStore::GetInstance(); + } + + virtual ~SSLClientSocketOpenSSLClientAuthTest() { + key_store_->Flush(); + } + + protected: + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config) { + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket(connection.Pass(), + host_and_port, + ssl_config, + context_); + } + + // Connect to a HTTPS test server. + bool ConnectToTestServer(SpawnedTestServer::SSLOptions& ssl_options) { + test_server_.reset(new SpawnedTestServer(SpawnedTestServer::TYPE_HTTPS, + ssl_options, + base::FilePath())); + if (!test_server_->Start()) { + LOG(ERROR) << "Could not start SpawnedTestServer"; + return false; + } + + if (!test_server_->GetAddressList(&addr_)) { + LOG(ERROR) << "Could not get SpawnedTestServer address list"; + return false; + } + + transport_.reset(new TCPClientSocket( + addr_, &log_, NetLog::Source())); + int rv = callback_.GetResult( + transport_->Connect(callback_.callback())); + if (rv != OK) { + LOG(ERROR) << "Could not connect to SpawnedTestServer"; + return false; + } + return true; + } + + // Record a certificate's private key to ensure it can be used + // by the OpenSSL-based SSLClientSocket implementation. + // |ssl_config| provides a client certificate. + // |private_key| must be an EVP_PKEY for the corresponding private key. + // Returns true on success, false on failure. + bool RecordPrivateKey(SSLConfig& ssl_config, + EVP_PKEY* private_key) { + return key_store_->RecordClientCertPrivateKey( + ssl_config.client_cert.get(), private_key); + } + + // Create an SSLClientSocket object and use it to connect to a test + // server, then wait for connection results. This must be called after + // a succesful ConnectToTestServer() call. + // |ssl_config| the SSL configuration to use. + // |result| will retrieve the ::Connect() result value. + // Returns true on succes, false otherwise. Success means that the socket + // could be created and its Connect() was called, not that the connection + // itself was a success. + bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config, + int* result) { + sock_ = CreateSSLClientSocket(transport_.Pass(), + test_server_->host_port_pair(), + ssl_config); + + if (sock_->IsConnected()) { + LOG(ERROR) << "SSL Socket prematurely connected"; + return false; + } + + *result = callback_.GetResult(sock_->Connect(callback_.callback())); + return true; + } + + + // Check that the client certificate was sent. + // Returns true on success. + bool CheckSSLClientSocketSentCert() { + SSLInfo ssl_info; + sock_->GetSSLInfo(&ssl_info); + return ssl_info.client_cert_sent; + } + + ClientSocketFactory* socket_factory_; + scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; + SSLClientSocketContext context_; + OpenSSLClientKeyStore* key_store_; + scoped_ptr<SpawnedTestServer> test_server_; + AddressList addr_; + TestCompletionCallback callback_; + CapturingNetLog log_; + scoped_ptr<StreamSocket> transport_; + scoped_ptr<SSLClientSocket> sock_; +}; + +// Connect to a server requesting client authentication, do not send +// any client certificates. It should refuse the connection. +TEST_F(SSLClientSocketOpenSSLClientAuthTest, NoCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + base::FilePath certs_dir = GetTestCertsDirectory(); + SSLConfig ssl_config = kDefaultSSLConfig; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); + EXPECT_FALSE(sock_->IsConnected()); +} + +// Connect to a server requesting client authentication, and send it +// an empty certificate. It should refuse the connection. +TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendEmptyCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + base::FilePath certs_dir = GetTestCertsDirectory(); + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.send_client_cert = true; + ssl_config.client_cert = NULL; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock_->IsConnected()); +} + +// Connect to a server requesting client authentication. Send it a +// matching certificate. It should allow the connection. +TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + base::FilePath certs_dir = GetTestCertsDirectory(); + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.send_client_cert = true; + ssl_config.client_cert = ImportCertFromFile(certs_dir, "client_1.pem"); + + // This is required to ensure that signing works with the client + // certificate's private key. + OpenSSLClientKeyStore::ScopedEVP_PKEY client_private_key; + ASSERT_TRUE(LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"), + &client_private_key)); + EXPECT_TRUE(RecordPrivateKey(ssl_config, client_private_key.get())); + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock_->IsConnected()); + + EXPECT_TRUE(CheckSSLClientSocketSentCert()); + + sock_->Disconnect(); + EXPECT_FALSE(sock_->IsConnected()); +} + +} // namespace +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc new file mode 100644 index 00000000000..d07c76ffb49 --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -0,0 +1,664 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_client_socket_pool.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/metrics/field_trial.h" +#include "base/metrics/histogram.h" +#include "base/metrics/sparse_histogram.h" +#include "base/values.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" +#include "net/http/http_proxy_client_socket.h" +#include "net/http/http_proxy_client_socket_pool.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socks_client_socket_pool.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/transport_client_socket_pool.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_info.h" + +namespace net { + +SSLSocketParams::SSLSocketParams( + const scoped_refptr<TransportSocketParams>& transport_params, + const scoped_refptr<SOCKSSocketParams>& socks_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + ProxyServer::Scheme proxy, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + PrivacyMode privacy_mode, + int load_flags, + bool force_spdy_over_ssl, + bool want_spdy_over_npn) + : transport_params_(transport_params), + http_proxy_params_(http_proxy_params), + socks_params_(socks_params), + proxy_(proxy), + host_and_port_(host_and_port), + ssl_config_(ssl_config), + privacy_mode_(privacy_mode), + load_flags_(load_flags), + force_spdy_over_ssl_(force_spdy_over_ssl), + want_spdy_over_npn_(want_spdy_over_npn), + ignore_limits_(false) { + switch (proxy_) { + case ProxyServer::SCHEME_DIRECT: + DCHECK(transport_params_.get() != NULL); + DCHECK(http_proxy_params_.get() == NULL); + DCHECK(socks_params_.get() == NULL); + ignore_limits_ = transport_params_->ignore_limits(); + break; + case ProxyServer::SCHEME_HTTP: + case ProxyServer::SCHEME_HTTPS: + DCHECK(transport_params_.get() == NULL); + DCHECK(http_proxy_params_.get() != NULL); + DCHECK(socks_params_.get() == NULL); + ignore_limits_ = http_proxy_params_->ignore_limits(); + break; + case ProxyServer::SCHEME_SOCKS4: + case ProxyServer::SCHEME_SOCKS5: + DCHECK(transport_params_.get() == NULL); + DCHECK(http_proxy_params_.get() == NULL); + DCHECK(socks_params_.get() != NULL); + ignore_limits_ = socks_params_->ignore_limits(); + break; + default: + LOG(DFATAL) << "unknown proxy type"; + break; + } +} + +SSLSocketParams::~SSLSocketParams() {} + +// Timeout for the SSL handshake portion of the connect. +static const int kSSLHandshakeTimeoutInSeconds = 30; + +SSLConnectJob::SSLConnectJob(const std::string& group_name, + const scoped_refptr<SSLSocketParams>& params, + const base::TimeDelta& timeout_duration, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + const SSLClientSocketContext& context, + Delegate* delegate, + NetLog* net_log) + : ConnectJob(group_name, + timeout_duration, + delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + params_(params), + transport_pool_(transport_pool), + socks_pool_(socks_pool), + http_proxy_pool_(http_proxy_pool), + client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + context_(context.cert_verifier, + context.server_bound_cert_service, + context.transport_security_state, + (params->privacy_mode() == kPrivacyModeEnabled + ? "pm/" + context.ssl_session_cache_shard + : context.ssl_session_cache_shard)), + callback_(base::Bind(&SSLConnectJob::OnIOComplete, + base::Unretained(this))) {} + +SSLConnectJob::~SSLConnectJob() {} + +LoadState SSLConnectJob::GetLoadState() const { + switch (next_state_) { + case STATE_TUNNEL_CONNECT_COMPLETE: + if (transport_socket_handle_->socket()) + return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL; + // else, fall through. + case STATE_TRANSPORT_CONNECT: + case STATE_TRANSPORT_CONNECT_COMPLETE: + case STATE_SOCKS_CONNECT: + case STATE_SOCKS_CONNECT_COMPLETE: + case STATE_TUNNEL_CONNECT: + return transport_socket_handle_->GetLoadState(); + case STATE_SSL_CONNECT: + case STATE_SSL_CONNECT_COMPLETE: + return LOAD_STATE_SSL_HANDSHAKE; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle* handle) { + // Headers in |error_response_info_| indicate a proxy tunnel setup + // problem. See DoTunnelConnectComplete. + if (error_response_info_.headers.get()) { + handle->set_pending_http_proxy_connection( + transport_socket_handle_.release()); + } + handle->set_ssl_error_response_info(error_response_info_); + if (!connect_timing_.ssl_start.is_null()) + handle->set_is_ssl_error(true); +} + +void SSLConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this|. +} + +int SSLConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = DoTransportConnectComplete(rv); + break; + case STATE_SOCKS_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoSOCKSConnect(); + break; + case STATE_SOCKS_CONNECT_COMPLETE: + rv = DoSOCKSConnectComplete(rv); + break; + case STATE_TUNNEL_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTunnelConnect(); + break; + case STATE_TUNNEL_CONNECT_COMPLETE: + rv = DoTunnelConnectComplete(rv); + break; + case STATE_SSL_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoSSLConnect(); + break; + case STATE_SSL_CONNECT_COMPLETE: + rv = DoSSLConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + +int SSLConnectJob::DoTransportConnect() { + DCHECK(transport_pool_); + + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<TransportSocketParams> transport_params = + params_->transport_params(); + return transport_socket_handle_->Init( + group_name(), transport_params, + transport_params->destination().priority(), callback_, transport_pool_, + net_log()); +} + +int SSLConnectJob::DoTransportConnectComplete(int result) { + if (result == OK) + next_state_ = STATE_SSL_CONNECT; + + return result; +} + +int SSLConnectJob::DoSOCKSConnect() { + DCHECK(socks_pool_); + next_state_ = STATE_SOCKS_CONNECT_COMPLETE; + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params(); + return transport_socket_handle_->Init( + group_name(), socks_params, socks_params->destination().priority(), + callback_, socks_pool_, net_log()); +} + +int SSLConnectJob::DoSOCKSConnectComplete(int result) { + if (result == OK) + next_state_ = STATE_SSL_CONNECT; + + return result; +} + +int SSLConnectJob::DoTunnelConnect() { + DCHECK(http_proxy_pool_); + next_state_ = STATE_TUNNEL_CONNECT_COMPLETE; + + transport_socket_handle_.reset(new ClientSocketHandle()); + scoped_refptr<HttpProxySocketParams> http_proxy_params = + params_->http_proxy_params(); + return transport_socket_handle_->Init( + group_name(), http_proxy_params, + http_proxy_params->destination().priority(), callback_, http_proxy_pool_, + net_log()); +} + +int SSLConnectJob::DoTunnelConnectComplete(int result) { + // Extract the information needed to prompt for appropriate proxy + // authentication so that when ClientSocketPoolBaseHelper calls + // |GetAdditionalErrorState|, we can easily set the state. + if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { + error_response_info_ = transport_socket_handle_->ssl_error_response_info(); + } else if (result == ERR_PROXY_AUTH_REQUESTED || + result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { + StreamSocket* socket = transport_socket_handle_->socket(); + HttpProxyClientSocket* tunnel_socket = + static_cast<HttpProxyClientSocket*>(socket); + error_response_info_ = *tunnel_socket->GetConnectResponseInfo(); + } + if (result < 0) + return result; + + next_state_ = STATE_SSL_CONNECT; + return result; +} + +int SSLConnectJob::DoSSLConnect() { + next_state_ = STATE_SSL_CONNECT_COMPLETE; + // Reset the timeout to just the time allowed for the SSL handshake. + ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds)); + + // If the handle has a fresh socket, get its connect start and DNS times. + // This should always be the case. + const LoadTimingInfo::ConnectTiming& socket_connect_timing = + transport_socket_handle_->connect_timing(); + if (!transport_socket_handle_->is_reused() && + !socket_connect_timing.connect_start.is_null()) { + // Overwriting |connect_start| serves two purposes - it adjusts timing so + // |connect_start| doesn't include dns times, and it adjusts the time so + // as not to include time spent waiting for an idle socket. + connect_timing_.connect_start = socket_connect_timing.connect_start; + connect_timing_.dns_start = socket_connect_timing.dns_start; + connect_timing_.dns_end = socket_connect_timing.dns_end; + } + + connect_timing_.ssl_start = base::TimeTicks::Now(); + + ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( + transport_socket_handle_.Pass(), + params_->host_and_port(), + params_->ssl_config(), + context_); + return ssl_socket_->Connect(callback_); +} + +int SSLConnectJob::DoSSLConnectComplete(int result) { + connect_timing_.ssl_end = base::TimeTicks::Now(); + + SSLClientSocket::NextProtoStatus status = + SSLClientSocket::kNextProtoUnsupported; + std::string proto; + std::string server_protos; + // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket + // that hasn't had SSL_ImportFD called on it. If we get a certificate error + // here, then we know that we called SSL_ImportFD. + if (result == OK || IsCertificateError(result)) + status = ssl_socket_->GetNextProto(&proto, &server_protos); + + // If we want spdy over npn, make sure it succeeded. + if (status == SSLClientSocket::kNextProtoNegotiated) { + ssl_socket_->set_was_npn_negotiated(true); + NextProto protocol_negotiated = + SSLClientSocket::NextProtoFromString(proto); + ssl_socket_->set_protocol_negotiated(protocol_negotiated); + // If we negotiated a SPDY version, it must have been present in + // SSLConfig::next_protos. + // TODO(mbelshe): Verify this. + if (protocol_negotiated >= kProtoSPDYMinimumVersion && + protocol_negotiated <= kProtoSPDYMaximumVersion) { + ssl_socket_->set_was_spdy_negotiated(true); + } + } + if (params_->want_spdy_over_npn() && !ssl_socket_->was_spdy_negotiated()) + return ERR_NPN_NEGOTIATION_FAILED; + + // Spdy might be turned on by default, or it might be over npn. + bool using_spdy = params_->force_spdy_over_ssl() || + params_->want_spdy_over_npn(); + + if (result == OK || + ssl_socket_->IgnoreCertError(result, params_->load_flags())) { + DCHECK(!connect_timing_.ssl_start.is_null()); + base::TimeDelta connect_duration = + connect_timing_.ssl_end - connect_timing_.ssl_start; + if (using_spdy) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency_2", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } +#if defined(SPDY_PROXY_AUTH_ORIGIN) + bool using_data_reduction_proxy = params_->host_and_port().Equals( + HostPortPair::FromURL(GURL(SPDY_PROXY_AUTH_ORIGIN))); + if (using_data_reduction_proxy) { + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.SSL_Connection_Latency_DataReductionProxy", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } +#endif + + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_2", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + + SSLInfo ssl_info; + ssl_socket_->GetSSLInfo(&ssl_info); + + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSL_CipherSuite", + SSLConnectionStatusToCipherSuite( + ssl_info.connection_status)); + + if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Resume_Handshake", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } else if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_FULL) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Full_Handshake", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } + + const std::string& host = params_->host_and_port().host(); + bool is_google = host == "google.com" || + (host.size() > 11 && + host.rfind(".google.com") == host.size() - 11); + if (is_google) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google2", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google_" + "Resume_Handshake", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } else if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_FULL) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google_" + "Full_Handshake", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(1), + 100); + } + } + } + + if (result == OK || IsCertificateError(result)) { + SetSocket(ssl_socket_.PassAs<StreamSocket>()); + } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { + error_response_info_.cert_request_info = new SSLCertRequestInfo; + ssl_socket_->GetSSLCertRequestInfo( + error_response_info_.cert_request_info.get()); + } + + return result; +} + +int SSLConnectJob::ConnectInternal() { + switch (params_->proxy()) { + case ProxyServer::SCHEME_DIRECT: + next_state_ = STATE_TRANSPORT_CONNECT; + break; + case ProxyServer::SCHEME_HTTP: + case ProxyServer::SCHEME_HTTPS: + next_state_ = STATE_TUNNEL_CONNECT; + break; + case ProxyServer::SCHEME_SOCKS4: + case ProxyServer::SCHEME_SOCKS5: + next_state_ = STATE_SOCKS_CONNECT; + break; + default: + NOTREACHED() << "unknown proxy type"; + break; + } + return DoLoop(OK); +} + +SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + const SSLClientSocketContext& context, + NetLog* net_log) + : transport_pool_(transport_pool), + socks_pool_(socks_pool), + http_proxy_pool_(http_proxy_pool), + client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + context_(context), + net_log_(net_log) { + base::TimeDelta max_transport_timeout = base::TimeDelta(); + base::TimeDelta pool_timeout; + if (transport_pool_) + max_transport_timeout = transport_pool_->ConnectionTimeout(); + if (socks_pool_) { + pool_timeout = socks_pool_->ConnectionTimeout(); + if (pool_timeout > max_transport_timeout) + max_transport_timeout = pool_timeout; + } + if (http_proxy_pool_) { + pool_timeout = http_proxy_pool_->ConnectionTimeout(); + if (pool_timeout > max_transport_timeout) + max_transport_timeout = pool_timeout; + } + timeout_ = max_transport_timeout + + base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds); +} + +SSLClientSocketPool::SSLClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + CertVerifier* cert_verifier, + ServerBoundCertService* server_bound_cert_service, + TransportSecurityState* transport_security_state, + const std::string& ssl_session_cache_shard, + ClientSocketFactory* client_socket_factory, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + SSLConfigService* ssl_config_service, + NetLog* net_log) + : transport_pool_(transport_pool), + socks_pool_(socks_pool), + http_proxy_pool_(http_proxy_pool), + base_(max_sockets, max_sockets_per_group, histograms, + ClientSocketPool::unused_idle_socket_timeout(), + ClientSocketPool::used_idle_socket_timeout(), + new SSLConnectJobFactory(transport_pool, + socks_pool, + http_proxy_pool, + client_socket_factory, + host_resolver, + SSLClientSocketContext( + cert_verifier, + server_bound_cert_service, + transport_security_state, + ssl_session_cache_shard), + net_log)), + ssl_config_service_(ssl_config_service) { + if (ssl_config_service_.get()) + ssl_config_service_->AddObserver(this); + if (transport_pool_) + transport_pool_->AddLayeredPool(this); + if (socks_pool_) + socks_pool_->AddLayeredPool(this); + if (http_proxy_pool_) + http_proxy_pool_->AddLayeredPool(this); +} + +SSLClientSocketPool::~SSLClientSocketPool() { + if (http_proxy_pool_) + http_proxy_pool_->RemoveLayeredPool(this); + if (socks_pool_) + socks_pool_->RemoveLayeredPool(this); + if (transport_pool_) + transport_pool_->RemoveLayeredPool(this); + if (ssl_config_service_.get()) + ssl_config_service_->RemoveObserver(this); +} + +scoped_ptr<ConnectJob> +SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return scoped_ptr<ConnectJob>( + new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), + transport_pool_, socks_pool_, http_proxy_pool_, + client_socket_factory_, host_resolver_, + context_, delegate, net_log_)); +} + +base::TimeDelta +SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() const { + return timeout_; +} + +int SSLClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { + const scoped_refptr<SSLSocketParams>* casted_socket_params = + static_cast<const scoped_refptr<SSLSocketParams>*>(socket_params); + + return base_.RequestSocket(group_name, *casted_socket_params, priority, + handle, callback, net_log); +} + +void SSLClientSocketPool::RequestSockets( + const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) { + const scoped_refptr<SSLSocketParams>* casted_params = + static_cast<const scoped_refptr<SSLSocketParams>*>(params); + + base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); +} + +void SSLClientSocketPool::CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); +} + +void SSLClientSocketPool::FlushWithError(int error) { + base_.FlushWithError(error); +} + +bool SSLClientSocketPool::IsStalled() const { + return base_.IsStalled() || + (transport_pool_ && transport_pool_->IsStalled()) || + (socks_pool_ && socks_pool_->IsStalled()) || + (http_proxy_pool_ && http_proxy_pool_->IsStalled()); +} + +void SSLClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int SSLClientSocketPool::IdleSocketCount() const { + return base_.idle_socket_count(); +} + +int SSLClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState SSLClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + +base::DictionaryValue* SSLClientSocketPool::GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const { + base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); + if (include_nested_pools) { + base::ListValue* list = new base::ListValue(); + if (transport_pool_) { + list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", + "transport_socket_pool", + false)); + } + if (socks_pool_) { + list->Append(socks_pool_->GetInfoAsValue("socks_pool", + "socks_pool", + true)); + } + if (http_proxy_pool_) { + list->Append(http_proxy_pool_->GetInfoAsValue("http_proxy_pool", + "http_proxy_pool", + true)); + } + dict->Set("nested_pools", list); + } + return dict; +} + +base::TimeDelta SSLClientSocketPool::ConnectionTimeout() const { + return base_.ConnectionTimeout(); +} + +ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { + return base_.histograms(); +} + +void SSLClientSocketPool::OnSSLConfigChanged() { + FlushWithError(ERR_NETWORK_CHANGED); +} + +bool SSLClientSocketPool::CloseOneIdleConnection() { + if (base_.CloseOneIdleSocket()) + return true; + return base_.CloseOneIdleConnectionInLayeredPool(); +} + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h new file mode 100644 index 00000000000..431a1b7ceea --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -0,0 +1,297 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/privacy_mode.h" +#include "net/dns/host_resolver.h" +#include "net/http/http_response_info.h" +#include "net/proxy/proxy_server.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/ssl_client_socket.h" +#include "net/ssl/ssl_config_service.h" + +namespace net { + +class CertVerifier; +class ClientSocketFactory; +class ConnectJobFactory; +class HostPortPair; +class HttpProxyClientSocketPool; +class HttpProxySocketParams; +class SOCKSClientSocketPool; +class SOCKSSocketParams; +class SSLClientSocket; +class TransportClientSocketPool; +class TransportSecurityState; +class TransportSocketParams; + +// SSLSocketParams only needs the socket params for the transport socket +// that will be used (denoted by |proxy|). +class NET_EXPORT_PRIVATE SSLSocketParams + : public base::RefCounted<SSLSocketParams> { + public: + SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params, + const scoped_refptr<SOCKSSocketParams>& socks_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + ProxyServer::Scheme proxy, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + PrivacyMode privacy_mode, + int load_flags, + bool force_spdy_over_ssl, + bool want_spdy_over_npn); + + const scoped_refptr<TransportSocketParams>& transport_params() { + return transport_params_; + } + const scoped_refptr<HttpProxySocketParams>& http_proxy_params() { + return http_proxy_params_; + } + const scoped_refptr<SOCKSSocketParams>& socks_params() { + return socks_params_; + } + ProxyServer::Scheme proxy() const { return proxy_; } + const HostPortPair& host_and_port() const { return host_and_port_; } + const SSLConfig& ssl_config() const { return ssl_config_; } + PrivacyMode privacy_mode() const { return privacy_mode_; } + int load_flags() const { return load_flags_; } + bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; } + bool want_spdy_over_npn() const { return want_spdy_over_npn_; } + bool ignore_limits() const { return ignore_limits_; } + + private: + friend class base::RefCounted<SSLSocketParams>; + ~SSLSocketParams(); + + const scoped_refptr<TransportSocketParams> transport_params_; + const scoped_refptr<HttpProxySocketParams> http_proxy_params_; + const scoped_refptr<SOCKSSocketParams> socks_params_; + const ProxyServer::Scheme proxy_; + const HostPortPair host_and_port_; + const SSLConfig ssl_config_; + const PrivacyMode privacy_mode_; + const int load_flags_; + const bool force_spdy_over_ssl_; + const bool want_spdy_over_npn_; + bool ignore_limits_; + + DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); +}; + +// SSLConnectJob handles the SSL handshake after setting up the underlying +// connection as specified in the params. +class SSLConnectJob : public ConnectJob { + public: + SSLConnectJob( + const std::string& group_name, + const scoped_refptr<SSLSocketParams>& params, + const base::TimeDelta& timeout_duration, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + const SSLClientSocketContext& context, + Delegate* delegate, + NetLog* net_log); + virtual ~SSLConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const OVERRIDE; + + virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE; + + private: + enum State { + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_SOCKS_CONNECT, + STATE_SOCKS_CONNECT_COMPLETE, + STATE_TUNNEL_CONNECT, + STATE_TUNNEL_CONNECT_COMPLETE, + STATE_SSL_CONNECT, + STATE_SSL_CONNECT_COMPLETE, + STATE_NONE, + }; + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + int DoSOCKSConnect(); + int DoSOCKSConnectComplete(int result); + int DoTunnelConnect(); + int DoTunnelConnectComplete(int result); + int DoSSLConnect(); + int DoSSLConnectComplete(int result); + + // Starts the SSL connection process. Returns OK on success and + // ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal() OVERRIDE; + + scoped_refptr<SSLSocketParams> params_; + TransportClientSocketPool* const transport_pool_; + SOCKSClientSocketPool* const socks_pool_; + HttpProxyClientSocketPool* const http_proxy_pool_; + ClientSocketFactory* const client_socket_factory_; + HostResolver* const host_resolver_; + + const SSLClientSocketContext context_; + + State next_state_; + CompletionCallback callback_; + scoped_ptr<ClientSocketHandle> transport_socket_handle_; + scoped_ptr<SSLClientSocket> ssl_socket_; + + HttpResponseInfo error_response_info_; + + DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); +}; + +class NET_EXPORT_PRIVATE SSLClientSocketPool + : public ClientSocketPool, + public LayeredPool, + public SSLConfigService::Observer { + public: + // Only the pools that will be used are required. i.e. if you never + // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. + SSLClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + CertVerifier* cert_verifier, + ServerBoundCertService* server_bound_cert_service, + TransportSecurityState* transport_security_state, + const std::string& ssl_session_cache_shard, + ClientSocketFactory* client_socket_factory, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + SSLConfigService* ssl_config_service, + NetLog* net_log); + + virtual ~SSLClientSocketPool(); + + // ClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) OVERRIDE; + + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; + + virtual void FlushWithError(int error) OVERRIDE; + + virtual bool IsStalled() const OVERRIDE; + + virtual void CloseIdleSockets() OVERRIDE; + + virtual int IdleSocketCount() const OVERRIDE; + + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE; + + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE; + + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE; + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + + // LayeredPool implementation. + virtual bool CloseOneIdleConnection() OVERRIDE; + + private: + typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; + + // SSLConfigService::Observer implementation. + + // When the user changes the SSL config, we flush all idle sockets so they + // won't get re-used. + virtual void OnSSLConfigChanged() OVERRIDE; + + class SSLConnectJobFactory : public PoolBase::ConnectJobFactory { + public: + SSLConnectJobFactory( + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + const SSLClientSocketContext& context, + NetLog* net_log); + + virtual ~SSLConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const OVERRIDE; + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + + private: + TransportClientSocketPool* const transport_pool_; + SOCKSClientSocketPool* const socks_pool_; + HttpProxyClientSocketPool* const http_proxy_pool_; + ClientSocketFactory* const client_socket_factory_; + HostResolver* const host_resolver_; + const SSLClientSocketContext context_; + base::TimeDelta timeout_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); + }; + + TransportClientSocketPool* const transport_pool_; + SOCKSClientSocketPool* const socks_pool_; + HttpProxyClientSocketPool* const http_proxy_pool_; + PoolBase base_; + const scoped_refptr<SSLConfigService> ssl_config_service_; + + DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams); + +} // namespace net + +#endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc new file mode 100644 index 00000000000..280f6e7af1a --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -0,0 +1,857 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/http/http_proxy_client_socket_pool.h" + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/strings/string_util.h" +#include "base/strings/utf_string_conversions.h" +#include "base/time/time.h" +#include "net/base/auth.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/cert/cert_verifier.h" +#include "net/dns/mock_host_resolver.h" +#include "net/http/http_auth_handler_factory.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_server_properties_impl.h" +#include "net/proxy/proxy_service.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/next_proto.h" +#include "net/socket/socket_test_util.h" +#include "net/spdy/spdy_session.h" +#include "net/spdy/spdy_session_pool.h" +#include "net/spdy/spdy_test_util_common.h" +#include "net/ssl/ssl_config_service_defaults.h" +#include "net/test/test_certificate_data.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; + +// Make sure |handle|'s load times are set correctly. DNS and connect start +// times comes from mock client sockets in these tests, so primarily serves to +// check those times were copied, and ssl times / connect end are set correctly. +void TestLoadTimingInfo(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + // None of these tests use a NetLog. + EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasTimes( + load_timing_info.connect_timing, + CONNECT_TIMING_HAS_SSL_TIMES | CONNECT_TIMING_HAS_DNS_TIMES); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +// Just like TestLoadTimingInfo, except DNS times are expected to be null, for +// tests over proxies that do DNS lookups themselves. +void TestLoadTimingInfoNoDns(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + // None of these tests use a NetLog. + EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + EXPECT_FALSE(load_timing_info.socket_reused); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_SSL_TIMES); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +class SSLClientSocketPoolTest + : public testing::Test, + public ::testing::WithParamInterface<NextProto> { + protected: + SSLClientSocketPoolTest() + : proxy_service_(ProxyService::CreateDirect()), + ssl_config_service_(new SSLConfigServiceDefaults), + http_auth_handler_factory_( + HttpAuthHandlerFactory::CreateDefault(&host_resolver_)), + session_(CreateNetworkSession()), + direct_transport_socket_params_( + new TransportSocketParams(HostPortPair("host", 443), + MEDIUM, + false, + false, + OnHostResolutionCallback())), + transport_histograms_("MockTCP"), + transport_socket_pool_(kMaxSockets, + kMaxSocketsPerGroup, + &transport_histograms_, + &socket_factory_), + proxy_transport_socket_params_( + new TransportSocketParams(HostPortPair("proxy", 443), + MEDIUM, + false, + false, + OnHostResolutionCallback())), + socks_socket_params_( + new SOCKSSocketParams(proxy_transport_socket_params_, + true, + HostPortPair("sockshost", 443), + MEDIUM)), + socks_histograms_("MockSOCKS"), + socks_socket_pool_(kMaxSockets, + kMaxSocketsPerGroup, + &socks_histograms_, + &transport_socket_pool_), + http_proxy_socket_params_( + new HttpProxySocketParams(proxy_transport_socket_params_, + NULL, + GURL("http://host"), + std::string(), + HostPortPair("host", 80), + session_->http_auth_cache(), + session_->http_auth_handler_factory(), + session_->spdy_session_pool(), + true)), + http_proxy_histograms_("MockHttpProxy"), + http_proxy_socket_pool_(kMaxSockets, + kMaxSocketsPerGroup, + &http_proxy_histograms_, + &host_resolver_, + &transport_socket_pool_, + NULL, + NULL) { + scoped_refptr<SSLConfigService> ssl_config_service( + new SSLConfigServiceDefaults); + ssl_config_service->GetSSLConfig(&ssl_config_); + } + + void CreatePool(bool transport_pool, bool http_proxy_pool, bool socks_pool) { + ssl_histograms_.reset(new ClientSocketPoolHistograms("SSLUnitTest")); + pool_.reset(new SSLClientSocketPool( + kMaxSockets, + kMaxSocketsPerGroup, + ssl_histograms_.get(), + NULL /* host_resolver */, + NULL /* cert_verifier */, + NULL /* server_bound_cert_service */, + NULL /* transport_security_state */, + std::string() /* ssl_session_cache_shard */, + &socket_factory_, + transport_pool ? &transport_socket_pool_ : NULL, + socks_pool ? &socks_socket_pool_ : NULL, + http_proxy_pool ? &http_proxy_socket_pool_ : NULL, + NULL, + NULL)); + } + + scoped_refptr<SSLSocketParams> SSLParams(ProxyServer::Scheme proxy, + bool want_spdy_over_npn) { + return make_scoped_refptr(new SSLSocketParams( + proxy == ProxyServer::SCHEME_DIRECT ? direct_transport_socket_params_ + : NULL, + proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL, + proxy == ProxyServer::SCHEME_HTTP ? http_proxy_socket_params_ : NULL, + proxy, + HostPortPair("host", 443), + ssl_config_, + kPrivacyModeDisabled, + 0, + false, + want_spdy_over_npn)); + } + + void AddAuthToCache() { + const base::string16 kFoo(ASCIIToUTF16("foo")); + const base::string16 kBar(ASCIIToUTF16("bar")); + session_->http_auth_cache()->Add(GURL("http://proxy:443/"), + "MyRealm1", + HttpAuth::AUTH_SCHEME_BASIC, + "Basic realm=MyRealm1", + AuthCredentials(kFoo, kBar), + "/"); + } + + HttpNetworkSession* CreateNetworkSession() { + HttpNetworkSession::Params params; + params.host_resolver = &host_resolver_; + params.cert_verifier = cert_verifier_.get(); + params.transport_security_state = transport_security_state_.get(); + params.proxy_service = proxy_service_.get(); + params.client_socket_factory = &socket_factory_; + params.ssl_config_service = ssl_config_service_.get(); + params.http_auth_handler_factory = http_auth_handler_factory_.get(); + params.http_server_properties = + http_server_properties_.GetWeakPtr(); + params.enable_spdy_compression = false; + params.spdy_default_protocol = GetParam(); + return new HttpNetworkSession(params); + } + + void TestIPPoolingDisabled(SSLSocketDataProvider* ssl); + + MockClientSocketFactory socket_factory_; + MockCachingHostResolver host_resolver_; + scoped_ptr<CertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; + const scoped_ptr<ProxyService> proxy_service_; + const scoped_refptr<SSLConfigService> ssl_config_service_; + const scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory_; + HttpServerPropertiesImpl http_server_properties_; + const scoped_refptr<HttpNetworkSession> session_; + + scoped_refptr<TransportSocketParams> direct_transport_socket_params_; + ClientSocketPoolHistograms transport_histograms_; + MockTransportClientSocketPool transport_socket_pool_; + + scoped_refptr<TransportSocketParams> proxy_transport_socket_params_; + + scoped_refptr<SOCKSSocketParams> socks_socket_params_; + ClientSocketPoolHistograms socks_histograms_; + MockSOCKSClientSocketPool socks_socket_pool_; + + scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_; + ClientSocketPoolHistograms http_proxy_histograms_; + HttpProxyClientSocketPool http_proxy_socket_pool_; + + SSLConfig ssl_config_; + scoped_ptr<ClientSocketPoolHistograms> ssl_histograms_; + scoped_ptr<SSLClientSocketPool> pool_; +}; + +INSTANTIATE_TEST_CASE_P( + NextProto, + SSLClientSocketPoolTest, + testing::Values(kProtoSPDY2, kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, + kProtoHTTP2Draft04)); + +TEST_P(SSLClientSocketPoolTest, TCPFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + int rv = handle.Init("a", params, MEDIUM, CompletionCallback(), pool_.get(), + BoundNetLog()); + EXPECT_EQ(ERR_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, TCPFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, BasicDirect) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); +} + +TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); +} + +TEST_P(SSLClientSocketPoolTest, DirectCertError) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, ERR_CERT_COMMON_NAME_INVALID); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); +} + +TEST_P(SSLClientSocketPoolTest, DirectSSLError) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, DirectWithNPN) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.SetNextProto(kProtoHTTP11); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); +} + +TEST_P(SSLClientSocketPoolTest, DirectNoSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.SetNextProto(kProtoHTTP11); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + true); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_NPN_NEGOTIATION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_TRUE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, DirectGotSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.SetNextProto(GetParam()); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + true); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); + + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); + std::string proto; + std::string server_protos; + ssl_socket->GetNextProto(&proto, &server_protos); + EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto)); +} + +TEST_P(SSLClientSocketPoolTest, DirectGotBonusSPDY) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.SetNextProto(GetParam()); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT, + true); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfo(handle); + + SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); + EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); + std::string proto; + std::string server_protos; + ssl_socket->GetNextProto(&proto, &server_protos); + EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto)); +} + +TEST_P(SSLClientSocketPoolTest, SOCKSFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, SOCKSFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, SOCKSBasic) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + // SOCKS5 generally has no DNS times, but the mock SOCKS5 sockets used here + // don't go through the real logic, unlike in the HTTP proxy tests. + TestLoadTimingInfo(handle); +} + +TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) { + StaticSocketDataProvider data; + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + // SOCKS5 generally has no DNS times, but the mock SOCKS5 sockets used here + // don't go through the real logic, unlike in the HTTP proxy tests. + TestLoadTimingInfo(handle); +} + +TEST_P(SSLClientSocketPoolTest, HttpProxyFail) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, HttpProxyFailAsync) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED)); + socket_factory_.AddSocketDataProvider(&data); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); +} + +TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(SYNCHRONOUS, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + AddAuthToCache(); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoNoDns(handle); +} + +TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + socket_factory_.AddSocketDataProvider(&data); + AddAuthToCache(); + SSLSocketDataProvider ssl(ASYNC, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoNoDns(handle); +} + +TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) { + MockWrite writes[] = { + MockWrite("CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockRead reads[] = { + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("Content-Length: 10\r\n\r\n"), + MockRead("0123456789"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP, + false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init( + "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_FALSE(handle.is_ssl_error()); + const HttpResponseInfo& tunnel_info = handle.ssl_error_response_info(); + EXPECT_EQ(tunnel_info.headers->response_code(), 407); + scoped_ptr<ClientSocketHandle> tunnel_handle( + handle.release_pending_http_proxy_connection()); + EXPECT_TRUE(tunnel_handle->socket()); + EXPECT_FALSE(tunnel_handle->socket()->IsConnected()); +} + +TEST_P(SSLClientSocketPoolTest, IPPooling) { + const int kTestPort = 80; + struct TestHosts { + std::string name; + std::string iplist; + SpdySessionKey key; + AddressList addresses; + } test_hosts[] = { + { "www.webkit.org", "192.0.2.33,192.168.0.1,192.168.0.5" }, + { "code.google.com", "192.168.0.2,192.168.0.3,192.168.0.5" }, + { "js.webkit.org", "192.168.0.4,192.168.0.1,192.0.2.33" }, + }; + + host_resolver_.set_synchronous_mode(true); + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) { + host_resolver_.rules()->AddIPLiteralRule( + test_hosts[i].name, test_hosts[i].iplist, std::string()); + + // This test requires that the HostResolver cache be populated. Normal + // code would have done this already, but we do it manually. + HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); + host_resolver_.Resolve(info, &test_hosts[i].addresses, CompletionCallback(), + NULL, BoundNetLog()); + + // Setup a SpdySessionKey + test_hosts[i].key = SpdySessionKey( + HostPortPair(test_hosts[i].name, kTestPort), ProxyServer::Direct(), + kPrivacyModeDisabled); + } + + MockRead reads[] = { + MockRead(ASYNC, ERR_IO_PENDING), + }; + StaticSocketDataProvider data(reads, arraysize(reads), NULL, 0); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.cert = X509Certificate::CreateFromBytes( + reinterpret_cast<const char*>(webkit_der), sizeof(webkit_der)); + ssl.SetNextProto(GetParam()); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(true /* tcp pool */, false, false); + base::WeakPtr<SpdySession> spdy_session = + CreateSecureSpdySession(session_, test_hosts[0].key, BoundNetLog()); + + EXPECT_TRUE( + HasSpdySession(session_->spdy_session_pool(), test_hosts[0].key)); + EXPECT_FALSE( + HasSpdySession(session_->spdy_session_pool(), test_hosts[1].key)); + EXPECT_TRUE( + HasSpdySession(session_->spdy_session_pool(), test_hosts[2].key)); + + session_->spdy_session_pool()->CloseAllSessions(); +} + +void SSLClientSocketPoolTest::TestIPPoolingDisabled( + SSLSocketDataProvider* ssl) { + const int kTestPort = 80; + struct TestHosts { + std::string name; + std::string iplist; + SpdySessionKey key; + AddressList addresses; + } test_hosts[] = { + { "www.webkit.org", "192.0.2.33,192.168.0.1,192.168.0.5" }, + { "js.webkit.com", "192.168.0.4,192.168.0.1,192.0.2.33" }, + }; + + TestCompletionCallback callback; + int rv; + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) { + host_resolver_.rules()->AddIPLiteralRule( + test_hosts[i].name, test_hosts[i].iplist, std::string()); + + // This test requires that the HostResolver cache be populated. Normal + // code would have done this already, but we do it manually. + HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); + rv = host_resolver_.Resolve(info, &test_hosts[i].addresses, + callback.callback(), NULL, BoundNetLog()); + EXPECT_EQ(OK, callback.GetResult(rv)); + + // Setup a SpdySessionKey + test_hosts[i].key = SpdySessionKey( + HostPortPair(test_hosts[i].name, kTestPort), ProxyServer::Direct(), + kPrivacyModeDisabled); + } + + MockRead reads[] = { + MockRead(ASYNC, ERR_IO_PENDING), + }; + StaticSocketDataProvider data(reads, arraysize(reads), NULL, 0); + socket_factory_.AddSocketDataProvider(&data); + socket_factory_.AddSSLSocketDataProvider(ssl); + + CreatePool(true /* tcp pool */, false, false); + base::WeakPtr<SpdySession> spdy_session = + CreateSecureSpdySession(session_, test_hosts[0].key, BoundNetLog()); + + EXPECT_TRUE( + HasSpdySession(session_->spdy_session_pool(), test_hosts[0].key)); + EXPECT_FALSE( + HasSpdySession(session_->spdy_session_pool(), test_hosts[1].key)); + + session_->spdy_session_pool()->CloseAllSessions(); +} + +// Verifies that an SSL connection with client authentication disables SPDY IP +// pooling. +TEST_P(SSLClientSocketPoolTest, IPPoolingClientCert) { + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.cert = X509Certificate::CreateFromBytes( + reinterpret_cast<const char*>(webkit_der), sizeof(webkit_der)); + ssl.client_cert_sent = true; + ssl.SetNextProto(GetParam()); + TestIPPoolingDisabled(&ssl); +} + +// Verifies that an SSL connection with channel ID disables SPDY IP pooling. +TEST_P(SSLClientSocketPoolTest, IPPoolingChannelID) { + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.channel_id_sent = true; + ssl.SetNextProto(GetParam()); + TestIPPoolingDisabled(&ssl); +} + +// It would be nice to also test the timeouts in SSLClientSocketPool. + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc new file mode 100644 index 00000000000..f791928580f --- /dev/null +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -0,0 +1,1798 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_client_socket.h" + +#include "base/callback_helpers.h" +#include "base/memory/ref_counted.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/test_completion_callback.h" +#include "net/base/test_data_directory.h" +#include "net/cert/mock_cert_verifier.h" +#include "net/cert/test_root_certs.h" +#include "net/dns/host_resolver.h" +#include "net/http/transport_security_state.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/tcp_client_socket.h" +#include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_config_service.h" +#include "net/test/cert_test_util.h" +#include "net/test/spawned_test_server/spawned_test_server.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +//----------------------------------------------------------------------------- + +namespace net { + +namespace { + +const SSLConfig kDefaultSSLConfig; + +// WrappedStreamSocket is a base class that wraps an existing StreamSocket, +// forwarding the Socket and StreamSocket interfaces to the underlying +// transport. +// This is to provide a common base class for subclasses to override specific +// StreamSocket methods for testing, while still communicating with a 'real' +// StreamSocket. +class WrappedStreamSocket : public StreamSocket { + public: + explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) + : transport_(transport.Pass()) {} + virtual ~WrappedStreamSocket() {} + + // StreamSocket implementation: + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return transport_->Connect(callback); + } + virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } + virtual bool IsConnected() const OVERRIDE { + return transport_->IsConnected(); + } + virtual bool IsConnectedAndIdle() const OVERRIDE { + return transport_->IsConnectedAndIdle(); + } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return transport_->GetPeerAddress(address); + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return transport_->GetLocalAddress(address); + } + virtual const BoundNetLog& NetLog() const OVERRIDE { + return transport_->NetLog(); + } + virtual void SetSubresourceSpeculation() OVERRIDE { + transport_->SetSubresourceSpeculation(); + } + virtual void SetOmniboxSpeculation() OVERRIDE { + transport_->SetOmniboxSpeculation(); + } + virtual bool WasEverUsed() const OVERRIDE { + return transport_->WasEverUsed(); + } + virtual bool UsingTCPFastOpen() const OVERRIDE { + return transport_->UsingTCPFastOpen(); + } + virtual bool WasNpnNegotiated() const OVERRIDE { + return transport_->WasNpnNegotiated(); + } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return transport_->GetNegotiatedProtocol(); + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return transport_->GetSSLInfo(ssl_info); + } + + // Socket implementation: + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return transport_->Read(buf, buf_len, callback); + } + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return transport_->Write(buf, buf_len, callback); + } + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { + return transport_->SetReceiveBufferSize(size); + } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { + return transport_->SetSendBufferSize(size); + } + + protected: + scoped_ptr<StreamSocket> transport_; +}; + +// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that +// will ensure a certain amount of data is internally buffered before +// satisfying a Read() request. It exists to mimic OS-level internal +// buffering, but in a way to guarantee that X number of bytes will be +// returned to callers of Read(), regardless of how quickly the OS receives +// them from the TestServer. +class ReadBufferingStreamSocket : public WrappedStreamSocket { + public: + explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); + virtual ~ReadBufferingStreamSocket() {} + + // Socket implementation: + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + + // Sets the internal buffer to |size|. This must not be greater than + // the largest value supplied to Read() - that is, it does not handle + // having "leftovers" at the end of Read(). + // Each call to Read() will be prevented from completion until at least + // |size| data has been read. + // Set to 0 to turn off buffering, causing Read() to transparently + // read via the underlying transport. + void SetBufferSize(int size); + + private: + enum State { + STATE_NONE, + STATE_READ, + STATE_READ_COMPLETE, + }; + + int DoLoop(int result); + int DoRead(); + int DoReadComplete(int result); + void OnReadCompleted(int result); + + State state_; + scoped_refptr<GrowableIOBuffer> read_buffer_; + int buffer_size_; + + scoped_refptr<IOBuffer> user_read_buf_; + CompletionCallback user_read_callback_; +}; + +ReadBufferingStreamSocket::ReadBufferingStreamSocket( + scoped_ptr<StreamSocket> transport) + : WrappedStreamSocket(transport.Pass()), + read_buffer_(new GrowableIOBuffer()), + buffer_size_(0) {} + +void ReadBufferingStreamSocket::SetBufferSize(int size) { + DCHECK(!user_read_buf_.get()); + buffer_size_ = size; + read_buffer_->SetCapacity(size); +} + +int ReadBufferingStreamSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (buffer_size_ == 0) + return transport_->Read(buf, buf_len, callback); + + if (buf_len < buffer_size_) + return ERR_UNEXPECTED; + + state_ = STATE_READ; + user_read_buf_ = buf; + int result = DoLoop(OK); + if (result == ERR_IO_PENDING) + user_read_callback_ = callback; + else + user_read_buf_ = NULL; + return result; +} + +int ReadBufferingStreamSocket::DoLoop(int result) { + int rv = result; + do { + State current_state = state_; + state_ = STATE_NONE; + switch (current_state) { + case STATE_READ: + rv = DoRead(); + break; + case STATE_READ_COMPLETE: + rv = DoReadComplete(rv); + break; + case STATE_NONE: + default: + NOTREACHED() << "Unexpected state: " << current_state; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && state_ != STATE_NONE); + return rv; +} + +int ReadBufferingStreamSocket::DoRead() { + state_ = STATE_READ_COMPLETE; + int rv = + transport_->Read(read_buffer_.get(), + read_buffer_->RemainingCapacity(), + base::Bind(&ReadBufferingStreamSocket::OnReadCompleted, + base::Unretained(this))); + return rv; +} + +int ReadBufferingStreamSocket::DoReadComplete(int result) { + state_ = STATE_NONE; + if (result <= 0) + return result; + + read_buffer_->set_offset(read_buffer_->offset() + result); + if (read_buffer_->RemainingCapacity() > 0) { + state_ = STATE_READ; + return OK; + } + + memcpy(user_read_buf_->data(), + read_buffer_->StartOfBuffer(), + read_buffer_->capacity()); + read_buffer_->set_offset(0); + return read_buffer_->capacity(); +} + +void ReadBufferingStreamSocket::OnReadCompleted(int result) { + result = DoLoop(result); + if (result == ERR_IO_PENDING) + return; + + user_read_buf_ = NULL; + base::ResetAndReturn(&user_read_callback_).Run(result); +} + +// Simulates synchronously receiving an error during Read() or Write() +class SynchronousErrorStreamSocket : public WrappedStreamSocket { + public: + explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); + virtual ~SynchronousErrorStreamSocket() {} + + // Socket implementation: + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + + // Sets the next Read() call and all future calls to return |error|. + // If there is already a pending asynchronous read, the configured error + // will not be returned until that asynchronous read has completed and Read() + // is called again. + void SetNextReadError(Error error) { + DCHECK_GE(0, error); + have_read_error_ = true; + pending_read_error_ = error; + } + + // Sets the next Write() call and all future calls to return |error|. + // If there is already a pending asynchronous write, the configured error + // will not be returned until that asynchronous write has completed and + // Write() is called again. + void SetNextWriteError(Error error) { + DCHECK_GE(0, error); + have_write_error_ = true; + pending_write_error_ = error; + } + + private: + bool have_read_error_; + int pending_read_error_; + + bool have_write_error_; + int pending_write_error_; + + DISALLOW_COPY_AND_ASSIGN(SynchronousErrorStreamSocket); +}; + +SynchronousErrorStreamSocket::SynchronousErrorStreamSocket( + scoped_ptr<StreamSocket> transport) + : WrappedStreamSocket(transport.Pass()), + have_read_error_(false), + pending_read_error_(OK), + have_write_error_(false), + pending_write_error_(OK) {} + +int SynchronousErrorStreamSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (have_read_error_) + return pending_read_error_; + return transport_->Read(buf, buf_len, callback); +} + +int SynchronousErrorStreamSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + if (have_write_error_) + return pending_write_error_; + return transport_->Write(buf, buf_len, callback); +} + +// FakeBlockingStreamSocket wraps an existing StreamSocket and simulates the +// underlying transport needing to complete things asynchronously in a +// deterministic manner (e.g.: independent of the TestServer and the OS's +// semantics). +class FakeBlockingStreamSocket : public WrappedStreamSocket { + public: + explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); + virtual ~FakeBlockingStreamSocket() {} + + // Socket implementation: + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return read_state_.RunWrappedFunction(buf, buf_len, callback); + } + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { + return write_state_.RunWrappedFunction(buf, buf_len, callback); + } + + // Causes the next call to Read() to return ERR_IO_PENDING, not completing + // (invoking the callback) until UnblockRead() has been called and the + // underlying transport has completed. + void SetNextReadShouldBlock() { read_state_.SetShouldBlock(); } + void UnblockRead() { read_state_.Unblock(); } + + // Causes the next call to Write() to return ERR_IO_PENDING, not completing + // (invoking the callback) until UnblockWrite() has been called and the + // underlying transport has completed. + void SetNextWriteShouldBlock() { write_state_.SetShouldBlock(); } + void UnblockWrite() { write_state_.Unblock(); } + + private: + // Tracks the state for simulating a blocking Read/Write operation. + class BlockingState { + public: + // Wrapper for the underlying Socket function to call (ie: Read/Write). + typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)> + WrappedSocketFunction; + + explicit BlockingState(const WrappedSocketFunction& function); + ~BlockingState() {} + + // Sets the next call to RunWrappedFunction() to block, returning + // ERR_IO_PENDING and not invoking the user callback until Unblock() is + // called. + void SetShouldBlock(); + + // Unblocks the currently blocked pending function, invoking the user + // callback if the results are immediately available. + // Note: It's not valid to call this unless SetShouldBlock() has been + // called beforehand. + void Unblock(); + + // Performs the wrapped socket function on the underlying transport. If + // configured to block via SetShouldBlock(), then |user_callback| will not + // be invoked until Unblock() has been called. + int RunWrappedFunction(IOBuffer* buf, + int len, + const CompletionCallback& user_callback); + + private: + // Handles completion from the underlying wrapped socket function. + void OnCompleted(int result); + + WrappedSocketFunction wrapped_function_; + bool should_block_; + bool have_result_; + int pending_result_; + CompletionCallback user_callback_; + }; + + BlockingState read_state_; + BlockingState write_state_; + + DISALLOW_COPY_AND_ASSIGN(FakeBlockingStreamSocket); +}; + +FakeBlockingStreamSocket::FakeBlockingStreamSocket( + scoped_ptr<StreamSocket> transport) + : WrappedStreamSocket(transport.Pass()), + read_state_(base::Bind(&Socket::Read, + base::Unretained(transport_.get()))), + write_state_(base::Bind(&Socket::Write, + base::Unretained(transport_.get()))) {} + +FakeBlockingStreamSocket::BlockingState::BlockingState( + const WrappedSocketFunction& function) + : wrapped_function_(function), + should_block_(false), + have_result_(false), + pending_result_(OK) {} + +void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() { + DCHECK(!should_block_); + should_block_ = true; +} + +void FakeBlockingStreamSocket::BlockingState::Unblock() { + DCHECK(should_block_); + should_block_ = false; + + // If the operation is still pending in the underlying transport, immediately + // return - OnCompleted() will handle invoking the callback once the transport + // has completed. + if (!have_result_) + return; + + have_result_ = false; + + base::ResetAndReturn(&user_callback_).Run(pending_result_); +} + +int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction( + IOBuffer* buf, + int len, + const CompletionCallback& callback) { + + // The callback to be called by the underlying transport. Either forward + // directly to the user's callback if not set to block, or intercept it with + // OnCompleted so that the user's callback is not invoked until Unblock() is + // called. + CompletionCallback transport_callback = + !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted, + base::Unretained(this)); + int rv = wrapped_function_.Run(buf, len, transport_callback); + if (should_block_) { + user_callback_ = callback; + // May have completed synchronously. + have_result_ = (rv != ERR_IO_PENDING); + pending_result_ = rv; + return ERR_IO_PENDING; + } + + return rv; +} + +void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) { + if (should_block_) { + // Store the result so that the callback can be invoked once Unblock() is + // called. + have_result_ = true; + pending_result_ = result; + return; + } + + // Otherwise, the Unblock() function was called before the underlying + // transport completed, so run the user's callback immediately. + base::ResetAndReturn(&user_callback_).Run(result); +} + +// CompletionCallback that will delete the associated StreamSocket when +// the callback is invoked. +class DeleteSocketCallback : public TestCompletionCallbackBase { + public: + explicit DeleteSocketCallback(StreamSocket* socket) + : socket_(socket), + callback_(base::Bind(&DeleteSocketCallback::OnComplete, + base::Unretained(this))) {} + virtual ~DeleteSocketCallback() {} + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + if (socket_) { + delete socket_; + socket_ = NULL; + } else { + ADD_FAILURE() << "Deleting socket twice"; + } + SetResult(result); + } + + StreamSocket* socket_; + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); +}; + +class SSLClientSocketTest : public PlatformTest { + public: + SSLClientSocketTest() + : socket_factory_(ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new MockCertVerifier), + transport_security_state_(new TransportSecurityState) { + cert_verifier_->set_default_result(OK); + context_.cert_verifier = cert_verifier_.get(); + context_.transport_security_state = transport_security_state_.get(); + } + + protected: + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config) { + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket( + connection.Pass(), host_and_port, ssl_config, context_); + } + + ClientSocketFactory* socket_factory_; + scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; + SSLClientSocketContext context_; +}; + +//----------------------------------------------------------------------------- + +// LogContainsSSLConnectEndEvent returns true if the given index in the given +// log is an SSL connect end event. The NSS sockets will cork in an attempt to +// merge the first application data record with the Finished message when false +// starting. However, in order to avoid the server timing out the handshake, +// they'll give up waiting for application data and send the Finished after a +// timeout. This means that an SSL connect end event may appear as a socket +// write. +static bool LogContainsSSLConnectEndEvent( + const CapturingNetLog::CapturedEntryList& log, + int i) { + return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) || + LogContainsEvent( + log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); +} +; + +TEST_F(SSLClientSocketTest, Connect) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + +TEST_F(SSLClientSocketTest, ConnectExpired) { + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_EXPIRED); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_EQ(ERR_CERT_DATE_INVALID, rv); + + // Rather than testing whether or not the underlying socket is connected, + // test that the handshake has finished. This is because it may be + // desirable to disconnect the socket before showing a user prompt, since + // the user may take indefinitely long to respond. + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); +} + +TEST_F(SSLClientSocketTest, ConnectMismatched) { + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv); + + // Rather than testing whether or not the underlying socket is connected, + // test that the handshake has finished. This is because it may be + // desirable to disconnect the socket before showing a user prompt, since + // the user may take indefinitely long to respond. + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); +} + +// Attempt to connect to a page which requests a client certificate. It should +// return an error code on connect. +TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + log.GetEntries(&entries); + // Because we prematurely kill the handshake at CertificateRequest, + // the server may still send data (notably the ServerHelloDone) + // after the error is returned. As a result, the SSL_CONNECT may not + // be the last entry. See http://crbug.com/54445. We use + // ExpectLogContainsSomewhere instead of + // LogContainsSSLConnectEndEvent to avoid assuming, e.g., only one + // extra read instead of two. This occurs before the handshake ends, + // so the corking logic of LogContainsSSLConnectEndEvent isn't + // necessary. + // + // TODO(davidben): When SSL_RestartHandshakeAfterCertReq in NSS is + // fixed and we can respond to the first CertificateRequest + // without closing the socket, add a unit test for sending the + // certificate. This test may still be useful as we'll want to close + // the socket on a timeout if the user takes a long time to pick a + // cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832 + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END); + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); + EXPECT_FALSE(sock->IsConnected()); +} + +// Connect to a server requesting optional client authentication. Send it a +// null certificate. It should allow the connection. +// +// TODO(davidben): Also test providing an actual certificate. +TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.send_client_cert = true; + ssl_config.client_cert = NULL; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + EXPECT_FALSE(sock->IsConnected()); + + // Our test server accepts certificate-less connections. + // TODO(davidben): Add a test which requires them and verify the error. + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + + // We responded to the server's certificate request with a Certificate + // message with no client certificate in it. ssl_info.client_cert_sent + // should be false in this case. + SSLInfo ssl_info; + sock->GetSSLInfo(&ssl_info); + EXPECT_FALSE(ssl_info.client_cert_sent); + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + +// TODO(wtc): Add unit tests for IsConnectedAndIdle: +// - Server closes an SSL connection (with a close_notify alert message). +// - Server closes the underlying TCP connection directly. +// - Server sends data unexpectedly. + +TEST_F(SSLClientSocketTest, Read) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + + rv = sock->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + for (;;) { + rv = sock->Read(buf.get(), 4096, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GE(rv, 0); + if (rv <= 0) + break; + } +} + +// Tests that the SSLClientSocket properly handles when the underlying transport +// synchronously returns an error code - such as if an intermediary terminates +// the socket connection uncleanly. +// This is a regression test for http://crbug.com/238536 +TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + static const int kRequestTextSize = + static_cast<int>(arraysize(request_text) - 1); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); + memcpy(request_buffer->data(), request_text, kRequestTextSize); + + rv = callback.GetResult( + sock->Write(request_buffer.get(), kRequestTextSize, callback.callback())); + EXPECT_EQ(kRequestTextSize, rv); + + // Simulate an unclean/forcible shutdown. + raw_transport->SetNextReadError(ERR_CONNECTION_RESET); + + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + + // Note: This test will hang if this bug has regressed. Simply checking that + // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate + // result when using a dedicated task runner for NSS. + rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback())); + +#if !defined(USE_OPENSSL) + // SSLClientSocketNSS records the error exactly + EXPECT_EQ(ERR_CONNECTION_RESET, rv); +#else + // SSLClientSocketOpenSSL treats any errors as a simple EOF. + EXPECT_EQ(0, rv); +#endif +} + +// Tests that the SSLClientSocket properly handles when the underlying transport +// asynchronously returns an error code while writing data - such as if an +// intermediary terminates the socket connection uncleanly. +// This is a regression test for http://crbug.com/249848 +TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer + // is retained in order to configure additional errors. + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + static const int kRequestTextSize = + static_cast<int>(arraysize(request_text) - 1); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); + memcpy(request_buffer->data(), request_text, kRequestTextSize); + + // Simulate an unclean/forcible shutdown on the underlying socket. + // However, simulate this error asynchronously. + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + raw_transport->SetNextWriteShouldBlock(); + + // This write should complete synchronously, because the TLS ciphertext + // can be created and placed into the outgoing buffers independent of the + // underlying transport. + rv = callback.GetResult( + sock->Write(request_buffer.get(), kRequestTextSize, callback.callback())); + EXPECT_EQ(kRequestTextSize, rv); + + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + + rv = sock->Read(buf.get(), 4096, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + // Now unblock the outgoing request, having it fail with the connection + // being reset. + raw_transport->UnblockWrite(); + + // Note: This will cause an inifite loop if this bug has regressed. Simply + // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING + // is a legitimate result when using a dedicated task runner for NSS. + rv = callback.GetResult(rv); + +#if !defined(USE_OPENSSL) + // SSLClientSocketNSS records the error exactly + EXPECT_EQ(ERR_CONNECTION_RESET, rv); +#else + // SSLClientSocketOpenSSL treats any errors as a simple EOF. + EXPECT_EQ(0, rv); +#endif +} + +// Test the full duplex mode, with Read and Write pending at the same time. +// This test also serves as a regression test for http://crbug.com/29815. +TEST_F(SSLClientSocketTest, Read_FullDuplex) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; // Used for everything except Write. + + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + // Issue a "hanging" Read first. + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + rv = sock->Read(buf.get(), 4096, callback.callback()); + // We haven't written the request, so there should be no response yet. + ASSERT_EQ(ERR_IO_PENDING, rv); + + // Write the request. + // The request is padded with a User-Agent header to a size that causes the + // memio circular buffer (4k bytes) in SSLClientSocketNSS to wrap around. + // This tests the fix for http://crbug.com/29815. + std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name "; + for (int i = 0; i < 3770; ++i) + request_text.push_back('*'); + request_text.append("\r\n\r\n"); + scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text)); + + TestCompletionCallback callback2; // Used for Write only. + rv = sock->Write( + request_buffer.get(), request_text.size(), callback2.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback2.WaitForResult(); + EXPECT_EQ(static_cast<int>(request_text.size()), rv); + + // Now get the Read result. + rv = callback.WaitForResult(); + EXPECT_GT(rv, 0); +} + +// Attempts to Read() and Write() from an SSLClientSocketNSS in full duplex +// mode when the underlying transport is blocked on sending data. When the +// underlying transport completes due to an error, it should invoke both the +// Read() and Write() callbacks. If the socket is deleted by the Read() +// callback, the Write() callback should not be invoked. +// Regression test for http://crbug.com/232633 +TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer + // is retained in order to configure additional errors. + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); + + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock = + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name "; + request_text.append(20 * 1024, '*'); + request_text.append("\r\n\r\n"); + scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer( + new StringIOBuffer(request_text), request_text.size())); + + // Simulate errors being returned from the underlying Read() and Write() ... + raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + // ... but have those errors returned asynchronously. Because the Write() will + // return first, this will trigger the error. + raw_transport->SetNextReadShouldBlock(); + raw_transport->SetNextWriteShouldBlock(); + + // Enqueue a Read() before calling Write(), which should "hang" due to + // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return. + SSLClientSocket* raw_sock = sock.get(); + DeleteSocketCallback read_callback(sock.release()); + scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096)); + rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback()); + + // Ensure things didn't complete synchronously, otherwise |sock| is invalid. + ASSERT_EQ(ERR_IO_PENDING, rv); + ASSERT_FALSE(read_callback.have_result()); + +#if !defined(USE_OPENSSL) + // NSS follows a pattern where a call to PR_Write will only consume as + // much data as it can encode into application data records before the + // internal memio buffer is full, which should only fill if writing a large + // amount of data and the underlying transport is blocked. Once this happens, + // NSS will return (total size of all application data records it wrote) - 1, + // with the caller expected to resume with the remaining unsent data. + // + // This causes SSLClientSocketNSS::Write to return that it wrote some data + // before it will return ERR_IO_PENDING, so make an extra call to Write() to + // get the socket in the state needed for the test below. + // + // This is not needed for OpenSSL, because for OpenSSL, + // SSL_MODE_ENABLE_PARTIAL_WRITE is not specified - thus + // SSLClientSocketOpenSSL::Write() will not return until all of + // |request_buffer| has been written to the underlying BIO (although not + // necessarily the underlying transport). + rv = callback.GetResult(raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback())); + ASSERT_LT(0, rv); + request_buffer->DidConsume(rv); + + // Guard to ensure that |request_buffer| was larger than all of the internal + // buffers (transport, memio, NSS) along the way - otherwise the next call + // to Write() will crash with an invalid buffer. + ASSERT_LT(0, request_buffer->BytesRemaining()); +#endif + + // Attempt to write the remaining data. NSS will not be able to consume the + // application data because the internal buffers are full, while OpenSSL will + // return that its blocked because the underlying transport is blocked. + rv = raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback()); + ASSERT_EQ(ERR_IO_PENDING, rv); + ASSERT_FALSE(callback.have_result()); + + // Now unblock Write(), which will invoke OnSendComplete and (eventually) + // call the Read() callback, deleting the socket and thus aborting calling + // the Write() callback. + raw_transport->UnblockWrite(); + + rv = read_callback.WaitForResult(); + +#if !defined(USE_OPENSSL) + // NSS records the error exactly. + EXPECT_EQ(ERR_CONNECTION_RESET, rv); +#else + // OpenSSL treats any errors as a simple EOF. + EXPECT_EQ(0, rv); +#endif + + // The Write callback should not have been called. + EXPECT_FALSE(callback.have_result()); +} + +TEST_F(SSLClientSocketTest, Read_SmallChunks) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + + rv = sock->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + + scoped_refptr<IOBuffer> buf(new IOBuffer(1)); + for (;;) { + rv = sock->Read(buf.get(), 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GE(rv, 0); + if (rv <= 0) + break; + } +} + +TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<ReadBufferingStreamSocket> transport( + new ReadBufferingStreamSocket(real_transport.Pass())); + ReadBufferingStreamSocket* raw_transport = transport.get(); + int rv = callback.GetResult(transport->Connect(callback.callback())); + ASSERT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + kDefaultSSLConfig)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + ASSERT_EQ(OK, rv); + ASSERT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + + rv = callback.GetResult(sock->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback())); + ASSERT_GT(rv, 0); + ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + + // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of + // data (the max SSL record size) at a time. Ensure that at least 15K worth + // of SSL data is buffered first. The 15K of buffered data is made up of + // many smaller SSL records (the TestServer writes along 1350 byte + // plaintext boundaries), although there may also be a few records that are + // smaller or larger, due to timing and SSL False Start. + // 15K was chosen because 15K is smaller than the 17K (max) read issued by + // the SSLClientSocket implementation, and larger than the minimum amount + // of ciphertext necessary to contain the 8K of plaintext requested below. + raw_transport->SetBufferSize(15000); + + scoped_refptr<IOBuffer> buffer(new IOBuffer(8192)); + rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback())); + ASSERT_EQ(rv, 8192); +} + +TEST_F(SSLClientSocketTest, Read_Interrupted) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + + rv = sock->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + + // Do a partial read and then exit. This test should not crash! + scoped_refptr<IOBuffer> buf(new IOBuffer(512)); + rv = sock->Read(buf.get(), 512, callback.callback()); + EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GT(rv, 0); +} + +TEST_F(SSLClientSocketTest, Read_FullLogging) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + log.SetLogLevel(NetLog::LOG_ALL); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + + rv = sock->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + size_t last_index = ExpectLogContainsSomewhereAfter( + entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); + + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + for (;;) { + rv = sock->Read(buf.get(), 4096, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GE(rv, 0); + if (rv <= 0) + break; + + log.GetEntries(&entries); + last_index = + ExpectLogContainsSomewhereAfter(entries, + last_index + 1, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, + NetLog::PHASE_NONE); + } +} + +// Regression test for http://crbug.com/42538 +TEST_F(SSLClientSocketTest, PrematureApplicationData) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + TestCompletionCallback callback; + + static const unsigned char application_data[] = { + 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, + 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, + 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, + 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, + 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, + 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, + 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, + 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, + 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, + 0x0a}; + + // All reads and writes complete synchronously (async=false). + MockRead data_reads[] = { + MockRead(SYNCHRONOUS, + reinterpret_cast<const char*>(application_data), + arraysize(application_data)), + MockRead(SYNCHRONOUS, OK), }; + + StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0); + + scoped_ptr<StreamSocket> transport( + new MockTCPClientSocket(addr, NULL, &data)); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv); +} + +TEST_F(SSLClientSocketTest, CipherSuiteDisables) { + // Rather than exhaustively disabling every RC4 ciphersuite defined at + // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, + // only disabling those cipher suites that the test server actually + // implements. + const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA + }; + + SpawnedTestServer::SSLOptions ssl_options; + // Enable only RC4 on the test server. + ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config; + for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i) + ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + + // NSS has special handling that maps a handshake_failure alert received + // immediately after a client_hello to be a mismatched cipher suite error, + // leading to ERR_SSL_VERSION_OR_CIPHER_MISMATCH. When using OpenSSL or + // Secure Transport (OS X), the handshake_failure is bubbled up without any + // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure + // indicates that no cipher suite was negotiated with the test server. + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH || + rv == ERR_SSL_PROTOCOL_ERROR); + // The exact ordering differs between SSLClientSocketNSS (which issues an + // extra read) and SSLClientSocketMac (which does not). Just make sure the + // error appears somewhere in the log. + log.GetEntries(&entries); + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE); + + // We cannot test sock->IsConnected(), as the NSS implementation disconnects + // the socket when it encounters an error, whereas other implementations + // leave it connected. + // Because this an error that the test server is mutually aware of, as opposed + // to being an error such as a certificate name mismatch, which is + // client-only, the exact index of the SSL connect end depends on how + // quickly the test server closes the underlying socket. If the test server + // closes before the IO message loop pumps messages, there may be a 0-byte + // Read event in the NetLog due to TCPClientSocket picking up the EOF. As a + // result, the SSL connect end event will be the second-to-last entry, + // rather than the last entry. + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1) || + LogContainsSSLConnectEndEvent(entries, -2)); +} + +// When creating an SSLClientSocket, it is allowed to pass in a +// ClientSocketHandle that is not obtained from a client socket pool. +// Here we verify that such a simple ClientSocketHandle, not associated with any +// client socket pool, can be destroyed safely. +TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle()); + socket_handle->SetSocket(transport.Pass()); + + scoped_ptr<SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(socket_handle.Pass(), + test_server.host_port_pair(), + kDefaultSSLConfig, + context_)); + + EXPECT_FALSE(sock->IsConnected()); + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); +} + +// Verifies that SSLClientSocket::ExportKeyingMaterial return a success +// code and different keying label results in different keying material. +TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + const int kKeyingMaterialSize = 32; + const char* kKeyingLabel1 = "client-socket-test-1"; + const char* kKeyingContext = ""; + unsigned char client_out1[kKeyingMaterialSize]; + memset(client_out1, 0, sizeof(client_out1)); + rv = sock->ExportKeyingMaterial( + kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1)); + EXPECT_EQ(rv, OK); + + const char* kKeyingLabel2 = "client-socket-test-2"; + unsigned char client_out2[kKeyingMaterialSize]; + memset(client_out2, 0, sizeof(client_out2)); + rv = sock->ExportKeyingMaterial( + kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2)); + EXPECT_EQ(rv, OK); + EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); +} + +// Verifies that SSLClientSocket::ClearSessionCache can be called without +// explicit NSS initialization. +TEST(SSLClientSocket, ClearSessionCache) { + SSLClientSocket::ClearSessionCache(); +} + +// This tests that SSLInfo contains a properly re-constructed certificate +// chain. That, in turn, verifies that GetSSLInfo is giving us the chain as +// verified, not the chain as served by the server. (They may be different.) +// +// CERT_CHAIN_WRONG_ROOT is redundant-server-chain.pem. It contains A +// (end-entity) -> B -> C, and C is signed by D. redundant-validated-chain.pem +// contains a chain of A -> B -> C2, where C2 is the same public key as C, but +// a self-signed root. Such a situation can occur when a new root (C2) is +// cross-certified by an old root (D) and has two different versions of its +// floating around. Servers may supply C2 as an intermediate, but the +// SSLClientSocket should return the chain that was verified, from +// verify_result, instead. +TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { + // By default, cause the CertVerifier to treat all certificates as + // expired. + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); + + // We will expect SSLInfo to ultimately contain this chain. + CertificateList certs = + CreateCertificateListFromFile(GetTestCertsDirectory(), + "redundant-validated-chain.pem", + X509Certificate::FORMAT_AUTO); + ASSERT_EQ(3U, certs.size()); + + X509Certificate::OSCertHandles temp_intermediates; + temp_intermediates.push_back(certs[1]->os_cert_handle()); + temp_intermediates.push_back(certs[2]->os_cert_handle()); + + CertVerifyResult verify_result; + verify_result.verified_cert = X509Certificate::CreateFromHandle( + certs[0]->os_cert_handle(), temp_intermediates); + + // Add a rule that maps the server cert (A) to the chain of A->B->C2 + // rather than A->B->C. + cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK); + + // Load and install the root for the validated chain. + scoped_refptr<X509Certificate> root_cert = ImportCertFromFile( + GetTestCertsDirectory(), "redundant-validated-chain-root.pem"); + ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert); + ScopedTestRoot scoped_root(root_cert.get()); + + // Set up a test server with CERT_CHAIN_WRONG_ROOT. + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, + ssl_options, + base::FilePath(FILE_PATH_LITERAL("net/data/ssl"))); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + EXPECT_FALSE(sock->IsConnected()); + rv = sock->Connect(callback.callback()); + + CapturingNetLog::CapturedEntryList entries; + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + log.GetEntries(&entries); + EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); + + SSLInfo ssl_info; + sock->GetSSLInfo(&ssl_info); + + // Verify that SSLInfo contains the corrected re-constructed chain A -> B + // -> C2. + const X509Certificate::OSCertHandles& intermediates = + ssl_info.cert->GetIntermediateCertificates(); + ASSERT_EQ(2U, intermediates.size()); + EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(), + certs[0]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0], + certs[1]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1], + certs[2]->os_cert_handle())); + + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); +} + +// Verifies the correctness of GetSSLCertRequestInfo. +class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { + protected: + // Creates a test server with the given SSLOptions, connects to it and returns + // the SSLCertRequestInfo reported by the socket. + scoped_refptr<SSLCertRequestInfo> GetCertRequest( + SpawnedTestServer::SSLOptions ssl_options) { + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + if (!test_server.Start()) + return NULL; + + AddressList addr; + if (!test_server.GetAddressList(&addr)) + return NULL; + + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + EXPECT_FALSE(sock->IsConnected()); + + rv = sock->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); + sock->GetSSLCertRequestInfo(request_info.get()); + sock->Disconnect(); + EXPECT_FALSE(sock->IsConnected()); + + return request_info; + } +}; + +TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); + ASSERT_TRUE(request_info.get()); + EXPECT_EQ(0u, request_info->cert_authorities.size()); +} + +TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { + const base::FilePath::CharType kThawteFile[] = + FILE_PATH_LITERAL("thawte.single.pem"); + const unsigned char kThawteDN[] = { + 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e, + 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79, + 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, + 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, + 0x53, 0x47, 0x43, 0x20, 0x43, 0x41}; + const size_t kThawteLen = sizeof(kThawteDN); + + const base::FilePath::CharType kDiginotarFile[] = + FILE_PATH_LITERAL("diginotar_root_ca.pem"); + const unsigned char kDiginotarDN[] = { + 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31, + 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69, + 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74, + 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, + 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f, + 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e, + 0x6c}; + const size_t kDiginotarLen = sizeof(kDiginotarDN); + + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().Append(kThawteFile)); + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().Append(kDiginotarFile)); + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); + ASSERT_TRUE(request_info.get()); + ASSERT_EQ(2u, request_info->cert_authorities.size()); + EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), + request_info->cert_authorities[0]); + EXPECT_EQ( + std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), + request_info->cert_authorities[1]); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/ssl_error_params.cc b/chromium/net/socket/ssl_error_params.cc new file mode 100644 index 00000000000..37561f0de48 --- /dev/null +++ b/chromium/net/socket/ssl_error_params.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_error_params.h" + +#include "base/bind.h" +#include "base/values.h" + +namespace net { + +namespace { + +base::Value* NetLogSSLErrorCallback(int net_error, + int ssl_lib_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("net_error", net_error); + if (ssl_lib_error) + dict->SetInteger("ssl_lib_error", ssl_lib_error); + return dict; +} + +} // namespace + +NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, + int ssl_lib_error) { + return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error); +} + +} // namespace net diff --git a/chromium/net/socket/ssl_error_params.h b/chromium/net/socket/ssl_error_params.h new file mode 100644 index 00000000000..07a1c4d99d9 --- /dev/null +++ b/chromium/net/socket/ssl_error_params.h @@ -0,0 +1,18 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_ERROR_PARAMS_H_ +#define NET_SOCKET_SSL_ERROR_PARAMS_H_ + +#include "net/base/net_log.h" + +namespace net { + +// Creates NetLog callback for when we receive an SSL error. +NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, + int ssl_lib_error); + +} // namespace net + +#endif // NET_SOCKET_SSL_ERROR_PARAMS_H_ diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h new file mode 100644 index 00000000000..8b607bf80cf --- /dev/null +++ b/chromium/net/socket/ssl_server_socket.h @@ -0,0 +1,64 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_H_ + +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/socket/ssl_socket.h" +#include "net/socket/stream_socket.h" + +namespace crypto { +class RSAPrivateKey; +} // namespace crypto + +namespace net { + +struct SSLConfig; +class X509Certificate; + +class SSLServerSocket : public SSLSocket { + public: + virtual ~SSLServerSocket() {} + + // Perform the SSL server handshake, and notify the supplied callback + // if the process completes asynchronously. If Disconnect is called before + // completion then the callback will be silently, as for other StreamSocket + // calls. + virtual int Handshake(const CompletionCallback& callback) = 0; +}; + +// Configures the underlying SSL library for the use of SSL server sockets. +// +// Due to the requirements of the underlying libraries, this should be called +// early in process initialization, before any SSL socket, client or server, +// has been used. +// +// Note: If a process does not use SSL server sockets, this call may be +// omitted. +NET_EXPORT void EnableSSLServerSockets(); + +// Creates an SSL server socket over an already-connected transport socket. +// The caller must provide the server certificate and private key to use. +// +// The returned SSLServerSocket takes ownership of |socket|. Stubbed versions +// of CreateSSLServerSocket will delete |socket| and return NULL. +// It takes a reference to |certificate|. +// The |key| and |ssl_config| parameters are copied. |key| cannot be const +// because the methods used to copy its contents are non-const. +// +// The caller starts the SSL server handshake by calling Handshake on the +// returned socket. +NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config); + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_H_ diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc new file mode 100644 index 00000000000..7e5d70118ac --- /dev/null +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -0,0 +1,828 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/ssl_server_socket_nss.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif + +#if defined(USE_SYSTEM_SSL) +#include <dlfcn.h> +#endif +#if defined(OS_MACOSX) +#include <Security/Security.h> +#endif +#include <certdb.h> +#include <cryptohi.h> +#include <hasht.h> +#include <keyhi.h> +#include <nspr.h> +#include <nss.h> +#include <pk11pub.h> +#include <secerr.h> +#include <sechash.h> +#include <ssl.h> +#include <sslerr.h> +#include <sslproto.h> + +#include <limits> + +#include "base/lazy_instance.h" +#include "base/memory/ref_counted.h" +#include "crypto/rsa_private_key.h" +#include "crypto/nss_util_internal.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/nss_ssl_util.h" +#include "net/socket/ssl_error_params.h" + +// SSL plaintext fragments are shorter than 16KB. Although the record layer +// overhead is allowed to be 2K + 5 bytes, in practice the overhead is much +// smaller than 1KB. So a 17KB buffer should be large enough to hold an +// entire SSL record. +static const int kRecvBufferSize = 17 * 1024; +static const int kSendBufferSize = 17 * 1024; + +#define GotoState(s) next_handshake_state_ = s + +namespace net { + +namespace { + +bool g_nss_server_sockets_init = false; + +class NSSSSLServerInitSingleton { + public: + NSSSSLServerInitSingleton() { + EnsureNSSSSLInit(); + + SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL); + g_nss_server_sockets_init = true; + } + + ~NSSSSLServerInitSingleton() { + SSL_ShutdownServerSessionIDCache(); + g_nss_server_sockets_init = false; + } +}; + +static base::LazyInstance<NSSSSLServerInitSingleton> + g_nss_ssl_server_init_singleton = LAZY_INSTANCE_INITIALIZER; + +} // namespace + +void EnableSSLServerSockets() { + g_nss_ssl_server_init_singleton.Get(); +} + +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* cert, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) { + DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" + << "called yet!"; + + return scoped_ptr<SSLServerSocket>( + new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config)); +} + +SSLServerSocketNSS::SSLServerSocketNSS( + scoped_ptr<StreamSocket> transport_socket, + scoped_refptr<X509Certificate> cert, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) + : transport_send_busy_(false), + transport_recv_busy_(false), + user_read_buf_len_(0), + user_write_buf_len_(0), + nss_fd_(NULL), + nss_bufs_(NULL), + transport_socket_(transport_socket.Pass()), + ssl_config_(ssl_config), + cert_(cert), + next_handshake_state_(STATE_NONE), + completed_handshake_(false) { + ssl_config_.false_start_enabled = false; + ssl_config_.version_min = SSL_PROTOCOL_VERSION_SSL3; + ssl_config_.version_max = SSL_PROTOCOL_VERSION_TLS1_1; + + // TODO(hclam): Need a better way to clone a key. + std::vector<uint8> key_bytes; + CHECK(key->ExportPrivateKey(&key_bytes)); + key_.reset(crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_bytes)); + CHECK(key_.get()); +} + +SSLServerSocketNSS::~SSLServerSocketNSS() { + if (nss_fd_ != NULL) { + PR_Close(nss_fd_); + nss_fd_ = NULL; + } +} + +int SSLServerSocketNSS::Handshake(const CompletionCallback& callback) { + net_log_.BeginEvent(NetLog::TYPE_SSL_SERVER_HANDSHAKE); + + int rv = Init(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize NSS"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv); + return rv; + } + + rv = InitializeSSLOptions(); + if (rv != OK) { + LOG(ERROR) << "Failed to initialize SSL options"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv); + return rv; + } + + // Set peer address. TODO(hclam): This should be in a separate method. + PRNetAddr peername; + memset(&peername, 0, sizeof(peername)); + peername.raw.family = AF_INET; + memio_SetPeerName(nss_fd_, &peername); + + GotoState(STATE_HANDSHAKE); + rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + user_handshake_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv); + } + + return rv > OK ? OK : rv; +} + +int SSLServerSocketNSS::ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + SECStatus result = SSL_ExportKeyingMaterial( + nss_fd_, label.data(), label.size(), has_context, + reinterpret_cast<const unsigned char*>(context.data()), + context.length(), out, outlen); + if (result != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ExportKeyingMaterial", ""); + return MapNSSError(PORT_GetError()); + } + return OK; +} + +int SSLServerSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + unsigned char buf[64]; + unsigned int len; + SECStatus result = SSL_GetChannelBinding(nss_fd_, + SSL_CHANNEL_BINDING_TLS_UNIQUE, + buf, &len, arraysize(buf)); + if (result != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", ""); + return MapNSSError(PORT_GetError()); + } + out->assign(reinterpret_cast<char*>(buf), len); + return OK; +} + +int SSLServerSocketNSS::Connect(const CompletionCallback& callback) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(user_read_callback_.is_null()); + DCHECK(user_handshake_callback_.is_null()); + DCHECK(!user_read_buf_.get()); + DCHECK(nss_bufs_); + DCHECK(!callback.is_null()); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + DCHECK(completed_handshake_); + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + return rv; +} + +int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(user_write_callback_.is_null()); + DCHECK(!user_write_buf_.get()); + DCHECK(nss_bufs_); + DCHECK(!callback.is_null()); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + return rv; +} + +bool SSLServerSocketNSS::SetReceiveBufferSize(int32 size) { + return transport_socket_->SetReceiveBufferSize(size); +} + +bool SSLServerSocketNSS::SetSendBufferSize(int32 size) { + return transport_socket_->SetSendBufferSize(size); +} + +bool SSLServerSocketNSS::IsConnected() const { + return completed_handshake_; +} + +void SSLServerSocketNSS::Disconnect() { + transport_socket_->Disconnect(); +} + +bool SSLServerSocketNSS::IsConnectedAndIdle() const { + return completed_handshake_ && transport_socket_->IsConnectedAndIdle(); +} + +int SSLServerSocketNSS::GetPeerAddress(IPEndPoint* address) const { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + return transport_socket_->GetPeerAddress(address); +} + +int SSLServerSocketNSS::GetLocalAddress(IPEndPoint* address) const { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + return transport_socket_->GetLocalAddress(address); +} + +const BoundNetLog& SSLServerSocketNSS::NetLog() const { + return net_log_; +} + +void SSLServerSocketNSS::SetSubresourceSpeculation() { + transport_socket_->SetSubresourceSpeculation(); +} + +void SSLServerSocketNSS::SetOmniboxSpeculation() { + transport_socket_->SetOmniboxSpeculation(); +} + +bool SSLServerSocketNSS::WasEverUsed() const { + return transport_socket_->WasEverUsed(); +} + +bool SSLServerSocketNSS::UsingTCPFastOpen() const { + return transport_socket_->UsingTCPFastOpen(); +} + +bool SSLServerSocketNSS::WasNpnNegotiated() const { + return false; +} + +NextProto SSLServerSocketNSS::GetNegotiatedProtocol() const { + // NPN is not supported by this class. + return kProtoUnknown; +} + +bool SSLServerSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { + NOTIMPLEMENTED(); + return false; +} + +int SSLServerSocketNSS::InitializeSSLOptions() { + // Transport connected, now hook it up to nss + nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize); + if (nss_fd_ == NULL) { + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code. + } + + // Grab pointer to buffers + nss_bufs_ = memio_GetSecret(nss_fd_); + + /* Create SSL state machine */ + /* Push SSL onto our fake I/O socket */ + nss_fd_ = SSL_ImportFD(NULL, nss_fd_); + if (nss_fd_ == NULL) { + LogFailedNSSFunction(net_log_, "SSL_ImportFD", ""); + return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code. + } + // TODO(port): set more ssl options! Check errors! + + int rv; + + rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2"); + return ERR_UNEXPECTED; + } + + SSLVersionRange version_range; + version_range.min = ssl_config_.version_min; + version_range.max = ssl_config_.version_max; + rv = SSL_VersionRangeSet(nss_fd_, &version_range); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", ""); + return ERR_NO_SSL_VERSIONS_ENABLED; + } + + for (std::vector<uint16>::const_iterator it = + ssl_config_.disabled_cipher_suites.begin(); + it != ssl_config_.disabled_cipher_suites.end(); ++it) { + // This will fail if the specified cipher is not implemented by NSS, but + // the failure is harmless. + SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); + } + + // Server socket doesn't need session tickets. + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction( + net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); + } + + // Doing this will force PR_Accept perform handshake as server. + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_SERVER, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_SERVER"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUEST_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_OptionSet(nss_fd_, SSL_REQUIRE_CERTIFICATE, PR_FALSE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUIRE_CERTIFICATE"); + return ERR_UNEXPECTED; + } + + rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_AuthCertificateHook", ""); + return ERR_UNEXPECTED; + } + + rv = SSL_HandshakeCallback(nss_fd_, HandshakeCallback, this); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_HandshakeCallback", ""); + return ERR_UNEXPECTED; + } + + // Get a certificate of CERTCertificate structure. + std::string der_string; + if (!X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string)) + return ERR_UNEXPECTED; + + SECItem der_cert; + der_cert.data = reinterpret_cast<unsigned char*>(const_cast<char*>( + der_string.data())); + der_cert.len = der_string.length(); + der_cert.type = siDERCertBuffer; + + // Parse into a CERTCertificate structure. + CERTCertificate* cert = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); + if (!cert) { + LogFailedNSSFunction(net_log_, "CERT_NewTempCertificate", ""); + return MapNSSError(PORT_GetError()); + } + + // Get a key of SECKEYPrivateKey* structure. + std::vector<uint8> key_vector; + if (!key_->ExportPrivateKey(&key_vector)) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECKEYPrivateKeyStr* private_key = NULL; + PK11SlotInfo* slot = crypto::GetPrivateNSSKeySlot(); + if (!slot) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + SECItem der_private_key_info; + der_private_key_info.data = + const_cast<unsigned char*>(&key_vector.front()); + der_private_key_info.len = key_vector.size(); + // The server's RSA private key must be imported into NSS with the + // following key usage bits: + // - KU_KEY_ENCIPHERMENT, required for the RSA key exchange algorithm. + // - KU_DIGITAL_SIGNATURE, required for the DHE_RSA and ECDHE_RSA key + // exchange algorithms. + const unsigned int key_usage = KU_KEY_ENCIPHERMENT | KU_DIGITAL_SIGNATURE; + rv = PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot, &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE, + key_usage, &private_key, NULL); + PK11_FreeSlot(slot); + if (rv != SECSuccess) { + CERT_DestroyCertificate(cert); + return ERR_UNEXPECTED; + } + + // Assign server certificate and private key. + SSLKEAType cert_kea = NSS_FindCertKEAType(cert); + rv = SSL_ConfigSecureServer(nss_fd_, cert, private_key, cert_kea); + CERT_DestroyCertificate(cert); + SECKEY_DestroyPrivateKey(private_key); + + if (rv != SECSuccess) { + PRErrorCode prerr = PR_GetError(); + LOG(ERROR) << "Failed to config SSL server: " << prerr; + LogFailedNSSFunction(net_log_, "SSL_ConfigureSecureServer", ""); + return ERR_UNEXPECTED; + } + + // Tell SSL we're a server; needed if not letting NSPR do socket I/O + rv = SSL_ResetHandshake(nss_fd_, PR_TRUE); + if (rv != SECSuccess) { + LogFailedNSSFunction(net_log_, "SSL_ResetHandshake", ""); + return ERR_UNEXPECTED; + } + + return OK; +} + +void SSLServerSocketNSS::OnSendComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + if (!completed_handshake_) + return; + + if (user_write_buf_.get()) { + int rv = DoWriteLoop(result); + if (rv != ERR_IO_PENDING) + DoWriteCallback(rv); + } else { + // Ensure that any queued ciphertext is flushed. + DoTransportIO(); + } +} + +void SSLServerSocketNSS::OnRecvComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_.get() || !completed_handshake_) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +void SSLServerSocketNSS::OnHandshakeIOComplete(int result) { + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv); + if (!user_handshake_callback_.is_null()) + DoHandshakeCallback(rv); + } +} + +// Return 0 for EOF, +// > 0 for bytes transferred immediately, +// < 0 for error (or the non-error ERR_IO_PENDING). +int SSLServerSocketNSS::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + const char* buf1; + const char* buf2; + unsigned int len1, len2; + memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); + const unsigned int len = len1 + len2; + + int rv = 0; + if (len) { + scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); + memcpy(send_buffer->data(), buf1, len1); + memcpy(send_buffer->data() + len1, buf2, len2); + rv = transport_socket_->Write( + send_buffer.get(), + len, + base::Bind(&SSLServerSocketNSS::BufferSendComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + + return rv; +} + +void SSLServerSocketNSS::BufferSendComplete(int result) { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); + transport_send_busy_ = false; + OnSendComplete(result); +} + +int SSLServerSocketNSS::BufferRecv(void) { + if (transport_recv_busy_) return ERR_IO_PENDING; + + char* buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + int rv; + if (!nb) { + // buffer too full to read into, so no I/O possible at moment + rv = ERR_IO_PENDING; + } else { + recv_buffer_ = new IOBuffer(nb); + rv = transport_socket_->Read( + recv_buffer_.get(), + nb, + base::Bind(&SSLServerSocketNSS::BufferRecvComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + if (rv > 0) + memcpy(buf, recv_buffer_->data(), rv); + memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); + recv_buffer_ = NULL; + } + } + return rv; +} + +void SSLServerSocketNSS::BufferRecvComplete(int result) { + if (result > 0) { + char* buf; + memio_GetReadParams(nss_bufs_, &buf); + memcpy(buf, recv_buffer_->data(), result); + } + recv_buffer_ = NULL; + memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); + transport_recv_busy_ = false; + OnRecvComplete(result); +} + +// Do as much network I/O as possible between the buffer and the +// transport socket. Return true if some I/O performed, false +// otherwise (error or ERR_IO_PENDING). +bool SSLServerSocketNSS::DoTransportIO() { + bool network_moved = false; + if (nss_bufs_ != NULL) { + int rv; + // Read and write as much data as we can. The loop is neccessary + // because Write() may return synchronously. + do { + rv = BufferSend(); + if (rv > 0) + network_moved = true; + } while (rv > 0); + if (BufferRecv() >= 0) + network_moved = true; + } + return network_moved; +} + +int SSLServerSocketNSS::DoPayloadRead() { + DCHECK(user_read_buf_.get()); + DCHECK_GT(user_read_buf_len_, 0); + int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, prerr)); + return rv; +} + +int SSLServerSocketNSS::DoPayloadWrite() { + DCHECK(user_write_buf_.get()); + int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); + if (rv >= 0) + return rv; + PRErrorCode prerr = PR_GetError(); + if (prerr == PR_WOULD_BLOCK_ERROR) { + return ERR_IO_PENDING; + } + rv = MapNSSError(prerr); + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + CreateNetLogSSLErrorCallback(rv, prerr)); + return rv; +} + +int SSLServerSocketNSS::DoHandshakeLoop(int last_io_result) { + int rv = last_io_result; + do { + // Default to STATE_NONE for next state. + // (This is a quirk carried over from the windows + // implementation. It makes reading the logs a bit harder.) + // State handlers can and often do call GotoState just + // to stay in the current state. + State state = next_handshake_state_; + GotoState(STATE_NONE); + switch (state) { + case STATE_HANDSHAKE: + rv = DoHandshake(); + break; + case STATE_NONE: + default: + rv = ERR_UNEXPECTED; + LOG(DFATAL) << "unexpected state " << state; + break; + } + + // Do the actual network I/O + bool network_moved = DoTransportIO(); + if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) { + // In general we exit the loop if rv is ERR_IO_PENDING. In this + // special case we keep looping even if rv is ERR_IO_PENDING because + // the transport IO may allow DoHandshake to make progress. + rv = OK; // This causes us to stay in the loop. + } + } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); + return rv; +} + +int SSLServerSocketNSS::DoReadLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogSSLErrorCallback(rv, 0)); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadRead(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoWriteLoop(int result) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + + if (result < 0) + return result; + + if (!nss_bufs_) { + LOG(DFATAL) << "!nss_bufs_"; + int rv = ERR_UNEXPECTED; + net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, + CreateNetLogSSLErrorCallback(rv, 0)); + return rv; + } + + bool network_moved; + int rv; + do { + rv = DoPayloadWrite(); + network_moved = DoTransportIO(); + } while (rv == ERR_IO_PENDING && network_moved); + return rv; +} + +int SSLServerSocketNSS::DoHandshake() { + int net_error = OK; + SECStatus rv = SSL_ForceHandshake(nss_fd_); + + if (rv == SECSuccess) { + completed_handshake_ = true; + } else { + PRErrorCode prerr = PR_GetError(); + net_error = MapNSSError(prerr); + + // If not done, stay in this state + if (net_error == ERR_IO_PENDING) { + GotoState(STATE_HANDSHAKE); + } else { + LOG(ERROR) << "handshake failed; NSS error code " << prerr + << ", net_error " << net_error; + net_log_.AddEvent(NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogSSLErrorCallback(net_error, prerr)); + } + } + return net_error; +} + +void SSLServerSocketNSS::DoHandshakeCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + + CompletionCallback c = user_handshake_callback_; + user_handshake_callback_.Reset(); + c.Run(rv > OK ? OK : rv); +} + +void SSLServerSocketNSS::DoReadCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(!user_read_callback_.is_null()); + + // Since Run may result in Read being called, clear |user_read_callback_| + // up front. + CompletionCallback c = user_read_callback_; + user_read_callback_.Reset(); + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c.Run(rv); +} + +void SSLServerSocketNSS::DoWriteCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(!user_write_callback_.is_null()); + + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + CompletionCallback c = user_write_callback_; + user_write_callback_.Reset(); + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c.Run(rv); +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +// Do nothing but return SECSuccess. +// This is called only in full handshake mode. +// Peer certificate is retrieved in HandshakeCallback() later, which is called +// in full handshake mode or in resumption handshake mode. +SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + // TODO(hclam): Implement. + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +// static +// NSS calls this when handshake is completed. +// After the SSL handshake is finished we need to verify the certificate. +void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, + void* arg) { + // TODO(hclam): Implement. +} + +int SSLServerSocketNSS::Init() { + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; + + EnableSSLServerSockets(); + return OK; +} + +} // namespace net diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h new file mode 100644 index 00000000000..8bbb0e338ac --- /dev/null +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -0,0 +1,150 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ +#define NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ + +#include <certt.h> +#include <keyt.h> +#include <nspr.h> +#include <nss.h> + +#include "base/memory/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_log.h" +#include "net/base/nss_memio.h" +#include "net/socket/ssl_server_socket.h" +#include "net/ssl/ssl_config_service.h" + +namespace net { + +class SSLServerSocketNSS : public SSLServerSocket { + public: + // See comments on CreateSSLServerSocket for details of how these + // parameters are used. + SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, + scoped_refptr<X509Certificate> certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config); + virtual ~SSLServerSocketNSS(); + + // SSLServerSocket interface. + virtual int Handshake(const CompletionCallback& callback) OVERRIDE; + + // SSLSocket interface. + virtual int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) OVERRIDE; + virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + + // Socket interface (via StreamSocket). + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + private: + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; + + int InitializeSSLOptions(); + + void OnSendComplete(int result); + void OnRecvComplete(int result); + void OnHandshakeIOComplete(int result); + + int BufferSend(); + void BufferSendComplete(int result); + int BufferRecv(); + void BufferRecvComplete(int result); + bool DoTransportIO(); + int DoPayloadRead(); + int DoPayloadWrite(); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoHandshake(); + void DoHandshakeCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + static SECStatus OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server); + static void HandshakeCallback(PRFileDesc* socket, void* arg); + + virtual int Init(); + + // Members used to send and receive buffer. + bool transport_send_busy_; + bool transport_recv_busy_; + + scoped_refptr<IOBuffer> recv_buffer_; + + BoundNetLog net_log_; + + CompletionCallback user_handshake_callback_; + CompletionCallback user_read_callback_; + CompletionCallback user_write_callback_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + // The NSS SSL state machine + PRFileDesc* nss_fd_; + + // Buffers for the network end of the SSL state machine + memio_Private* nss_bufs_; + + // StreamSocket for sending and receiving data. + scoped_ptr<StreamSocket> transport_socket_; + + // Options for the SSL socket. + SSLConfig ssl_config_; + + // Certificate for the server. + scoped_refptr<X509Certificate> cert_; + + // Private key used by the server. + scoped_ptr<crypto::RSAPrivateKey> key_; + + State next_handshake_state_; + bool completed_handshake_; + + DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS); +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_ diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc new file mode 100644 index 00000000000..c327f2caf10 --- /dev/null +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/logging.h" +#include "net/socket/ssl_server_socket.h" + +// TODO(bulach): Provide simple stubs for EnableSSLServerSockets and +// CreateSSLServerSocket so that when building for OpenSSL rather than NSS, +// so that the code using SSL server sockets can be compiled and disabled +// programatically rather than requiring to be carved out from the compile. + +namespace net { + +void EnableSSLServerSockets() { + NOTIMPLEMENTED(); +} + +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) { + NOTIMPLEMENTED(); + return scoped_ptr<SSLServerSocket>(); +} + +} // namespace net diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc new file mode 100644 index 00000000000..64c85490b29 --- /dev/null +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -0,0 +1,588 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This test suite uses SSLClientSocket to test the implementation of +// SSLServerSocket. In order to establish connections between the sockets +// we need two additional classes: +// 1. FakeSocket +// Connects SSL socket to FakeDataChannel. This class is just a stub. +// +// 2. FakeDataChannel +// Implements the actual exchange of data between two FakeSockets. +// +// Implementations of these two classes are included in this file. + +#include "net/socket/ssl_server_socket.h" + +#include <stdlib.h> + +#include <queue> + +#include "base/compiler_specific.h" +#include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/message_loop/message_loop.h" +#include "base/path_service.h" +#include "crypto/nss_util.h" +#include "crypto/rsa_private_key.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/test_data_directory.h" +#include "net/cert/cert_status_flags.h" +#include "net/cert/mock_cert_verifier.h" +#include "net/cert/x509_certificate.h" +#include "net/http/transport_security_state.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_config_service.h" +#include "net/ssl/ssl_info.h" +#include "net/test/cert_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { + +class FakeDataChannel { + public: + FakeDataChannel() + : read_buf_len_(0), + weak_factory_(this), + closed_(false), + write_called_after_close_(false) { + } + + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + if (closed_) + return 0; + if (data_.empty()) { + read_callback_ = callback; + read_buf_ = buf; + read_buf_len_ = buf_len; + return net::ERR_IO_PENDING; + } + return PropogateData(buf, buf_len); + } + + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + if (closed_) { + if (write_called_after_close_) + return net::ERR_CONNECTION_RESET; + write_called_after_close_ = true; + write_callback_ = callback; + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, + weak_factory_.GetWeakPtr())); + return net::ERR_IO_PENDING; + } + data_.push(new net::DrainableIOBuffer(buf, buf_len)); + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, + weak_factory_.GetWeakPtr())); + return buf_len; + } + + // Closes the FakeDataChannel. After Close() is called, Read() returns 0, + // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that + // after the FakeDataChannel is closed, the first Write() call completes + // asynchronously, which is necessary to reproduce bug 127822. + void Close() { + closed_ = true; + } + + private: + void DoReadCallback() { + if (read_callback_.is_null() || data_.empty()) + return; + + int copied = PropogateData(read_buf_, read_buf_len_); + CompletionCallback callback = read_callback_; + read_callback_.Reset(); + read_buf_ = NULL; + read_buf_len_ = 0; + callback.Run(copied); + } + + void DoWriteCallback() { + if (write_callback_.is_null()) + return; + + CompletionCallback callback = write_callback_; + write_callback_.Reset(); + callback.Run(net::ERR_CONNECTION_RESET); + } + + int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { + scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); + int copied = std::min(buf->BytesRemaining(), read_buf_len); + memcpy(read_buf->data(), buf->data(), copied); + buf->DidConsume(copied); + + if (!buf->BytesRemaining()) + data_.pop(); + return copied; + } + + CompletionCallback read_callback_; + scoped_refptr<net::IOBuffer> read_buf_; + int read_buf_len_; + + CompletionCallback write_callback_; + + std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; + + base::WeakPtrFactory<FakeDataChannel> weak_factory_; + + // True if Close() has been called. + bool closed_; + + // Controls the completion of Write() after the FakeDataChannel is closed. + // After the FakeDataChannel is closed, the first Write() call completes + // asynchronously. + bool write_called_after_close_; + + DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); +}; + +class FakeSocket : public StreamSocket { + public: + FakeSocket(FakeDataChannel* incoming_channel, + FakeDataChannel* outgoing_channel) + : incoming_(incoming_channel), + outgoing_(outgoing_channel) { + } + + virtual ~FakeSocket() { + } + + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + // Read random number of bytes. + buf_len = rand() % buf_len + 1; + return incoming_->Read(buf, buf_len, callback); + } + + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + // Write random number of bytes. + buf_len = rand() % buf_len + 1; + return outgoing_->Write(buf, buf_len, callback); + } + + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { + return true; + } + + virtual bool SetSendBufferSize(int32 size) OVERRIDE { + return true; + } + + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return net::OK; + } + + virtual void Disconnect() OVERRIDE { + incoming_->Close(); + outgoing_->Close(); + } + + virtual bool IsConnected() const OVERRIDE { + return true; + } + + virtual bool IsConnectedAndIdle() const OVERRIDE { + return true; + } + + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + net::IPAddressNumber ip_address(net::kIPv4AddressSize); + *address = net::IPEndPoint(ip_address, 0 /*port*/); + return net::OK; + } + + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + net::IPAddressNumber ip_address(4); + *address = net::IPEndPoint(ip_address, 0); + return net::OK; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + + virtual bool WasEverUsed() const OVERRIDE { + return true; + } + + virtual bool UsingTCPFastOpen() const OVERRIDE { + return false; + } + + + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return false; + } + + private: + net::BoundNetLog net_log_; + FakeDataChannel* incoming_; + FakeDataChannel* outgoing_; + + DISALLOW_COPY_AND_ASSIGN(FakeSocket); +}; + +} // namespace + +// Verify the correctness of the test helper classes first. +TEST(FakeSocketTest, DataTransfer) { + // Establish channels between two sockets. + FakeDataChannel channel_1; + FakeDataChannel channel_2; + FakeSocket client(&channel_1, &channel_2); + FakeSocket server(&channel_2, &channel_1); + + const char kTestData[] = "testing123"; + const int kTestDataSize = strlen(kTestData); + const int kReadBufSize = 1024; + scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); + scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); + + // Write then read. + int written = + server.Write(write_buf.get(), kTestDataSize, CompletionCallback()); + EXPECT_GT(written, 0); + EXPECT_LE(written, kTestDataSize); + + int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback()); + EXPECT_GT(read, 0); + EXPECT_LE(read, written); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); + + // Read then write. + TestCompletionCallback callback; + EXPECT_EQ(net::ERR_IO_PENDING, + server.Read(read_buf.get(), kReadBufSize, callback.callback())); + + written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); + EXPECT_GT(written, 0); + EXPECT_LE(written, kTestDataSize); + + read = callback.WaitForResult(); + EXPECT_GT(read, 0); + EXPECT_LE(read, written); + EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); +} + +class SSLServerSocketTest : public PlatformTest { + public: + SSLServerSocketTest() + : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new MockCertVerifier()), + transport_security_state_(new TransportSecurityState) { + cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID); + } + + protected: + void Initialize() { + scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); + client_connection->SetSocket( + scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); + scoped_ptr<StreamSocket> server_socket( + new FakeSocket(&channel_2_, &channel_1_)); + + base::FilePath certs_dir(GetTestCertsDirectory()); + + base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); + std::string cert_der; + ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); + + scoped_refptr<net::X509Certificate> cert = + X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); + + base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); + std::string key_string; + ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); + std::vector<uint8> key_vector( + reinterpret_cast<const uint8*>(key_string.data()), + reinterpret_cast<const uint8*>(key_string.data() + + key_string.length())); + + scoped_ptr<crypto::RSAPrivateKey> private_key( + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); + + net::SSLConfig ssl_config; + ssl_config.cached_info_enabled = false; + ssl_config.false_start_enabled = false; + ssl_config.channel_id_enabled = false; + ssl_config.version_min = SSL_PROTOCOL_VERSION_SSL3; + ssl_config.version_max = SSL_PROTOCOL_VERSION_TLS1_1; + + // Certificate provided by the host doesn't need authority. + net::SSLConfig::CertAndStatus cert_and_status; + cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; + cert_and_status.der_cert = cert_der; + ssl_config.allowed_bad_certs.push_back(cert_and_status); + + net::HostPortPair host_and_pair("unittest", 0); + net::SSLClientSocketContext context; + context.cert_verifier = cert_verifier_.get(); + context.transport_security_state = transport_security_state_.get(); + client_socket_ = + socket_factory_->CreateSSLClientSocket( + client_connection.Pass(), host_and_pair, ssl_config, context); + server_socket_ = net::CreateSSLServerSocket( + server_socket.Pass(), + cert.get(), private_key.get(), net::SSLConfig()); + } + + FakeDataChannel channel_1_; + FakeDataChannel channel_2_; + scoped_ptr<net::SSLClientSocket> client_socket_; + scoped_ptr<net::SSLServerSocket> server_socket_; + net::ClientSocketFactory* socket_factory_; + scoped_ptr<net::MockCertVerifier> cert_verifier_; + scoped_ptr<net::TransportSecurityState> transport_security_state_; +}; + +// SSLServerSocket is only implemented using NSS. +#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) + +// This test only executes creation of client and server sockets. This is to +// test that creation of sockets doesn't crash and have minimal code to run +// under valgrind in order to help debugging memory problems. +TEST_F(SSLServerSocketTest, Initialize) { + Initialize(); +} + +// This test executes Connect() on SSLClientSocket and Handshake() on +// SSLServerSocket to make sure handshaking between the two sockets is +// completed successfully. +TEST_F(SSLServerSocketTest, Handshake) { + Initialize(); + + TestCompletionCallback connect_callback; + TestCompletionCallback handshake_callback; + + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + int client_ret = client_socket_->Connect(connect_callback.callback()); + EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + EXPECT_EQ(net::OK, handshake_callback.WaitForResult()); + } + + // Make sure the cert status is expected. + SSLInfo ssl_info; + client_socket_->GetSSLInfo(&ssl_info); + EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); +} + +TEST_F(SSLServerSocketTest, DataTransfer) { + Initialize(); + + TestCompletionCallback connect_callback; + TestCompletionCallback handshake_callback; + + // Establish connection. + int client_ret = client_socket_->Connect(connect_callback.callback()); + ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + client_ret = connect_callback.GetResult(client_ret); + ASSERT_EQ(net::OK, client_ret); + server_ret = handshake_callback.GetResult(server_ret); + ASSERT_EQ(net::OK, server_ret); + + const int kReadBufSize = 1024; + scoped_refptr<net::StringIOBuffer> write_buf = + new net::StringIOBuffer("testing123"); + scoped_refptr<net::DrainableIOBuffer> read_buf = + new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize), + kReadBufSize); + + // Write then read. + TestCompletionCallback write_callback; + TestCompletionCallback read_callback; + server_ret = server_socket_->Write( + write_buf.get(), write_buf->size(), write_callback.callback()); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Read( + read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + server_ret = write_callback.GetResult(server_ret); + EXPECT_GT(server_ret, 0); + client_ret = read_callback.GetResult(client_ret); + ASSERT_GT(client_ret, 0); + + read_buf->DidConsume(client_ret); + while (read_buf->BytesConsumed() < write_buf->size()) { + client_ret = client_socket_->Read( + read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + client_ret = read_callback.GetResult(client_ret); + ASSERT_GT(client_ret, 0); + read_buf->DidConsume(client_ret); + } + EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); + read_buf->SetOffset(0); + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); + + // Read then write. + write_buf = new net::StringIOBuffer("hello123"); + server_ret = server_socket_->Read( + read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + client_ret = client_socket_->Write( + write_buf.get(), write_buf->size(), write_callback.callback()); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + server_ret = read_callback.GetResult(server_ret); + ASSERT_GT(server_ret, 0); + client_ret = write_callback.GetResult(client_ret); + EXPECT_GT(client_ret, 0); + + read_buf->DidConsume(server_ret); + while (read_buf->BytesConsumed() < write_buf->size()) { + server_ret = server_socket_->Read( + read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + server_ret = read_callback.GetResult(server_ret); + ASSERT_GT(server_ret, 0); + read_buf->DidConsume(server_ret); + } + EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); + read_buf->SetOffset(0); + EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); +} + +// A regression test for bug 127822 (http://crbug.com/127822). +// If the server closes the connection after the handshake is finished, +// the client's Write() call should not cause an infinite loop. +// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. +TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { + Initialize(); + + TestCompletionCallback connect_callback; + TestCompletionCallback handshake_callback; + + // Establish connection. + int client_ret = client_socket_->Connect(connect_callback.callback()); + ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + client_ret = connect_callback.GetResult(client_ret); + ASSERT_EQ(net::OK, client_ret); + server_ret = handshake_callback.GetResult(server_ret); + ASSERT_EQ(net::OK, server_ret); + + scoped_refptr<net::StringIOBuffer> write_buf = + new net::StringIOBuffer("testing123"); + + // The server closes the connection. The server needs to write some + // data first so that the client's Read() calls from the transport + // socket won't return ERR_IO_PENDING. This ensures that the client + // will call Read() on the transport socket again. + TestCompletionCallback write_callback; + + server_ret = server_socket_->Write( + write_buf.get(), write_buf->size(), write_callback.callback()); + EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + + server_ret = write_callback.GetResult(server_ret); + EXPECT_GT(server_ret, 0); + + server_socket_->Disconnect(); + + // The client writes some data. This should not cause an infinite loop. + client_ret = client_socket_->Write( + write_buf.get(), write_buf->size(), write_callback.callback()); + EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + + client_ret = write_callback.GetResult(client_ret); + EXPECT_GT(client_ret, 0); + + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, base::MessageLoop::QuitClosure(), + base::TimeDelta::FromMilliseconds(10)); + base::MessageLoop::current()->Run(); +} + +// This test executes ExportKeyingMaterial() on the client and server sockets, +// after connecting them, and verifies that the results match. +// This test will fail if False Start is enabled (see crbug.com/90208). +TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { + Initialize(); + + TestCompletionCallback connect_callback; + TestCompletionCallback handshake_callback; + + int client_ret = client_socket_->Connect(connect_callback.callback()); + ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + + if (client_ret == net::ERR_IO_PENDING) { + ASSERT_EQ(net::OK, connect_callback.WaitForResult()); + } + if (server_ret == net::ERR_IO_PENDING) { + ASSERT_EQ(net::OK, handshake_callback.WaitForResult()); + } + + const int kKeyingMaterialSize = 32; + const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; + const char* kKeyingContext = ""; + unsigned char server_out[kKeyingMaterialSize]; + int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, + false, kKeyingContext, + server_out, sizeof(server_out)); + ASSERT_EQ(net::OK, rv); + + unsigned char client_out[kKeyingMaterialSize]; + rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, + false, kKeyingContext, + client_out, sizeof(client_out)); + ASSERT_EQ(net::OK, rv); + EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); + + const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; + unsigned char client_bad[kKeyingMaterialSize]; + rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, + false, kKeyingContext, + client_bad, sizeof(client_bad)); + ASSERT_EQ(rv, net::OK); + EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); +} +#endif + +} // namespace net diff --git a/chromium/net/socket/ssl_socket.h b/chromium/net/socket/ssl_socket.h new file mode 100644 index 00000000000..68d1e4a2bfe --- /dev/null +++ b/chromium/net/socket/ssl_socket.h @@ -0,0 +1,37 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SSL_SOCKET_H_ +#define NET_SOCKET_SSL_SOCKET_H_ + +#include "base/basictypes.h" +#include "base/strings/string_piece.h" +#include "net/socket/stream_socket.h" + +namespace net { + +// SSLSocket interface defines method that are common between client +// and server SSL sockets. +class NET_EXPORT SSLSocket : public StreamSocket { +public: + virtual ~SSLSocket() {} + + // Exports data derived from the SSL master-secret (see RFC 5705). + // If |has_context| is false, uses the no-context construction from the + // RFC and |context| is ignored. The call will fail with an error if + // the socket is not connected or the SSL implementation does not + // support the operation. + virtual int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) = 0; + + // Stores the the tls-unique channel binding (see RFC 5929) in |*out|. + virtual int GetTLSUniqueChannelBinding(std::string* out) = 0; +}; + +} // namespace net + +#endif // NET_SOCKET_SSL_SOCKET_H_ diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc new file mode 100644 index 00000000000..c85c671800d --- /dev/null +++ b/chromium/net/socket/stream_listen_socket.cc @@ -0,0 +1,308 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/stream_listen_socket.h" + +#if defined(OS_WIN) +// winsock2.h must be included first in order to ensure it is included before +// windows.h. +#include <winsock2.h> +#elif defined(OS_POSIX) +#include <arpa/inet.h> +#include <errno.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include "net/base/net_errors.h" +#endif + +#include "base/logging.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/posix/eintr_wrapper.h" +#include "base/sys_byteorder.h" +#include "base/threading/platform_thread.h" +#include "build/build_config.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +using std::string; + +#if defined(OS_WIN) +typedef int socklen_t; +#endif // defined(OS_WIN) + +namespace net { + +namespace { + +const int kReadBufSize = 4096; + +} // namespace + +#if defined(OS_WIN) +const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET; +const int StreamListenSocket::kSocketError = SOCKET_ERROR; +#elif defined(OS_POSIX) +const SocketDescriptor StreamListenSocket::kInvalidSocket = -1; +const int StreamListenSocket::kSocketError = -1; +#endif + +StreamListenSocket::StreamListenSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del) + : socket_delegate_(del), + socket_(s), + reads_paused_(false), + has_pending_reads_(false) { +#if defined(OS_WIN) + socket_event_ = WSACreateEvent(); + // TODO(ibrar): error handling in case of socket_event_ == WSA_INVALID_EVENT. + WatchSocket(NOT_WAITING); +#elif defined(OS_POSIX) + wait_state_ = NOT_WAITING; +#endif +} + +StreamListenSocket::~StreamListenSocket() { +#if defined(OS_WIN) + if (socket_event_) { + WSACloseEvent(socket_event_); + socket_event_ = WSA_INVALID_EVENT; + } +#endif + CloseSocket(socket_); +} + +void StreamListenSocket::Send(const char* bytes, int len, + bool append_linefeed) { + SendInternal(bytes, len); + if (append_linefeed) + SendInternal("\r\n", 2); +} + +void StreamListenSocket::Send(const string& str, bool append_linefeed) { + Send(str.data(), static_cast<int>(str.length()), append_linefeed); +} + +int StreamListenSocket::GetLocalAddress(IPEndPoint* address) { + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len)) { +#if defined(OS_WIN) + int err = WSAGetLastError(); +#else + int err = errno; +#endif + return MapSystemError(err); + } + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_FAILED; + return OK; +} + +SocketDescriptor StreamListenSocket::AcceptSocket() { + SocketDescriptor conn = HANDLE_EINTR(accept(socket_, NULL, NULL)); + if (conn == kInvalidSocket) + LOG(ERROR) << "Error accepting connection."; + else + SetNonBlocking(conn); + return conn; +} + +void StreamListenSocket::SendInternal(const char* bytes, int len) { + char* send_buf = const_cast<char *>(bytes); + int len_left = len; + while (true) { + int sent = HANDLE_EINTR(send(socket_, send_buf, len_left, 0)); + if (sent == len_left) { // A shortcut to avoid extraneous checks. + break; + } + if (sent == kSocketError) { +#if defined(OS_WIN) + if (WSAGetLastError() != WSAEWOULDBLOCK) { + LOG(ERROR) << "send failed: WSAGetLastError()==" << WSAGetLastError(); +#elif defined(OS_POSIX) + if (errno != EWOULDBLOCK && errno != EAGAIN) { + LOG(ERROR) << "send failed: errno==" << errno; +#endif + break; + } + // Otherwise we would block, and now we have to wait for a retry. + // Fall through to PlatformThread::YieldCurrentThread() + } else { + // sent != len_left according to the shortcut above. + // Shift the buffer start and send the remainder after a short while. + send_buf += sent; + len_left -= sent; + } + base::PlatformThread::YieldCurrentThread(); + } +} + +void StreamListenSocket::Listen() { + int backlog = 10; // TODO(erikkay): maybe don't allow any backlog? + if (listen(socket_, backlog) == -1) { + // TODO(erikkay): error handling. + LOG(ERROR) << "Could not listen on socket."; + return; + } +#if defined(OS_POSIX) + WatchSocket(WAITING_ACCEPT); +#endif +} + +void StreamListenSocket::Read() { + char buf[kReadBufSize + 1]; // +1 for null termination. + int len; + do { + len = HANDLE_EINTR(recv(socket_, buf, kReadBufSize, 0)); + if (len == kSocketError) { +#if defined(OS_WIN) + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { +#elif defined(OS_POSIX) + if (errno == EWOULDBLOCK || errno == EAGAIN) { +#endif + break; + } else { + // TODO(ibrar): some error handling required here. + break; + } + } else if (len == 0) { + // In Windows, Close() is called by OnObjectSignaled. In POSIX, we need + // to call it here. +#if defined(OS_POSIX) + Close(); +#endif + } else { + // TODO(ibrar): maybe change DidRead to take a length instead. + DCHECK_GT(len, 0); + DCHECK_LE(len, kReadBufSize); + buf[len] = 0; // Already create a buffer with +1 length. + socket_delegate_->DidRead(this, buf, len); + } + } while (len == kReadBufSize); +} + +void StreamListenSocket::Close() { +#if defined(OS_POSIX) + if (wait_state_ == NOT_WAITING) + return; + wait_state_ = NOT_WAITING; +#endif + UnwatchSocket(); + socket_delegate_->DidClose(this); +} + +void StreamListenSocket::CloseSocket(SocketDescriptor s) { + if (s && s != kInvalidSocket) { + UnwatchSocket(); +#if defined(OS_WIN) + closesocket(s); +#elif defined(OS_POSIX) + close(s); +#endif + } +} + +void StreamListenSocket::WatchSocket(WaitState state) { +#if defined(OS_WIN) + WSAEventSelect(socket_, socket_event_, FD_ACCEPT | FD_CLOSE | FD_READ); + watcher_.StartWatching(socket_event_, this); +#elif defined(OS_POSIX) + // Implicitly calls StartWatchingFileDescriptor(). + base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_READ, &watcher_, this); + wait_state_ = state; +#endif +} + +void StreamListenSocket::UnwatchSocket() { +#if defined(OS_WIN) + watcher_.StopWatching(); +#elif defined(OS_POSIX) + watcher_.StopWatchingFileDescriptor(); +#endif +} + +// TODO(ibrar): We can add these functions into OS dependent files. +#if defined(OS_WIN) +// MessageLoop watcher callback. +void StreamListenSocket::OnObjectSignaled(HANDLE object) { + WSANETWORKEVENTS ev; + if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) { + // TODO + return; + } + + if (ev.lNetworkEvents & FD_CLOSE) { + Close(); + // Close might have deleted this object. We should return immediately. + return; + } + + // The object was reset by WSAEnumNetworkEvents. Watch for the next signal. + watcher_.StartWatching(object, this); + + if (ev.lNetworkEvents == 0) { + // Occasionally the event is set even though there is no new data. + // The net seems to think that this is ignorable. + return; + } + if (ev.lNetworkEvents & FD_ACCEPT) { + Accept(); + } + if (ev.lNetworkEvents & FD_READ) { + if (reads_paused_) { + has_pending_reads_ = true; + } else { + Read(); + // Read() might call Close() internally and 'this' can be invalid here + return; + } + } +} +#elif defined(OS_POSIX) +void StreamListenSocket::OnFileCanReadWithoutBlocking(int fd) { + switch (wait_state_) { + case WAITING_ACCEPT: + Accept(); + break; + case WAITING_READ: + if (reads_paused_) { + has_pending_reads_ = true; + } else { + Read(); + } + break; + default: + // Close() is called by Read() in the Linux case. + NOTREACHED(); + break; + } +} + +void StreamListenSocket::OnFileCanWriteWithoutBlocking(int fd) { + // MessagePumpLibevent callback, we don't listen for write events + // so we shouldn't ever reach here. + NOTREACHED(); +} + +#endif + +void StreamListenSocket::PauseReads() { + DCHECK(!reads_paused_); + reads_paused_ = true; +} + +void StreamListenSocket::ResumeReads() { + DCHECK(reads_paused_); + reads_paused_ = false; + if (has_pending_reads_) { + has_pending_reads_ = false; + Read(); + } +} + +} // namespace net diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h new file mode 100644 index 00000000000..6f03eefaca2 --- /dev/null +++ b/chromium/net/socket/stream_listen_socket.h @@ -0,0 +1,155 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Stream-based listen socket implementation that handles reading and writing +// to the socket, but does not handle creating the socket nor connecting +// sockets, which are handled by subclasses on creation and in Accept, +// respectively. + +// StreamListenSocket handles IO asynchronously in the specified MessageLoop. +// This class is NOT thread safe. It uses WSAEVENT handles to monitor activity +// in a given MessageLoop. This means that callbacks will happen in that loop's +// thread always and that all other methods (including constructor and +// destructor) should also be called from the same thread. + +#ifndef NET_SOCKET_STREAM_LISTEN_SOCKET_H_ +#define NET_SOCKET_STREAM_LISTEN_SOCKET_H_ + +#include "build/build_config.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif +#include <string> +#if defined(OS_WIN) +#include "base/win/object_watcher.h" +#elif defined(OS_POSIX) +#include "base/message_loop/message_loop.h" +#endif + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "net/base/net_export.h" +#include "net/socket/stream_listen_socket.h" + +#if defined(OS_POSIX) +typedef int SocketDescriptor; +#else +typedef SOCKET SocketDescriptor; +#endif + +namespace net { + +class IPEndPoint; + +class NET_EXPORT StreamListenSocket + : public base::RefCountedThreadSafe<StreamListenSocket>, +#if defined(OS_WIN) + public base::win::ObjectWatcher::Delegate { +#elif defined(OS_POSIX) + public base::MessageLoopForIO::Watcher { +#endif + + public: + // TODO(erikkay): this delegate should really be split into two parts + // to split up the listener from the connected socket. Perhaps this class + // should be split up similarly. + class Delegate { + public: + // |server| is the original listening Socket, connection is the new + // Socket that was created. Ownership of |connection| is transferred + // to the delegate with this call. + virtual void DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) = 0; + virtual void DidRead(StreamListenSocket* connection, + const char* data, + int len) = 0; + virtual void DidClose(StreamListenSocket* sock) = 0; + + protected: + virtual ~Delegate() {} + }; + + // Send data to the socket. + void Send(const char* bytes, int len, bool append_linefeed = false); + void Send(const std::string& str, bool append_linefeed = false); + + // Copies the local address to |address|. Returns a network error code. + int GetLocalAddress(IPEndPoint* address); + + static const SocketDescriptor kInvalidSocket; + static const int kSocketError; + + protected: + enum WaitState { + NOT_WAITING = 0, + WAITING_ACCEPT = 1, + WAITING_READ = 2 + }; + + StreamListenSocket(SocketDescriptor s, Delegate* del); + virtual ~StreamListenSocket(); + + SocketDescriptor AcceptSocket(); + virtual void Accept() = 0; + + void Listen(); + void Read(); + void Close(); + void CloseSocket(SocketDescriptor s); + + // Pass any value in case of Windows, because in Windows + // we are not using state. + void WatchSocket(WaitState state); + void UnwatchSocket(); + + Delegate* const socket_delegate_; + + private: + friend class base::RefCountedThreadSafe<StreamListenSocket>; + friend class TransportClientSocketTest; + + void SendInternal(const char* bytes, int len); + +#if defined(OS_WIN) + // ObjectWatcher delegate. + virtual void OnObjectSignaled(HANDLE object); + base::win::ObjectWatcher watcher_; + HANDLE socket_event_; +#elif defined(OS_POSIX) + // Called by MessagePumpLibevent when the socket is ready to do I/O. + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; + WaitState wait_state_; + // The socket's libevent wrapper. + base::MessageLoopForIO::FileDescriptorWatcher watcher_; +#endif + + // NOTE: This is for unit test use only! + // Pause/Resume calling Read(). Note that ResumeReads() will also call + // Read() if there is anything to read. + void PauseReads(); + void ResumeReads(); + + const SocketDescriptor socket_; + bool reads_paused_; + bool has_pending_reads_; + + DISALLOW_COPY_AND_ASSIGN(StreamListenSocket); +}; + +// Abstract factory that must be subclassed for each subclass of +// StreamListenSocket. +class NET_EXPORT StreamListenSocketFactory { + public: + virtual ~StreamListenSocketFactory() {} + + // Returns a new instance of StreamListenSocket or NULL if an error occurred. + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const = 0; +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_LISTEN_SOCKET_H_ diff --git a/chromium/net/socket/stream_socket.cc b/chromium/net/socket/stream_socket.cc new file mode 100644 index 00000000000..fb194f64625 --- /dev/null +++ b/chromium/net/socket/stream_socket.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/stream_socket.h" + +#include "base/metrics/field_trial.h" +#include "base/metrics/histogram.h" +#include "base/strings/string_number_conversions.h" +#include "base/values.h" + +namespace net { + +StreamSocket::UseHistory::UseHistory() + : was_ever_connected_(false), + was_used_to_convey_data_(false), + omnibox_speculation_(false), + subresource_speculation_(false) { +} + +StreamSocket::UseHistory::~UseHistory() { + EmitPreconnectionHistograms(); +} + +void StreamSocket::UseHistory::Reset() { + EmitPreconnectionHistograms(); + was_ever_connected_ = false; + was_used_to_convey_data_ = false; + // omnibox_speculation_ and subresource_speculation_ values + // are intentionally preserved. +} + +void StreamSocket::UseHistory::set_was_ever_connected() { + DCHECK(!was_used_to_convey_data_); + was_ever_connected_ = true; +} + +void StreamSocket::UseHistory::set_was_used_to_convey_data() { + DCHECK(was_ever_connected_); + was_used_to_convey_data_ = true; +} + + +void StreamSocket::UseHistory::set_subresource_speculation() { + DCHECK(was_ever_connected_); + // TODO(jar): We should transition to marking a socket (or stream) at + // construction time as being created for speculative reasons. This current + // approach of trying to track use of a socket to convey data can make + // mistakes when other sockets (such as ones sitting in the pool for a long + // time) are issued. Unused sockets can be left over when a when a set of + // connections to a host are made, and one is "unlucky" and takes so long to + // complete a connection, that another socket is used, and recycled before a + // second connection comes available. Similarly, re-try connections can leave + // an original (slow to connect socket) in the pool, and that can be issued + // to a speculative requester. In any cases such old sockets will fail when an + // attempt is made to used them!... and then it will look like a speculative + // socket was discarded without any user!?!?! + if (was_used_to_convey_data_) + return; + subresource_speculation_ = true; +} + +void StreamSocket::UseHistory::set_omnibox_speculation() { + DCHECK(was_ever_connected_); + if (was_used_to_convey_data_) + return; + omnibox_speculation_ = true; +} + +bool StreamSocket::UseHistory::was_used_to_convey_data() const { + DCHECK(!was_used_to_convey_data_ || was_ever_connected_); + return was_used_to_convey_data_; +} + +void StreamSocket::UseHistory::EmitPreconnectionHistograms() const { + DCHECK(!subresource_speculation_ || !omnibox_speculation_); + // 0 ==> non-speculative, never connected. + // 1 ==> non-speculative never used (but connected). + // 2 ==> non-speculative and used. + // 3 ==> omnibox_speculative never connected. + // 4 ==> omnibox_speculative never used (but connected). + // 5 ==> omnibox_speculative and used. + // 6 ==> subresource_speculative never connected. + // 7 ==> subresource_speculative never used (but connected). + // 8 ==> subresource_speculative and used. + int result; + if (was_used_to_convey_data_) + result = 2; + else if (was_ever_connected_) + result = 1; + else + result = 0; // Never used, and not really connected. + + if (omnibox_speculation_) + result += 3; + else if (subresource_speculation_) + result += 6; + UMA_HISTOGRAM_ENUMERATION("Net.PreconnectUtilization2", result, 9); +} + +} // namespace net diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h new file mode 100644 index 00000000000..38eec86dd92 --- /dev/null +++ b/chromium/net/socket/stream_socket.h @@ -0,0 +1,138 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_STREAM_SOCKET_H_ +#define NET_SOCKET_STREAM_SOCKET_H_ + +#include "net/base/net_log.h" +#include "net/socket/next_proto.h" +#include "net/socket/socket.h" + +namespace net { + +class AddressList; +class IPEndPoint; +class SSLInfo; + +class NET_EXPORT_PRIVATE StreamSocket : public Socket { + public: + virtual ~StreamSocket() {} + + // Called to establish a connection. Returns OK if the connection could be + // established synchronously. Otherwise, ERR_IO_PENDING is returned and the + // given callback will run asynchronously when the connection is established + // or when an error occurs. The result is some other error code if the + // connection could not be established. + // + // The socket's Read and Write methods may not be called until Connect + // succeeds. + // + // It is valid to call Connect on an already connected socket, in which case + // OK is simply returned. + // + // Connect may also be called again after a call to the Disconnect method. + // + virtual int Connect(const CompletionCallback& callback) = 0; + + // Called to disconnect a socket. Does nothing if the socket is already + // disconnected. After calling Disconnect it is possible to call Connect + // again to establish a new connection. + // + // If IO (Connect, Read, or Write) is pending when the socket is + // disconnected, the pending IO is cancelled, and the completion callback + // will not be called. + virtual void Disconnect() = 0; + + // Called to test if the connection is still alive. Returns false if a + // connection wasn't established or the connection is dead. + virtual bool IsConnected() const = 0; + + // Called to test if the connection is still alive and idle. Returns false + // if a connection wasn't established, the connection is dead, or some data + // have been received. + virtual bool IsConnectedAndIdle() const = 0; + + // Copies the peer address to |address| and returns a network error code. + // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not connected. + virtual int GetPeerAddress(IPEndPoint* address) const = 0; + + // Copies the local address to |address| and returns a network error code. + // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not bound. + virtual int GetLocalAddress(IPEndPoint* address) const = 0; + + // Gets the NetLog for this socket. + virtual const BoundNetLog& NetLog() const = 0; + + // Set the annotation to indicate this socket was created for speculative + // reasons. This call is generally forwarded to a basic TCPClientSocket*, + // where a UseHistory can be updated. + virtual void SetSubresourceSpeculation() = 0; + virtual void SetOmniboxSpeculation() = 0; + + // Returns true if the underlying transport socket ever had any reads or + // writes. StreamSockets layered on top of transport sockets should forward + // this call to the transport socket. + virtual bool WasEverUsed() const = 0; + + // Returns true if the underlying transport socket is using TCP FastOpen. + // TCP FastOpen is an experiment with sending data in the TCP SYN packet. + virtual bool UsingTCPFastOpen() const = 0; + + // Returns true if NPN was negotiated during the connection of this socket. + virtual bool WasNpnNegotiated() const = 0; + + // Returns the protocol negotiated via NPN for this socket, or + // kProtoUnknown will be returned if NPN is not applicable. + virtual NextProto GetNegotiatedProtocol() const = 0; + + // Gets the SSL connection information of the socket. Returns false if + // SSL was not used by this socket. + virtual bool GetSSLInfo(SSLInfo* ssl_info) = 0; + + protected: + // The following class is only used to gather statistics about the history of + // a socket. It is only instantiated and used in basic sockets, such as + // TCPClientSocket* instances. Other classes that are derived from + // StreamSocket should forward any potential settings to their underlying + // transport sockets. + class UseHistory { + public: + UseHistory(); + ~UseHistory(); + + // Resets the state of UseHistory and emits histograms for the + // current state. + void Reset(); + + void set_was_ever_connected(); + void set_was_used_to_convey_data(); + + // The next two setters only have any impact if the socket has not yet been + // used to transmit data. If called later, we assume that the socket was + // reused from the pool, and was NOT constructed to service a speculative + // request. + void set_subresource_speculation(); + void set_omnibox_speculation(); + + bool was_used_to_convey_data() const; + + private: + // Summarize the statistics for this socket. + void EmitPreconnectionHistograms() const; + // Indicate if this was ever connected. + bool was_ever_connected_; + // Indicate if this socket was ever used to transmit or receive data. + bool was_used_to_convey_data_; + + // Indicate if this socket was first created for speculative use, and + // identify the motivation. + bool omnibox_speculation_; + bool subresource_speculation_; + DISALLOW_COPY_AND_ASSIGN(UseHistory); + }; +}; + +} // namespace net + +#endif // NET_SOCKET_STREAM_SOCKET_H_ diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc new file mode 100644 index 00000000000..dbd21056f39 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_client_socket.h" + +#include "base/file_util.h" +#include "base/files/file_path.h" + +namespace net { + +namespace { + +#if defined(OS_LINUX) + +// Checks to see if the system supports TCP FastOpen. Notably, it requires +// kernel support. Additionally, this checks system configuration to ensure that +// it's enabled. +bool SystemSupportsTCPFastOpen() { + static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = + "/proc/sys/net/ipv4/tcp_fastopen"; + std::string system_enabled_tcp_fastopen; + if (!file_util::ReadFileToString( + base::FilePath(kTCPFastOpenProcFilePath), + &system_enabled_tcp_fastopen)) { + return false; + } + + // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 + // TFO_CLIENT_ENABLE is the LSB + if (system_enabled_tcp_fastopen.empty() || + (system_enabled_tcp_fastopen[0] & 0x1) == 0) { + return false; + } + + return true; +} + +#else + +bool SystemSupportsTCPFastOpen() { + return false; +} + +#endif + +} + +static bool g_tcp_fastopen_enabled = false; + +void SetTCPFastOpenEnabled(bool value) { + g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen(); +} + +bool IsTCPFastOpenEnabled() { + return g_tcp_fastopen_enabled; +} + +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h new file mode 100644 index 00000000000..8a2c0cd73f0 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket.h @@ -0,0 +1,35 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_H_ +#define NET_SOCKET_TCP_CLIENT_SOCKET_H_ + +#include "build/build_config.h" +#include "net/base/net_export.h" + +#if defined(OS_WIN) +#include "net/socket/tcp_client_socket_win.h" +#elif defined(OS_POSIX) +#include "net/socket/tcp_client_socket_libevent.h" +#endif + +namespace net { + +// A client socket that uses TCP as the transport layer. +#if defined(OS_WIN) +typedef TCPClientSocketWin TCPClientSocket; +#elif defined(OS_POSIX) +typedef TCPClientSocketLibevent TCPClientSocket; +#endif + +// Enable/disable experimental TCP FastOpen option. +// Not thread safe. Must be called during initialization/startup only. +NET_EXPORT void SetTCPFastOpenEnabled(bool value); + +// Check if the TCP FastOpen option is enabled. +bool IsTCPFastOpenEnabled(); + +} // namespace net + +#endif // NET_SOCKET_TCP_CLIENT_SOCKET_H_ diff --git a/chromium/net/socket/tcp_client_socket_libevent.cc b/chromium/net/socket/tcp_client_socket_libevent.cc new file mode 100644 index 00000000000..2f7e4b4b255 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket_libevent.cc @@ -0,0 +1,830 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_client_socket.h" + +#include <errno.h> +#include <fcntl.h> +#include <netdb.h> +#include <sys/socket.h> +#include <netinet/tcp.h> +#if defined(OS_POSIX) +#include <netinet/in.h> +#endif + +#include "base/logging.h" +#include "base/message_loop/message_loop.h" +#include "base/metrics/histogram.h" +#include "base/metrics/stats_counters.h" +#include "base/posix/eintr_wrapper.h" +#include "base/strings/string_util.h" +#include "net/base/connection_type_histograms.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/socket/socket_net_log_params.h" + +// If we don't have a definition for TCPI_OPT_SYN_DATA, create one. +#ifndef TCPI_OPT_SYN_DATA +#define TCPI_OPT_SYN_DATA 32 +#endif + +namespace net { + +namespace { + +const int kInvalidSocket = -1; +const int kTCPKeepAliveSeconds = 45; + +// SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets +// will wait up to 200ms for more data to complete a packet before transmitting. +// After calling this function, the kernel will not wait. See TCP_NODELAY in +// `man 7 tcp`. +bool SetTCPNoDelay(int fd, bool no_delay) { + int on = no_delay ? 1 : 0; + int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, + sizeof(on)); + return error == 0; +} + +// SetTCPKeepAlive sets SO_KEEPALIVE. +bool SetTCPKeepAlive(int fd, bool enable, int delay) { + int on = enable ? 1 : 0; + if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) { + PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd; + return false; + } +#if defined(OS_LINUX) || defined(OS_ANDROID) + // Set seconds until first TCP keep alive. + if (setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &delay, sizeof(delay))) { + PLOG(ERROR) << "Failed to set TCP_KEEPIDLE on fd: " << fd; + return false; + } + // Set seconds between TCP keep alives. + if (setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &delay, sizeof(delay))) { + PLOG(ERROR) << "Failed to set TCP_KEEPINTVL on fd: " << fd; + return false; + } +#endif + return true; +} + +// Sets socket parameters. Returns the OS error code (or 0 on +// success). +int SetupSocket(int socket) { + if (SetNonBlocking(socket)) + return errno; + + // This mirrors the behaviour on Windows. See the comment in + // tcp_client_socket_win.cc after searching for "NODELAY". + SetTCPNoDelay(socket, true); // If SetTCPNoDelay fails, we don't care. + SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); + + return 0; +} + +// Creates a new socket and sets default parameters for it. Returns +// the OS error code (or 0 on success). +int CreateSocket(int family, int* socket) { + *socket = ::socket(family, SOCK_STREAM, IPPROTO_TCP); + if (*socket == kInvalidSocket) + return errno; + int error = SetupSocket(*socket); + if (error) { + if (HANDLE_EINTR(close(*socket)) < 0) + PLOG(ERROR) << "close"; + *socket = kInvalidSocket; + return error; + } + return 0; +} + +int MapConnectError(int os_error) { + switch (os_error) { + case EACCES: + return ERR_NETWORK_ACCESS_DENIED; + case ETIMEDOUT: + return ERR_CONNECTION_TIMED_OUT; + default: { + int net_error = MapSystemError(os_error); + if (net_error == ERR_FAILED) + return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED. + + // Give a more specific error when the user is offline. + if (net_error == ERR_ADDRESS_UNREACHABLE && + NetworkChangeNotifier::IsOffline()) { + return ERR_INTERNET_DISCONNECTED; + } + return net_error; + } + } +} + +} // namespace + +//----------------------------------------------------------------------------- + +TCPClientSocketLibevent::TCPClientSocketLibevent( + const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(kInvalidSocket), + bound_socket_(kInvalidSocket), + addresses_(addresses), + current_address_index_(-1), + read_watcher_(this), + write_watcher_(this), + next_connect_state_(CONNECT_STATE_NONE), + connect_os_error_(0), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + previously_disconnected_(false), + use_tcp_fastopen_(IsTCPFastOpenEnabled()), + tcp_fastopen_connected_(false), + fast_open_status_(FAST_OPEN_STATUS_UNKNOWN) { + net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, + source.ToEventParametersCallback()); +} + +TCPClientSocketLibevent::~TCPClientSocketLibevent() { + Disconnect(); + net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); + if (tcp_fastopen_connected_) { + UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection", + fast_open_status_, FAST_OPEN_MAX_VALUE); + } +} + +int TCPClientSocketLibevent::AdoptSocket(int socket) { + DCHECK_EQ(socket_, kInvalidSocket); + + int error = SetupSocket(socket); + if (error) + return MapSystemError(error); + + socket_ = socket; + + // This is to make GetPeerAddress() work. It's up to the caller ensure + // that |address_| contains a reasonable address for this + // socket. (i.e. at least match IPv4 vs IPv6!). + current_address_index_ = 0; + use_history_.set_was_ever_connected(); + + return OK; +} + +int TCPClientSocketLibevent::Bind(const IPEndPoint& address) { + if (current_address_index_ >= 0 || bind_address_.get()) { + // Cannot bind the socket if we are already bound connected or + // connecting. + return ERR_UNEXPECTED; + } + + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + + // Create |bound_socket_| and try to bind it to |address|. + int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); + if (error) + return MapSystemError(error); + + if (HANDLE_EINTR(bind(bound_socket_, storage.addr, storage.addr_len))) { + error = errno; + if (HANDLE_EINTR(close(bound_socket_)) < 0) + PLOG(ERROR) << "close"; + bound_socket_ = kInvalidSocket; + return MapSystemError(error); + } + + bind_address_.reset(new IPEndPoint(address)); + + return 0; +} + +int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + // If already connected, then just return OK. + if (socket_ != kInvalidSocket) + return OK; + + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + DCHECK(!waiting_connect()); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses_.CreateNetLogCallback()); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_address_index_ = 0; + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + write_callback_ = callback; + } else { + LogConnectCompletion(rv); + } + + return rv; +} + +int TCPClientSocketLibevent::DoConnectLoop(int result) { + DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); + + int rv = result; + do { + ConnectState state = next_connect_state_; + next_connect_state_ = CONNECT_STATE_NONE; + switch (state) { + case CONNECT_STATE_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoConnect(); + break; + case CONNECT_STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + default: + LOG(DFATAL) << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + + return rv; +} + +int TCPClientSocketLibevent::DoConnect() { + DCHECK_GE(current_address_index_, 0); + DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + DCHECK_EQ(0, connect_os_error_); + + const IPEndPoint& endpoint = addresses_[current_address_index_]; + + if (previously_disconnected_) { + use_history_.Reset(); + previously_disconnected_ = false; + } + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(&endpoint)); + + next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; + + if (bound_socket_ != kInvalidSocket) { + DCHECK(bind_address_.get()); + socket_ = bound_socket_; + bound_socket_ = kInvalidSocket; + } else { + // Create a non-blocking socket. + connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); + if (connect_os_error_) + return MapSystemError(connect_os_error_); + + if (bind_address_.get()) { + SockaddrStorage storage; + if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + if (HANDLE_EINTR(bind(socket_, storage.addr, storage.addr_len))) + return MapSystemError(errno); + } + } + + // Connect the socket. + if (!use_tcp_fastopen_) { + SockaddrStorage storage; + if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + + if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { + // Connected without waiting! + return OK; + } + } else { + // With TCP FastOpen, we pretend that the socket is connected. + DCHECK(!tcp_fastopen_connected_); + return OK; + } + + // Check if the connect() failed synchronously. + connect_os_error_ = errno; + if (connect_os_error_ != EINPROGRESS) + return MapConnectError(connect_os_error_); + + // Otherwise the connect() is going to complete asynchronously, so watch + // for its completion. + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + connect_os_error_ = errno; + DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; + return MapSystemError(connect_os_error_); + } + + return ERR_IO_PENDING; +} + +int TCPClientSocketLibevent::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); + } + + if (result == OK) { + write_socket_watcher_.StopWatchingFileDescriptor(); + use_history_.set_was_ever_connected(); + return OK; // Done! + } + + // Close whatever partially connected socket we currently have. + DoDisconnect(); + + // Try to fall back to the next address in the list. + if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { + next_connect_state_ = CONNECT_STATE_CONNECT; + ++current_address_index_; + return OK; + } + + // Otherwise there is nothing to fall back to, so give up. + return result; +} + +void TCPClientSocketLibevent::Disconnect() { + DCHECK(CalledOnValidThread()); + + DoDisconnect(); + current_address_index_ = -1; + bind_address_.reset(); +} + +void TCPClientSocketLibevent::DoDisconnect() { + if (socket_ == kInvalidSocket) + return; + + bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + if (HANDLE_EINTR(close(socket_)) < 0) + PLOG(ERROR) << "close"; + socket_ = kInvalidSocket; + previously_disconnected_ = true; +} + +bool TCPClientSocketLibevent::IsConnected() const { + DCHECK(CalledOnValidThread()); + + if (socket_ == kInvalidSocket || waiting_connect()) + return false; + + if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + // With TCP FastOpen, we pretend that the socket is connected. + // This allows GetPeerAddress() to return current_ai_ as the peer + // address. Since we don't fail over to the next address if + // sendto() fails, current_ai_ is the only possible peer address. + CHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + return true; + } + + // Check if connection is alive. + char c; + int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK)); + if (rv == 0) + return false; + if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) + return false; + + return true; +} + +bool TCPClientSocketLibevent::IsConnectedAndIdle() const { + DCHECK(CalledOnValidThread()); + + if (socket_ == kInvalidSocket || waiting_connect()) + return false; + + // TODO(wtc): should we also handle the TCP FastOpen case here, + // as we do in IsConnected()? + + // Check if connection is alive and we haven't received any data + // unexpectedly. + char c; + int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK)); + if (rv >= 0) + return false; + if (errno != EAGAIN && errno != EWOULDBLOCK) + return false; + + return true; +} + +int TCPClientSocketLibevent::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!waiting_connect()); + DCHECK(read_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_GT(buf_len, 0); + + int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len)); + if (nread >= 0) { + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(nread); + if (nread > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, + buf->data()); + RecordFastOpenStatus(); + return nread; + } + if (errno != EAGAIN && errno != EWOULDBLOCK) { + int net_error = MapSystemError(errno); + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(net_error, errno)); + return net_error; + } + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_READ, + &read_socket_watcher_, &read_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno; + return MapSystemError(errno); + } + + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + return ERR_IO_PENDING; +} + +int TCPClientSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_); + DCHECK(!waiting_connect()); + DCHECK(write_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_GT(buf_len, 0); + + int nwrite = InternalWrite(buf, buf_len); + if (nwrite >= 0) { + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(nwrite); + if (nwrite > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, + buf->data()); + return nwrite; + } + if (errno != EAGAIN && errno != EWOULDBLOCK) { + int net_error = MapSystemError(errno); + net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, + CreateNetLogSocketErrorCallback(net_error, errno)); + return net_error; + } + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno; + return MapSystemError(errno); + } + + write_buf_ = buf; + write_buf_len_ = buf_len; + write_callback_ = callback; + return ERR_IO_PENDING; +} + +int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { + int nwrite; + if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + SockaddrStorage storage; + if (!addresses_[current_address_index_].ToSockAddr(storage.addr, + &storage.addr_len)) { + errno = EINVAL; + return -1; + } + + int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. +#if defined(OS_LINUX) + // sendto() will fail with EPIPE when the system doesn't support TCP Fast + // Open. Theoretically that shouldn't happen since the caller should check + // for system support on startup, but users may dynamically disable TCP Fast + // Open via sysctl. + flags |= MSG_NOSIGNAL; +#endif // defined(OS_LINUX) + nwrite = HANDLE_EINTR(sendto(socket_, + buf->data(), + buf_len, + flags, + storage.addr, + storage.addr_len)); + tcp_fastopen_connected_ = true; + + if (nwrite < 0) { + DCHECK_NE(EPIPE, errno); + + // If errno == EINPROGRESS, that means the kernel didn't have a cookie + // and would block. The kernel is internally doing a connect() though. + // Remap EINPROGRESS to EAGAIN so we treat this the same as our other + // asynchronous cases. Note that the user buffer has not been copied to + // kernel space. + if (errno == EINPROGRESS) { + errno = EAGAIN; + fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; + } else { + fast_open_status_ = FAST_OPEN_ERROR; + } + } else { + fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; + } + } else { + nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); + } + return nwrite; +} + +bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) { + DCHECK(CalledOnValidThread()); + int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF, + reinterpret_cast<const char*>(&size), + sizeof(size)); + DCHECK(!rv) << "Could not set socket receive buffer size: " << errno; + return rv == 0; +} + +bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) { + DCHECK(CalledOnValidThread()); + int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF, + reinterpret_cast<const char*>(&size), + sizeof(size)); + DCHECK(!rv) << "Could not set socket send buffer size: " << errno; + return rv == 0; +} + +bool TCPClientSocketLibevent::SetKeepAlive(bool enable, int delay) { + int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; + return SetTCPKeepAlive(socket, enable, delay); +} + +bool TCPClientSocketLibevent::SetNoDelay(bool no_delay) { + int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; + return SetTCPNoDelay(socket, no_delay); +} + +void TCPClientSocketLibevent::ReadWatcher::OnFileCanReadWithoutBlocking(int) { + socket_->RecordFastOpenStatus(); + if (!socket_->read_callback_.is_null()) + socket_->DidCompleteRead(); +} + +void TCPClientSocketLibevent::WriteWatcher::OnFileCanWriteWithoutBlocking(int) { + if (socket_->waiting_connect()) { + socket_->DidCompleteConnect(); + } else if (!socket_->write_callback_.is_null()) { + socket_->DidCompleteWrite(); + } +} + +void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { + if (net_error == OK) + UpdateConnectionTypeHistograms(CONNECTION_ANY); + + if (net_error != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error); + return; + } + + SockaddrStorage storage; + int rv = getsockname(socket_, storage.addr, &storage.addr_len); + if (rv != 0) { + PLOG(ERROR) << "getsockname() [rv: " << rv << "] error: "; + NOTREACHED(); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv); + return; + } + + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT, + CreateNetLogSourceAddressCallback(storage.addr, + storage.addr_len)); +} + +void TCPClientSocketLibevent::DoReadCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(!read_callback_.is_null()); + + // since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); +} + +void TCPClientSocketLibevent::DoWriteCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(!write_callback_.is_null()); + + // since Run may result in Write being called, clear write_callback_ up front. + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); +} + +void TCPClientSocketLibevent::DidCompleteConnect() { + DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); + + // Get the error that connect() completed with. + int os_error = 0; + socklen_t len = sizeof(os_error); + if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) + os_error = errno; + + // TODO(eroman): Is this check really necessary? + if (os_error == EINPROGRESS || os_error == EALREADY) { + NOTREACHED(); // This indicates a bug in libevent or our code. + return; + } + + connect_os_error_ = os_error; + int rv = DoConnectLoop(MapConnectError(os_error)); + if (rv != ERR_IO_PENDING) { + LogConnectCompletion(rv); + DoWriteCallback(rv); + } +} + +void TCPClientSocketLibevent::DidCompleteRead() { + int bytes_transferred; + bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(), + read_buf_len_)); + + int result; + if (bytes_transferred >= 0) { + result = bytes_transferred; + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(bytes_transferred); + if (bytes_transferred > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result, + read_buf_->data()); + } else { + result = MapSystemError(errno); + if (result != ERR_IO_PENDING) { + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(result, errno)); + } + } + + if (result != ERR_IO_PENDING) { + read_buf_ = NULL; + read_buf_len_ = 0; + bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + DoReadCallback(result); + } +} + +void TCPClientSocketLibevent::DidCompleteWrite() { + int bytes_transferred; + bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(), + write_buf_len_)); + + int result; + if (bytes_transferred >= 0) { + result = bytes_transferred; + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(bytes_transferred); + if (bytes_transferred > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result, + write_buf_->data()); + } else { + result = MapSystemError(errno); + if (result != ERR_IO_PENDING) { + net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, + CreateNetLogSocketErrorCallback(result, errno)); + } + } + + if (result != ERR_IO_PENDING) { + write_buf_ = NULL; + write_buf_len_ = 0; + write_socket_watcher_.StopWatchingFileDescriptor(); + DoWriteCallback(result); + } +} + +int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = addresses_[current_address_index_]; + return OK; +} + +int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (socket_ == kInvalidSocket) { + if (bind_address_.get()) { + *address = *bind_address_; + return OK; + } + return ERR_SOCKET_NOT_CONNECTED; + } + + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len)) + return MapSystemError(errno); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_FAILED; + + return OK; +} + +void TCPClientSocketLibevent::RecordFastOpenStatus() { + if (use_tcp_fastopen_ && + (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN || + fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) { + DCHECK_NE(FAST_OPEN_STATUS_UNKNOWN, fast_open_status_); + bool getsockopt_success(false); + bool server_acked_data(false); +#if defined(TCP_INFO) + // Probe to see the if the socket used TCP Fast Open. + tcp_info info; + socklen_t info_len = sizeof(tcp_info); + getsockopt_success = + getsockopt(socket_, IPPROTO_TCP, TCP_INFO, &info, &info_len) == 0 && + info_len == sizeof(tcp_info); + server_acked_data = getsockopt_success && + (info.tcpi_options & TCPI_OPT_SYN_DATA); +#endif + if (getsockopt_success) { + if (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN) { + fast_open_status_ = (server_acked_data ? FAST_OPEN_SYN_DATA_ACK : + FAST_OPEN_SYN_DATA_NACK); + } else { + fast_open_status_ = (server_acked_data ? FAST_OPEN_NO_SYN_DATA_ACK : + FAST_OPEN_NO_SYN_DATA_NACK); + } + } else { + fast_open_status_ = (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ? + FAST_OPEN_SYN_DATA_FAILED : + FAST_OPEN_NO_SYN_DATA_FAILED); + } + } +} + +const BoundNetLog& TCPClientSocketLibevent::NetLog() const { + return net_log_; +} + +void TCPClientSocketLibevent::SetSubresourceSpeculation() { + use_history_.set_subresource_speculation(); +} + +void TCPClientSocketLibevent::SetOmniboxSpeculation() { + use_history_.set_omnibox_speculation(); +} + +bool TCPClientSocketLibevent::WasEverUsed() const { + return use_history_.was_used_to_convey_data(); +} + +bool TCPClientSocketLibevent::UsingTCPFastOpen() const { + return use_tcp_fastopen_; +} + +bool TCPClientSocketLibevent::WasNpnNegotiated() const { + return false; +} + +NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const { + return kProtoUnknown; +} + +bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket_libevent.h b/chromium/net/socket/tcp_client_socket_libevent.h new file mode 100644 index 00000000000..e5a0d8deab4 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket_libevent.h @@ -0,0 +1,256 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ +#define NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/non_thread_safe.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/net_log.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class BoundNetLog; + +// A client socket that uses TCP as the transport layer. +class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, + public base::NonThreadSafe { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + TCPClientSocketLibevent(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source); + + virtual ~TCPClientSocketLibevent(); + + // AdoptSocket causes the given, connected socket to be adopted as a TCP + // socket. This object must not be connected. This object takes ownership of + // the given socket and then acts as if Connect() had been called. This + // function is used by TCPServerSocket() to adopt accepted connections + // and for testing. + int AdoptSocket(int socket); + + // Binds the socket to a local IP address and port. + int Bind(const IPEndPoint& address); + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + virtual bool SetKeepAlive(bool enable, int delay); + virtual bool SetNoDelay(bool no_delay); + + private: + // State machine for connecting the socket. + enum ConnectState { + CONNECT_STATE_CONNECT, + CONNECT_STATE_CONNECT_COMPLETE, + CONNECT_STATE_NONE, + }; + + // States that a fast open socket attempt can result in. + enum FastOpenStatus { + FAST_OPEN_STATUS_UNKNOWN, + + // The initial fast open connect attempted returned synchronously, + // indicating that we had and sent a cookie along with the initial data. + FAST_OPEN_FAST_CONNECT_RETURN, + + // The initial fast open connect attempted returned asynchronously, + // indicating that we did not have a cookie for the server. + FAST_OPEN_SLOW_CONNECT_RETURN, + + // Some other error occurred on connection, so we couldn't tell if + // fast open would have worked. + FAST_OPEN_ERROR, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had acked the data we sent. + FAST_OPEN_SYN_DATA_ACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had nacked the data we sent. + FAST_OPEN_SYN_DATA_NACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the + // socket was using fast open failed. + FAST_OPEN_SYN_DATA_FAILED, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later confirmed that the server had acked initial data. This + // should never happen (we didn't send data, so it shouldn't have + // been acked). + FAST_OPEN_NO_SYN_DATA_ACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later discovered that the server had nacked initial data. This + // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. + FAST_OPEN_NO_SYN_DATA_NACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and our later probe for ack/nack state failed. + FAST_OPEN_NO_SYN_DATA_FAILED, + + FAST_OPEN_MAX_VALUE + }; + + class ReadWatcher : public base::MessageLoopForIO::Watcher { + public: + explicit ReadWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} + + // MessageLoopForIO::Watcher methods + + virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE; + + virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE {} + + private: + TCPClientSocketLibevent* const socket_; + + DISALLOW_COPY_AND_ASSIGN(ReadWatcher); + }; + + class WriteWatcher : public base::MessageLoopForIO::Watcher { + public: + explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} + + // MessageLoopForIO::Watcher implementation. + virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {} + virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE; + + private: + TCPClientSocketLibevent* const socket_; + + DISALLOW_COPY_AND_ASSIGN(WriteWatcher); + }; + + // State machine used by Connect(). + int DoConnectLoop(int result); + int DoConnect(); + int DoConnectComplete(int result); + + // Helper used by Disconnect(), which disconnects minus the logging and + // resetting of current_address_index_. + void DoDisconnect(); + + void DoReadCallback(int rv); + void DoWriteCallback(int rv); + void DidCompleteRead(); + void DidCompleteWrite(); + void DidCompleteConnect(); + + // Returns true if a Connect() is in progress. + bool waiting_connect() const { + return next_connect_state_ != CONNECT_STATE_NONE; + } + + // Helper to add a TCP_CONNECT (end) event to the NetLog. + void LogConnectCompletion(int net_error); + + // Internal function to write to a socket. + int InternalWrite(IOBuffer* buf, int buf_len); + + // Called when the socket is known to be in a connected state. + void RecordFastOpenStatus(); + + int socket_; + + // Local IP address and port we are bound to. Set to NULL if Bind() + // was't called (in that cases OS chooses address/port). + scoped_ptr<IPEndPoint> bind_address_; + + // Stores bound socket between Bind() and Connect() calls. + int bound_socket_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list. Set to -1 if uninitialized. + int current_address_index_; + + // The socket's libevent wrappers + base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; + base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; + + // The corresponding watchers for reads and writes. + ReadWatcher read_watcher_; + WriteWatcher write_watcher_; + + // The buffer used by OnSocketReady to retry Read requests + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + + // The buffer used by OnSocketReady to retry Write requests + scoped_refptr<IOBuffer> write_buf_; + int write_buf_len_; + + // External callback; called when read is complete. + CompletionCallback read_callback_; + + // External callback; called when write is complete. + CompletionCallback write_callback_; + + // The next state for the Connect() state machine. + ConnectState next_connect_state_; + + // The OS error that CONNECT_STATE_CONNECT last completed with. + int connect_os_error_; + + BoundNetLog net_log_; + + // This socket was previously disconnected and has not been re-connected. + bool previously_disconnected_; + + // Record of connectivity and transmissions, for use in speculative connection + // histograms. + UseHistory use_history_; + + // Enables experimental TCP FastOpen option. + const bool use_tcp_fastopen_; + + // True when TCP FastOpen is in use and we have done the connect. + bool tcp_fastopen_connected_; + + enum FastOpenStatus fast_open_status_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_client_socket_unittest.cc b/chromium/net/socket/tcp_client_socket_unittest.cc new file mode 100644 index 00000000000..ce0c53559f8 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket_unittest.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file contains some tests for TCPClientSocket. +// transport_client_socket_unittest.cc contans some other tests that +// are common for TCP and other types of sockets. + +#include "net/socket/tcp_client_socket.h" + +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/tcp_server_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +// Try binding a socket to loopback interface and verify that we can +// still connect to a server on the same interface. +TEST(TCPClientSocketTest, BindLoopbackToLoopback) { + IPAddressNumber lo_address; + ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address)); + + TCPServerSocket server(NULL, NetLog::Source()); + ASSERT_EQ(OK, server.Listen(IPEndPoint(lo_address, 0), 1)); + IPEndPoint server_address; + ASSERT_EQ(OK, server.GetLocalAddress(&server_address)); + + TCPClientSocket socket(AddressList(server_address), NULL, NetLog::Source()); + + EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); + + IPEndPoint local_address_result; + EXPECT_EQ(OK, socket.GetLocalAddress(&local_address_result)); + EXPECT_EQ(lo_address, local_address_result.address()); + + TestCompletionCallback connect_callback; + EXPECT_EQ(ERR_IO_PENDING, socket.Connect(connect_callback.callback())); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + int result = server.Accept(&accepted_socket, accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + + EXPECT_TRUE(socket.IsConnected()); + socket.Disconnect(); + EXPECT_FALSE(socket.IsConnected()); + EXPECT_EQ(ERR_SOCKET_NOT_CONNECTED, + socket.GetLocalAddress(&local_address_result)); +} + +// Try to bind socket to the loopback interface and connect to an +// external address, verify that connection fails. +TEST(TCPClientSocketTest, BindLoopbackToExternal) { + IPAddressNumber external_ip; + ASSERT_TRUE(ParseIPLiteralToNumber("72.14.213.105", &external_ip)); + TCPClientSocket socket(AddressList::CreateFromIPAddress(external_ip, 80), + NULL, NetLog::Source()); + + IPAddressNumber lo_address; + ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address)); + EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); + + TestCompletionCallback connect_callback; + int result = socket.Connect(connect_callback.callback()); + if (result == ERR_IO_PENDING) + result = connect_callback.WaitForResult(); + + // We may get different errors here on different system, but + // connect() is not expected to succeed. + EXPECT_NE(OK, result); +} + +// Bind a socket to the IPv4 loopback interface and try to connect to +// the IPv6 loopback interface, verify that connection fails. +TEST(TCPClientSocketTest, BindLoopbackToIPv6) { + IPAddressNumber ipv6_lo_ip; + ASSERT_TRUE(ParseIPLiteralToNumber("::1", &ipv6_lo_ip)); + TCPServerSocket server(NULL, NetLog::Source()); + int listen_result = server.Listen(IPEndPoint(ipv6_lo_ip, 0), 1); + if (listen_result != OK) { + LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is disabled." + " Skipping the test"; + return; + } + + IPEndPoint server_address; + ASSERT_EQ(OK, server.GetLocalAddress(&server_address)); + TCPClientSocket socket(AddressList(server_address), NULL, NetLog::Source()); + + IPAddressNumber ipv4_lo_ip; + ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &ipv4_lo_ip)); + EXPECT_EQ(OK, socket.Bind(IPEndPoint(ipv4_lo_ip, 0))); + + TestCompletionCallback connect_callback; + int result = socket.Connect(connect_callback.callback()); + if (result == ERR_IO_PENDING) + result = connect_callback.WaitForResult(); + + EXPECT_NE(OK, result); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket_win.cc b/chromium/net/socket/tcp_client_socket_win.cc new file mode 100644 index 00000000000..9b0a5b50bf1 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket_win.cc @@ -0,0 +1,1045 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_client_socket_win.h" + +#include <mstcpip.h> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/metrics/stats_counters.h" +#include "base/strings/string_util.h" +#include "base/win/object_watcher.h" +#include "base/win/windows_version.h" +#include "net/base/connection_type_histograms.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/base/winsock_init.h" +#include "net/base/winsock_util.h" +#include "net/socket/socket_net_log_params.h" + +namespace net { + +namespace { + +const int kTCPKeepAliveSeconds = 45; +bool g_disable_overlapped_reads = false; + +bool SetSocketReceiveBufferSize(SOCKET socket, int32 size) { + int rv = setsockopt(socket, SOL_SOCKET, SO_RCVBUF, + reinterpret_cast<const char*>(&size), sizeof(size)); + DCHECK(!rv) << "Could not set socket receive buffer size: " << GetLastError(); + return rv == 0; +} + +bool SetSocketSendBufferSize(SOCKET socket, int32 size) { + int rv = setsockopt(socket, SOL_SOCKET, SO_SNDBUF, + reinterpret_cast<const char*>(&size), sizeof(size)); + DCHECK(!rv) << "Could not set socket send buffer size: " << GetLastError(); + return rv == 0; +} + +// Disable Nagle. +// The Nagle implementation on windows is governed by RFC 896. The idea +// behind Nagle is to reduce small packets on the network. When Nagle is +// enabled, if a partial packet has been sent, the TCP stack will disallow +// further *partial* packets until an ACK has been received from the other +// side. Good applications should always strive to send as much data as +// possible and avoid partial-packet sends. However, in most real world +// applications, there are edge cases where this does not happen, and two +// partial packets may be sent back to back. For a browser, it is NEVER +// a benefit to delay for an RTT before the second packet is sent. +// +// As a practical example in Chromium today, consider the case of a small +// POST. I have verified this: +// Client writes 649 bytes of header (partial packet #1) +// Client writes 50 bytes of POST data (partial packet #2) +// In the above example, with Nagle, a RTT delay is inserted between these +// two sends due to nagle. RTTs can easily be 100ms or more. The best +// fix is to make sure that for POSTing data, we write as much data as +// possible and minimize partial packets. We will fix that. But disabling +// Nagle also ensure we don't run into this delay in other edge cases. +// See also: +// http://technet.microsoft.com/en-us/library/bb726981.aspx +bool DisableNagle(SOCKET socket, bool disable) { + BOOL val = disable ? TRUE : FALSE; + int rv = setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast<const char*>(&val), + sizeof(val)); + DCHECK(!rv) << "Could not disable nagle"; + return rv == 0; +} + +// Enable TCP Keep-Alive to prevent NAT routers from timing out TCP +// connections. See http://crbug.com/27400 for details. +bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { + int delay = delay_secs * 1000; + struct tcp_keepalive keepalive_vals = { + enable ? 1 : 0, // TCP keep-alive on. + delay, // Delay seconds before sending first TCP keep-alive packet. + delay, // Delay seconds between sending TCP keep-alive packets. + }; + DWORD bytes_returned = 0xABAB; + int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals, + sizeof(keepalive_vals), NULL, 0, + &bytes_returned, NULL, NULL); + DCHECK(!rv) << "Could not enable TCP Keep-Alive for socket: " << socket + << " [error: " << WSAGetLastError() << "]."; + + // Disregard any failure in disabling nagle or enabling TCP Keep-Alive. + return rv == 0; +} + +// Sets socket parameters. Returns the OS error code (or 0 on +// success). +int SetupSocket(SOCKET socket) { + // Increase the socket buffer sizes from the default sizes for WinXP. In + // performance testing, there is substantial benefit by increasing from 8KB + // to 64KB. + // See also: + // http://support.microsoft.com/kb/823764/EN-US + // On Vista, if we manually set these sizes, Vista turns off its receive + // window auto-tuning feature. + // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx + // Since Vista's auto-tune is better than any static value we can could set, + // only change these on pre-vista machines. + if (base::win::GetVersion() < base::win::VERSION_VISTA) { + const int32 kSocketBufferSize = 64 * 1024; + SetSocketReceiveBufferSize(socket, kSocketBufferSize); + SetSocketSendBufferSize(socket, kSocketBufferSize); + } + + DisableNagle(socket, true); + SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); + return 0; +} + +// Creates a new socket and sets default parameters for it. Returns +// the OS error code (or 0 on success). +int CreateSocket(int family, SOCKET* socket) { + *socket = CreatePlatformSocket(family, SOCK_STREAM, IPPROTO_TCP); + if (*socket == INVALID_SOCKET) { + int os_error = WSAGetLastError(); + LOG(ERROR) << "CreatePlatformSocket failed: " << os_error; + return os_error; + } + int error = SetupSocket(*socket); + if (error) { + if (closesocket(*socket) < 0) + PLOG(ERROR) << "closesocket"; + *socket = INVALID_SOCKET; + return error; + } + return 0; +} + +int MapConnectError(int os_error) { + switch (os_error) { + // connect fails with WSAEACCES when Windows Firewall blocks the + // connection. + case WSAEACCES: + return ERR_NETWORK_ACCESS_DENIED; + case WSAETIMEDOUT: + return ERR_CONNECTION_TIMED_OUT; + default: { + int net_error = MapSystemError(os_error); + if (net_error == ERR_FAILED) + return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED. + + // Give a more specific error when the user is offline. + if (net_error == ERR_ADDRESS_UNREACHABLE && + NetworkChangeNotifier::IsOffline()) { + return ERR_INTERNET_DISCONNECTED; + } + + return net_error; + } + } +} + +} // namespace + +//----------------------------------------------------------------------------- + +// This class encapsulates all the state that has to be preserved as long as +// there is a network IO operation in progress. If the owner TCPClientSocketWin +// is destroyed while an operation is in progress, the Core is detached and it +// lives until the operation completes and the OS doesn't reference any resource +// declared on this class anymore. +class TCPClientSocketWin::Core : public base::RefCounted<Core> { + public: + explicit Core(TCPClientSocketWin* socket); + + // Start watching for the end of a read or write operation. + void WatchForRead(); + void WatchForWrite(); + + // The TCPClientSocketWin is going away. + void Detach() { socket_ = NULL; } + + // Throttle the read size based on our current slow start state. + // Returns the throttled read size. + int ThrottleReadSize(int size) { + if (slow_start_throttle_ < kMaxSlowStartThrottle) { + size = std::min(size, slow_start_throttle_); + slow_start_throttle_ *= 2; + } + return size; + } + + // The separate OVERLAPPED variables for asynchronous operation. + // |read_overlapped_| is used for both Connect() and Read(). + // |write_overlapped_| is only used for Write(); + OVERLAPPED read_overlapped_; + OVERLAPPED write_overlapped_; + + // The buffers used in Read() and Write(). + scoped_refptr<IOBuffer> read_iobuffer_; + scoped_refptr<IOBuffer> write_iobuffer_; + int read_buffer_length_; + int write_buffer_length_; + + // Remember the state of g_disable_overlapped_reads for the duration of the + // socket based on what it was when the socket was created. + bool disable_overlapped_reads_; + bool non_blocking_reads_initialized_; + + private: + friend class base::RefCounted<Core>; + + class ReadDelegate : public base::win::ObjectWatcher::Delegate { + public: + explicit ReadDelegate(Core* core) : core_(core) {} + virtual ~ReadDelegate() {} + + // base::ObjectWatcher::Delegate methods: + virtual void OnObjectSignaled(HANDLE object); + + private: + Core* const core_; + }; + + class WriteDelegate : public base::win::ObjectWatcher::Delegate { + public: + explicit WriteDelegate(Core* core) : core_(core) {} + virtual ~WriteDelegate() {} + + // base::ObjectWatcher::Delegate methods: + virtual void OnObjectSignaled(HANDLE object); + + private: + Core* const core_; + }; + + ~Core(); + + // The socket that created this object. + TCPClientSocketWin* socket_; + + // |reader_| handles the signals from |read_watcher_|. + ReadDelegate reader_; + // |writer_| handles the signals from |write_watcher_|. + WriteDelegate writer_; + + // |read_watcher_| watches for events from Connect() and Read(). + base::win::ObjectWatcher read_watcher_; + // |write_watcher_| watches for events from Write(); + base::win::ObjectWatcher write_watcher_; + + // When doing reads from the socket, we try to mirror TCP's slow start. + // We do this because otherwise the async IO subsystem artifically delays + // returning data to the application. + static const int kInitialSlowStartThrottle = 1 * 1024; + static const int kMaxSlowStartThrottle = 32 * kInitialSlowStartThrottle; + int slow_start_throttle_; + + DISALLOW_COPY_AND_ASSIGN(Core); +}; + +TCPClientSocketWin::Core::Core( + TCPClientSocketWin* socket) + : read_buffer_length_(0), + write_buffer_length_(0), + disable_overlapped_reads_(g_disable_overlapped_reads), + non_blocking_reads_initialized_(false), + socket_(socket), + reader_(this), + writer_(this), + slow_start_throttle_(kInitialSlowStartThrottle) { + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + memset(&write_overlapped_, 0, sizeof(write_overlapped_)); + + read_overlapped_.hEvent = WSACreateEvent(); + write_overlapped_.hEvent = WSACreateEvent(); +} + +TCPClientSocketWin::Core::~Core() { + // Make sure the message loop is not watching this object anymore. + read_watcher_.StopWatching(); + write_watcher_.StopWatching(); + + WSACloseEvent(read_overlapped_.hEvent); + memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_)); + WSACloseEvent(write_overlapped_.hEvent); + memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_)); +} + +void TCPClientSocketWin::Core::WatchForRead() { + // We grab an extra reference because there is an IO operation in progress. + // Balanced in ReadDelegate::OnObjectSignaled(). + AddRef(); + read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_); +} + +void TCPClientSocketWin::Core::WatchForWrite() { + // We grab an extra reference because there is an IO operation in progress. + // Balanced in WriteDelegate::OnObjectSignaled(). + AddRef(); + write_watcher_.StartWatching(write_overlapped_.hEvent, &writer_); +} + +void TCPClientSocketWin::Core::ReadDelegate::OnObjectSignaled( + HANDLE object) { + DCHECK_EQ(object, core_->read_overlapped_.hEvent); + if (core_->socket_) { + if (core_->socket_->waiting_connect()) { + core_->socket_->DidCompleteConnect(); + } else if (core_->disable_overlapped_reads_) { + core_->socket_->DidSignalRead(); + } else { + core_->socket_->DidCompleteRead(); + } + } + + core_->Release(); +} + +void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled( + HANDLE object) { + DCHECK_EQ(object, core_->write_overlapped_.hEvent); + if (core_->socket_) + core_->socket_->DidCompleteWrite(); + + core_->Release(); +} + +//----------------------------------------------------------------------------- + +TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(INVALID_SOCKET), + bound_socket_(INVALID_SOCKET), + addresses_(addresses), + current_address_index_(-1), + waiting_read_(false), + waiting_write_(false), + next_connect_state_(CONNECT_STATE_NONE), + connect_os_error_(0), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + previously_disconnected_(false) { + net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, + source.ToEventParametersCallback()); + EnsureWinsockInit(); +} + +TCPClientSocketWin::~TCPClientSocketWin() { + Disconnect(); + net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); +} + +int TCPClientSocketWin::AdoptSocket(SOCKET socket) { + DCHECK_EQ(socket_, INVALID_SOCKET); + + int error = SetupSocket(socket); + if (error) + return MapSystemError(error); + + socket_ = socket; + SetNonBlocking(socket_); + + core_ = new Core(this); + current_address_index_ = 0; + use_history_.set_was_ever_connected(); + + return OK; +} + +int TCPClientSocketWin::Bind(const IPEndPoint& address) { + if (current_address_index_ >= 0 || bind_address_.get()) { + // Cannot bind the socket if we are already connected or connecting. + return ERR_UNEXPECTED; + } + + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + + // Create |bound_socket_| and try to bind it to |address|. + int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); + if (error) + return MapSystemError(error); + + if (bind(bound_socket_, storage.addr, storage.addr_len)) { + error = errno; + if (closesocket(bound_socket_) < 0) + PLOG(ERROR) << "closesocket"; + bound_socket_ = INVALID_SOCKET; + return MapSystemError(error); + } + + bind_address_.reset(new IPEndPoint(address)); + + return 0; +} + + +int TCPClientSocketWin::Connect(const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + + // If already connected, then just return OK. + if (socket_ != INVALID_SOCKET) + return OK; + + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses_.CreateNetLogCallback()); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_address_index_ = 0; + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + // TODO(ajwong): Is setting read_callback_ the right thing to do here?? + read_callback_ = callback; + } else { + LogConnectCompletion(rv); + } + + return rv; +} + +int TCPClientSocketWin::DoConnectLoop(int result) { + DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); + + int rv = result; + do { + ConnectState state = next_connect_state_; + next_connect_state_ = CONNECT_STATE_NONE; + switch (state) { + case CONNECT_STATE_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoConnect(); + break; + case CONNECT_STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + default: + LOG(DFATAL) << "bad state " << state; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + + return rv; +} + +int TCPClientSocketWin::DoConnect() { + DCHECK_GE(current_address_index_, 0); + DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + DCHECK_EQ(0, connect_os_error_); + + const IPEndPoint& endpoint = addresses_[current_address_index_]; + + if (previously_disconnected_) { + use_history_.Reset(); + previously_disconnected_ = false; + } + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(&endpoint)); + + next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; + + if (bound_socket_ != INVALID_SOCKET) { + DCHECK(bind_address_.get()); + socket_ = bound_socket_; + bound_socket_ = INVALID_SOCKET; + } else { + connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); + if (connect_os_error_ != 0) + return MapSystemError(connect_os_error_); + + if (bind_address_.get()) { + SockaddrStorage storage; + if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + if (bind(socket_, storage.addr, storage.addr_len)) + return MapSystemError(errno); + } + } + + DCHECK(!core_); + core_ = new Core(this); + // WSAEventSelect sets the socket to non-blocking mode as a side effect. + // Our connect() and recv() calls require that the socket be non-blocking. + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); + + SockaddrStorage storage; + if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + if (!connect(socket_, storage.addr, storage.addr_len)) { + // Connected without waiting! + // + // The MSDN page for connect says: + // With a nonblocking socket, the connection attempt cannot be completed + // immediately. In this case, connect will return SOCKET_ERROR, and + // WSAGetLastError will return WSAEWOULDBLOCK. + // which implies that for a nonblocking socket, connect never returns 0. + // It's not documented whether the event object will be signaled or not + // if connect does return 0. So the code below is essentially dead code + // and we don't know if it's correct. + NOTREACHED(); + + if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) + return OK; + } else { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + LOG(ERROR) << "connect failed: " << os_error; + connect_os_error_ = os_error; + return MapConnectError(os_error); + } + } + + core_->WatchForRead(); + return ERR_IO_PENDING; +} + +int TCPClientSocketWin::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); + } + + if (result == OK) { + use_history_.set_was_ever_connected(); + return OK; // Done! + } + + // Close whatever partially connected socket we currently have. + DoDisconnect(); + + // Try to fall back to the next address in the list. + if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { + next_connect_state_ = CONNECT_STATE_CONNECT; + ++current_address_index_; + return OK; + } + + // Otherwise there is nothing to fall back to, so give up. + return result; +} + +void TCPClientSocketWin::Disconnect() { + DCHECK(CalledOnValidThread()); + + DoDisconnect(); + current_address_index_ = -1; + bind_address_.reset(); +} + +void TCPClientSocketWin::DoDisconnect() { + DCHECK(CalledOnValidThread()); + + if (socket_ == INVALID_SOCKET) + return; + + // Note: don't use CancelIo to cancel pending IO because it doesn't work + // when there is a Winsock layered service provider. + + // In most socket implementations, closing a socket results in a graceful + // connection shutdown, but in Winsock we have to call shutdown explicitly. + // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure" + // at http://msdn.microsoft.com/en-us/library/ms738547.aspx + shutdown(socket_, SD_SEND); + + // This cancels any pending IO. + closesocket(socket_); + socket_ = INVALID_SOCKET; + + if (waiting_connect()) { + // We closed the socket, so this notification will never come. + // From MSDN' WSAEventSelect documentation: + // "Closing a socket with closesocket also cancels the association and + // selection of network events specified in WSAEventSelect for the socket". + core_->Release(); + } + + waiting_read_ = false; + waiting_write_ = false; + + core_->Detach(); + core_ = NULL; + + previously_disconnected_ = true; +} + +bool TCPClientSocketWin::IsConnected() const { + DCHECK(CalledOnValidThread()); + + if (socket_ == INVALID_SOCKET || waiting_connect()) + return false; + + if (waiting_read_) + return true; + + // Check if connection is alive. + char c; + int rv = recv(socket_, &c, 1, MSG_PEEK); + if (rv == 0) + return false; + if (rv == SOCKET_ERROR && WSAGetLastError() != WSAEWOULDBLOCK) + return false; + + return true; +} + +bool TCPClientSocketWin::IsConnectedAndIdle() const { + DCHECK(CalledOnValidThread()); + + if (socket_ == INVALID_SOCKET || waiting_connect()) + return false; + + if (waiting_read_) + return true; + + // Check if connection is alive and we haven't received any data + // unexpectedly. + char c; + int rv = recv(socket_, &c, 1, MSG_PEEK); + if (rv >= 0) + return false; + if (WSAGetLastError() != WSAEWOULDBLOCK) + return false; + + return true; +} + +int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = addresses_[current_address_index_]; + return OK; +} + +int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (socket_ == INVALID_SOCKET) { + if (bind_address_.get()) { + *address = *bind_address_; + return OK; + } + return ERR_SOCKET_NOT_CONNECTED; + } + + struct sockaddr_storage addr_storage; + socklen_t addr_len = sizeof(addr_storage); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + if (getsockname(socket_, addr, &addr_len)) + return MapSystemError(WSAGetLastError()); + if (!address->FromSockAddr(addr, addr_len)) + return ERR_FAILED; + return OK; +} + +void TCPClientSocketWin::SetSubresourceSpeculation() { + use_history_.set_subresource_speculation(); +} + +void TCPClientSocketWin::SetOmniboxSpeculation() { + use_history_.set_omnibox_speculation(); +} + +bool TCPClientSocketWin::WasEverUsed() const { + return use_history_.was_used_to_convey_data(); +} + +bool TCPClientSocketWin::UsingTCPFastOpen() const { + // Not supported on windows. + return false; +} + +bool TCPClientSocketWin::WasNpnNegotiated() const { + return false; +} + +NextProto TCPClientSocketWin::GetNegotiatedProtocol() const { + return kProtoUnknown; +} + +bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + +int TCPClientSocketWin::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_read_); + DCHECK(read_callback_.is_null()); + DCHECK(!core_->read_iobuffer_); + + return DoRead(buf, buf_len, callback); +} + +int TCPClientSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_write_); + DCHECK(write_callback_.is_null()); + DCHECK_GT(buf_len, 0); + DCHECK(!core_->write_iobuffer_); + + base::StatsCounter writes("tcp.writes"); + writes.Increment(); + + WSABUF write_buffer; + write_buffer.len = buf_len; + write_buffer.buf = buf->data(); + + // TODO(wtc): Remove the assertion after enough testing. + AssertEventNotSignaled(core_->write_overlapped_.hEvent); + DWORD num; + int rv = WSASend(socket_, &write_buffer, 1, &num, 0, + &core_->write_overlapped_, NULL); + if (rv == 0) { + if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) { + rv = static_cast<int>(num); + if (rv > buf_len || rv < 0) { + // It seems that some winsock interceptors report that more was written + // than was available. Treat this as an error. http://crbug.com/27870 + LOG(ERROR) << "Detected broken LSP: Asked to write " << buf_len + << " bytes, but " << rv << " bytes reported."; + return ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; + } + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(rv); + if (rv > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, + buf->data()); + return rv; + } + } else { + int os_error = WSAGetLastError(); + if (os_error != WSA_IO_PENDING) { + int net_error = MapSystemError(os_error); + net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, + CreateNetLogSocketErrorCallback(net_error, os_error)); + return net_error; + } + } + waiting_write_ = true; + write_callback_ = callback; + core_->write_iobuffer_ = buf; + core_->write_buffer_length_ = buf_len; + core_->WatchForWrite(); + return ERR_IO_PENDING; +} + +bool TCPClientSocketWin::SetReceiveBufferSize(int32 size) { + DCHECK(CalledOnValidThread()); + return SetSocketReceiveBufferSize(socket_, size); +} + +bool TCPClientSocketWin::SetSendBufferSize(int32 size) { + DCHECK(CalledOnValidThread()); + return SetSocketSendBufferSize(socket_, size); +} + +bool TCPClientSocketWin::SetKeepAlive(bool enable, int delay) { + return SetTCPKeepAlive(socket_, enable, delay); +} + +bool TCPClientSocketWin::SetNoDelay(bool no_delay) { + return DisableNagle(socket_, no_delay); +} + +void TCPClientSocketWin::DisableOverlappedReads() { + g_disable_overlapped_reads = true; +} + +void TCPClientSocketWin::LogConnectCompletion(int net_error) { + if (net_error == OK) + UpdateConnectionTypeHistograms(CONNECTION_ANY); + + if (net_error != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error); + return; + } + + struct sockaddr_storage source_address; + socklen_t addrlen = sizeof(source_address); + int rv = getsockname( + socket_, reinterpret_cast<struct sockaddr*>(&source_address), &addrlen); + if (rv != 0) { + LOG(ERROR) << "getsockname() [rv: " << rv + << "] error: " << WSAGetLastError(); + NOTREACHED(); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv); + return; + } + + net_log_.EndEvent( + NetLog::TYPE_TCP_CONNECT, + CreateNetLogSourceAddressCallback( + reinterpret_cast<const struct sockaddr*>(&source_address), + sizeof(source_address))); +} + +int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (core_->disable_overlapped_reads_) { + if (!core_->non_blocking_reads_initialized_) { + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, + FD_READ | FD_CLOSE); + core_->non_blocking_reads_initialized_ = true; + } + int rv = recv(socket_, buf->data(), buf_len, 0); + if (rv == SOCKET_ERROR) { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + int net_error = MapSystemError(os_error); + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(net_error, os_error)); + return net_error; + } + } else { + base::StatsCounter read_bytes("tcp.read_bytes"); + if (rv > 0) { + use_history_.set_was_used_to_convey_data(); + read_bytes.Add(rv); + } + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, + buf->data()); + return rv; + } + } else { + buf_len = core_->ThrottleReadSize(buf_len); + + WSABUF read_buffer; + read_buffer.len = buf_len; + read_buffer.buf = buf->data(); + + // TODO(wtc): Remove the assertion after enough testing. + AssertEventNotSignaled(core_->read_overlapped_.hEvent); + DWORD num; + DWORD flags = 0; + int rv = WSARecv(socket_, &read_buffer, 1, &num, &flags, + &core_->read_overlapped_, NULL); + if (rv == 0) { + if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { + base::StatsCounter read_bytes("tcp.read_bytes"); + if (num > 0) { + use_history_.set_was_used_to_convey_data(); + read_bytes.Add(num); + } + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num, + buf->data()); + return static_cast<int>(num); + } + } else { + int os_error = WSAGetLastError(); + if (os_error != WSA_IO_PENDING) { + int net_error = MapSystemError(os_error); + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(net_error, os_error)); + return net_error; + } + } + } + + waiting_read_ = true; + read_callback_ = callback; + core_->read_iobuffer_ = buf; + core_->read_buffer_length_ = buf_len; + core_->WatchForRead(); + return ERR_IO_PENDING; +} + +void TCPClientSocketWin::DoReadCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(!read_callback_.is_null()); + + // Since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback c = read_callback_; + read_callback_.Reset(); + c.Run(rv); +} + +void TCPClientSocketWin::DoWriteCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(!write_callback_.is_null()); + + // since Run may result in Write being called, clear write_callback_ up front. + CompletionCallback c = write_callback_; + write_callback_.Reset(); + c.Run(rv); +} + +void TCPClientSocketWin::DidCompleteConnect() { + DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); + int result; + + WSANETWORKEVENTS events; + int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, + &events); + int os_error = 0; + if (rv == SOCKET_ERROR) { + NOTREACHED(); + os_error = WSAGetLastError(); + result = MapSystemError(os_error); + } else if (events.lNetworkEvents & FD_CONNECT) { + os_error = events.iErrorCode[FD_CONNECT_BIT]; + result = MapConnectError(os_error); + } else { + NOTREACHED(); + result = ERR_UNEXPECTED; + } + + connect_os_error_ = os_error; + rv = DoConnectLoop(result); + if (rv != ERR_IO_PENDING) { + LogConnectCompletion(rv); + DoReadCallback(rv); + } +} + +void TCPClientSocketWin::DidCompleteRead() { + DCHECK(waiting_read_); + DWORD num_bytes, flags; + BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_, + &num_bytes, FALSE, &flags); + waiting_read_ = false; + int rv; + if (ok) { + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(num_bytes); + if (num_bytes > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, + num_bytes, core_->read_iobuffer_->data()); + rv = static_cast<int>(num_bytes); + } else { + int os_error = WSAGetLastError(); + rv = MapSystemError(os_error); + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(rv, os_error)); + } + WSAResetEvent(core_->read_overlapped_.hEvent); + core_->read_iobuffer_ = NULL; + core_->read_buffer_length_ = 0; + DoReadCallback(rv); +} + +void TCPClientSocketWin::DidCompleteWrite() { + DCHECK(waiting_write_); + + DWORD num_bytes, flags; + BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_, + &num_bytes, FALSE, &flags); + WSAResetEvent(core_->write_overlapped_.hEvent); + waiting_write_ = false; + int rv; + if (!ok) { + int os_error = WSAGetLastError(); + rv = MapSystemError(os_error); + net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, + CreateNetLogSocketErrorCallback(rv, os_error)); + } else { + rv = static_cast<int>(num_bytes); + if (rv > core_->write_buffer_length_ || rv < 0) { + // It seems that some winsock interceptors report that more was written + // than was available. Treat this as an error. http://crbug.com/27870 + LOG(ERROR) << "Detected broken LSP: Asked to write " + << core_->write_buffer_length_ << " bytes, but " << rv + << " bytes reported."; + rv = ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; + } else { + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(num_bytes); + if (num_bytes > 0) + use_history_.set_was_used_to_convey_data(); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes, + core_->write_iobuffer_->data()); + } + } + core_->write_iobuffer_ = NULL; + DoWriteCallback(rv); +} + +void TCPClientSocketWin::DidSignalRead() { + DCHECK(waiting_read_); + int os_error = 0; + WSANETWORKEVENTS network_events; + int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, + &network_events); + if (rv == SOCKET_ERROR) { + os_error = WSAGetLastError(); + rv = MapSystemError(os_error); + } else if (network_events.lNetworkEvents) { + DCHECK_EQ(network_events.lNetworkEvents & ~(FD_READ | FD_CLOSE), 0); + // If network_events.lNetworkEvents is FD_CLOSE and + // network_events.iErrorCode[FD_CLOSE_BIT] is 0, it is a graceful + // connection closure. It is tempting to directly set rv to 0 in + // this case, but the MSDN pages for WSAEventSelect and + // WSAAsyncSelect recommend we still call DoRead(): + // FD_CLOSE should only be posted after all data is read from a + // socket, but an application should check for remaining data upon + // receipt of FD_CLOSE to avoid any possibility of losing data. + // + // If network_events.iErrorCode[FD_READ_BIT] or + // network_events.iErrorCode[FD_CLOSE_BIT] is nonzero, still call + // DoRead() because recv() reports a more accurate error code + // (WSAECONNRESET vs. WSAECONNABORTED) when the connection was + // reset. + rv = DoRead(core_->read_iobuffer_, core_->read_buffer_length_, + read_callback_); + if (rv == ERR_IO_PENDING) + return; + } else { + // This may happen because Read() may succeed synchronously and + // consume all the received data without resetting the event object. + core_->WatchForRead(); + return; + } + waiting_read_ = false; + core_->read_iobuffer_ = NULL; + core_->read_buffer_length_ = 0; + DoReadCallback(rv); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket_win.h b/chromium/net/socket/tcp_client_socket_win.h new file mode 100644 index 00000000000..26c8b9feff2 --- /dev/null +++ b/chromium/net/socket/tcp_client_socket_win.h @@ -0,0 +1,162 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ +#define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ + +#include <winsock2.h> + +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" +#include "net/base/net_log.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class BoundNetLog; + +class NET_EXPORT TCPClientSocketWin : public StreamSocket, + NON_EXPORTED_BASE(base::NonThreadSafe) { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + TCPClientSocketWin(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source); + + virtual ~TCPClientSocketWin(); + + // AdoptSocket causes the given, connected socket to be adopted as a TCP + // socket. This object must not be connected. This object takes ownership of + // the given socket and then acts as if Connect() had been called. This + // function is used by TCPServerSocket() to adopt accepted connections + // and for testing. + int AdoptSocket(SOCKET socket); + + // Binds the socket to a local IP address and port. + int Bind(const IPEndPoint& address); + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + virtual int GetPeerAddress(IPEndPoint* address) const; + virtual int GetLocalAddress(IPEndPoint* address) const; + virtual const BoundNetLog& NetLog() const { return net_log_; } + virtual void SetSubresourceSpeculation(); + virtual void SetOmniboxSpeculation(); + virtual bool WasEverUsed() const; + virtual bool UsingTCPFastOpen() const; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); + + virtual bool SetReceiveBufferSize(int32 size); + virtual bool SetSendBufferSize(int32 size); + + virtual bool SetKeepAlive(bool enable, int delay); + virtual bool SetNoDelay(bool no_delay); + + // Perform reads in non-blocking mode instead of overlapped mode. + // Used for experiments. + static void DisableOverlappedReads(); + + private: + // State machine for connecting the socket. + enum ConnectState { + CONNECT_STATE_CONNECT, + CONNECT_STATE_CONNECT_COMPLETE, + CONNECT_STATE_NONE, + }; + + class Core; + + // State machine used by Connect(). + int DoConnectLoop(int result); + int DoConnect(); + int DoConnectComplete(int result); + + // Helper used by Disconnect(), which disconnects minus the logging and + // resetting of current_address_index_. + void DoDisconnect(); + + // Returns true if a Connect() is in progress. + bool waiting_connect() const { + return next_connect_state_ != CONNECT_STATE_NONE; + } + + // Called after Connect() has completed with |net_error|. + void LogConnectCompletion(int net_error); + + int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + void DoReadCallback(int rv); + void DoWriteCallback(int rv); + void DidCompleteConnect(); + void DidCompleteRead(); + void DidCompleteWrite(); + void DidSignalRead(); + + SOCKET socket_; + + // Local IP address and port we are bound to. Set to NULL if Bind() + // was't called (in that cases OS chooses address/port). + scoped_ptr<IPEndPoint> bind_address_; + + // Stores bound socket between Bind() and Connect() calls. + SOCKET bound_socket_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list. Set to -1 if uninitialized. + int current_address_index_; + + // The various states that the socket could be in. + bool waiting_read_; + bool waiting_write_; + + // The core of the socket that can live longer than the socket itself. We pass + // resources to the Windows async IO functions and we have to make sure that + // they are not destroyed while the OS still references them. + scoped_refptr<Core> core_; + + // External callback; called when connect or read is complete. + CompletionCallback read_callback_; + + // External callback; called when write is complete. + CompletionCallback write_callback_; + + // The next state for the Connect() state machine. + ConnectState next_connect_state_; + + // The OS error that CONNECT_STATE_CONNECT last completed with. + int connect_os_error_; + + BoundNetLog net_log_; + + // This socket was previously disconnected and has not been re-connected. + bool previously_disconnected_; + + // Record of connectivity and transmissions, for use in speculative connection + // histograms. + UseHistory use_history_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc new file mode 100644 index 00000000000..aab2e45d0e9 --- /dev/null +++ b/chromium/net/socket/tcp_listen_socket.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_listen_socket.h" + +#if defined(OS_WIN) +// winsock2.h must be included first in order to ensure it is included before +// windows.h. +#include <winsock2.h> +#elif defined(OS_POSIX) +#include <arpa/inet.h> +#include <errno.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include "net/base/net_errors.h" +#endif + +#include "base/logging.h" +#include "base/sys_byteorder.h" +#include "base/threading/platform_thread.h" +#include "build/build_config.h" +#include "net/base/net_util.h" +#include "net/base/winsock_init.h" + +using std::string; + +namespace net { + +// static +scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen( + const string& ip, int port, StreamListenSocket::Delegate* del) { + SocketDescriptor s = CreateAndBind(ip, port); + if (s == kInvalidSocket) + return NULL; + scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); + sock->Listen(); + return sock; +} + +TCPListenSocket::TCPListenSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del) + : StreamListenSocket(s, del) { +} + +TCPListenSocket::~TCPListenSocket() {} + +SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { +#if defined(OS_WIN) + EnsureWinsockInit(); +#endif + + SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s != kInvalidSocket) { +#if defined(OS_POSIX) + // Allow rapid reuse. + static const int kOn = 1; + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &kOn, sizeof(kOn)); +#endif + sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr(ip.c_str()); + addr.sin_port = base::HostToNet16(port); + if (bind(s, reinterpret_cast<sockaddr*>(&addr), sizeof(addr))) { +#if defined(OS_WIN) + closesocket(s); +#elif defined(OS_POSIX) + close(s); +#endif + LOG(ERROR) << "Could not bind socket to " << ip << ":" << port; + s = kInvalidSocket; + } + } + return s; +} + +SocketDescriptor TCPListenSocket::CreateAndBindAnyPort(const string& ip, + int* port) { + SocketDescriptor s = CreateAndBind(ip, 0); + if (s == kInvalidSocket) + return kInvalidSocket; + sockaddr_in addr; + socklen_t addr_size = sizeof(addr); + bool failed = getsockname(s, reinterpret_cast<struct sockaddr*>(&addr), + &addr_size) != 0; + if (addr_size != sizeof(addr)) + failed = true; + if (failed) { + LOG(ERROR) << "Could not determine bound port, getsockname() failed"; +#if defined(OS_WIN) + closesocket(s); +#elif defined(OS_POSIX) + close(s); +#endif + return kInvalidSocket; + } + *port = base::NetToHost16(addr.sin_port); + return s; +} + +void TCPListenSocket::Accept() { + SocketDescriptor conn = AcceptSocket(); + if (conn == kInvalidSocket) + return; + scoped_refptr<TCPListenSocket> sock( + new TCPListenSocket(conn, socket_delegate_)); + // It's up to the delegate to AddRef if it wants to keep it around. +#if defined(OS_POSIX) + sock->WatchSocket(WAITING_READ); +#endif + socket_delegate_->DidAccept(this, sock.get()); +} + +TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) + : ip_(ip), + port_(port) { +} + +TCPListenSocketFactory::~TCPListenSocketFactory() {} + +scoped_refptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return TCPListenSocket::CreateAndListen(ip_, port_, delegate); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h new file mode 100644 index 00000000000..dbc5347e945 --- /dev/null +++ b/chromium/net/socket/tcp_listen_socket.h @@ -0,0 +1,64 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_LISTEN_SOCKET_H_ +#define NET_SOCKET_TCP_LISTEN_SOCKET_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "net/base/net_export.h" +#include "net/socket/stream_listen_socket.h" + +namespace net { + +// Implements a TCP socket. Note that this is ref counted. +class NET_EXPORT TCPListenSocket : public StreamListenSocket { + public: + // Listen on port for the specified IP address. Use 127.0.0.1 to only + // accept local connections. + static scoped_refptr<TCPListenSocket> CreateAndListen( + const std::string& ip, int port, StreamListenSocket::Delegate* del); + + // Get raw TCP socket descriptor bound to ip:port. + static SocketDescriptor CreateAndBind(const std::string& ip, int port); + + // Get raw TCP socket descriptor bound to ip and return port it is bound to. + static SocketDescriptor CreateAndBindAnyPort(const std::string& ip, + int* port); + + protected: + friend class scoped_refptr<TCPListenSocket>; + + TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); + virtual ~TCPListenSocket(); + + // Implements StreamListenSocket::Accept. + virtual void Accept() OVERRIDE; + + private: + DISALLOW_COPY_AND_ASSIGN(TCPListenSocket); +}; + +// Factory that can be used to instantiate TCPListenSocket. +class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory { + public: + TCPListenSocketFactory(const std::string& ip, int port); + virtual ~TCPListenSocketFactory(); + + // StreamListenSocketFactory overrides. + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const OVERRIDE; + + private: + const std::string ip_; + const int port_; + + DISALLOW_COPY_AND_ASSIGN(TCPListenSocketFactory); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_LISTEN_SOCKET_H_ diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc new file mode 100644 index 00000000000..d13b784cbdc --- /dev/null +++ b/chromium/net/socket/tcp_listen_socket_unittest.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_listen_socket_unittest.h" + +#include <fcntl.h> +#include <sys/types.h> + +#include "base/bind.h" +#include "base/posix/eintr_wrapper.h" +#include "base/sys_byteorder.h" +#include "net/base/net_util.h" +#include "testing/platform_test.h" + +namespace net { + +const int TCPListenSocketTester::kTestPort = 9999; + +static const int kReadBufSize = 1024; +static const char kHelloWorld[] = "HELLO, WORLD"; +static const int kMaxQueueSize = 20; +static const char kLoopback[] = "127.0.0.1"; +static const int kDefaultTimeoutMs = 5000; + +TCPListenSocketTester::TCPListenSocketTester() + : loop_(NULL), server_(NULL), connection_(NULL), cv_(&lock_) {} + +void TCPListenSocketTester::SetUp() { + base::Thread::Options options; + options.message_loop_type = base::MessageLoop::TYPE_IO; + thread_.reset(new base::Thread("socketio_test")); + thread_->StartWithOptions(options); + loop_ = reinterpret_cast<base::MessageLoopForIO*>(thread_->message_loop()); + + loop_->PostTask(FROM_HERE, base::Bind( + &TCPListenSocketTester::Listen, this)); + + // verify Listen succeeded + NextAction(); + ASSERT_FALSE(server_.get() == NULL); + ASSERT_EQ(ACTION_LISTEN, last_action_.type()); + + // verify the connect/accept and setup test_socket_ + test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_); + struct sockaddr_in client; + client.sin_family = AF_INET; + client.sin_addr.s_addr = inet_addr(kLoopback); + client.sin_port = base::HostToNet16(kTestPort); + int ret = HANDLE_EINTR( + connect(test_socket_, reinterpret_cast<sockaddr*>(&client), + sizeof(client))); +#if defined(OS_POSIX) + // The connect() call may be interrupted by a signal. When connect() + // is retried on EINTR, it fails with EISCONN. + if (ret == StreamListenSocket::kSocketError) + ASSERT_EQ(EISCONN, errno); +#else + // Don't have signals. + ASSERT_NE(StreamListenSocket::kSocketError, ret); +#endif + + NextAction(); + ASSERT_EQ(ACTION_ACCEPT, last_action_.type()); +} + +void TCPListenSocketTester::TearDown() { +#if defined(OS_WIN) + ASSERT_EQ(0, closesocket(test_socket_)); +#elif defined(OS_POSIX) + ASSERT_EQ(0, HANDLE_EINTR(close(test_socket_))); +#endif + NextAction(); + ASSERT_EQ(ACTION_CLOSE, last_action_.type()); + + loop_->PostTask(FROM_HERE, base::Bind( + &TCPListenSocketTester::Shutdown, this)); + NextAction(); + ASSERT_EQ(ACTION_SHUTDOWN, last_action_.type()); + + thread_.reset(); + loop_ = NULL; +} + +void TCPListenSocketTester::ReportAction( + const TCPListenSocketTestAction& action) { + base::AutoLock locked(lock_); + queue_.push_back(action); + cv_.Broadcast(); +} + +void TCPListenSocketTester::NextAction() { + base::AutoLock locked(lock_); + while (queue_.empty()) + cv_.Wait(); + last_action_ = queue_.front(); + queue_.pop_front(); +} + +int TCPListenSocketTester::ClearTestSocket() { + char buf[kReadBufSize]; + int len_ret = 0; + do { + int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0)); + if (len == StreamListenSocket::kSocketError || len == 0) { + break; + } else { + len_ret += len; + } + } while (true); + return len_ret; +} + +void TCPListenSocketTester::Shutdown() { + connection_->Release(); + connection_ = NULL; + server_->Release(); + server_ = NULL; + ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN)); +} + +void TCPListenSocketTester::Listen() { + server_ = DoListen(); + ASSERT_TRUE(server_.get()); + server_->AddRef(); + ReportAction(TCPListenSocketTestAction(ACTION_LISTEN)); +} + +void TCPListenSocketTester::SendFromTester() { + connection_->Send(kHelloWorld); + ReportAction(TCPListenSocketTestAction(ACTION_SEND)); +} + +void TCPListenSocketTester::TestClientSend() { + ASSERT_TRUE(Send(test_socket_, kHelloWorld)); + NextAction(); + ASSERT_EQ(ACTION_READ, last_action_.type()); + ASSERT_EQ(last_action_.data(), kHelloWorld); +} + +void TCPListenSocketTester::TestClientSendLong() { + size_t hello_len = strlen(kHelloWorld); + std::string long_string; + size_t long_len = 0; + for (int i = 0; i < 200; i++) { + long_string += kHelloWorld; + long_len += hello_len; + } + ASSERT_TRUE(Send(test_socket_, long_string)); + size_t read_len = 0; + while (read_len < long_len) { + NextAction(); + ASSERT_EQ(ACTION_READ, last_action_.type()); + std::string last_data = last_action_.data(); + size_t len = last_data.length(); + if (long_string.compare(read_len, len, last_data)) { + ASSERT_EQ(long_string.compare(read_len, len, last_data), 0); + } + read_len += last_data.length(); + } + ASSERT_EQ(read_len, long_len); +} + +void TCPListenSocketTester::TestServerSend() { + loop_->PostTask(FROM_HERE, base::Bind( + &TCPListenSocketTester::SendFromTester, this)); + NextAction(); + ASSERT_EQ(ACTION_SEND, last_action_.type()); + const int buf_len = 200; + char buf[buf_len+1]; + unsigned recv_len = 0; + while (recv_len < strlen(kHelloWorld)) { + int r = HANDLE_EINTR(recv(test_socket_, + buf + recv_len, buf_len - recv_len, 0)); + ASSERT_GE(r, 0); + recv_len += static_cast<unsigned>(r); + if (!r) + break; + } + buf[recv_len] = 0; + ASSERT_STREQ(kHelloWorld, buf); +} + +void TCPListenSocketTester::TestServerSendMultiple() { + // Send enough data to exceed the socket receive window. 20kb is probably a + // safe bet. + int send_count = (1024*20) / (sizeof(kHelloWorld)-1); + + // Send multiple writes. Since no reading is occurring the data should be + // buffered in TCPListenSocket. + for (int i = 0; i < send_count; ++i) { + loop_->PostTask(FROM_HERE, base::Bind( + &TCPListenSocketTester::SendFromTester, this)); + NextAction(); + ASSERT_EQ(ACTION_SEND, last_action_.type()); + } + + // Make multiple reads. All of the data should eventually be returned. + char buf[sizeof(kHelloWorld)]; + const int buf_len = sizeof(kHelloWorld); + for (int i = 0; i < send_count; ++i) { + unsigned recv_len = 0; + while (recv_len < buf_len-1) { + int r = HANDLE_EINTR(recv(test_socket_, + buf + recv_len, buf_len - 1 - recv_len, 0)); + ASSERT_GE(r, 0); + recv_len += static_cast<unsigned>(r); + if (!r) + break; + } + buf[recv_len] = 0; + ASSERT_STREQ(kHelloWorld, buf); + } +} + +bool TCPListenSocketTester::Send(SocketDescriptor sock, + const std::string& str) { + int len = static_cast<int>(str.length()); + int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0)); + if (send_len == StreamListenSocket::kSocketError) { + LOG(ERROR) << "send failed: " << errno; + return false; + } else if (send_len != len) { + return false; + } + return true; +} + +void TCPListenSocketTester::DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) { + connection_ = connection; + connection_->AddRef(); + ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT)); +} + +void TCPListenSocketTester::DidRead(StreamListenSocket* connection, + const char* data, + int len) { + std::string str(data, len); + ReportAction(TCPListenSocketTestAction(ACTION_READ, str)); +} + +void TCPListenSocketTester::DidClose(StreamListenSocket* sock) { + ReportAction(TCPListenSocketTestAction(ACTION_CLOSE)); +} + +TCPListenSocketTester::~TCPListenSocketTester() {} + +scoped_refptr<TCPListenSocket> TCPListenSocketTester::DoListen() { + return TCPListenSocket::CreateAndListen(kLoopback, kTestPort, this); +} + +class TCPListenSocketTest: public PlatformTest { + public: + TCPListenSocketTest() { + tester_ = NULL; + } + + virtual void SetUp() { + PlatformTest::SetUp(); + tester_ = new TCPListenSocketTester(); + tester_->SetUp(); + } + + virtual void TearDown() { + PlatformTest::TearDown(); + tester_->TearDown(); + tester_ = NULL; + } + + scoped_refptr<TCPListenSocketTester> tester_; +}; + +TEST_F(TCPListenSocketTest, ClientSend) { + tester_->TestClientSend(); +} + +TEST_F(TCPListenSocketTest, ClientSendLong) { + tester_->TestClientSendLong(); +} + +TEST_F(TCPListenSocketTest, ServerSend) { + tester_->TestServerSend(); +} + +TEST_F(TCPListenSocketTest, ServerSendMultiple) { + tester_->TestServerSendMultiple(); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h new file mode 100644 index 00000000000..048a0186705 --- /dev/null +++ b/chromium/net/socket/tcp_listen_socket_unittest.h @@ -0,0 +1,122 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_LISTEN_SOCKET_UNITTEST_H_ +#define NET_BASE_LISTEN_SOCKET_UNITTEST_H_ + +#include "build/build_config.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#elif defined(OS_POSIX) +#include <arpa/inet.h> +#include <errno.h> +#include <sys/socket.h> +#endif + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/strings/string_util.h" +#include "base/synchronization/condition_variable.h" +#include "base/synchronization/lock.h" +#include "base/threading/thread.h" +#include "net/base/net_util.h" +#include "net/base/winsock_init.h" +#include "net/socket/tcp_listen_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +enum ActionType { + ACTION_NONE = 0, + ACTION_LISTEN = 1, + ACTION_ACCEPT = 2, + ACTION_READ = 3, + ACTION_SEND = 4, + ACTION_CLOSE = 5, + ACTION_SHUTDOWN = 6 +}; + +class TCPListenSocketTestAction { + public: + TCPListenSocketTestAction() : action_(ACTION_NONE) {} + explicit TCPListenSocketTestAction(ActionType action) : action_(action) {} + TCPListenSocketTestAction(ActionType action, std::string data) + : action_(action), + data_(data) {} + + const std::string data() const { return data_; } + ActionType type() const { return action_; } + + private: + ActionType action_; + std::string data_; +}; + + +// This had to be split out into a separate class because I couldn't +// make the testing::Test class refcounted. +class TCPListenSocketTester : + public StreamListenSocket::Delegate, + public base::RefCountedThreadSafe<TCPListenSocketTester> { + + public: + TCPListenSocketTester(); + + void SetUp(); + void TearDown(); + + void ReportAction(const TCPListenSocketTestAction& action); + void NextAction(); + + // read all pending data from the test socket + int ClearTestSocket(); + // Release the connection and server sockets + void Shutdown(); + void Listen(); + void SendFromTester(); + // verify the send/read from client to server + void TestClientSend(); + // verify send/read of a longer string + void TestClientSendLong(); + // verify a send/read from server to client + void TestServerSend(); + // verify multiple sends and reads from server to client. + void TestServerSendMultiple(); + + virtual bool Send(SocketDescriptor sock, const std::string& str); + + // StreamListenSocket::Delegate: + virtual void DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) OVERRIDE; + virtual void DidRead(StreamListenSocket* connection, const char* data, + int len) OVERRIDE; + virtual void DidClose(StreamListenSocket* sock) OVERRIDE; + + scoped_ptr<base::Thread> thread_; + base::MessageLoopForIO* loop_; + scoped_refptr<TCPListenSocket> server_; + StreamListenSocket* connection_; + TCPListenSocketTestAction last_action_; + + SocketDescriptor test_socket_; + static const int kTestPort; + + base::Lock lock_; // protects |queue_| and wraps |cv_| + base::ConditionVariable cv_; + std::deque<TCPListenSocketTestAction> queue_; + + protected: + friend class base::RefCountedThreadSafe<TCPListenSocketTester>; + + virtual ~TCPListenSocketTester(); + + virtual scoped_refptr<TCPListenSocket> DoListen(); +}; + +} // namespace net + +#endif // NET_BASE_LISTEN_SOCKET_UNITTEST_H_ diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h new file mode 100644 index 00000000000..4970a150e8d --- /dev/null +++ b/chromium/net/socket/tcp_server_socket.h @@ -0,0 +1,26 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_ +#define NET_SOCKET_TCP_SERVER_SOCKET_H_ + +#include "build/build_config.h" + +#if defined(OS_WIN) +#include "net/socket/tcp_server_socket_win.h" +#elif defined(OS_POSIX) +#include "net/socket/tcp_server_socket_libevent.h" +#endif + +namespace net { + +#if defined(OS_WIN) +typedef TCPServerSocketWin TCPServerSocket; +#elif defined(OS_POSIX) +typedef TCPServerSocketLibevent TCPServerSocket; +#endif + +} // namespace net + +#endif // NET_SOCKET_TCP_SERVER_SOCKET_H_ diff --git a/chromium/net/socket/tcp_server_socket_libevent.cc b/chromium/net/socket/tcp_server_socket_libevent.cc new file mode 100644 index 00000000000..38dda962f46 --- /dev/null +++ b/chromium/net/socket/tcp_server_socket_libevent.cc @@ -0,0 +1,223 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_server_socket_libevent.h" + +#include <errno.h> +#include <fcntl.h> +#include <netdb.h> +#include <sys/socket.h> + +#include "build/build_config.h" + +#if defined(OS_POSIX) +#include <netinet/in.h> +#endif + +#include "base/posix/eintr_wrapper.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/socket/socket_net_log_params.h" +#include "net/socket/tcp_client_socket.h" + +namespace net { + +namespace { + +const int kInvalidSocket = -1; + +} // namespace + +TCPServerSocketLibevent::TCPServerSocketLibevent( + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(kInvalidSocket), + accept_socket_(NULL), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, + source.ToEventParametersCallback()); +} + +TCPServerSocketLibevent::~TCPServerSocketLibevent() { + if (socket_ != kInvalidSocket) + Close(); + net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); +} + +int TCPServerSocketLibevent::Listen(const IPEndPoint& address, int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_EQ(socket_, kInvalidSocket); + + socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); + if (socket_ < 0) { + PLOG(ERROR) << "socket() returned an error"; + return MapSystemError(errno); + } + + if (SetNonBlocking(socket_)) { + int result = MapSystemError(errno); + Close(); + return result; + } + + int result = SetSocketOptions(); + if (result != OK) { + Close(); + return result; + } + + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { + Close(); + return ERR_ADDRESS_INVALID; + } + + result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + result = MapSystemError(errno); + Close(); + return result; + } + + result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + result = MapSystemError(errno); + Close(); + return result; + } + + return OK; +} + +int TCPServerSocketLibevent::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) + return MapSystemError(errno); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_FAILED; + + return OK; +} + +int TCPServerSocketLibevent::Accept( + scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); + + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); + + int result = AcceptInternal(socket); + + if (result == ERR_IO_PENDING) { + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_READ, + &accept_socket_watcher_, this)) { + PLOG(ERROR) << "WatchFileDescriptor failed on read"; + return MapSystemError(errno); + } + + accept_socket_ = socket; + accept_callback_ = callback; + } + + return result; +} + +int TCPServerSocketLibevent::SetSocketOptions() { + // SO_REUSEADDR is useful for server sockets to bind to a recently unbound + // port. When a socket is closed, the end point changes its state to TIME_WAIT + // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer + // acknowledges its closure. For server sockets, it is usually safe to + // bind to a TIME_WAIT end point immediately, which is a widely adopted + // behavior. + // + // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to + // an end point that is already bound by another socket. To do that one must + // set SO_REUSEPORT instead. This option is not provided on Linux prior + // to 3.9. + // + // SO_REUSEPORT is provided in MacOS X and iOS. + int true_value = 1; + int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &true_value, + sizeof(true_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +int TCPServerSocketLibevent::AcceptInternal( + scoped_ptr<StreamSocket>* socket) { + SockaddrStorage storage; + int new_socket = HANDLE_EINTR(accept(socket_, + storage.addr, + &storage.addr_len)); + if (new_socket < 0) { + int net_error = MapSystemError(errno); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint address; + if (!address.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (HANDLE_EINTR(close(new_socket)) < 0) + PLOG(ERROR) << "close"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); + return ERR_FAILED; + } + scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( + AddressList(address), + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptSocket(new_socket); + if (adopt_result != OK) { + if (HANDLE_EINTR(close(new_socket)) < 0) + PLOG(ERROR) << "close"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + socket->reset(tcp_socket.release()); + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&address)); + return OK; +} + +void TCPServerSocketLibevent::Close() { + if (socket_ != kInvalidSocket) { + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + if (HANDLE_EINTR(close(socket_)) < 0) + PLOG(ERROR) << "close"; + socket_ = kInvalidSocket; + } +} + +void TCPServerSocketLibevent::OnFileCanReadWithoutBlocking(int fd) { + DCHECK(CalledOnValidThread()); + + int result = AcceptInternal(accept_socket_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + CompletionCallback callback = accept_callback_; + accept_callback_.Reset(); + callback.Run(result); + } +} + +void TCPServerSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) { + NOTREACHED(); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_libevent.h b/chromium/net/socket/tcp_server_socket_libevent.h new file mode 100644 index 00000000000..fe69472a653 --- /dev/null +++ b/chromium/net/socket/tcp_server_socket_libevent.h @@ -0,0 +1,55 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ +#define NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ + +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/non_thread_safe.h" +#include "net/base/completion_callback.h" +#include "net/base/net_log.h" +#include "net/socket/server_socket.h" + +namespace net { + +class IPEndPoint; + +class NET_EXPORT_PRIVATE TCPServerSocketLibevent : + public ServerSocket, + public base::NonThreadSafe, + public base::MessageLoopForIO::Watcher { + public: + TCPServerSocketLibevent(net::NetLog* net_log, + const net::NetLog::Source& source); + virtual ~TCPServerSocketLibevent(); + + // net::ServerSocket implementation. + virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) OVERRIDE; + + // MessageLoopForIO::Watcher implementation. + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; + + private: + int SetSocketOptions(); + int AcceptInternal(scoped_ptr<StreamSocket>* socket); + void Close(); + + int socket_; + + base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; + + scoped_ptr<StreamSocket>* accept_socket_; + CompletionCallback accept_callback_; + + BoundNetLog net_log_; +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc new file mode 100644 index 00000000000..fd81e550d08 --- /dev/null +++ b/chromium/net/socket/tcp_server_socket_unittest.cc @@ -0,0 +1,251 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_server_socket.h" + +#include <string> +#include <vector> + +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/tcp_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { +const int kListenBacklog = 5; + +class TCPServerSocketTest : public PlatformTest { + protected: + TCPServerSocketTest() + : socket_(NULL, NetLog::Source()) { + } + + void SetUpIPv4() { + IPEndPoint address; + ParseAddress("127.0.0.1", 0, &address); + ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog)); + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + } + + void SetUpIPv6(bool* success) { + *success = false; + IPEndPoint address; + ParseAddress("::1", 0, &address); + if (socket_.Listen(address, kListenBacklog) != 0) { + LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " + "disabled. Skipping the test"; + return; + } + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + *success = true; + } + + void ParseAddress(std::string ip_str, int port, IPEndPoint* address) { + IPAddressNumber ip_number; + bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); + if (!rv) + return; + *address = IPEndPoint(ip_number, port); + } + + static IPEndPoint GetPeerAddress(StreamSocket* socket) { + IPEndPoint address; + EXPECT_EQ(OK, socket->GetPeerAddress(&address)); + return address; + } + + AddressList local_address_list() const { + return AddressList(local_address_); + } + + TCPServerSocket socket_; + IPEndPoint local_address_; +}; + +TEST_F(TCPServerSocketTest, Accept) { + ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + int result = socket_.Accept(&accepted_socket, accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + ASSERT_TRUE(accepted_socket.get() != NULL); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), + local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +// Test Accept() callback. +TEST_F(TCPServerSocketTest, AcceptAsync) { + ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + EXPECT_TRUE(accepted_socket != NULL); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), + local_address_.address()); +} + +// Accept two connections simultaneously. +TEST_F(TCPServerSocketTest, Accept2Connections) { + ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback connect_callback2; + TCPClientSocket connecting_socket2(local_address_list(), + NULL, NetLog::Source()); + connecting_socket2.Connect(connect_callback2.callback()); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + TestCompletionCallback accept_callback2; + scoped_ptr<StreamSocket> accepted_socket2; + int result = socket_.Accept(&accepted_socket2, accept_callback2.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback2.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + + EXPECT_TRUE(accepted_socket != NULL); + EXPECT_TRUE(accepted_socket2 != NULL); + EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); + + EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), + local_address_.address()); + EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(), + local_address_.address()); +} + +TEST_F(TCPServerSocketTest, AcceptIPv6) { + bool initialized = false; + ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized)); + if (!initialized) + return; + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + int result = socket_.Accept(&accepted_socket, accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + ASSERT_TRUE(accepted_socket.get() != NULL); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), + local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +TEST_F(TCPServerSocketTest, AcceptIO) { + ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<StreamSocket> accepted_socket; + int result = socket_.Accept(&accepted_socket, accept_callback.callback()); + ASSERT_EQ(OK, accept_callback.GetResult(result)); + + ASSERT_TRUE(accepted_socket.get() != NULL); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), + local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + + const std::string message("test message"); + std::vector<char> buffer(message.size()); + + size_t bytes_written = 0; + while (bytes_written < message.size()) { + scoped_refptr<net::IOBufferWithSize> write_buffer( + new net::IOBufferWithSize(message.size() - bytes_written)); + memmove(write_buffer->data(), message.data(), message.size()); + + TestCompletionCallback write_callback; + int write_result = accepted_socket->Write( + write_buffer.get(), write_buffer->size(), write_callback.callback()); + write_result = write_callback.GetResult(write_result); + ASSERT_TRUE(write_result >= 0); + ASSERT_TRUE(bytes_written + write_result <= message.size()); + bytes_written += write_result; + } + + size_t bytes_read = 0; + while (bytes_read < message.size()) { + scoped_refptr<net::IOBufferWithSize> read_buffer( + new net::IOBufferWithSize(message.size() - bytes_read)); + TestCompletionCallback read_callback; + int read_result = connecting_socket.Read( + read_buffer.get(), read_buffer->size(), read_callback.callback()); + read_result = read_callback.GetResult(read_result); + ASSERT_TRUE(read_result >= 0); + ASSERT_TRUE(bytes_read + read_result <= message.size()); + memmove(&buffer[bytes_read], read_buffer->data(), read_result); + bytes_read += read_result; + } + + std::string received_message(buffer.begin(), buffer.end()); + ASSERT_EQ(message, received_message); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_win.cc b/chromium/net/socket/tcp_server_socket_win.cc new file mode 100644 index 00000000000..0ac77be5e81 --- /dev/null +++ b/chromium/net/socket/tcp_server_socket_win.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_server_socket_win.h" + +#include <mstcpip.h> + +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/winsock_init.h" +#include "net/base/winsock_util.h" +#include "net/socket/socket_net_log_params.h" +#include "net/socket/tcp_client_socket.h" + +namespace net { + +TCPServerSocketWin::TCPServerSocketWin(net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(INVALID_SOCKET), + socket_event_(WSA_INVALID_EVENT), + accept_socket_(NULL), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, + source.ToEventParametersCallback()); + EnsureWinsockInit(); +} + +TCPServerSocketWin::~TCPServerSocketWin() { + Close(); + net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); +} + +int TCPServerSocketWin::Listen(const IPEndPoint& address, int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_EQ(socket_, INVALID_SOCKET); + DCHECK_EQ(socket_event_, WSA_INVALID_EVENT); + + socket_event_ = WSACreateEvent(); + if (socket_event_ == WSA_INVALID_EVENT) { + PLOG(ERROR) << "WSACreateEvent()"; + return ERR_FAILED; + } + + socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); + if (socket_ == INVALID_SOCKET) { + PLOG(ERROR) << "socket() returned an error"; + return MapSystemError(WSAGetLastError()); + } + + if (SetNonBlocking(socket_)) { + int result = MapSystemError(WSAGetLastError()); + Close(); + return result; + } + + int result = SetSocketOptions(); + if (result != OK) { + Close(); + return result; + } + + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { + Close(); + return ERR_ADDRESS_INVALID; + } + + result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + result = MapSystemError(WSAGetLastError()); + Close(); + return result; + } + + result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + result = MapSystemError(WSAGetLastError()); + Close(); + return result; + } + + return OK; +} + +int TCPServerSocketWin::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len)) + return MapSystemError(WSAGetLastError()); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_FAILED; + + return OK; +} + +int TCPServerSocketWin::Accept( + scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); + + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); + + int result = AcceptInternal(socket); + + if (result == ERR_IO_PENDING) { + // Start watching + WSAEventSelect(socket_, socket_event_, FD_ACCEPT); + accept_watcher_.StartWatching(socket_event_, this); + + accept_socket_ = socket; + accept_callback_ = callback; + } + + return result; +} + +int TCPServerSocketWin::SetSocketOptions() { + // On Windows, a bound end point can be hijacked by another process by + // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE + // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the + // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another + // socket to forcibly bind to the end point until the end point is unbound. + // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE. + // MSDN: http://goo.gl/M6fjQ. + // + // Unlike on *nix, on Windows a TCP server socket can always bind to an end + // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not + // needed here. + // + // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end + // point in TIME_WAIT status. It does not have this effect for a TCP server + // socket. + + BOOL true_value = 1; + int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast<const char*>(&true_value), + sizeof(true_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +int TCPServerSocketWin::AcceptInternal(scoped_ptr<StreamSocket>* socket) { + SockaddrStorage storage; + int new_socket = accept(socket_, storage.addr, &storage.addr_len); + if (new_socket < 0) { + int net_error = MapSystemError(WSAGetLastError()); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint address; + if (!address.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (closesocket(new_socket) < 0) + PLOG(ERROR) << "closesocket"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); + return ERR_FAILED; + } + scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( + AddressList(address), + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptSocket(new_socket); + if (adopt_result != OK) { + if (closesocket(new_socket) < 0) + PLOG(ERROR) << "closesocket"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + socket->reset(tcp_socket.release()); + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&address)); + return OK; +} + +void TCPServerSocketWin::Close() { + if (socket_ != INVALID_SOCKET) { + if (closesocket(socket_) < 0) + PLOG(ERROR) << "closesocket"; + socket_ = INVALID_SOCKET; + } + + if (socket_event_) { + WSACloseEvent(socket_event_); + socket_event_ = WSA_INVALID_EVENT; + } +} + +void TCPServerSocketWin::OnObjectSignaled(HANDLE object) { + WSANETWORKEVENTS ev; + if (WSAEnumNetworkEvents(socket_, socket_event_, &ev) == SOCKET_ERROR) { + PLOG(ERROR) << "WSAEnumNetworkEvents()"; + return; + } + + if (ev.lNetworkEvents & FD_ACCEPT) { + int result = AcceptInternal(accept_socket_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + CompletionCallback callback = accept_callback_; + accept_callback_.Reset(); + callback.Run(result); + } + } +} + +} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_win.h b/chromium/net/socket/tcp_server_socket_win.h new file mode 100644 index 00000000000..5a1d378ad9b --- /dev/null +++ b/chromium/net/socket/tcp_server_socket_win.h @@ -0,0 +1,58 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ +#define NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ + +#include <winsock2.h> + +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/non_thread_safe.h" +#include "base/win/object_watcher.h" +#include "net/base/completion_callback.h" +#include "net/base/net_log.h" +#include "net/socket/server_socket.h" + +namespace net { + +class IPEndPoint; + +class NET_EXPORT_PRIVATE TCPServerSocketWin + : public ServerSocket, + NON_EXPORTED_BASE(public base::NonThreadSafe), + public base::win::ObjectWatcher::Delegate { + public: + TCPServerSocketWin(net::NetLog* net_log, + const net::NetLog::Source& source); + ~TCPServerSocketWin(); + + // net::ServerSocket implementation. + virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) OVERRIDE; + + // base::ObjectWatcher::Delegate implementation. + virtual void OnObjectSignaled(HANDLE object); + + private: + int SetSocketOptions(); + int AcceptInternal(scoped_ptr<StreamSocket>* socket); + void Close(); + + SOCKET socket_; + HANDLE socket_event_; + + base::win::ObjectWatcher accept_watcher_; + + scoped_ptr<StreamSocket>* accept_socket_; + CompletionCallback accept_callback_; + + BoundNetLog net_log_; +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc new file mode 100644 index 00000000000..6d0afac59fb --- /dev/null +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -0,0 +1,477 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/transport_client_socket_pool.h" + +#include <algorithm> + +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/message_loop/message_loop.h" +#include "base/metrics/histogram.h" +#include "base/strings/string_util.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/socket_net_log_params.h" +#include "net/socket/tcp_client_socket.h" + +using base::TimeDelta; + +namespace net { + +// TODO(willchan): Base this off RTT instead of statically setting it. Note we +// choose a timeout that is different from the backup connect job timer so they +// don't synchronize. +const int TransportConnectJob::kIPv6FallbackTimerInMs = 300; + +namespace { + +// Returns true iff all addresses in |list| are in the IPv6 family. +bool AddressListOnlyContainsIPv6(const AddressList& list) { + DCHECK(!list.empty()); + for (AddressList::const_iterator iter = list.begin(); iter != list.end(); + ++iter) { + if (iter->GetFamily() != ADDRESS_FAMILY_IPV6) + return false; + } + return true; +} + +} // namespace + +TransportSocketParams::TransportSocketParams( + const HostPortPair& host_port_pair, + RequestPriority priority, + bool disable_resolver_cache, + bool ignore_limits, + const OnHostResolutionCallback& host_resolution_callback) + : destination_(host_port_pair), + ignore_limits_(ignore_limits), + host_resolution_callback_(host_resolution_callback) { + Initialize(priority, disable_resolver_cache); +} + +TransportSocketParams::~TransportSocketParams() {} + +void TransportSocketParams::Initialize(RequestPriority priority, + bool disable_resolver_cache) { + destination_.set_priority(priority); + if (disable_resolver_cache) + destination_.set_allow_cached_response(false); +} + +// TransportConnectJobs will time out after this many seconds. Note this is +// the total time, including both host resolution and TCP connect() times. +// +// TODO(eroman): The use of this constant needs to be re-evaluated. The time +// needed for TCPClientSocketXXX::Connect() can be arbitrarily long, since +// the address list may contain many alternatives, and most of those may +// timeout. Even worse, the per-connect timeout threshold varies greatly +// between systems (anywhere from 20 seconds to 190 seconds). +// See comment #12 at http://crbug.com/23364 for specifics. +static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. + +TransportConnectJob::TransportConnectJob( + const std::string& group_name, + const scoped_refptr<TransportSocketParams>& params, + base::TimeDelta timeout_duration, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + Delegate* delegate, + NetLog* net_log) + : ConnectJob(group_name, timeout_duration, delegate, + BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), + params_(params), + client_socket_factory_(client_socket_factory), + resolver_(host_resolver), + next_state_(STATE_NONE) { +} + +TransportConnectJob::~TransportConnectJob() { + // We don't worry about cancelling the host resolution and TCP connect, since + // ~SingleRequestHostResolver and ~StreamSocket will take care of it. +} + +LoadState TransportConnectJob::GetLoadState() const { + switch (next_state_) { + case STATE_RESOLVE_HOST: + case STATE_RESOLVE_HOST_COMPLETE: + return LOAD_STATE_RESOLVING_HOST; + case STATE_TRANSPORT_CONNECT: + case STATE_TRANSPORT_CONNECT_COMPLETE: + return LOAD_STATE_CONNECTING; + default: + NOTREACHED(); + return LOAD_STATE_IDLE; + } +} + +// static +void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) { + for (AddressList::iterator i = list->begin(); i != list->end(); ++i) { + if (i->GetFamily() == ADDRESS_FAMILY_IPV4) { + std::rotate(list->begin(), i, list->end()); + break; + } + } +} + +void TransportConnectJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + NotifyDelegateOfCompletion(rv); // Deletes |this| +} + +int TransportConnectJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_RESOLVE_HOST: + DCHECK_EQ(OK, rv); + rv = DoResolveHost(); + break; + case STATE_RESOLVE_HOST_COMPLETE: + rv = DoResolveHostComplete(rv); + break; + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = DoTransportConnectComplete(rv); + break; + default: + NOTREACHED(); + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + +int TransportConnectJob::DoResolveHost() { + next_state_ = STATE_RESOLVE_HOST_COMPLETE; + connect_timing_.dns_start = base::TimeTicks::Now(); + + return resolver_.Resolve( + params_->destination(), &addresses_, + base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)), + net_log()); +} + +int TransportConnectJob::DoResolveHostComplete(int result) { + connect_timing_.dns_end = base::TimeTicks::Now(); + // Overwrite connection start time, since for connections that do not go + // through proxies, |connect_start| should not include dns lookup time. + connect_timing_.connect_start = connect_timing_.dns_end; + + if (result == OK) { + // Invoke callback, and abort if it fails. + if (!params_->host_resolution_callback().is_null()) + result = params_->host_resolution_callback().Run(addresses_, net_log()); + + if (result == OK) + next_state_ = STATE_TRANSPORT_CONNECT; + } + return result; +} + +int TransportConnectJob::DoTransportConnect() { + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; + transport_socket_ = client_socket_factory_->CreateTransportClientSocket( + addresses_, net_log().net_log(), net_log().source()); + int rv = transport_socket_->Connect( + base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); + if (rv == ERR_IO_PENDING && + addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV6 && + !AddressListOnlyContainsIPv6(addresses_)) { + fallback_timer_.Start(FROM_HERE, + base::TimeDelta::FromMilliseconds(kIPv6FallbackTimerInMs), + this, &TransportConnectJob::DoIPv6FallbackTransportConnect); + } + return rv; +} + +int TransportConnectJob::DoTransportConnectComplete(int result) { + if (result == OK) { + bool is_ipv4 = addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV4; + DCHECK(!connect_timing_.connect_start.is_null()); + DCHECK(!connect_timing_.dns_start.is_null()); + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeDelta total_duration = now - connect_timing_.dns_start; + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.DNS_Resolution_And_TCP_Connection_Latency2", + total_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + base::TimeDelta connect_duration = now - connect_timing_.connect_start; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + if (is_ipv4) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + } else { + if (AddressListOnlyContainsIPv6(addresses_)) { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + } else { + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + } + } + SetSocket(transport_socket_.Pass()); + fallback_timer_.Stop(); + } else { + // Be a bit paranoid and kill off the fallback members to prevent reuse. + fallback_transport_socket_.reset(); + fallback_addresses_.reset(); + } + + return result; +} + +void TransportConnectJob::DoIPv6FallbackTransportConnect() { + // The timer should only fire while we're waiting for the main connect to + // succeed. + if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + NOTREACHED(); + return; + } + + DCHECK(!fallback_transport_socket_.get()); + DCHECK(!fallback_addresses_.get()); + + fallback_addresses_.reset(new AddressList(addresses_)); + MakeAddressListStartWithIPv4(fallback_addresses_.get()); + fallback_transport_socket_ = + client_socket_factory_->CreateTransportClientSocket( + *fallback_addresses_, net_log().net_log(), net_log().source()); + fallback_connect_start_time_ = base::TimeTicks::Now(); + int rv = fallback_transport_socket_->Connect( + base::Bind( + &TransportConnectJob::DoIPv6FallbackTransportConnectComplete, + base::Unretained(this))); + if (rv != ERR_IO_PENDING) + DoIPv6FallbackTransportConnectComplete(rv); +} + +void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { + // This should only happen when we're waiting for the main connect to succeed. + if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + NOTREACHED(); + return; + } + + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(fallback_transport_socket_.get()); + DCHECK(fallback_addresses_.get()); + + if (result == OK) { + DCHECK(!fallback_connect_start_time_.is_null()); + DCHECK(!connect_timing_.dns_start.is_null()); + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeDelta total_duration = now - connect_timing_.dns_start; + UMA_HISTOGRAM_CUSTOM_TIMES( + "Net.DNS_Resolution_And_TCP_Connection_Latency2", + total_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + base::TimeDelta connect_duration = now - fallback_connect_start_time_; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + SetSocket(fallback_transport_socket_.Pass()); + next_state_ = STATE_NONE; + transport_socket_.reset(); + } else { + // Be a bit paranoid and kill off the fallback members to prevent reuse. + fallback_transport_socket_.reset(); + fallback_addresses_.reset(); + } + NotifyDelegateOfCompletion(result); // Deletes |this| +} + +int TransportConnectJob::ConnectInternal() { + next_state_ = STATE_RESOLVE_HOST; + return DoLoop(OK); +} + +scoped_ptr<ConnectJob> + TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return scoped_ptr<ConnectJob>( + new TransportConnectJob(group_name, + request.params(), + ConnectionTimeout(), + client_socket_factory_, + host_resolver_, + delegate, + net_log_)); +} + +base::TimeDelta + TransportClientSocketPool::TransportConnectJobFactory::ConnectionTimeout() + const { + return base::TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds); +} + +TransportClientSocketPool::TransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log) + : base_(max_sockets, max_sockets_per_group, histograms, + ClientSocketPool::unused_idle_socket_timeout(), + ClientSocketPool::used_idle_socket_timeout(), + new TransportConnectJobFactory(client_socket_factory, + host_resolver, net_log)) { + base_.EnableConnectBackupJobs(); +} + +TransportClientSocketPool::~TransportClientSocketPool() {} + +int TransportClientSocketPool::RequestSocket( + const std::string& group_name, + const void* params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { + const scoped_refptr<TransportSocketParams>* casted_params = + static_cast<const scoped_refptr<TransportSocketParams>*>(params); + + if (net_log.IsLoggingAllEvents()) { + // TODO(eroman): Split out the host and port parameters. + net_log.AddEvent( + NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKET, + CreateNetLogHostPortPairCallback( + &casted_params->get()->destination().host_port_pair())); + } + + return base_.RequestSocket(group_name, *casted_params, priority, handle, + callback, net_log); +} + +void TransportClientSocketPool::RequestSockets( + const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) { + const scoped_refptr<TransportSocketParams>* casted_params = + static_cast<const scoped_refptr<TransportSocketParams>*>(params); + + if (net_log.IsLoggingAllEvents()) { + // TODO(eroman): Split out the host and port parameters. + net_log.AddEvent( + NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKETS, + CreateNetLogHostPortPairCallback( + &casted_params->get()->destination().host_port_pair())); + } + + base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); +} + +void TransportClientSocketPool::CancelRequest( + const std::string& group_name, + ClientSocketHandle* handle) { + base_.CancelRequest(group_name, handle); +} + +void TransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); +} + +void TransportClientSocketPool::FlushWithError(int error) { + base_.FlushWithError(error); +} + +bool TransportClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void TransportClientSocketPool::CloseIdleSockets() { + base_.CloseIdleSockets(); +} + +int TransportClientSocketPool::IdleSocketCount() const { + return base_.idle_socket_count(); +} + +int TransportClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return base_.IdleSocketCountInGroup(group_name); +} + +LoadState TransportClientSocketPool::GetLoadState( + const std::string& group_name, const ClientSocketHandle* handle) const { + return base_.GetLoadState(group_name, handle); +} + +void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + +base::DictionaryValue* TransportClientSocketPool::GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const { + return base_.GetInfoAsValue(name, type); +} + +base::TimeDelta TransportClientSocketPool::ConnectionTimeout() const { + return base_.ConnectionTimeout(); +} + +ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const { + return base_.histograms(); +} + +} // namespace net diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h new file mode 100644 index 00000000000..f07dc1f5675 --- /dev/null +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -0,0 +1,221 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "net/base/host_port_pair.h" +#include "net/dns/host_resolver.h" +#include "net/dns/single_request_host_resolver.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/client_socket_pool_histograms.h" + +namespace net { + +class ClientSocketFactory; + +typedef base::Callback<int(const AddressList&, const BoundNetLog& net_log)> +OnHostResolutionCallback; + +class NET_EXPORT_PRIVATE TransportSocketParams + : public base::RefCounted<TransportSocketParams> { + public: + // |host_resolution_callback| will be invoked after the the hostname is + // resolved. If |host_resolution_callback| does not return OK, then the + // connection will be aborted with that value. + TransportSocketParams( + const HostPortPair& host_port_pair, + RequestPriority priority, + bool disable_resolver_cache, + bool ignore_limits, + const OnHostResolutionCallback& host_resolution_callback); + + const HostResolver::RequestInfo& destination() const { return destination_; } + bool ignore_limits() const { return ignore_limits_; } + const OnHostResolutionCallback& host_resolution_callback() const { + return host_resolution_callback_; + } + + private: + friend class base::RefCounted<TransportSocketParams>; + ~TransportSocketParams(); + + void Initialize(RequestPriority priority, bool disable_resolver_cache); + + HostResolver::RequestInfo destination_; + bool ignore_limits_; + const OnHostResolutionCallback host_resolution_callback_; + + DISALLOW_COPY_AND_ASSIGN(TransportSocketParams); +}; + +// TransportConnectJob handles the host resolution necessary for socket creation +// and the transport (likely TCP) connect. TransportConnectJob also has fallback +// logic for IPv6 connect() timeouts (which may happen due to networks / routers +// with broken IPv6 support). Those timeouts take 20s, so rather than make the +// user wait 20s for the timeout to fire, we use a fallback timer +// (kIPv6FallbackTimerInMs) and start a connect() to a IPv4 address if the timer +// fires. Then we race the IPv4 connect() against the IPv6 connect() (which has +// a headstart) and return the one that completes first to the socket pool. +class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { + public: + TransportConnectJob(const std::string& group_name, + const scoped_refptr<TransportSocketParams>& params, + base::TimeDelta timeout_duration, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + Delegate* delegate, + NetLog* net_log); + virtual ~TransportConnectJob(); + + // ConnectJob methods. + virtual LoadState GetLoadState() const OVERRIDE; + + // Rolls |addrlist| forward until the first IPv4 address, if any. + // WARNING: this method should only be used to implement the prefer-IPv4 hack. + static void MakeAddressListStartWithIPv4(AddressList* addrlist); + + static const int kIPv6FallbackTimerInMs; + + private: + enum State { + STATE_RESOLVE_HOST, + STATE_RESOLVE_HOST_COMPLETE, + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_NONE, + }; + + void OnIOComplete(int result); + + // Runs the state transition loop. + int DoLoop(int result); + + int DoResolveHost(); + int DoResolveHostComplete(int result); + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + + // Not part of the state machine. + void DoIPv6FallbackTransportConnect(); + void DoIPv6FallbackTransportConnectComplete(int result); + + // Begins the host resolution and the TCP connect. Returns OK on success + // and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal() OVERRIDE; + + scoped_refptr<TransportSocketParams> params_; + ClientSocketFactory* const client_socket_factory_; + SingleRequestHostResolver resolver_; + AddressList addresses_; + State next_state_; + + scoped_ptr<StreamSocket> transport_socket_; + + scoped_ptr<StreamSocket> fallback_transport_socket_; + scoped_ptr<AddressList> fallback_addresses_; + base::TimeTicks fallback_connect_start_time_; + base::OneShotTimer<TransportConnectJob> fallback_timer_; + + DISALLOW_COPY_AND_ASSIGN(TransportConnectJob); +}; + +class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { + public: + TransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log); + + virtual ~TransportClientSocketPool(); + + // ClientSocketPool implementation. + virtual int RequestSocket(const std::string& group_name, + const void* resolve_info, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) OVERRIDE; + virtual void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) OVERRIDE; + virtual void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) OVERRIDE; + virtual void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; + virtual void FlushWithError(int error) OVERRIDE; + virtual bool IsStalled() const OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; + virtual int IdleSocketCount() const OVERRIDE; + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE; + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE; + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE; + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + + private: + typedef ClientSocketPoolBase<TransportSocketParams> PoolBase; + + class TransportConnectJobFactory + : public PoolBase::ConnectJobFactory { + public: + TransportConnectJobFactory(ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + NetLog* net_log) + : client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + net_log_(net_log) {} + + virtual ~TransportConnectJobFactory() {} + + // ClientSocketPoolBase::ConnectJobFactory methods. + + virtual scoped_ptr<ConnectJob> NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const OVERRIDE; + + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + + private: + ClientSocketFactory* const client_socket_factory_; + HostResolver* const host_resolver_; + NetLog* net_log_; + + DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory); + }; + + PoolBase base_; + + DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); +}; + +REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool, + TransportSocketParams); + +} // namespace net + +#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc new file mode 100644 index 00000000000..c607a38b78a --- /dev/null +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -0,0 +1,1355 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/transport_client_socket_pool.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/platform_thread.h" +#include "net/base/capturing_net_log.h" +#include "net/base/ip_endpoint.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +using internal::ClientSocketPoolBaseHelper; + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; +const net::RequestPriority kDefaultPriority = LOW; + +// Make sure |handle| sets load times correctly when it has been assigned a +// reused socket. +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + // Only pass true in as |is_reused|, as in general, HttpStream types should + // have stricter concepts of reuse than socket pools. + EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); + + EXPECT_TRUE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +// Make sure |handle| sets load times correctly when it has been assigned a +// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner +// of a connection where |is_reused| is false may consider the connection +// reused. +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { + EXPECT_FALSE(handle.is_reused()); + + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_DNS_TIMES); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); + + TestLoadTimingInfoConnectedReused(handle); +} + +void SetIPv4Address(IPEndPoint* address) { + IPAddressNumber number; + CHECK(ParseIPLiteralToNumber("1.1.1.1", &number)); + *address = IPEndPoint(number, 80); +} + +void SetIPv6Address(IPEndPoint* address) { + IPAddressNumber number; + CHECK(ParseIPLiteralToNumber("1:abcd::3:4:ff", &number)); + *address = IPEndPoint(number, 80); +} + +class MockClientSocket : public StreamSocket { + public: + MockClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + } + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + connected_ = true; + return OK; + } + virtual void Disconnect() OVERRIDE { + connected_ = false; + } + virtual bool IsConnected() const OVERRIDE { + return connected_; + } + virtual bool IsConnectedAndIdle() const OVERRIDE { + return connected_; + } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return false; + } + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; } + + private: + bool connected_; + const AddressList addrlist_; + BoundNetLog net_log_; +}; + +class MockFailingClientSocket : public StreamSocket { + public: + MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + } + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + return ERR_CONNECTION_FAILED; + } + + virtual void Disconnect() OVERRIDE {} + + virtual bool IsConnected() const OVERRIDE { + return false; + } + virtual bool IsConnectedAndIdle() const OVERRIDE { + return false; + } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return false; + } + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; } + + private: + const AddressList addrlist_; + BoundNetLog net_log_; +}; + +class MockPendingClientSocket : public StreamSocket { + public: + // |should_connect| indicates whether the socket should successfully complete + // or fail. + // |should_stall| indicates that this socket should never connect. + // |delay_ms| is the delay, in milliseconds, before simulating a connect. + MockPendingClientSocket( + const AddressList& addrlist, + bool should_connect, + bool should_stall, + base::TimeDelta delay, + net::NetLog* net_log) + : weak_factory_(this), + should_connect_(should_connect), + should_stall_(should_stall), + delay_(delay), + is_connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { + } + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE { + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&MockPendingClientSocket::DoCallback, + weak_factory_.GetWeakPtr(), callback), + delay_); + return ERR_IO_PENDING; + } + + virtual void Disconnect() OVERRIDE {} + + virtual bool IsConnected() const OVERRIDE { + return is_connected_; + } + virtual bool IsConnectedAndIdle() const OVERRIDE { + return is_connected_; + } + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + return ERR_UNEXPECTED; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!is_connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + virtual void SetSubresourceSpeculation() OVERRIDE {} + virtual void SetOmniboxSpeculation() OVERRIDE {} + virtual bool WasEverUsed() const OVERRIDE { return false; } + virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } + virtual bool WasNpnNegotiated() const OVERRIDE { + return false; + } + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + return kProtoUnknown; + } + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + return false; + } + + // Socket implementation. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + return ERR_FAILED; + } + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; } + virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; } + + private: + void DoCallback(const CompletionCallback& callback) { + if (should_stall_) + return; + + if (should_connect_) { + is_connected_ = true; + callback.Run(OK); + } else { + is_connected_ = false; + callback.Run(ERR_CONNECTION_FAILED); + } + } + + base::WeakPtrFactory<MockPendingClientSocket> weak_factory_; + bool should_connect_; + bool should_stall_; + base::TimeDelta delay_; + bool is_connected_; + const AddressList addrlist_; + BoundNetLog net_log_; +}; + +class MockClientSocketFactory : public ClientSocketFactory { + public: + enum ClientSocketType { + MOCK_CLIENT_SOCKET, + MOCK_FAILING_CLIENT_SOCKET, + MOCK_PENDING_CLIENT_SOCKET, + MOCK_PENDING_FAILING_CLIENT_SOCKET, + // A delayed socket will pause before connecting through the message loop. + MOCK_DELAYED_CLIENT_SOCKET, + // A stalled socket that never connects at all. + MOCK_STALLED_CLIENT_SOCKET, + }; + + explicit MockClientSocketFactory(NetLog* net_log) + : net_log_(net_log), allocation_count_(0), + client_socket_type_(MOCK_CLIENT_SOCKET), client_socket_types_(NULL), + client_socket_index_(0), client_socket_index_max_(0), + delay_(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)) {} + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) OVERRIDE { + NOTREACHED(); + return scoped_ptr<DatagramClientSocket>(); + } + + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /* source */) OVERRIDE { + allocation_count_++; + + ClientSocketType type = client_socket_type_; + if (client_socket_types_ && + client_socket_index_ < client_socket_index_max_) { + type = client_socket_types_[client_socket_index_++]; + } + + switch (type) { + case MOCK_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); + case MOCK_FAILING_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockFailingClientSocket(addresses, net_log_)); + case MOCK_PENDING_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, base::TimeDelta(), net_log_)); + case MOCK_PENDING_FAILING_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, false, false, base::TimeDelta(), net_log_)); + case MOCK_DELAYED_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, delay_, net_log_)); + case MOCK_STALLED_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, true, base::TimeDelta(), net_log_)); + default: + NOTREACHED(); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); + } + } + + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); + } + + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + int allocation_count() const { return allocation_count_; } + + // Set the default ClientSocketType. + void set_client_socket_type(ClientSocketType type) { + client_socket_type_ = type; + } + + // Set a list of ClientSocketTypes to be used. + void set_client_socket_types(ClientSocketType* type_list, int num_types) { + DCHECK_GT(num_types, 0); + client_socket_types_ = type_list; + client_socket_index_ = 0; + client_socket_index_max_ = num_types; + } + + void set_delay(base::TimeDelta delay) { delay_ = delay; } + + private: + NetLog* net_log_; + int allocation_count_; + ClientSocketType client_socket_type_; + ClientSocketType* client_socket_types_; + int client_socket_index_; + int client_socket_index_max_; + base::TimeDelta delay_; +}; + +class TransportClientSocketPoolTest : public testing::Test { + protected: + TransportClientSocketPoolTest() + : connect_backup_jobs_enabled_( + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)), + params_( + new TransportSocketParams(HostPortPair("www.google.com", 80), + kDefaultPriority, false, false, + OnHostResolutionCallback())), + low_params_( + new TransportSocketParams(HostPortPair("www.google.com", 80), + LOW, false, false, + OnHostResolutionCallback())), + histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), + host_resolver_(new MockHostResolver), + client_socket_factory_(&net_log_), + pool_(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL) { + } + + virtual ~TransportClientSocketPoolTest() { + internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled( + connect_backup_jobs_enabled_); + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + scoped_refptr<TransportSocketParams> params(new TransportSocketParams( + HostPortPair("www.google.com", 80), MEDIUM, false, false, + OnHostResolutionCallback())); + return test_base_.StartRequestUsingPool( + &pool_, group_name, priority, params); + } + + int GetOrderOfRequest(size_t index) { + return test_base_.GetOrderOfRequest(index); + } + + bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) { + return test_base_.ReleaseOneConnection(keep_alive); + } + + void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) { + test_base_.ReleaseAllConnections(keep_alive); + } + + ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } + size_t completion_count() const { return test_base_.completion_count(); } + + bool connect_backup_jobs_enabled_; + CapturingNetLog net_log_; + scoped_refptr<TransportSocketParams> params_; + scoped_refptr<TransportSocketParams> low_params_; + scoped_ptr<ClientSocketPoolHistograms> histograms_; + scoped_ptr<MockHostResolver> host_resolver_; + MockClientSocketFactory client_socket_factory_; + TransportClientSocketPool pool_; + ClientSocketPoolTest test_base_; +}; + +TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { + IPAddressNumber ip_number; + ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.1", &ip_number)); + IPEndPoint addrlist_v4_1(ip_number, 80); + ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.2", &ip_number)); + IPEndPoint addrlist_v4_2(ip_number, 80); + ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::64", &ip_number)); + IPEndPoint addrlist_v6_1(ip_number, 80); + ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::66", &ip_number)); + IPEndPoint addrlist_v6_2(ip_number, 80); + + AddressList addrlist; + + // Test 1: IPv4 only. Expect no change. + addrlist.clear(); + addrlist.push_back(addrlist_v4_1); + addrlist.push_back(addrlist_v4_2); + TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist); + ASSERT_EQ(2u, addrlist.size()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily()); + + // Test 2: IPv6 only. Expect no change. + addrlist.clear(); + addrlist.push_back(addrlist_v6_1); + addrlist.push_back(addrlist_v6_2); + TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist); + ASSERT_EQ(2u, addrlist.size()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[0].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[1].GetFamily()); + + // Test 3: IPv4 then IPv6. Expect no change. + addrlist.clear(); + addrlist.push_back(addrlist_v4_1); + addrlist.push_back(addrlist_v4_2); + addrlist.push_back(addrlist_v6_1); + addrlist.push_back(addrlist_v6_2); + TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist); + ASSERT_EQ(4u, addrlist.size()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[2].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily()); + + // Test 4: IPv6, IPv4, IPv6, IPv4. Expect first IPv6 moved to the end. + addrlist.clear(); + addrlist.push_back(addrlist_v6_1); + addrlist.push_back(addrlist_v4_1); + addrlist.push_back(addrlist_v6_2); + addrlist.push_back(addrlist_v4_2); + TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist); + ASSERT_EQ(4u, addrlist.size()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[1].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[2].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily()); + + // Test 5: IPv6, IPv6, IPv4, IPv4. Expect first two IPv6's moved to the end. + addrlist.clear(); + addrlist.push_back(addrlist_v6_1); + addrlist.push_back(addrlist_v6_2); + addrlist.push_back(addrlist_v4_1); + addrlist.push_back(addrlist_v4_2); + TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist); + ASSERT_EQ(4u, addrlist.size()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[2].GetFamily()); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily()); +} + +TEST_F(TransportClientSocketPoolTest, Basic) { + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); +} + +TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { + host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name"); + TestCompletionCallback callback; + ClientSocketHandle handle; + HostPortPair host_port_pair("unresolvable.host.name", 80); + scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( + host_port_pair, kDefaultPriority, false, false, + OnHostResolutionCallback())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", dest, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); +} + +TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + + // Make the host resolutions complete synchronously this time. + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(ERR_CONNECTION_FAILED, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); +} + +TEST_F(TransportClientSocketPoolTest, PendingRequests) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, (*requests())[0]->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them finish synchronously, until we reach the per-group limit. + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + + // The rest are pending since we've used all active sockets. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count()); + + // One initial asynchronous request and then 10 pending requests. + EXPECT_EQ(11U, completion_count()); + + // First part of requests, all with the same priority, finishes in FIFO order. + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + + // Make sure that rest of the requests complete in the order of priority. + EXPECT_EQ(7, GetOrderOfRequest(7)); + EXPECT_EQ(14, GetOrderOfRequest(8)); + EXPECT_EQ(15, GetOrderOfRequest(9)); + EXPECT_EQ(10, GetOrderOfRequest(10)); + EXPECT_EQ(13, GetOrderOfRequest(11)); + EXPECT_EQ(8, GetOrderOfRequest(12)); + EXPECT_EQ(16, GetOrderOfRequest(13)); + EXPECT_EQ(11, GetOrderOfRequest(14)); + EXPECT_EQ(12, GetOrderOfRequest(15)); + EXPECT_EQ(9, GetOrderOfRequest(16)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(17)); +} + +TEST_F(TransportClientSocketPoolTest, PendingRequests_NoKeepAlive) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, (*requests())[0]->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them finish synchronously, until we reach the per-group limit. + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + + // The rest are pending since we've used all active sockets. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + // The pending requests should finish successfully. + EXPECT_EQ(OK, (*requests())[6]->WaitForResult()); + EXPECT_EQ(OK, (*requests())[7]->WaitForResult()); + EXPECT_EQ(OK, (*requests())[8]->WaitForResult()); + EXPECT_EQ(OK, (*requests())[9]->WaitForResult()); + EXPECT_EQ(OK, (*requests())[10]->WaitForResult()); + + EXPECT_EQ(static_cast<int>(requests()->size()), + client_socket_factory_.allocation_count()); + + // First asynchronous request, and then last 5 pending requests. + EXPECT_EQ(6U, completion_count()); +} + +// This test will start up a RequestSocket() and then immediately Cancel() it. +// The pending host resolution will eventually complete, and destroy the +// ClientSocketPool which will crash if the group was not cleared properly. +TEST_F(TransportClientSocketPoolTest, CancelRequestClearGroup) { + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); + handle.Reset(); +} + +TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) { + ClientSocketHandle handle; + TestCompletionCallback callback; + ClientSocketHandle handle2; + TestCompletionCallback callback2; + + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", params_, kDefaultPriority, callback2.callback(), + &pool_, BoundNetLog())); + + handle.Reset(); + + EXPECT_EQ(OK, callback2.WaitForResult()); + handle2.Reset(); +} + +TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, kDefaultPriority, callback.callback(), + &pool_, BoundNetLog())); + + handle.Reset(); + + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, kDefaultPriority, callback2.callback(), + &pool_, BoundNetLog())); + + host_resolver_->set_synchronous_mode(true); + // At this point, handle has two ConnectingSockets out for it. Due to the + // setting the mock resolver into synchronous mode, the host resolution for + // both will return in the same loop of the MessageLoop. The client socket + // is a pending socket, so the Connect() will asynchronously complete on the + // next loop of the MessageLoop. That means that the first + // ConnectingSocket will enter OnIOComplete, and then the second one will. + // If the first one is not cancelled, it will advance the load state, and + // then the second one will crash. + + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_FALSE(callback.have_result()); + + handle.Reset(); +} + +TEST_F(TransportClientSocketPoolTest, CancelRequest) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, (*requests())[0]->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + + // Reached per-group limit, queue up requests. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST)); + + // Cancel a request. + size_t index_to_cancel = kMaxSocketsPerGroup + 2; + EXPECT_FALSE((*requests())[index_to_cancel]->handle()->is_initialized()); + (*requests())[index_to_cancel]->handle()->Reset(); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(kMaxSocketsPerGroup, + client_socket_factory_.allocation_count()); + EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + EXPECT_EQ(14, GetOrderOfRequest(7)); + EXPECT_EQ(7, GetOrderOfRequest(8)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, + GetOrderOfRequest(9)); // Canceled request. + EXPECT_EQ(9, GetOrderOfRequest(10)); + EXPECT_EQ(10, GetOrderOfRequest(11)); + EXPECT_EQ(11, GetOrderOfRequest(12)); + EXPECT_EQ(8, GetOrderOfRequest(13)); + EXPECT_EQ(12, GetOrderOfRequest(14)); + EXPECT_EQ(13, GetOrderOfRequest(15)); + EXPECT_EQ(15, GetOrderOfRequest(16)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(17)); +} + +class RequestSocketCallback : public TestCompletionCallbackBase { + public: + RequestSocketCallback(ClientSocketHandle* handle, + TransportClientSocketPool* pool) + : handle_(handle), + pool_(pool), + within_callback_(false), + callback_(base::Bind(&RequestSocketCallback::OnComplete, + base::Unretained(this))) { + } + + virtual ~RequestSocketCallback() {} + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + ASSERT_EQ(OK, result); + + if (!within_callback_) { + // Don't allow reuse of the socket. Disconnect it and then release it and + // run through the MessageLoop once to get it completely released. + handle_->socket()->Disconnect(); + handle_->Reset(); + { + base::MessageLoop::ScopedNestableTaskAllower allow( + base::MessageLoop::current()); + base::MessageLoop::current()->RunUntilIdle(); + } + within_callback_ = true; + scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( + HostPortPair("www.google.com", 80), LOWEST, false, false, + OnHostResolutionCallback())); + int rv = handle_->Init("a", dest, LOWEST, callback(), pool_, + BoundNetLog()); + EXPECT_EQ(OK, rv); + } + } + + ClientSocketHandle* const handle_; + TransportClientSocketPool* const pool_; + bool within_callback_; + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(RequestSocketCallback); +}; + +TEST_F(TransportClientSocketPoolTest, RequestTwice) { + ClientSocketHandle handle; + RequestSocketCallback callback(&handle, &pool_); + scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( + HostPortPair("www.google.com", 80), LOWEST, false, false, + OnHostResolutionCallback())); + int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_, + BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + + // The callback is going to request "www.google.com". We want it to complete + // synchronously this time. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(OK, callback.WaitForResult()); + + handle.Reset(); +} + +// Make sure that pending requests get serviced after active requests get +// cancelled. +TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + + // Queue up all the requests + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Now, kMaxSocketsPerGroup requests should be active. Let's cancel them. + ASSERT_LE(kMaxSocketsPerGroup, static_cast<int>(requests()->size())); + for (int i = 0; i < kMaxSocketsPerGroup; i++) + (*requests())[i]->handle()->Reset(); + + // Let's wait for the rest to complete now. + for (size_t i = kMaxSocketsPerGroup; i < requests()->size(); ++i) { + EXPECT_EQ(OK, (*requests())[i]->WaitForResult()); + (*requests())[i]->handle()->Reset(); + } + + EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count()); +} + +// Make sure that pending requests get serviced after active requests fail. +TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); + + const int kNumRequests = 2 * kMaxSocketsPerGroup + 1; + ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang. + + // Queue up all the requests + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_CONNECTION_FAILED, (*requests())[i]->WaitForResult()); +} + +TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); + + handle.Reset(); + // Need to run all pending to release the socket back to the pool. + base::MessageLoop::current()->RunUntilIdle(); + + // Now we should have 1 idle socket. + EXPECT_EQ(1, pool_.IdleSocketCount()); + + rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(OK, rv); + EXPECT_EQ(0, pool_.IdleSocketCount()); + TestLoadTimingInfoConnectedReused(handle); +} + +TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + + handle.Reset(); + + // Need to run all pending to release the socket back to the pool. + base::MessageLoop::current()->RunUntilIdle(); + + // Now we should have 1 idle socket. + EXPECT_EQ(1, pool_.IdleSocketCount()); + + // After an IP address change, we should have 0 idle sockets. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + + EXPECT_EQ(0, pool_.IdleSocketCount()); +} + +TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { + // Case 1 tests the first socket stalling, and the backup connecting. + MockClientSocketFactory::ClientSocketType case1_types[] = { + // The first socket will not connect. + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // The second socket will connect more quickly. + MockClientSocketFactory::MOCK_CLIENT_SOCKET + }; + + // Case 2 tests the first socket being slow, so that we start the + // second connect, but the second connect stalls, and we still + // complete the first. + MockClientSocketFactory::ClientSocketType case2_types[] = { + // The first socket will connect, although delayed. + MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // The second socket will not connect. + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + }; + + MockClientSocketFactory::ClientSocketType* cases[2] = { + case1_types, + case2_types + }; + + for (size_t index = 0; index < arraysize(cases); ++index) { + client_socket_factory_.set_client_socket_types(cases[index], 2); + + EXPECT_EQ(0, pool_.IdleSocketCount()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + // Create the first socket, set the timer. + base::MessageLoop::current()->RunUntilIdle(); + + // Wait for the backup socket timer to fire. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs + 50)); + + // Let the appropriate socket connect. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + + // One socket is stalled, the other is active. + EXPECT_EQ(0, pool_.IdleSocketCount()); + handle.Reset(); + + // Close all pending connect jobs and existing sockets. + pool_.FlushWithError(ERR_NETWORK_CHANGED); + } +} + +// Test the case where a socket took long enough to start the creation +// of the backup socket, but then we cancelled the request after that. +TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); + + enum { CANCEL_BEFORE_WAIT, CANCEL_AFTER_WAIT }; + + for (int index = CANCEL_BEFORE_WAIT; index < CANCEL_AFTER_WAIT; ++index) { + EXPECT_EQ(0, pool_.IdleSocketCount()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("c", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + // Create the first socket, set the timer. + base::MessageLoop::current()->RunUntilIdle(); + + if (index == CANCEL_AFTER_WAIT) { + // Wait for the backup socket timer to fire. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)); + } + + // Let the appropriate socket connect. + base::MessageLoop::current()->RunUntilIdle(); + + handle.Reset(); + + EXPECT_FALSE(callback.have_result()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + // One socket is stalled, the other is active. + EXPECT_EQ(0, pool_.IdleSocketCount()); + } +} + +// Test the case where a socket took long enough to start the creation +// of the backup socket and never completes, and then the backup +// connection fails. +TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { + MockClientSocketFactory::ClientSocketType case_types[] = { + // The first socket will not connect. + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // The second socket will fail immediately. + MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + }; + + client_socket_factory_.set_client_socket_types(case_types, 2); + + EXPECT_EQ(0, pool_.IdleSocketCount()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + // Create the first socket, set the timer. + base::MessageLoop::current()->RunUntilIdle(); + + // Wait for the backup socket timer to fire. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)); + + // Let the second connect be synchronous. Otherwise, the emulated + // host resolution takes an extra trip through the message loop. + host_resolver_->set_synchronous_mode(true); + + // Let the appropriate socket connect. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + EXPECT_EQ(0, pool_.IdleSocketCount()); + handle.Reset(); + + // Reset for the next case. + host_resolver_->set_synchronous_mode(false); +} + +// Test the case where a socket took long enough to start the creation +// of the backup socket and eventually completes, but the backup socket +// fails. +TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { + MockClientSocketFactory::ClientSocketType case_types[] = { + // The first socket will connect, although delayed. + MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // The second socket will not connect. + MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + }; + + client_socket_factory_.set_client_socket_types(case_types, 2); + client_socket_factory_.set_delay(base::TimeDelta::FromSeconds(5)); + + EXPECT_EQ(0, pool_.IdleSocketCount()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + // Create the first socket, set the timer. + base::MessageLoop::current()->RunUntilIdle(); + + // Wait for the backup socket timer to fire. + base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)); + + // Let the second connect be synchronous. Otherwise, the emulated + // host resolution takes an extra trip through the message loop. + host_resolver_->set_synchronous_mode(true); + + // Let the appropriate socket connect. + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + handle.Reset(); + + // Reset for the next case. + host_resolver_->set_synchronous_mode(false); +} + +// Test the case of the IPv6 address stalling, and falling back to the IPv4 +// socket which finishes first. +TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + }; + + client_socket_factory_.set_client_socket_types(case_types, 2); + + // Resolve an AddressList with a IPv6 address first and then a IPv4 address. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +// Test the case of the IPv6 address being slow, thus falling back to trying to +// connect to the IPv4 address, but having the connect to the IPv6 address +// finish first. +TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + }; + + client_socket_factory_.set_client_socket_types(case_types, 2); + client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( + TransportConnectJob::kIPv6FallbackTimerInMs + 50)); + + // Resolve an AddressList with a IPv6 address first and then a IPv4 address. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv6 addresses. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_client_socket_type( + MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc new file mode 100644 index 00000000000..5c5a303b82f --- /dev/null +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -0,0 +1,449 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_client_socket.h" + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_log_unittest.h" +#include "net/base/test_completion_callback.h" +#include "net/base/winsock_init.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/tcp_listen_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { + +const char kServerReply[] = "HTTP/1.1 404 Not Found"; + +enum ClientSocketTestTypes { + TCP, + SCTP +}; + +} // namespace + +class TransportClientSocketTest + : public StreamListenSocket::Delegate, + public ::testing::TestWithParam<ClientSocketTestTypes> { + public: + TransportClientSocketTest() + : listen_port_(0), + socket_factory_(ClientSocketFactory::GetDefaultFactory()), + close_server_socket_on_next_send_(false) { + } + + virtual ~TransportClientSocketTest() { + } + + // Implement StreamListenSocket::Delegate methods + virtual void DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) OVERRIDE { + connected_sock_ = reinterpret_cast<TCPListenSocket*>(connection); + } + virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE { + // TODO(dkegel): this might not be long enough to tickle some bugs. + connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1, + false /* Don't append line feed */); + if (close_server_socket_on_next_send_) + CloseServerSocket(); + } + virtual void DidClose(StreamListenSocket* sock) OVERRIDE {} + + // Testcase hooks + virtual void SetUp(); + + void CloseServerSocket() { + // delete the connected_sock_, which will close it. + connected_sock_ = NULL; + } + + void PauseServerReads() { + connected_sock_->PauseReads(); + } + + void ResumeServerReads() { + connected_sock_->ResumeReads(); + } + + int DrainClientSocket(IOBuffer* buf, + uint32 buf_len, + uint32 bytes_to_read, + TestCompletionCallback* callback); + + void SendClientRequest(); + + void set_close_server_socket_on_next_send(bool close) { + close_server_socket_on_next_send_ = close; + } + + protected: + int listen_port_; + CapturingNetLog net_log_; + ClientSocketFactory* const socket_factory_; + scoped_ptr<StreamSocket> sock_; + + private: + scoped_refptr<TCPListenSocket> listen_sock_; + scoped_refptr<TCPListenSocket> connected_sock_; + bool close_server_socket_on_next_send_; +}; + +void TransportClientSocketTest::SetUp() { + ::testing::TestWithParam<ClientSocketTestTypes>::SetUp(); + + // Find a free port to listen on + scoped_refptr<TCPListenSocket> sock; + int port; + // Range of ports to listen on. Shouldn't need to try many. + const int kMinPort = 10100; + const int kMaxPort = 10200; +#if defined(OS_WIN) + EnsureWinsockInit(); +#endif + for (port = kMinPort; port < kMaxPort; port++) { + sock = TCPListenSocket::CreateAndListen("127.0.0.1", port, this); + if (sock.get()) + break; + } + ASSERT_TRUE(sock.get() != NULL); + listen_sock_ = sock; + listen_port_ = port; + + AddressList addr; + // MockHostResolver resolves everything to 127.0.0.1. + scoped_ptr<HostResolver> resolver(new MockHostResolver()); + HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_)); + TestCompletionCallback callback; + int rv = resolver->Resolve(info, &addr, callback.callback(), NULL, + BoundNetLog()); + CHECK_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + CHECK_EQ(rv, OK); + sock_ = + socket_factory_->CreateTransportClientSocket(addr, + &net_log_, + NetLog::Source()); +} + +int TransportClientSocketTest::DrainClientSocket( + IOBuffer* buf, uint32 buf_len, + uint32 bytes_to_read, TestCompletionCallback* callback) { + int rv = OK; + uint32 bytes_read = 0; + + while (bytes_read < bytes_to_read) { + rv = sock_->Read(buf, buf_len, callback->callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback->WaitForResult(); + + EXPECT_GE(rv, 0); + bytes_read += rv; + } + + return static_cast<int>(bytes_read); +} + +void TransportClientSocketTest::SendClientRequest() { + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); + TestCompletionCallback callback; + int rv; + + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); + rv = sock_->Write( + request_buffer.get(), arraysize(request_text) - 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(rv, static_cast<int>(arraysize(request_text) - 1)); +} + +// TODO(leighton): Add SCTP to this list when it is ready. +INSTANTIATE_TEST_CASE_P(StreamSocket, + TransportClientSocketTest, + ::testing::Values(TCP)); + +TEST_P(TransportClientSocketTest, Connect) { + TestCompletionCallback callback; + EXPECT_FALSE(sock_->IsConnected()); + + int rv = sock_->Connect(callback.callback()); + + net::CapturingNetLog::CapturedEntryList net_log_entries; + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(net::LogContainsBeginEvent( + net_log_entries, 0, net::NetLog::TYPE_SOCKET_ALIVE)); + EXPECT_TRUE(net::LogContainsBeginEvent( + net_log_entries, 1, net::NetLog::TYPE_TCP_CONNECT)); + if (rv != OK) { + ASSERT_EQ(rv, ERR_IO_PENDING); + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + + EXPECT_TRUE(sock_->IsConnected()); + net_log_.GetEntries(&net_log_entries); + EXPECT_TRUE(net::LogContainsEndEvent( + net_log_entries, -1, net::NetLog::TYPE_TCP_CONNECT)); + + sock_->Disconnect(); + EXPECT_FALSE(sock_->IsConnected()); +} + +TEST_P(TransportClientSocketTest, IsConnected) { + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + TestCompletionCallback callback; + uint32 bytes_read; + + EXPECT_FALSE(sock_->IsConnected()); + EXPECT_FALSE(sock_->IsConnectedAndIdle()); + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(rv, ERR_IO_PENDING); + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_TRUE(sock_->IsConnectedAndIdle()); + + // Send the request and wait for the server to respond. + SendClientRequest(); + + // Drain a single byte so we know we've received some data. + bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback); + ASSERT_EQ(bytes_read, 1u); + + // Socket should be considered connected, but not idle, due to + // pending data. + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_FALSE(sock_->IsConnectedAndIdle()); + + bytes_read = DrainClientSocket( + buf.get(), 4096, arraysize(kServerReply) - 2, &callback); + ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2); + + // After draining the data, the socket should be back to connected + // and idle. + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_TRUE(sock_->IsConnectedAndIdle()); + + // This time close the server socket immediately after the server response. + set_close_server_socket_on_next_send(true); + SendClientRequest(); + + bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback); + ASSERT_EQ(bytes_read, 1u); + + // As above because of data. + EXPECT_TRUE(sock_->IsConnected()); + EXPECT_FALSE(sock_->IsConnectedAndIdle()); + + bytes_read = DrainClientSocket( + buf.get(), 4096, arraysize(kServerReply) - 2, &callback); + ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2); + + // Once the data is drained, the socket should now be seen as not + // connected. + if (sock_->IsConnected()) { + // In the unlikely event that the server's connection closure is not + // processed in time, wait for the connection to be closed. + rv = sock_->Read(buf.get(), 4096, callback.callback()); + EXPECT_EQ(0, callback.GetResult(rv)); + EXPECT_FALSE(sock_->IsConnected()); + } + EXPECT_FALSE(sock_->IsConnectedAndIdle()); +} + +TEST_P(TransportClientSocketTest, Read) { + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(rv, ERR_IO_PENDING); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + SendClientRequest(); + + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + uint32 bytes_read = DrainClientSocket( + buf.get(), 4096, arraysize(kServerReply) - 1, &callback); + ASSERT_EQ(bytes_read, arraysize(kServerReply) - 1); + + // All data has been read now. Read once more to force an ERR_IO_PENDING, and + // then close the server socket, and note the close. + + rv = sock_->Read(buf.get(), 4096, callback.callback()); + ASSERT_EQ(ERR_IO_PENDING, rv); + CloseServerSocket(); + EXPECT_EQ(0, callback.WaitForResult()); +} + +TEST_P(TransportClientSocketTest, Read_SmallChunks) { + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(rv, ERR_IO_PENDING); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + SendClientRequest(); + + scoped_refptr<IOBuffer> buf(new IOBuffer(1)); + uint32 bytes_read = 0; + while (bytes_read < arraysize(kServerReply) - 1) { + rv = sock_->Read(buf.get(), 1, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + ASSERT_EQ(1, rv); + bytes_read += rv; + } + + // All data has been read now. Read once more to force an ERR_IO_PENDING, and + // then close the server socket, and note the close. + + rv = sock_->Read(buf.get(), 1, callback.callback()); + ASSERT_EQ(ERR_IO_PENDING, rv); + CloseServerSocket(); + EXPECT_EQ(0, callback.WaitForResult()); +} + +TEST_P(TransportClientSocketTest, Read_Interrupted) { + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + SendClientRequest(); + + // Do a partial read and then exit. This test should not crash! + scoped_refptr<IOBuffer> buf(new IOBuffer(16)); + rv = sock_->Read(buf.get(), 16, callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_NE(0, rv); +} + +TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(rv, ERR_IO_PENDING); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, OK); + } + + // Read first. There's no data, so it should return ERR_IO_PENDING. + const int kBufLen = 4096; + scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); + rv = sock_->Read(buf.get(), kBufLen, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + PauseServerReads(); + const int kWriteBufLen = 64 * 1024; + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); + char* request_data = request_buffer->data(); + memset(request_data, 'A', kWriteBufLen); + TestCompletionCallback write_callback; + + while (true) { + rv = sock_->Write( + request_buffer.get(), kWriteBufLen, write_callback.callback()); + ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) { + ResumeServerReads(); + rv = write_callback.WaitForResult(); + break; + } + } + + // At this point, both read and write have returned ERR_IO_PENDING, and the + // write callback has executed. We wait for the read callback to run now to + // make sure that the socket can handle full duplex communications. + + rv = callback.WaitForResult(); + EXPECT_GE(rv, 0); +} + +TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { + TestCompletionCallback callback; + int rv = sock_->Connect(callback.callback()); + if (rv != OK) { + ASSERT_EQ(ERR_IO_PENDING, rv); + + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + } + + PauseServerReads(); + const int kWriteBufLen = 64 * 1024; + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); + char* request_data = request_buffer->data(); + memset(request_data, 'A', kWriteBufLen); + TestCompletionCallback write_callback; + + while (true) { + rv = sock_->Write( + request_buffer.get(), kWriteBufLen, write_callback.callback()); + ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + + if (rv == ERR_IO_PENDING) + break; + } + + // Now we have the Write() blocked on ERR_IO_PENDING. It's time to force the + // Read() to block on ERR_IO_PENDING too. + + const int kBufLen = 4096; + scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); + while (true) { + rv = sock_->Read(buf.get(), kBufLen, callback.callback()); + ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + if (rv == ERR_IO_PENDING) + break; + } + + // At this point, both read and write have returned ERR_IO_PENDING. Now we + // run the write and read callbacks to make sure they can handle full duplex + // communications. + + ResumeServerReads(); + rv = write_callback.WaitForResult(); + EXPECT_GE(rv, 0); + + // It's possible the read is blocked because it's already read all the data. + // Close the server socket, so there will at least be a 0-byte read. + CloseServerSocket(); + + rv = callback.WaitForResult(); + EXPECT_GE(rv, 0); +} + +} // namespace net diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc new file mode 100644 index 00000000000..5b6b2498245 --- /dev/null +++ b/chromium/net/socket/unix_domain_socket_posix.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_socket_posix.h" + +#include <cstring> +#include <string> + +#include <errno.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/posix/eintr_wrapper.h" +#include "base/threading/platform_thread.h" +#include "build/build_config.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +namespace net { + +namespace { + +bool NoAuthenticationCallback(uid_t, gid_t) { + return true; +} + +bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) { +#if defined(OS_LINUX) || defined(OS_ANDROID) + struct ucred user_cred; + socklen_t len = sizeof(user_cred); + if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1) + return false; + *user_id = user_cred.uid; + *group_id = user_cred.gid; +#else + if (getpeereid(socket, user_id, group_id) == -1) + return false; +#endif + return true; +} + +} // namespace + +// static +UnixDomainSocket::AuthCallback NoAuthentication() { + return base::Bind(NoAuthenticationCallback); +} + +// static +UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace) { + SocketDescriptor s = CreateAndBind(path, use_abstract_namespace); + if (s == kInvalidSocket && !fallback_path.empty()) + s = CreateAndBind(fallback_path, use_abstract_namespace); + if (s == kInvalidSocket) + return NULL; + UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback); + sock->Listen(); + return sock; +} + +// static +scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return CreateAndListenInternal(path, "", del, auth_callback, false); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// static +scoped_refptr<UnixDomainSocket> +UnixDomainSocket::CreateAndListenWithAbstractNamespace( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return make_scoped_refptr( + CreateAndListenInternal(path, fallback_path, del, auth_callback, true)); +} +#endif + +UnixDomainSocket::UnixDomainSocket( + SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) + : StreamListenSocket(s, del), + auth_callback_(auth_callback) {} + +UnixDomainSocket::~UnixDomainSocket() {} + +// static +SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, + bool use_abstract_namespace) { + sockaddr_un addr; + static const size_t kPathMax = sizeof(addr.sun_path); + if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax) + return kInvalidSocket; + const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0); + if (s == kInvalidSocket) + return kInvalidSocket; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + socklen_t addr_len; + if (use_abstract_namespace) { + // Convert the path given into abstract socket name. It must start with + // the '\0' character, so we are adding it. |addr_len| must specify the + // length of the structure exactly, as potentially the socket name may + // have '\0' characters embedded (although we don't support this). + // Note that addr.sun_path is already zero initialized. + memcpy(addr.sun_path + 1, path.c_str(), path.size()); + addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; + } else { + memcpy(addr.sun_path, path.c_str(), path.size()); + addr_len = sizeof(sockaddr_un); + } + if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) { + LOG(ERROR) << "Could not bind unix domain socket to " << path; + if (use_abstract_namespace) + LOG(ERROR) << " (with abstract namespace enabled)"; + if (HANDLE_EINTR(close(s)) < 0) + LOG(ERROR) << "close() error"; + return kInvalidSocket; + } + return s; +} + +void UnixDomainSocket::Accept() { + SocketDescriptor conn = StreamListenSocket::AcceptSocket(); + if (conn == kInvalidSocket) + return; + uid_t user_id; + gid_t group_id; + if (!GetPeerIds(conn, &user_id, &group_id) || + !auth_callback_.Run(user_id, group_id)) { + if (HANDLE_EINTR(close(conn)) < 0) + LOG(ERROR) << "close() error"; + return; + } + scoped_refptr<UnixDomainSocket> sock( + new UnixDomainSocket(conn, socket_delegate_, auth_callback_)); + // It's up to the delegate to AddRef if it wants to keep it around. + sock->WatchSocket(WAITING_READ); + socket_delegate_->DidAccept(this, sock.get()); +} + +UnixDomainSocketFactory::UnixDomainSocketFactory( + const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback) + : path_(path), + auth_callback_(auth_callback) {} + +UnixDomainSocketFactory::~UnixDomainSocketFactory() {} + +scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainSocket::CreateAndListen( + path_, delegate, auth_callback_); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + +UnixDomainSocketWithAbstractNamespaceFactory:: +UnixDomainSocketWithAbstractNamespaceFactory( + const std::string& path, + const std::string& fallback_path, + const UnixDomainSocket::AuthCallback& auth_callback) + : UnixDomainSocketFactory(path, auth_callback), + fallback_path_(fallback_path) {} + +UnixDomainSocketWithAbstractNamespaceFactory:: +~UnixDomainSocketWithAbstractNamespaceFactory() {} + +scoped_refptr<StreamListenSocket> +UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainSocket::CreateAndListenWithAbstractNamespace( + path_, fallback_path_, delegate, auth_callback_); +} + +#endif + +} // namespace net diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h new file mode 100644 index 00000000000..2ef06803d24 --- /dev/null +++ b/chromium/net/socket/unix_domain_socket_posix.h @@ -0,0 +1,126 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ +#define NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/callback_forward.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "build/build_config.h" +#include "net/base/net_export.h" +#include "net/socket/stream_listen_socket.h" + +#if defined(OS_ANDROID) || defined(OS_LINUX) +// Feature only supported on Linux currently. This lets the Unix Domain Socket +// not be backed by the file system. +#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED +#endif + +namespace net { + +// Unix Domain Socket Implementation. Supports abstract namespaces on Linux. +class NET_EXPORT UnixDomainSocket : public StreamListenSocket { + public: + // Callback that returns whether the already connected client, identified by + // its process |user_id| and |group_id|, is allowed to keep the connection + // open. Note that the socket is closed immediately in case the callback + // returns false. + typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback; + + // Returns an authentication callback that always grants access for + // convenience in case you don't want to use authentication. + static AuthCallback NoAuthentication(); + + // Note that the returned UnixDomainSocket instance does not take ownership of + // |del|. + static scoped_refptr<UnixDomainSocket> CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + // Same as above except that the created socket uses the abstract namespace + // which is a Linux-only feature. If |fallback_path| is not empty, + // make the second attempt with the provided fallback name. + static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); +#endif + + private: + UnixDomainSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + virtual ~UnixDomainSocket(); + + static UnixDomainSocket* CreateAndListenInternal( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace); + + static SocketDescriptor CreateAndBind(const std::string& path, + bool use_abstract_namespace); + + // StreamListenSocket: + virtual void Accept() OVERRIDE; + + AuthCallback auth_callback_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket); +}; + +// Factory that can be used to instantiate UnixDomainSocket. +class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory { + public: + // Note that this class does not take ownership of the provided delegate. + UnixDomainSocketFactory(const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback); + virtual ~UnixDomainSocketFactory(); + + // StreamListenSocketFactory: + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const OVERRIDE; + + protected: + const std::string path_; + const UnixDomainSocket::AuthCallback auth_callback_; + + private: + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory); +}; + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// Use this factory to instantiate UnixDomainSocket using the abstract +// namespace feature (only supported on Linux). +class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory + : public UnixDomainSocketFactory { + public: + UnixDomainSocketWithAbstractNamespaceFactory( + const std::string& path, + const std::string& fallback_path, + const UnixDomainSocket::AuthCallback& auth_callback); + virtual ~UnixDomainSocketWithAbstractNamespaceFactory(); + + // UnixDomainSocketFactory: + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const OVERRIDE; + + private: + std::string fallback_path_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory); +}; +#endif + +} // namespace net + +#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc new file mode 100644 index 00000000000..5abe03b4ae3 --- /dev/null +++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc @@ -0,0 +1,338 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <errno.h> +#include <fcntl.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/time.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstring> +#include <queue> +#include <string> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/posix/eintr_wrapper.h" +#include "base/synchronization/condition_variable.h" +#include "base/synchronization/lock.h" +#include "base/threading/platform_thread.h" +#include "base/threading/thread.h" +#include "net/socket/unix_domain_socket_posix.h" +#include "testing/gtest/include/gtest/gtest.h" + +using std::queue; +using std::string; + +namespace net { +namespace { + +const char kSocketFilename[] = "unix_domain_socket_for_testing"; +const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2"; +const char kInvalidSocketPath[] = "/invalid/path"; +const char kMsg[] = "hello"; + +enum EventType { + EVENT_ACCEPT, + EVENT_AUTH_DENIED, + EVENT_AUTH_GRANTED, + EVENT_CLOSE, + EVENT_LISTEN, + EVENT_READ, +}; + +string MakeSocketPath(const string& socket_file_name) { + base::FilePath temp_dir; + file_util::GetTempDir(&temp_dir); + return temp_dir.Append(socket_file_name).value(); +} + +string MakeSocketPath() { + return MakeSocketPath(kSocketFilename); +} + +class EventManager : public base::RefCounted<EventManager> { + public: + EventManager() : condition_(&mutex_) {} + + bool HasPendingEvent() { + base::AutoLock lock(mutex_); + return !events_.empty(); + } + + void Notify(EventType event) { + base::AutoLock lock(mutex_); + events_.push(event); + condition_.Broadcast(); + } + + EventType WaitForEvent() { + base::AutoLock lock(mutex_); + while (events_.empty()) + condition_.Wait(); + EventType event = events_.front(); + events_.pop(); + return event; + } + + private: + friend class base::RefCounted<EventManager>; + virtual ~EventManager() {} + + queue<EventType> events_; + base::Lock mutex_; + base::ConditionVariable condition_; +}; + +class TestListenSocketDelegate : public StreamListenSocket::Delegate { + public: + explicit TestListenSocketDelegate( + const scoped_refptr<EventManager>& event_manager) + : event_manager_(event_manager) {} + + virtual void DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) OVERRIDE { + LOG(ERROR) << __PRETTY_FUNCTION__; + connection_ = connection; + Notify(EVENT_ACCEPT); + } + + virtual void DidRead(StreamListenSocket* connection, + const char* data, + int len) OVERRIDE { + { + base::AutoLock lock(mutex_); + DCHECK(len); + data_.assign(data, len - 1); + } + Notify(EVENT_READ); + } + + virtual void DidClose(StreamListenSocket* sock) OVERRIDE { + Notify(EVENT_CLOSE); + } + + void OnListenCompleted() { + Notify(EVENT_LISTEN); + } + + string ReceivedData() { + base::AutoLock lock(mutex_); + return data_; + } + + private: + void Notify(EventType event) { + event_manager_->Notify(event); + } + + const scoped_refptr<EventManager> event_manager_; + scoped_refptr<StreamListenSocket> connection_; + base::Lock mutex_; + string data_; +}; + +bool UserCanConnectCallback( + bool allow_user, const scoped_refptr<EventManager>& event_manager, + uid_t, gid_t) { + event_manager->Notify( + allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED); + return allow_user; +} + +class UnixDomainSocketTestHelper : public testing::Test { + public: + void CreateAndListen() { + socket_ = UnixDomainSocket::CreateAndListen( + file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); + socket_delegate_->OnListenCompleted(); + } + + protected: + UnixDomainSocketTestHelper(const string& path, bool allow_user) + : file_path_(path), + allow_user_(allow_user) {} + + virtual void SetUp() OVERRIDE { + event_manager_ = new EventManager(); + socket_delegate_.reset(new TestListenSocketDelegate(event_manager_)); + DeleteSocketFile(); + } + + virtual void TearDown() OVERRIDE { + DeleteSocketFile(); + socket_ = NULL; + socket_delegate_.reset(); + event_manager_ = NULL; + } + + UnixDomainSocket::AuthCallback MakeAuthCallback() { + return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_); + } + + void DeleteSocketFile() { + ASSERT_FALSE(file_path_.empty()); + base::DeleteFile(file_path_, false /* not recursive */); + } + + SocketDescriptor CreateClientSocket() { + const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0); + if (sock < 0) { + LOG(ERROR) << "socket() error"; + return StreamListenSocket::kInvalidSocket; + } + sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + socklen_t addr_len; + strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path)); + addr_len = sizeof(sockaddr_un); + if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { + LOG(ERROR) << "connect() error"; + return StreamListenSocket::kInvalidSocket; + } + return sock; + } + + scoped_ptr<base::Thread> CreateAndRunServerThread() { + base::Thread::Options options; + options.message_loop_type = base::MessageLoop::TYPE_IO; + scoped_ptr<base::Thread> thread(new base::Thread("socketio_test")); + thread->StartWithOptions(options); + thread->message_loop()->PostTask( + FROM_HERE, + base::Bind(&UnixDomainSocketTestHelper::CreateAndListen, + base::Unretained(this))); + return thread.Pass(); + } + + const base::FilePath file_path_; + const bool allow_user_; + scoped_refptr<EventManager> event_manager_; + scoped_ptr<TestListenSocketDelegate> socket_delegate_; + scoped_refptr<UnixDomainSocket> socket_; +}; + +class UnixDomainSocketTest : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTest() + : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {} +}; + +class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTestWithInvalidPath() + : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {} +}; + +class UnixDomainSocketTestWithForbiddenUser + : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTestWithForbiddenUser() + : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {} +}; + +TEST_F(UnixDomainSocketTest, CreateAndListen) { + CreateAndListen(); + EXPECT_FALSE(socket_.get() == NULL); +} + +TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) { + CreateAndListen(); + EXPECT_TRUE(socket_.get() == NULL); +} + +#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED +// Test with an invalid path to make sure that the socket is not backed by a +// file. +TEST_F(UnixDomainSocketTestWithInvalidPath, + CreateAndListenWithAbstractNamespace) { + socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( + file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); + EXPECT_FALSE(socket_.get() == NULL); +} + +TEST_F(UnixDomainSocketTest, TestFallbackName) { + scoped_refptr<UnixDomainSocket> existing_socket = + UnixDomainSocket::CreateAndListenWithAbstractNamespace( + file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); + EXPECT_FALSE(existing_socket.get() == NULL); + // First, try to bind socket with the same name with no fallback name. + socket_ = + UnixDomainSocket::CreateAndListenWithAbstractNamespace( + file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); + EXPECT_TRUE(socket_.get() == NULL); + // Now with a fallback name. + socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( + file_path_.value(), + MakeSocketPath(kFallbackSocketName), + socket_delegate_.get(), + MakeAuthCallback()); + EXPECT_FALSE(socket_.get() == NULL); + existing_socket = NULL; +} +#endif + +TEST_F(UnixDomainSocketTest, TestWithClient) { + const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); + EventType event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_LISTEN, event); + + // Create the client socket. + const SocketDescriptor sock = CreateClientSocket(); + ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_AUTH_GRANTED, event); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_ACCEPT, event); + + // Send a message from the client to the server. + ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); + ASSERT_NE(-1, ret); + ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret)); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_READ, event); + ASSERT_EQ(kMsg, socket_delegate_->ReceivedData()); + + // Close the client socket. + ret = HANDLE_EINTR(close(sock)); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_CLOSE, event); +} + +TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { + const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); + EventType event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_LISTEN, event); + const SocketDescriptor sock = CreateClientSocket(); + ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_AUTH_DENIED, event); + + // Wait until the file descriptor is closed by the server. + struct pollfd poll_fd; + poll_fd.fd = sock; + poll_fd.events = POLLIN; + poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */); + + // Send() must fail. + ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); + ASSERT_EQ(-1, ret); + ASSERT_EQ(EPIPE, errno); + ASSERT_FALSE(event_manager_->HasPendingEvent()); +} + +} // namespace +} // namespace net |
