summaryrefslogtreecommitdiffstats
path: root/chromium/net/socket
diff options
context:
space:
mode:
authorZeno Albisser <zeno.albisser@digia.com>2013-08-15 21:46:11 +0200
committerZeno Albisser <zeno.albisser@digia.com>2013-08-15 21:46:11 +0200
commit679147eead574d186ebf3069647b4c23e8ccace6 (patch)
treefc247a0ac8ff119f7c8550879ebb6d3dd8d1ff69 /chromium/net/socket
Initial import.
Diffstat (limited to 'chromium/net/socket')
-rw-r--r--chromium/net/socket/buffered_write_stream_socket.cc161
-rw-r--r--chromium/net/socket/buffered_write_stream_socket.h82
-rw-r--r--chromium/net/socket/buffered_write_stream_socket_unittest.cc124
-rw-r--r--chromium/net/socket/client_socket_factory.cc142
-rw-r--r--chromium/net/socket/client_socket_factory.h65
-rw-r--r--chromium/net/socket/client_socket_handle.cc180
-rw-r--r--chromium/net/socket/client_socket_handle.h241
-rw-r--r--chromium/net/socket/client_socket_pool.cc50
-rw-r--r--chromium/net/socket/client_socket_pool.h221
-rw-r--r--chromium/net/socket/client_socket_pool_base.cc1266
-rw-r--r--chromium/net/socket/client_socket_pool_base.h819
-rw-r--r--chromium/net/socket/client_socket_pool_base_unittest.cc4168
-rw-r--r--chromium/net/socket/client_socket_pool_histograms.cc83
-rw-r--r--chromium/net/socket/client_socket_pool_histograms.h46
-rw-r--r--chromium/net/socket/client_socket_pool_manager.cc467
-rw-r--r--chromium/net/socket/client_socket_pool_manager.h169
-rw-r--r--chromium/net/socket/client_socket_pool_manager_impl.cc392
-rw-r--r--chromium/net/socket/client_socket_pool_manager_impl.h150
-rw-r--r--chromium/net/socket/deterministic_socket_data_unittest.cc621
-rw-r--r--chromium/net/socket/mock_client_socket_pool_manager.cc94
-rw-r--r--chromium/net/socket/mock_client_socket_pool_manager.h63
-rw-r--r--chromium/net/socket/next_proto.h39
-rw-r--r--chromium/net/socket/nss_ssl_util.cc276
-rw-r--r--chromium/net/socket/nss_ssl_util.h35
-rw-r--r--chromium/net/socket/server_socket.h40
-rw-r--r--chromium/net/socket/socket.h62
-rw-r--r--chromium/net/socket/socket_net_log_params.cc72
-rw-r--r--chromium/net/socket/socket_net_log_params.h38
-rw-r--r--chromium/net/socket/socket_test_util.cc1888
-rw-r--r--chromium/net/socket/socket_test_util.h1198
-rw-r--r--chromium/net/socket/socks5_client_socket.cc487
-rw-r--r--chromium/net/socket/socks5_client_socket.h155
-rw-r--r--chromium/net/socket/socks5_client_socket_unittest.cc375
-rw-r--r--chromium/net/socket/socks_client_socket.cc432
-rw-r--r--chromium/net/socket/socks_client_socket.h134
-rw-r--r--chromium/net/socket/socks_client_socket_pool.cc310
-rw-r--r--chromium/net/socket/socks_client_socket_pool.h211
-rw-r--r--chromium/net/socket/socks_client_socket_pool_unittest.cc297
-rw-r--r--chromium/net/socket/socks_client_socket_unittest.cc415
-rw-r--r--chromium/net/socket/ssl_client_socket.cc148
-rw-r--r--chromium/net/socket/ssl_client_socket.h141
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.cc3504
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.h196
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.cc1435
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.h203
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl_unittest.cc279
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.cc664
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.h297
-rw-r--r--chromium/net/socket/ssl_client_socket_pool_unittest.cc857
-rw-r--r--chromium/net/socket/ssl_client_socket_unittest.cc1798
-rw-r--r--chromium/net/socket/ssl_error_params.cc31
-rw-r--r--chromium/net/socket/ssl_error_params.h18
-rw-r--r--chromium/net/socket/ssl_server_socket.h64
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.cc828
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.h150
-rw-r--r--chromium/net/socket/ssl_server_socket_openssl.cc28
-rw-r--r--chromium/net/socket/ssl_server_socket_unittest.cc588
-rw-r--r--chromium/net/socket/ssl_socket.h37
-rw-r--r--chromium/net/socket/stream_listen_socket.cc308
-rw-r--r--chromium/net/socket/stream_listen_socket.h155
-rw-r--r--chromium/net/socket/stream_socket.cc101
-rw-r--r--chromium/net/socket/stream_socket.h138
-rw-r--r--chromium/net/socket/tcp_client_socket.cc59
-rw-r--r--chromium/net/socket/tcp_client_socket.h35
-rw-r--r--chromium/net/socket/tcp_client_socket_libevent.cc830
-rw-r--r--chromium/net/socket/tcp_client_socket_libevent.h256
-rw-r--r--chromium/net/socket/tcp_client_socket_unittest.cc113
-rw-r--r--chromium/net/socket/tcp_client_socket_win.cc1045
-rw-r--r--chromium/net/socket/tcp_client_socket_win.h162
-rw-r--r--chromium/net/socket/tcp_listen_socket.cc128
-rw-r--r--chromium/net/socket/tcp_listen_socket.h64
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.cc291
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.h122
-rw-r--r--chromium/net/socket/tcp_server_socket.h26
-rw-r--r--chromium/net/socket/tcp_server_socket_libevent.cc223
-rw-r--r--chromium/net/socket/tcp_server_socket_libevent.h55
-rw-r--r--chromium/net/socket/tcp_server_socket_unittest.cc251
-rw-r--r--chromium/net/socket/tcp_server_socket_win.cc217
-rw-r--r--chromium/net/socket/tcp_server_socket_win.h58
-rw-r--r--chromium/net/socket/transport_client_socket_pool.cc477
-rw-r--r--chromium/net/socket/transport_client_socket_pool.h221
-rw-r--r--chromium/net/socket/transport_client_socket_pool_unittest.cc1355
-rw-r--r--chromium/net/socket/transport_client_socket_unittest.cc449
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.cc193
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.h126
-rw-r--r--chromium/net/socket/unix_domain_socket_posix_unittest.cc338
86 files changed, 35130 insertions, 0 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc
new file mode 100644
index 00000000000..cf13c5e439a
--- /dev/null
+++ b/chromium/net/socket/buffered_write_stream_socket.cc
@@ -0,0 +1,161 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/buffered_write_stream_socket.h"
+
+#include "base/bind.h"
+#include "base/location.h"
+#include "base/message_loop/message_loop.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+
+namespace net {
+
+namespace {
+
+void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) {
+ int old_capacity = dst->capacity();
+ dst->SetCapacity(old_capacity + src_len);
+ memcpy(dst->StartOfBuffer() + old_capacity, src->data(), src_len);
+}
+
+} // anonymous namespace
+
+BufferedWriteStreamSocket::BufferedWriteStreamSocket(
+ scoped_ptr<StreamSocket> socket_to_wrap)
+ : wrapped_socket_(socket_to_wrap.Pass()),
+ io_buffer_(new GrowableIOBuffer()),
+ backup_buffer_(new GrowableIOBuffer()),
+ weak_factory_(this),
+ callback_pending_(false),
+ wrapped_write_in_progress_(false),
+ error_(0) {
+}
+
+BufferedWriteStreamSocket::~BufferedWriteStreamSocket() {
+}
+
+int BufferedWriteStreamSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ return wrapped_socket_->Read(buf, buf_len, callback);
+}
+
+int BufferedWriteStreamSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (error_) {
+ return error_;
+ }
+ GrowableIOBuffer* idle_buffer =
+ wrapped_write_in_progress_ ? backup_buffer_.get() : io_buffer_.get();
+ AppendBuffer(idle_buffer, buf, buf_len);
+ if (!callback_pending_) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&BufferedWriteStreamSocket::DoDelayedWrite,
+ weak_factory_.GetWeakPtr()));
+ callback_pending_ = true;
+ }
+ return buf_len;
+}
+
+bool BufferedWriteStreamSocket::SetReceiveBufferSize(int32 size) {
+ return wrapped_socket_->SetReceiveBufferSize(size);
+}
+
+bool BufferedWriteStreamSocket::SetSendBufferSize(int32 size) {
+ return wrapped_socket_->SetSendBufferSize(size);
+}
+
+int BufferedWriteStreamSocket::Connect(const CompletionCallback& callback) {
+ return wrapped_socket_->Connect(callback);
+}
+
+void BufferedWriteStreamSocket::Disconnect() {
+ wrapped_socket_->Disconnect();
+}
+
+bool BufferedWriteStreamSocket::IsConnected() const {
+ return wrapped_socket_->IsConnected();
+}
+
+bool BufferedWriteStreamSocket::IsConnectedAndIdle() const {
+ return wrapped_socket_->IsConnectedAndIdle();
+}
+
+int BufferedWriteStreamSocket::GetPeerAddress(IPEndPoint* address) const {
+ return wrapped_socket_->GetPeerAddress(address);
+}
+
+int BufferedWriteStreamSocket::GetLocalAddress(IPEndPoint* address) const {
+ return wrapped_socket_->GetLocalAddress(address);
+}
+
+const BoundNetLog& BufferedWriteStreamSocket::NetLog() const {
+ return wrapped_socket_->NetLog();
+}
+
+void BufferedWriteStreamSocket::SetSubresourceSpeculation() {
+ wrapped_socket_->SetSubresourceSpeculation();
+}
+
+void BufferedWriteStreamSocket::SetOmniboxSpeculation() {
+ wrapped_socket_->SetOmniboxSpeculation();
+}
+
+bool BufferedWriteStreamSocket::WasEverUsed() const {
+ return wrapped_socket_->WasEverUsed();
+}
+
+bool BufferedWriteStreamSocket::UsingTCPFastOpen() const {
+ return wrapped_socket_->UsingTCPFastOpen();
+}
+
+bool BufferedWriteStreamSocket::WasNpnNegotiated() const {
+ return wrapped_socket_->WasNpnNegotiated();
+}
+
+NextProto BufferedWriteStreamSocket::GetNegotiatedProtocol() const {
+ return wrapped_socket_->GetNegotiatedProtocol();
+}
+
+bool BufferedWriteStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ return wrapped_socket_->GetSSLInfo(ssl_info);
+}
+
+void BufferedWriteStreamSocket::DoDelayedWrite() {
+ int result = wrapped_socket_->Write(
+ io_buffer_.get(),
+ io_buffer_->RemainingCapacity(),
+ base::Bind(&BufferedWriteStreamSocket::OnIOComplete,
+ base::Unretained(this)));
+ if (result == ERR_IO_PENDING) {
+ callback_pending_ = true;
+ wrapped_write_in_progress_ = true;
+ } else {
+ OnIOComplete(result);
+ }
+}
+
+void BufferedWriteStreamSocket::OnIOComplete(int result) {
+ callback_pending_ = false;
+ wrapped_write_in_progress_ = false;
+ if (backup_buffer_->RemainingCapacity()) {
+ AppendBuffer(io_buffer_.get(), backup_buffer_.get(),
+ backup_buffer_->RemainingCapacity());
+ backup_buffer_->SetCapacity(0);
+ }
+ if (result < 0) {
+ error_ = result;
+ io_buffer_->SetCapacity(0);
+ } else {
+ io_buffer_->set_offset(io_buffer_->offset() + result);
+ if (io_buffer_->RemainingCapacity()) {
+ DoDelayedWrite();
+ } else {
+ io_buffer_->SetCapacity(0);
+ }
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h
new file mode 100644
index 00000000000..aad5736d0b0
--- /dev/null
+++ b/chromium/net/socket/buffered_write_stream_socket.h
@@ -0,0 +1,82 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
+#define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_
+
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "net/base/net_log.h"
+#include "net/socket/stream_socket.h"
+
+namespace base {
+class TimeDelta;
+}
+
+namespace net {
+
+class AddressList;
+class GrowableIOBuffer;
+class IPEndPoint;
+
+// A StreamSocket decorator. All functions are passed through to the wrapped
+// socket, except for Write().
+//
+// Writes are buffered locally so that multiple Write()s to this class are
+// issued as only one Write() to the wrapped socket. This is useful to force
+// multiple requests to be issued in a single packet, as is needed to trigger
+// edge cases in HTTP pipelining.
+//
+// Note that the Write() always returns synchronously. It will either buffer the
+// entire input or return the most recently reported error.
+//
+// There are no bounds on the local buffer size. Use carefully.
+class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket {
+ public:
+ explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap);
+ virtual ~BufferedWriteStreamSocket();
+
+ // Socket interface
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ // StreamSocket interface
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ private:
+ void DoDelayedWrite();
+ void OnIOComplete(int result);
+
+ scoped_ptr<StreamSocket> wrapped_socket_;
+ scoped_refptr<GrowableIOBuffer> io_buffer_;
+ scoped_refptr<GrowableIOBuffer> backup_buffer_;
+ base::WeakPtrFactory<BufferedWriteStreamSocket> weak_factory_;
+ bool callback_pending_;
+ bool wrapped_write_in_progress_;
+ int error_;
+
+ DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_STREAM_SOCKET_H_
diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc
new file mode 100644
index 00000000000..485295f33f6
--- /dev/null
+++ b/chromium/net/socket/buffered_write_stream_socket_unittest.cc
@@ -0,0 +1,124 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/buffered_write_stream_socket.h"
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/socket_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+class BufferedWriteStreamSocketTest : public testing::Test {
+ public:
+ void Finish() {
+ base::MessageLoop::current()->RunUntilIdle();
+ EXPECT_TRUE(data_->at_read_eof());
+ EXPECT_TRUE(data_->at_write_eof());
+ }
+
+ void Initialize(MockWrite* writes, size_t writes_count) {
+ data_.reset(new DeterministicSocketData(NULL, 0, writes, writes_count));
+ data_->set_connect_data(MockConnect(SYNCHRONOUS, 0));
+ if (writes_count) {
+ data_->StopAfter(writes_count);
+ }
+ scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket(
+ new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()));
+ data_->set_delegate(wrapped_socket->AsWeakPtr());
+ socket_.reset(new BufferedWriteStreamSocket(
+ wrapped_socket.PassAs<StreamSocket>()));
+ socket_->Connect(callback_.callback());
+ }
+
+ void TestWrite(const char* text) {
+ scoped_refptr<StringIOBuffer> buf(new StringIOBuffer(text));
+ EXPECT_EQ(buf->size(),
+ socket_->Write(buf.get(), buf->size(), callback_.callback()));
+ }
+
+ scoped_ptr<BufferedWriteStreamSocket> socket_;
+ scoped_ptr<DeterministicSocketData> data_;
+ BoundNetLog net_log_;
+ TestCompletionCallback callback_;
+};
+
+TEST_F(BufferedWriteStreamSocketTest, SingleWrite) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, 0, "abc"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abc");
+ Finish();
+}
+
+TEST_F(BufferedWriteStreamSocketTest, AsyncWrite) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, 0, "abc"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abc");
+ data_->Run();
+ Finish();
+}
+
+TEST_F(BufferedWriteStreamSocketTest, TwoWritesIntoOne) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, 0, "abcdef"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abc");
+ TestWrite("def");
+ Finish();
+}
+
+TEST_F(BufferedWriteStreamSocketTest, WriteWhileBlocked) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, 0, "abc"),
+ MockWrite(ASYNC, 1, "def"),
+ MockWrite(ASYNC, 2, "ghi"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abc");
+ base::MessageLoop::current()->RunUntilIdle();
+ TestWrite("def");
+ data_->RunFor(1);
+ TestWrite("ghi");
+ data_->RunFor(1);
+ Finish();
+}
+
+TEST_F(BufferedWriteStreamSocketTest, ContinuesPartialWrite) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, 0, "abc"),
+ MockWrite(ASYNC, 1, "def"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abcdef");
+ data_->Run();
+ Finish();
+}
+
+TEST_F(BufferedWriteStreamSocketTest, TwoSeparateWrites) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, 0, "abc"),
+ MockWrite(ASYNC, 1, "def"),
+ };
+ Initialize(writes, arraysize(writes));
+ TestWrite("abc");
+ data_->RunFor(1);
+ TestWrite("def");
+ data_->RunFor(1);
+ Finish();
+}
+
+} // anonymous namespace
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc
new file mode 100644
index 00000000000..a86688e3333
--- /dev/null
+++ b/chromium/net/socket/client_socket_factory.cc
@@ -0,0 +1,142 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_factory.h"
+
+#include "base/lazy_instance.h"
+#include "base/thread_task_runner_handle.h"
+#include "base/threading/sequenced_worker_pool.h"
+#include "build/build_config.h"
+#include "net/cert/cert_database.h"
+#include "net/socket/client_socket_handle.h"
+#if defined(USE_OPENSSL)
+#include "net/socket/ssl_client_socket_openssl.h"
+#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN)
+#include "net/socket/ssl_client_socket_nss.h"
+#endif
+#include "net/socket/tcp_client_socket.h"
+#include "net/udp/udp_client_socket.h"
+
+namespace net {
+
+class X509Certificate;
+
+namespace {
+
+// ChromeOS and Linux may require interaction with smart cards or TPMs, which
+// may cause NSS functions to block for upwards of several seconds. To avoid
+// blocking all activity on the current task runner, such as network or IPC
+// traffic, run NSS SSL functions on a dedicated thread.
+#if defined(OS_CHROMEOS) || defined(OS_LINUX)
+bool g_use_dedicated_nss_thread = true;
+#else
+bool g_use_dedicated_nss_thread = false;
+#endif
+
+class DefaultClientSocketFactory : public ClientSocketFactory,
+ public CertDatabase::Observer {
+ public:
+ DefaultClientSocketFactory() {
+ if (g_use_dedicated_nss_thread) {
+ // Use a single thread for the worker pool.
+ worker_pool_ = new base::SequencedWorkerPool(1, "NSS SSL Thread");
+ nss_thread_task_runner_ =
+ worker_pool_->GetSequencedTaskRunnerWithShutdownBehavior(
+ worker_pool_->GetSequenceToken(),
+ base::SequencedWorkerPool::CONTINUE_ON_SHUTDOWN);
+ }
+
+ CertDatabase::GetInstance()->AddObserver(this);
+ }
+
+ virtual ~DefaultClientSocketFactory() {
+ // Note: This code never runs, as the factory is defined as a Leaky
+ // singleton.
+ CertDatabase::GetInstance()->RemoveObserver(this);
+ }
+
+ virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE {
+ ClearSSLSessionCache();
+ }
+
+ virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE {
+ // Per wtc, we actually only need to flush when trust is reduced.
+ // Always flush now because OnCertTrustChanged does not tell us this.
+ // See comments in ClientSocketPoolManager::OnCertTrustChanged.
+ ClearSSLSessionCache();
+ }
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE {
+ return scoped_ptr<DatagramClientSocket>(
+ new UDPClientSocket(bind_type, rand_int_cb, net_log, source));
+ }
+
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE {
+ return scoped_ptr<StreamSocket>(
+ new TCPClientSocket(addresses, net_log, source));
+ }
+
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE {
+ // nss_thread_task_runner_ may be NULL if g_use_dedicated_nss_thread is
+ // false or if the dedicated NSS thread failed to start. If so, cause NSS
+ // functions to execute on the current task runner.
+ //
+ // Note: The current task runner is obtained on each call due to unit
+ // tests, which may create and tear down the current thread's TaskRunner
+ // between each test. Because the DefaultClientSocketFactory is leaky, it
+ // may span multiple tests, and thus the current task runner may change
+ // from call to call.
+ scoped_refptr<base::SequencedTaskRunner> nss_task_runner(
+ nss_thread_task_runner_);
+ if (!nss_task_runner.get())
+ nss_task_runner = base::ThreadTaskRunnerHandle::Get();
+
+#if defined(USE_OPENSSL)
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port,
+ ssl_config, context));
+#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN)
+ return scoped_ptr<SSLClientSocket>(
+ new SSLClientSocketNSS(nss_task_runner.get(),
+ transport_socket.Pass(),
+ host_and_port,
+ ssl_config,
+ context));
+#else
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+#endif
+ }
+
+ virtual void ClearSSLSessionCache() OVERRIDE {
+ SSLClientSocket::ClearSessionCache();
+ }
+
+ private:
+ scoped_refptr<base::SequencedWorkerPool> worker_pool_;
+ scoped_refptr<base::SequencedTaskRunner> nss_thread_task_runner_;
+};
+
+static base::LazyInstance<DefaultClientSocketFactory>::Leaky
+ g_default_client_socket_factory = LAZY_INSTANCE_INITIALIZER;
+
+} // namespace
+
+// static
+ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() {
+ return g_default_client_socket_factory.Pointer();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h
new file mode 100644
index 00000000000..6cb5949f0b3
--- /dev/null
+++ b/chromium/net/socket/client_socket_factory.h
@@ -0,0 +1,65 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_FACTORY_H_
+#define NET_SOCKET_CLIENT_SOCKET_FACTORY_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/udp/datagram_socket.h"
+
+namespace net {
+
+class AddressList;
+class ClientSocketHandle;
+class DatagramClientSocket;
+class HostPortPair;
+class SSLClientSocket;
+struct SSLClientSocketContext;
+struct SSLConfig;
+class StreamSocket;
+
+// An interface used to instantiate StreamSocket objects. Used to facilitate
+// testing code with mock socket implementations.
+class NET_EXPORT ClientSocketFactory {
+ public:
+ virtual ~ClientSocketFactory() {}
+
+ // |source| is the NetLog::Source for the entity trying to create the socket,
+ // if it has one.
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) = 0;
+
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source) = 0;
+
+ // It is allowed to pass in a |transport_socket| that is not obtained from a
+ // socket pool. The caller could create a ClientSocketHandle directly and call
+ // set_socket() on it to set a valid StreamSocket instance.
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) = 0;
+
+ // Clears cache used for SSL session resumption.
+ virtual void ClearSSLSessionCache() = 0;
+
+ // Returns the default ClientSocketFactory.
+ static ClientSocketFactory* GetDefaultFactory();
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_FACTORY_H_
diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc
new file mode 100644
index 00000000000..acb896b36f5
--- /dev/null
+++ b/chromium/net/socket/client_socket_handle.cc
@@ -0,0 +1,180 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_handle.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/compiler_specific.h"
+#include "base/metrics/histogram.h"
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/client_socket_pool_histograms.h"
+
+namespace net {
+
+ClientSocketHandle::ClientSocketHandle()
+ : is_initialized_(false),
+ pool_(NULL),
+ layered_pool_(NULL),
+ is_reused_(false),
+ callback_(base::Bind(&ClientSocketHandle::OnIOComplete,
+ base::Unretained(this))),
+ is_ssl_error_(false) {}
+
+ClientSocketHandle::~ClientSocketHandle() {
+ Reset();
+}
+
+void ClientSocketHandle::Reset() {
+ ResetInternal(true);
+ ResetErrorState();
+}
+
+void ClientSocketHandle::ResetInternal(bool cancel) {
+ if (group_name_.empty()) // Was Init called?
+ return;
+ if (is_initialized()) {
+ // Because of http://crbug.com/37810 we may not have a pool, but have
+ // just a raw socket.
+ socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE);
+ if (pool_)
+ // If we've still got a socket, release it back to the ClientSocketPool so
+ // it can be deleted or reused.
+ pool_->ReleaseSocket(group_name_, PassSocket(), pool_id_);
+ } else if (cancel) {
+ // If we did not get initialized yet, we've got a socket request pending.
+ // Cancel it.
+ pool_->CancelRequest(group_name_, this);
+ }
+ is_initialized_ = false;
+ group_name_.clear();
+ is_reused_ = false;
+ user_callback_.Reset();
+ if (layered_pool_) {
+ pool_->RemoveLayeredPool(layered_pool_);
+ layered_pool_ = NULL;
+ }
+ pool_ = NULL;
+ idle_time_ = base::TimeDelta();
+ init_time_ = base::TimeTicks();
+ setup_time_ = base::TimeDelta();
+ connect_timing_ = LoadTimingInfo::ConnectTiming();
+ pool_id_ = -1;
+}
+
+void ClientSocketHandle::ResetErrorState() {
+ is_ssl_error_ = false;
+ ssl_error_response_info_ = HttpResponseInfo();
+ pending_http_proxy_connection_.reset();
+}
+
+LoadState ClientSocketHandle::GetLoadState() const {
+ CHECK(!is_initialized());
+ CHECK(!group_name_.empty());
+ // Because of http://crbug.com/37810 we may not have a pool, but have
+ // just a raw socket.
+ if (!pool_)
+ return LOAD_STATE_IDLE;
+ return pool_->GetLoadState(group_name_, this);
+}
+
+bool ClientSocketHandle::IsPoolStalled() const {
+ return pool_->IsStalled();
+}
+
+void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) {
+ CHECK(layered_pool);
+ CHECK(!layered_pool_);
+ if (pool_) {
+ pool_->AddLayeredPool(layered_pool);
+ layered_pool_ = layered_pool;
+ }
+}
+
+void ClientSocketHandle::RemoveLayeredPool(LayeredPool* layered_pool) {
+ CHECK(layered_pool);
+ CHECK(layered_pool_);
+ if (pool_) {
+ pool_->RemoveLayeredPool(layered_pool);
+ layered_pool_ = NULL;
+ }
+}
+
+bool ClientSocketHandle::GetLoadTimingInfo(
+ bool is_reused,
+ LoadTimingInfo* load_timing_info) const {
+ // Only return load timing information when there's a socket.
+ if (!socket_)
+ return false;
+
+ load_timing_info->socket_log_id = socket_->NetLog().source().id;
+ load_timing_info->socket_reused = is_reused;
+
+ // No times if the socket is reused.
+ if (is_reused)
+ return true;
+
+ load_timing_info->connect_timing = connect_timing_;
+ return true;
+}
+
+void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) {
+ socket_ = s.Pass();
+}
+
+void ClientSocketHandle::OnIOComplete(int result) {
+ CompletionCallback callback = user_callback_;
+ user_callback_.Reset();
+ HandleInitCompletion(result);
+ callback.Run(result);
+}
+
+scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() {
+ return socket_.Pass();
+}
+
+void ClientSocketHandle::HandleInitCompletion(int result) {
+ CHECK_NE(ERR_IO_PENDING, result);
+ ClientSocketPoolHistograms* histograms = pool_->histograms();
+ histograms->AddErrorCode(result);
+ if (result != OK) {
+ if (!socket_.get())
+ ResetInternal(false); // Nothing to cancel since the request failed.
+ else
+ is_initialized_ = true;
+ return;
+ }
+ is_initialized_ = true;
+ CHECK_NE(-1, pool_id_) << "Pool should have set |pool_id_| to a valid value.";
+ setup_time_ = base::TimeTicks::Now() - init_time_;
+
+ histograms->AddSocketType(reuse_type());
+ switch (reuse_type()) {
+ case ClientSocketHandle::UNUSED:
+ histograms->AddRequestTime(setup_time());
+ break;
+ case ClientSocketHandle::UNUSED_IDLE:
+ histograms->AddUnusedIdleTime(idle_time());
+ break;
+ case ClientSocketHandle::REUSED_IDLE:
+ histograms->AddReusedIdleTime(idle_time());
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+
+ // Broadcast that the socket has been acquired.
+ // TODO(eroman): This logging is not complete, in particular set_socket() and
+ // release() socket. It ends up working though, since those methods are being
+ // used to layer sockets (and the destination sources are the same).
+ DCHECK(socket_.get());
+ socket_->NetLog().BeginEvent(
+ NetLog::TYPE_SOCKET_IN_USE,
+ requesting_source_.ToEventParametersCallback());
+}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h
new file mode 100644
index 00000000000..9651f089a86
--- /dev/null
+++ b/chromium/net/socket/client_socket_handle.h
@@ -0,0 +1,241 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_HANDLE_H_
+#define NET_SOCKET_CLIENT_SOCKET_HANDLE_H_
+
+#include <string>
+
+#include "base/logging.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/completion_callback.h"
+#include "net/base/load_states.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/base/request_priority.h"
+#include "net/http/http_response_info.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+// A container for a StreamSocket.
+//
+// The handle's |group_name| uniquely identifies the origin and type of the
+// connection. It is used by the ClientSocketPool to group similar connected
+// client socket objects.
+//
+class NET_EXPORT ClientSocketHandle {
+ public:
+ enum SocketReuseType {
+ UNUSED = 0, // unused socket that just finished connecting
+ UNUSED_IDLE, // unused socket that has been idle for awhile
+ REUSED_IDLE, // previously used socket
+ NUM_TYPES,
+ };
+
+ ClientSocketHandle();
+ ~ClientSocketHandle();
+
+ // Initializes a ClientSocketHandle object, which involves talking to the
+ // ClientSocketPool to obtain a connected socket, possibly reusing one. This
+ // method returns either OK or ERR_IO_PENDING. On ERR_IO_PENDING, |priority|
+ // is used to determine the placement in ClientSocketPool's wait list.
+ //
+ // If this method succeeds, then the socket member will be set to an existing
+ // connected socket if an existing connected socket was available to reuse,
+ // otherwise it will be set to a new connected socket. Consumers can then
+ // call is_reused() to see if the socket was reused. If not reusing an
+ // existing socket, ClientSocketPool may need to establish a new
+ // connection using |socket_params|.
+ //
+ // This method returns ERR_IO_PENDING if it cannot complete synchronously, in
+ // which case the consumer will be notified of completion via |callback|.
+ //
+ // If the pool was not able to reuse an existing socket, the new socket
+ // may report a recoverable error. In this case, the return value will
+ // indicate an error and the socket member will be set. If it is determined
+ // that the error is not recoverable, the Disconnect method should be used
+ // on the socket, so that it does not get reused.
+ //
+ // A non-recoverable error may set additional state in the ClientSocketHandle
+ // to allow the caller to determine what went wrong.
+ //
+ // Init may be called multiple times.
+ //
+ // Profiling information for the request is saved to |net_log| if non-NULL.
+ //
+ template <typename SocketParams, typename PoolType>
+ int Init(const std::string& group_name,
+ const scoped_refptr<SocketParams>& socket_params,
+ RequestPriority priority,
+ const CompletionCallback& callback,
+ PoolType* pool,
+ const BoundNetLog& net_log);
+
+ // An initialized handle can be reset, which causes it to return to the
+ // un-initialized state. This releases the underlying socket, which in the
+ // case of a socket that still has an established connection, indicates that
+ // the socket may be kept alive for use by a subsequent ClientSocketHandle.
+ //
+ // NOTE: To prevent the socket from being kept alive, be sure to call its
+ // Disconnect method. This will result in the ClientSocketPool deleting the
+ // StreamSocket.
+ void Reset();
+
+ // Used after Init() is called, but before the ClientSocketPool has
+ // initialized the ClientSocketHandle.
+ LoadState GetLoadState() const;
+
+ bool IsPoolStalled() const;
+
+ void AddLayeredPool(LayeredPool* layered_pool);
+
+ void RemoveLayeredPool(LayeredPool* layered_pool);
+
+ // Returns true when Init() has completed successfully.
+ bool is_initialized() const { return is_initialized_; }
+
+ // Returns the time tick when Init() was called.
+ base::TimeTicks init_time() const { return init_time_; }
+
+ // Returns the time between Init() and when is_initialized() becomes true.
+ base::TimeDelta setup_time() const { return setup_time_; }
+
+ // Sets the portion of LoadTimingInfo related to connection establishment, and
+ // the socket id. |is_reused| is needed because the handle may not have full
+ // reuse information. |load_timing_info| must have all default values when
+ // called. Returns false and makes no changes to |load_timing_info| when
+ // |socket_| is NULL.
+ bool GetLoadTimingInfo(bool is_reused,
+ LoadTimingInfo* load_timing_info) const;
+
+ // Used by ClientSocketPool to initialize the ClientSocketHandle.
+ void SetSocket(scoped_ptr<StreamSocket> s);
+ void set_is_reused(bool is_reused) { is_reused_ = is_reused; }
+ void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; }
+ void set_pool_id(int id) { pool_id_ = id; }
+ void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; }
+ void set_ssl_error_response_info(const HttpResponseInfo& ssl_error_state) {
+ ssl_error_response_info_ = ssl_error_state;
+ }
+ void set_pending_http_proxy_connection(ClientSocketHandle* connection) {
+ pending_http_proxy_connection_.reset(connection);
+ }
+
+ // Only valid if there is no |socket_|.
+ bool is_ssl_error() const {
+ DCHECK(socket_.get() == NULL);
+ return is_ssl_error_;
+ }
+ // On an ERR_PROXY_AUTH_REQUESTED error, the |headers| and |auth_challenge|
+ // fields are filled in. On an ERR_SSL_CLIENT_AUTH_CERT_NEEDED error,
+ // the |cert_request_info| field is set.
+ const HttpResponseInfo& ssl_error_response_info() const {
+ return ssl_error_response_info_;
+ }
+ ClientSocketHandle* release_pending_http_proxy_connection() {
+ return pending_http_proxy_connection_.release();
+ }
+
+ // These may only be used if is_initialized() is true.
+ scoped_ptr<StreamSocket> PassSocket();
+ StreamSocket* socket() { return socket_.get(); }
+ const std::string& group_name() const { return group_name_; }
+ int id() const { return pool_id_; }
+ bool is_reused() const { return is_reused_; }
+ base::TimeDelta idle_time() const { return idle_time_; }
+ SocketReuseType reuse_type() const {
+ if (is_reused()) {
+ return REUSED_IDLE;
+ } else if (idle_time() == base::TimeDelta()) {
+ return UNUSED;
+ } else {
+ return UNUSED_IDLE;
+ }
+ }
+ const LoadTimingInfo::ConnectTiming& connect_timing() const {
+ return connect_timing_;
+ }
+ void set_connect_timing(const LoadTimingInfo::ConnectTiming& connect_timing) {
+ connect_timing_ = connect_timing;
+ }
+
+ private:
+ // Called on asynchronous completion of an Init() request.
+ void OnIOComplete(int result);
+
+ // Called on completion (both asynchronous & synchronous) of an Init()
+ // request.
+ void HandleInitCompletion(int result);
+
+ // Resets the state of the ClientSocketHandle. |cancel| indicates whether or
+ // not to try to cancel the request with the ClientSocketPool. Does not
+ // reset the supplemental error state.
+ void ResetInternal(bool cancel);
+
+ // Resets the supplemental error state.
+ void ResetErrorState();
+
+ bool is_initialized_;
+ ClientSocketPool* pool_;
+ LayeredPool* layered_pool_;
+ scoped_ptr<StreamSocket> socket_;
+ std::string group_name_;
+ bool is_reused_;
+ CompletionCallback callback_;
+ CompletionCallback user_callback_;
+ base::TimeDelta idle_time_;
+ int pool_id_; // See ClientSocketPool::ReleaseSocket() for an explanation.
+ bool is_ssl_error_;
+ HttpResponseInfo ssl_error_response_info_;
+ scoped_ptr<ClientSocketHandle> pending_http_proxy_connection_;
+ base::TimeTicks init_time_;
+ base::TimeDelta setup_time_;
+
+ NetLog::Source requesting_source_;
+
+ // Timing information is set when a connection is successfully established.
+ LoadTimingInfo::ConnectTiming connect_timing_;
+
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketHandle);
+};
+
+// Template function implementation:
+template <typename SocketParams, typename PoolType>
+int ClientSocketHandle::Init(const std::string& group_name,
+ const scoped_refptr<SocketParams>& socket_params,
+ RequestPriority priority,
+ const CompletionCallback& callback,
+ PoolType* pool,
+ const BoundNetLog& net_log) {
+ requesting_source_ = net_log.source();
+
+ CHECK(!group_name.empty());
+ // Note that this will result in a compile error if the SocketParams has not
+ // been registered for the PoolType via REGISTER_SOCKET_PARAMS_FOR_POOL
+ // (defined in client_socket_pool.h).
+ CheckIsValidSocketParamsForPool<PoolType, SocketParams>();
+ ResetInternal(true);
+ ResetErrorState();
+ pool_ = pool;
+ group_name_ = group_name;
+ init_time_ = base::TimeTicks::Now();
+ int rv = pool_->RequestSocket(
+ group_name, &socket_params, priority, this, callback_, net_log);
+ if (rv == ERR_IO_PENDING) {
+ user_callback_ = callback;
+ } else {
+ HandleInitCompletion(rv);
+ }
+ return rv;
+}
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_HANDLE_H_
diff --git a/chromium/net/socket/client_socket_pool.cc b/chromium/net/socket/client_socket_pool.cc
new file mode 100644
index 00000000000..0eebd11b9fb
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool.cc
@@ -0,0 +1,50 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool.h"
+
+#include "base/logging.h"
+
+namespace {
+
+// The maximum duration, in seconds, to keep unused idle persistent sockets
+// alive.
+// TODO(ziadh): Change this timeout after getting histogram data on how long it
+// should be.
+int g_unused_idle_socket_timeout_s = 10;
+
+// The maximum duration, in seconds, to keep used idle persistent sockets alive.
+int g_used_idle_socket_timeout_s = 300; // 5 minutes
+
+} // namespace
+
+namespace net {
+
+// static
+base::TimeDelta ClientSocketPool::unused_idle_socket_timeout() {
+ return base::TimeDelta::FromSeconds(g_unused_idle_socket_timeout_s);
+}
+
+// static
+void ClientSocketPool::set_unused_idle_socket_timeout(base::TimeDelta timeout) {
+ DCHECK_GT(timeout.InSeconds(), 0);
+ g_unused_idle_socket_timeout_s = timeout.InSeconds();
+}
+
+// static
+base::TimeDelta ClientSocketPool::used_idle_socket_timeout() {
+ return base::TimeDelta::FromSeconds(g_used_idle_socket_timeout_s);
+}
+
+// static
+void ClientSocketPool::set_used_idle_socket_timeout(base::TimeDelta timeout) {
+ DCHECK_GT(timeout.InSeconds(), 0);
+ g_used_idle_socket_timeout_s = timeout.InSeconds();
+}
+
+ClientSocketPool::ClientSocketPool() {}
+
+ClientSocketPool::~ClientSocketPool() {}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h
new file mode 100644
index 00000000000..af184547d6e
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool.h
@@ -0,0 +1,221 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_H_
+#define NET_SOCKET_CLIENT_SOCKET_POOL_H_
+
+#include <deque>
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/template_util.h"
+#include "base/time/time.h"
+#include "net/base/completion_callback.h"
+#include "net/base/load_states.h"
+#include "net/base/net_export.h"
+#include "net/base/request_priority.h"
+#include "net/dns/host_resolver.h"
+
+namespace base {
+class DictionaryValue;
+}
+
+namespace net {
+
+class ClientSocketHandle;
+class ClientSocketPoolHistograms;
+class StreamSocket;
+
+// ClientSocketPools are layered. This defines an interface for lower level
+// socket pools to communicate with higher layer pools.
+class NET_EXPORT LayeredPool {
+ public:
+ virtual ~LayeredPool() {};
+
+ // Instructs the LayeredPool to close an idle connection. Return true if one
+ // was closed.
+ virtual bool CloseOneIdleConnection() = 0;
+};
+
+// A ClientSocketPool is used to restrict the number of sockets open at a time.
+// It also maintains a list of idle persistent sockets.
+//
+class NET_EXPORT ClientSocketPool {
+ public:
+ // Requests a connected socket for a group_name.
+ //
+ // There are five possible results from calling this function:
+ // 1) RequestSocket returns OK and initializes |handle| with a reused socket.
+ // 2) RequestSocket returns OK with a newly connected socket.
+ // 3) RequestSocket returns ERR_IO_PENDING. The handle will be added to a
+ // wait list until a socket is available to reuse or a new socket finishes
+ // connecting. |priority| will determine the placement into the wait list.
+ // 4) An error occurred early on, so RequestSocket returns an error code.
+ // 5) A recoverable error occurred while setting up the socket. An error
+ // code is returned, but the |handle| is initialized with the new socket.
+ // The caller must recover from the error before using the connection, or
+ // Disconnect the socket before releasing or resetting the |handle|.
+ // The current recoverable errors are: the errors accepted by
+ // IsCertificateError(err) and PROXY_AUTH_REQUESTED, or
+ // HTTPS_PROXY_TUNNEL_RESPONSE when reported by HttpProxyClientSocketPool.
+ //
+ // If this function returns OK, then |handle| is initialized upon return.
+ // The |handle|'s is_initialized method will return true in this case. If a
+ // StreamSocket was reused, then ClientSocketPool will call
+ // |handle|->set_reused(true). In either case, the socket will have been
+ // allocated and will be connected. A client might want to know whether or
+ // not the socket is reused in order to request a new socket if he encounters
+ // an error with the reused socket.
+ //
+ // If ERR_IO_PENDING is returned, then the callback will be used to notify the
+ // client of completion.
+ //
+ // Profiling information for the request is saved to |net_log| if non-NULL.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) = 0;
+
+ // RequestSockets is used to request that |num_sockets| be connected in the
+ // connection group for |group_name|. If the connection group already has
+ // |num_sockets| idle sockets / active sockets / currently connecting sockets,
+ // then this function doesn't do anything. Otherwise, it will start up as
+ // many connections as necessary to reach |num_sockets| total sockets for the
+ // group. It uses |params| to control how to connect the sockets. The
+ // ClientSocketPool will assign a priority to the new connections, if any.
+ // This priority will probably be lower than all others, since this method
+ // is intended to make sure ahead of time that |num_sockets| sockets are
+ // available to talk to a host.
+ virtual void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) = 0;
+
+ // Called to cancel a RequestSocket call that returned ERR_IO_PENDING. The
+ // same handle parameter must be passed to this method as was passed to the
+ // RequestSocket call being cancelled. The associated CompletionCallback is
+ // not run. However, for performance, we will let one ConnectJob complete
+ // and go idle.
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) = 0;
+
+ // Called to release a socket once the socket is no longer needed. If the
+ // socket still has an established connection, then it will be added to the
+ // set of idle sockets to be used to satisfy future RequestSocket calls.
+ // Otherwise, the StreamSocket is destroyed. |id| is used to differentiate
+ // between updated versions of the same pool instance. The pool's id will
+ // change when it flushes, so it can use this |id| to discard sockets with
+ // mismatched ids.
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) = 0;
+
+ // This flushes all state from the ClientSocketPool. This means that all
+ // idle and connecting sockets are discarded with the given |error|.
+ // Active sockets being held by ClientSocketPool clients will be discarded
+ // when released back to the pool.
+ // Does not flush any pools wrapped by |this|.
+ virtual void FlushWithError(int error) = 0;
+
+ // Returns true if a there is currently a request blocked on the
+ // per-pool (not per-host) max socket limit.
+ virtual bool IsStalled() const = 0;
+
+ // Called to close any idle connections held by the connection manager.
+ virtual void CloseIdleSockets() = 0;
+
+ // The total number of idle sockets in the pool.
+ virtual int IdleSocketCount() const = 0;
+
+ // The total number of idle sockets in a connection group.
+ virtual int IdleSocketCountInGroup(const std::string& group_name) const = 0;
+
+ // Determine the LoadState of a connecting ClientSocketHandle.
+ virtual LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const = 0;
+
+ // Adds a LayeredPool on top of |this|.
+ virtual void AddLayeredPool(LayeredPool* layered_pool) = 0;
+
+ // Removes a LayeredPool from |this|.
+ virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0;
+
+ // Retrieves information on the current state of the pool as a
+ // DictionaryValue. Caller takes possession of the returned value.
+ // If |include_nested_pools| is true, the states of any nested
+ // ClientSocketPools will be included.
+ virtual base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const = 0;
+
+ // Returns the maximum amount of time to wait before retrying a connect.
+ static const int kMaxConnectRetryIntervalMs = 250;
+
+ // The set of histograms specific to this pool. We can't use the standard
+ // UMA_HISTOGRAM_* macros because they are callsite static.
+ virtual ClientSocketPoolHistograms* histograms() const = 0;
+
+ static base::TimeDelta unused_idle_socket_timeout();
+ static void set_unused_idle_socket_timeout(base::TimeDelta timeout);
+
+ static base::TimeDelta used_idle_socket_timeout();
+ static void set_used_idle_socket_timeout(base::TimeDelta timeout);
+
+ protected:
+ ClientSocketPool();
+ virtual ~ClientSocketPool();
+
+ // Return the connection timeout for this pool.
+ virtual base::TimeDelta ConnectionTimeout() const = 0;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketPool);
+};
+
+// ClientSocketPool subclasses should indicate valid SocketParams via the
+// REGISTER_SOCKET_PARAMS_FOR_POOL macro below. By default, any given
+// <PoolType,SocketParams> pair will have its SocketParamsTrait inherit from
+// base::false_type, but REGISTER_SOCKET_PARAMS_FOR_POOL will specialize that
+// pairing to inherit from base::true_type. This provides compile time
+// verification that the correct SocketParams type is used with the appropriate
+// PoolType.
+template <typename PoolType, typename SocketParams>
+struct SocketParamTraits : public base::false_type {
+};
+
+template <typename PoolType, typename SocketParams>
+void CheckIsValidSocketParamsForPool() {
+ COMPILE_ASSERT(!base::is_pointer<scoped_refptr<SocketParams> >::value,
+ socket_params_cannot_be_pointer);
+ COMPILE_ASSERT((SocketParamTraits<PoolType,
+ scoped_refptr<SocketParams> >::value),
+ invalid_socket_params_for_pool);
+}
+
+// Provides an empty definition for CheckIsValidSocketParamsForPool() which
+// should be optimized out by the compiler.
+#define REGISTER_SOCKET_PARAMS_FOR_POOL(pool_type, socket_params) \
+template<> \
+struct SocketParamTraits<pool_type, scoped_refptr<socket_params> > \
+ : public base::true_type { \
+}
+
+template <typename PoolType, typename SocketParams>
+void RequestSocketsForPool(PoolType* pool,
+ const std::string& group_name,
+ const scoped_refptr<SocketParams>& params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ CheckIsValidSocketParamsForPool<PoolType, SocketParams>();
+ pool->RequestSockets(group_name, &params, num_sockets, net_log);
+}
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc
new file mode 100644
index 00000000000..b1ddd40881c
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_base.cc
@@ -0,0 +1,1266 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool_base.h"
+
+#include "base/compiler_specific.h"
+#include "base/format_macros.h"
+#include "base/logging.h"
+#include "base/message_loop/message_loop.h"
+#include "base/metrics/stats_counters.h"
+#include "base/stl_util.h"
+#include "base/strings/string_util.h"
+#include "base/time/time.h"
+#include "base/values.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_handle.h"
+
+using base::TimeDelta;
+
+namespace net {
+
+namespace {
+
+// Indicate whether we should enable idle socket cleanup timer. When timer is
+// disabled, sockets are closed next time a socket request is made.
+bool g_cleanup_timer_enabled = true;
+
+// The timeout value, in seconds, used to clean up idle sockets that can't be
+// reused.
+//
+// Note: It's important to close idle sockets that have received data as soon
+// as possible because the received data may cause BSOD on Windows XP under
+// some conditions. See http://crbug.com/4606.
+const int kCleanupInterval = 10; // DO NOT INCREASE THIS TIMEOUT.
+
+// Indicate whether or not we should establish a new transport layer connection
+// after a certain timeout has passed without receiving an ACK.
+bool g_connect_backup_jobs_enabled = true;
+
+// Compares the effective priority of two results, and returns 1 if |request1|
+// has greater effective priority than |request2|, 0 if they have the same
+// effective priority, and -1 if |request2| has the greater effective priority.
+// Requests with |ignore_limits| set have higher effective priority than those
+// without. If both requests have |ignore_limits| set/unset, then the request
+// with the highest Pririoty has the highest effective priority. Does not take
+// into account the fact that Requests are serviced in FIFO order if they would
+// otherwise have the same priority.
+int CompareEffectiveRequestPriority(
+ const internal::ClientSocketPoolBaseHelper::Request& request1,
+ const internal::ClientSocketPoolBaseHelper::Request& request2) {
+ if (request1.ignore_limits() && !request2.ignore_limits())
+ return 1;
+ if (!request1.ignore_limits() && request2.ignore_limits())
+ return -1;
+ if (request1.priority() > request2.priority())
+ return 1;
+ if (request1.priority() < request2.priority())
+ return -1;
+ return 0;
+}
+
+} // namespace
+
+ConnectJob::ConnectJob(const std::string& group_name,
+ base::TimeDelta timeout_duration,
+ Delegate* delegate,
+ const BoundNetLog& net_log)
+ : group_name_(group_name),
+ timeout_duration_(timeout_duration),
+ delegate_(delegate),
+ net_log_(net_log),
+ idle_(true) {
+ DCHECK(!group_name.empty());
+ DCHECK(delegate);
+ net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB,
+ NetLog::StringCallback("group_name", &group_name_));
+}
+
+ConnectJob::~ConnectJob() {
+ net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB);
+}
+
+scoped_ptr<StreamSocket> ConnectJob::PassSocket() {
+ return socket_.Pass();
+}
+
+int ConnectJob::Connect() {
+ if (timeout_duration_ != base::TimeDelta())
+ timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout);
+
+ idle_ = false;
+
+ LogConnectStart();
+
+ int rv = ConnectInternal();
+
+ if (rv != ERR_IO_PENDING) {
+ LogConnectCompletion(rv);
+ delegate_ = NULL;
+ }
+
+ return rv;
+}
+
+void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) {
+ if (socket) {
+ net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET,
+ socket->NetLog().source().ToEventParametersCallback());
+ }
+ socket_ = socket.Pass();
+}
+
+void ConnectJob::NotifyDelegateOfCompletion(int rv) {
+ // The delegate will own |this|.
+ Delegate* delegate = delegate_;
+ delegate_ = NULL;
+
+ LogConnectCompletion(rv);
+ delegate->OnConnectJobComplete(rv, this);
+}
+
+void ConnectJob::ResetTimer(base::TimeDelta remaining_time) {
+ timer_.Stop();
+ timer_.Start(FROM_HERE, remaining_time, this, &ConnectJob::OnTimeout);
+}
+
+void ConnectJob::LogConnectStart() {
+ connect_timing_.connect_start = base::TimeTicks::Now();
+ net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT);
+}
+
+void ConnectJob::LogConnectCompletion(int net_error) {
+ connect_timing_.connect_end = base::TimeTicks::Now();
+ net_log().EndEventWithNetErrorCode(
+ NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT, net_error);
+}
+
+void ConnectJob::OnTimeout() {
+ // Make sure the socket is NULL before calling into |delegate|.
+ SetSocket(scoped_ptr<StreamSocket>());
+
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT);
+
+ NotifyDelegateOfCompletion(ERR_TIMED_OUT);
+}
+
+namespace internal {
+
+ClientSocketPoolBaseHelper::Request::Request(
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ RequestPriority priority,
+ bool ignore_limits,
+ Flags flags,
+ const BoundNetLog& net_log)
+ : handle_(handle),
+ callback_(callback),
+ priority_(priority),
+ ignore_limits_(ignore_limits),
+ flags_(flags),
+ net_log_(net_log) {}
+
+ClientSocketPoolBaseHelper::Request::~Request() {}
+
+ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper(
+ int max_sockets,
+ int max_sockets_per_group,
+ base::TimeDelta unused_idle_socket_timeout,
+ base::TimeDelta used_idle_socket_timeout,
+ ConnectJobFactory* connect_job_factory)
+ : idle_socket_count_(0),
+ connecting_socket_count_(0),
+ handed_out_socket_count_(0),
+ max_sockets_(max_sockets),
+ max_sockets_per_group_(max_sockets_per_group),
+ use_cleanup_timer_(g_cleanup_timer_enabled),
+ unused_idle_socket_timeout_(unused_idle_socket_timeout),
+ used_idle_socket_timeout_(used_idle_socket_timeout),
+ connect_job_factory_(connect_job_factory),
+ connect_backup_jobs_enabled_(false),
+ pool_generation_number_(0),
+ weak_factory_(this) {
+ DCHECK_LE(0, max_sockets_per_group);
+ DCHECK_LE(max_sockets_per_group, max_sockets);
+
+ NetworkChangeNotifier::AddIPAddressObserver(this);
+}
+
+ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() {
+ // Clean up any idle sockets and pending connect jobs. Assert that we have no
+ // remaining active sockets or pending requests. They should have all been
+ // cleaned up prior to |this| being destroyed.
+ FlushWithError(ERR_ABORTED);
+ DCHECK(group_map_.empty());
+ DCHECK(pending_callback_map_.empty());
+ DCHECK_EQ(0, connecting_socket_count_);
+ CHECK(higher_layer_pools_.empty());
+
+ NetworkChangeNotifier::RemoveIPAddressObserver(this);
+}
+
+ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair()
+ : result(OK) {
+}
+
+ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair(
+ const CompletionCallback& callback_in, int result_in)
+ : callback(callback_in),
+ result(result_in) {
+}
+
+ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {}
+
+// static
+void ClientSocketPoolBaseHelper::InsertRequestIntoQueue(
+ const Request* r, RequestQueue* pending_requests) {
+ RequestQueue::iterator it = pending_requests->begin();
+ // TODO(mmenke): Should the network stack require requests with
+ // |ignore_limits| have the highest priority?
+ while (it != pending_requests->end() &&
+ CompareEffectiveRequestPriority(*r, *(*it)) <= 0) {
+ ++it;
+ }
+ pending_requests->insert(it, r);
+}
+
+// static
+const ClientSocketPoolBaseHelper::Request*
+ClientSocketPoolBaseHelper::RemoveRequestFromQueue(
+ const RequestQueue::iterator& it, Group* group) {
+ const Request* req = *it;
+ group->mutable_pending_requests()->erase(it);
+ // If there are no more requests, we kill the backup timer.
+ if (group->pending_requests().empty())
+ group->CleanupBackupJob();
+ return req;
+}
+
+void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) {
+ CHECK(pool);
+ CHECK(!ContainsKey(higher_layer_pools_, pool));
+ higher_layer_pools_.insert(pool);
+}
+
+void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) {
+ CHECK(pool);
+ CHECK(ContainsKey(higher_layer_pools_, pool));
+ higher_layer_pools_.erase(pool);
+}
+
+int ClientSocketPoolBaseHelper::RequestSocket(
+ const std::string& group_name,
+ const Request* request) {
+ CHECK(!request->callback().is_null());
+ CHECK(request->handle());
+
+ // Cleanup any timed-out idle sockets if no timer is used.
+ if (!use_cleanup_timer_)
+ CleanupIdleSockets(false);
+
+ request->net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL);
+ Group* group = GetOrCreateGroup(group_name);
+
+ int rv = RequestSocketInternal(group_name, request);
+ if (rv != ERR_IO_PENDING) {
+ request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv);
+ CHECK(!request->handle()->is_initialized());
+ delete request;
+ } else {
+ InsertRequestIntoQueue(request, group->mutable_pending_requests());
+ // Have to do this asynchronously, as closing sockets in higher level pools
+ // call back in to |this|, which will cause all sorts of fun and exciting
+ // re-entrancy issues if the socket pool is doing something else at the
+ // time.
+ if (group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(
+ &ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools,
+ weak_factory_.GetWeakPtr()));
+ }
+ }
+ return rv;
+}
+
+void ClientSocketPoolBaseHelper::RequestSockets(
+ const std::string& group_name,
+ const Request& request,
+ int num_sockets) {
+ DCHECK(request.callback().is_null());
+ DCHECK(!request.handle());
+
+ // Cleanup any timed out idle sockets if no timer is used.
+ if (!use_cleanup_timer_)
+ CleanupIdleSockets(false);
+
+ if (num_sockets > max_sockets_per_group_) {
+ num_sockets = max_sockets_per_group_;
+ }
+
+ request.net_log().BeginEvent(
+ NetLog::TYPE_SOCKET_POOL_CONNECTING_N_SOCKETS,
+ NetLog::IntegerCallback("num_sockets", num_sockets));
+
+ Group* group = GetOrCreateGroup(group_name);
+
+ // RequestSocketsInternal() may delete the group.
+ bool deleted_group = false;
+
+ int rv = OK;
+ for (int num_iterations_left = num_sockets;
+ group->NumActiveSocketSlots() < num_sockets &&
+ num_iterations_left > 0 ; num_iterations_left--) {
+ rv = RequestSocketInternal(group_name, &request);
+ if (rv < 0 && rv != ERR_IO_PENDING) {
+ // We're encountering a synchronous error. Give up.
+ if (!ContainsKey(group_map_, group_name))
+ deleted_group = true;
+ break;
+ }
+ if (!ContainsKey(group_map_, group_name)) {
+ // Unexpected. The group should only be getting deleted on synchronous
+ // error.
+ NOTREACHED();
+ deleted_group = true;
+ break;
+ }
+ }
+
+ if (!deleted_group && group->IsEmpty())
+ RemoveGroup(group_name);
+
+ if (rv == ERR_IO_PENDING)
+ rv = OK;
+ request.net_log().EndEventWithNetErrorCode(
+ NetLog::TYPE_SOCKET_POOL_CONNECTING_N_SOCKETS, rv);
+}
+
+int ClientSocketPoolBaseHelper::RequestSocketInternal(
+ const std::string& group_name,
+ const Request* request) {
+ ClientSocketHandle* const handle = request->handle();
+ const bool preconnecting = !handle;
+ Group* group = GetOrCreateGroup(group_name);
+
+ if (!(request->flags() & NO_IDLE_SOCKETS)) {
+ // Try to reuse a socket.
+ if (AssignIdleSocketToRequest(request, group))
+ return OK;
+ }
+
+ // If there are more ConnectJobs than pending requests, don't need to do
+ // anything. Can just wait for the extra job to connect, and then assign it
+ // to the request.
+ if (!preconnecting && group->TryToUseUnassignedConnectJob())
+ return ERR_IO_PENDING;
+
+ // Can we make another active socket now?
+ if (!group->HasAvailableSocketSlot(max_sockets_per_group_) &&
+ !request->ignore_limits()) {
+ // TODO(willchan): Consider whether or not we need to close a socket in a
+ // higher layered group. I don't think this makes sense since we would just
+ // reuse that socket then if we needed one and wouldn't make it down to this
+ // layer.
+ request->net_log().AddEvent(
+ NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP);
+ return ERR_IO_PENDING;
+ }
+
+ if (ReachedMaxSocketsLimit() && !request->ignore_limits()) {
+ // NOTE(mmenke): Wonder if we really need different code for each case
+ // here. Only reason for them now seems to be preconnects.
+ if (idle_socket_count() > 0) {
+ // There's an idle socket in this pool. Either that's because there's
+ // still one in this group, but we got here due to preconnecting bypassing
+ // idle sockets, or because there's an idle socket in another group.
+ bool closed = CloseOneIdleSocketExceptInGroup(group);
+ if (preconnecting && !closed)
+ return ERR_PRECONNECT_MAX_SOCKET_LIMIT;
+ } else {
+ // We could check if we really have a stalled group here, but it requires
+ // a scan of all groups, so just flip a flag here, and do the check later.
+ request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS);
+ return ERR_IO_PENDING;
+ }
+ }
+
+ // We couldn't find a socket to reuse, and there's space to allocate one,
+ // so allocate and connect a new one.
+ scoped_ptr<ConnectJob> connect_job(
+ connect_job_factory_->NewConnectJob(group_name, *request, this));
+
+ int rv = connect_job->Connect();
+ if (rv == OK) {
+ LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
+ if (!preconnecting) {
+ HandOutSocket(connect_job->PassSocket(), false /* not reused */,
+ connect_job->connect_timing(), handle, base::TimeDelta(),
+ group, request->net_log());
+ } else {
+ AddIdleSocket(connect_job->PassSocket(), group);
+ }
+ } else if (rv == ERR_IO_PENDING) {
+ // If we don't have any sockets in this group, set a timer for potentially
+ // creating a new one. If the SYN is lost, this backup socket may complete
+ // before the slow socket, improving end user latency.
+ if (connect_backup_jobs_enabled_ &&
+ group->IsEmpty() && !group->HasBackupJob()) {
+ group->StartBackupSocketTimer(group_name, this);
+ }
+
+ connecting_socket_count_++;
+
+ group->AddJob(connect_job.Pass(), preconnecting);
+ } else {
+ LogBoundConnectJobToRequest(connect_job->net_log().source(), request);
+ scoped_ptr<StreamSocket> error_socket;
+ if (!preconnecting) {
+ DCHECK(handle);
+ connect_job->GetAdditionalErrorState(handle);
+ error_socket = connect_job->PassSocket();
+ }
+ if (error_socket) {
+ HandOutSocket(error_socket.Pass(), false /* not reused */,
+ connect_job->connect_timing(), handle, base::TimeDelta(),
+ group, request->net_log());
+ } else if (group->IsEmpty()) {
+ RemoveGroup(group_name);
+ }
+ }
+
+ return rv;
+}
+
+bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest(
+ const Request* request, Group* group) {
+ std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets();
+ std::list<IdleSocket>::iterator idle_socket_it = idle_sockets->end();
+
+ // Iterate through the idle sockets forwards (oldest to newest)
+ // * Delete any disconnected ones.
+ // * If we find a used idle socket, assign to |idle_socket|. At the end,
+ // the |idle_socket_it| will be set to the newest used idle socket.
+ for (std::list<IdleSocket>::iterator it = idle_sockets->begin();
+ it != idle_sockets->end();) {
+ if (!it->socket->IsConnectedAndIdle()) {
+ DecrementIdleCount();
+ delete it->socket;
+ it = idle_sockets->erase(it);
+ continue;
+ }
+
+ if (it->socket->WasEverUsed()) {
+ // We found one we can reuse!
+ idle_socket_it = it;
+ }
+
+ ++it;
+ }
+
+ // If we haven't found an idle socket, that means there are no used idle
+ // sockets. Pick the oldest (first) idle socket (FIFO).
+
+ if (idle_socket_it == idle_sockets->end() && !idle_sockets->empty())
+ idle_socket_it = idle_sockets->begin();
+
+ if (idle_socket_it != idle_sockets->end()) {
+ DecrementIdleCount();
+ base::TimeDelta idle_time =
+ base::TimeTicks::Now() - idle_socket_it->start_time;
+ IdleSocket idle_socket = *idle_socket_it;
+ idle_sockets->erase(idle_socket_it);
+ HandOutSocket(
+ scoped_ptr<StreamSocket>(idle_socket.socket),
+ idle_socket.socket->WasEverUsed(),
+ LoadTimingInfo::ConnectTiming(),
+ request->handle(),
+ idle_time,
+ group,
+ request->net_log());
+ return true;
+ }
+
+ return false;
+}
+
+// static
+void ClientSocketPoolBaseHelper::LogBoundConnectJobToRequest(
+ const NetLog::Source& connect_job_source, const Request* request) {
+ request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ connect_job_source.ToEventParametersCallback());
+}
+
+void ClientSocketPoolBaseHelper::CancelRequest(
+ const std::string& group_name, ClientSocketHandle* handle) {
+ PendingCallbackMap::iterator callback_it = pending_callback_map_.find(handle);
+ if (callback_it != pending_callback_map_.end()) {
+ int result = callback_it->second.result;
+ pending_callback_map_.erase(callback_it);
+ scoped_ptr<StreamSocket> socket = handle->PassSocket();
+ if (socket) {
+ if (result != OK)
+ socket->Disconnect();
+ ReleaseSocket(handle->group_name(), socket.Pass(), handle->id());
+ }
+ return;
+ }
+
+ CHECK(ContainsKey(group_map_, group_name));
+
+ Group* group = GetOrCreateGroup(group_name);
+
+ // Search pending_requests for matching handle.
+ RequestQueue::iterator it = group->mutable_pending_requests()->begin();
+ for (; it != group->pending_requests().end(); ++it) {
+ if ((*it)->handle() == handle) {
+ scoped_ptr<const Request> req(RemoveRequestFromQueue(it, group));
+ req->net_log().AddEvent(NetLog::TYPE_CANCELLED);
+ req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
+
+ // We let the job run, unless we're at the socket limit and there is
+ // not another request waiting on the job.
+ if (group->jobs().size() > group->pending_requests().size() &&
+ ReachedMaxSocketsLimit()) {
+ RemoveConnectJob(*group->jobs().begin(), group);
+ CheckForStalledSocketGroups();
+ }
+ break;
+ }
+ }
+}
+
+bool ClientSocketPoolBaseHelper::HasGroup(const std::string& group_name) const {
+ return ContainsKey(group_map_, group_name);
+}
+
+void ClientSocketPoolBaseHelper::CloseIdleSockets() {
+ CleanupIdleSockets(true);
+ DCHECK_EQ(0, idle_socket_count_);
+}
+
+int ClientSocketPoolBaseHelper::IdleSocketCountInGroup(
+ const std::string& group_name) const {
+ GroupMap::const_iterator i = group_map_.find(group_name);
+ CHECK(i != group_map_.end());
+
+ return i->second->idle_sockets().size();
+}
+
+LoadState ClientSocketPoolBaseHelper::GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const {
+ if (ContainsKey(pending_callback_map_, handle))
+ return LOAD_STATE_CONNECTING;
+
+ if (!ContainsKey(group_map_, group_name)) {
+ NOTREACHED() << "ClientSocketPool does not contain group: " << group_name
+ << " for handle: " << handle;
+ return LOAD_STATE_IDLE;
+ }
+
+ // Can't use operator[] since it is non-const.
+ const Group& group = *group_map_.find(group_name)->second;
+
+ // Search the first group.jobs().size() |pending_requests| for |handle|.
+ // If it's farther back in the deque than that, it doesn't have a
+ // corresponding ConnectJob.
+ size_t connect_jobs = group.jobs().size();
+ RequestQueue::const_iterator it = group.pending_requests().begin();
+ for (size_t i = 0; it != group.pending_requests().end() && i < connect_jobs;
+ ++it, ++i) {
+ if ((*it)->handle() != handle)
+ continue;
+
+ // Just return the state of the farthest along ConnectJob for the first
+ // group.jobs().size() pending requests.
+ LoadState max_state = LOAD_STATE_IDLE;
+ for (ConnectJobSet::const_iterator job_it = group.jobs().begin();
+ job_it != group.jobs().end(); ++job_it) {
+ max_state = std::max(max_state, (*job_it)->GetLoadState());
+ }
+ return max_state;
+ }
+
+ if (group.IsStalledOnPoolMaxSockets(max_sockets_per_group_))
+ return LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL;
+ return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
+}
+
+base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue(
+ const std::string& name, const std::string& type) const {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("name", name);
+ dict->SetString("type", type);
+ dict->SetInteger("handed_out_socket_count", handed_out_socket_count_);
+ dict->SetInteger("connecting_socket_count", connecting_socket_count_);
+ dict->SetInteger("idle_socket_count", idle_socket_count_);
+ dict->SetInteger("max_socket_count", max_sockets_);
+ dict->SetInteger("max_sockets_per_group", max_sockets_per_group_);
+ dict->SetInteger("pool_generation_number", pool_generation_number_);
+
+ if (group_map_.empty())
+ return dict;
+
+ base::DictionaryValue* all_groups_dict = new base::DictionaryValue();
+ for (GroupMap::const_iterator it = group_map_.begin();
+ it != group_map_.end(); it++) {
+ const Group* group = it->second;
+ base::DictionaryValue* group_dict = new base::DictionaryValue();
+
+ group_dict->SetInteger("pending_request_count",
+ group->pending_requests().size());
+ if (!group->pending_requests().empty()) {
+ group_dict->SetInteger("top_pending_priority",
+ group->TopPendingPriority());
+ }
+
+ group_dict->SetInteger("active_socket_count", group->active_socket_count());
+
+ base::ListValue* idle_socket_list = new base::ListValue();
+ std::list<IdleSocket>::const_iterator idle_socket;
+ for (idle_socket = group->idle_sockets().begin();
+ idle_socket != group->idle_sockets().end();
+ idle_socket++) {
+ int source_id = idle_socket->socket->NetLog().source().id;
+ idle_socket_list->Append(new base::FundamentalValue(source_id));
+ }
+ group_dict->Set("idle_sockets", idle_socket_list);
+
+ base::ListValue* connect_jobs_list = new base::ListValue();
+ std::set<ConnectJob*>::const_iterator job = group->jobs().begin();
+ for (job = group->jobs().begin(); job != group->jobs().end(); job++) {
+ int source_id = (*job)->net_log().source().id;
+ connect_jobs_list->Append(new base::FundamentalValue(source_id));
+ }
+ group_dict->Set("connect_jobs", connect_jobs_list);
+
+ group_dict->SetBoolean("is_stalled",
+ group->IsStalledOnPoolMaxSockets(
+ max_sockets_per_group_));
+ group_dict->SetBoolean("has_backup_job", group->HasBackupJob());
+
+ all_groups_dict->SetWithoutPathExpansion(it->first, group_dict);
+ }
+ dict->Set("groups", all_groups_dict);
+ return dict;
+}
+
+bool ClientSocketPoolBaseHelper::IdleSocket::ShouldCleanup(
+ base::TimeTicks now,
+ base::TimeDelta timeout) const {
+ bool timed_out = (now - start_time) >= timeout;
+ if (timed_out)
+ return true;
+ if (socket->WasEverUsed())
+ return !socket->IsConnectedAndIdle();
+ return !socket->IsConnected();
+}
+
+void ClientSocketPoolBaseHelper::CleanupIdleSockets(bool force) {
+ if (idle_socket_count_ == 0)
+ return;
+
+ // Current time value. Retrieving it once at the function start rather than
+ // inside the inner loop, since it shouldn't change by any meaningful amount.
+ base::TimeTicks now = base::TimeTicks::Now();
+
+ GroupMap::iterator i = group_map_.begin();
+ while (i != group_map_.end()) {
+ Group* group = i->second;
+
+ std::list<IdleSocket>::iterator j = group->mutable_idle_sockets()->begin();
+ while (j != group->idle_sockets().end()) {
+ base::TimeDelta timeout =
+ j->socket->WasEverUsed() ?
+ used_idle_socket_timeout_ : unused_idle_socket_timeout_;
+ if (force || j->ShouldCleanup(now, timeout)) {
+ delete j->socket;
+ j = group->mutable_idle_sockets()->erase(j);
+ DecrementIdleCount();
+ } else {
+ ++j;
+ }
+ }
+
+ // Delete group if no longer needed.
+ if (group->IsEmpty()) {
+ RemoveGroup(i++);
+ } else {
+ ++i;
+ }
+ }
+}
+
+ClientSocketPoolBaseHelper::Group* ClientSocketPoolBaseHelper::GetOrCreateGroup(
+ const std::string& group_name) {
+ GroupMap::iterator it = group_map_.find(group_name);
+ if (it != group_map_.end())
+ return it->second;
+ Group* group = new Group;
+ group_map_[group_name] = group;
+ return group;
+}
+
+void ClientSocketPoolBaseHelper::RemoveGroup(const std::string& group_name) {
+ GroupMap::iterator it = group_map_.find(group_name);
+ CHECK(it != group_map_.end());
+
+ RemoveGroup(it);
+}
+
+void ClientSocketPoolBaseHelper::RemoveGroup(GroupMap::iterator it) {
+ delete it->second;
+ group_map_.erase(it);
+}
+
+// static
+bool ClientSocketPoolBaseHelper::connect_backup_jobs_enabled() {
+ return g_connect_backup_jobs_enabled;
+}
+
+// static
+bool ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(bool enabled) {
+ bool old_value = g_connect_backup_jobs_enabled;
+ g_connect_backup_jobs_enabled = enabled;
+ return old_value;
+}
+
+void ClientSocketPoolBaseHelper::EnableConnectBackupJobs() {
+ connect_backup_jobs_enabled_ = g_connect_backup_jobs_enabled;
+}
+
+void ClientSocketPoolBaseHelper::IncrementIdleCount() {
+ if (++idle_socket_count_ == 1 && use_cleanup_timer_)
+ StartIdleSocketTimer();
+}
+
+void ClientSocketPoolBaseHelper::DecrementIdleCount() {
+ if (--idle_socket_count_ == 0)
+ timer_.Stop();
+}
+
+// static
+bool ClientSocketPoolBaseHelper::cleanup_timer_enabled() {
+ return g_cleanup_timer_enabled;
+}
+
+// static
+bool ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(bool enabled) {
+ bool old_value = g_cleanup_timer_enabled;
+ g_cleanup_timer_enabled = enabled;
+ return old_value;
+}
+
+void ClientSocketPoolBaseHelper::StartIdleSocketTimer() {
+ timer_.Start(FROM_HERE, TimeDelta::FromSeconds(kCleanupInterval), this,
+ &ClientSocketPoolBaseHelper::OnCleanupTimerFired);
+}
+
+void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ GroupMap::iterator i = group_map_.find(group_name);
+ CHECK(i != group_map_.end());
+
+ Group* group = i->second;
+
+ CHECK_GT(handed_out_socket_count_, 0);
+ handed_out_socket_count_--;
+
+ CHECK_GT(group->active_socket_count(), 0);
+ group->DecrementActiveSocketCount();
+
+ const bool can_reuse = socket->IsConnectedAndIdle() &&
+ id == pool_generation_number_;
+ if (can_reuse) {
+ // Add it to the idle list.
+ AddIdleSocket(socket.Pass(), group);
+ OnAvailableSocketSlot(group_name, group);
+ } else {
+ socket.reset();
+ }
+
+ CheckForStalledSocketGroups();
+}
+
+void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() {
+ // If we have idle sockets, see if we can give one to the top-stalled group.
+ std::string top_group_name;
+ Group* top_group = NULL;
+ if (!FindTopStalledGroup(&top_group, &top_group_name))
+ return;
+
+ if (ReachedMaxSocketsLimit()) {
+ if (idle_socket_count() > 0) {
+ CloseOneIdleSocket();
+ } else {
+ // We can't activate more sockets since we're already at our global
+ // limit.
+ return;
+ }
+ }
+
+ // Note: we don't loop on waking stalled groups. If the stalled group is at
+ // its limit, may be left with other stalled groups that could be
+ // woken. This isn't optimal, but there is no starvation, so to avoid
+ // the looping we leave it at this.
+ OnAvailableSocketSlot(top_group_name, top_group);
+}
+
+// Search for the highest priority pending request, amongst the groups that
+// are not at the |max_sockets_per_group_| limit. Note: for requests with
+// the same priority, the winner is based on group hash ordering (and not
+// insertion order).
+bool ClientSocketPoolBaseHelper::FindTopStalledGroup(
+ Group** group,
+ std::string* group_name) const {
+ CHECK((group && group_name) || (!group && !group_name));
+ Group* top_group = NULL;
+ const std::string* top_group_name = NULL;
+ bool has_stalled_group = false;
+ for (GroupMap::const_iterator i = group_map_.begin();
+ i != group_map_.end(); ++i) {
+ Group* curr_group = i->second;
+ const RequestQueue& queue = curr_group->pending_requests();
+ if (queue.empty())
+ continue;
+ if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) {
+ if (!group)
+ return true;
+ has_stalled_group = true;
+ bool has_higher_priority = !top_group ||
+ curr_group->TopPendingPriority() > top_group->TopPendingPriority();
+ if (has_higher_priority) {
+ top_group = curr_group;
+ top_group_name = &i->first;
+ }
+ }
+ }
+
+ if (top_group) {
+ CHECK(group);
+ *group = top_group;
+ *group_name = *top_group_name;
+ } else {
+ CHECK(!has_stalled_group);
+ }
+ return has_stalled_group;
+}
+
+void ClientSocketPoolBaseHelper::OnConnectJobComplete(
+ int result, ConnectJob* job) {
+ DCHECK_NE(ERR_IO_PENDING, result);
+ const std::string group_name = job->group_name();
+ GroupMap::iterator group_it = group_map_.find(group_name);
+ CHECK(group_it != group_map_.end());
+ Group* group = group_it->second;
+
+ scoped_ptr<StreamSocket> socket = job->PassSocket();
+
+ // Copies of these are needed because |job| may be deleted before they are
+ // accessed.
+ BoundNetLog job_log = job->net_log();
+ LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing();
+
+ // RemoveConnectJob(job, _) must be called by all branches below;
+ // otherwise, |job| will be leaked.
+
+ if (result == OK) {
+ DCHECK(socket.get());
+ RemoveConnectJob(job, group);
+ if (!group->pending_requests().empty()) {
+ scoped_ptr<const Request> r(RemoveRequestFromQueue(
+ group->mutable_pending_requests()->begin(), group));
+ LogBoundConnectJobToRequest(job_log.source(), r.get());
+ HandOutSocket(
+ socket.Pass(), false /* unused socket */, connect_timing,
+ r->handle(), base::TimeDelta(), group, r->net_log());
+ r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL);
+ InvokeUserCallbackLater(r->handle(), r->callback(), result);
+ } else {
+ AddIdleSocket(socket.Pass(), group);
+ OnAvailableSocketSlot(group_name, group);
+ CheckForStalledSocketGroups();
+ }
+ } else {
+ // If we got a socket, it must contain error information so pass that
+ // up so that the caller can retrieve it.
+ bool handed_out_socket = false;
+ if (!group->pending_requests().empty()) {
+ scoped_ptr<const Request> r(RemoveRequestFromQueue(
+ group->mutable_pending_requests()->begin(), group));
+ LogBoundConnectJobToRequest(job_log.source(), r.get());
+ job->GetAdditionalErrorState(r->handle());
+ RemoveConnectJob(job, group);
+ if (socket.get()) {
+ handed_out_socket = true;
+ HandOutSocket(socket.Pass(), false /* unused socket */,
+ connect_timing, r->handle(), base::TimeDelta(), group,
+ r->net_log());
+ }
+ r->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result);
+ InvokeUserCallbackLater(r->handle(), r->callback(), result);
+ } else {
+ RemoveConnectJob(job, group);
+ }
+ if (!handed_out_socket) {
+ OnAvailableSocketSlot(group_name, group);
+ CheckForStalledSocketGroups();
+ }
+ }
+}
+
+void ClientSocketPoolBaseHelper::OnIPAddressChanged() {
+ FlushWithError(ERR_NETWORK_CHANGED);
+}
+
+void ClientSocketPoolBaseHelper::FlushWithError(int error) {
+ pool_generation_number_++;
+ CancelAllConnectJobs();
+ CloseIdleSockets();
+ CancelAllRequestsWithError(error);
+}
+
+bool ClientSocketPoolBaseHelper::IsStalled() const {
+ // If we are not using |max_sockets_|, then clearly we are not stalled
+ if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_)
+ return false;
+ // So in order to be stalled we need to be using |max_sockets_| AND
+ // we need to have a request that is actually stalled on the global
+ // socket limit. To find such a request, we look for a group that
+ // a has more requests that jobs AND where the number of jobs is less
+ // than |max_sockets_per_group_|. (If the number of jobs is equal to
+ // |max_sockets_per_group_|, then the request is stalled on the group,
+ // which does not count.)
+ for (GroupMap::const_iterator it = group_map_.begin();
+ it != group_map_.end(); ++it) {
+ if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_))
+ return true;
+ }
+ return false;
+}
+
+void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job,
+ Group* group) {
+ CHECK_GT(connecting_socket_count_, 0);
+ connecting_socket_count_--;
+
+ DCHECK(group);
+ group->RemoveJob(job);
+
+ // If we've got no more jobs for this group, then we no longer need a
+ // backup job either.
+ if (group->jobs().empty())
+ group->CleanupBackupJob();
+}
+
+void ClientSocketPoolBaseHelper::OnAvailableSocketSlot(
+ const std::string& group_name, Group* group) {
+ DCHECK(ContainsKey(group_map_, group_name));
+ if (group->IsEmpty())
+ RemoveGroup(group_name);
+ else if (!group->pending_requests().empty())
+ ProcessPendingRequest(group_name, group);
+}
+
+void ClientSocketPoolBaseHelper::ProcessPendingRequest(
+ const std::string& group_name, Group* group) {
+ int rv = RequestSocketInternal(group_name,
+ *group->pending_requests().begin());
+ if (rv != ERR_IO_PENDING) {
+ scoped_ptr<const Request> request(RemoveRequestFromQueue(
+ group->mutable_pending_requests()->begin(), group));
+ if (group->IsEmpty())
+ RemoveGroup(group_name);
+
+ request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv);
+ InvokeUserCallbackLater(request->handle(), request->callback(), rv);
+ }
+}
+
+void ClientSocketPoolBaseHelper::HandOutSocket(
+ scoped_ptr<StreamSocket> socket,
+ bool reused,
+ const LoadTimingInfo::ConnectTiming& connect_timing,
+ ClientSocketHandle* handle,
+ base::TimeDelta idle_time,
+ Group* group,
+ const BoundNetLog& net_log) {
+ DCHECK(socket);
+ handle->SetSocket(socket.Pass());
+ handle->set_is_reused(reused);
+ handle->set_idle_time(idle_time);
+ handle->set_pool_id(pool_generation_number_);
+ handle->set_connect_timing(connect_timing);
+
+ if (reused) {
+ net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET,
+ NetLog::IntegerCallback(
+ "idle_ms", static_cast<int>(idle_time.InMilliseconds())));
+ }
+
+ net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ handle->socket()->NetLog().source().ToEventParametersCallback());
+
+ handed_out_socket_count_++;
+ group->IncrementActiveSocketCount();
+}
+
+void ClientSocketPoolBaseHelper::AddIdleSocket(
+ scoped_ptr<StreamSocket> socket,
+ Group* group) {
+ DCHECK(socket);
+ IdleSocket idle_socket;
+ idle_socket.socket = socket.release();
+ idle_socket.start_time = base::TimeTicks::Now();
+
+ group->mutable_idle_sockets()->push_back(idle_socket);
+ IncrementIdleCount();
+}
+
+void ClientSocketPoolBaseHelper::CancelAllConnectJobs() {
+ for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) {
+ Group* group = i->second;
+ connecting_socket_count_ -= group->jobs().size();
+ group->RemoveAllJobs();
+
+ // Delete group if no longer needed.
+ if (group->IsEmpty()) {
+ // RemoveGroup() will call .erase() which will invalidate the iterator,
+ // but i will already have been incremented to a valid iterator before
+ // RemoveGroup() is called.
+ RemoveGroup(i++);
+ } else {
+ ++i;
+ }
+ }
+ DCHECK_EQ(0, connecting_socket_count_);
+}
+
+void ClientSocketPoolBaseHelper::CancelAllRequestsWithError(int error) {
+ for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) {
+ Group* group = i->second;
+
+ RequestQueue pending_requests;
+ pending_requests.swap(*group->mutable_pending_requests());
+ for (RequestQueue::iterator it2 = pending_requests.begin();
+ it2 != pending_requests.end(); ++it2) {
+ scoped_ptr<const Request> request(*it2);
+ InvokeUserCallbackLater(
+ request->handle(), request->callback(), error);
+ }
+
+ // Delete group if no longer needed.
+ if (group->IsEmpty()) {
+ // RemoveGroup() will call .erase() which will invalidate the iterator,
+ // but i will already have been incremented to a valid iterator before
+ // RemoveGroup() is called.
+ RemoveGroup(i++);
+ } else {
+ ++i;
+ }
+ }
+}
+
+bool ClientSocketPoolBaseHelper::ReachedMaxSocketsLimit() const {
+ // Each connecting socket will eventually connect and be handed out.
+ int total = handed_out_socket_count_ + connecting_socket_count_ +
+ idle_socket_count();
+ // There can be more sockets than the limit since some requests can ignore
+ // the limit
+ if (total < max_sockets_)
+ return false;
+ return true;
+}
+
+bool ClientSocketPoolBaseHelper::CloseOneIdleSocket() {
+ if (idle_socket_count() == 0)
+ return false;
+ return CloseOneIdleSocketExceptInGroup(NULL);
+}
+
+bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup(
+ const Group* exception_group) {
+ CHECK_GT(idle_socket_count(), 0);
+
+ for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end(); ++i) {
+ Group* group = i->second;
+ if (exception_group == group)
+ continue;
+ std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets();
+
+ if (!idle_sockets->empty()) {
+ delete idle_sockets->front().socket;
+ idle_sockets->pop_front();
+ DecrementIdleCount();
+ if (group->IsEmpty())
+ RemoveGroup(i);
+
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() {
+ // This pool doesn't have any idle sockets. It's possible that a pool at a
+ // higher layer is holding one of this sockets active, but it's actually idle.
+ // Query the higher layers.
+ for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin();
+ it != higher_layer_pools_.end(); ++it) {
+ if ((*it)->CloseOneIdleConnection())
+ return true;
+ }
+ return false;
+}
+
+void ClientSocketPoolBaseHelper::InvokeUserCallbackLater(
+ ClientSocketHandle* handle, const CompletionCallback& callback, int rv) {
+ CHECK(!ContainsKey(pending_callback_map_, handle));
+ pending_callback_map_[handle] = CallbackResultPair(callback, rv);
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&ClientSocketPoolBaseHelper::InvokeUserCallback,
+ weak_factory_.GetWeakPtr(), handle));
+}
+
+void ClientSocketPoolBaseHelper::InvokeUserCallback(
+ ClientSocketHandle* handle) {
+ PendingCallbackMap::iterator it = pending_callback_map_.find(handle);
+
+ // Exit if the request has already been cancelled.
+ if (it == pending_callback_map_.end())
+ return;
+
+ CHECK(!handle->is_initialized());
+ CompletionCallback callback = it->second.callback;
+ int result = it->second.result;
+ pending_callback_map_.erase(it);
+ callback.Run(result);
+}
+
+void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() {
+ while (IsStalled()) {
+ // Closing a socket will result in calling back into |this| to use the freed
+ // socket slot, so nothing else is needed.
+ if (!CloseOneIdleConnectionInLayeredPool())
+ return;
+ }
+}
+
+ClientSocketPoolBaseHelper::Group::Group()
+ : unassigned_job_count_(0),
+ active_socket_count_(0),
+ weak_factory_(this) {}
+
+ClientSocketPoolBaseHelper::Group::~Group() {
+ CleanupBackupJob();
+ DCHECK_EQ(0u, unassigned_job_count_);
+}
+
+void ClientSocketPoolBaseHelper::Group::StartBackupSocketTimer(
+ const std::string& group_name,
+ ClientSocketPoolBaseHelper* pool) {
+ // Only allow one timer pending to create a backup socket.
+ if (weak_factory_.HasWeakPtrs())
+ return;
+
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&Group::OnBackupSocketTimerFired, weak_factory_.GetWeakPtr(),
+ group_name, pool),
+ pool->ConnectRetryInterval());
+}
+
+bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() {
+ SanityCheck();
+
+ if (unassigned_job_count_ == 0)
+ return false;
+ --unassigned_job_count_;
+ return true;
+}
+
+void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job,
+ bool is_preconnect) {
+ SanityCheck();
+
+ if (is_preconnect)
+ ++unassigned_job_count_;
+ jobs_.insert(job.release());
+}
+
+void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) {
+ scoped_ptr<ConnectJob> owned_job(job);
+ SanityCheck();
+
+ std::set<ConnectJob*>::iterator it = jobs_.find(job);
+ if (it != jobs_.end()) {
+ jobs_.erase(it);
+ } else {
+ NOTREACHED();
+ }
+ size_t job_count = jobs_.size();
+ if (job_count < unassigned_job_count_)
+ unassigned_job_count_ = job_count;
+}
+
+void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired(
+ std::string group_name,
+ ClientSocketPoolBaseHelper* pool) {
+ // If there are no more jobs pending, there is no work to do.
+ // If we've done our cleanups correctly, this should not happen.
+ if (jobs_.empty()) {
+ NOTREACHED();
+ return;
+ }
+
+ // If our old job is waiting on DNS, or if we can't create any sockets
+ // right now due to limits, just reset the timer.
+ if (pool->ReachedMaxSocketsLimit() ||
+ !HasAvailableSocketSlot(pool->max_sockets_per_group_) ||
+ (*jobs_.begin())->GetLoadState() == LOAD_STATE_RESOLVING_HOST) {
+ StartBackupSocketTimer(group_name, pool);
+ return;
+ }
+
+ if (pending_requests_.empty())
+ return;
+
+ scoped_ptr<ConnectJob> backup_job =
+ pool->connect_job_factory_->NewConnectJob(
+ group_name, **pending_requests_.begin(), pool);
+ backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED);
+ SIMPLE_STATS_COUNTER("socket.backup_created");
+ int rv = backup_job->Connect();
+ pool->connecting_socket_count_++;
+ ConnectJob* raw_backup_job = backup_job.get();
+ AddJob(backup_job.Pass(), false);
+ if (rv != ERR_IO_PENDING)
+ pool->OnConnectJobComplete(rv, raw_backup_job);
+}
+
+void ClientSocketPoolBaseHelper::Group::SanityCheck() {
+ DCHECK_LE(unassigned_job_count_, jobs_.size());
+}
+
+void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() {
+ SanityCheck();
+
+ // Delete active jobs.
+ STLDeleteElements(&jobs_);
+ unassigned_job_count_ = 0;
+
+ // Cancel pending backup job.
+ weak_factory_.InvalidateWeakPtrs();
+}
+
+} // namespace internal
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h
new file mode 100644
index 00000000000..eb642edd730
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_base.h
@@ -0,0 +1,819 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// A ClientSocketPoolBase is used to restrict the number of sockets open at
+// a time. It also maintains a list of idle persistent sockets for reuse.
+// Subclasses of ClientSocketPool should compose ClientSocketPoolBase to handle
+// the core logic of (1) restricting the number of active (connected or
+// connecting) sockets per "group" (generally speaking, the hostname), (2)
+// maintaining a per-group list of idle, persistent sockets for reuse, and (3)
+// limiting the total number of active sockets in the system.
+//
+// ClientSocketPoolBase abstracts socket connection details behind ConnectJob,
+// ConnectJobFactory, and SocketParams. When a socket "slot" becomes available,
+// the ClientSocketPoolBase will ask the ConnectJobFactory to create a
+// ConnectJob with a SocketParams. Subclasses of ClientSocketPool should
+// implement their socket specific connection by subclassing ConnectJob and
+// implementing ConnectJob::ConnectInternal(). They can control the parameters
+// passed to each new ConnectJob instance via their ConnectJobFactory subclass
+// and templated SocketParams parameter.
+//
+#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_
+#define NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_
+
+#include <deque>
+#include <list>
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "base/time/time.h"
+#include "base/timer/timer.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/load_states.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/base/network_change_notifier.h"
+#include "net/base/request_priority.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class ClientSocketHandle;
+
+// ConnectJob provides an abstract interface for "connecting" a socket.
+// The connection may involve host resolution, tcp connection, ssl connection,
+// etc.
+class NET_EXPORT_PRIVATE ConnectJob {
+ public:
+ class NET_EXPORT_PRIVATE Delegate {
+ public:
+ Delegate() {}
+ virtual ~Delegate() {}
+
+ // Alerts the delegate that the connection completed. |job| must
+ // be destroyed by the delegate. A scoped_ptr<> isn't used because
+ // the caller of this function doesn't own |job|.
+ virtual void OnConnectJobComplete(int result,
+ ConnectJob* job) = 0;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(Delegate);
+ };
+
+ // A |timeout_duration| of 0 corresponds to no timeout.
+ ConnectJob(const std::string& group_name,
+ base::TimeDelta timeout_duration,
+ Delegate* delegate,
+ const BoundNetLog& net_log);
+ virtual ~ConnectJob();
+
+ // Accessors
+ const std::string& group_name() const { return group_name_; }
+ const BoundNetLog& net_log() { return net_log_; }
+
+ // Releases ownership of the underlying socket to the caller.
+ // Returns the released socket, or NULL if there was a connection
+ // error.
+ scoped_ptr<StreamSocket> PassSocket();
+
+ // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it
+ // cannot complete synchronously without blocking, or another net error code
+ // on error. In asynchronous completion, the ConnectJob will notify
+ // |delegate_| via OnConnectJobComplete. In both asynchronous and synchronous
+ // completion, ReleaseSocket() can be called to acquire the connected socket
+ // if it succeeded.
+ int Connect();
+
+ virtual LoadState GetLoadState() const = 0;
+
+ // If Connect returns an error (or OnConnectJobComplete reports an error
+ // result) this method will be called, allowing the pool to add
+ // additional error state to the ClientSocketHandle (post late-binding).
+ virtual void GetAdditionalErrorState(ClientSocketHandle* handle) {}
+
+ const LoadTimingInfo::ConnectTiming& connect_timing() const {
+ return connect_timing_;
+ }
+
+ const BoundNetLog& net_log() const { return net_log_; }
+
+ protected:
+ void SetSocket(scoped_ptr<StreamSocket> socket);
+ StreamSocket* socket() { return socket_.get(); }
+ void NotifyDelegateOfCompletion(int rv);
+ void ResetTimer(base::TimeDelta remainingTime);
+
+ // Connection establishment timing information.
+ LoadTimingInfo::ConnectTiming connect_timing_;
+
+ private:
+ virtual int ConnectInternal() = 0;
+
+ void LogConnectStart();
+ void LogConnectCompletion(int net_error);
+
+ // Alerts the delegate that the ConnectJob has timed out.
+ void OnTimeout();
+
+ const std::string group_name_;
+ const base::TimeDelta timeout_duration_;
+ // Timer to abort jobs that take too long.
+ base::OneShotTimer<ConnectJob> timer_;
+ Delegate* delegate_;
+ scoped_ptr<StreamSocket> socket_;
+ BoundNetLog net_log_;
+ // A ConnectJob is idle until Connect() has been called.
+ bool idle_;
+
+ DISALLOW_COPY_AND_ASSIGN(ConnectJob);
+};
+
+namespace internal {
+
+// ClientSocketPoolBaseHelper is an internal class that implements almost all
+// the functionality from ClientSocketPoolBase without using templates.
+// ClientSocketPoolBase adds templated definitions built on top of
+// ClientSocketPoolBaseHelper. This class is not for external use, please use
+// ClientSocketPoolBase instead.
+class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
+ : public ConnectJob::Delegate,
+ public NetworkChangeNotifier::IPAddressObserver {
+ public:
+ typedef uint32 Flags;
+
+ // Used to specify specific behavior for the ClientSocketPool.
+ enum Flag {
+ NORMAL = 0, // Normal behavior.
+ NO_IDLE_SOCKETS = 0x1, // Do not return an idle socket. Create a new one.
+ };
+
+ class NET_EXPORT_PRIVATE Request {
+ public:
+ Request(ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ RequestPriority priority,
+ bool ignore_limits,
+ Flags flags,
+ const BoundNetLog& net_log);
+
+ virtual ~Request();
+
+ ClientSocketHandle* handle() const { return handle_; }
+ const CompletionCallback& callback() const { return callback_; }
+ RequestPriority priority() const { return priority_; }
+ bool ignore_limits() const { return ignore_limits_; }
+ Flags flags() const { return flags_; }
+ const BoundNetLog& net_log() const { return net_log_; }
+
+ private:
+ ClientSocketHandle* const handle_;
+ CompletionCallback callback_;
+ const RequestPriority priority_;
+ bool ignore_limits_;
+ const Flags flags_;
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(Request);
+ };
+
+ class ConnectJobFactory {
+ public:
+ ConnectJobFactory() {}
+ virtual ~ConnectJobFactory() {}
+
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const Request& request,
+ ConnectJob::Delegate* delegate) const = 0;
+
+ virtual base::TimeDelta ConnectionTimeout() const = 0;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory);
+ };
+
+ ClientSocketPoolBaseHelper(
+ int max_sockets,
+ int max_sockets_per_group,
+ base::TimeDelta unused_idle_socket_timeout,
+ base::TimeDelta used_idle_socket_timeout,
+ ConnectJobFactory* connect_job_factory);
+
+ virtual ~ClientSocketPoolBaseHelper();
+
+ // Adds/Removes layered pools. It is expected in the destructor that no
+ // layered pools remain.
+ void AddLayeredPool(LayeredPool* pool);
+ void RemoveLayeredPool(LayeredPool* pool);
+
+ // See ClientSocketPool::RequestSocket for documentation on this function.
+ // ClientSocketPoolBaseHelper takes ownership of |request|, which must be
+ // heap allocated.
+ int RequestSocket(const std::string& group_name, const Request* request);
+
+ // See ClientSocketPool::RequestSocket for documentation on this function.
+ void RequestSockets(const std::string& group_name,
+ const Request& request,
+ int num_sockets);
+
+ // See ClientSocketPool::CancelRequest for documentation on this function.
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle);
+
+ // See ClientSocketPool::ReleaseSocket for documentation on this function.
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id);
+
+ // See ClientSocketPool::FlushWithError for documentation on this function.
+ void FlushWithError(int error);
+
+ // See ClientSocketPool::IsStalled for documentation on this function.
+ bool IsStalled() const;
+
+ // See ClientSocketPool::CloseIdleSockets for documentation on this function.
+ void CloseIdleSockets();
+
+ // See ClientSocketPool::IdleSocketCount() for documentation on this function.
+ int idle_socket_count() const {
+ return idle_socket_count_;
+ }
+
+ // See ClientSocketPool::IdleSocketCountInGroup() for documentation on this
+ // function.
+ int IdleSocketCountInGroup(const std::string& group_name) const;
+
+ // See ClientSocketPool::GetLoadState() for documentation on this function.
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const;
+
+ base::TimeDelta ConnectRetryInterval() const {
+ // TODO(mbelshe): Make this tuned dynamically based on measured RTT.
+ // For now, just use the max retry interval.
+ return base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs);
+ }
+
+ int NumUnassignedConnectJobsInGroup(const std::string& group_name) const {
+ return group_map_.find(group_name)->second->unassigned_job_count();
+ }
+
+ int NumConnectJobsInGroup(const std::string& group_name) const {
+ return group_map_.find(group_name)->second->jobs().size();
+ }
+
+ int NumActiveSocketsInGroup(const std::string& group_name) const {
+ return group_map_.find(group_name)->second->active_socket_count();
+ }
+
+ bool HasGroup(const std::string& group_name) const;
+
+ // Called to enable/disable cleaning up idle sockets. When enabled,
+ // idle sockets that have been around for longer than a period defined
+ // by kCleanupInterval are cleaned up using a timer. Otherwise they are
+ // closed next time client makes a request. This may reduce network
+ // activity and power consumption.
+ static bool cleanup_timer_enabled();
+ static bool set_cleanup_timer_enabled(bool enabled);
+
+ // Closes all idle sockets if |force| is true. Else, only closes idle
+ // sockets that timed out or can't be reused. Made public for testing.
+ void CleanupIdleSockets(bool force);
+
+ // Closes one idle socket. Picks the first one encountered.
+ // TODO(willchan): Consider a better algorithm for doing this. Perhaps we
+ // should keep an ordered list of idle sockets, and close them in order.
+ // Requires maintaining more state. It's not clear if it's worth it since
+ // I'm not sure if we hit this situation often.
+ bool CloseOneIdleSocket();
+
+ // Checks layered pools to see if they can close an idle connection.
+ bool CloseOneIdleConnectionInLayeredPool();
+
+ // See ClientSocketPool::GetInfoAsValue for documentation on this function.
+ base::DictionaryValue* GetInfoAsValue(const std::string& name,
+ const std::string& type) const;
+
+ base::TimeDelta ConnectionTimeout() const {
+ return connect_job_factory_->ConnectionTimeout();
+ }
+
+ static bool connect_backup_jobs_enabled();
+ static bool set_connect_backup_jobs_enabled(bool enabled);
+
+ void EnableConnectBackupJobs();
+
+ // ConnectJob::Delegate methods:
+ virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE;
+
+ // NetworkChangeNotifier::IPAddressObserver methods:
+ virtual void OnIPAddressChanged() OVERRIDE;
+
+ private:
+ friend class base::RefCounted<ClientSocketPoolBaseHelper>;
+
+ // Entry for a persistent socket which became idle at time |start_time|.
+ struct IdleSocket {
+ IdleSocket() : socket(NULL) {}
+
+ // An idle socket should be removed if it can't be reused, or has been idle
+ // for too long. |now| is the current time value (TimeTicks::Now()).
+ // |timeout| is the length of time to wait before timing out an idle socket.
+ //
+ // An idle socket can't be reused if it is disconnected or has received
+ // data unexpectedly (hence no longer idle). The unread data would be
+ // mistaken for the beginning of the next response if we were to reuse the
+ // socket for a new request.
+ bool ShouldCleanup(base::TimeTicks now, base::TimeDelta timeout) const;
+
+ StreamSocket* socket;
+ base::TimeTicks start_time;
+ };
+
+ typedef std::deque<const Request* > RequestQueue;
+ typedef std::map<const ClientSocketHandle*, const Request*> RequestMap;
+
+ // A Group is allocated per group_name when there are idle sockets or pending
+ // requests. Otherwise, the Group object is removed from the map.
+ // |active_socket_count| tracks the number of sockets held by clients.
+ class Group {
+ public:
+ Group();
+ ~Group();
+
+ bool IsEmpty() const {
+ return active_socket_count_ == 0 && idle_sockets_.empty() &&
+ jobs_.empty() && pending_requests_.empty();
+ }
+
+ bool HasAvailableSocketSlot(int max_sockets_per_group) const {
+ return NumActiveSocketSlots() < max_sockets_per_group;
+ }
+
+ int NumActiveSocketSlots() const {
+ return active_socket_count_ + static_cast<int>(jobs_.size()) +
+ static_cast<int>(idle_sockets_.size());
+ }
+
+ bool IsStalledOnPoolMaxSockets(int max_sockets_per_group) const {
+ return HasAvailableSocketSlot(max_sockets_per_group) &&
+ pending_requests_.size() > jobs_.size();
+ }
+
+ RequestPriority TopPendingPriority() const {
+ return pending_requests_.front()->priority();
+ }
+
+ bool HasBackupJob() const { return weak_factory_.HasWeakPtrs(); }
+
+ void CleanupBackupJob() {
+ weak_factory_.InvalidateWeakPtrs();
+ }
+
+ // Set a timer to create a backup socket if it takes too long to create one.
+ void StartBackupSocketTimer(const std::string& group_name,
+ ClientSocketPoolBaseHelper* pool);
+
+ // If there's a ConnectJob that's never been assigned to Request,
+ // decrements |unassigned_job_count_| and returns true.
+ // Otherwise, returns false.
+ bool TryToUseUnassignedConnectJob();
+
+ void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect);
+ // Remove |job| from this group, which must already own |job|.
+ void RemoveJob(ConnectJob* job);
+ void RemoveAllJobs();
+
+ void IncrementActiveSocketCount() { active_socket_count_++; }
+ void DecrementActiveSocketCount() { active_socket_count_--; }
+
+ int unassigned_job_count() const { return unassigned_job_count_; }
+ const std::set<ConnectJob*>& jobs() const { return jobs_; }
+ const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; }
+ const RequestQueue& pending_requests() const { return pending_requests_; }
+ int active_socket_count() const { return active_socket_count_; }
+ RequestQueue* mutable_pending_requests() { return &pending_requests_; }
+ std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; }
+
+ private:
+ // Called when the backup socket timer fires.
+ void OnBackupSocketTimerFired(
+ std::string group_name,
+ ClientSocketPoolBaseHelper* pool);
+
+ // Checks that |unassigned_job_count_| does not execeed the number of
+ // ConnectJobs.
+ void SanityCheck();
+
+ // Total number of ConnectJobs that have never been assigned to a Request.
+ // Since jobs use late binding to requests, which ConnectJobs have or have
+ // not been assigned to a request are not tracked. This is incremented on
+ // preconnect and decremented when a preconnect is assigned, or when there
+ // are fewer than |unassigned_job_count_| ConnectJobs. Not incremented
+ // when a request is cancelled.
+ size_t unassigned_job_count_;
+
+ std::list<IdleSocket> idle_sockets_;
+ std::set<ConnectJob*> jobs_;
+ RequestQueue pending_requests_;
+ int active_socket_count_; // number of active sockets used by clients
+ // A factory to pin the backup_job tasks.
+ base::WeakPtrFactory<Group> weak_factory_;
+ };
+
+ typedef std::map<std::string, Group*> GroupMap;
+
+ typedef std::set<ConnectJob*> ConnectJobSet;
+
+ struct CallbackResultPair {
+ CallbackResultPair();
+ CallbackResultPair(const CompletionCallback& callback_in, int result_in);
+ ~CallbackResultPair();
+
+ CompletionCallback callback;
+ int result;
+ };
+
+ typedef std::map<const ClientSocketHandle*, CallbackResultPair>
+ PendingCallbackMap;
+
+ // Inserts the request into the queue based on order they will receive
+ // sockets. Sockets which ignore the socket pool limits are first. Then
+ // requests are sorted by priority, with higher priorities closer to the
+ // front. Older requests are prioritized over requests of equal priority.
+ static void InsertRequestIntoQueue(const Request* r,
+ RequestQueue* pending_requests);
+ static const Request* RemoveRequestFromQueue(const RequestQueue::iterator& it,
+ Group* group);
+
+ Group* GetOrCreateGroup(const std::string& group_name);
+ void RemoveGroup(const std::string& group_name);
+ void RemoveGroup(GroupMap::iterator it);
+
+ // Called when the number of idle sockets changes.
+ void IncrementIdleCount();
+ void DecrementIdleCount();
+
+ // Start cleanup timer for idle sockets.
+ void StartIdleSocketTimer();
+
+ // Scans the group map for groups which have an available socket slot and
+ // at least one pending request. Returns true if any groups are stalled, and
+ // if so (and if both |group| and |group_name| are not NULL), fills |group|
+ // and |group_name| with data of the stalled group having highest priority.
+ bool FindTopStalledGroup(Group** group, std::string* group_name) const;
+
+ // Called when timer_ fires. This method scans the idle sockets removing
+ // sockets that timed out or can't be reused.
+ void OnCleanupTimerFired() {
+ CleanupIdleSockets(false);
+ }
+
+ // Removes |job| from |group|, which must already own |job|.
+ void RemoveConnectJob(ConnectJob* job, Group* group);
+
+ // Tries to see if we can handle any more requests for |group|.
+ void OnAvailableSocketSlot(const std::string& group_name, Group* group);
+
+ // Process a pending socket request for a group.
+ void ProcessPendingRequest(const std::string& group_name, Group* group);
+
+ // Assigns |socket| to |handle| and updates |group|'s counters appropriately.
+ void HandOutSocket(scoped_ptr<StreamSocket> socket,
+ bool reused,
+ const LoadTimingInfo::ConnectTiming& connect_timing,
+ ClientSocketHandle* handle,
+ base::TimeDelta time_idle,
+ Group* group,
+ const BoundNetLog& net_log);
+
+ // Adds |socket| to the list of idle sockets for |group|.
+ void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group);
+
+ // Iterates through |group_map_|, canceling all ConnectJobs and deleting
+ // groups if they are no longer needed.
+ void CancelAllConnectJobs();
+
+ // Iterates through |group_map_|, posting |error| callbacks for all
+ // requests, and then deleting groups if they are no longer needed.
+ void CancelAllRequestsWithError(int error);
+
+ // Returns true if we can't create any more sockets due to the total limit.
+ bool ReachedMaxSocketsLimit() const;
+
+ // This is the internal implementation of RequestSocket(). It differs in that
+ // it does not handle logging into NetLog of the queueing status of
+ // |request|.
+ int RequestSocketInternal(const std::string& group_name,
+ const Request* request);
+
+ // Assigns an idle socket for the group to the request.
+ // Returns |true| if an idle socket is available, false otherwise.
+ bool AssignIdleSocketToRequest(const Request* request, Group* group);
+
+ static void LogBoundConnectJobToRequest(
+ const NetLog::Source& connect_job_source, const Request* request);
+
+ // Same as CloseOneIdleSocket() except it won't close an idle socket in
+ // |group|. If |group| is NULL, it is ignored. Returns true if it closed a
+ // socket.
+ bool CloseOneIdleSocketExceptInGroup(const Group* group);
+
+ // Checks if there are stalled socket groups that should be notified
+ // for possible wakeup.
+ void CheckForStalledSocketGroups();
+
+ // Posts a task to call InvokeUserCallback() on the next iteration through the
+ // current message loop. Inserts |callback| into |pending_callback_map_|,
+ // keyed by |handle|.
+ void InvokeUserCallbackLater(
+ ClientSocketHandle* handle, const CompletionCallback& callback, int rv);
+
+ // Invokes the user callback for |handle|. By the time this task has run,
+ // it's possible that the request has been cancelled, so |handle| may not
+ // exist in |pending_callback_map_|. We look up the callback and result code
+ // in |pending_callback_map_|.
+ void InvokeUserCallback(ClientSocketHandle* handle);
+
+ // Tries to close idle sockets in a higher level socket pool as long as this
+ // this pool is stalled.
+ void TryToCloseSocketsInLayeredPools();
+
+ GroupMap group_map_;
+
+ // Map of the ClientSocketHandles for which we have a pending Task to invoke a
+ // callback. This is necessary since, before we invoke said callback, it's
+ // possible that the request is cancelled.
+ PendingCallbackMap pending_callback_map_;
+
+ // Timer used to periodically prune idle sockets that timed out or can't be
+ // reused.
+ base::RepeatingTimer<ClientSocketPoolBaseHelper> timer_;
+
+ // The total number of idle sockets in the system.
+ int idle_socket_count_;
+
+ // Number of connecting sockets across all groups.
+ int connecting_socket_count_;
+
+ // Number of connected sockets we handed out across all groups.
+ int handed_out_socket_count_;
+
+ // The maximum total number of sockets. See ReachedMaxSocketsLimit.
+ const int max_sockets_;
+
+ // The maximum number of sockets kept per group.
+ const int max_sockets_per_group_;
+
+ // Whether to use timer to cleanup idle sockets.
+ bool use_cleanup_timer_;
+
+ // The time to wait until closing idle sockets.
+ const base::TimeDelta unused_idle_socket_timeout_;
+ const base::TimeDelta used_idle_socket_timeout_;
+
+ const scoped_ptr<ConnectJobFactory> connect_job_factory_;
+
+ // TODO(vandebo) Remove when backup jobs move to TransportClientSocketPool
+ bool connect_backup_jobs_enabled_;
+
+ // A unique id for the pool. It gets incremented every time we
+ // FlushWithError() the pool. This is so that when sockets get released back
+ // to the pool, we can make sure that they are discarded rather than reused.
+ int pool_generation_number_;
+
+ std::set<LayeredPool*> higher_layer_pools_;
+
+ base::WeakPtrFactory<ClientSocketPoolBaseHelper> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBaseHelper);
+};
+
+} // namespace internal
+
+template <typename SocketParams>
+class ClientSocketPoolBase {
+ public:
+ class Request : public internal::ClientSocketPoolBaseHelper::Request {
+ public:
+ Request(ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ RequestPriority priority,
+ internal::ClientSocketPoolBaseHelper::Flags flags,
+ bool ignore_limits,
+ const scoped_refptr<SocketParams>& params,
+ const BoundNetLog& net_log)
+ : internal::ClientSocketPoolBaseHelper::Request(
+ handle, callback, priority, ignore_limits, flags, net_log),
+ params_(params) {}
+
+ const scoped_refptr<SocketParams>& params() const { return params_; }
+
+ private:
+ const scoped_refptr<SocketParams> params_;
+ };
+
+ class ConnectJobFactory {
+ public:
+ ConnectJobFactory() {}
+ virtual ~ConnectJobFactory() {}
+
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const Request& request,
+ ConnectJob::Delegate* delegate) const = 0;
+
+ virtual base::TimeDelta ConnectionTimeout() const = 0;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ConnectJobFactory);
+ };
+
+ // |max_sockets| is the maximum number of sockets to be maintained by this
+ // ClientSocketPool. |max_sockets_per_group| specifies the maximum number of
+ // sockets a "group" can have. |unused_idle_socket_timeout| specifies how
+ // long to leave an unused idle socket open before closing it.
+ // |used_idle_socket_timeout| specifies how long to leave a previously used
+ // idle socket open before closing it.
+ ClientSocketPoolBase(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ base::TimeDelta unused_idle_socket_timeout,
+ base::TimeDelta used_idle_socket_timeout,
+ ConnectJobFactory* connect_job_factory)
+ : histograms_(histograms),
+ helper_(max_sockets, max_sockets_per_group,
+ unused_idle_socket_timeout, used_idle_socket_timeout,
+ new ConnectJobFactoryAdaptor(connect_job_factory)) {}
+
+ virtual ~ClientSocketPoolBase() {}
+
+ // These member functions simply forward to ClientSocketPoolBaseHelper.
+ void AddLayeredPool(LayeredPool* pool) {
+ helper_.AddLayeredPool(pool);
+ }
+
+ void RemoveLayeredPool(LayeredPool* pool) {
+ helper_.RemoveLayeredPool(pool);
+ }
+
+ // RequestSocket bundles up the parameters into a Request and then forwards to
+ // ClientSocketPoolBaseHelper::RequestSocket().
+ int RequestSocket(const std::string& group_name,
+ const scoped_refptr<SocketParams>& params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) {
+ Request* request =
+ new Request(handle, callback, priority,
+ internal::ClientSocketPoolBaseHelper::NORMAL,
+ params->ignore_limits(),
+ params, net_log);
+ return helper_.RequestSocket(group_name, request);
+ }
+
+ // RequestSockets bundles up the parameters into a Request and then forwards
+ // to ClientSocketPoolBaseHelper::RequestSockets(). Note that it assigns the
+ // priority to DEFAULT_PRIORITY and specifies the NO_IDLE_SOCKETS flag.
+ void RequestSockets(const std::string& group_name,
+ const scoped_refptr<SocketParams>& params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ const Request request(NULL /* no handle */,
+ CompletionCallback(),
+ DEFAULT_PRIORITY,
+ internal::ClientSocketPoolBaseHelper::NO_IDLE_SOCKETS,
+ params->ignore_limits(),
+ params,
+ net_log);
+ helper_.RequestSockets(group_name, request, num_sockets);
+ }
+
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) {
+ return helper_.CancelRequest(group_name, handle);
+ }
+
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ return helper_.ReleaseSocket(group_name, socket.Pass(), id);
+ }
+
+ void FlushWithError(int error) { helper_.FlushWithError(error); }
+
+ bool IsStalled() const { return helper_.IsStalled(); }
+
+ void CloseIdleSockets() { return helper_.CloseIdleSockets(); }
+
+ int idle_socket_count() const { return helper_.idle_socket_count(); }
+
+ int IdleSocketCountInGroup(const std::string& group_name) const {
+ return helper_.IdleSocketCountInGroup(group_name);
+ }
+
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const {
+ return helper_.GetLoadState(group_name, handle);
+ }
+
+ virtual void OnConnectJobComplete(int result, ConnectJob* job) {
+ return helper_.OnConnectJobComplete(result, job);
+ }
+
+ int NumUnassignedConnectJobsInGroup(const std::string& group_name) const {
+ return helper_.NumUnassignedConnectJobsInGroup(group_name);
+ }
+
+ int NumConnectJobsInGroup(const std::string& group_name) const {
+ return helper_.NumConnectJobsInGroup(group_name);
+ }
+
+ int NumActiveSocketsInGroup(const std::string& group_name) const {
+ return helper_.NumActiveSocketsInGroup(group_name);
+ }
+
+ bool HasGroup(const std::string& group_name) const {
+ return helper_.HasGroup(group_name);
+ }
+
+ void CleanupIdleSockets(bool force) {
+ return helper_.CleanupIdleSockets(force);
+ }
+
+ base::DictionaryValue* GetInfoAsValue(const std::string& name,
+ const std::string& type) const {
+ return helper_.GetInfoAsValue(name, type);
+ }
+
+ base::TimeDelta ConnectionTimeout() const {
+ return helper_.ConnectionTimeout();
+ }
+
+ ClientSocketPoolHistograms* histograms() const {
+ return histograms_;
+ }
+
+ void EnableConnectBackupJobs() { helper_.EnableConnectBackupJobs(); }
+
+ bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); }
+
+ bool CloseOneIdleConnectionInLayeredPool() {
+ return helper_.CloseOneIdleConnectionInLayeredPool();
+ }
+
+ private:
+ // This adaptor class exists to bridge the
+ // internal::ClientSocketPoolBaseHelper::ConnectJobFactory and
+ // ClientSocketPoolBase::ConnectJobFactory types, allowing clients to use the
+ // typesafe ClientSocketPoolBase::ConnectJobFactory, rather than having to
+ // static_cast themselves.
+ class ConnectJobFactoryAdaptor
+ : public internal::ClientSocketPoolBaseHelper::ConnectJobFactory {
+ public:
+ typedef typename ClientSocketPoolBase<SocketParams>::ConnectJobFactory
+ ConnectJobFactory;
+
+ explicit ConnectJobFactoryAdaptor(ConnectJobFactory* connect_job_factory)
+ : connect_job_factory_(connect_job_factory) {}
+ virtual ~ConnectJobFactoryAdaptor() {}
+
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const internal::ClientSocketPoolBaseHelper::Request& request,
+ ConnectJob::Delegate* delegate) const OVERRIDE {
+ const Request& casted_request = static_cast<const Request&>(request);
+ return connect_job_factory_->NewConnectJob(
+ group_name, casted_request, delegate);
+ }
+
+ virtual base::TimeDelta ConnectionTimeout() const {
+ return connect_job_factory_->ConnectionTimeout();
+ }
+
+ const scoped_ptr<ConnectJobFactory> connect_job_factory_;
+ };
+
+ // Histograms for the pool
+ ClientSocketPoolHistograms* const histograms_;
+ internal::ClientSocketPoolBaseHelper helper_;
+
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBase);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_POOL_BASE_H_
diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc
new file mode 100644
index 00000000000..6688e01244d
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_base_unittest.cc
@@ -0,0 +1,4168 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool_base.h"
+
+#include <vector>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/callback.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_vector.h"
+#include "base/memory/weak_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/run_loop.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/strings/stringprintf.h"
+#include "base/threading/platform_thread.h"
+#include "base/values.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/request_priority.h"
+#include "net/base/test_completion_callback.h"
+#include "net/http/http_response_headers.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::Invoke;
+using ::testing::Return;
+
+namespace net {
+
+namespace {
+
+const int kDefaultMaxSockets = 4;
+const int kDefaultMaxSocketsPerGroup = 2;
+const net::RequestPriority kDefaultPriority = MEDIUM;
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// reused socket.
+void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ // Only pass true in as |is_reused|, as in general, HttpStream types should
+ // have stricter concepts of reuse than socket pools.
+ EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
+
+ EXPECT_EQ(true, load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner
+// of a connection where |is_reused| is false may consider the connection
+// reused.
+void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
+ EXPECT_FALSE(handle.is_reused());
+
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+
+ TestLoadTimingInfoConnectedReused(handle);
+}
+
+// Make sure |handle| sets load times correctly, in the case that it does not
+// currently have a socket.
+void TestLoadTimingInfoNotConnected(const ClientSocketHandle& handle) {
+ // Should only be set to true once a socket is assigned, if at all.
+ EXPECT_FALSE(handle.is_reused());
+
+ LoadTimingInfo load_timing_info;
+ EXPECT_FALSE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+ EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+class TestSocketParams : public base::RefCounted<TestSocketParams> {
+ public:
+ TestSocketParams() : ignore_limits_(false) {}
+
+ void set_ignore_limits(bool ignore_limits) {
+ ignore_limits_ = ignore_limits;
+ }
+ bool ignore_limits() { return ignore_limits_; }
+
+ private:
+ friend class base::RefCounted<TestSocketParams>;
+ ~TestSocketParams() {}
+
+ bool ignore_limits_;
+};
+typedef ClientSocketPoolBase<TestSocketParams> TestClientSocketPoolBase;
+
+class MockClientSocket : public StreamSocket {
+ public:
+ explicit MockClientSocket(net::NetLog* net_log)
+ : connected_(false),
+ net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_SOCKET)),
+ was_used_to_convey_data_(false) {
+ }
+
+ // Socket implementation.
+ virtual int Read(
+ IOBuffer* /* buf */, int len,
+ const CompletionCallback& /* callback */) OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+
+ virtual int Write(
+ IOBuffer* /* buf */, int len,
+ const CompletionCallback& /* callback */) OVERRIDE {
+ was_used_to_convey_data_ = true;
+ return len;
+ }
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; }
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ connected_ = true;
+ return OK;
+ }
+
+ virtual void Disconnect() OVERRIDE { connected_ = false; }
+ virtual bool IsConnected() const OVERRIDE { return connected_; }
+ virtual bool IsConnectedAndIdle() const OVERRIDE { return connected_; }
+
+ virtual int GetPeerAddress(IPEndPoint* /* address */) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+
+ virtual int GetLocalAddress(IPEndPoint* /* address */) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+ virtual bool WasEverUsed() const OVERRIDE {
+ return was_used_to_convey_data_;
+ }
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return false;
+ }
+
+ private:
+ bool connected_;
+ BoundNetLog net_log_;
+ bool was_used_to_convey_data_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockClientSocket);
+};
+
+class TestConnectJob;
+
+class MockClientSocketFactory : public ClientSocketFactory {
+ public:
+ MockClientSocketFactory() : allocation_count_(0) {}
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE {
+ NOTREACHED();
+ return scoped_ptr<DatagramClientSocket>();
+ }
+
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* /* net_log */,
+ const NetLog::Source& /*source*/) OVERRIDE {
+ allocation_count_++;
+ return scoped_ptr<StreamSocket>();
+ }
+
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+ }
+
+ virtual void ClearSSLSessionCache() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+
+ void WaitForSignal(TestConnectJob* job) { waiting_jobs_.push_back(job); }
+
+ void SignalJobs();
+
+ void SignalJob(size_t job);
+
+ void SetJobLoadState(size_t job, LoadState load_state);
+
+ int allocation_count() const { return allocation_count_; }
+
+ private:
+ int allocation_count_;
+ std::vector<TestConnectJob*> waiting_jobs_;
+};
+
+class TestConnectJob : public ConnectJob {
+ public:
+ enum JobType {
+ kMockJob,
+ kMockFailingJob,
+ kMockPendingJob,
+ kMockPendingFailingJob,
+ kMockWaitingJob,
+ kMockRecoverableJob,
+ kMockPendingRecoverableJob,
+ kMockAdditionalErrorStateJob,
+ kMockPendingAdditionalErrorStateJob,
+ };
+
+ // The kMockPendingJob uses a slight delay before allowing the connect
+ // to complete.
+ static const int kPendingConnectDelay = 2;
+
+ TestConnectJob(JobType job_type,
+ const std::string& group_name,
+ const TestClientSocketPoolBase::Request& request,
+ base::TimeDelta timeout_duration,
+ ConnectJob::Delegate* delegate,
+ MockClientSocketFactory* client_socket_factory,
+ NetLog* net_log)
+ : ConnectJob(group_name, timeout_duration, delegate,
+ BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
+ job_type_(job_type),
+ client_socket_factory_(client_socket_factory),
+ weak_factory_(this),
+ load_state_(LOAD_STATE_IDLE),
+ store_additional_error_state_(false) {}
+
+ void Signal() {
+ DoConnect(waiting_success_, true /* async */, false /* recoverable */);
+ }
+
+ void set_load_state(LoadState load_state) { load_state_ = load_state; }
+
+ // From ConnectJob:
+
+ virtual LoadState GetLoadState() const OVERRIDE { return load_state_; }
+
+ virtual void GetAdditionalErrorState(ClientSocketHandle* handle) OVERRIDE {
+ if (store_additional_error_state_) {
+ // Set all of the additional error state fields in some way.
+ handle->set_is_ssl_error(true);
+ HttpResponseInfo info;
+ info.headers = new HttpResponseHeaders(std::string());
+ handle->set_ssl_error_response_info(info);
+ }
+ }
+
+ private:
+ // From ConnectJob:
+
+ virtual int ConnectInternal() OVERRIDE {
+ AddressList ignored;
+ client_socket_factory_->CreateTransportClientSocket(
+ ignored, NULL, net::NetLog::Source());
+ SetSocket(
+ scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log())));
+ switch (job_type_) {
+ case kMockJob:
+ return DoConnect(true /* successful */, false /* sync */,
+ false /* recoverable */);
+ case kMockFailingJob:
+ return DoConnect(false /* error */, false /* sync */,
+ false /* recoverable */);
+ case kMockPendingJob:
+ set_load_state(LOAD_STATE_CONNECTING);
+
+ // Depending on execution timings, posting a delayed task can result
+ // in the task getting executed the at the earliest possible
+ // opportunity or only after returning once from the message loop and
+ // then a second call into the message loop. In order to make behavior
+ // more deterministic, we change the default delay to 2ms. This should
+ // always require us to wait for the second call into the message loop.
+ //
+ // N.B. The correct fix for this and similar timing problems is to
+ // abstract time for the purpose of unittests. Unfortunately, we have
+ // a lot of third-party components that directly call the various
+ // time functions, so this change would be rather invasive.
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect),
+ weak_factory_.GetWeakPtr(),
+ true /* successful */,
+ true /* async */,
+ false /* recoverable */),
+ base::TimeDelta::FromMilliseconds(kPendingConnectDelay));
+ return ERR_IO_PENDING;
+ case kMockPendingFailingJob:
+ set_load_state(LOAD_STATE_CONNECTING);
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect),
+ weak_factory_.GetWeakPtr(),
+ false /* error */,
+ true /* async */,
+ false /* recoverable */),
+ base::TimeDelta::FromMilliseconds(2));
+ return ERR_IO_PENDING;
+ case kMockWaitingJob:
+ set_load_state(LOAD_STATE_CONNECTING);
+ client_socket_factory_->WaitForSignal(this);
+ waiting_success_ = true;
+ return ERR_IO_PENDING;
+ case kMockRecoverableJob:
+ return DoConnect(false /* error */, false /* sync */,
+ true /* recoverable */);
+ case kMockPendingRecoverableJob:
+ set_load_state(LOAD_STATE_CONNECTING);
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect),
+ weak_factory_.GetWeakPtr(),
+ false /* error */,
+ true /* async */,
+ true /* recoverable */),
+ base::TimeDelta::FromMilliseconds(2));
+ return ERR_IO_PENDING;
+ case kMockAdditionalErrorStateJob:
+ store_additional_error_state_ = true;
+ return DoConnect(false /* error */, false /* sync */,
+ false /* recoverable */);
+ case kMockPendingAdditionalErrorStateJob:
+ set_load_state(LOAD_STATE_CONNECTING);
+ store_additional_error_state_ = true;
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(base::IgnoreResult(&TestConnectJob::DoConnect),
+ weak_factory_.GetWeakPtr(),
+ false /* error */,
+ true /* async */,
+ false /* recoverable */),
+ base::TimeDelta::FromMilliseconds(2));
+ return ERR_IO_PENDING;
+ default:
+ NOTREACHED();
+ SetSocket(scoped_ptr<StreamSocket>());
+ return ERR_FAILED;
+ }
+ }
+
+ int DoConnect(bool succeed, bool was_async, bool recoverable) {
+ int result = OK;
+ if (succeed) {
+ socket()->Connect(CompletionCallback());
+ } else if (recoverable) {
+ result = ERR_PROXY_AUTH_REQUESTED;
+ } else {
+ result = ERR_CONNECTION_FAILED;
+ SetSocket(scoped_ptr<StreamSocket>());
+ }
+
+ if (was_async)
+ NotifyDelegateOfCompletion(result);
+ return result;
+ }
+
+ bool waiting_success_;
+ const JobType job_type_;
+ MockClientSocketFactory* const client_socket_factory_;
+ base::WeakPtrFactory<TestConnectJob> weak_factory_;
+ LoadState load_state_;
+ bool store_additional_error_state_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestConnectJob);
+};
+
+class TestConnectJobFactory
+ : public TestClientSocketPoolBase::ConnectJobFactory {
+ public:
+ TestConnectJobFactory(MockClientSocketFactory* client_socket_factory,
+ NetLog* net_log)
+ : job_type_(TestConnectJob::kMockJob),
+ job_types_(NULL),
+ client_socket_factory_(client_socket_factory),
+ net_log_(net_log) {
+ }
+
+ virtual ~TestConnectJobFactory() {}
+
+ void set_job_type(TestConnectJob::JobType job_type) { job_type_ = job_type; }
+
+ void set_job_types(std::list<TestConnectJob::JobType>* job_types) {
+ job_types_ = job_types;
+ CHECK(!job_types_->empty());
+ }
+
+ void set_timeout_duration(base::TimeDelta timeout_duration) {
+ timeout_duration_ = timeout_duration;
+ }
+
+ // ConnectJobFactory implementation.
+
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const TestClientSocketPoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const OVERRIDE {
+ EXPECT_TRUE(!job_types_ || !job_types_->empty());
+ TestConnectJob::JobType job_type = job_type_;
+ if (job_types_ && !job_types_->empty()) {
+ job_type = job_types_->front();
+ job_types_->pop_front();
+ }
+ return scoped_ptr<ConnectJob>(new TestConnectJob(job_type,
+ group_name,
+ request,
+ timeout_duration_,
+ delegate,
+ client_socket_factory_,
+ net_log_));
+ }
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
+ return timeout_duration_;
+ }
+
+ private:
+ TestConnectJob::JobType job_type_;
+ std::list<TestConnectJob::JobType>* job_types_;
+ base::TimeDelta timeout_duration_;
+ MockClientSocketFactory* const client_socket_factory_;
+ NetLog* net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestConnectJobFactory);
+};
+
+class TestClientSocketPool : public ClientSocketPool {
+ public:
+ TestClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ base::TimeDelta unused_idle_socket_timeout,
+ base::TimeDelta used_idle_socket_timeout,
+ TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory)
+ : base_(max_sockets, max_sockets_per_group, histograms,
+ unused_idle_socket_timeout, used_idle_socket_timeout,
+ connect_job_factory) {}
+
+ virtual ~TestClientSocketPool() {}
+
+ virtual int RequestSocket(
+ const std::string& group_name,
+ const void* params,
+ net::RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE {
+ const scoped_refptr<TestSocketParams>* casted_socket_params =
+ static_cast<const scoped_refptr<TestSocketParams>*>(params);
+ return base_.RequestSocket(group_name, *casted_socket_params, priority,
+ handle, callback, net_log);
+ }
+
+ virtual void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) OVERRIDE {
+ const scoped_refptr<TestSocketParams>* casted_params =
+ static_cast<const scoped_refptr<TestSocketParams>*>(params);
+
+ base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
+ }
+
+ virtual void CancelRequest(
+ const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE {
+ base_.CancelRequest(group_name, handle);
+ }
+
+ virtual void ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
+ }
+
+ virtual void FlushWithError(int error) OVERRIDE {
+ base_.FlushWithError(error);
+ }
+
+ virtual bool IsStalled() const OVERRIDE {
+ return base_.IsStalled();
+ }
+
+ virtual void CloseIdleSockets() OVERRIDE {
+ base_.CloseIdleSockets();
+ }
+
+ virtual int IdleSocketCount() const OVERRIDE {
+ return base_.idle_socket_count();
+ }
+
+ virtual int IdleSocketCountInGroup(
+ const std::string& group_name) const OVERRIDE {
+ return base_.IdleSocketCountInGroup(group_name);
+ }
+
+ virtual LoadState GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const OVERRIDE {
+ return base_.GetLoadState(group_name, handle);
+ }
+
+ virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE {
+ base_.AddLayeredPool(pool);
+ }
+
+ virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE {
+ base_.RemoveLayeredPool(pool);
+ }
+
+ virtual base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const OVERRIDE {
+ return base_.GetInfoAsValue(name, type);
+ }
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
+ return base_.ConnectionTimeout();
+ }
+
+ virtual ClientSocketPoolHistograms* histograms() const OVERRIDE {
+ return base_.histograms();
+ }
+
+ const TestClientSocketPoolBase* base() const { return &base_; }
+
+ int NumUnassignedConnectJobsInGroup(const std::string& group_name) const {
+ return base_.NumUnassignedConnectJobsInGroup(group_name);
+ }
+
+ int NumConnectJobsInGroup(const std::string& group_name) const {
+ return base_.NumConnectJobsInGroup(group_name);
+ }
+
+ int NumActiveSocketsInGroup(const std::string& group_name) const {
+ return base_.NumActiveSocketsInGroup(group_name);
+ }
+
+ bool HasGroup(const std::string& group_name) const {
+ return base_.HasGroup(group_name);
+ }
+
+ void CleanupTimedOutIdleSockets() { base_.CleanupIdleSockets(false); }
+
+ void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); }
+
+ bool CloseOneIdleConnectionInLayeredPool() {
+ return base_.CloseOneIdleConnectionInLayeredPool();
+ }
+
+ private:
+ TestClientSocketPoolBase base_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestClientSocketPool);
+};
+
+} // namespace
+
+REGISTER_SOCKET_PARAMS_FOR_POOL(TestClientSocketPool, TestSocketParams);
+
+namespace {
+
+void MockClientSocketFactory::SignalJobs() {
+ for (std::vector<TestConnectJob*>::iterator it = waiting_jobs_.begin();
+ it != waiting_jobs_.end(); ++it) {
+ (*it)->Signal();
+ }
+ waiting_jobs_.clear();
+}
+
+void MockClientSocketFactory::SignalJob(size_t job) {
+ ASSERT_LT(job, waiting_jobs_.size());
+ waiting_jobs_[job]->Signal();
+ waiting_jobs_.erase(waiting_jobs_.begin() + job);
+}
+
+void MockClientSocketFactory::SetJobLoadState(size_t job,
+ LoadState load_state) {
+ ASSERT_LT(job, waiting_jobs_.size());
+ waiting_jobs_[job]->set_load_state(load_state);
+}
+
+class TestConnectJobDelegate : public ConnectJob::Delegate {
+ public:
+ TestConnectJobDelegate()
+ : have_result_(false), waiting_for_result_(false), result_(OK) {}
+ virtual ~TestConnectJobDelegate() {}
+
+ virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE {
+ result_ = result;
+ scoped_ptr<ConnectJob> owned_job(job);
+ scoped_ptr<StreamSocket> socket = owned_job->PassSocket();
+ // socket.get() should be NULL iff result != OK
+ EXPECT_EQ(socket == NULL, result != OK);
+ have_result_ = true;
+ if (waiting_for_result_)
+ base::MessageLoop::current()->Quit();
+ }
+
+ int WaitForResult() {
+ DCHECK(!waiting_for_result_);
+ while (!have_result_) {
+ waiting_for_result_ = true;
+ base::MessageLoop::current()->Run();
+ waiting_for_result_ = false;
+ }
+ have_result_ = false; // auto-reset for next callback
+ return result_;
+ }
+
+ private:
+ bool have_result_;
+ bool waiting_for_result_;
+ int result_;
+};
+
+class ClientSocketPoolBaseTest : public testing::Test {
+ protected:
+ ClientSocketPoolBaseTest()
+ : params_(new TestSocketParams()),
+ histograms_("ClientSocketPoolTest") {
+ connect_backup_jobs_enabled_ =
+ internal::ClientSocketPoolBaseHelper::connect_backup_jobs_enabled();
+ internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true);
+ cleanup_timer_enabled_ =
+ internal::ClientSocketPoolBaseHelper::cleanup_timer_enabled();
+ }
+
+ virtual ~ClientSocketPoolBaseTest() {
+ internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(
+ connect_backup_jobs_enabled_);
+ internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(
+ cleanup_timer_enabled_);
+ }
+
+ void CreatePool(int max_sockets, int max_sockets_per_group) {
+ CreatePoolWithIdleTimeouts(
+ max_sockets,
+ max_sockets_per_group,
+ ClientSocketPool::unused_idle_socket_timeout(),
+ ClientSocketPool::used_idle_socket_timeout());
+ }
+
+ void CreatePoolWithIdleTimeouts(
+ int max_sockets, int max_sockets_per_group,
+ base::TimeDelta unused_idle_socket_timeout,
+ base::TimeDelta used_idle_socket_timeout) {
+ DCHECK(!pool_.get());
+ connect_job_factory_ = new TestConnectJobFactory(&client_socket_factory_,
+ &net_log_);
+ pool_.reset(new TestClientSocketPool(max_sockets,
+ max_sockets_per_group,
+ &histograms_,
+ unused_idle_socket_timeout,
+ used_idle_socket_timeout,
+ connect_job_factory_));
+ }
+
+ int StartRequestWithParams(
+ const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<TestSocketParams>& params) {
+ return test_base_.StartRequestUsingPool<
+ TestClientSocketPool, TestSocketParams>(
+ pool_.get(), group_name, priority, params);
+ }
+
+ int StartRequest(const std::string& group_name, RequestPriority priority) {
+ return StartRequestWithParams(group_name, priority, params_);
+ }
+
+ int GetOrderOfRequest(size_t index) const {
+ return test_base_.GetOrderOfRequest(index);
+ }
+
+ bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) {
+ return test_base_.ReleaseOneConnection(keep_alive);
+ }
+
+ void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) {
+ test_base_.ReleaseAllConnections(keep_alive);
+ }
+
+ TestSocketRequest* request(int i) { return test_base_.request(i); }
+ size_t requests_size() const { return test_base_.requests_size(); }
+ ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
+ size_t completion_count() const { return test_base_.completion_count(); }
+
+ CapturingNetLog net_log_;
+ bool connect_backup_jobs_enabled_;
+ bool cleanup_timer_enabled_;
+ MockClientSocketFactory client_socket_factory_;
+ TestConnectJobFactory* connect_job_factory_;
+ scoped_refptr<TestSocketParams> params_;
+ ClientSocketPoolHistograms histograms_;
+ scoped_ptr<TestClientSocketPool> pool_;
+ ClientSocketPoolTest test_base_;
+};
+
+// Even though a timeout is specified, it doesn't time out on a synchronous
+// completion.
+TEST_F(ClientSocketPoolBaseTest, ConnectJob_NoTimeoutOnSynchronousCompletion) {
+ TestConnectJobDelegate delegate;
+ ClientSocketHandle ignored;
+ TestClientSocketPoolBase::Request request(
+ &ignored, CompletionCallback(), kDefaultPriority,
+ internal::ClientSocketPoolBaseHelper::NORMAL,
+ false, params_, BoundNetLog());
+ scoped_ptr<TestConnectJob> job(
+ new TestConnectJob(TestConnectJob::kMockJob,
+ "a",
+ request,
+ base::TimeDelta::FromMicroseconds(1),
+ &delegate,
+ &client_socket_factory_,
+ NULL));
+ EXPECT_EQ(OK, job->Connect());
+}
+
+TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) {
+ TestConnectJobDelegate delegate;
+ ClientSocketHandle ignored;
+ CapturingNetLog log;
+
+ TestClientSocketPoolBase::Request request(
+ &ignored, CompletionCallback(), kDefaultPriority,
+ internal::ClientSocketPoolBaseHelper::NORMAL,
+ false, params_, BoundNetLog());
+ // Deleted by TestConnectJobDelegate.
+ TestConnectJob* job =
+ new TestConnectJob(TestConnectJob::kMockPendingJob,
+ "a",
+ request,
+ base::TimeDelta::FromMicroseconds(1),
+ &delegate,
+ &client_socket_factory_,
+ &log);
+ ASSERT_EQ(ERR_IO_PENDING, job->Connect());
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(1));
+ EXPECT_EQ(ERR_TIMED_OUT, delegate.WaitForResult());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+
+ EXPECT_EQ(6u, entries.size());
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB));
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 2, NetLog::TYPE_CONNECT_JOB_SET_SOCKET,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 3, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 4, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_CONNECT));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 5, NetLog::TYPE_SOCKET_POOL_CONNECT_JOB));
+}
+
+TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ CapturingBoundNetLog log;
+ TestLoadTimingInfoNotConnected(handle);
+
+ EXPECT_EQ(OK,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ log.bound()));
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoConnectedNotReused(handle);
+
+ handle.Reset();
+ TestLoadTimingInfoNotConnected(handle);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+
+ EXPECT_EQ(4u, entries.size());
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKET_POOL));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 2, NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 3, NetLog::TYPE_SOCKET_POOL));
+}
+
+TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob);
+ CapturingBoundNetLog log;
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ // Set the additional error state members to ensure that they get cleared.
+ handle.set_is_ssl_error(true);
+ HttpResponseInfo info;
+ info.headers = new HttpResponseHeaders(std::string());
+ handle.set_ssl_error_response_info(info);
+ EXPECT_EQ(ERR_CONNECTION_FAILED,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ log.bound()));
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+ EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL);
+ TestLoadTimingInfoNotConnected(handle);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+
+ EXPECT_EQ(3u, entries.size());
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKET_POOL));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 2, NetLog::TYPE_SOCKET_POOL));
+}
+
+TEST_F(ClientSocketPoolBaseTest, TotalLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ // TODO(eroman): Check that the NetLog contains this event.
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("c", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("d", kDefaultPriority));
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("f", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("g", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+ EXPECT_EQ(7, GetOrderOfRequest(7));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8));
+}
+
+TEST_F(ClientSocketPoolBaseTest, TotalLimitReachedNewGroup) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ // TODO(eroman): Check that the NetLog contains this event.
+
+ // Reach all limits: max total sockets, and max sockets per group.
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ // Now create a new group and verify that we don't starve it.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(6));
+}
+
+TEST_F(ClientSocketPoolBaseTest, TotalLimitRespectsPriority) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("b", LOWEST));
+ EXPECT_EQ(OK, StartRequest("a", MEDIUM));
+ EXPECT_EQ(OK, StartRequest("b", HIGHEST));
+ EXPECT_EQ(OK, StartRequest("a", LOWEST));
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", HIGHEST));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ // First 4 requests don't have to wait, and finish in order.
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+
+ // Request ("b", HIGHEST) has the highest priority, then ("a", MEDIUM),
+ // and then ("c", LOWEST).
+ EXPECT_EQ(7, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+ EXPECT_EQ(5, GetOrderOfRequest(7));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(9));
+}
+
+TEST_F(ClientSocketPoolBaseTest, TotalLimitRespectsGroupLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("a", LOWEST));
+ EXPECT_EQ(OK, StartRequest("a", LOW));
+ EXPECT_EQ(OK, StartRequest("b", HIGHEST));
+ EXPECT_EQ(OK, StartRequest("b", MEDIUM));
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", HIGHEST));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSockets, completion_count());
+
+ // First 4 requests don't have to wait, and finish in order.
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+
+ // Request ("b", 7) has the highest priority, but we can't make new socket for
+ // group "b", because it has reached the per-group limit. Then we make
+ // socket for ("c", 6), because it has higher priority than ("a", 4),
+ // and we still can't make a socket for group "b".
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+ EXPECT_EQ(7, GetOrderOfRequest(7));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8));
+}
+
+// Make sure that we count connecting sockets against the total limit.
+TEST_F(ClientSocketPoolBaseTest, TotalLimitCountsConnectingSockets) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("c", kDefaultPriority));
+
+ // Create one asynchronous request.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("d", kDefaultPriority));
+
+ // We post all of our delayed tasks with a 2ms delay. I.e. they don't
+ // actually become pending until 2ms after they have been created. In order
+ // to flush all tasks, we need to wait so that we know there are no
+ // soon-to-be-pending tasks waiting.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10));
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // The next synchronous request should wait for its turn.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("e", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(6));
+}
+
+TEST_F(ClientSocketPoolBaseTest, CorrectlyCountStalledGroups) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("b", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kDefaultPriority));
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+ EXPECT_EQ(kDefaultMaxSockets + 1, client_socket_factory_.allocation_count());
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+ EXPECT_EQ(kDefaultMaxSockets + 2, client_socket_factory_.allocation_count());
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+ EXPECT_EQ(kDefaultMaxSockets + 2, client_socket_factory_.allocation_count());
+}
+
+TEST_F(ClientSocketPoolBaseTest, StallAndThenCancelAndTriggerAvailableSocket) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ClientSocketHandle handles[4];
+ for (size_t i = 0; i < arraysize(handles); ++i) {
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handles[i].Init("b",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // One will be stalled, cancel all the handles now.
+ // This should hit the OnAvailableSocketSlot() code where we previously had
+ // stalled groups, but no longer have any.
+ for (size_t i = 0; i < arraysize(handles); ++i)
+ handles[i].Reset();
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelStalledSocketAtSocketLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ {
+ ClientSocketHandle handles[kDefaultMaxSockets];
+ TestCompletionCallback callbacks[kDefaultMaxSockets];
+ for (int i = 0; i < kDefaultMaxSockets; ++i) {
+ EXPECT_EQ(OK, handles[i].Init(base::IntToString(i),
+ params_,
+ kDefaultPriority,
+ callbacks[i].callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // Force a stalled group.
+ ClientSocketHandle stalled_handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ // Cancel the stalled request.
+ stalled_handle.Reset();
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+
+ // Dropping out of scope will close all handles and return them to idle.
+ }
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+ EXPECT_EQ(kDefaultMaxSockets, pool_->IdleSocketCount());
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelPendingSocketAtSocketLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ {
+ ClientSocketHandle handles[kDefaultMaxSockets];
+ for (int i = 0; i < kDefaultMaxSockets; ++i) {
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handles[i].Init(base::IntToString(i),
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // Force a stalled group.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle stalled_handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ // Since it is stalled, it should have no connect jobs.
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("foo"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo"));
+
+ // Cancel the stalled request.
+ handles[0].Reset();
+
+ // Now we should have a connect job.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("foo"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo"));
+
+ // The stalled socket should connect.
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ EXPECT_EQ(kDefaultMaxSockets + 1,
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("foo"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("foo"));
+
+ // Dropping out of scope will close all handles and return them to idle.
+ }
+
+ EXPECT_EQ(1, pool_->IdleSocketCount());
+}
+
+TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ ClientSocketHandle stalled_handle;
+ TestCompletionCallback callback;
+ {
+ EXPECT_FALSE(pool_->IsStalled());
+ ClientSocketHandle handles[kDefaultMaxSockets];
+ for (int i = 0; i < kDefaultMaxSockets; ++i) {
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handles[i].Init(base::StringPrintf(
+ "Take 2: %d", i),
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_FALSE(pool_->IsStalled());
+
+ // Now we will hit the socket limit.
+ EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_TRUE(pool_->IsStalled());
+
+ // Dropping out of scope will close all handles and return them to idle.
+ }
+
+ // But if we wait for it, the released idle sockets will be closed in
+ // preference of the waiting request.
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ EXPECT_EQ(kDefaultMaxSockets + 1, client_socket_factory_.allocation_count());
+ EXPECT_EQ(3, pool_->IdleSocketCount());
+}
+
+// Regression test for http://crbug.com/40952.
+TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ pool_->EnableConnectBackupJobs();
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ for (int i = 0; i < kDefaultMaxSockets; ++i) {
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handle.Init(base::IntToString(i),
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // Flush all the DoReleaseSocket tasks.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Stall a group. Set a pending job so it'll trigger a backup job if we don't
+ // reuse a socket.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+
+ // "0" is special here, since it should be the first entry in the sorted map,
+ // which is the one which we would close an idle socket for. We shouldn't
+ // close an idle socket though, since we should reuse the idle socket.
+ EXPECT_EQ(OK, handle.Init("0",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+ EXPECT_EQ(kDefaultMaxSockets - 1, pool_->IdleSocketCount());
+}
+
+TEST_F(ClientSocketPoolBaseTest, PendingRequests) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", IDLE));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup,
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup,
+ completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(8, GetOrderOfRequest(3));
+ EXPECT_EQ(6, GetOrderOfRequest(4));
+ EXPECT_EQ(4, GetOrderOfRequest(5));
+ EXPECT_EQ(3, GetOrderOfRequest(6));
+ EXPECT_EQ(5, GetOrderOfRequest(7));
+ EXPECT_EQ(7, GetOrderOfRequest(8));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(9));
+}
+
+TEST_F(ClientSocketPoolBaseTest, PendingRequests_NoKeepAlive) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ for (size_t i = kDefaultMaxSocketsPerGroup; i < requests_size(); ++i)
+ EXPECT_EQ(OK, request(i)->WaitForResult());
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup,
+ completion_count());
+}
+
+// This test will start up a RequestSocket() and then immediately Cancel() it.
+// The pending connect job will be cancelled and should not call back into
+// ClientSocketPoolBase.
+TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ handle.Reset();
+}
+
+TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ handle.Reset();
+
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_FALSE(callback.have_result());
+
+ handle.Reset();
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelRequest) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+
+ // Cancel a request.
+ size_t index_to_cancel = kDefaultMaxSocketsPerGroup + 2;
+ EXPECT_FALSE((*requests())[index_to_cancel]->handle()->is_initialized());
+ (*requests())[index_to_cancel]->handle()->Reset();
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup,
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup - 1,
+ completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(5, GetOrderOfRequest(3));
+ EXPECT_EQ(3, GetOrderOfRequest(4));
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound,
+ GetOrderOfRequest(5)); // Canceled request.
+ EXPECT_EQ(4, GetOrderOfRequest(6));
+ EXPECT_EQ(6, GetOrderOfRequest(7));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(8));
+}
+
+class RequestSocketCallback : public TestCompletionCallbackBase {
+ public:
+ RequestSocketCallback(ClientSocketHandle* handle,
+ TestClientSocketPool* pool,
+ TestConnectJobFactory* test_connect_job_factory,
+ TestConnectJob::JobType next_job_type)
+ : handle_(handle),
+ pool_(pool),
+ within_callback_(false),
+ test_connect_job_factory_(test_connect_job_factory),
+ next_job_type_(next_job_type),
+ callback_(base::Bind(&RequestSocketCallback::OnComplete,
+ base::Unretained(this))) {
+ }
+
+ virtual ~RequestSocketCallback() {}
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ SetResult(result);
+ ASSERT_EQ(OK, result);
+
+ if (!within_callback_) {
+ test_connect_job_factory_->set_job_type(next_job_type_);
+
+ // Don't allow reuse of the socket. Disconnect it and then release it and
+ // run through the MessageLoop once to get it completely released.
+ handle_->socket()->Disconnect();
+ handle_->Reset();
+ {
+ // TODO: Resolve conflicting intentions of stopping recursion with the
+ // |!within_callback_| test (above) and the call to |RunUntilIdle()|
+ // below. http://crbug.com/114130.
+ base::MessageLoop::ScopedNestableTaskAllower allow(
+ base::MessageLoop::current());
+ base::MessageLoop::current()->RunUntilIdle();
+ }
+ within_callback_ = true;
+ TestCompletionCallback next_job_callback;
+ scoped_refptr<TestSocketParams> params(new TestSocketParams());
+ int rv = handle_->Init("a",
+ params,
+ kDefaultPriority,
+ next_job_callback.callback(),
+ pool_,
+ BoundNetLog());
+ switch (next_job_type_) {
+ case TestConnectJob::kMockJob:
+ EXPECT_EQ(OK, rv);
+ break;
+ case TestConnectJob::kMockPendingJob:
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // For pending jobs, wait for new socket to be created. This makes
+ // sure there are no more pending operations nor any unclosed sockets
+ // when the test finishes.
+ // We need to give it a little bit of time to run, so that all the
+ // operations that happen on timers (e.g. cleanup of idle
+ // connections) can execute.
+ {
+ base::MessageLoop::ScopedNestableTaskAllower allow(
+ base::MessageLoop::current());
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10));
+ EXPECT_EQ(OK, next_job_callback.WaitForResult());
+ }
+ break;
+ default:
+ FAIL() << "Unexpected job type: " << next_job_type_;
+ break;
+ }
+ }
+ }
+
+ ClientSocketHandle* const handle_;
+ TestClientSocketPool* const pool_;
+ bool within_callback_;
+ TestConnectJobFactory* const test_connect_job_factory_;
+ TestConnectJob::JobType next_job_type_;
+ CompletionCallback callback_;
+};
+
+TEST_F(ClientSocketPoolBaseTest, RequestPendingJobTwice) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ RequestSocketCallback callback(
+ &handle, pool_.get(), connect_job_factory_,
+ TestConnectJob::kMockPendingJob);
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestPendingJobThenSynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ RequestSocketCallback callback(
+ &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockJob);
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+}
+
+// Make sure that pending requests get serviced after active requests get
+// cancelled.
+TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestWithPendingRequests) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ // Now, kDefaultMaxSocketsPerGroup requests should be active.
+ // Let's cancel them.
+ for (int i = 0; i < kDefaultMaxSocketsPerGroup; ++i) {
+ ASSERT_FALSE(request(i)->handle()->is_initialized());
+ request(i)->handle()->Reset();
+ }
+
+ // Let's wait for the rest to complete now.
+ for (size_t i = kDefaultMaxSocketsPerGroup; i < requests_size(); ++i) {
+ EXPECT_EQ(OK, request(i)->WaitForResult());
+ request(i)->handle()->Reset();
+ }
+
+ EXPECT_EQ(requests_size() - kDefaultMaxSocketsPerGroup,
+ completion_count());
+}
+
+// Make sure that pending requests get serviced after active requests fail.
+TEST_F(ClientSocketPoolBaseTest, FailingActiveRequestWithPendingRequests) {
+ const size_t kMaxSockets = 5;
+ CreatePool(kMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob);
+
+ const size_t kNumberOfRequests = 2 * kDefaultMaxSocketsPerGroup + 1;
+ ASSERT_LE(kNumberOfRequests, kMaxSockets); // Otherwise the test will hang.
+
+ // Queue up all the requests
+ for (size_t i = 0; i < kNumberOfRequests; ++i)
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ for (size_t i = 0; i < kNumberOfRequests; ++i)
+ EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult());
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Cancel the active request.
+ handle.Reset();
+
+ rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ EXPECT_FALSE(handle.is_reused());
+ TestLoadTimingInfoConnectedNotReused(handle);
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+}
+
+// Regression test for http://crbug.com/17985.
+TEST_F(ClientSocketPoolBaseTest, GroupWithPendingRequestsIsNotEmpty) {
+ const int kMaxSockets = 3;
+ const int kMaxSocketsPerGroup = 2;
+ CreatePool(kMaxSockets, kMaxSocketsPerGroup);
+
+ const RequestPriority kHighPriority = HIGHEST;
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+
+ // This is going to be a pending request in an otherwise empty group.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ // Reach the maximum socket limit.
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+
+ // Create a stalled group with high priorities.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kHighPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("c", kHighPriority));
+
+ // Release the first two sockets from "a". Because this is a keepalive,
+ // the first release will unblock the pending request for "a". The
+ // second release will unblock a request for "c", becaue it is the next
+ // high priority socket.
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::KEEP_ALIVE));
+
+ // Closing idle sockets should not get us into trouble, but in the bug
+ // we were hitting a CHECK here.
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ pool_->CloseIdleSockets();
+
+ // Run the released socket wakeups.
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ CapturingBoundNetLog log;
+ int rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ log.bound());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+ TestLoadTimingInfoNotConnected(handle);
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoConnectedNotReused(handle);
+
+ handle.Reset();
+ TestLoadTimingInfoNotConnected(handle);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+
+ EXPECT_EQ(4u, entries.size());
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKET_POOL));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 2, NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 3, NetLog::TYPE_SOCKET_POOL));
+}
+
+TEST_F(ClientSocketPoolBaseTest,
+ InitConnectionAsynchronousFailure) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ CapturingBoundNetLog log;
+ // Set the additional error state members to ensure that they get cleared.
+ handle.set_is_ssl_error(true);
+ HttpResponseInfo info;
+ info.headers = new HttpResponseHeaders(std::string());
+ handle.set_ssl_error_response_info(info);
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ log.bound()));
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_ssl_error());
+ EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+
+ EXPECT_EQ(3u, entries.size());
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKET_POOL));
+ EXPECT_TRUE(LogContainsEvent(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ NetLog::PHASE_NONE));
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, 2, NetLog::TYPE_SOCKET_POOL));
+}
+
+TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) {
+ // TODO(eroman): Add back the log expectations! Removed them because the
+ // ordering is difficult, and some may fire during destructor.
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ CapturingBoundNetLog log2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ handle.Reset();
+
+
+ // At this point, request 2 is just waiting for the connect job to finish.
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ handle2.Reset();
+
+ // Now request 2 has actually finished.
+ // TODO(eroman): Add back log expectations.
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelRequestLimitsJobs) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a"));
+ (*requests())[2]->handle()->Reset();
+ (*requests())[3]->handle()->Reset();
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a"));
+
+ (*requests())[1]->handle()->Reset();
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a"));
+
+ (*requests())[0]->handle()->Reset();
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->NumConnectJobsInGroup("a"));
+}
+
+// When requests and ConnectJobs are not coupled, the request will get serviced
+// by whatever comes first.
+TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ // Start job 1 (async OK)
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ std::vector<TestSocketRequest*> request_order;
+ size_t completion_count; // unused
+ TestSocketRequest req1(&request_order, &completion_count);
+ int rv = req1.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req1.callback(), pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(OK, req1.WaitForResult());
+
+ // Job 1 finished OK. Start job 2 (also async OK). Request 3 is pending
+ // without a job.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ TestSocketRequest req2(&request_order, &completion_count);
+ rv = req2.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ TestSocketRequest req3(&request_order, &completion_count);
+ rv = req3.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req3.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Both Requests 2 and 3 are pending. We release socket 1 which should
+ // service request 2. Request 3 should still be waiting.
+ req1.handle()->Reset();
+ // Run the released socket wakeups.
+ base::MessageLoop::current()->RunUntilIdle();
+ ASSERT_TRUE(req2.handle()->socket());
+ EXPECT_EQ(OK, req2.WaitForResult());
+ EXPECT_FALSE(req3.handle()->socket());
+
+ // Signal job 2, which should service request 3.
+
+ client_socket_factory_.SignalJobs();
+ EXPECT_EQ(OK, req3.WaitForResult());
+
+ ASSERT_EQ(3U, request_order.size());
+ EXPECT_EQ(&req1, request_order[0]);
+ EXPECT_EQ(&req2, request_order[1]);
+ EXPECT_EQ(&req3, request_order[2]);
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+}
+
+// The requests are not coupled to the jobs. So, the requests should finish in
+// their priority / insertion order.
+TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ // First two jobs are async.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob);
+
+ std::vector<TestSocketRequest*> request_order;
+ size_t completion_count; // unused
+ TestSocketRequest req1(&request_order, &completion_count);
+ int rv = req1.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req1.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ TestSocketRequest req2(&request_order, &completion_count);
+ rv = req2.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // The pending job is sync.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ TestSocketRequest req3(&request_order, &completion_count);
+ rv = req3.handle()->Init("a",
+ params_,
+ kDefaultPriority,
+ req3.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, req1.WaitForResult());
+ EXPECT_EQ(OK, req2.WaitForResult());
+ EXPECT_EQ(ERR_CONNECTION_FAILED, req3.WaitForResult());
+
+ ASSERT_EQ(3U, request_order.size());
+ EXPECT_EQ(&req1, request_order[0]);
+ EXPECT_EQ(&req2, request_order[1]);
+ EXPECT_EQ(&req3, request_order[2]);
+}
+
+// Test GetLoadState in the case there's only one socket request.
+TEST_F(ClientSocketPoolBaseTest, LoadStateOneRequest) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState());
+
+ client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE);
+ EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState());
+
+ // No point in completing the connection, since ClientSocketHandles only
+ // expect the LoadState to be checked while connecting.
+}
+
+// Test GetLoadState in the case there are two socket requests.
+TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) {
+ CreatePool(2, 2);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // If the first Job is in an earlier state than the second, the state of
+ // the second job should be used for both handles.
+ client_socket_factory_.SetJobLoadState(0, LOAD_STATE_RESOLVING_HOST);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState());
+
+ // If the second Job is in an earlier state than the second, the state of
+ // the first job should be used for both handles.
+ client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE);
+ // One request is farther
+ EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle2.GetLoadState());
+
+ // Farthest along job connects and the first request gets the socket. The
+ // second handle switches to the state of the remaining ConnectJob.
+ client_socket_factory_.SignalJob(0);
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState());
+}
+
+// Test GetLoadState in the case the per-group limit is reached.
+TEST_F(ClientSocketPoolBaseTest, LoadStateGroupLimit) {
+ CreatePool(2, 1);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ MEDIUM,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState());
+
+ // Request another socket from the same pool, buth with a higher priority.
+ // The first request should now be stalled at the socket group limit.
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("a",
+ params_,
+ HIGHEST,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState());
+
+ // The first handle should remain stalled as the other socket goes through
+ // the connect process.
+
+ client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE);
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle2.GetLoadState());
+
+ client_socket_factory_.SignalJob(0);
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState());
+
+ // Closing the second socket should cause the stalled handle to finally get a
+ // ConnectJob.
+ handle2.socket()->Disconnect();
+ handle2.Reset();
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState());
+}
+
+// Test GetLoadState in the case the per-pool limit is reached.
+TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) {
+ CreatePool(2, 2);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Request for socket from another pool.
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("b",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Request another socket from the first pool. Request should stall at the
+ // socket pool limit.
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ rv = handle3.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // The third handle should remain stalled as the other sockets in its group
+ // goes through the connect process.
+
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState());
+
+ client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE);
+ EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState());
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState());
+
+ client_socket_factory_.SignalJob(0);
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL, handle3.GetLoadState());
+
+ // Closing a socket should allow the stalled handle to finally get a new
+ // ConnectJob.
+ handle.socket()->Disconnect();
+ handle.Reset();
+ EXPECT_EQ(LOAD_STATE_CONNECTING, handle3.GetLoadState());
+}
+
+TEST_F(ClientSocketPoolBaseTest, Recoverable) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockRecoverableJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ pool_.get(), BoundNetLog()));
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+}
+
+TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(
+ TestConnectJob::kMockPendingRecoverableJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+ EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+}
+
+TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateSynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(
+ TestConnectJob::kMockAdditionalErrorStateJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_CONNECTION_FAILED,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_TRUE(handle.is_ssl_error());
+ EXPECT_FALSE(handle.ssl_error_response_info().headers.get() == NULL);
+}
+
+TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateAsynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(
+ TestConnectJob::kMockPendingAdditionalErrorStateJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_TRUE(handle.is_ssl_error());
+ EXPECT_FALSE(handle.ssl_error_response_info().headers.get() == NULL);
+}
+
+// Make sure we can reuse sockets when the cleanup timer is disabled.
+TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) {
+ // Disable cleanup timer.
+ internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(false);
+
+ CreatePoolWithIdleTimeouts(
+ kDefaultMaxSockets, kDefaultMaxSocketsPerGroup,
+ base::TimeDelta(), // Time out unused sockets immediately.
+ base::TimeDelta::FromDays(1)); // Don't time out used sockets.
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+ ASSERT_EQ(OK, callback.WaitForResult());
+
+ // Use and release the socket.
+ EXPECT_EQ(1, handle.socket()->Write(NULL, 1, CompletionCallback()));
+ TestLoadTimingInfoConnectedNotReused(handle);
+ handle.Reset();
+
+ // Should now have one idle socket.
+ ASSERT_EQ(1, pool_->IdleSocketCount());
+
+ // Request a new socket. This should reuse the old socket and complete
+ // synchronously.
+ CapturingBoundNetLog log;
+ rv = handle.Init("a",
+ params_,
+ LOWEST,
+ CompletionCallback(),
+ pool_.get(),
+ log.bound());
+ ASSERT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_reused());
+ TestLoadTimingInfoConnectedReused(handle);
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEntryWithType(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET));
+}
+
+// Make sure we cleanup old unused sockets when the cleanup timer is disabled.
+TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerNoReuse) {
+ // Disable cleanup timer.
+ internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(false);
+
+ CreatePoolWithIdleTimeouts(
+ kDefaultMaxSockets, kDefaultMaxSocketsPerGroup,
+ base::TimeDelta(), // Time out unused sockets immediately
+ base::TimeDelta()); // Time out used sockets immediately
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ // Startup two mock pending connect jobs, which will sit in the MessageLoop.
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("a",
+ params_,
+ LOWEST,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2));
+
+ // Cancel one of the requests. Wait for the other, which will get the first
+ // job. Release the socket. Run the loop again to make sure the second
+ // socket is sitting idle and the first one is released (since ReleaseSocket()
+ // just posts a DoReleaseSocket() task).
+
+ handle.Reset();
+ ASSERT_EQ(OK, callback2.WaitForResult());
+ // Use the socket.
+ EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback()));
+ handle2.Reset();
+
+ // We post all of our delayed tasks with a 2ms delay. I.e. they don't
+ // actually become pending until 2ms after they have been created. In order
+ // to flush all tasks, we need to wait so that we know there are no
+ // soon-to-be-pending tasks waiting.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10));
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Both sockets should now be idle.
+ ASSERT_EQ(2, pool_->IdleSocketCount());
+
+ // Request a new socket. This should cleanup the unused and timed out ones.
+ // A new socket will be created rather than reusing the idle one.
+ CapturingBoundNetLog log;
+ TestCompletionCallback callback3;
+ rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback3.callback(),
+ pool_.get(),
+ log.bound());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ ASSERT_EQ(OK, callback3.WaitForResult());
+ EXPECT_FALSE(handle.is_reused());
+
+ // Make sure the idle socket is closed.
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_FALSE(LogContainsEntryWithType(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET));
+}
+
+TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) {
+ CreatePoolWithIdleTimeouts(
+ kDefaultMaxSockets, kDefaultMaxSocketsPerGroup,
+ base::TimeDelta(), // Time out unused sockets immediately.
+ base::TimeDelta::FromDays(1)); // Don't time out used sockets.
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ // Startup two mock pending connect jobs, which will sit in the MessageLoop.
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("a",
+ params_,
+ LOWEST,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2));
+
+ // Cancel one of the requests. Wait for the other, which will get the first
+ // job. Release the socket. Run the loop again to make sure the second
+ // socket is sitting idle and the first one is released (since ReleaseSocket()
+ // just posts a DoReleaseSocket() task).
+
+ handle.Reset();
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ // Use the socket.
+ EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, CompletionCallback()));
+ handle2.Reset();
+
+ // We post all of our delayed tasks with a 2ms delay. I.e. they don't
+ // actually become pending until 2ms after they have been created. In order
+ // to flush all tasks, we need to wait so that we know there are no
+ // soon-to-be-pending tasks waiting.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(10));
+ base::MessageLoop::current()->RunUntilIdle();
+
+ ASSERT_EQ(2, pool_->IdleSocketCount());
+
+ // Invoke the idle socket cleanup check. Only one socket should be left, the
+ // used socket. Request it to make sure that it's used.
+
+ pool_->CleanupTimedOutIdleSockets();
+ CapturingBoundNetLog log;
+ rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ log.bound());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_reused());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEntryWithType(
+ entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET));
+}
+
+// Make sure that we process all pending requests even when we're stalling
+// because of multiple releasing disconnected sockets.
+TEST_F(ClientSocketPoolBaseTest, MultipleReleasingDisconnectedSockets) {
+ CreatePoolWithIdleTimeouts(
+ kDefaultMaxSockets, kDefaultMaxSocketsPerGroup,
+ base::TimeDelta(), // Time out unused sockets immediately.
+ base::TimeDelta::FromDays(1)); // Don't time out used sockets.
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ // Startup 4 connect jobs. Two of them will be pending.
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init("a",
+ params_,
+ LOWEST,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(OK, rv);
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ rv = handle2.Init("a",
+ params_,
+ LOWEST,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(OK, rv);
+
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ rv = handle3.Init("a",
+ params_,
+ LOWEST,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ ClientSocketHandle handle4;
+ TestCompletionCallback callback4;
+ rv = handle4.Init("a",
+ params_,
+ LOWEST,
+ callback4.callback(),
+ pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Release two disconnected sockets.
+
+ handle.socket()->Disconnect();
+ handle.Reset();
+ handle2.socket()->Disconnect();
+ handle2.Reset();
+
+ EXPECT_EQ(OK, callback3.WaitForResult());
+ EXPECT_FALSE(handle3.is_reused());
+ EXPECT_EQ(OK, callback4.WaitForResult());
+ EXPECT_FALSE(handle4.is_reused());
+}
+
+// Regression test for http://crbug.com/42267.
+// When DoReleaseSocket() is processed for one socket, it is blocked because the
+// other stalled groups all have releasing sockets, so no progress can be made.
+TEST_F(ClientSocketPoolBaseTest, SocketLimitReleasingSockets) {
+ CreatePoolWithIdleTimeouts(
+ 4 /* socket limit */, 4 /* socket limit per group */,
+ base::TimeDelta(), // Time out unused sockets immediately.
+ base::TimeDelta::FromDays(1)); // Don't time out used sockets.
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ // Max out the socket limit with 2 per group.
+
+ ClientSocketHandle handle_a[4];
+ TestCompletionCallback callback_a[4];
+ ClientSocketHandle handle_b[4];
+ TestCompletionCallback callback_b[4];
+
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(OK, handle_a[i].Init("a",
+ params_,
+ LOWEST,
+ callback_a[i].callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, handle_b[i].Init("b",
+ params_,
+ LOWEST,
+ callback_b[i].callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // Make 4 pending requests, 2 per group.
+
+ for (int i = 2; i < 4; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle_a[i].Init("a",
+ params_,
+ LOWEST,
+ callback_a[i].callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle_b[i].Init("b",
+ params_,
+ LOWEST,
+ callback_b[i].callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ // Release b's socket first. The order is important, because in
+ // DoReleaseSocket(), we'll process b's released socket, and since both b and
+ // a are stalled, but 'a' is lower lexicographically, we'll process group 'a'
+ // first, which has a releasing socket, so it refuses to start up another
+ // ConnectJob. So, we used to infinite loop on this.
+ handle_b[0].socket()->Disconnect();
+ handle_b[0].Reset();
+ handle_a[0].socket()->Disconnect();
+ handle_a[0].Reset();
+
+ // Used to get stuck here.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ handle_b[1].socket()->Disconnect();
+ handle_b[1].Reset();
+ handle_a[1].socket()->Disconnect();
+ handle_a[1].Reset();
+
+ for (int i = 2; i < 4; ++i) {
+ EXPECT_EQ(OK, callback_b[i].WaitForResult());
+ EXPECT_EQ(OK, callback_a[i].WaitForResult());
+ }
+}
+
+TEST_F(ClientSocketPoolBaseTest,
+ ReleasingDisconnectedSocketsMaintainsPriorityOrder) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ EXPECT_EQ(OK, (*requests())[0]->WaitForResult());
+ EXPECT_EQ(OK, (*requests())[1]->WaitForResult());
+ EXPECT_EQ(2u, completion_count());
+
+ // Releases one connection.
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE));
+ EXPECT_EQ(OK, (*requests())[2]->WaitForResult());
+
+ EXPECT_TRUE(ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE));
+ EXPECT_EQ(OK, (*requests())[3]->WaitForResult());
+ EXPECT_EQ(4u, completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(5));
+}
+
+class TestReleasingSocketRequest : public TestCompletionCallbackBase {
+ public:
+ TestReleasingSocketRequest(TestClientSocketPool* pool,
+ int expected_result,
+ bool reset_releasing_handle)
+ : pool_(pool),
+ expected_result_(expected_result),
+ reset_releasing_handle_(reset_releasing_handle),
+ callback_(base::Bind(&TestReleasingSocketRequest::OnComplete,
+ base::Unretained(this))) {
+ }
+
+ virtual ~TestReleasingSocketRequest() {}
+
+ ClientSocketHandle* handle() { return &handle_; }
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ SetResult(result);
+ if (reset_releasing_handle_)
+ handle_.Reset();
+
+ scoped_refptr<TestSocketParams> con_params(new TestSocketParams());
+ EXPECT_EQ(expected_result_,
+ handle2_.Init("a", con_params, kDefaultPriority,
+ callback2_.callback(), pool_, BoundNetLog()));
+ }
+
+ TestClientSocketPool* const pool_;
+ int expected_result_;
+ bool reset_releasing_handle_;
+ ClientSocketHandle handle_;
+ ClientSocketHandle handle2_;
+ CompletionCallback callback_;
+ TestCompletionCallback callback2_;
+};
+
+
+TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("b", kDefaultPriority));
+
+ EXPECT_EQ(static_cast<int>(requests_size()),
+ client_socket_factory_.allocation_count());
+
+ connect_job_factory_->set_job_type(
+ TestConnectJob::kMockPendingAdditionalErrorStateJob);
+ TestReleasingSocketRequest req(pool_.get(), OK, false);
+ EXPECT_EQ(ERR_IO_PENDING,
+ req.handle()->Init("a", params_, kDefaultPriority, req.callback(),
+ pool_.get(), BoundNetLog()));
+ // The next job should complete synchronously
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, req.WaitForResult());
+ EXPECT_FALSE(req.handle()->is_initialized());
+ EXPECT_FALSE(req.handle()->socket());
+ EXPECT_TRUE(req.handle()->is_ssl_error());
+ EXPECT_FALSE(req.handle()->ssl_error_response_info().headers.get() == NULL);
+}
+
+// http://crbug.com/44724 regression test.
+// We start releasing the pool when we flush on network change. When that
+// happens, the only active references are in the ClientSocketHandles. When a
+// ConnectJob completes and calls back into the last ClientSocketHandle, that
+// callback can release the last reference and delete the pool. After the
+// callback finishes, we go back to the stack frame within the now-deleted pool.
+// Executing any code that refers to members of the now-deleted pool can cause
+// crashes.
+TEST_F(ClientSocketPoolBaseTest, CallbackThatReleasesPool) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ pool_->FlushWithError(ERR_NETWORK_CHANGED);
+
+ // We'll call back into this now.
+ callback.WaitForResult();
+}
+
+TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type());
+
+ pool_->FlushWithError(ERR_NETWORK_CHANGED);
+
+ handle.Reset();
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type());
+}
+
+class ConnectWithinCallback : public TestCompletionCallbackBase {
+ public:
+ ConnectWithinCallback(
+ const std::string& group_name,
+ const scoped_refptr<TestSocketParams>& params,
+ TestClientSocketPool* pool)
+ : group_name_(group_name),
+ params_(params),
+ pool_(pool),
+ callback_(base::Bind(&ConnectWithinCallback::OnComplete,
+ base::Unretained(this))) {
+ }
+
+ virtual ~ConnectWithinCallback() {}
+
+ int WaitForNestedResult() {
+ return nested_callback_.WaitForResult();
+ }
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ SetResult(result);
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle_.Init(group_name_,
+ params_,
+ kDefaultPriority,
+ nested_callback_.callback(),
+ pool_,
+ BoundNetLog()));
+ }
+
+ const std::string group_name_;
+ const scoped_refptr<TestSocketParams> params_;
+ TestClientSocketPool* const pool_;
+ ClientSocketHandle handle_;
+ CompletionCallback callback_;
+ TestCompletionCallback nested_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(ConnectWithinCallback);
+};
+
+TEST_F(ClientSocketPoolBaseTest, AbortAllRequestsOnFlush) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+
+ // First job will be waiting until it gets aborted.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+
+ ClientSocketHandle handle;
+ ConnectWithinCallback callback("a", params_, pool_.get());
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ // Second job will be started during the first callback, and will
+ // asynchronously complete with OK.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ pool_->FlushWithError(ERR_NETWORK_CHANGED);
+ EXPECT_EQ(ERR_NETWORK_CHANGED, callback.WaitForResult());
+ EXPECT_EQ(OK, callback.WaitForNestedResult());
+}
+
+// Cancel a pending socket request while we're at max sockets,
+// and verify that the backup socket firing doesn't cause a crash.
+TEST_F(ClientSocketPoolBaseTest, BackupSocketCancelAtMaxSockets) {
+ // Max 4 sockets globally, max 4 sockets per group.
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ pool_->EnableConnectBackupJobs();
+
+ // Create the first socket and set to ERR_IO_PENDING. This starts the backup
+ // timer.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ // Start (MaxSockets - 1) connected sockets to reach max sockets.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+ ClientSocketHandle handles[kDefaultMaxSockets];
+ for (int i = 1; i < kDefaultMaxSockets; ++i) {
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handles[i].Init("bar",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ }
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Cancel the pending request.
+ handle.Reset();
+
+ // Wait for the backup timer to fire (add some slop to ensure it fires)
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3));
+
+ base::MessageLoop::current()->RunUntilIdle();
+ EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count());
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterCancelingAllRequests) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ pool_->EnableConnectBackupJobs();
+
+ // Create the first socket and set to ERR_IO_PENDING. This starts the backup
+ // timer.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ASSERT_TRUE(pool_->HasGroup("bar"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("bar"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("bar"));
+
+ // Cancel the socket request. This should cancel the backup timer. Wait for
+ // the backup time to see if it indeed got canceled.
+ handle.Reset();
+ // Wait for the backup timer to fire (add some slop to ensure it fires)
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3));
+ base::MessageLoop::current()->RunUntilIdle();
+ ASSERT_TRUE(pool_->HasGroup("bar"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("bar"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterFinishingAllRequests) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ pool_->EnableConnectBackupJobs();
+
+ // Create the first socket and set to ERR_IO_PENDING. This starts the backup
+ // timer.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("bar",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ASSERT_TRUE(pool_->HasGroup("bar"));
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("bar"));
+
+ // Cancel request 1 and then complete request 2. With the requests finished,
+ // the backup timer should be cancelled.
+ handle.Reset();
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ // Wait for the backup timer to fire (add some slop to ensure it fires)
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs / 2 * 3));
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+// Test delayed socket binding for the case where we have two connects,
+// and while one is waiting on a connect, the other frees up.
+// The socket waiting on a connect should switch immediately to the freed
+// up socket.
+TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // No idle sockets, no pending jobs.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ // Create a second socket to the same host, but this one will wait.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ // No idle sockets, and one connecting job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Return the first handle to the pool. This will initiate the delayed
+ // binding.
+ handle1.Reset();
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Still no idle sockets, still one pending connect job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // The second socket connected, even though it was a Waiting Job.
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // And we can see there is still one job waiting.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Finally, signal the waiting Connect.
+ client_socket_factory_.SignalJobs();
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+// Test delayed socket binding when a group is at capacity and one
+// of the group's sockets frees up.
+TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // No idle sockets, no pending jobs.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ // Create a second socket to the same host, but this one will wait.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ // No idle sockets, and one connecting job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Return the first handle to the pool. This will initiate the delayed
+ // binding.
+ handle1.Reset();
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Still no idle sockets, still one pending connect job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // The second socket connected, even though it was a Waiting Job.
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // And we can see there is still one job waiting.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Finally, signal the waiting Connect.
+ client_socket_factory_.SignalJobs();
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+// Test out the case where we have one socket connected, one
+// connecting, when the first socket finishes and goes idle.
+// Although the second connection is pending, the second request
+// should complete, by taking the first socket's idle socket.
+TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // No idle sockets, no pending jobs.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ // Create a second socket to the same host, but this one will wait.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ ClientSocketHandle handle2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ // No idle sockets, and one connecting job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Return the first handle to the pool. This will initiate the delayed
+ // binding.
+ handle1.Reset();
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Still no idle sockets, still one pending connect job.
+ EXPECT_EQ(0, pool_->IdleSocketCount());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // The second socket connected, even though it was a Waiting Job.
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // And we can see there is still one job waiting.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Finally, signal the waiting Connect.
+ client_socket_factory_.SignalJobs();
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+// Cover the case where on an available socket slot, we have one pending
+// request that completes synchronously, thereby making the Group empty.
+TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) {
+ const int kUnlimitedSockets = 100;
+ const int kOneSocketPerGroup = 1;
+ CreatePool(kUnlimitedSockets, kOneSocketPerGroup);
+
+ // Make the first request asynchronous fail.
+ // This will free up a socket slot later.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Make the second request synchronously fail. This should make the Group
+ // empty.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob);
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ // It'll be ERR_IO_PENDING now, but the TestConnectJob will synchronously fail
+ // when created.
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback1.WaitForResult());
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback2.WaitForResult());
+ EXPECT_FALSE(pool_->HasGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a",
+ params_,
+ kDefaultPriority,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(OK, callback1.WaitForResult());
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_EQ(OK, callback3.WaitForResult());
+
+ // Use the socket.
+ EXPECT_EQ(1, handle1.socket()->Write(NULL, 1, CompletionCallback()));
+ EXPECT_EQ(1, handle3.socket()->Write(NULL, 1, CompletionCallback()));
+
+ handle1.Reset();
+ handle2.Reset();
+ handle3.Reset();
+
+ EXPECT_EQ(OK, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, handle3.Init("a",
+ params_,
+ kDefaultPriority,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_TRUE(handle1.socket()->WasEverUsed());
+ EXPECT_TRUE(handle2.socket()->WasEverUsed());
+ EXPECT_FALSE(handle3.socket()->WasEverUsed());
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSockets) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ EXPECT_EQ(OK, callback1.WaitForResult());
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ handle1.Reset();
+ handle2.Reset();
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ EXPECT_EQ(OK, callback1.WaitForResult());
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ handle1.Reset();
+ handle2.Reset();
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest,
+ RequestSocketsWhenAlreadyHaveMultipleConnectJob) {
+ CreatePool(4, 4);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a",
+ params_,
+ kDefaultPriority,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ EXPECT_EQ(OK, callback1.WaitForResult());
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_EQ(OK, callback3.WaitForResult());
+ handle1.Reset();
+ handle2.Reset();
+ handle3.Reset();
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(3, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsAtMaxSocketLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ASSERT_FALSE(pool_->HasGroup("a"));
+
+ pool_->RequestSockets("a", &params_, kDefaultMaxSockets,
+ BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(kDefaultMaxSockets, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(kDefaultMaxSockets, pool_->NumUnassignedConnectJobsInGroup("a"));
+
+ ASSERT_FALSE(pool_->HasGroup("b"));
+
+ pool_->RequestSockets("b", &params_, kDefaultMaxSockets,
+ BoundNetLog());
+
+ ASSERT_FALSE(pool_->HasGroup("b"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsHitMaxSocketLimit) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSockets);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ASSERT_FALSE(pool_->HasGroup("a"));
+
+ pool_->RequestSockets("a", &params_, kDefaultMaxSockets - 1,
+ BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(kDefaultMaxSockets - 1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(kDefaultMaxSockets - 1,
+ pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_FALSE(pool_->IsStalled());
+
+ ASSERT_FALSE(pool_->HasGroup("b"));
+
+ pool_->RequestSockets("b", &params_, kDefaultMaxSockets,
+ BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("b"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("b"));
+ EXPECT_FALSE(pool_->IsStalled());
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountIdleSockets) {
+ CreatePool(4, 4);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ASSERT_EQ(OK, callback1.WaitForResult());
+ handle1.Reset();
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountActiveSockets) {
+ CreatePool(4, 4);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ASSERT_EQ(OK, callback1.WaitForResult());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronous) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ pool_->RequestSockets("a", &params_, kDefaultMaxSocketsPerGroup,
+ BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("b", &params_, kDefaultMaxSocketsPerGroup,
+ BoundNetLog());
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b"));
+ EXPECT_EQ(kDefaultMaxSocketsPerGroup, pool_->IdleSocketCountInGroup("b"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsSynchronousError) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob);
+
+ pool_->RequestSockets("a", &params_, kDefaultMaxSocketsPerGroup,
+ BoundNetLog());
+
+ ASSERT_FALSE(pool_->HasGroup("a"));
+
+ connect_job_factory_->set_job_type(
+ TestConnectJob::kMockAdditionalErrorStateJob);
+ pool_->RequestSockets("a", &params_, kDefaultMaxSocketsPerGroup,
+ BoundNetLog());
+
+ ASSERT_FALSE(pool_->HasGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) {
+ CreatePool(4, 4);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ ASSERT_EQ(OK, callback1.WaitForResult());
+
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ int rv = handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog());
+ if (rv != OK) {
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ }
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ handle1.Reset();
+ handle2.Reset();
+
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, RequestSocketsDifferentNumSockets) {
+ CreatePool(4, 4);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+ EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(2, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 3, BoundNetLog());
+ EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(3, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+ EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(3, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, PreconnectJobsTakenByNormalRequests) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ ASSERT_EQ(OK, callback1.WaitForResult());
+
+ // Make sure if a preconneced socket is not fully connected when a request
+ // starts, it has a connect start time.
+ TestLoadTimingInfoConnectedNotReused(handle1);
+ handle1.Reset();
+
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+}
+
+// Checks that fully connected preconnect jobs have no connect times, and are
+// marked as reused.
+TEST_F(ClientSocketPoolBaseTest, ConnectedPreconnectJobsHaveNoConnectTimes) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+
+ ASSERT_TRUE(pool_->HasGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(OK, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ // Make sure the idle socket was used.
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ TestLoadTimingInfoConnectedReused(handle);
+ handle.Reset();
+ TestLoadTimingInfoNotConnected(handle);
+}
+
+// http://crbug.com/64940 regression test.
+TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) {
+ const int kMaxTotalSockets = 3;
+ const int kMaxSocketsPerGroup = 2;
+ CreatePool(kMaxTotalSockets, kMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ // Note that group name ordering matters here. "a" comes before "b", so
+ // CloseOneIdleSocket() will try to close "a"'s idle socket.
+
+ // Set up one idle socket in "a".
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ASSERT_EQ(OK, callback1.WaitForResult());
+ handle1.Reset();
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+
+ // Set up two active sockets in "b".
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle1.Init("b",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(ERR_IO_PENDING, handle2.Init("b",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ ASSERT_EQ(OK, callback1.WaitForResult());
+ ASSERT_EQ(OK, callback2.WaitForResult());
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b"));
+ EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b"));
+
+ // Now we have 1 idle socket in "a" and 2 active sockets in "b". This means
+ // we've maxed out on sockets, since we set |kMaxTotalSockets| to 3.
+ // Requesting 2 preconnected sockets for "a" should fail to allocate any more
+ // sockets for "a", and "b" should still have 2 active sockets.
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("b"));
+ EXPECT_EQ(2, pool_->NumActiveSocketsInGroup("b"));
+
+ // Now release the 2 active sockets for "b". This will give us 1 idle socket
+ // in "a" and 2 idle sockets in "b". Requesting 2 preconnected sockets for
+ // "a" should result in closing 1 for "b".
+ handle1.Reset();
+ handle2.Reset();
+ EXPECT_EQ(2, pool_->IdleSocketCountInGroup("b"));
+ EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b"));
+
+ pool_->RequestSockets("a", &params_, 2, BoundNetLog());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("b"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("b"));
+ EXPECT_EQ(1, pool_->IdleSocketCountInGroup("b"));
+ EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("b"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, PreconnectWithoutBackupJob) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ pool_->EnableConnectBackupJobs();
+
+ // Make the ConnectJob hang until it times out, shorten the timeout.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ connect_job_factory_->set_timeout_duration(
+ base::TimeDelta::FromMilliseconds(500));
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+
+ // Verify the backup timer doesn't create a backup job, by making
+ // the backup job a pending job instead of a waiting job, so it
+ // *would* complete if it were created.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::MessageLoop::QuitClosure(),
+ base::TimeDelta::FromSeconds(1));
+ base::MessageLoop::current()->Run();
+ EXPECT_FALSE(pool_->HasGroup("a"));
+}
+
+TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ pool_->EnableConnectBackupJobs();
+
+ // Make the ConnectJob hang forever.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob);
+ pool_->RequestSockets("a", &params_, 1, BoundNetLog());
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(1, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Make the backup job be a pending job, so it completes normally.
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ // Timer has started, but the backup connect job shouldn't be created yet.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(0, pool_->NumActiveSocketsInGroup("a"));
+ ASSERT_EQ(OK, callback.WaitForResult());
+
+ // The hung connect job should still be there, but everything else should be
+ // complete.
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a"));
+ EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a"));
+ EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a"));
+}
+
+class MockLayeredPool : public LayeredPool {
+ public:
+ MockLayeredPool(TestClientSocketPool* pool,
+ const std::string& group_name)
+ : pool_(pool),
+ params_(new TestSocketParams),
+ group_name_(group_name),
+ can_release_connection_(true) {
+ pool_->AddLayeredPool(this);
+ }
+
+ ~MockLayeredPool() {
+ pool_->RemoveLayeredPool(this);
+ }
+
+ int RequestSocket(TestClientSocketPool* pool) {
+ return handle_.Init(group_name_, params_, kDefaultPriority,
+ callback_.callback(), pool, BoundNetLog());
+ }
+
+ int RequestSocketWithoutLimits(TestClientSocketPool* pool) {
+ params_->set_ignore_limits(true);
+ return handle_.Init(group_name_, params_, kDefaultPriority,
+ callback_.callback(), pool, BoundNetLog());
+ }
+
+ bool ReleaseOneConnection() {
+ if (!handle_.is_initialized() || !can_release_connection_) {
+ return false;
+ }
+ handle_.socket()->Disconnect();
+ handle_.Reset();
+ return true;
+ }
+
+ void set_can_release_connection(bool can_release_connection) {
+ can_release_connection_ = can_release_connection;
+ }
+
+ MOCK_METHOD0(CloseOneIdleConnection, bool());
+
+ private:
+ TestClientSocketPool* const pool_;
+ scoped_refptr<TestSocketParams> params_;
+ ClientSocketHandle handle_;
+ TestCompletionCallback callback_;
+ const std::string group_name_;
+ bool can_release_connection_;
+};
+
+TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "foo");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillOnce(Return(false));
+ EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool());
+}
+
+TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) {
+ CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "foo");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillOnce(Invoke(&mock_layered_pool,
+ &MockLayeredPool::ReleaseOneConnection));
+ EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool());
+}
+
+// Tests the basic case of closing an idle socket in a higher layered pool when
+// a new request is issued and the lower layer pool is stalled.
+TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketsHeldByLayeredPoolWhenNeeded) {
+ CreatePool(1, 1);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "foo");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillOnce(Invoke(&mock_layered_pool,
+ &MockLayeredPool::ReleaseOneConnection));
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+}
+
+// Same as above, but the idle socket is in the same group as the stalled
+// socket, and closes the only other request in its group when closing requests
+// in higher layered pools. This generally shouldn't happen, but it may be
+// possible if a higher level pool issues a request and the request is
+// subsequently cancelled. Even if it's not possible, best not to crash.
+TEST_F(ClientSocketPoolBaseTest,
+ CloseIdleSocketsHeldByLayeredPoolWhenNeededSameGroup) {
+ CreatePool(2, 2);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ // Need a socket in another group for the pool to be stalled (If a group
+ // has the maximum number of connections already, it's not stalled).
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(OK, handle1.Init("group1",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "group2");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillOnce(Invoke(&mock_layered_pool,
+ &MockLayeredPool::ReleaseOneConnection));
+ ClientSocketHandle handle;
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("group2",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback2.WaitForResult());
+}
+
+// Tests the case when an idle socket can be closed when a new request is
+// issued, and the new request belongs to a group that was previously stalled.
+TEST_F(ClientSocketPoolBaseTest,
+ CloseIdleSocketsHeldByLayeredPoolInSameGroupWhenNeeded) {
+ CreatePool(2, 2);
+ std::list<TestConnectJob::JobType> job_types;
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ connect_job_factory_->set_job_types(&job_types);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(OK, handle1.Init("group1",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "group2");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillRepeatedly(Invoke(&mock_layered_pool,
+ &MockLayeredPool::ReleaseOneConnection));
+ mock_layered_pool.set_can_release_connection(false);
+
+ // The third request is made when the socket pool is in a stalled state.
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3",
+ params_,
+ kDefaultPriority,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ base::RunLoop().RunUntilIdle();
+ EXPECT_FALSE(callback3.have_result());
+
+ // The fourth request is made when the pool is no longer stalled. The third
+ // request should be serviced first, since it was issued first and has the
+ // same priority.
+ mock_layered_pool.set_can_release_connection(true);
+ ClientSocketHandle handle4;
+ TestCompletionCallback callback4;
+ EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3",
+ params_,
+ kDefaultPriority,
+ callback4.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback3.WaitForResult());
+ EXPECT_FALSE(callback4.have_result());
+
+ // Closing a handle should free up another socket slot.
+ handle1.Reset();
+ EXPECT_EQ(OK, callback4.WaitForResult());
+}
+
+// Tests the case when an idle socket can be closed when a new request is
+// issued, and the new request belongs to a group that was previously stalled.
+//
+// The two differences from the above test are that the stalled requests are not
+// in the same group as the layered pool's request, and the the fourth request
+// has a higher priority than the third one, so gets a socket first.
+TEST_F(ClientSocketPoolBaseTest,
+ CloseIdleSocketsHeldByLayeredPoolInSameGroupWhenNeeded2) {
+ CreatePool(2, 2);
+ std::list<TestConnectJob::JobType> job_types;
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ job_types.push_back(TestConnectJob::kMockJob);
+ connect_job_factory_->set_job_types(&job_types);
+
+ ClientSocketHandle handle1;
+ TestCompletionCallback callback1;
+ EXPECT_EQ(OK, handle1.Init("group1",
+ params_,
+ kDefaultPriority,
+ callback1.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ MockLayeredPool mock_layered_pool(pool_.get(), "group2");
+ EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection())
+ .WillRepeatedly(Invoke(&mock_layered_pool,
+ &MockLayeredPool::ReleaseOneConnection));
+ mock_layered_pool.set_can_release_connection(false);
+
+ // The third request is made when the socket pool is in a stalled state.
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback3;
+ EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3",
+ params_,
+ MEDIUM,
+ callback3.callback(),
+ pool_.get(),
+ BoundNetLog()));
+
+ base::RunLoop().RunUntilIdle();
+ EXPECT_FALSE(callback3.have_result());
+
+ // The fourth request is made when the pool is no longer stalled. This
+ // request has a higher priority than the third request, so is serviced first.
+ mock_layered_pool.set_can_release_connection(true);
+ ClientSocketHandle handle4;
+ TestCompletionCallback callback4;
+ EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3",
+ params_,
+ HIGHEST,
+ callback4.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback4.WaitForResult());
+ EXPECT_FALSE(callback3.have_result());
+
+ // Closing a handle should free up another socket slot.
+ handle1.Reset();
+ EXPECT_EQ(OK, callback3.WaitForResult());
+}
+
+TEST_F(ClientSocketPoolBaseTest,
+ CloseMultipleIdleSocketsHeldByLayeredPoolWhenNeeded) {
+ CreatePool(1, 1);
+ connect_job_factory_->set_job_type(TestConnectJob::kMockJob);
+
+ MockLayeredPool mock_layered_pool1(pool_.get(), "foo");
+ EXPECT_EQ(OK, mock_layered_pool1.RequestSocket(pool_.get()));
+ EXPECT_CALL(mock_layered_pool1, CloseOneIdleConnection())
+ .WillRepeatedly(Invoke(&mock_layered_pool1,
+ &MockLayeredPool::ReleaseOneConnection));
+ MockLayeredPool mock_layered_pool2(pool_.get(), "bar");
+ EXPECT_EQ(OK, mock_layered_pool2.RequestSocketWithoutLimits(pool_.get()));
+ EXPECT_CALL(mock_layered_pool2, CloseOneIdleConnection())
+ .WillRepeatedly(Invoke(&mock_layered_pool2,
+ &MockLayeredPool::ReleaseOneConnection));
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING, handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ pool_.get(),
+ BoundNetLog()));
+ EXPECT_EQ(OK, callback.WaitForResult());
+}
+
+// Test that when a socket pool and group are at their limits, a request
+// with |ignore_limits| triggers creation of a new socket, and gets the socket
+// instead of a request with the same priority that was issued earlier, but
+// that does not have |ignore_limits| set.
+TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) {
+ scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams());
+ params_ignore_limits->set_ignore_limits(true);
+ CreatePool(1, 1);
+
+ // Issue a request to reach the socket pool limit.
+ EXPECT_EQ(OK, StartRequestWithParams("a", kDefaultPriority, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority,
+ params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", kDefaultPriority,
+ params_ignore_limits));
+ ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(OK, request(2)->WaitForResult());
+ EXPECT_FALSE(request(1)->have_result());
+}
+
+// Test that when a socket pool and group are at their limits, a request with
+// |ignore_limits| set triggers creation of a new socket, and gets the socket
+// instead of a request with a higher priority that was issued earlier, but
+// that does not have |ignore_limits| set.
+TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority) {
+ scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams());
+ params_ignore_limits->set_ignore_limits(true);
+ CreatePool(1, 1);
+
+ // Issue a request to reach the socket pool limit.
+ EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW,
+ params_ignore_limits));
+ ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(OK, request(2)->WaitForResult());
+ EXPECT_FALSE(request(1)->have_result());
+}
+
+// Test that when a socket pool and group are at their limits, a request with
+// |ignore_limits| set triggers creation of a new socket, and gets the socket
+// instead of a request with a higher priority that was issued later and
+// does not have |ignore_limits| set.
+TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsLowPriority2) {
+ scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams());
+ params_ignore_limits->set_ignore_limits(true);
+ CreatePool(1, 1);
+
+ // Issue a request to reach the socket pool limit.
+ EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", LOW,
+ params_ignore_limits));
+ ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(OK, request(1)->WaitForResult());
+ EXPECT_FALSE(request(2)->have_result());
+}
+
+// Test that when a socket pool and group are at their limits, a ConnectJob
+// issued for a request with |ignore_limits| set is not cancelled when a request
+// without |ignore_limits| issued to the same group is cancelled.
+TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsCancelOtherJob) {
+ scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams());
+ params_ignore_limits->set_ignore_limits(true);
+ CreatePool(1, 1);
+
+ // Issue a request to reach the socket pool limit.
+ EXPECT_EQ(OK, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST,
+ params_ignore_limits));
+ ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ // Cancel the pending request without ignore_limits set. The ConnectJob
+ // should not be cancelled.
+ request(1)->handle()->Reset();
+ ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a"));
+
+ EXPECT_EQ(OK, request(2)->WaitForResult());
+ EXPECT_FALSE(request(1)->have_result());
+}
+
+// More involved test of ignore limits. Issues a bunch of requests and later
+// checks the order in which they receive sockets.
+TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsOrder) {
+ scoped_refptr<TestSocketParams> params_ignore_limits(new TestSocketParams());
+ params_ignore_limits->set_ignore_limits(true);
+ CreatePool(1, 1);
+
+ connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob);
+
+ // Requests 0 and 1 do not have ignore_limits set, so they finish last. Since
+ // the maximum number of sockets per pool is 1, the second requests does not
+ // trigger a ConnectJob.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", HIGHEST, params_));
+
+ // Requests 2 and 3 have ignore_limits set, but have a low priority, so they
+ // finish just before the first two.
+ EXPECT_EQ(ERR_IO_PENDING,
+ StartRequestWithParams("a", LOW, params_ignore_limits));
+ EXPECT_EQ(ERR_IO_PENDING,
+ StartRequestWithParams("a", LOW, params_ignore_limits));
+
+ // Request 4 finishes first, since it is high priority and ignores limits.
+ EXPECT_EQ(ERR_IO_PENDING,
+ StartRequestWithParams("a", HIGHEST, params_ignore_limits));
+
+ // Request 5 and 6 are cancelled right after starting. This should result in
+ // creating two ConnectJobs. Since only one request (Request 1) did not
+ // result in creating a ConnectJob, only one of the ConnectJobs should be
+ // cancelled when the requests are.
+ EXPECT_EQ(ERR_IO_PENDING,
+ StartRequestWithParams("a", HIGHEST, params_ignore_limits));
+ EXPECT_EQ(ERR_IO_PENDING,
+ StartRequestWithParams("a", HIGHEST, params_ignore_limits));
+ EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a"));
+ request(5)->handle()->Reset();
+ EXPECT_EQ(6, pool_->NumConnectJobsInGroup("a"));
+ request(6)->handle()->Reset();
+ ASSERT_EQ(5, pool_->NumConnectJobsInGroup("a"));
+
+ // Wait for the last request to get a socket.
+ EXPECT_EQ(OK, request(1)->WaitForResult());
+
+ // Check order in which requests received sockets.
+ // These are 1-based indices, while request(x) uses 0-based indices.
+ EXPECT_EQ(1, GetOrderOfRequest(5));
+ EXPECT_EQ(2, GetOrderOfRequest(3));
+ EXPECT_EQ(3, GetOrderOfRequest(4));
+ EXPECT_EQ(4, GetOrderOfRequest(1));
+ EXPECT_EQ(5, GetOrderOfRequest(2));
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_histograms.cc b/chromium/net/socket/client_socket_pool_histograms.cc
new file mode 100644
index 00000000000..9af8649c48b
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_histograms.cc
@@ -0,0 +1,83 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool_histograms.h"
+
+#include <string>
+
+#include "base/metrics/field_trial.h"
+#include "base/metrics/histogram.h"
+#include "net/base/net_errors.h"
+#include "net/socket/client_socket_handle.h"
+
+namespace net {
+
+using base::Histogram;
+using base::HistogramBase;
+using base::LinearHistogram;
+using base::CustomHistogram;
+
+ClientSocketPoolHistograms::ClientSocketPoolHistograms(
+ const std::string& pool_name)
+ : is_http_proxy_connection_(false),
+ is_socks_connection_(false) {
+ // UMA_HISTOGRAM_ENUMERATION
+ socket_type_ = LinearHistogram::FactoryGet("Net.SocketType_" + pool_name, 1,
+ ClientSocketHandle::NUM_TYPES, ClientSocketHandle::NUM_TYPES + 1,
+ HistogramBase::kUmaTargetedHistogramFlag);
+ // UMA_HISTOGRAM_CUSTOM_TIMES
+ request_time_ = Histogram::FactoryTimeGet(
+ "Net.SocketRequestTime_" + pool_name,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100, HistogramBase::kUmaTargetedHistogramFlag);
+ // UMA_HISTOGRAM_CUSTOM_TIMES
+ unused_idle_time_ = Histogram::FactoryTimeGet(
+ "Net.SocketIdleTimeBeforeNextUse_UnusedSocket_" + pool_name,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(6),
+ 100, HistogramBase::kUmaTargetedHistogramFlag);
+ // UMA_HISTOGRAM_CUSTOM_TIMES
+ reused_idle_time_ = Histogram::FactoryTimeGet(
+ "Net.SocketIdleTimeBeforeNextUse_ReusedSocket_" + pool_name,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(6),
+ 100, HistogramBase::kUmaTargetedHistogramFlag);
+ // UMA_HISTOGRAM_CUSTOM_ENUMERATION
+ error_code_ = CustomHistogram::FactoryGet(
+ "Net.SocketInitErrorCodes_" + pool_name,
+ GetAllErrorCodesForUma(),
+ HistogramBase::kUmaTargetedHistogramFlag);
+
+ if (pool_name == "HTTPProxy")
+ is_http_proxy_connection_ = true;
+ else if (pool_name == "SOCK")
+ is_socks_connection_ = true;
+}
+
+ClientSocketPoolHistograms::~ClientSocketPoolHistograms() {
+}
+
+void ClientSocketPoolHistograms::AddSocketType(int type) const {
+ socket_type_->Add(type);
+}
+
+void ClientSocketPoolHistograms::AddRequestTime(base::TimeDelta time) const {
+ request_time_->AddTime(time);
+}
+
+void ClientSocketPoolHistograms::AddUnusedIdleTime(base::TimeDelta time) const {
+ unused_idle_time_->AddTime(time);
+}
+
+void ClientSocketPoolHistograms::AddReusedIdleTime(base::TimeDelta time) const {
+ reused_idle_time_->AddTime(time);
+}
+
+void ClientSocketPoolHistograms::AddErrorCode(int error_code) const {
+ // Error codes are positive (since histograms expect positive sample values).
+ error_code_->Add(-error_code);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_histograms.h b/chromium/net/socket/client_socket_pool_histograms.h
new file mode 100644
index 00000000000..26a406362bd
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_histograms.h
@@ -0,0 +1,46 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_
+#define NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_
+
+#include <string>
+
+#include "base/memory/ref_counted.h"
+#include "base/time/time.h"
+#include "net/base/net_export.h"
+
+namespace base {
+class HistogramBase;
+}
+
+namespace net {
+
+class NET_EXPORT_PRIVATE ClientSocketPoolHistograms {
+ public:
+ ClientSocketPoolHistograms(const std::string& pool_name);
+ ~ClientSocketPoolHistograms();
+
+ void AddSocketType(int socket_reuse_type) const;
+ void AddRequestTime(base::TimeDelta time) const;
+ void AddUnusedIdleTime(base::TimeDelta time) const;
+ void AddReusedIdleTime(base::TimeDelta time) const;
+ void AddErrorCode(int error_code) const;
+
+ private:
+ base::HistogramBase* socket_type_;
+ base::HistogramBase* request_time_;
+ base::HistogramBase* unused_idle_time_;
+ base::HistogramBase* reused_idle_time_;
+ base::HistogramBase* error_code_;
+
+ bool is_http_proxy_connection_;
+ bool is_socks_connection_;
+
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolHistograms);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_
diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc
new file mode 100644
index 00000000000..71496d28646
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_manager.cc
@@ -0,0 +1,467 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool_manager.h"
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+#include "base/strings/stringprintf.h"
+#include "net/base/load_flags.h"
+#include "net/http/http_proxy_client_socket_pool.h"
+#include "net/http/http_request_info.h"
+#include "net/http/http_stream_factory.h"
+#include "net/proxy/proxy_info.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/socks_client_socket_pool.h"
+#include "net/socket/ssl_client_socket_pool.h"
+#include "net/socket/transport_client_socket_pool.h"
+
+namespace net {
+
+namespace {
+
+// Limit of sockets of each socket pool.
+int g_max_sockets_per_pool[] = {
+ 256, // NORMAL_SOCKET_POOL
+ 256 // WEBSOCKET_SOCKET_POOL
+};
+
+COMPILE_ASSERT(arraysize(g_max_sockets_per_pool) ==
+ HttpNetworkSession::NUM_SOCKET_POOL_TYPES,
+ max_sockets_per_pool_length_mismatch);
+
+// Default to allow up to 6 connections per host. Experiment and tuning may
+// try other values (greater than 0). Too large may cause many problems, such
+// as home routers blocking the connections!?!? See http://crbug.com/12066.
+//
+// WebSocket connections are long-lived, and should be treated differently
+// than normal other connections. 6 connections per group sounded too small
+// for such use, thus we use a larger limit which was determined somewhat
+// arbitrarily.
+// TODO(yutak): Look at the usage and determine the right value after
+// WebSocket protocol stack starts to work.
+int g_max_sockets_per_group[] = {
+ 6, // NORMAL_SOCKET_POOL
+ 30 // WEBSOCKET_SOCKET_POOL
+};
+
+COMPILE_ASSERT(arraysize(g_max_sockets_per_group) ==
+ HttpNetworkSession::NUM_SOCKET_POOL_TYPES,
+ max_sockets_per_group_length_mismatch);
+
+// The max number of sockets to allow per proxy server. This applies both to
+// http and SOCKS proxies. See http://crbug.com/12066 and
+// http://crbug.com/44501 for details about proxy server connection limits.
+int g_max_sockets_per_proxy_server[] = {
+ kDefaultMaxSocketsPerProxyServer, // NORMAL_SOCKET_POOL
+ kDefaultMaxSocketsPerProxyServer // WEBSOCKET_SOCKET_POOL
+};
+
+COMPILE_ASSERT(arraysize(g_max_sockets_per_proxy_server) ==
+ HttpNetworkSession::NUM_SOCKET_POOL_TYPES,
+ max_sockets_per_proxy_server_length_mismatch);
+
+// The meat of the implementation for the InitSocketHandleForHttpRequest,
+// InitSocketHandleForRawConnect and PreconnectSocketsForHttpRequest methods.
+int InitSocketPoolHelper(const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ bool force_tunnel,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ int num_preconnect_streams,
+ ClientSocketHandle* socket_handle,
+ HttpNetworkSession::SocketPoolType socket_pool_type,
+ const OnHostResolutionCallback& resolution_callback,
+ const CompletionCallback& callback) {
+ scoped_refptr<TransportSocketParams> tcp_params;
+ scoped_refptr<HttpProxySocketParams> http_proxy_params;
+ scoped_refptr<SOCKSSocketParams> socks_params;
+ scoped_ptr<HostPortPair> proxy_host_port;
+
+ bool using_ssl = request_url.SchemeIs("https") ||
+ request_url.SchemeIs("wss") || force_spdy_over_ssl;
+
+ HostPortPair origin_host_port =
+ HostPortPair(request_url.HostNoBrackets(),
+ request_url.EffectiveIntPort());
+
+ if (!using_ssl && session->params().testing_fixed_http_port != 0) {
+ origin_host_port.set_port(session->params().testing_fixed_http_port);
+ } else if (using_ssl && session->params().testing_fixed_https_port != 0) {
+ origin_host_port.set_port(session->params().testing_fixed_https_port);
+ }
+
+ bool disable_resolver_cache =
+ request_load_flags & LOAD_BYPASS_CACHE ||
+ request_load_flags & LOAD_VALIDATE_CACHE ||
+ request_load_flags & LOAD_DISABLE_CACHE;
+
+ int load_flags = request_load_flags;
+ if (session->params().ignore_certificate_errors)
+ load_flags |= LOAD_IGNORE_ALL_CERT_ERRORS;
+
+ // Build the string used to uniquely identify connections of this type.
+ // Determine the host and port to connect to.
+ std::string connection_group = origin_host_port.ToString();
+ DCHECK(!connection_group.empty());
+ if (request_url.SchemeIs("ftp")) {
+ // Combining FTP with forced SPDY over SSL would be a "path to madness".
+ // Make sure we never do that.
+ DCHECK(!using_ssl);
+ connection_group = "ftp/" + connection_group;
+ }
+ if (using_ssl) {
+ // All connections in a group should use the same SSLConfig settings.
+ // Encode version_max in the connection group's name, unless it's the
+ // default version_max. (We want the common case to use the shortest
+ // encoding). A version_max of TLS 1.1 is encoded as "ssl(max:3.2)/"
+ // rather than "tlsv1.1/" because the actual protocol version, which
+ // is selected by the server, may not be TLS 1.1. Do not encode
+ // version_min in the connection group's name because version_min
+ // should be the same for all connections, whereas version_max may
+ // change for version fallbacks.
+ std::string prefix = "ssl/";
+ if (ssl_config_for_origin.version_max !=
+ SSLConfigService::default_version_max()) {
+ switch (ssl_config_for_origin.version_max) {
+ case SSL_PROTOCOL_VERSION_TLS1_2:
+ prefix = "ssl(max:3.3)/";
+ break;
+ case SSL_PROTOCOL_VERSION_TLS1_1:
+ prefix = "ssl(max:3.2)/";
+ break;
+ case SSL_PROTOCOL_VERSION_TLS1:
+ prefix = "ssl(max:3.1)/";
+ break;
+ case SSL_PROTOCOL_VERSION_SSL3:
+ prefix = "sslv3/";
+ break;
+ default:
+ CHECK(false);
+ break;
+ }
+ }
+ connection_group = prefix + connection_group;
+ }
+
+ bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0;
+ if (proxy_info.is_direct()) {
+ tcp_params = new TransportSocketParams(origin_host_port,
+ request_priority,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback);
+ } else {
+ ProxyServer proxy_server = proxy_info.proxy_server();
+ proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair()));
+ scoped_refptr<TransportSocketParams> proxy_tcp_params(
+ new TransportSocketParams(*proxy_host_port,
+ request_priority,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback));
+
+ if (proxy_info.is_http() || proxy_info.is_https()) {
+ std::string user_agent;
+ request_extra_headers.GetHeader(HttpRequestHeaders::kUserAgent,
+ &user_agent);
+ scoped_refptr<SSLSocketParams> ssl_params;
+ if (proxy_info.is_https()) {
+ // Set ssl_params, and unset proxy_tcp_params
+ ssl_params = new SSLSocketParams(proxy_tcp_params,
+ NULL,
+ NULL,
+ ProxyServer::SCHEME_DIRECT,
+ *proxy_host_port.get(),
+ ssl_config_for_proxy,
+ kPrivacyModeDisabled,
+ load_flags,
+ force_spdy_over_ssl,
+ want_spdy_over_npn);
+ proxy_tcp_params = NULL;
+ }
+
+ http_proxy_params =
+ new HttpProxySocketParams(proxy_tcp_params,
+ ssl_params,
+ request_url,
+ user_agent,
+ origin_host_port,
+ session->http_auth_cache(),
+ session->http_auth_handler_factory(),
+ session->spdy_session_pool(),
+ force_tunnel || using_ssl);
+ } else {
+ DCHECK(proxy_info.is_socks());
+ char socks_version;
+ if (proxy_server.scheme() == ProxyServer::SCHEME_SOCKS5)
+ socks_version = '5';
+ else
+ socks_version = '4';
+ connection_group = base::StringPrintf(
+ "socks%c/%s", socks_version, connection_group.c_str());
+
+ socks_params = new SOCKSSocketParams(proxy_tcp_params,
+ socks_version == '5',
+ origin_host_port,
+ request_priority);
+ }
+ }
+
+ // Change group name if privacy mode is enabled.
+ if (privacy_mode == kPrivacyModeEnabled)
+ connection_group = "pm/" + connection_group;
+
+ // Deal with SSL - which layers on top of any given proxy.
+ if (using_ssl) {
+ scoped_refptr<SSLSocketParams> ssl_params =
+ new SSLSocketParams(tcp_params,
+ socks_params,
+ http_proxy_params,
+ proxy_info.proxy_server().scheme(),
+ origin_host_port,
+ ssl_config_for_origin,
+ privacy_mode,
+ load_flags,
+ force_spdy_over_ssl,
+ want_spdy_over_npn);
+ SSLClientSocketPool* ssl_pool = NULL;
+ if (proxy_info.is_direct()) {
+ ssl_pool = session->GetSSLSocketPool(socket_pool_type);
+ } else {
+ ssl_pool = session->GetSocketPoolForSSLWithProxy(socket_pool_type,
+ *proxy_host_port);
+ }
+
+ if (num_preconnect_streams) {
+ RequestSocketsForPool(ssl_pool, connection_group, ssl_params,
+ num_preconnect_streams, net_log);
+ return OK;
+ }
+
+ return socket_handle->Init(connection_group, ssl_params,
+ request_priority, callback, ssl_pool,
+ net_log);
+ }
+
+ // Finally, get the connection started.
+
+ if (proxy_info.is_http() || proxy_info.is_https()) {
+ HttpProxyClientSocketPool* pool =
+ session->GetSocketPoolForHTTPProxy(socket_pool_type, *proxy_host_port);
+ if (num_preconnect_streams) {
+ RequestSocketsForPool(pool, connection_group, http_proxy_params,
+ num_preconnect_streams, net_log);
+ return OK;
+ }
+
+ return socket_handle->Init(connection_group, http_proxy_params,
+ request_priority, callback,
+ pool, net_log);
+ }
+
+ if (proxy_info.is_socks()) {
+ SOCKSClientSocketPool* pool =
+ session->GetSocketPoolForSOCKSProxy(socket_pool_type, *proxy_host_port);
+ if (num_preconnect_streams) {
+ RequestSocketsForPool(pool, connection_group, socks_params,
+ num_preconnect_streams, net_log);
+ return OK;
+ }
+
+ return socket_handle->Init(connection_group, socks_params,
+ request_priority, callback, pool,
+ net_log);
+ }
+
+ DCHECK(proxy_info.is_direct());
+
+ TransportClientSocketPool* pool =
+ session->GetTransportSocketPool(socket_pool_type);
+ if (num_preconnect_streams) {
+ RequestSocketsForPool(pool, connection_group, tcp_params,
+ num_preconnect_streams, net_log);
+ return OK;
+ }
+
+ return socket_handle->Init(connection_group, tcp_params,
+ request_priority, callback,
+ pool, net_log);
+}
+
+} // namespace
+
+ClientSocketPoolManager::ClientSocketPoolManager() {}
+ClientSocketPoolManager::~ClientSocketPoolManager() {}
+
+// static
+int ClientSocketPoolManager::max_sockets_per_pool(
+ HttpNetworkSession::SocketPoolType pool_type) {
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ return g_max_sockets_per_pool[pool_type];
+}
+
+// static
+void ClientSocketPoolManager::set_max_sockets_per_pool(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count) {
+ DCHECK_LT(0, socket_count);
+ DCHECK_GT(1000, socket_count); // Sanity check.
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ g_max_sockets_per_pool[pool_type] = socket_count;
+ DCHECK_GE(g_max_sockets_per_pool[pool_type],
+ g_max_sockets_per_group[pool_type]);
+}
+
+// static
+int ClientSocketPoolManager::max_sockets_per_group(
+ HttpNetworkSession::SocketPoolType pool_type) {
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ return g_max_sockets_per_group[pool_type];
+}
+
+// static
+void ClientSocketPoolManager::set_max_sockets_per_group(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count) {
+ DCHECK_LT(0, socket_count);
+ // The following is a sanity check... but we should NEVER be near this value.
+ DCHECK_GT(100, socket_count);
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ g_max_sockets_per_group[pool_type] = socket_count;
+
+ DCHECK_GE(g_max_sockets_per_pool[pool_type],
+ g_max_sockets_per_group[pool_type]);
+ DCHECK_GE(g_max_sockets_per_proxy_server[pool_type],
+ g_max_sockets_per_group[pool_type]);
+}
+
+// static
+int ClientSocketPoolManager::max_sockets_per_proxy_server(
+ HttpNetworkSession::SocketPoolType pool_type) {
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ return g_max_sockets_per_proxy_server[pool_type];
+}
+
+// static
+void ClientSocketPoolManager::set_max_sockets_per_proxy_server(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count) {
+ DCHECK_LT(0, socket_count);
+ DCHECK_GT(100, socket_count); // Sanity check.
+ DCHECK_LT(pool_type, HttpNetworkSession::NUM_SOCKET_POOL_TYPES);
+ // Assert this case early on. The max number of sockets per group cannot
+ // exceed the max number of sockets per proxy server.
+ DCHECK_LE(g_max_sockets_per_group[pool_type], socket_count);
+ g_max_sockets_per_proxy_server[pool_type] = socket_count;
+}
+
+int InitSocketHandleForHttpRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const OnHostResolutionCallback& resolution_callback,
+ const CompletionCallback& callback) {
+ DCHECK(socket_handle);
+ return InitSocketPoolHelper(
+ request_url, request_extra_headers, request_load_flags, request_priority,
+ session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn,
+ ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log,
+ 0, socket_handle, HttpNetworkSession::NORMAL_SOCKET_POOL,
+ resolution_callback, callback);
+}
+
+int InitSocketHandleForWebSocketRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const OnHostResolutionCallback& resolution_callback,
+ const CompletionCallback& callback) {
+ DCHECK(socket_handle);
+ return InitSocketPoolHelper(
+ request_url, request_extra_headers, request_load_flags, request_priority,
+ session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn,
+ ssl_config_for_origin, ssl_config_for_proxy, true, privacy_mode, net_log,
+ 0, socket_handle, HttpNetworkSession::WEBSOCKET_SOCKET_POOL,
+ resolution_callback, callback);
+}
+
+int InitSocketHandleForRawConnect(
+ const HostPortPair& host_port_pair,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const CompletionCallback& callback) {
+ DCHECK(socket_handle);
+ // Synthesize an HttpRequestInfo.
+ GURL request_url = GURL("http://" + host_port_pair.ToString());
+ HttpRequestHeaders request_extra_headers;
+ int request_load_flags = 0;
+ RequestPriority request_priority = MEDIUM;
+
+ return InitSocketPoolHelper(
+ request_url, request_extra_headers, request_load_flags, request_priority,
+ session, proxy_info, false, false, ssl_config_for_origin,
+ ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle,
+ HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(),
+ callback);
+}
+
+int PreconnectSocketsForHttpRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ int num_preconnect_streams) {
+ return InitSocketPoolHelper(
+ request_url, request_extra_headers, request_load_flags, request_priority,
+ session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn,
+ ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log,
+ num_preconnect_streams, NULL, HttpNetworkSession::NORMAL_SOCKET_POOL,
+ OnHostResolutionCallback(), CompletionCallback());
+}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_manager.h b/chromium/net/socket/client_socket_pool_manager.h
new file mode 100644
index 00000000000..1b78324f233
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_manager.h
@@ -0,0 +1,169 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// ClientSocketPoolManager manages access to all ClientSocketPools. It's a
+// simple container for all of them. Most importantly, it handles the lifetime
+// and destruction order properly.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_
+#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_
+
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+#include "net/base/request_priority.h"
+#include "net/http/http_network_session.h"
+
+class GURL;
+
+namespace base {
+class Value;
+}
+
+namespace net {
+
+typedef base::Callback<int(const AddressList&, const BoundNetLog& net_log)>
+OnHostResolutionCallback;
+
+class BoundNetLog;
+class ClientSocketHandle;
+class HostPortPair;
+class HttpNetworkSession;
+class HttpProxyClientSocketPool;
+class HttpRequestHeaders;
+class ProxyInfo;
+class TransportClientSocketPool;
+class SOCKSClientSocketPool;
+class SSLClientSocketPool;
+
+struct SSLConfig;
+
+// This should rather be a simple constant but Windows shared libs doesn't
+// really offer much flexiblity in exporting contants.
+enum DefaultMaxValues { kDefaultMaxSocketsPerProxyServer = 32 };
+
+class NET_EXPORT_PRIVATE ClientSocketPoolManager {
+ public:
+ ClientSocketPoolManager();
+ virtual ~ClientSocketPoolManager();
+
+ // The setter methods below affect only newly created socket pools after the
+ // methods are called. Normally they should be called at program startup
+ // before any ClientSocketPoolManagerImpl is created.
+ static int max_sockets_per_pool(HttpNetworkSession::SocketPoolType pool_type);
+ static void set_max_sockets_per_pool(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count);
+
+ static int max_sockets_per_group(
+ HttpNetworkSession::SocketPoolType pool_type);
+ static void set_max_sockets_per_group(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count);
+
+ static int max_sockets_per_proxy_server(
+ HttpNetworkSession::SocketPoolType pool_type);
+ static void set_max_sockets_per_proxy_server(
+ HttpNetworkSession::SocketPoolType pool_type,
+ int socket_count);
+
+ virtual void FlushSocketPoolsWithError(int error) = 0;
+ virtual void CloseIdleSockets() = 0;
+ virtual TransportClientSocketPool* GetTransportSocketPool() = 0;
+ virtual SSLClientSocketPool* GetSSLSocketPool() = 0;
+ virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) = 0;
+ virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) = 0;
+ virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) = 0;
+ // Creates a Value summary of the state of the socket pools. The caller is
+ // responsible for deleting the returned value.
+ virtual base::Value* SocketPoolInfoToValue() const = 0;
+};
+
+// A helper method that uses the passed in proxy information to initialize a
+// ClientSocketHandle with the relevant socket pool. Use this method for
+// HTTP/HTTPS requests. |ssl_config_for_origin| is only used if the request
+// uses SSL and |ssl_config_for_proxy| is used if the proxy server is HTTPS.
+// |resolution_callback| will be invoked after the the hostname is
+// resolved. If |resolution_callback| does not return OK, then the
+// connection will be aborted with that value.
+int InitSocketHandleForHttpRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const OnHostResolutionCallback& resolution_callback,
+ const CompletionCallback& callback);
+
+// A helper method that uses the passed in proxy information to initialize a
+// ClientSocketHandle with the relevant socket pool. Use this method for
+// HTTP/HTTPS requests for WebSocket handshake.
+// |ssl_config_for_origin| is only used if the request
+// uses SSL and |ssl_config_for_proxy| is used if the proxy server is HTTPS.
+// |resolution_callback| will be invoked after the the hostname is
+// resolved. If |resolution_callback| does not return OK, then the
+// connection will be aborted with that value.
+// This function uses WEBSOCKET_SOCKET_POOL socket pools.
+int InitSocketHandleForWebSocketRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const OnHostResolutionCallback& resolution_callback,
+ const CompletionCallback& callback);
+
+// A helper method that uses the passed in proxy information to initialize a
+// ClientSocketHandle with the relevant socket pool. Use this method for
+// a raw socket connection to a host-port pair (that needs to tunnel through
+// the proxies).
+NET_EXPORT int InitSocketHandleForRawConnect(
+ const HostPortPair& host_port_pair,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ ClientSocketHandle* socket_handle,
+ const CompletionCallback& callback);
+
+// Similar to InitSocketHandleForHttpRequest except that it initiates the
+// desired number of preconnect streams from the relevant socket pool.
+int PreconnectSocketsForHttpRequest(
+ const GURL& request_url,
+ const HttpRequestHeaders& request_extra_headers,
+ int request_load_flags,
+ RequestPriority request_priority,
+ HttpNetworkSession* session,
+ const ProxyInfo& proxy_info,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn,
+ const SSLConfig& ssl_config_for_origin,
+ const SSLConfig& ssl_config_for_proxy,
+ PrivacyMode privacy_mode,
+ const BoundNetLog& net_log,
+ int num_preconnect_streams);
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_H_
diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc
new file mode 100644
index 00000000000..b557874d011
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_manager_impl.cc
@@ -0,0 +1,392 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/client_socket_pool_manager_impl.h"
+
+#include "base/logging.h"
+#include "base/values.h"
+#include "net/http/http_network_session.h"
+#include "net/http/http_proxy_client_socket_pool.h"
+#include "net/socket/socks_client_socket_pool.h"
+#include "net/socket/ssl_client_socket_pool.h"
+#include "net/socket/transport_client_socket_pool.h"
+#include "net/ssl/ssl_config_service.h"
+
+namespace net {
+
+namespace {
+
+// Appends information about all |socket_pools| to the end of |list|.
+template <class MapType>
+void AddSocketPoolsToList(base::ListValue* list,
+ const MapType& socket_pools,
+ const std::string& type,
+ bool include_nested_pools) {
+ for (typename MapType::const_iterator it = socket_pools.begin();
+ it != socket_pools.end(); it++) {
+ list->Append(it->second->GetInfoAsValue(it->first.ToString(),
+ type,
+ include_nested_pools));
+ }
+}
+
+} // namespace
+
+ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl(
+ NetLog* net_log,
+ ClientSocketFactory* socket_factory,
+ HostResolver* host_resolver,
+ CertVerifier* cert_verifier,
+ ServerBoundCertService* server_bound_cert_service,
+ TransportSecurityState* transport_security_state,
+ const std::string& ssl_session_cache_shard,
+ ProxyService* proxy_service,
+ SSLConfigService* ssl_config_service,
+ HttpNetworkSession::SocketPoolType pool_type)
+ : net_log_(net_log),
+ socket_factory_(socket_factory),
+ host_resolver_(host_resolver),
+ cert_verifier_(cert_verifier),
+ server_bound_cert_service_(server_bound_cert_service),
+ transport_security_state_(transport_security_state),
+ ssl_session_cache_shard_(ssl_session_cache_shard),
+ proxy_service_(proxy_service),
+ ssl_config_service_(ssl_config_service),
+ pool_type_(pool_type),
+ transport_pool_histograms_("TCP"),
+ transport_socket_pool_(new TransportClientSocketPool(
+ max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type),
+ &transport_pool_histograms_,
+ host_resolver,
+ socket_factory_,
+ net_log)),
+ ssl_pool_histograms_("SSL2"),
+ ssl_socket_pool_(new SSLClientSocketPool(
+ max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type),
+ &ssl_pool_histograms_,
+ host_resolver,
+ cert_verifier,
+ server_bound_cert_service,
+ transport_security_state,
+ ssl_session_cache_shard,
+ socket_factory,
+ transport_socket_pool_.get(),
+ NULL /* no socks proxy */,
+ NULL /* no http proxy */,
+ ssl_config_service,
+ net_log)),
+ transport_for_socks_pool_histograms_("TCPforSOCKS"),
+ socks_pool_histograms_("SOCK"),
+ transport_for_http_proxy_pool_histograms_("TCPforHTTPProxy"),
+ transport_for_https_proxy_pool_histograms_("TCPforHTTPSProxy"),
+ ssl_for_https_proxy_pool_histograms_("SSLforHTTPSProxy"),
+ http_proxy_pool_histograms_("HTTPProxy"),
+ ssl_socket_pool_for_proxies_histograms_("SSLForProxies") {
+ CertDatabase::GetInstance()->AddObserver(this);
+}
+
+ClientSocketPoolManagerImpl::~ClientSocketPoolManagerImpl() {
+ CertDatabase::GetInstance()->RemoveObserver(this);
+}
+
+void ClientSocketPoolManagerImpl::FlushSocketPoolsWithError(int error) {
+ // Flush the highest level pools first, since higher level pools may release
+ // stuff to the lower level pools.
+
+ for (SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_proxies_.begin();
+ it != ssl_socket_pools_for_proxies_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (HTTPProxySocketPoolMap::const_iterator it =
+ http_proxy_socket_pools_.begin();
+ it != http_proxy_socket_pools_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_https_proxies_.begin();
+ it != ssl_socket_pools_for_https_proxies_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_https_proxies_.begin();
+ it != transport_socket_pools_for_https_proxies_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_http_proxies_.begin();
+ it != transport_socket_pools_for_http_proxies_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (SOCKSSocketPoolMap::const_iterator it =
+ socks_socket_pools_.begin();
+ it != socks_socket_pools_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_socks_proxies_.begin();
+ it != transport_socket_pools_for_socks_proxies_.end();
+ ++it)
+ it->second->FlushWithError(error);
+
+ ssl_socket_pool_->FlushWithError(error);
+ transport_socket_pool_->FlushWithError(error);
+}
+
+void ClientSocketPoolManagerImpl::CloseIdleSockets() {
+ // Close sockets in the highest level pools first, since higher level pools'
+ // sockets may release stuff to the lower level pools.
+ for (SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_proxies_.begin();
+ it != ssl_socket_pools_for_proxies_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (HTTPProxySocketPoolMap::const_iterator it =
+ http_proxy_socket_pools_.begin();
+ it != http_proxy_socket_pools_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_https_proxies_.begin();
+ it != ssl_socket_pools_for_https_proxies_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_https_proxies_.begin();
+ it != transport_socket_pools_for_https_proxies_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_http_proxies_.begin();
+ it != transport_socket_pools_for_http_proxies_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (SOCKSSocketPoolMap::const_iterator it =
+ socks_socket_pools_.begin();
+ it != socks_socket_pools_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ for (TransportSocketPoolMap::const_iterator it =
+ transport_socket_pools_for_socks_proxies_.begin();
+ it != transport_socket_pools_for_socks_proxies_.end();
+ ++it)
+ it->second->CloseIdleSockets();
+
+ ssl_socket_pool_->CloseIdleSockets();
+ transport_socket_pool_->CloseIdleSockets();
+}
+
+TransportClientSocketPool*
+ClientSocketPoolManagerImpl::GetTransportSocketPool() {
+ return transport_socket_pool_.get();
+}
+
+SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSSLSocketPool() {
+ return ssl_socket_pool_.get();
+}
+
+SOCKSClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) {
+ SOCKSSocketPoolMap::const_iterator it = socks_socket_pools_.find(socks_proxy);
+ if (it != socks_socket_pools_.end()) {
+ DCHECK(ContainsKey(transport_socket_pools_for_socks_proxies_, socks_proxy));
+ return it->second;
+ }
+
+ DCHECK(!ContainsKey(transport_socket_pools_for_socks_proxies_, socks_proxy));
+
+ std::pair<TransportSocketPoolMap::iterator, bool> tcp_ret =
+ transport_socket_pools_for_socks_proxies_.insert(
+ std::make_pair(
+ socks_proxy,
+ new TransportClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &transport_for_socks_pool_histograms_,
+ host_resolver_,
+ socket_factory_,
+ net_log_)));
+ DCHECK(tcp_ret.second);
+
+ std::pair<SOCKSSocketPoolMap::iterator, bool> ret =
+ socks_socket_pools_.insert(
+ std::make_pair(socks_proxy, new SOCKSClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &socks_pool_histograms_,
+ host_resolver_,
+ tcp_ret.first->second,
+ net_log_)));
+
+ return ret.first->second;
+}
+
+HttpProxyClientSocketPool*
+ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) {
+ HTTPProxySocketPoolMap::const_iterator it =
+ http_proxy_socket_pools_.find(http_proxy);
+ if (it != http_proxy_socket_pools_.end()) {
+ DCHECK(ContainsKey(transport_socket_pools_for_http_proxies_, http_proxy));
+ DCHECK(ContainsKey(transport_socket_pools_for_https_proxies_, http_proxy));
+ DCHECK(ContainsKey(ssl_socket_pools_for_https_proxies_, http_proxy));
+ return it->second;
+ }
+
+ DCHECK(!ContainsKey(transport_socket_pools_for_http_proxies_, http_proxy));
+ DCHECK(!ContainsKey(transport_socket_pools_for_https_proxies_, http_proxy));
+ DCHECK(!ContainsKey(ssl_socket_pools_for_https_proxies_, http_proxy));
+
+ std::pair<TransportSocketPoolMap::iterator, bool> tcp_http_ret =
+ transport_socket_pools_for_http_proxies_.insert(
+ std::make_pair(
+ http_proxy,
+ new TransportClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &transport_for_http_proxy_pool_histograms_,
+ host_resolver_,
+ socket_factory_,
+ net_log_)));
+ DCHECK(tcp_http_ret.second);
+
+ std::pair<TransportSocketPoolMap::iterator, bool> tcp_https_ret =
+ transport_socket_pools_for_https_proxies_.insert(
+ std::make_pair(
+ http_proxy,
+ new TransportClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &transport_for_https_proxy_pool_histograms_,
+ host_resolver_,
+ socket_factory_,
+ net_log_)));
+ DCHECK(tcp_https_ret.second);
+
+ std::pair<SSLSocketPoolMap::iterator, bool> ssl_https_ret =
+ ssl_socket_pools_for_https_proxies_.insert(std::make_pair(
+ http_proxy,
+ new SSLClientSocketPool(max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &ssl_for_https_proxy_pool_histograms_,
+ host_resolver_,
+ cert_verifier_,
+ server_bound_cert_service_,
+ transport_security_state_,
+ ssl_session_cache_shard_,
+ socket_factory_,
+ tcp_https_ret.first->second /* https proxy */,
+ NULL /* no socks proxy */,
+ NULL /* no http proxy */,
+ ssl_config_service_.get(),
+ net_log_)));
+ DCHECK(tcp_https_ret.second);
+
+ std::pair<HTTPProxySocketPoolMap::iterator, bool> ret =
+ http_proxy_socket_pools_.insert(
+ std::make_pair(
+ http_proxy,
+ new HttpProxyClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &http_proxy_pool_histograms_,
+ host_resolver_,
+ tcp_http_ret.first->second,
+ ssl_https_ret.first->second,
+ net_log_)));
+
+ return ret.first->second;
+}
+
+SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) {
+ SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_proxies_.find(proxy_server);
+ if (it != ssl_socket_pools_for_proxies_.end())
+ return it->second;
+
+ SSLClientSocketPool* new_pool = new SSLClientSocketPool(
+ max_sockets_per_proxy_server(pool_type_),
+ max_sockets_per_group(pool_type_),
+ &ssl_pool_histograms_,
+ host_resolver_,
+ cert_verifier_,
+ server_bound_cert_service_,
+ transport_security_state_,
+ ssl_session_cache_shard_,
+ socket_factory_,
+ NULL, /* no tcp pool, we always go through a proxy */
+ GetSocketPoolForSOCKSProxy(proxy_server),
+ GetSocketPoolForHTTPProxy(proxy_server),
+ ssl_config_service_.get(),
+ net_log_);
+
+ std::pair<SSLSocketPoolMap::iterator, bool> ret =
+ ssl_socket_pools_for_proxies_.insert(std::make_pair(proxy_server,
+ new_pool));
+
+ return ret.first->second;
+}
+
+base::Value* ClientSocketPoolManagerImpl::SocketPoolInfoToValue() const {
+ base::ListValue* list = new base::ListValue();
+ list->Append(transport_socket_pool_->GetInfoAsValue("transport_socket_pool",
+ "transport_socket_pool",
+ false));
+ // Third parameter is false because |ssl_socket_pool_| uses
+ // |transport_socket_pool_| internally, and do not want to add it a second
+ // time.
+ list->Append(ssl_socket_pool_->GetInfoAsValue("ssl_socket_pool",
+ "ssl_socket_pool",
+ false));
+ AddSocketPoolsToList(list,
+ http_proxy_socket_pools_,
+ "http_proxy_socket_pool",
+ true);
+ AddSocketPoolsToList(list,
+ socks_socket_pools_,
+ "socks_socket_pool",
+ true);
+
+ // Third parameter is false because |ssl_socket_pools_for_proxies_| use
+ // socket pools in |http_proxy_socket_pools_| and |socks_socket_pools_|.
+ AddSocketPoolsToList(list,
+ ssl_socket_pools_for_proxies_,
+ "ssl_socket_pool_for_proxies",
+ false);
+ return list;
+}
+
+void ClientSocketPoolManagerImpl::OnCertAdded(const X509Certificate* cert) {
+ FlushSocketPoolsWithError(ERR_NETWORK_CHANGED);
+}
+
+void ClientSocketPoolManagerImpl::OnCertTrustChanged(
+ const X509Certificate* cert) {
+ // We should flush the socket pools if we removed trust from a
+ // cert, because a previously trusted server may have become
+ // untrusted.
+ //
+ // We should not flush the socket pools if we added trust to a
+ // cert.
+ //
+ // Since the OnCertTrustChanged method doesn't tell us what
+ // kind of trust change it is, we have to flush the socket
+ // pools to be safe.
+ FlushSocketPoolsWithError(ERR_NETWORK_CHANGED);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h
new file mode 100644
index 00000000000..8f6e618d2e1
--- /dev/null
+++ b/chromium/net/socket/client_socket_pool_manager_impl.h
@@ -0,0 +1,150 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_
+#define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_
+
+#include <map>
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/stl_util.h"
+#include "base/template_util.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/cert/cert_database.h"
+#include "net/http/http_network_session.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/client_socket_pool_manager.h"
+
+namespace net {
+
+class CertVerifier;
+class ClientSocketFactory;
+class ClientSocketPoolHistograms;
+class HttpProxyClientSocketPool;
+class HostResolver;
+class NetLog;
+class ServerBoundCertService;
+class ProxyService;
+class SOCKSClientSocketPool;
+class SSLClientSocketPool;
+class SSLConfigService;
+class TransportClientSocketPool;
+class TransportSecurityState;
+
+namespace internal {
+
+// A helper class for auto-deleting Values in the destructor.
+template <typename Key, typename Value>
+class OwnedPoolMap : public std::map<Key, Value> {
+ public:
+ OwnedPoolMap() {
+ COMPILE_ASSERT(base::is_pointer<Value>::value,
+ value_must_be_a_pointer);
+ }
+
+ ~OwnedPoolMap() {
+ STLDeleteValues(this);
+ }
+};
+
+} // namespace internal
+
+class ClientSocketPoolManagerImpl : public base::NonThreadSafe,
+ public ClientSocketPoolManager,
+ public CertDatabase::Observer {
+ public:
+ ClientSocketPoolManagerImpl(NetLog* net_log,
+ ClientSocketFactory* socket_factory,
+ HostResolver* host_resolver,
+ CertVerifier* cert_verifier,
+ ServerBoundCertService* server_bound_cert_service,
+ TransportSecurityState* transport_security_state,
+ const std::string& ssl_session_cache_shard,
+ ProxyService* proxy_service,
+ SSLConfigService* ssl_config_service,
+ HttpNetworkSession::SocketPoolType pool_type);
+ virtual ~ClientSocketPoolManagerImpl();
+
+ virtual void FlushSocketPoolsWithError(int error) OVERRIDE;
+ virtual void CloseIdleSockets() OVERRIDE;
+
+ virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE;
+
+ virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE;
+
+ virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) OVERRIDE;
+
+ virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) OVERRIDE;
+
+ virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) OVERRIDE;
+
+ // Creates a Value summary of the state of the socket pools. The caller is
+ // responsible for deleting the returned value.
+ virtual base::Value* SocketPoolInfoToValue() const OVERRIDE;
+
+ // CertDatabase::Observer methods:
+ virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE;
+ virtual void OnCertTrustChanged(const X509Certificate* cert) OVERRIDE;
+
+ private:
+ typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*>
+ TransportSocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, SOCKSClientSocketPool*>
+ SOCKSSocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, HttpProxyClientSocketPool*>
+ HTTPProxySocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, SSLClientSocketPool*>
+ SSLSocketPoolMap;
+
+ NetLog* const net_log_;
+ ClientSocketFactory* const socket_factory_;
+ HostResolver* const host_resolver_;
+ CertVerifier* const cert_verifier_;
+ ServerBoundCertService* const server_bound_cert_service_;
+ TransportSecurityState* const transport_security_state_;
+ const std::string ssl_session_cache_shard_;
+ ProxyService* const proxy_service_;
+ const scoped_refptr<SSLConfigService> ssl_config_service_;
+ const HttpNetworkSession::SocketPoolType pool_type_;
+
+ // Note: this ordering is important.
+
+ ClientSocketPoolHistograms transport_pool_histograms_;
+ scoped_ptr<TransportClientSocketPool> transport_socket_pool_;
+
+ ClientSocketPoolHistograms ssl_pool_histograms_;
+ scoped_ptr<SSLClientSocketPool> ssl_socket_pool_;
+
+ ClientSocketPoolHistograms transport_for_socks_pool_histograms_;
+ TransportSocketPoolMap transport_socket_pools_for_socks_proxies_;
+
+ ClientSocketPoolHistograms socks_pool_histograms_;
+ SOCKSSocketPoolMap socks_socket_pools_;
+
+ ClientSocketPoolHistograms transport_for_http_proxy_pool_histograms_;
+ TransportSocketPoolMap transport_socket_pools_for_http_proxies_;
+
+ ClientSocketPoolHistograms transport_for_https_proxy_pool_histograms_;
+ TransportSocketPoolMap transport_socket_pools_for_https_proxies_;
+
+ ClientSocketPoolHistograms ssl_for_https_proxy_pool_histograms_;
+ SSLSocketPoolMap ssl_socket_pools_for_https_proxies_;
+
+ ClientSocketPoolHistograms http_proxy_pool_histograms_;
+ HTTPProxySocketPoolMap http_proxy_socket_pools_;
+
+ ClientSocketPoolHistograms ssl_socket_pool_for_proxies_histograms_;
+ SSLSocketPoolMap ssl_socket_pools_for_proxies_;
+
+ DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolManagerImpl);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_
diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc
new file mode 100644
index 00000000000..eba01b5e9cc
--- /dev/null
+++ b/chromium/net/socket/deterministic_socket_data_unittest.cc
@@ -0,0 +1,621 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_test_util.h"
+
+#include <string.h>
+
+#include "base/memory/ref_counted.h"
+#include "testing/platform_test.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+//-----------------------------------------------------------------------------
+
+namespace {
+
+static const char kMsg1[] = "\0hello!\xff";
+static const int kLen1 = arraysize(kMsg1);
+static const char kMsg2[] = "\012345678\0";
+static const int kLen2 = arraysize(kMsg2);
+static const char kMsg3[] = "bye!";
+static const int kLen3 = arraysize(kMsg3);
+
+} // anonymous namespace
+
+namespace net {
+
+class DeterministicSocketDataTest : public PlatformTest {
+ public:
+ DeterministicSocketDataTest();
+
+ virtual void TearDown();
+
+ void ReentrantReadCallback(int len, int rv);
+ void ReentrantWriteCallback(const char* data, int len, int rv);
+
+ protected:
+ void Initialize(MockRead* reads, size_t reads_count, MockWrite* writes,
+ size_t writes_count);
+
+ void AssertSyncReadEquals(const char* data, int len);
+ void AssertAsyncReadEquals(const char* data, int len);
+ void AssertReadReturns(const char* data, int len, int rv);
+ void AssertReadBufferEquals(const char* data, int len);
+
+ void AssertSyncWriteEquals(const char* data, int len);
+ void AssertAsyncWriteEquals(const char* data, int len);
+ void AssertWriteReturns(const char* data, int len, int rv);
+
+ TestCompletionCallback read_callback_;
+ TestCompletionCallback write_callback_;
+ StreamSocket* sock_;
+ scoped_ptr<DeterministicSocketData> data_;
+
+ private:
+ scoped_refptr<IOBuffer> read_buf_;
+ MockConnect connect_data_;
+
+ HostPortPair endpoint_;
+ scoped_refptr<TransportSocketParams> tcp_params_;
+ ClientSocketPoolHistograms histograms_;
+ DeterministicMockClientSocketFactory socket_factory_;
+ MockTransportClientSocketPool socket_pool_;
+ ClientSocketHandle connection_;
+
+ DISALLOW_COPY_AND_ASSIGN(DeterministicSocketDataTest);
+};
+
+DeterministicSocketDataTest::DeterministicSocketDataTest()
+ : sock_(NULL),
+ read_buf_(NULL),
+ connect_data_(SYNCHRONOUS, OK),
+ endpoint_("www.google.com", 443),
+ tcp_params_(new TransportSocketParams(endpoint_,
+ LOWEST,
+ false,
+ false,
+ OnHostResolutionCallback())),
+ histograms_(std::string()),
+ socket_pool_(10, 10, &histograms_, &socket_factory_) {}
+
+void DeterministicSocketDataTest::TearDown() {
+ // Empty the current queue.
+ base::MessageLoop::current()->RunUntilIdle();
+ PlatformTest::TearDown();
+}
+
+void DeterministicSocketDataTest::Initialize(MockRead* reads,
+ size_t reads_count,
+ MockWrite* writes,
+ size_t writes_count) {
+ data_.reset(new DeterministicSocketData(reads, reads_count,
+ writes, writes_count));
+ data_->set_connect_data(connect_data_);
+ socket_factory_.AddSocketDataProvider(data_.get());
+
+ // Perform the TCP connect
+ EXPECT_EQ(OK,
+ connection_.Init(endpoint_.ToString(),
+ tcp_params_,
+ LOWEST,
+ CompletionCallback(),
+ reinterpret_cast<TransportClientSocketPool*>(&socket_pool_),
+ BoundNetLog()));
+ sock_ = connection_.socket();
+}
+
+void DeterministicSocketDataTest::AssertSyncReadEquals(const char* data,
+ int len) {
+ // Issue the read, which will complete immediately
+ AssertReadReturns(data, len, len);
+ AssertReadBufferEquals(data, len);
+}
+
+void DeterministicSocketDataTest::AssertAsyncReadEquals(const char* data,
+ int len) {
+ // Issue the read, which will be completed asynchronously
+ AssertReadReturns(data, len, ERR_IO_PENDING);
+
+ EXPECT_FALSE(read_callback_.have_result());
+ EXPECT_TRUE(sock_->IsConnected());
+ data_->RunFor(1); // Runs 1 step, to cause the callbacks to be invoked
+
+ // Now the read should complete
+ ASSERT_EQ(len, read_callback_.WaitForResult());
+ AssertReadBufferEquals(data, len);
+}
+
+void DeterministicSocketDataTest::AssertReadReturns(const char* data,
+ int len, int rv) {
+ read_buf_ = new IOBuffer(len);
+ ASSERT_EQ(rv, sock_->Read(read_buf_.get(), len, read_callback_.callback()));
+}
+
+void DeterministicSocketDataTest::AssertReadBufferEquals(const char* data,
+ int len) {
+ ASSERT_EQ(std::string(data, len), std::string(read_buf_->data(), len));
+}
+
+void DeterministicSocketDataTest::AssertSyncWriteEquals(const char* data,
+ int len) {
+ scoped_refptr<IOBuffer> buf(new IOBuffer(len));
+ memcpy(buf->data(), data, len);
+
+ // Issue the write, which will complete immediately
+ ASSERT_EQ(len, sock_->Write(buf.get(), len, write_callback_.callback()));
+}
+
+void DeterministicSocketDataTest::AssertAsyncWriteEquals(const char* data,
+ int len) {
+ // Issue the read, which will be completed asynchronously
+ AssertWriteReturns(data, len, ERR_IO_PENDING);
+
+ EXPECT_FALSE(read_callback_.have_result());
+ EXPECT_TRUE(sock_->IsConnected());
+ data_->RunFor(1); // Runs 1 step, to cause the callbacks to be invoked
+
+ ASSERT_EQ(len, write_callback_.WaitForResult());
+}
+
+void DeterministicSocketDataTest::AssertWriteReturns(const char* data,
+ int len, int rv) {
+ scoped_refptr<IOBuffer> buf(new IOBuffer(len));
+ memcpy(buf->data(), data, len);
+
+ // Issue the read, which will complete asynchronously
+ ASSERT_EQ(rv, sock_->Write(buf.get(), len, write_callback_.callback()));
+}
+
+void DeterministicSocketDataTest::ReentrantReadCallback(int len, int rv) {
+ scoped_refptr<IOBuffer> read_buf(new IOBuffer(len));
+ EXPECT_EQ(len,
+ sock_->Read(
+ read_buf.get(),
+ len,
+ base::Bind(&DeterministicSocketDataTest::ReentrantReadCallback,
+ base::Unretained(this),
+ len)));
+}
+
+void DeterministicSocketDataTest::ReentrantWriteCallback(
+ const char* data, int len, int rv) {
+ scoped_refptr<IOBuffer> write_buf(new IOBuffer(len));
+ memcpy(write_buf->data(), data, len);
+ EXPECT_EQ(len,
+ sock_->Write(
+ write_buf.get(),
+ len,
+ base::Bind(&DeterministicSocketDataTest::ReentrantWriteCallback,
+ base::Unretained(this),
+ data,
+ len)));
+}
+
+// ----------- Read
+
+TEST_F(DeterministicSocketDataTest, SingleSyncReadWhileStopped) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 1), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ data_->SetStopped(true);
+ AssertReadReturns(kMsg1, kLen1, ERR_UNEXPECTED);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleSyncReadTooEarly) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 1), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 2), // EOF
+ };
+
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, 0, 0)
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ data_->StopAfter(2);
+ ASSERT_FALSE(data_->stopped());
+ AssertReadReturns(kMsg1, kLen1, ERR_UNEXPECTED);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleSyncRead) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 1), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+ // Make sure we don't stop before we've read all the data
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MultipleSyncReads) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg3, kLen3, 3), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg3, kLen3, 5), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 7), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ // Make sure we don't stop before we've read all the data
+ data_->StopAfter(10);
+ AssertSyncReadEquals(kMsg1, kLen1);
+ AssertSyncReadEquals(kMsg2, kLen2);
+ AssertSyncReadEquals(kMsg3, kLen3);
+ AssertSyncReadEquals(kMsg3, kLen3);
+ AssertSyncReadEquals(kMsg2, kLen2);
+ AssertSyncReadEquals(kMsg3, kLen3);
+ AssertSyncReadEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleAsyncRead) {
+ MockRead reads[] = {
+ MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read
+ MockRead(SYNCHRONOUS, 0, 1), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ AssertAsyncReadEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MultipleAsyncReads) {
+ MockRead reads[] = {
+ MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read
+ MockRead(ASYNC, kMsg2, kLen2, 1), // Async Read
+ MockRead(ASYNC, kMsg3, kLen3, 2), // Async Read
+ MockRead(ASYNC, kMsg3, kLen3, 3), // Async Read
+ MockRead(ASYNC, kMsg2, kLen2, 4), // Async Read
+ MockRead(ASYNC, kMsg3, kLen3, 5), // Async Read
+ MockRead(ASYNC, kMsg1, kLen1, 6), // Async Read
+ MockRead(SYNCHRONOUS, 0, 7), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ AssertAsyncReadEquals(kMsg1, kLen1);
+ AssertAsyncReadEquals(kMsg2, kLen2);
+ AssertAsyncReadEquals(kMsg3, kLen3);
+ AssertAsyncReadEquals(kMsg3, kLen3);
+ AssertAsyncReadEquals(kMsg2, kLen2);
+ AssertAsyncReadEquals(kMsg3, kLen3);
+ AssertAsyncReadEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MixedReads) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(ASYNC, kMsg2, kLen2, 1), // Async Read
+ MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Read
+ MockRead(ASYNC, kMsg3, kLen3, 3), // Async Read
+ MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Read
+ MockRead(ASYNC, kMsg3, kLen3, 5), // Async Read
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 7), // EOF
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg1, kLen1);
+ AssertAsyncReadEquals(kMsg2, kLen2);
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg3, kLen3);
+ AssertAsyncReadEquals(kMsg3, kLen3);
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg2, kLen2);
+ AssertAsyncReadEquals(kMsg3, kLen3);
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, SyncReadFromCompletionCallback) {
+ MockRead reads[] = {
+ MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read
+ MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Read
+ };
+
+ Initialize(reads, arraysize(reads), NULL, 0);
+
+ data_->StopAfter(2);
+
+ scoped_refptr<IOBuffer> read_buf(new IOBuffer(kLen1));
+ ASSERT_EQ(ERR_IO_PENDING,
+ sock_->Read(
+ read_buf.get(),
+ kLen1,
+ base::Bind(&DeterministicSocketDataTest::ReentrantReadCallback,
+ base::Unretained(this),
+ kLen2)));
+ data_->Run();
+}
+
+// ----------- Write
+
+TEST_F(DeterministicSocketDataTest, SingleSyncWriteWhileStopped) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ data_->SetStopped(true);
+ AssertWriteReturns(kMsg1, kLen1, ERR_UNEXPECTED);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleSyncWriteTooEarly) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 1), // Sync Write
+ };
+
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, 0, 0)
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ data_->StopAfter(2);
+ ASSERT_FALSE(data_->stopped());
+ AssertWriteReturns(kMsg1, kLen1, ERR_UNEXPECTED);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleSyncWrite) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ // Make sure we don't stop before we've read all the data
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MultipleSyncWrites) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 3), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 5), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ // Make sure we don't stop before we've read all the data
+ data_->StopAfter(10);
+ AssertSyncWriteEquals(kMsg1, kLen1);
+ AssertSyncWriteEquals(kMsg2, kLen2);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+ AssertSyncWriteEquals(kMsg2, kLen2);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+ AssertSyncWriteEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, SingleAsyncWrite) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ AssertAsyncWriteEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MultipleAsyncWrites) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write
+ MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write
+ MockWrite(ASYNC, kMsg3, kLen3, 2), // Async Write
+ MockWrite(ASYNC, kMsg3, kLen3, 3), // Async Write
+ MockWrite(ASYNC, kMsg2, kLen2, 4), // Async Write
+ MockWrite(ASYNC, kMsg3, kLen3, 5), // Async Write
+ MockWrite(ASYNC, kMsg1, kLen1, 6), // Async Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ AssertAsyncWriteEquals(kMsg1, kLen1);
+ AssertAsyncWriteEquals(kMsg2, kLen2);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ AssertAsyncWriteEquals(kMsg2, kLen2);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ AssertAsyncWriteEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, MixedWrites) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Write
+ MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write
+ MockWrite(ASYNC, kMsg3, kLen3, 3), // Async Write
+ MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), // Sync Write
+ MockWrite(ASYNC, kMsg3, kLen3, 5), // Async Write
+ MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), // Sync Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg1, kLen1);
+ AssertAsyncWriteEquals(kMsg2, kLen2);
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg2, kLen2);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg1, kLen1);
+}
+
+TEST_F(DeterministicSocketDataTest, SyncWriteFromCompletionCallback) {
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg1, kLen1, 0), // Async Write
+ MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write
+ };
+
+ Initialize(NULL, 0, writes, arraysize(writes));
+
+ data_->StopAfter(2);
+
+ scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1));
+ memcpy(write_buf->data(), kMsg1, kLen1);
+ ASSERT_EQ(ERR_IO_PENDING,
+ sock_->Write(
+ write_buf.get(),
+ kLen1,
+ base::Bind(&DeterministicSocketDataTest::ReentrantWriteCallback,
+ base::Unretained(this),
+ kMsg2,
+ kLen2)));
+ data_->Run();
+}
+
+// ----------- Mixed Reads and Writes
+
+TEST_F(DeterministicSocketDataTest, MixedSyncOperations) {
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(SYNCHRONOUS, kMsg2, kLen2, 3), // Sync Read
+ MockRead(SYNCHRONOUS, 0, 4), // EOF
+ };
+
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), // Sync Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ // Make sure we don't stop before we've read/written everything
+ data_->StopAfter(10);
+ AssertSyncReadEquals(kMsg1, kLen1);
+ AssertSyncWriteEquals(kMsg2, kLen2);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+ AssertSyncReadEquals(kMsg2, kLen2);
+}
+
+TEST_F(DeterministicSocketDataTest, MixedAsyncOperations) {
+ MockRead reads[] = {
+ MockRead(ASYNC, kMsg1, kLen1, 0), // Sync Read
+ MockRead(ASYNC, kMsg2, kLen2, 3), // Sync Read
+ MockRead(ASYNC, 0, 4), // EOF
+ };
+
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg2, kLen2, 1), // Sync Write
+ MockWrite(ASYNC, kMsg3, kLen3, 2), // Sync Write
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ AssertAsyncReadEquals(kMsg1, kLen1);
+ AssertAsyncWriteEquals(kMsg2, kLen2);
+ AssertAsyncWriteEquals(kMsg3, kLen3);
+ AssertAsyncReadEquals(kMsg2, kLen2);
+}
+
+TEST_F(DeterministicSocketDataTest, InterleavedAsyncOperations) {
+ // Order of completion is read, write, write, read
+ MockRead reads[] = {
+ MockRead(ASYNC, kMsg1, kLen1, 0), // Async Read
+ MockRead(ASYNC, kMsg2, kLen2, 3), // Async Read
+ MockRead(ASYNC, 0, 4), // EOF
+ };
+
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write
+ MockWrite(ASYNC, kMsg3, kLen3, 2), // Async Write
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ // Issue the write, which will block until the read completes
+ AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING);
+
+ // Issue the read which will return first
+ AssertReadReturns(kMsg1, kLen1, ERR_IO_PENDING);
+
+ data_->RunFor(1);
+ ASSERT_TRUE(read_callback_.have_result());
+ ASSERT_EQ(kLen1, read_callback_.WaitForResult());
+ AssertReadBufferEquals(kMsg1, kLen1);
+
+ data_->RunFor(1);
+ ASSERT_TRUE(write_callback_.have_result());
+ ASSERT_EQ(kLen2, write_callback_.WaitForResult());
+
+ data_->StopAfter(1);
+ // Issue the read, which will block until the write completes
+ AssertReadReturns(kMsg2, kLen2, ERR_IO_PENDING);
+
+ // Issue the writes which will return first
+ AssertWriteReturns(kMsg3, kLen3, ERR_IO_PENDING);
+
+ data_->RunFor(1);
+ ASSERT_TRUE(write_callback_.have_result());
+ ASSERT_EQ(kLen3, write_callback_.WaitForResult());
+
+ data_->RunFor(1);
+ ASSERT_TRUE(read_callback_.have_result());
+ ASSERT_EQ(kLen2, read_callback_.WaitForResult());
+ AssertReadBufferEquals(kMsg2, kLen2);
+}
+
+TEST_F(DeterministicSocketDataTest, InterleavedMixedOperations) {
+ // Order of completion is read, write, write, read
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), // Sync Read
+ MockRead(ASYNC, kMsg2, kLen2, 3), // Async Read
+ MockRead(SYNCHRONOUS, 0, 4), // EOF
+ };
+
+ MockWrite writes[] = {
+ MockWrite(ASYNC, kMsg2, kLen2, 1), // Async Write
+ MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), // Sync Write
+ };
+
+ Initialize(reads, arraysize(reads), writes, arraysize(writes));
+
+ // Issue the write, which will block until the read completes
+ AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING);
+
+ // Issue the writes which will complete immediately
+ data_->StopAfter(1);
+ AssertSyncReadEquals(kMsg1, kLen1);
+
+ data_->RunFor(1);
+ ASSERT_TRUE(write_callback_.have_result());
+ ASSERT_EQ(kLen2, write_callback_.WaitForResult());
+
+ // Issue the read, which will block until the write completes
+ AssertReadReturns(kMsg2, kLen2, ERR_IO_PENDING);
+
+ // Issue the writes which will complete immediately
+ data_->StopAfter(1);
+ AssertSyncWriteEquals(kMsg3, kLen3);
+
+ data_->RunFor(1);
+ ASSERT_TRUE(read_callback_.have_result());
+ ASSERT_EQ(kLen2, read_callback_.WaitForResult());
+ AssertReadBufferEquals(kMsg2, kLen2);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/mock_client_socket_pool_manager.cc b/chromium/net/socket/mock_client_socket_pool_manager.cc
new file mode 100644
index 00000000000..0496adb9636
--- /dev/null
+++ b/chromium/net/socket/mock_client_socket_pool_manager.cc
@@ -0,0 +1,94 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/mock_client_socket_pool_manager.h"
+
+#include "net/http/http_proxy_client_socket_pool.h"
+#include "net/socket/socks_client_socket_pool.h"
+#include "net/socket/ssl_client_socket_pool.h"
+#include "net/socket/transport_client_socket_pool.h"
+
+namespace net {
+
+MockClientSocketPoolManager::MockClientSocketPoolManager() {}
+MockClientSocketPoolManager::~MockClientSocketPoolManager() {}
+
+void MockClientSocketPoolManager::SetTransportSocketPool(
+ TransportClientSocketPool* pool) {
+ transport_socket_pool_.reset(pool);
+}
+
+void MockClientSocketPoolManager::SetSSLSocketPool(
+ SSLClientSocketPool* pool) {
+ ssl_socket_pool_.reset(pool);
+}
+
+void MockClientSocketPoolManager::SetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy,
+ SOCKSClientSocketPool* pool) {
+ socks_socket_pools_[socks_proxy] = pool;
+}
+
+void MockClientSocketPoolManager::SetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy,
+ HttpProxyClientSocketPool* pool) {
+ http_proxy_socket_pools_[http_proxy] = pool;
+}
+
+void MockClientSocketPoolManager::SetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server,
+ SSLClientSocketPool* pool) {
+ ssl_socket_pools_for_proxies_[proxy_server] = pool;
+}
+
+void MockClientSocketPoolManager::FlushSocketPoolsWithError(int error) {
+ NOTIMPLEMENTED();
+}
+
+void MockClientSocketPoolManager::CloseIdleSockets() {
+ NOTIMPLEMENTED();
+}
+
+TransportClientSocketPool*
+MockClientSocketPoolManager::GetTransportSocketPool() {
+ return transport_socket_pool_.get();
+}
+
+SSLClientSocketPool* MockClientSocketPoolManager::GetSSLSocketPool() {
+ return ssl_socket_pool_.get();
+}
+
+SOCKSClientSocketPool* MockClientSocketPoolManager::GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) {
+ SOCKSSocketPoolMap::const_iterator it = socks_socket_pools_.find(socks_proxy);
+ if (it != socks_socket_pools_.end())
+ return it->second;
+ return NULL;
+}
+
+HttpProxyClientSocketPool*
+MockClientSocketPoolManager::GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) {
+ HTTPProxySocketPoolMap::const_iterator it =
+ http_proxy_socket_pools_.find(http_proxy);
+ if (it != http_proxy_socket_pools_.end())
+ return it->second;
+ return NULL;
+}
+
+SSLClientSocketPool* MockClientSocketPoolManager::GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) {
+ SSLSocketPoolMap::const_iterator it =
+ ssl_socket_pools_for_proxies_.find(proxy_server);
+ if (it != ssl_socket_pools_for_proxies_.end())
+ return it->second;
+ return NULL;
+}
+
+base::Value* MockClientSocketPoolManager::SocketPoolInfoToValue() const {
+ NOTIMPLEMENTED();
+ return NULL;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/mock_client_socket_pool_manager.h b/chromium/net/socket/mock_client_socket_pool_manager.h
new file mode 100644
index 00000000000..c2c3792a4f6
--- /dev/null
+++ b/chromium/net/socket/mock_client_socket_pool_manager.h
@@ -0,0 +1,63 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_
+#define NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_
+
+#include "base/basictypes.h"
+#include "net/socket/client_socket_pool_manager.h"
+#include "net/socket/client_socket_pool_manager_impl.h"
+
+namespace net {
+
+class MockClientSocketPoolManager : public ClientSocketPoolManager {
+ public:
+ MockClientSocketPoolManager();
+ virtual ~MockClientSocketPoolManager();
+
+ // Sets "override" socket pools that get used instead.
+ void SetTransportSocketPool(TransportClientSocketPool* pool);
+ void SetSSLSocketPool(SSLClientSocketPool* pool);
+ void SetSocketPoolForSOCKSProxy(const HostPortPair& socks_proxy,
+ SOCKSClientSocketPool* pool);
+ void SetSocketPoolForHTTPProxy(const HostPortPair& http_proxy,
+ HttpProxyClientSocketPool* pool);
+ void SetSocketPoolForSSLWithProxy(const HostPortPair& proxy_server,
+ SSLClientSocketPool* pool);
+
+ // ClientSocketPoolManager methods:
+ virtual void FlushSocketPoolsWithError(int error) OVERRIDE;
+ virtual void CloseIdleSockets() OVERRIDE;
+ virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE;
+ virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE;
+ virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) OVERRIDE;
+ virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) OVERRIDE;
+ virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) OVERRIDE;
+ virtual base::Value* SocketPoolInfoToValue() const OVERRIDE;
+
+ private:
+ typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*>
+ TransportSocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, SOCKSClientSocketPool*>
+ SOCKSSocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, HttpProxyClientSocketPool*>
+ HTTPProxySocketPoolMap;
+ typedef internal::OwnedPoolMap<HostPortPair, SSLClientSocketPool*>
+ SSLSocketPoolMap;
+
+ scoped_ptr<TransportClientSocketPool> transport_socket_pool_;
+ scoped_ptr<SSLClientSocketPool> ssl_socket_pool_;
+ SOCKSSocketPoolMap socks_socket_pools_;
+ HTTPProxySocketPoolMap http_proxy_socket_pools_;
+ SSLSocketPoolMap ssl_socket_pools_for_proxies_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockClientSocketPoolManager);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_MOCK_CLIENT_SOCKET_POOL_MANAGER_H_
diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h
new file mode 100644
index 00000000000..0bd307a3778
--- /dev/null
+++ b/chromium/net/socket/next_proto.h
@@ -0,0 +1,39 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_NEXT_PROTO_H_
+#define NET_SOCKET_NEXT_PROTO_H_
+
+namespace net {
+
+// Next Protocol Negotiation (NPN), if successful, results in agreement on an
+// application-level string that specifies the application level protocol to
+// use over the TLS connection. NextProto enumerates the application level
+// protocols that we recognise.
+enum NextProto {
+ kProtoUnknown = 0,
+ kProtoHTTP11 = 1,
+ kProtoMinimumVersion = kProtoHTTP11,
+
+ // TODO(akalin): Stop advertising SPDY/1 and remove this.
+ kProtoSPDY1 = 2,
+ kProtoSPDYMinimumVersion = kProtoSPDY1,
+ kProtoSPDY2 = 3,
+ // TODO(akalin): Stop adverising SPDY/2.1, too.
+ kProtoSPDY21 = 4,
+ kProtoSPDY3 = 5,
+ kProtoSPDY31 = 6,
+ kProtoSPDY4a2 = 7,
+ // We lump in HTTP/2 with the SPDY protocols for now.
+ kProtoHTTP2Draft04 = 8,
+ kProtoSPDYMaximumVersion = kProtoHTTP2Draft04,
+
+ kProtoQUIC1SPDY3 = 9,
+
+ kProtoMaximumVersion = kProtoQUIC1SPDY3,
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_NEXT_PROTO_H_
diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc
new file mode 100644
index 00000000000..ae037b17c1e
--- /dev/null
+++ b/chromium/net/socket/nss_ssl_util.cc
@@ -0,0 +1,276 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/nss_ssl_util.h"
+
+#include <nss.h>
+#include <secerr.h>
+#include <ssl.h>
+#include <sslerr.h>
+#include <sslproto.h>
+
+#include <string>
+
+#include "base/bind.h"
+#include "base/lazy_instance.h"
+#include "base/logging.h"
+#include "base/memory/singleton.h"
+#include "base/threading/thread_restrictions.h"
+#include "base/values.h"
+#include "build/build_config.h"
+#include "crypto/nss_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+
+#if defined(OS_WIN)
+#include "base/win/windows_version.h"
+#endif
+
+namespace net {
+
+class NSSSSLInitSingleton {
+ public:
+ NSSSSLInitSingleton() {
+ crypto::EnsureNSSInit();
+
+ NSS_SetDomesticPolicy();
+
+ const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers();
+ const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers();
+
+ // Disable ECDSA cipher suites on platforms that do not support ECDSA
+ // signed certificates, as servers may use the presence of such
+ // ciphersuites as a hint to send an ECDSA certificate.
+ bool disableECDSA = false;
+#if defined(OS_WIN)
+ if (base::win::GetVersion() < base::win::VERSION_VISTA)
+ disableECDSA = true;
+#endif
+
+ // Explicitly enable exactly those ciphers with keys of at least 80 bits
+ for (int i = 0; i < num_ciphers; i++) {
+ SSLCipherSuiteInfo info;
+ if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info,
+ sizeof(info)) == SECSuccess) {
+ bool enabled = info.effectiveKeyBits >= 80;
+ if (info.authAlgorithm == ssl_auth_ecdsa && disableECDSA)
+ enabled = false;
+
+ // Trim the list of cipher suites in order to keep the size of the
+ // ClientHello down. DSS, ECDH, CAMELLIA, SEED, ECC+3DES, and
+ // HMAC-SHA256 cipher suites are disabled.
+ if (info.symCipher == ssl_calg_camellia ||
+ info.symCipher == ssl_calg_seed ||
+ (info.symCipher == ssl_calg_3des && info.keaType != ssl_kea_rsa) ||
+ info.authAlgorithm == ssl_auth_dsa ||
+ info.macAlgorithm == ssl_hmac_sha256 ||
+ info.nonStandard ||
+ strcmp(info.keaTypeName, "ECDH") == 0) {
+ enabled = false;
+ }
+
+ if (ssl_ciphers[i] == TLS_DHE_DSS_WITH_AES_128_CBC_SHA) {
+ // Enabled to allow servers with only a DSA certificate to function.
+ enabled = true;
+ }
+ SSL_CipherPrefSetDefault(ssl_ciphers[i], enabled);
+ }
+ }
+
+ // Enable SSL.
+ SSL_OptionSetDefault(SSL_SECURITY, PR_TRUE);
+
+ // All other SSL options are set per-session by SSLClientSocket and
+ // SSLServerSocket.
+ }
+
+ ~NSSSSLInitSingleton() {
+ // Have to clear the cache, or NSS_Shutdown fails with SEC_ERROR_BUSY.
+ SSL_ClearSessionCache();
+ }
+};
+
+static base::LazyInstance<NSSSSLInitSingleton> g_nss_ssl_init_singleton =
+ LAZY_INSTANCE_INITIALIZER;
+
+// Initialize the NSS SSL library if it isn't already initialized. This must
+// be called before any other NSS SSL functions. This function is
+// thread-safe, and the NSS SSL library will only ever be initialized once.
+// The NSS SSL library will be properly shut down on program exit.
+void EnsureNSSSSLInit() {
+ // Initializing SSL causes us to do blocking IO.
+ // Temporarily allow it until we fix
+ // http://code.google.com/p/chromium/issues/detail?id=59847
+ base::ThreadRestrictions::ScopedAllowIO allow_io;
+
+ g_nss_ssl_init_singleton.Get();
+}
+
+// Map a Chromium net error code to an NSS error code.
+// See _MD_unix_map_default_error in the NSS source
+// tree for inspiration.
+PRErrorCode MapErrorToNSS(int result) {
+ if (result >=0)
+ return result;
+
+ switch (result) {
+ case ERR_IO_PENDING:
+ return PR_WOULD_BLOCK_ERROR;
+ case ERR_ACCESS_DENIED:
+ case ERR_NETWORK_ACCESS_DENIED:
+ // For connect, this could be mapped to PR_ADDRESS_NOT_SUPPORTED_ERROR.
+ return PR_NO_ACCESS_RIGHTS_ERROR;
+ case ERR_NOT_IMPLEMENTED:
+ return PR_NOT_IMPLEMENTED_ERROR;
+ case ERR_SOCKET_NOT_CONNECTED:
+ return PR_NOT_CONNECTED_ERROR;
+ case ERR_INTERNET_DISCONNECTED: // Equivalent to ENETDOWN.
+ return PR_NETWORK_UNREACHABLE_ERROR; // Best approximation.
+ case ERR_CONNECTION_TIMED_OUT:
+ case ERR_TIMED_OUT:
+ return PR_IO_TIMEOUT_ERROR;
+ case ERR_CONNECTION_RESET:
+ return PR_CONNECT_RESET_ERROR;
+ case ERR_CONNECTION_ABORTED:
+ return PR_CONNECT_ABORTED_ERROR;
+ case ERR_CONNECTION_REFUSED:
+ return PR_CONNECT_REFUSED_ERROR;
+ case ERR_ADDRESS_UNREACHABLE:
+ return PR_HOST_UNREACHABLE_ERROR; // Also PR_NETWORK_UNREACHABLE_ERROR.
+ case ERR_ADDRESS_INVALID:
+ return PR_ADDRESS_NOT_AVAILABLE_ERROR;
+ case ERR_NAME_NOT_RESOLVED:
+ return PR_DIRECTORY_LOOKUP_ERROR;
+ default:
+ LOG(WARNING) << "MapErrorToNSS " << result
+ << " mapped to PR_UNKNOWN_ERROR";
+ return PR_UNKNOWN_ERROR;
+ }
+}
+
+// The default error mapping function.
+// Maps an NSS error code to a network error code.
+int MapNSSError(PRErrorCode err) {
+ // TODO(port): fill this out as we learn what's important
+ switch (err) {
+ case PR_WOULD_BLOCK_ERROR:
+ return ERR_IO_PENDING;
+ case PR_ADDRESS_NOT_SUPPORTED_ERROR: // For connect.
+ case PR_NO_ACCESS_RIGHTS_ERROR:
+ return ERR_ACCESS_DENIED;
+ case PR_IO_TIMEOUT_ERROR:
+ return ERR_TIMED_OUT;
+ case PR_CONNECT_RESET_ERROR:
+ return ERR_CONNECTION_RESET;
+ case PR_CONNECT_ABORTED_ERROR:
+ return ERR_CONNECTION_ABORTED;
+ case PR_CONNECT_REFUSED_ERROR:
+ return ERR_CONNECTION_REFUSED;
+ case PR_NOT_CONNECTED_ERROR:
+ return ERR_SOCKET_NOT_CONNECTED;
+ case PR_HOST_UNREACHABLE_ERROR:
+ case PR_NETWORK_UNREACHABLE_ERROR:
+ return ERR_ADDRESS_UNREACHABLE;
+ case PR_ADDRESS_NOT_AVAILABLE_ERROR:
+ return ERR_ADDRESS_INVALID;
+ case PR_INVALID_ARGUMENT_ERROR:
+ return ERR_INVALID_ARGUMENT;
+ case PR_END_OF_FILE_ERROR:
+ return ERR_CONNECTION_CLOSED;
+ case PR_NOT_IMPLEMENTED_ERROR:
+ return ERR_NOT_IMPLEMENTED;
+
+ case SEC_ERROR_LIBRARY_FAILURE:
+ return ERR_UNEXPECTED;
+ case SEC_ERROR_INVALID_ARGS:
+ return ERR_INVALID_ARGUMENT;
+ case SEC_ERROR_NO_MEMORY:
+ return ERR_OUT_OF_MEMORY;
+ case SEC_ERROR_NO_KEY:
+ return ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY;
+ case SEC_ERROR_INVALID_KEY:
+ case SSL_ERROR_SIGN_HASHES_FAILURE:
+ LOG(ERROR) << "ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED: NSS error " << err
+ << ", OS error " << PR_GetOSError();
+ return ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED;
+ // A handshake (initial or renegotiation) may fail because some signature
+ // (for example, the signature in the ServerKeyExchange message for an
+ // ephemeral Diffie-Hellman cipher suite) is invalid.
+ case SEC_ERROR_BAD_SIGNATURE:
+ return ERR_SSL_PROTOCOL_ERROR;
+
+ case SSL_ERROR_SSL_DISABLED:
+ return ERR_NO_SSL_VERSIONS_ENABLED;
+ case SSL_ERROR_NO_CYPHER_OVERLAP:
+ case SSL_ERROR_PROTOCOL_VERSION_ALERT:
+ case SSL_ERROR_UNSUPPORTED_VERSION:
+ return ERR_SSL_VERSION_OR_CIPHER_MISMATCH;
+ case SSL_ERROR_HANDSHAKE_FAILURE_ALERT:
+ case SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT:
+ case SSL_ERROR_ILLEGAL_PARAMETER_ALERT:
+ return ERR_SSL_PROTOCOL_ERROR;
+ case SSL_ERROR_DECOMPRESSION_FAILURE_ALERT:
+ return ERR_SSL_DECOMPRESSION_FAILURE_ALERT;
+ case SSL_ERROR_BAD_MAC_ALERT:
+ return ERR_SSL_BAD_RECORD_MAC_ALERT;
+ case SSL_ERROR_DECRYPT_ERROR_ALERT:
+ return ERR_SSL_DECRYPT_ERROR_ALERT;
+ case SSL_ERROR_UNSAFE_NEGOTIATION:
+ return ERR_SSL_UNSAFE_NEGOTIATION;
+ case SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY:
+ return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY;
+ case SSL_ERROR_HANDSHAKE_NOT_COMPLETED:
+ return ERR_SSL_HANDSHAKE_NOT_COMPLETED;
+ case SEC_ERROR_BAD_KEY:
+ case SSL_ERROR_EXTRACT_PUBLIC_KEY_FAILURE:
+ // TODO(wtc): the following errors may also occur in contexts unrelated
+ // to the peer's public key. We should add new error codes for them, or
+ // map them to ERR_SSL_BAD_PEER_PUBLIC_KEY only in the right context.
+ // General unsupported/unknown key algorithm error.
+ case SEC_ERROR_UNSUPPORTED_KEYALG:
+ // General DER decoding errors.
+ case SEC_ERROR_BAD_DER:
+ case SEC_ERROR_EXTRA_INPUT:
+ return ERR_SSL_BAD_PEER_PUBLIC_KEY;
+
+ default: {
+ if (IS_SSL_ERROR(err)) {
+ LOG(WARNING) << "Unknown SSL error " << err
+ << " mapped to net::ERR_SSL_PROTOCOL_ERROR";
+ return ERR_SSL_PROTOCOL_ERROR;
+ }
+ LOG(WARNING) << "Unknown error " << err << " mapped to net::ERR_FAILED";
+ return ERR_FAILED;
+ }
+ }
+}
+
+// Returns parameters to attach to the NetLog when we receive an error in
+// response to a call to an NSS function. Used instead of
+// NetLogSSLErrorCallback with events of type TYPE_SSL_NSS_ERROR.
+base::Value* NetLogSSLFailedNSSFunctionCallback(
+ const char* function,
+ const char* param,
+ int ssl_lib_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("function", function);
+ if (param[0] != '\0')
+ dict->SetString("param", param);
+ dict->SetInteger("ssl_lib_error", ssl_lib_error);
+ return dict;
+}
+
+void LogFailedNSSFunction(const BoundNetLog& net_log,
+ const char* function,
+ const char* param) {
+ DCHECK(function);
+ DCHECK(param);
+ net_log.AddEvent(
+ NetLog::TYPE_SSL_NSS_ERROR,
+ base::Bind(&NetLogSSLFailedNSSFunctionCallback,
+ function, param, PR_GetError()));
+}
+
+} // namespace net
diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h
new file mode 100644
index 00000000000..09ae3562cd7
--- /dev/null
+++ b/chromium/net/socket/nss_ssl_util.h
@@ -0,0 +1,35 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file is only included in ssl_client_socket_nss.cc and
+// ssl_server_socket_nss.cc to share common functions of NSS.
+
+#ifndef NET_SOCKET_NSS_SSL_UTIL_H_
+#define NET_SOCKET_NSS_SSL_UTIL_H_
+
+#include <prerror.h>
+
+#include "net/base/net_export.h"
+
+namespace net {
+
+class BoundNetLog;
+
+// Initalize NSS SSL library.
+NET_EXPORT void EnsureNSSSSLInit();
+
+// Log a failed NSS funcion call.
+void LogFailedNSSFunction(const BoundNetLog& net_log,
+ const char* function,
+ const char* param);
+
+// Map network error code to NSS error code.
+PRErrorCode MapErrorToNSS(int result);
+
+// Map NSS error code to network error code.
+int MapNSSError(PRErrorCode err);
+
+} // namespace net
+
+#endif // NET_SOCKET_NSS_SSL_UTIL_H_
diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h
new file mode 100644
index 00000000000..11151eea153
--- /dev/null
+++ b/chromium/net/socket/server_socket.h
@@ -0,0 +1,40 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SERVER_SOCKET_H_
+#define NET_SOCKET_SERVER_SOCKET_H_
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class IPEndPoint;
+class StreamSocket;
+
+class NET_EXPORT ServerSocket {
+ public:
+ ServerSocket() { }
+ virtual ~ServerSocket() { }
+
+ // Bind the socket and start listening. Destroy the socket to stop
+ // listening.
+ virtual int Listen(const net::IPEndPoint& address, int backlog) = 0;
+
+ // Gets current address the socket is bound to.
+ virtual int GetLocalAddress(IPEndPoint* address) const = 0;
+
+ // Accept connection. Callback is called when new connection is
+ // accepted.
+ virtual int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) = 0;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ServerSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SERVER_SOCKET_H_
diff --git a/chromium/net/socket/socket.h b/chromium/net/socket/socket.h
new file mode 100644
index 00000000000..fccb25873a4
--- /dev/null
+++ b/chromium/net/socket/socket.h
@@ -0,0 +1,62 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKET_H_
+#define NET_SOCKET_SOCKET_H_
+
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class IOBuffer;
+
+// Represents a read/write socket.
+class NET_EXPORT Socket {
+ public:
+ virtual ~Socket() {}
+
+ // Reads data, up to |buf_len| bytes, from the socket. The number of bytes
+ // read is returned, or an error is returned upon failure.
+ // ERR_SOCKET_NOT_CONNECTED should be returned if the socket is not currently
+ // connected. Zero is returned once to indicate end-of-file; the return value
+ // of subsequent calls is undefined, and may be OS dependent. ERR_IO_PENDING
+ // is returned if the operation could not be completed synchronously, in which
+ // case the result will be passed to the callback when available. If the
+ // operation is not completed immediately, the socket acquires a reference to
+ // the provided buffer until the callback is invoked or the socket is
+ // closed. If the socket is Disconnected before the read completes, the
+ // callback will not be invoked.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) = 0;
+
+ // Writes data, up to |buf_len| bytes, to the socket. Note: data may be
+ // written partially. The number of bytes written is returned, or an error
+ // is returned upon failure. ERR_SOCKET_NOT_CONNECTED should be returned if
+ // the socket is not currently connected. The return value when the
+ // connection is closed is undefined, and may be OS dependent. ERR_IO_PENDING
+ // is returned if the operation could not be completed synchronously, in which
+ // case the result will be passed to the callback when available. If the
+ // operation is not completed immediately, the socket acquires a reference to
+ // the provided buffer until the callback is invoked or the socket is
+ // closed. Implementations of this method should not modify the contents
+ // of the actual buffer that is written to the socket. If the socket is
+ // Disconnected before the write completes, the callback will not be invoked.
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) = 0;
+
+ // Set the receive buffer size (in bytes) for the socket.
+ // Note: changing this value can affect the TCP window size on some platforms.
+ // Returns true on success, or false on failure.
+ virtual bool SetReceiveBufferSize(int32 size) = 0;
+
+ // Set the send buffer size (in bytes) for the socket.
+ // Note: changing this value can affect the TCP window size on some platforms.
+ // Returns true on success, or false on failure.
+ virtual bool SetSendBufferSize(int32 size) = 0;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKET_H_
diff --git a/chromium/net/socket/socket_net_log_params.cc b/chromium/net/socket/socket_net_log_params.cc
new file mode 100644
index 00000000000..bcc12c86bcf
--- /dev/null
+++ b/chromium/net/socket/socket_net_log_params.cc
@@ -0,0 +1,72 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_net_log_params.h"
+
+#include "base/bind.h"
+#include "base/values.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+namespace {
+
+base::Value* NetLogSocketErrorCallback(int net_error,
+ int os_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetInteger("net_error", net_error);
+ dict->SetInteger("os_error", os_error);
+ return dict;
+}
+
+base::Value* NetLogHostPortPairCallback(const HostPortPair* host_and_port,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("host_and_port", host_and_port->ToString());
+ return dict;
+}
+
+base::Value* NetLogIPEndPointCallback(const IPEndPoint* address,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("address", address->ToString());
+ return dict;
+}
+
+base::Value* NetLogSourceAddressCallback(const struct sockaddr* net_address,
+ socklen_t address_len,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("source_address",
+ NetAddressToStringWithPort(net_address, address_len));
+ return dict;
+}
+
+} // namespace
+
+NetLog::ParametersCallback CreateNetLogSocketErrorCallback(int net_error,
+ int os_error) {
+ return base::Bind(&NetLogSocketErrorCallback, net_error, os_error);
+}
+
+NetLog::ParametersCallback CreateNetLogHostPortPairCallback(
+ const HostPortPair* host_and_port) {
+ return base::Bind(&NetLogHostPortPairCallback, host_and_port);
+}
+
+NetLog::ParametersCallback CreateNetLogIPEndPointCallback(
+ const IPEndPoint* address) {
+ return base::Bind(&NetLogIPEndPointCallback, address);
+}
+
+NetLog::ParametersCallback CreateNetLogSourceAddressCallback(
+ const struct sockaddr* net_address,
+ socklen_t address_len) {
+ return base::Bind(&NetLogSourceAddressCallback, net_address, address_len);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socket_net_log_params.h b/chromium/net/socket/socket_net_log_params.h
new file mode 100644
index 00000000000..f5fe652d125
--- /dev/null
+++ b/chromium/net/socket/socket_net_log_params.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_
+#define NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_
+
+#include "net/base/net_log.h"
+#include "net/base/sys_addrinfo.h"
+
+namespace net {
+
+class HostPortPair;
+class IPEndPoint;
+
+// Creates a NetLog callback for socket error events.
+NetLog::ParametersCallback CreateNetLogSocketErrorCallback(int net_error,
+ int os_error);
+
+// Creates a NetLog callback for a HostPortPair.
+// |host_and_port| must remain valid for the lifetime of the returned callback.
+NetLog::ParametersCallback CreateNetLogHostPortPairCallback(
+ const HostPortPair* host_and_port);
+
+// Creates a NetLog callback for an IPEndPoint.
+// |address| must remain valid for the lifetime of the returned callback.
+NetLog::ParametersCallback CreateNetLogIPEndPointCallback(
+ const IPEndPoint* address);
+
+// Creates a NetLog callback for the source sockaddr on connect events.
+// |net_address| must remain valid for the lifetime of the returned callback.
+NetLog::ParametersCallback CreateNetLogSourceAddressCallback(
+ const struct sockaddr* net_address,
+ socklen_t address_len);
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_
diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc
new file mode 100644
index 00000000000..159f62e42c6
--- /dev/null
+++ b/chromium/net/socket/socket_test_util.cc
@@ -0,0 +1,1888 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_test_util.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/compiler_specific.h"
+#include "base/message_loop/message_loop.h"
+#include "base/run_loop.h"
+#include "base/time/time.h"
+#include "net/base/address_family.h"
+#include "net/base/address_list.h"
+#include "net/base/auth.h"
+#include "net/base/load_timing_info.h"
+#include "net/http/http_network_session.h"
+#include "net/http/http_request_headers.h"
+#include "net/http/http_response_headers.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/socket.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_info.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+// Socket events are easier to debug if you log individual reads and writes.
+// Enable these if locally debugging, but they are too noisy for the waterfall.
+#if 0
+#define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() "
+#else
+#define NET_TRACE(level, s) EAT_STREAM_PARAMETERS
+#endif
+
+namespace net {
+
+namespace {
+
+inline char AsciifyHigh(char x) {
+ char nybble = static_cast<char>((x >> 4) & 0x0F);
+ return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
+}
+
+inline char AsciifyLow(char x) {
+ char nybble = static_cast<char>((x >> 0) & 0x0F);
+ return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
+}
+
+inline char Asciify(char x) {
+ if ((x < 0) || !isprint(x))
+ return '.';
+ return x;
+}
+
+void DumpData(const char* data, int data_len) {
+ if (logging::LOG_INFO < logging::GetMinLogLevel())
+ return;
+ DVLOG(1) << "Length: " << data_len;
+ const char* pfx = "Data: ";
+ if (!data || (data_len <= 0)) {
+ DVLOG(1) << pfx << "<None>";
+ } else {
+ int i;
+ for (i = 0; i <= (data_len - 4); i += 4) {
+ DVLOG(1) << pfx
+ << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
+ << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
+ << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
+ << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
+ << " '"
+ << Asciify(data[i + 0])
+ << Asciify(data[i + 1])
+ << Asciify(data[i + 2])
+ << Asciify(data[i + 3])
+ << "'";
+ pfx = " ";
+ }
+ // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
+ switch (data_len - i) {
+ case 3:
+ DVLOG(1) << pfx
+ << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
+ << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
+ << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
+ << " '"
+ << Asciify(data[i + 0])
+ << Asciify(data[i + 1])
+ << Asciify(data[i + 2])
+ << " '";
+ break;
+ case 2:
+ DVLOG(1) << pfx
+ << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
+ << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
+ << " '"
+ << Asciify(data[i + 0])
+ << Asciify(data[i + 1])
+ << " '";
+ break;
+ case 1:
+ DVLOG(1) << pfx
+ << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
+ << " '"
+ << Asciify(data[i + 0])
+ << " '";
+ break;
+ }
+ }
+}
+
+template <MockReadWriteType type>
+void DumpMockReadWrite(const MockReadWrite<type>& r) {
+ if (logging::LOG_INFO < logging::GetMinLogLevel())
+ return;
+ DVLOG(1) << "Async: " << (r.mode == ASYNC)
+ << "\nResult: " << r.result;
+ DumpData(r.data, r.data_len);
+ const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
+ DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
+ << "\nTime: " << r.time_stamp.ToInternalValue();
+}
+
+} // namespace
+
+MockConnect::MockConnect() : mode(ASYNC), result(OK) {
+ IPAddressNumber ip;
+ CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
+ peer_addr = IPEndPoint(ip, 0);
+}
+
+MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
+ IPAddressNumber ip;
+ CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
+ peer_addr = IPEndPoint(ip, 0);
+}
+
+MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) :
+ mode(io_mode),
+ result(r),
+ peer_addr(addr) {
+}
+
+MockConnect::~MockConnect() {}
+
+StaticSocketDataProvider::StaticSocketDataProvider()
+ : reads_(NULL),
+ read_index_(0),
+ read_count_(0),
+ writes_(NULL),
+ write_index_(0),
+ write_count_(0) {
+}
+
+StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
+ size_t reads_count,
+ MockWrite* writes,
+ size_t writes_count)
+ : reads_(reads),
+ read_index_(0),
+ read_count_(reads_count),
+ writes_(writes),
+ write_index_(0),
+ write_count_(writes_count) {
+}
+
+StaticSocketDataProvider::~StaticSocketDataProvider() {}
+
+const MockRead& StaticSocketDataProvider::PeekRead() const {
+ DCHECK(!at_read_eof());
+ return reads_[read_index_];
+}
+
+const MockWrite& StaticSocketDataProvider::PeekWrite() const {
+ DCHECK(!at_write_eof());
+ return writes_[write_index_];
+}
+
+const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
+ DCHECK_LT(index, read_count_);
+ return reads_[index];
+}
+
+const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
+ DCHECK_LT(index, write_count_);
+ return writes_[index];
+}
+
+MockRead StaticSocketDataProvider::GetNextRead() {
+ DCHECK(!at_read_eof());
+ reads_[read_index_].time_stamp = base::Time::Now();
+ return reads_[read_index_++];
+}
+
+MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
+ if (!writes_) {
+ // Not using mock writes; succeed synchronously.
+ return MockWriteResult(SYNCHRONOUS, data.length());
+ }
+ DCHECK(!at_write_eof());
+
+ // Check that what we are writing matches the expectation.
+ // Then give the mocked return value.
+ MockWrite* w = &writes_[write_index_++];
+ w->time_stamp = base::Time::Now();
+ int result = w->result;
+ if (w->data) {
+ // Note - we can simulate a partial write here. If the expected data
+ // is a match, but shorter than the write actually written, that is legal.
+ // Example:
+ // Application writes "foobarbaz" (9 bytes)
+ // Expected write was "foo" (3 bytes)
+ // This is a success, and we return 3 to the application.
+ std::string expected_data(w->data, w->data_len);
+ EXPECT_GE(data.length(), expected_data.length());
+ std::string actual_data(data.substr(0, w->data_len));
+ EXPECT_EQ(expected_data, actual_data);
+ if (expected_data != actual_data)
+ return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
+ if (result == OK)
+ result = w->data_len;
+ }
+ return MockWriteResult(w->mode, result);
+}
+
+void StaticSocketDataProvider::Reset() {
+ read_index_ = 0;
+ write_index_ = 0;
+}
+
+DynamicSocketDataProvider::DynamicSocketDataProvider()
+ : short_read_limit_(0),
+ allow_unconsumed_reads_(false) {
+}
+
+DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
+
+MockRead DynamicSocketDataProvider::GetNextRead() {
+ if (reads_.empty())
+ return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
+ MockRead result = reads_.front();
+ if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
+ reads_.pop_front();
+ } else {
+ result.data_len = short_read_limit_;
+ reads_.front().data += result.data_len;
+ reads_.front().data_len -= result.data_len;
+ }
+ return result;
+}
+
+void DynamicSocketDataProvider::Reset() {
+ reads_.clear();
+}
+
+void DynamicSocketDataProvider::SimulateRead(const char* data,
+ const size_t length) {
+ if (!allow_unconsumed_reads_) {
+ EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
+ }
+ reads_.push_back(MockRead(ASYNC, data, length));
+}
+
+SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
+ : connect(mode, result),
+ next_proto_status(SSLClientSocket::kNextProtoUnsupported),
+ was_npn_negotiated(false),
+ protocol_negotiated(kProtoUnknown),
+ client_cert_sent(false),
+ cert_request_info(NULL),
+ channel_id_sent(false) {
+}
+
+SSLSocketDataProvider::~SSLSocketDataProvider() {
+}
+
+void SSLSocketDataProvider::SetNextProto(NextProto proto) {
+ was_npn_negotiated = true;
+ next_proto_status = SSLClientSocket::kNextProtoNegotiated;
+ protocol_negotiated = proto;
+ next_proto = SSLClientSocket::NextProtoToString(proto);
+}
+
+DelayedSocketData::DelayedSocketData(
+ int write_delay, MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count)
+ : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
+ write_delay_(write_delay),
+ read_in_progress_(false),
+ weak_factory_(this) {
+ DCHECK_GE(write_delay_, 0);
+}
+
+DelayedSocketData::DelayedSocketData(
+ const MockConnect& connect, int write_delay, MockRead* reads,
+ size_t reads_count, MockWrite* writes, size_t writes_count)
+ : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
+ write_delay_(write_delay),
+ read_in_progress_(false),
+ weak_factory_(this) {
+ DCHECK_GE(write_delay_, 0);
+ set_connect_data(connect);
+}
+
+DelayedSocketData::~DelayedSocketData() {
+}
+
+void DelayedSocketData::ForceNextRead() {
+ DCHECK(read_in_progress_);
+ write_delay_ = 0;
+ CompleteRead();
+}
+
+MockRead DelayedSocketData::GetNextRead() {
+ MockRead out = MockRead(ASYNC, ERR_IO_PENDING);
+ if (write_delay_ <= 0)
+ out = StaticSocketDataProvider::GetNextRead();
+ read_in_progress_ = (out.result == ERR_IO_PENDING);
+ return out;
+}
+
+MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
+ MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
+ // Now that our write has completed, we can allow reads to continue.
+ if (!--write_delay_ && read_in_progress_)
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&DelayedSocketData::CompleteRead,
+ weak_factory_.GetWeakPtr()),
+ base::TimeDelta::FromMilliseconds(100));
+ return rv;
+}
+
+void DelayedSocketData::Reset() {
+ set_socket(NULL);
+ read_in_progress_ = false;
+ weak_factory_.InvalidateWeakPtrs();
+ StaticSocketDataProvider::Reset();
+}
+
+void DelayedSocketData::CompleteRead() {
+ if (socket() && read_in_progress_)
+ socket()->OnReadComplete(GetNextRead());
+}
+
+OrderedSocketData::OrderedSocketData(
+ MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
+ : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
+ sequence_number_(0), loop_stop_stage_(0),
+ blocked_(false), weak_factory_(this) {
+}
+
+OrderedSocketData::OrderedSocketData(
+ const MockConnect& connect,
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count)
+ : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
+ sequence_number_(0), loop_stop_stage_(0),
+ blocked_(false), weak_factory_(this) {
+ set_connect_data(connect);
+}
+
+void OrderedSocketData::EndLoop() {
+ // If we've already stopped the loop, don't do it again until we've advanced
+ // to the next sequence_number.
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()";
+ if (loop_stop_stage_ > 0) {
+ const MockRead& next_read = StaticSocketDataProvider::PeekRead();
+ if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
+ loop_stop_stage_) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": Clearing stop index";
+ loop_stop_stage_ = 0;
+ } else {
+ return;
+ }
+ }
+ // Record the sequence_number at which we stopped the loop.
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": Posting Quit at read " << read_index();
+ loop_stop_stage_ = sequence_number_;
+}
+
+MockRead OrderedSocketData::GetNextRead() {
+ weak_factory_.InvalidateWeakPtrs();
+ blocked_ = false;
+ const MockRead& next_read = StaticSocketDataProvider::PeekRead();
+ if (next_read.sequence_number & MockRead::STOPLOOP)
+ EndLoop();
+ if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
+ sequence_number_++) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
+ << ": Read " << read_index();
+ DumpMockReadWrite(next_read);
+ blocked_ = (next_read.result == ERR_IO_PENDING);
+ return StaticSocketDataProvider::GetNextRead();
+ }
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1
+ << ": I/O Pending";
+ MockRead result = MockRead(ASYNC, ERR_IO_PENDING);
+ DumpMockReadWrite(result);
+ blocked_ = true;
+ return result;
+}
+
+MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": Write " << write_index();
+ DumpMockReadWrite(PeekWrite());
+ ++sequence_number_;
+ if (blocked_) {
+ // TODO(willchan): This 100ms delay seems to work around some weirdness. We
+ // should probably fix the weirdness. One example is in SpdyStream,
+ // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the
+ // SYN_REPLY causes OnResponseReceived() to get called before
+ // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&OrderedSocketData::CompleteRead,
+ weak_factory_.GetWeakPtr()),
+ base::TimeDelta::FromMilliseconds(100));
+ }
+ return StaticSocketDataProvider::OnWrite(data);
+}
+
+void OrderedSocketData::Reset() {
+ NET_TRACE(INFO, " *** ") << "Stage "
+ << sequence_number_ << ": Reset()";
+ sequence_number_ = 0;
+ loop_stop_stage_ = 0;
+ set_socket(NULL);
+ weak_factory_.InvalidateWeakPtrs();
+ StaticSocketDataProvider::Reset();
+}
+
+void OrderedSocketData::CompleteRead() {
+ if (socket() && blocked_) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_;
+ socket()->OnReadComplete(GetNextRead());
+ }
+}
+
+OrderedSocketData::~OrderedSocketData() {}
+
+DeterministicSocketData::DeterministicSocketData(MockRead* reads,
+ size_t reads_count, MockWrite* writes, size_t writes_count)
+ : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
+ sequence_number_(0),
+ current_read_(),
+ current_write_(),
+ stopping_sequence_number_(0),
+ stopped_(false),
+ print_debug_(false),
+ is_running_(false) {
+ VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count);
+}
+
+DeterministicSocketData::~DeterministicSocketData() {}
+
+void DeterministicSocketData::Run() {
+ DCHECK(!is_running_);
+ is_running_ = true;
+
+ SetStopped(false);
+ int counter = 0;
+ // Continue to consume data until all data has run out, or the stopped_ flag
+ // has been set. Consuming data requires two separate operations -- running
+ // the tasks in the message loop, and explicitly invoking the read/write
+ // callbacks (simulating network I/O). We check our conditions between each,
+ // since they can change in either.
+ while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
+ if (counter % 2 == 0)
+ base::RunLoop().RunUntilIdle();
+ if (counter % 2 == 1) {
+ InvokeCallbacks();
+ }
+ counter++;
+ }
+ // We're done consuming new data, but it is possible there are still some
+ // pending callbacks which we expect to complete before returning.
+ while (delegate_.get() &&
+ (delegate_->WritePending() || delegate_->ReadPending()) &&
+ !stopped()) {
+ InvokeCallbacks();
+ base::RunLoop().RunUntilIdle();
+ }
+ SetStopped(false);
+ is_running_ = false;
+}
+
+void DeterministicSocketData::RunFor(int steps) {
+ StopAfter(steps);
+ Run();
+}
+
+void DeterministicSocketData::SetStop(int seq) {
+ DCHECK_LT(sequence_number_, seq);
+ stopping_sequence_number_ = seq;
+ stopped_ = false;
+}
+
+void DeterministicSocketData::StopAfter(int seq) {
+ SetStop(sequence_number_ + seq);
+}
+
+MockRead DeterministicSocketData::GetNextRead() {
+ current_read_ = StaticSocketDataProvider::PeekRead();
+
+ // Synchronous read while stopped is an error
+ if (stopped() && current_read_.mode == SYNCHRONOUS) {
+ LOG(ERROR) << "Unable to perform synchronous IO while stopped";
+ return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
+ }
+
+ // Async read which will be called back in a future step.
+ if (sequence_number_ < current_read_.sequence_number) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": I/O Pending";
+ MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING);
+ if (current_read_.mode == SYNCHRONOUS) {
+ LOG(ERROR) << "Unable to perform synchronous read: "
+ << current_read_.sequence_number
+ << " at stage: " << sequence_number_;
+ result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
+ }
+ if (print_debug_)
+ DumpMockReadWrite(result);
+ return result;
+ }
+
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": Read " << read_index();
+ if (print_debug_)
+ DumpMockReadWrite(current_read_);
+
+ // Increment the sequence number if IO is complete
+ if (current_read_.mode == SYNCHRONOUS)
+ NextStep();
+
+ DCHECK_NE(ERR_IO_PENDING, current_read_.result);
+ StaticSocketDataProvider::GetNextRead();
+
+ return current_read_;
+}
+
+MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
+ const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
+ current_write_ = next_write;
+
+ // Synchronous write while stopped is an error
+ if (stopped() && next_write.mode == SYNCHRONOUS) {
+ LOG(ERROR) << "Unable to perform synchronous IO while stopped";
+ return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
+ }
+
+ // Async write which will be called back in a future step.
+ if (sequence_number_ < next_write.sequence_number) {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": I/O Pending";
+ if (next_write.mode == SYNCHRONOUS) {
+ LOG(ERROR) << "Unable to perform synchronous write: "
+ << next_write.sequence_number << " at stage: " << sequence_number_;
+ return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
+ }
+ } else {
+ NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_
+ << ": Write " << write_index();
+ }
+
+ if (print_debug_)
+ DumpMockReadWrite(next_write);
+
+ // Move to the next step if I/O is synchronous, since the operation will
+ // complete when this method returns.
+ if (next_write.mode == SYNCHRONOUS)
+ NextStep();
+
+ // This is either a sync write for this step, or an async write.
+ return StaticSocketDataProvider::OnWrite(data);
+}
+
+void DeterministicSocketData::Reset() {
+ NET_TRACE(INFO, " *** ") << "Stage "
+ << sequence_number_ << ": Reset()";
+ sequence_number_ = 0;
+ StaticSocketDataProvider::Reset();
+ NOTREACHED();
+}
+
+void DeterministicSocketData::InvokeCallbacks() {
+ if (delegate_.get() && delegate_->WritePending() &&
+ (current_write().sequence_number == sequence_number())) {
+ NextStep();
+ delegate_->CompleteWrite();
+ return;
+ }
+ if (delegate_.get() && delegate_->ReadPending() &&
+ (current_read().sequence_number == sequence_number())) {
+ NextStep();
+ delegate_->CompleteRead();
+ return;
+ }
+}
+
+void DeterministicSocketData::NextStep() {
+ // Invariant: Can never move *past* the stopping step.
+ DCHECK_LT(sequence_number_, stopping_sequence_number_);
+ sequence_number_++;
+ if (sequence_number_ == stopping_sequence_number_)
+ SetStopped(true);
+}
+
+void DeterministicSocketData::VerifyCorrectSequenceNumbers(
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count) {
+ size_t read = 0;
+ size_t write = 0;
+ int expected = 0;
+ while (read < reads_count || write < writes_count) {
+ // Check to see that we have a read or write at the expected
+ // state.
+ if (read < reads_count && reads[read].sequence_number == expected) {
+ ++read;
+ ++expected;
+ continue;
+ }
+ if (write < writes_count && writes[write].sequence_number == expected) {
+ ++write;
+ ++expected;
+ continue;
+ }
+ NOTREACHED() << "Missing sequence number: " << expected;
+ return;
+ }
+ DCHECK_EQ(read, reads_count);
+ DCHECK_EQ(write, writes_count);
+}
+
+MockClientSocketFactory::MockClientSocketFactory() {}
+
+MockClientSocketFactory::~MockClientSocketFactory() {}
+
+void MockClientSocketFactory::AddSocketDataProvider(
+ SocketDataProvider* data) {
+ mock_data_.Add(data);
+}
+
+void MockClientSocketFactory::AddSSLSocketDataProvider(
+ SSLSocketDataProvider* data) {
+ mock_ssl_data_.Add(data);
+}
+
+void MockClientSocketFactory::ResetNextMockIndexes() {
+ mock_data_.ResetNextIndex();
+ mock_ssl_data_.ResetNextIndex();
+}
+
+scoped_ptr<DatagramClientSocket>
+MockClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) {
+ SocketDataProvider* data_provider = mock_data_.GetNext();
+ scoped_ptr<MockUDPClientSocket> socket(
+ new MockUDPClientSocket(data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
+}
+
+scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
+ const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) {
+ SocketDataProvider* data_provider = mock_data_.GetNext();
+ scoped_ptr<MockTCPClientSocket> socket(
+ new MockTCPClientSocket(addresses, net_log, data_provider));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<StreamSocket>();
+}
+
+scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) {
+ return scoped_ptr<SSLClientSocket>(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
+}
+
+void MockClientSocketFactory::ClearSSLSessionCache() {
+}
+
+const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ";
+
+MockClientSocket::MockClientSocket(const BoundNetLog& net_log)
+ : weak_factory_(this),
+ connected_(false),
+ net_log_(net_log) {
+ IPAddressNumber ip;
+ CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
+ peer_addr_ = IPEndPoint(ip, 0);
+}
+
+bool MockClientSocket::SetReceiveBufferSize(int32 size) {
+ return true;
+}
+
+bool MockClientSocket::SetSendBufferSize(int32 size) {
+ return true;
+}
+
+void MockClientSocket::Disconnect() {
+ connected_ = false;
+}
+
+bool MockClientSocket::IsConnected() const {
+ return connected_;
+}
+
+bool MockClientSocket::IsConnectedAndIdle() const {
+ return connected_;
+}
+
+int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ *address = peer_addr_;
+ return OK;
+}
+
+int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ IPAddressNumber ip;
+ bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
+ CHECK(rv);
+ *address = IPEndPoint(ip, 123);
+ return OK;
+}
+
+const BoundNetLog& MockClientSocket::NetLog() const {
+ return net_log_;
+}
+
+void MockClientSocket::GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) {
+}
+
+int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) {
+ memset(out, 'A', outlen);
+ return OK;
+}
+
+int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) {
+ out->assign(MockClientSocket::kTlsUnique);
+ return OK;
+}
+
+ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const {
+ NOTREACHED();
+ return NULL;
+}
+
+SSLClientSocket::NextProtoStatus
+MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) {
+ proto->clear();
+ server_protos->clear();
+ return SSLClientSocket::kNextProtoUnsupported;
+}
+
+MockClientSocket::~MockClientSocket() {}
+
+void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback,
+ int result) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockClientSocket::RunCallback,
+ weak_factory_.GetWeakPtr(),
+ callback,
+ result));
+}
+
+void MockClientSocket::RunCallback(const net::CompletionCallback& callback,
+ int result) {
+ if (!callback.is_null())
+ callback.Run(result);
+}
+
+MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
+ net::NetLog* net_log,
+ SocketDataProvider* data)
+ : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
+ addresses_(addresses),
+ data_(data),
+ read_offset_(0),
+ read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
+ need_read_data_(true),
+ peer_closed_connection_(false),
+ pending_buf_(NULL),
+ pending_buf_len_(0),
+ was_used_to_convey_data_(false) {
+ DCHECK(data_);
+ peer_addr_ = data->connect_data().peer_addr;
+ data_->Reset();
+}
+
+MockTCPClientSocket::~MockTCPClientSocket() {}
+
+int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ // If the buffer is already in use, a read is already in progress!
+ DCHECK(pending_buf_ == NULL);
+
+ // Store our async IO data.
+ pending_buf_ = buf;
+ pending_buf_len_ = buf_len;
+ pending_callback_ = callback;
+
+ if (need_read_data_) {
+ read_data_ = data_->GetNextRead();
+ if (read_data_.result == ERR_CONNECTION_CLOSED) {
+ // This MockRead is just a marker to instruct us to set
+ // peer_closed_connection_.
+ peer_closed_connection_ = true;
+ }
+ if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
+ // This MockRead is just a marker to instruct us to set
+ // peer_closed_connection_. Skip it and get the next one.
+ read_data_ = data_->GetNextRead();
+ peer_closed_connection_ = true;
+ }
+ // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
+ // to complete the async IO manually later (via OnReadComplete).
+ if (read_data_.result == ERR_IO_PENDING) {
+ // We need to be using async IO in this case.
+ DCHECK(!callback.is_null());
+ return ERR_IO_PENDING;
+ }
+ need_read_data_ = false;
+ }
+
+ return CompleteRead();
+}
+
+int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(buf);
+ DCHECK_GT(buf_len, 0);
+
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ std::string data(buf->data(), buf_len);
+ MockWriteResult write_result = data_->OnWrite(data);
+
+ was_used_to_convey_data_ = true;
+
+ if (write_result.mode == ASYNC) {
+ RunCallbackAsync(callback, write_result.result);
+ return ERR_IO_PENDING;
+ }
+
+ return write_result.result;
+}
+
+int MockTCPClientSocket::Connect(const CompletionCallback& callback) {
+ if (connected_)
+ return OK;
+ connected_ = true;
+ peer_closed_connection_ = false;
+ if (data_->connect_data().mode == ASYNC) {
+ if (data_->connect_data().result == ERR_IO_PENDING)
+ pending_callback_ = callback;
+ else
+ RunCallbackAsync(callback, data_->connect_data().result);
+ return ERR_IO_PENDING;
+ }
+ return data_->connect_data().result;
+}
+
+void MockTCPClientSocket::Disconnect() {
+ MockClientSocket::Disconnect();
+ pending_callback_.Reset();
+}
+
+bool MockTCPClientSocket::IsConnected() const {
+ return connected_ && !peer_closed_connection_;
+}
+
+bool MockTCPClientSocket::IsConnectedAndIdle() const {
+ return IsConnected();
+}
+
+int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ if (addresses_.empty())
+ return MockClientSocket::GetPeerAddress(address);
+
+ *address = addresses_[0];
+ return OK;
+}
+
+bool MockTCPClientSocket::WasEverUsed() const {
+ return was_used_to_convey_data_;
+}
+
+bool MockTCPClientSocket::UsingTCPFastOpen() const {
+ return false;
+}
+
+bool MockTCPClientSocket::WasNpnNegotiated() const {
+ return false;
+}
+
+bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ return false;
+}
+
+void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
+ // There must be a read pending.
+ DCHECK(pending_buf_);
+ // You can't complete a read with another ERR_IO_PENDING status code.
+ DCHECK_NE(ERR_IO_PENDING, data.result);
+ // Since we've been waiting for data, need_read_data_ should be true.
+ DCHECK(need_read_data_);
+
+ read_data_ = data;
+ need_read_data_ = false;
+
+ // The caller is simulating that this IO completes right now. Don't
+ // let CompleteRead() schedule a callback.
+ read_data_.mode = SYNCHRONOUS;
+
+ CompletionCallback callback = pending_callback_;
+ int rv = CompleteRead();
+ RunCallback(callback, rv);
+}
+
+void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
+ CompletionCallback callback = pending_callback_;
+ RunCallback(callback, data.result);
+}
+
+int MockTCPClientSocket::CompleteRead() {
+ DCHECK(pending_buf_);
+ DCHECK(pending_buf_len_ > 0);
+
+ was_used_to_convey_data_ = true;
+
+ // Save the pending async IO data and reset our |pending_| state.
+ IOBuffer* buf = pending_buf_;
+ int buf_len = pending_buf_len_;
+ CompletionCallback callback = pending_callback_;
+ pending_buf_ = NULL;
+ pending_buf_len_ = 0;
+ pending_callback_.Reset();
+
+ int result = read_data_.result;
+ DCHECK(result != ERR_IO_PENDING);
+
+ if (read_data_.data) {
+ if (read_data_.data_len - read_offset_ > 0) {
+ result = std::min(buf_len, read_data_.data_len - read_offset_);
+ memcpy(buf->data(), read_data_.data + read_offset_, result);
+ read_offset_ += result;
+ if (read_offset_ == read_data_.data_len) {
+ need_read_data_ = true;
+ read_offset_ = 0;
+ }
+ } else {
+ result = 0; // EOF
+ }
+ }
+
+ if (read_data_.mode == ASYNC) {
+ DCHECK(!callback.is_null());
+ RunCallbackAsync(callback, result);
+ return ERR_IO_PENDING;
+ }
+ return result;
+}
+
+DeterministicSocketHelper::DeterministicSocketHelper(
+ net::NetLog* net_log,
+ DeterministicSocketData* data)
+ : write_pending_(false),
+ write_result_(0),
+ read_data_(),
+ read_buf_(NULL),
+ read_buf_len_(0),
+ read_pending_(false),
+ data_(data),
+ was_used_to_convey_data_(false),
+ peer_closed_connection_(false),
+ net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) {
+}
+
+DeterministicSocketHelper::~DeterministicSocketHelper() {}
+
+void DeterministicSocketHelper::CompleteWrite() {
+ was_used_to_convey_data_ = true;
+ write_pending_ = false;
+ write_callback_.Run(write_result_);
+}
+
+int DeterministicSocketHelper::CompleteRead() {
+ DCHECK_GT(read_buf_len_, 0);
+ DCHECK_LE(read_data_.data_len, read_buf_len_);
+ DCHECK(read_buf_);
+
+ was_used_to_convey_data_ = true;
+
+ if (read_data_.result == ERR_IO_PENDING)
+ read_data_ = data_->GetNextRead();
+ DCHECK_NE(ERR_IO_PENDING, read_data_.result);
+ // If read_data_.mode is ASYNC, we do not need to wait, since this is already
+ // the callback. Therefore we don't even bother to check it.
+ int result = read_data_.result;
+
+ if (read_data_.data_len > 0) {
+ DCHECK(read_data_.data);
+ result = std::min(read_buf_len_, read_data_.data_len);
+ memcpy(read_buf_->data(), read_data_.data, result);
+ }
+
+ if (read_pending_) {
+ read_pending_ = false;
+ read_callback_.Run(result);
+ }
+
+ return result;
+}
+
+int DeterministicSocketHelper::Write(
+ IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+ DCHECK(buf);
+ DCHECK_GT(buf_len, 0);
+
+ std::string data(buf->data(), buf_len);
+ MockWriteResult write_result = data_->OnWrite(data);
+
+ if (write_result.mode == ASYNC) {
+ write_callback_ = callback;
+ write_result_ = write_result.result;
+ DCHECK(!write_callback_.is_null());
+ write_pending_ = true;
+ return ERR_IO_PENDING;
+ }
+
+ was_used_to_convey_data_ = true;
+ write_pending_ = false;
+ return write_result.result;
+}
+
+int DeterministicSocketHelper::Read(
+ IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+
+ read_data_ = data_->GetNextRead();
+ // The buffer should always be big enough to contain all the MockRead data. To
+ // use small buffers, split the data into multiple MockReads.
+ DCHECK_LE(read_data_.data_len, buf_len);
+
+ if (read_data_.result == ERR_CONNECTION_CLOSED) {
+ // This MockRead is just a marker to instruct us to set
+ // peer_closed_connection_.
+ peer_closed_connection_ = true;
+ }
+ if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
+ // This MockRead is just a marker to instruct us to set
+ // peer_closed_connection_. Skip it and get the next one.
+ read_data_ = data_->GetNextRead();
+ peer_closed_connection_ = true;
+ }
+
+ read_buf_ = buf;
+ read_buf_len_ = buf_len;
+ read_callback_ = callback;
+
+ if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) {
+ read_pending_ = true;
+ DCHECK(!read_callback_.is_null());
+ return ERR_IO_PENDING;
+ }
+
+ was_used_to_convey_data_ = true;
+ return CompleteRead();
+}
+
+DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket(
+ net::NetLog* net_log,
+ DeterministicSocketData* data)
+ : connected_(false),
+ helper_(net_log, data) {
+}
+
+DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {}
+
+bool DeterministicMockUDPClientSocket::WritePending() const {
+ return helper_.write_pending();
+}
+
+bool DeterministicMockUDPClientSocket::ReadPending() const {
+ return helper_.read_pending();
+}
+
+void DeterministicMockUDPClientSocket::CompleteWrite() {
+ helper_.CompleteWrite();
+}
+
+int DeterministicMockUDPClientSocket::CompleteRead() {
+ return helper_.CompleteRead();
+}
+
+int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) {
+ if (connected_)
+ return OK;
+ connected_ = true;
+ peer_address_ = address;
+ return helper_.data()->connect_data().result;
+};
+
+int DeterministicMockUDPClientSocket::Write(
+ IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ return helper_.Write(buf, buf_len, callback);
+}
+
+int DeterministicMockUDPClientSocket::Read(
+ IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ return helper_.Read(buf, buf_len, callback);
+}
+
+bool DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) {
+ return true;
+}
+
+bool DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) {
+ return true;
+}
+
+void DeterministicMockUDPClientSocket::Close() {
+ connected_ = false;
+}
+
+int DeterministicMockUDPClientSocket::GetPeerAddress(
+ IPEndPoint* address) const {
+ *address = peer_address_;
+ return OK;
+}
+
+int DeterministicMockUDPClientSocket::GetLocalAddress(
+ IPEndPoint* address) const {
+ IPAddressNumber ip;
+ bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
+ CHECK(rv);
+ *address = IPEndPoint(ip, 123);
+ return OK;
+}
+
+const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const {
+ return helper_.net_log();
+}
+
+void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {}
+
+void DeterministicMockUDPClientSocket::OnConnectComplete(
+ const MockConnect& data) {
+ NOTIMPLEMENTED();
+}
+
+DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
+ net::NetLog* net_log,
+ DeterministicSocketData* data)
+ : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
+ helper_(net_log, data) {
+ peer_addr_ = data->connect_data().peer_addr;
+}
+
+DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
+
+bool DeterministicMockTCPClientSocket::WritePending() const {
+ return helper_.write_pending();
+}
+
+bool DeterministicMockTCPClientSocket::ReadPending() const {
+ return helper_.read_pending();
+}
+
+void DeterministicMockTCPClientSocket::CompleteWrite() {
+ helper_.CompleteWrite();
+}
+
+int DeterministicMockTCPClientSocket::CompleteRead() {
+ return helper_.CompleteRead();
+}
+
+int DeterministicMockTCPClientSocket::Write(
+ IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ return helper_.Write(buf, buf_len, callback);
+}
+
+int DeterministicMockTCPClientSocket::Read(
+ IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ return helper_.Read(buf, buf_len, callback);
+}
+
+// TODO(erikchen): Support connect sequencing.
+int DeterministicMockTCPClientSocket::Connect(
+ const CompletionCallback& callback) {
+ if (connected_)
+ return OK;
+ connected_ = true;
+ if (helper_.data()->connect_data().mode == ASYNC) {
+ RunCallbackAsync(callback, helper_.data()->connect_data().result);
+ return ERR_IO_PENDING;
+ }
+ return helper_.data()->connect_data().result;
+}
+
+void DeterministicMockTCPClientSocket::Disconnect() {
+ MockClientSocket::Disconnect();
+}
+
+bool DeterministicMockTCPClientSocket::IsConnected() const {
+ return connected_ && !helper_.peer_closed_connection();
+}
+
+bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
+ return IsConnected();
+}
+
+bool DeterministicMockTCPClientSocket::WasEverUsed() const {
+ return helper_.was_used_to_convey_data();
+}
+
+bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
+ return false;
+}
+
+bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const {
+ return false;
+}
+
+bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ return false;
+}
+
+void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
+
+void DeterministicMockTCPClientSocket::OnConnectComplete(
+ const MockConnect& data) {}
+
+// static
+void MockSSLClientSocket::ConnectCallback(
+ MockSSLClientSocket* ssl_client_socket,
+ const CompletionCallback& callback,
+ int rv) {
+ if (rv == OK)
+ ssl_client_socket->connected_ = true;
+ callback.Run(rv);
+}
+
+MockSSLClientSocket::MockSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_port_pair,
+ const SSLConfig& ssl_config,
+ SSLSocketDataProvider* data)
+ : MockClientSocket(
+ // Have to use the right BoundNetLog for LoadTimingInfo regression
+ // tests.
+ transport_socket->socket()->NetLog()),
+ transport_(transport_socket.Pass()),
+ data_(data),
+ is_npn_state_set_(false),
+ new_npn_value_(false),
+ is_protocol_negotiated_set_(false),
+ protocol_negotiated_(kProtoUnknown) {
+ DCHECK(data_);
+ peer_addr_ = data->connect.peer_addr;
+}
+
+MockSSLClientSocket::~MockSSLClientSocket() {
+ Disconnect();
+}
+
+int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ return transport_->socket()->Read(buf, buf_len, callback);
+}
+
+int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ return transport_->socket()->Write(buf, buf_len, callback);
+}
+
+int MockSSLClientSocket::Connect(const CompletionCallback& callback) {
+ int rv = transport_->socket()->Connect(
+ base::Bind(&ConnectCallback, base::Unretained(this), callback));
+ if (rv == OK) {
+ if (data_->connect.result == OK)
+ connected_ = true;
+ if (data_->connect.mode == ASYNC) {
+ RunCallbackAsync(callback, data_->connect.result);
+ return ERR_IO_PENDING;
+ }
+ return data_->connect.result;
+ }
+ return rv;
+}
+
+void MockSSLClientSocket::Disconnect() {
+ MockClientSocket::Disconnect();
+ if (transport_->socket() != NULL)
+ transport_->socket()->Disconnect();
+}
+
+bool MockSSLClientSocket::IsConnected() const {
+ return transport_->socket()->IsConnected();
+}
+
+bool MockSSLClientSocket::WasEverUsed() const {
+ return transport_->socket()->WasEverUsed();
+}
+
+bool MockSSLClientSocket::UsingTCPFastOpen() const {
+ return transport_->socket()->UsingTCPFastOpen();
+}
+
+int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetPeerAddress(address);
+}
+
+bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ ssl_info->Reset();
+ ssl_info->cert = data_->cert;
+ ssl_info->client_cert_sent = data_->client_cert_sent;
+ ssl_info->channel_id_sent = data_->channel_id_sent;
+ return true;
+}
+
+void MockSSLClientSocket::GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) {
+ DCHECK(cert_request_info);
+ if (data_->cert_request_info) {
+ cert_request_info->host_and_port =
+ data_->cert_request_info->host_and_port;
+ cert_request_info->client_certs = data_->cert_request_info->client_certs;
+ } else {
+ cert_request_info->Reset();
+ }
+}
+
+SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
+ std::string* proto, std::string* server_protos) {
+ *proto = data_->next_proto;
+ *server_protos = data_->server_protos;
+ return data_->next_proto_status;
+}
+
+bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
+ is_npn_state_set_ = true;
+ return new_npn_value_ = negotiated;
+}
+
+bool MockSSLClientSocket::WasNpnNegotiated() const {
+ if (is_npn_state_set_)
+ return new_npn_value_;
+ return data_->was_npn_negotiated;
+}
+
+NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
+ if (is_protocol_negotiated_set_)
+ return protocol_negotiated_;
+ return data_->protocol_negotiated;
+}
+
+void MockSSLClientSocket::set_protocol_negotiated(
+ NextProto protocol_negotiated) {
+ is_protocol_negotiated_set_ = true;
+ protocol_negotiated_ = protocol_negotiated;
+}
+
+bool MockSSLClientSocket::WasChannelIDSent() const {
+ return data_->channel_id_sent;
+}
+
+void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
+ data_->channel_id_sent = channel_id_sent;
+}
+
+ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const {
+ return data_->server_bound_cert_service;
+}
+
+void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
+ NOTIMPLEMENTED();
+}
+
+void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
+ NOTIMPLEMENTED();
+}
+
+MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
+ net::NetLog* net_log)
+ : connected_(false),
+ data_(data),
+ read_offset_(0),
+ read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
+ need_read_data_(true),
+ pending_buf_(NULL),
+ pending_buf_len_(0),
+ net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
+ weak_factory_(this) {
+ DCHECK(data_);
+ data_->Reset();
+ peer_addr_ = data->connect_data().peer_addr;
+}
+
+MockUDPClientSocket::~MockUDPClientSocket() {}
+
+int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ // If the buffer is already in use, a read is already in progress!
+ DCHECK(pending_buf_ == NULL);
+
+ // Store our async IO data.
+ pending_buf_ = buf;
+ pending_buf_len_ = buf_len;
+ pending_callback_ = callback;
+
+ if (need_read_data_) {
+ read_data_ = data_->GetNextRead();
+ // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
+ // to complete the async IO manually later (via OnReadComplete).
+ if (read_data_.result == ERR_IO_PENDING) {
+ // We need to be using async IO in this case.
+ DCHECK(!callback.is_null());
+ return ERR_IO_PENDING;
+ }
+ need_read_data_ = false;
+ }
+
+ return CompleteRead();
+}
+
+int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(buf);
+ DCHECK_GT(buf_len, 0);
+
+ if (!connected_)
+ return ERR_UNEXPECTED;
+
+ std::string data(buf->data(), buf_len);
+ MockWriteResult write_result = data_->OnWrite(data);
+
+ if (write_result.mode == ASYNC) {
+ RunCallbackAsync(callback, write_result.result);
+ return ERR_IO_PENDING;
+ }
+ return write_result.result;
+}
+
+bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) {
+ return true;
+}
+
+bool MockUDPClientSocket::SetSendBufferSize(int32 size) {
+ return true;
+}
+
+void MockUDPClientSocket::Close() {
+ connected_ = false;
+}
+
+int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ *address = peer_addr_;
+ return OK;
+}
+
+int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ IPAddressNumber ip;
+ bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
+ CHECK(rv);
+ *address = IPEndPoint(ip, 123);
+ return OK;
+}
+
+const BoundNetLog& MockUDPClientSocket::NetLog() const {
+ return net_log_;
+}
+
+int MockUDPClientSocket::Connect(const IPEndPoint& address) {
+ connected_ = true;
+ peer_addr_ = address;
+ return OK;
+}
+
+void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
+ // There must be a read pending.
+ DCHECK(pending_buf_);
+ // You can't complete a read with another ERR_IO_PENDING status code.
+ DCHECK_NE(ERR_IO_PENDING, data.result);
+ // Since we've been waiting for data, need_read_data_ should be true.
+ DCHECK(need_read_data_);
+
+ read_data_ = data;
+ need_read_data_ = false;
+
+ // The caller is simulating that this IO completes right now. Don't
+ // let CompleteRead() schedule a callback.
+ read_data_.mode = SYNCHRONOUS;
+
+ net::CompletionCallback callback = pending_callback_;
+ int rv = CompleteRead();
+ RunCallback(callback, rv);
+}
+
+void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
+ NOTIMPLEMENTED();
+}
+
+int MockUDPClientSocket::CompleteRead() {
+ DCHECK(pending_buf_);
+ DCHECK(pending_buf_len_ > 0);
+
+ // Save the pending async IO data and reset our |pending_| state.
+ IOBuffer* buf = pending_buf_;
+ int buf_len = pending_buf_len_;
+ CompletionCallback callback = pending_callback_;
+ pending_buf_ = NULL;
+ pending_buf_len_ = 0;
+ pending_callback_.Reset();
+
+ int result = read_data_.result;
+ DCHECK(result != ERR_IO_PENDING);
+
+ if (read_data_.data) {
+ if (read_data_.data_len - read_offset_ > 0) {
+ result = std::min(buf_len, read_data_.data_len - read_offset_);
+ memcpy(buf->data(), read_data_.data + read_offset_, result);
+ read_offset_ += result;
+ if (read_offset_ == read_data_.data_len) {
+ need_read_data_ = true;
+ read_offset_ = 0;
+ }
+ } else {
+ result = 0; // EOF
+ }
+ }
+
+ if (read_data_.mode == ASYNC) {
+ DCHECK(!callback.is_null());
+ RunCallbackAsync(callback, result);
+ return ERR_IO_PENDING;
+ }
+ return result;
+}
+
+void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback,
+ int result) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockUDPClientSocket::RunCallback,
+ weak_factory_.GetWeakPtr(),
+ callback,
+ result));
+}
+
+void MockUDPClientSocket::RunCallback(const CompletionCallback& callback,
+ int result) {
+ if (!callback.is_null())
+ callback.Run(result);
+}
+
+TestSocketRequest::TestSocketRequest(
+ std::vector<TestSocketRequest*>* request_order, size_t* completion_count)
+ : request_order_(request_order),
+ completion_count_(completion_count),
+ callback_(base::Bind(&TestSocketRequest::OnComplete,
+ base::Unretained(this))) {
+ DCHECK(request_order);
+ DCHECK(completion_count);
+}
+
+TestSocketRequest::~TestSocketRequest() {
+}
+
+void TestSocketRequest::OnComplete(int result) {
+ SetResult(result);
+ (*completion_count_)++;
+ request_order_->push_back(this);
+}
+
+// static
+const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
+
+// static
+const int ClientSocketPoolTest::kRequestNotFound = -2;
+
+ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
+ClientSocketPoolTest::~ClientSocketPoolTest() {}
+
+int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
+ index--;
+ if (index >= requests_.size())
+ return kIndexOutOfBounds;
+
+ for (size_t i = 0; i < request_order_.size(); i++)
+ if (requests_[index] == request_order_[i])
+ return i + 1;
+
+ return kRequestNotFound;
+}
+
+bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
+ ScopedVector<TestSocketRequest>::iterator i;
+ for (i = requests_.begin(); i != requests_.end(); ++i) {
+ if ((*i)->handle()->is_initialized()) {
+ if (keep_alive == NO_KEEP_ALIVE)
+ (*i)->handle()->socket()->Disconnect();
+ (*i)->handle()->Reset();
+ base::RunLoop().RunUntilIdle();
+ return true;
+ }
+ }
+ return false;
+}
+
+void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
+ bool released_one;
+ do {
+ released_one = ReleaseOneConnection(keep_alive);
+ } while (released_one);
+}
+
+MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
+ scoped_ptr<StreamSocket> socket,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback)
+ : socket_(socket.Pass()),
+ handle_(handle),
+ user_callback_(callback) {
+}
+
+MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
+
+int MockTransportClientSocketPool::MockConnectJob::Connect() {
+ int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect,
+ base::Unretained(this)));
+ if (rv == OK) {
+ user_callback_.Reset();
+ OnConnect(OK);
+ }
+ return rv;
+}
+
+bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
+ const ClientSocketHandle* handle) {
+ if (handle != handle_)
+ return false;
+ socket_.reset();
+ handle_ = NULL;
+ user_callback_.Reset();
+ return true;
+}
+
+void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
+ if (!socket_.get())
+ return;
+ if (rv == OK) {
+ handle_->SetSocket(socket_.Pass());
+
+ // Needed for socket pool tests that layer other sockets on top of mock
+ // sockets.
+ LoadTimingInfo::ConnectTiming connect_timing;
+ base::TimeTicks now = base::TimeTicks::Now();
+ connect_timing.dns_start = now;
+ connect_timing.dns_end = now;
+ connect_timing.connect_start = now;
+ connect_timing.connect_end = now;
+ handle_->set_connect_timing(connect_timing);
+ } else {
+ socket_.reset();
+ }
+
+ handle_ = NULL;
+
+ if (!user_callback_.is_null()) {
+ CompletionCallback callback = user_callback_;
+ user_callback_.Reset();
+ callback.Run(rv);
+ }
+}
+
+MockTransportClientSocketPool::MockTransportClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ ClientSocketFactory* socket_factory)
+ : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
+ NULL, NULL, NULL),
+ client_socket_factory_(socket_factory),
+ release_count_(0),
+ cancel_count_(0) {
+}
+
+MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
+
+int MockTransportClientSocketPool::RequestSocket(
+ const std::string& group_name, const void* socket_params,
+ RequestPriority priority, ClientSocketHandle* handle,
+ const CompletionCallback& callback, const BoundNetLog& net_log) {
+ scoped_ptr<StreamSocket> socket =
+ client_socket_factory_->CreateTransportClientSocket(
+ AddressList(), net_log.net_log(), net::NetLog::Source());
+ MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback);
+ job_list_.push_back(job);
+ handle->set_pool_id(1);
+ return job->Connect();
+}
+
+void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) {
+ std::vector<MockConnectJob*>::iterator i;
+ for (i = job_list_.begin(); i != job_list_.end(); ++i) {
+ if ((*i)->CancelHandle(handle)) {
+ cancel_count_++;
+ break;
+ }
+ }
+}
+
+void MockTransportClientSocketPool::ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ EXPECT_EQ(1, id);
+ release_count_++;
+}
+
+DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
+
+DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
+
+void DeterministicMockClientSocketFactory::AddSocketDataProvider(
+ DeterministicSocketData* data) {
+ mock_data_.Add(data);
+}
+
+void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
+ SSLSocketDataProvider* data) {
+ mock_ssl_data_.Add(data);
+}
+
+void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
+ mock_data_.ResetNextIndex();
+ mock_ssl_data_.ResetNextIndex();
+}
+
+MockSSLClientSocket* DeterministicMockClientSocketFactory::
+ GetMockSSLClientSocket(size_t index) const {
+ DCHECK_LT(index, ssl_client_sockets_.size());
+ return ssl_client_sockets_[index];
+}
+
+scoped_ptr<DatagramClientSocket>
+DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const NetLog::Source& source) {
+ DeterministicSocketData* data_provider = mock_data().GetNext();
+ scoped_ptr<DeterministicMockUDPClientSocket> socket(
+ new DeterministicMockUDPClientSocket(net_log, data_provider));
+ data_provider->set_delegate(socket->AsWeakPtr());
+ udp_client_sockets().push_back(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
+}
+
+scoped_ptr<StreamSocket>
+DeterministicMockClientSocketFactory::CreateTransportClientSocket(
+ const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) {
+ DeterministicSocketData* data_provider = mock_data().GetNext();
+ scoped_ptr<DeterministicMockTCPClientSocket> socket(
+ new DeterministicMockTCPClientSocket(net_log, data_provider));
+ data_provider->set_delegate(socket->AsWeakPtr());
+ tcp_client_sockets().push_back(socket.get());
+ return socket.PassAs<StreamSocket>();
+}
+
+scoped_ptr<SSLClientSocket>
+DeterministicMockClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) {
+ scoped_ptr<MockSSLClientSocket> socket(
+ new MockSSLClientSocket(transport_socket.Pass(),
+ host_and_port, ssl_config,
+ mock_ssl_data_.GetNext()));
+ ssl_client_sockets_.push_back(socket.get());
+ return socket.PassAs<SSLClientSocket>();
+}
+
+void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
+}
+
+MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ TransportClientSocketPool* transport_pool)
+ : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
+ NULL, transport_pool, NULL),
+ transport_pool_(transport_pool) {
+}
+
+MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
+
+int MockSOCKSClientSocketPool::RequestSocket(
+ const std::string& group_name, const void* socket_params,
+ RequestPriority priority, ClientSocketHandle* handle,
+ const CompletionCallback& callback, const BoundNetLog& net_log) {
+ return transport_pool_->RequestSocket(
+ group_name, socket_params, priority, handle, callback, net_log);
+}
+
+void MockSOCKSClientSocketPool::CancelRequest(
+ const std::string& group_name,
+ ClientSocketHandle* handle) {
+ return transport_pool_->CancelRequest(group_name, handle);
+}
+
+void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id);
+}
+
+const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
+const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
+
+const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
+const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
+
+const char kSOCKS5OkRequest[] =
+ { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
+const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
+
+const char kSOCKS5OkResponse[] =
+ { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
+const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
+
+} // namespace net
diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h
new file mode 100644
index 00000000000..a888249654c
--- /dev/null
+++ b/chromium/net/socket/socket_test_util.h
@@ -0,0 +1,1198 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
+#define NET_SOCKET_SOCKET_TEST_UTIL_H_
+
+#include <cstring>
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/callback.h"
+#include "base/logging.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/scoped_vector.h"
+#include "base/memory/weak_ptr.h"
+#include "base/strings/string16.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/test_completion_callback.h"
+#include "net/http/http_auth_controller.h"
+#include "net/http/http_proxy_client_socket_pool.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/socks_client_socket_pool.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/ssl_client_socket_pool.h"
+#include "net/socket/transport_client_socket_pool.h"
+#include "net/ssl/ssl_config_service.h"
+#include "net/udp/datagram_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+enum {
+ // A private network error code used by the socket test utility classes.
+ // If the |result| member of a MockRead is
+ // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
+ // marker that indicates the peer will close the connection after the next
+ // MockRead. The other members of that MockRead are ignored.
+ ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
+};
+
+class AsyncSocket;
+class MockClientSocket;
+class ServerBoundCertService;
+class SSLClientSocket;
+class StreamSocket;
+
+enum IoMode {
+ ASYNC,
+ SYNCHRONOUS
+};
+
+struct MockConnect {
+ // Asynchronous connection success.
+ // Creates a MockConnect with |mode| ASYC, |result| OK, and
+ // |peer_addr| 192.0.2.33.
+ MockConnect();
+ // Creates a MockConnect with the specified mode and result, with
+ // |peer_addr| 192.0.2.33.
+ MockConnect(IoMode io_mode, int r);
+ MockConnect(IoMode io_mode, int r, IPEndPoint addr);
+ ~MockConnect();
+
+ IoMode mode;
+ int result;
+ IPEndPoint peer_addr;
+};
+
+// MockRead and MockWrite shares the same interface and members, but we'd like
+// to have distinct types because we don't want to have them used
+// interchangably. To do this, a struct template is defined, and MockRead and
+// MockWrite are instantiated by using this template. Template parameter |type|
+// is not used in the struct definition (it purely exists for creating a new
+// type).
+//
+// |data| in MockRead and MockWrite has different meanings: |data| in MockRead
+// is the data returned from the socket when MockTCPClientSocket::Read() is
+// attempted, while |data| in MockWrite is the expected data that should be
+// given in MockTCPClientSocket::Write().
+enum MockReadWriteType {
+ MOCK_READ,
+ MOCK_WRITE
+};
+
+template <MockReadWriteType type>
+struct MockReadWrite {
+ // Flag to indicate that the message loop should be terminated.
+ enum {
+ STOPLOOP = 1 << 31
+ };
+
+ // Default
+ MockReadWrite() : mode(SYNCHRONOUS), result(0), data(NULL), data_len(0),
+ sequence_number(0), time_stamp(base::Time::Now()) {}
+
+ // Read/write failure (no data).
+ MockReadWrite(IoMode io_mode, int result) : mode(io_mode), result(result),
+ data(NULL), data_len(0), sequence_number(0),
+ time_stamp(base::Time::Now()) { }
+
+ // Read/write failure (no data), with sequence information.
+ MockReadWrite(IoMode io_mode, int result, int seq) : mode(io_mode),
+ result(result), data(NULL), data_len(0), sequence_number(seq),
+ time_stamp(base::Time::Now()) { }
+
+ // Asynchronous read/write success (inferred data length).
+ explicit MockReadWrite(const char* data) : mode(ASYNC), result(0),
+ data(data), data_len(strlen(data)), sequence_number(0),
+ time_stamp(base::Time::Now()) { }
+
+ // Read/write success (inferred data length).
+ MockReadWrite(IoMode io_mode, const char* data) : mode(io_mode), result(0),
+ data(data), data_len(strlen(data)), sequence_number(0),
+ time_stamp(base::Time::Now()) { }
+
+ // Read/write success.
+ MockReadWrite(IoMode io_mode, const char* data, int data_len) : mode(io_mode),
+ result(0), data(data), data_len(data_len), sequence_number(0),
+ time_stamp(base::Time::Now()) { }
+
+ // Read/write success (inferred data length) with sequence information.
+ MockReadWrite(IoMode io_mode, int seq, const char* data) : mode(io_mode),
+ result(0), data(data), data_len(strlen(data)), sequence_number(seq),
+ time_stamp(base::Time::Now()) { }
+
+ // Read/write success with sequence information.
+ MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) :
+ mode(io_mode), result(0), data(data), data_len(data_len),
+ sequence_number(seq), time_stamp(base::Time::Now()) { }
+
+ IoMode mode;
+ int result;
+ const char* data;
+ int data_len;
+
+ // For OrderedSocketData, which only allows reads to occur in a particular
+ // sequence. If a read occurs before the given |sequence_number| is reached,
+ // an ERR_IO_PENDING is returned.
+ int sequence_number; // The sequence number at which a read is allowed
+ // to occur.
+ base::Time time_stamp; // The time stamp at which the operation occurred.
+};
+
+typedef MockReadWrite<MOCK_READ> MockRead;
+typedef MockReadWrite<MOCK_WRITE> MockWrite;
+
+struct MockWriteResult {
+ MockWriteResult(IoMode io_mode, int result)
+ : mode(io_mode),
+ result(result) {}
+
+ IoMode mode;
+ int result;
+};
+
+// The SocketDataProvider is an interface used by the MockClientSocket
+// for getting data about individual reads and writes on the socket.
+class SocketDataProvider {
+ public:
+ SocketDataProvider() : socket_(NULL) {}
+
+ virtual ~SocketDataProvider() {}
+
+ // Returns the buffer and result code for the next simulated read.
+ // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
+ // that it will be called via the AsyncSocket::OnReadComplete()
+ // function at a later time.
+ virtual MockRead GetNextRead() = 0;
+ virtual MockWriteResult OnWrite(const std::string& data) = 0;
+ virtual void Reset() = 0;
+
+ // Accessor for the socket which is using the SocketDataProvider.
+ AsyncSocket* socket() { return socket_; }
+ void set_socket(AsyncSocket* socket) { socket_ = socket; }
+
+ MockConnect connect_data() const { return connect_; }
+ void set_connect_data(const MockConnect& connect) { connect_ = connect; }
+
+ private:
+ MockConnect connect_;
+ AsyncSocket* socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
+};
+
+// The AsyncSocket is an interface used by the SocketDataProvider to
+// complete the asynchronous read operation.
+class AsyncSocket {
+ public:
+ // If an async IO is pending because the SocketDataProvider returned
+ // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
+ // is called to complete the asynchronous read operation.
+ // data.async is ignored, and this read is completed synchronously as
+ // part of this call.
+ virtual void OnReadComplete(const MockRead& data) = 0;
+ virtual void OnConnectComplete(const MockConnect& data) = 0;
+};
+
+// SocketDataProvider which responds based on static tables of mock reads and
+// writes.
+class StaticSocketDataProvider : public SocketDataProvider {
+ public:
+ StaticSocketDataProvider();
+ StaticSocketDataProvider(MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+ virtual ~StaticSocketDataProvider();
+
+ // These functions get access to the next available read and write data.
+ const MockRead& PeekRead() const;
+ const MockWrite& PeekWrite() const;
+ // These functions get random access to the read and write data, for timing.
+ const MockRead& PeekRead(size_t index) const;
+ const MockWrite& PeekWrite(size_t index) const;
+ size_t read_index() const { return read_index_; }
+ size_t write_index() const { return write_index_; }
+ size_t read_count() const { return read_count_; }
+ size_t write_count() const { return write_count_; }
+
+ bool at_read_eof() const { return read_index_ >= read_count_; }
+ bool at_write_eof() const { return write_index_ >= write_count_; }
+
+ virtual void CompleteRead() {}
+
+ // SocketDataProvider implementation.
+ virtual MockRead GetNextRead() OVERRIDE;
+ virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
+ ; virtual void Reset() OVERRIDE;
+
+ private:
+ MockRead* reads_;
+ size_t read_index_;
+ size_t read_count_;
+ MockWrite* writes_;
+ size_t write_index_;
+ size_t write_count_;
+
+ DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
+};
+
+// SocketDataProvider which can make decisions about next mock reads based on
+// received writes. It can also be used to enforce order of operations, for
+// example that tested code must send the "Hello!" message before receiving
+// response. This is useful for testing conversation-like protocols like FTP.
+class DynamicSocketDataProvider : public SocketDataProvider {
+ public:
+ DynamicSocketDataProvider();
+ virtual ~DynamicSocketDataProvider();
+
+ int short_read_limit() const { return short_read_limit_; }
+ void set_short_read_limit(int limit) { short_read_limit_ = limit; }
+
+ void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
+
+ // SocketDataProvider implementation.
+ virtual MockRead GetNextRead() OVERRIDE;
+ virtual MockWriteResult OnWrite(const std::string& data) = 0;
+ virtual void Reset() OVERRIDE;
+
+ protected:
+ // The next time there is a read from this socket, it will return |data|.
+ // Before calling SimulateRead next time, the previous data must be consumed.
+ void SimulateRead(const char* data, size_t length);
+ void SimulateRead(const char* data) {
+ SimulateRead(data, std::strlen(data));
+ }
+
+ private:
+ std::deque<MockRead> reads_;
+
+ // Max number of bytes we will read at a time. 0 means no limit.
+ int short_read_limit_;
+
+ // If true, we'll not require the client to consume all data before we
+ // mock the next read.
+ bool allow_unconsumed_reads_;
+
+ DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
+};
+
+// SSLSocketDataProviders only need to keep track of the return code from calls
+// to Connect().
+struct SSLSocketDataProvider {
+ SSLSocketDataProvider(IoMode mode, int result);
+ ~SSLSocketDataProvider();
+
+ void SetNextProto(NextProto proto);
+
+ MockConnect connect;
+ SSLClientSocket::NextProtoStatus next_proto_status;
+ std::string next_proto;
+ std::string server_protos;
+ bool was_npn_negotiated;
+ NextProto protocol_negotiated;
+ bool client_cert_sent;
+ SSLCertRequestInfo* cert_request_info;
+ scoped_refptr<X509Certificate> cert;
+ bool channel_id_sent;
+ ServerBoundCertService* server_bound_cert_service;
+};
+
+// A DataProvider where the client must write a request before the reads (e.g.
+// the response) will complete.
+class DelayedSocketData : public StaticSocketDataProvider {
+ public:
+ // |write_delay| the number of MockWrites to complete before allowing
+ // a MockRead to complete.
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
+ // MockRead(true, 0, 0);
+ DelayedSocketData(int write_delay,
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+
+ // |connect| the result for the connect phase.
+ // |reads| the list of MockRead completions.
+ // |write_delay| the number of MockWrites to complete before allowing
+ // a MockRead to complete.
+ // |writes| the list of MockWrite completions.
+ // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
+ // MockRead(true, 0, 0);
+ DelayedSocketData(const MockConnect& connect, int write_delay,
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+ virtual ~DelayedSocketData();
+
+ void ForceNextRead();
+
+ // StaticSocketDataProvider:
+ virtual MockRead GetNextRead() OVERRIDE;
+ virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
+ virtual void Reset() OVERRIDE;
+ virtual void CompleteRead() OVERRIDE;
+
+ private:
+ int write_delay_;
+ bool read_in_progress_;
+ base::WeakPtrFactory<DelayedSocketData> weak_factory_;
+};
+
+// A DataProvider where the reads are ordered.
+// If a read is requested before its sequence number is reached, we return an
+// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
+// wait).
+// The sequence number is incremented on every read and write operation.
+// The message loop may be interrupted by setting the high bit of the sequence
+// number in the MockRead's sequence number. When that MockRead is reached,
+// we post a Quit message to the loop. This allows us to interrupt the reading
+// of data before a complete message has arrived, and provides support for
+// testing server push when the request is issued while the response is in the
+// middle of being received.
+class OrderedSocketData : public StaticSocketDataProvider {
+ public:
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ // Note: All MockReads and MockWrites must be async.
+ // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
+ // MockRead(true, 0, 0);
+ OrderedSocketData(MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+ virtual ~OrderedSocketData();
+
+ // |connect| the result for the connect phase.
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ // Note: All MockReads and MockWrites must be async.
+ // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
+ // MockRead(true, 0, 0);
+ OrderedSocketData(const MockConnect& connect,
+ MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+
+ // Posts a quit message to the current message loop, if one is running.
+ void EndLoop();
+
+ // StaticSocketDataProvider:
+ virtual MockRead GetNextRead() OVERRIDE;
+ virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
+ virtual void Reset() OVERRIDE;
+ virtual void CompleteRead() OVERRIDE;
+
+ private:
+ int sequence_number_;
+ int loop_stop_stage_;
+ bool blocked_;
+ base::WeakPtrFactory<OrderedSocketData> weak_factory_;
+};
+
+class DeterministicMockTCPClientSocket;
+
+// This class gives the user full control over the network activity,
+// specifically the timing of the COMPLETION of I/O operations. Regardless of
+// the order in which I/O operations are initiated, this class ensures that they
+// complete in the correct order.
+//
+// Network activity is modeled as a sequence of numbered steps which is
+// incremented whenever an I/O operation completes. This can happen under two
+// different circumstances:
+//
+// 1) Performing a synchronous I/O operation. (Invoking Read() or Write()
+// when the corresponding MockRead or MockWrite is marked !async).
+// 2) Running the Run() method of this class. The run method will invoke
+// the current MessageLoop, running all pending events, and will then
+// invoke any pending IO callbacks.
+//
+// In addition, this class allows for I/O processing to "stop" at a specified
+// step, by calling SetStop(int) or StopAfter(int). Initiating an I/O operation
+// by calling Read() or Write() while stopped is permitted if the operation is
+// asynchronous. It is an error to perform synchronous I/O while stopped.
+//
+// When creating the MockReads and MockWrites, note that the sequence number
+// refers to the number of the step in which the I/O will complete. In the
+// case of synchronous I/O, this will be the same step as the I/O is initiated.
+// However, in the case of asynchronous I/O, this I/O may be initiated in
+// a much earlier step. Furthermore, when the a Read() or Write() is separated
+// from its completion by other Read() or Writes()'s, it can not be marked
+// synchronous. If it is, ERR_UNUEXPECTED will be returned indicating that a
+// synchronous Read() or Write() could not be completed synchronously because of
+// the specific ordering constraints.
+//
+// Sequence numbers are preserved across both reads and writes. There should be
+// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
+// MockRead reads[] = {
+// MockRead(false, "first read", length, 0) // sync
+// MockRead(true, "second read", length, 2) // async
+// };
+// MockWrite writes[] = {
+// MockWrite(true, "first write", length, 1), // async
+// MockWrite(false, "second write", length, 3), // sync
+// };
+//
+// Example control flow:
+// Read() is called. The current step is 0. The first available read is
+// synchronous, so the call to Read() returns length. The current step is
+// now 1. Next, Read() is called again. The next available read can
+// not be completed until step 2, so Read() returns ERR_IO_PENDING. The current
+// step is still 1. Write is called(). The first available write is able to
+// complete in this step, but is marked asynchronous. Write() returns
+// ERR_IO_PENDING. The current step is still 1. At this point RunFor(1) is
+// called which will cause the write callback to be invoked, and will then
+// stop. The current state is now 2. RunFor(1) is called again, which
+// causes the read callback to be invoked, and will then stop. Then current
+// step is 2. Write() is called again. Then next available write is
+// synchronous so the call to Write() returns length.
+//
+// For examples of how to use this class, see:
+// deterministic_socket_data_unittests.cc
+class DeterministicSocketData
+ : public StaticSocketDataProvider {
+ public:
+ // The Delegate is an abstract interface which handles the communication from
+ // the DeterministicSocketData to the Deterministic MockSocket. The
+ // MockSockets directly store a pointer to the DeterministicSocketData,
+ // whereas the DeterministicSocketData only stores a pointer to the
+ // abstract Delegate interface.
+ class Delegate {
+ public:
+ // Returns true if there is currently a write pending. That is to say, if
+ // an asynchronous write has been started but the callback has not been
+ // invoked.
+ virtual bool WritePending() const = 0;
+ // Returns true if there is currently a read pending. That is to say, if
+ // an asynchronous read has been started but the callback has not been
+ // invoked.
+ virtual bool ReadPending() const = 0;
+ // Called to complete an asynchronous write to execute the write callback.
+ virtual void CompleteWrite() = 0;
+ // Called to complete an asynchronous read to execute the read callback.
+ virtual int CompleteRead() = 0;
+
+ protected:
+ virtual ~Delegate() {}
+ };
+
+ // |reads| the list of MockRead completions.
+ // |writes| the list of MockWrite completions.
+ DeterministicSocketData(MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+ virtual ~DeterministicSocketData();
+
+ // Consume all the data up to the give stop point (via SetStop()).
+ void Run();
+
+ // Set the stop point to be |steps| from now, and then invoke Run().
+ void RunFor(int steps);
+
+ // Stop at step |seq|, which must be in the future.
+ virtual void SetStop(int seq);
+
+ // Stop |seq| steps after the current step.
+ virtual void StopAfter(int seq);
+ bool stopped() const { return stopped_; }
+ void SetStopped(bool val) { stopped_ = val; }
+ MockRead& current_read() { return current_read_; }
+ MockWrite& current_write() { return current_write_; }
+ int sequence_number() const { return sequence_number_; }
+ void set_delegate(base::WeakPtr<Delegate> delegate) {
+ delegate_ = delegate;
+ }
+
+ // StaticSocketDataProvider:
+
+ // When the socket calls Read(), that calls GetNextRead(), and expects either
+ // ERR_IO_PENDING or data.
+ virtual MockRead GetNextRead() OVERRIDE;
+
+ // When the socket calls Write(), it always completes synchronously. OnWrite()
+ // checks to make sure the written data matches the expected data. The
+ // callback will not be invoked until its sequence number is reached.
+ virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
+ virtual void Reset() OVERRIDE;
+ virtual void CompleteRead() OVERRIDE {}
+
+ private:
+ // Invoke the read and write callbacks, if the timing is appropriate.
+ void InvokeCallbacks();
+
+ void NextStep();
+
+ void VerifyCorrectSequenceNumbers(MockRead* reads, size_t reads_count,
+ MockWrite* writes, size_t writes_count);
+
+ int sequence_number_;
+ MockRead current_read_;
+ MockWrite current_write_;
+ int stopping_sequence_number_;
+ bool stopped_;
+ base::WeakPtr<Delegate> delegate_;
+ bool print_debug_;
+ bool is_running_;
+};
+
+// Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket
+// objects get instantiated, they take their data from the i'th element of this
+// array.
+template<typename T>
+class SocketDataProviderArray {
+ public:
+ SocketDataProviderArray() : next_index_(0) {}
+
+ T* GetNext() {
+ DCHECK_LT(next_index_, data_providers_.size());
+ return data_providers_[next_index_++];
+ }
+
+ void Add(T* data_provider) {
+ DCHECK(data_provider);
+ data_providers_.push_back(data_provider);
+ }
+
+ size_t next_index() { return next_index_; }
+
+ void ResetNextIndex() {
+ next_index_ = 0;
+ }
+
+ private:
+ // Index of the next |data_providers_| element to use. Not an iterator
+ // because those are invalidated on vector reallocation.
+ size_t next_index_;
+
+ // SocketDataProviders to be returned.
+ std::vector<T*> data_providers_;
+};
+
+class MockUDPClientSocket;
+class MockTCPClientSocket;
+class MockSSLClientSocket;
+
+// ClientSocketFactory which contains arrays of sockets of each type.
+// You should first fill the arrays using AddMock{SSL,}Socket. When the factory
+// is asked to create a socket, it takes next entry from appropriate array.
+// You can use ResetNextMockIndexes to reset that next entry index for all mock
+// socket types.
+class MockClientSocketFactory : public ClientSocketFactory {
+ public:
+ MockClientSocketFactory();
+ virtual ~MockClientSocketFactory();
+
+ void AddSocketDataProvider(SocketDataProvider* socket);
+ void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
+ void ResetNextMockIndexes();
+
+ SocketDataProviderArray<SocketDataProvider>& mock_data() {
+ return mock_data_;
+ }
+
+ // ClientSocketFactory
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE;
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE;
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE;
+ virtual void ClearSSLSessionCache() OVERRIDE;
+
+ private:
+ SocketDataProviderArray<SocketDataProvider> mock_data_;
+ SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
+};
+
+class MockClientSocket : public SSLClientSocket {
+ public:
+ // Value returned by GetTLSUniqueChannelBinding().
+ static const char kTlsUnique[];
+
+ // The BoundNetLog is needed to test LoadTimingInfo, which uses NetLog IDs as
+ // unique socket IDs.
+ explicit MockClientSocket(const BoundNetLog& net_log);
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) = 0;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) = 0;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) = 0;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+
+ // SSLClientSocket implementation.
+ virtual void GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) OVERRIDE;
+ virtual int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) OVERRIDE;
+ virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+ virtual NextProtoStatus GetNextProto(std::string* proto,
+ std::string* server_protos) OVERRIDE;
+ virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+
+ protected:
+ virtual ~MockClientSocket();
+ void RunCallbackAsync(const CompletionCallback& callback, int result);
+ void RunCallback(const CompletionCallback& callback, int result);
+
+ base::WeakPtrFactory<MockClientSocket> weak_factory_;
+
+ // True if Connect completed successfully and Disconnect hasn't been called.
+ bool connected_;
+
+ // Address of the "remote" peer we're connected to.
+ IPEndPoint peer_addr_;
+
+ BoundNetLog net_log_;
+};
+
+class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
+ public:
+ MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log,
+ SocketDataProvider* socket);
+ virtual ~MockTCPClientSocket();
+
+ const AddressList& addresses() const { return addresses_; }
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // AsyncSocket:
+ virtual void OnReadComplete(const MockRead& data) OVERRIDE;
+ virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+
+ private:
+ int CompleteRead();
+
+ AddressList addresses_;
+
+ SocketDataProvider* data_;
+ int read_offset_;
+ MockRead read_data_;
+ bool need_read_data_;
+
+ // True if the peer has closed the connection. This allows us to simulate
+ // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
+ // TCPClientSocket.
+ bool peer_closed_connection_;
+
+ // While an asynchronous IO is pending, we save our user-buffer state.
+ IOBuffer* pending_buf_;
+ int pending_buf_len_;
+ CompletionCallback pending_callback_;
+ bool was_used_to_convey_data_;
+};
+
+// DeterministicSocketHelper is a helper class that can be used
+// to simulate net::Socket::Read() and net::Socket::Write()
+// using deterministic |data|.
+// Note: This is provided as a common helper class because
+// of the inheritance hierarchy of DeterministicMock[UDP,TCP]ClientSocket and a
+// desire not to introduce an additional common base class.
+class DeterministicSocketHelper {
+ public:
+ DeterministicSocketHelper(net::NetLog* net_log,
+ DeterministicSocketData* data);
+ virtual ~DeterministicSocketHelper();
+
+ bool write_pending() const { return write_pending_; }
+ bool read_pending() const { return read_pending_; }
+
+ void CompleteWrite();
+ int CompleteRead();
+
+ int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback);
+ int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback);
+
+ const BoundNetLog& net_log() const { return net_log_; }
+
+ bool was_used_to_convey_data() const { return was_used_to_convey_data_; }
+
+ bool peer_closed_connection() const { return peer_closed_connection_; }
+
+ DeterministicSocketData* data() const { return data_; }
+
+ private:
+ bool write_pending_;
+ CompletionCallback write_callback_;
+ int write_result_;
+
+ MockRead read_data_;
+
+ IOBuffer* read_buf_;
+ int read_buf_len_;
+ bool read_pending_;
+ CompletionCallback read_callback_;
+ DeterministicSocketData* data_;
+ bool was_used_to_convey_data_;
+ bool peer_closed_connection_;
+ BoundNetLog net_log_;
+};
+
+// Mock UDP socket to be used in conjunction with DeterministicSocketData.
+class DeterministicMockUDPClientSocket
+ : public DatagramClientSocket,
+ public AsyncSocket,
+ public DeterministicSocketData::Delegate,
+ public base::SupportsWeakPtr<DeterministicMockUDPClientSocket> {
+ public:
+ DeterministicMockUDPClientSocket(net::NetLog* net_log,
+ DeterministicSocketData* data);
+ virtual ~DeterministicMockUDPClientSocket();
+
+ // DeterministicSocketData::Delegate:
+ virtual bool WritePending() const OVERRIDE;
+ virtual bool ReadPending() const OVERRIDE;
+ virtual void CompleteWrite() OVERRIDE;
+ virtual int CompleteRead() OVERRIDE;
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ // DatagramSocket implementation.
+ virtual void Close() OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+
+ // DatagramClientSocket implementation.
+ virtual int Connect(const IPEndPoint& address) OVERRIDE;
+
+ // AsyncSocket implementation.
+ virtual void OnReadComplete(const MockRead& data) OVERRIDE;
+ virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+
+ private:
+ bool connected_;
+ IPEndPoint peer_address_;
+ DeterministicSocketHelper helper_;
+};
+
+// Mock TCP socket to be used in conjunction with DeterministicSocketData.
+class DeterministicMockTCPClientSocket
+ : public MockClientSocket,
+ public AsyncSocket,
+ public DeterministicSocketData::Delegate,
+ public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
+ public:
+ DeterministicMockTCPClientSocket(net::NetLog* net_log,
+ DeterministicSocketData* data);
+ virtual ~DeterministicMockTCPClientSocket();
+
+ // DeterministicSocketData::Delegate:
+ virtual bool WritePending() const OVERRIDE;
+ virtual bool ReadPending() const OVERRIDE;
+ virtual void CompleteWrite() OVERRIDE;
+ virtual int CompleteRead() OVERRIDE;
+
+ // Socket:
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // StreamSocket:
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // AsyncSocket:
+ virtual void OnReadComplete(const MockRead& data) OVERRIDE;
+ virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+
+ private:
+ DeterministicSocketHelper helper_;
+};
+
+class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
+ public:
+ MockSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ SSLSocketDataProvider* socket);
+ virtual ~MockSSLClientSocket();
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // SSLClientSocket implementation.
+ virtual void GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) OVERRIDE;
+ virtual NextProtoStatus GetNextProto(std::string* proto,
+ std::string* server_protos) OVERRIDE;
+ virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE;
+ virtual void set_protocol_negotiated(
+ NextProto protocol_negotiated) OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+
+ // This MockSocket does not implement the manual async IO feature.
+ virtual void OnReadComplete(const MockRead& data) OVERRIDE;
+ virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+
+ virtual bool WasChannelIDSent() const OVERRIDE;
+ virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE;
+ virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+
+ private:
+ static void ConnectCallback(MockSSLClientSocket *ssl_client_socket,
+ const CompletionCallback& callback,
+ int rv);
+
+ scoped_ptr<ClientSocketHandle> transport_;
+ SSLSocketDataProvider* data_;
+ bool is_npn_state_set_;
+ bool new_npn_value_;
+ bool is_protocol_negotiated_set_;
+ NextProto protocol_negotiated_;
+};
+
+class MockUDPClientSocket
+ : public DatagramClientSocket,
+ public AsyncSocket {
+ public:
+ MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log);
+ virtual ~MockUDPClientSocket();
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ // DatagramSocket implementation.
+ virtual void Close() OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+
+ // DatagramClientSocket implementation.
+ virtual int Connect(const IPEndPoint& address) OVERRIDE;
+
+ // AsyncSocket implementation.
+ virtual void OnReadComplete(const MockRead& data) OVERRIDE;
+ virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+
+ private:
+ int CompleteRead();
+
+ void RunCallbackAsync(const CompletionCallback& callback, int result);
+ void RunCallback(const CompletionCallback& callback, int result);
+
+ bool connected_;
+ SocketDataProvider* data_;
+ int read_offset_;
+ MockRead read_data_;
+ bool need_read_data_;
+
+ // Address of the "remote" peer we're connected to.
+ IPEndPoint peer_addr_;
+
+ // While an asynchronous IO is pending, we save our user-buffer state.
+ IOBuffer* pending_buf_;
+ int pending_buf_len_;
+ CompletionCallback pending_callback_;
+
+ BoundNetLog net_log_;
+
+ base::WeakPtrFactory<MockUDPClientSocket> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockUDPClientSocket);
+};
+
+class TestSocketRequest : public TestCompletionCallbackBase {
+ public:
+ TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
+ size_t* completion_count);
+ virtual ~TestSocketRequest();
+
+ ClientSocketHandle* handle() { return &handle_; }
+
+ const net::CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result);
+
+ ClientSocketHandle handle_;
+ std::vector<TestSocketRequest*>* request_order_;
+ size_t* completion_count_;
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestSocketRequest);
+};
+
+class ClientSocketPoolTest {
+ public:
+ enum KeepAlive {
+ KEEP_ALIVE,
+
+ // A socket will be disconnected in addition to handle being reset.
+ NO_KEEP_ALIVE,
+ };
+
+ static const int kIndexOutOfBounds;
+ static const int kRequestNotFound;
+
+ ClientSocketPoolTest();
+ ~ClientSocketPoolTest();
+
+ template <typename PoolType, typename SocketParams>
+ int StartRequestUsingPool(PoolType* socket_pool,
+ const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<SocketParams>& socket_params) {
+ DCHECK(socket_pool);
+ TestSocketRequest* request = new TestSocketRequest(&request_order_,
+ &completion_count_);
+ requests_.push_back(request);
+ int rv = request->handle()->Init(
+ group_name, socket_params, priority, request->callback(),
+ socket_pool, BoundNetLog());
+ if (rv != ERR_IO_PENDING)
+ request_order_.push_back(request);
+ return rv;
+ }
+
+ // Provided there were n requests started, takes |index| in range 1..n
+ // and returns order in which that request completed, in range 1..n,
+ // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
+ // if that request did not complete (for example was canceled).
+ int GetOrderOfRequest(size_t index) const;
+
+ // Resets first initialized socket handle from |requests_|. If found such
+ // a handle, returns true.
+ bool ReleaseOneConnection(KeepAlive keep_alive);
+
+ // Releases connections until there is nothing to release.
+ void ReleaseAllConnections(KeepAlive keep_alive);
+
+ // Note that this uses 0-based indices, while GetOrderOfRequest takes and
+ // returns 0-based indices.
+ TestSocketRequest* request(int i) { return requests_[i]; }
+
+ size_t requests_size() const { return requests_.size(); }
+ ScopedVector<TestSocketRequest>* requests() { return &requests_; }
+ size_t completion_count() const { return completion_count_; }
+
+ private:
+ ScopedVector<TestSocketRequest> requests_;
+ std::vector<TestSocketRequest*> request_order_;
+ size_t completion_count_;
+};
+
+class MockTransportClientSocketPool : public TransportClientSocketPool {
+ public:
+ class MockConnectJob {
+ public:
+ MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle,
+ const CompletionCallback& callback);
+ ~MockConnectJob();
+
+ int Connect();
+ bool CancelHandle(const ClientSocketHandle* handle);
+
+ private:
+ void OnConnect(int rv);
+
+ scoped_ptr<StreamSocket> socket_;
+ ClientSocketHandle* handle_;
+ CompletionCallback user_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
+ };
+
+ MockTransportClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ ClientSocketFactory* socket_factory);
+
+ virtual ~MockTransportClientSocketPool();
+
+ int release_count() const { return release_count_; }
+ int cancel_count() const { return cancel_count_; }
+
+ // TransportClientSocketPool implementation.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE;
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
+
+ private:
+ ClientSocketFactory* client_socket_factory_;
+ ScopedVector<MockConnectJob> job_list_;
+ int release_count_;
+ int cancel_count_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
+};
+
+class DeterministicMockClientSocketFactory : public ClientSocketFactory {
+ public:
+ DeterministicMockClientSocketFactory();
+ virtual ~DeterministicMockClientSocketFactory();
+
+ void AddSocketDataProvider(DeterministicSocketData* socket);
+ void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
+ void ResetNextMockIndexes();
+
+ // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
+ // created.
+ MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
+
+ SocketDataProviderArray<DeterministicSocketData>& mock_data() {
+ return mock_data_;
+ }
+ std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
+ return tcp_client_sockets_;
+ }
+ std::vector<DeterministicMockUDPClientSocket*>& udp_client_sockets() {
+ return udp_client_sockets_;
+ }
+
+ // ClientSocketFactory
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE;
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE;
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE;
+ virtual void ClearSSLSessionCache() OVERRIDE;
+
+ private:
+ SocketDataProviderArray<DeterministicSocketData> mock_data_;
+ SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
+
+ // Store pointers to handed out sockets in case the test wants to get them.
+ std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
+ std::vector<DeterministicMockUDPClientSocket*> udp_client_sockets_;
+ std::vector<MockSSLClientSocket*> ssl_client_sockets_;
+};
+
+class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
+ public:
+ MockSOCKSClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ TransportClientSocketPool* transport_pool);
+
+ virtual ~MockSOCKSClientSocketPool();
+
+ // SOCKSClientSocketPool implementation.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE;
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
+
+ private:
+ TransportClientSocketPool* const transport_pool_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
+};
+
+// Constants for a successful SOCKS v5 handshake.
+extern const char kSOCKS5GreetRequest[];
+extern const int kSOCKS5GreetRequestLength;
+
+extern const char kSOCKS5GreetResponse[];
+extern const int kSOCKS5GreetResponseLength;
+
+extern const char kSOCKS5OkRequest[];
+extern const int kSOCKS5OkRequestLength;
+
+extern const char kSOCKS5OkResponse[];
+extern const int kSOCKS5OkResponseLength;
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKET_TEST_UTIL_H_
diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc
new file mode 100644
index 00000000000..537b584a932
--- /dev/null
+++ b/chromium/net/socket/socks5_client_socket.cc
@@ -0,0 +1,487 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks5_client_socket.h"
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/debug/trace_event.h"
+#include "base/format_macros.h"
+#include "base/strings/string_util.h"
+#include "base/sys_byteorder.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/socket/client_socket_handle.h"
+
+namespace net {
+
+const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
+const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
+const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
+const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05;
+const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01;
+const uint8 SOCKS5ClientSocket::kNullByte = 0x00;
+
+COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
+COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
+
+SOCKS5ClientSocket::SOCKS5ClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info)
+ : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
+ base::Unretained(this))),
+ transport_(transport_socket.Pass()),
+ next_state_(STATE_NONE),
+ completed_handshake_(false),
+ bytes_sent_(0),
+ bytes_received_(0),
+ read_header_size(kReadHeaderSize),
+ host_request_info_(req_info),
+ net_log_(transport_->socket()->NetLog()) {
+}
+
+SOCKS5ClientSocket::~SOCKS5ClientSocket() {
+ Disconnect();
+}
+
+int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) {
+ DCHECK(transport_.get());
+ DCHECK(transport_->socket());
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ // If already connected, then just return OK.
+ if (completed_handshake_)
+ return OK;
+
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT);
+
+ next_state_ = STATE_GREET_WRITE;
+ buffer_.clear();
+
+ int rv = DoLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ user_callback_ = callback;
+ } else {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv);
+ }
+ return rv;
+}
+
+void SOCKS5ClientSocket::Disconnect() {
+ completed_handshake_ = false;
+ transport_->socket()->Disconnect();
+
+ // Reset other states to make sure they aren't mistakenly used later.
+ // These are the states initialized by Connect().
+ next_state_ = STATE_NONE;
+ user_callback_.Reset();
+}
+
+bool SOCKS5ClientSocket::IsConnected() const {
+ return completed_handshake_ && transport_->socket()->IsConnected();
+}
+
+bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
+ return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
+}
+
+const BoundNetLog& SOCKS5ClientSocket::NetLog() const {
+ return net_log_;
+}
+
+void SOCKS5ClientSocket::SetSubresourceSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetSubresourceSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+void SOCKS5ClientSocket::SetOmniboxSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetOmniboxSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+bool SOCKS5ClientSocket::WasEverUsed() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->WasEverUsed();
+ }
+ NOTREACHED();
+ return false;
+}
+
+bool SOCKS5ClientSocket::UsingTCPFastOpen() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->UsingTCPFastOpen();
+ }
+ NOTREACHED();
+ return false;
+}
+
+bool SOCKS5ClientSocket::WasNpnNegotiated() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->WasNpnNegotiated();
+ }
+ NOTREACHED();
+ return false;
+}
+
+NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->GetNegotiatedProtocol();
+ }
+ NOTREACHED();
+ return kProtoUnknown;
+}
+
+bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->GetSSLInfo(ssl_info);
+ }
+ NOTREACHED();
+ return false;
+
+}
+
+// Read is called by the transport layer above to read. This can only be done
+// if the SOCKS handshake is complete.
+int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(completed_handshake_);
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ return transport_->socket()->Read(buf, buf_len, callback);
+}
+
+// Write is called by the transport layer. This can only be done if the
+// SOCKS handshake is complete.
+int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(completed_handshake_);
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ return transport_->socket()->Write(buf, buf_len, callback);
+}
+
+bool SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) {
+ return transport_->socket()->SetReceiveBufferSize(size);
+}
+
+bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) {
+ return transport_->socket()->SetSendBufferSize(size);
+}
+
+void SOCKS5ClientSocket::DoCallback(int result) {
+ DCHECK_NE(ERR_IO_PENDING, result);
+ DCHECK(!user_callback_.is_null());
+
+ // Since Run() may result in Read being called,
+ // clear user_callback_ up front.
+ CompletionCallback c = user_callback_;
+ user_callback_.Reset();
+ c.Run(result);
+}
+
+void SOCKS5ClientSocket::OnIOComplete(int result) {
+ DCHECK_NE(STATE_NONE, next_state_);
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ net_log_.EndEvent(NetLog::TYPE_SOCKS5_CONNECT);
+ DoCallback(rv);
+ }
+}
+
+int SOCKS5ClientSocket::DoLoop(int last_io_result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+ int rv = last_io_result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_GREET_WRITE:
+ DCHECK_EQ(OK, rv);
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_WRITE);
+ rv = DoGreetWrite();
+ break;
+ case STATE_GREET_WRITE_COMPLETE:
+ rv = DoGreetWriteComplete(rv);
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_WRITE, rv);
+ break;
+ case STATE_GREET_READ:
+ DCHECK_EQ(OK, rv);
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_READ);
+ rv = DoGreetRead();
+ break;
+ case STATE_GREET_READ_COMPLETE:
+ rv = DoGreetReadComplete(rv);
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_READ, rv);
+ break;
+ case STATE_HANDSHAKE_WRITE:
+ DCHECK_EQ(OK, rv);
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE);
+ rv = DoHandshakeWrite();
+ break;
+ case STATE_HANDSHAKE_WRITE_COMPLETE:
+ rv = DoHandshakeWriteComplete(rv);
+ net_log_.EndEventWithNetErrorCode(
+ NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE, rv);
+ break;
+ case STATE_HANDSHAKE_READ:
+ DCHECK_EQ(OK, rv);
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_READ);
+ rv = DoHandshakeRead();
+ break;
+ case STATE_HANDSHAKE_READ_COMPLETE:
+ rv = DoHandshakeReadComplete(rv);
+ net_log_.EndEventWithNetErrorCode(
+ NetLog::TYPE_SOCKS5_HANDSHAKE_READ, rv);
+ break;
+ default:
+ NOTREACHED() << "bad state";
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+ return rv;
+}
+
+const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication
+const char kSOCKS5GreetReadData[] = { 0x05, 0x00 };
+
+int SOCKS5ClientSocket::DoGreetWrite() {
+ // Since we only have 1 byte to send the hostname length in, if the
+ // URL has a hostname longer than 255 characters we can't send it.
+ if (0xFF < host_request_info_.hostname().size()) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_HOSTNAME_TOO_BIG);
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ if (buffer_.empty()) {
+ buffer_ = std::string(kSOCKS5GreetWriteData,
+ arraysize(kSOCKS5GreetWriteData));
+ bytes_sent_ = 0;
+ }
+
+ next_state_ = STATE_GREET_WRITE_COMPLETE;
+ size_t handshake_buf_len = buffer_.size() - bytes_sent_;
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
+ handshake_buf_len);
+ return transport_->socket()
+ ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
+}
+
+int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
+ if (result < 0)
+ return result;
+
+ bytes_sent_ += result;
+ if (bytes_sent_ == buffer_.size()) {
+ buffer_.clear();
+ bytes_received_ = 0;
+ next_state_ = STATE_GREET_READ;
+ } else {
+ next_state_ = STATE_GREET_WRITE;
+ }
+ return OK;
+}
+
+int SOCKS5ClientSocket::DoGreetRead() {
+ next_state_ = STATE_GREET_READ_COMPLETE;
+ size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ return transport_->socket()
+ ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
+}
+
+int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
+ if (result < 0)
+ return result;
+
+ if (result == 0) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ bytes_received_ += result;
+ buffer_.append(handshake_buf_->data(), result);
+ if (bytes_received_ < kGreetReadHeaderSize) {
+ next_state_ = STATE_GREET_READ;
+ return OK;
+ }
+
+ // Got the greet data.
+ if (buffer_[0] != kSOCKS5Version) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
+ NetLog::IntegerCallback("version", buffer_[0]));
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+ if (buffer_[1] != 0x00) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_AUTH,
+ NetLog::IntegerCallback("method", buffer_[1]));
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ buffer_.clear();
+ next_state_ = STATE_HANDSHAKE_WRITE;
+ return OK;
+}
+
+int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
+ const {
+ DCHECK(handshake->empty());
+
+ handshake->push_back(kSOCKS5Version);
+ handshake->push_back(kTunnelCommand); // Connect command
+ handshake->push_back(kNullByte); // Reserved null
+
+ handshake->push_back(kEndPointDomain); // The type of the address.
+
+ DCHECK_GE(static_cast<size_t>(0xFF), host_request_info_.hostname().size());
+
+ // First add the size of the hostname, followed by the hostname.
+ handshake->push_back(static_cast<unsigned char>(
+ host_request_info_.hostname().size()));
+ handshake->append(host_request_info_.hostname());
+
+ uint16 nw_port = base::HostToNet16(host_request_info_.port());
+ handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
+ return OK;
+}
+
+// Writes the SOCKS handshake data to the underlying socket connection.
+int SOCKS5ClientSocket::DoHandshakeWrite() {
+ next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
+
+ if (buffer_.empty()) {
+ int rv = BuildHandshakeWriteBuffer(&buffer_);
+ if (rv != OK)
+ return rv;
+ bytes_sent_ = 0;
+ }
+
+ int handshake_buf_len = buffer_.size() - bytes_sent_;
+ DCHECK_LT(0, handshake_buf_len);
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
+ handshake_buf_len);
+ return transport_->socket()
+ ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
+}
+
+int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
+ if (result < 0)
+ return result;
+
+ // We ignore the case when result is 0, since the underlying Write
+ // may return spurious writes while waiting on the socket.
+
+ bytes_sent_ += result;
+ if (bytes_sent_ == buffer_.size()) {
+ next_state_ = STATE_HANDSHAKE_READ;
+ buffer_.clear();
+ } else if (bytes_sent_ < buffer_.size()) {
+ next_state_ = STATE_HANDSHAKE_WRITE;
+ } else {
+ NOTREACHED();
+ }
+
+ return OK;
+}
+
+int SOCKS5ClientSocket::DoHandshakeRead() {
+ next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
+
+ if (buffer_.empty()) {
+ bytes_received_ = 0;
+ read_header_size = kReadHeaderSize;
+ }
+
+ int handshake_buf_len = read_header_size - bytes_received_;
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ return transport_->socket()
+ ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
+}
+
+int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
+ if (result < 0)
+ return result;
+
+ // The underlying socket closed unexpectedly.
+ if (result == 0) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ buffer_.append(handshake_buf_->data(), result);
+ bytes_received_ += result;
+
+ // When the first few bytes are read, check how many more are required
+ // and accordingly increase them
+ if (bytes_received_ == kReadHeaderSize) {
+ if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
+ NetLog::IntegerCallback("version", buffer_[0]));
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+ if (buffer_[1] != 0x00) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_SERVER_ERROR,
+ NetLog::IntegerCallback("error_code", buffer_[1]));
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ // We check the type of IP/Domain the server returns and accordingly
+ // increase the size of the response. For domains, we need to read the
+ // size of the domain, so the initial request size is upto the domain
+ // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
+ // read, we substract 1 byte from the additional request size.
+ SocksEndPointAddressType address_type =
+ static_cast<SocksEndPointAddressType>(buffer_[3]);
+ if (address_type == kEndPointDomain)
+ read_header_size += static_cast<uint8>(buffer_[4]);
+ else if (address_type == kEndPointResolvedIPv4)
+ read_header_size += sizeof(struct in_addr) - 1;
+ else if (address_type == kEndPointResolvedIPv6)
+ read_header_size += sizeof(struct in6_addr) - 1;
+ else {
+ net_log_.AddEvent(NetLog::TYPE_SOCKS_UNKNOWN_ADDRESS_TYPE,
+ NetLog::IntegerCallback("address_type", buffer_[3]));
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ read_header_size += 2; // for the port.
+ next_state_ = STATE_HANDSHAKE_READ;
+ return OK;
+ }
+
+ // When the final bytes are read, setup handshake. We ignore the rest
+ // of the response since they represent the SOCKSv5 endpoint and have
+ // no use when doing a tunnel connection.
+ if (bytes_received_ == read_header_size) {
+ completed_handshake_ = true;
+ buffer_.clear();
+ next_state_ = STATE_NONE;
+ return OK;
+ }
+
+ next_state_ = STATE_HANDSHAKE_READ;
+ return OK;
+}
+
+int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetPeerAddress(address);
+}
+
+int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetLocalAddress(address);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h
new file mode 100644
index 00000000000..45216244f10
--- /dev/null
+++ b/chromium/net/socket/socks5_client_socket.h
@@ -0,0 +1,155 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_
+#define NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/gtest_prod_util.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/dns/host_resolver.h"
+#include "net/socket/stream_socket.h"
+#include "url/gurl.h"
+
+namespace net {
+
+class ClientSocketHandle;
+class BoundNetLog;
+
+// This StreamSocket is used to setup a SOCKSv5 handshake with a socks proxy.
+// Currently no SOCKSv5 authentication is supported.
+class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket {
+ public:
+ // |req_info| contains the hostname and port to which the socket above will
+ // communicate to via the SOCKS layer.
+ //
+ // Although SOCKS 5 supports 3 different modes of addressing, we will
+ // always pass it a hostname. This means the DNS resolving is done
+ // proxy side.
+ SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info);
+
+ // On destruction Disconnect() is called.
+ virtual ~SOCKS5ClientSocket();
+
+ // StreamSocket implementation.
+
+ // Does the SOCKS handshake and completes the protocol.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+
+ private:
+ enum State {
+ STATE_GREET_WRITE,
+ STATE_GREET_WRITE_COMPLETE,
+ STATE_GREET_READ,
+ STATE_GREET_READ_COMPLETE,
+ STATE_HANDSHAKE_WRITE,
+ STATE_HANDSHAKE_WRITE_COMPLETE,
+ STATE_HANDSHAKE_READ,
+ STATE_HANDSHAKE_READ_COMPLETE,
+ STATE_NONE,
+ };
+
+ // Addressing type that can be specified in requests or responses.
+ enum SocksEndPointAddressType {
+ kEndPointDomain = 0x03,
+ kEndPointResolvedIPv4 = 0x01,
+ kEndPointResolvedIPv6 = 0x04,
+ };
+
+ static const unsigned int kGreetReadHeaderSize;
+ static const unsigned int kWriteHeaderSize;
+ static const unsigned int kReadHeaderSize;
+ static const uint8 kSOCKS5Version;
+ static const uint8 kTunnelCommand;
+ static const uint8 kNullByte;
+
+ void DoCallback(int result);
+ void OnIOComplete(int result);
+
+ int DoLoop(int last_io_result);
+ int DoHandshakeRead();
+ int DoHandshakeReadComplete(int result);
+ int DoHandshakeWrite();
+ int DoHandshakeWriteComplete(int result);
+ int DoGreetRead();
+ int DoGreetReadComplete(int result);
+ int DoGreetWrite();
+ int DoGreetWriteComplete(int result);
+
+ // Writes the SOCKS handshake buffer into |handshake|
+ // and return OK on success.
+ int BuildHandshakeWriteBuffer(std::string* handshake) const;
+
+ CompletionCallback io_callback_;
+
+ // Stores the underlying socket.
+ scoped_ptr<ClientSocketHandle> transport_;
+
+ State next_state_;
+
+ // Stores the callback to the layer above, called on completing Connect().
+ CompletionCallback user_callback_;
+
+ // This IOBuffer is used by the class to read and write
+ // SOCKS handshake data. The length contains the expected size to
+ // read or write.
+ scoped_refptr<IOBuffer> handshake_buf_;
+
+ // While writing, this buffer stores the complete write handshake data.
+ // While reading, it stores the handshake information received so far.
+ std::string buffer_;
+
+ // This becomes true when the SOCKS handshake has completed and the
+ // overlying connection is free to communicate.
+ bool completed_handshake_;
+
+ // These contain the bytes sent / received by the SOCKS handshake.
+ size_t bytes_sent_;
+ size_t bytes_received_;
+
+ size_t read_header_size;
+
+ HostResolver::RequestInfo host_request_info_;
+
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKS5_CLIENT_SOCKET_H_
diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc
new file mode 100644
index 00000000000..4c9240ff5c1
--- /dev/null
+++ b/chromium/net/socket/socks5_client_socket_unittest.cc
@@ -0,0 +1,375 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks5_client_socket.h"
+
+#include <algorithm>
+#include <iterator>
+#include <map>
+
+#include "base/sys_byteorder.h"
+#include "net/base/address_list.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/test_completion_callback.h"
+#include "net/base/winsock_init.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/tcp_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+//-----------------------------------------------------------------------------
+
+namespace net {
+
+namespace {
+
+// Base class to test SOCKS5ClientSocket
+class SOCKS5ClientSocketTest : public PlatformTest {
+ public:
+ SOCKS5ClientSocketTest();
+ // Create a SOCKSClientSocket on top of a MockSocket.
+ scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[],
+ size_t reads_count,
+ MockWrite writes[],
+ size_t writes_count,
+ const std::string& hostname,
+ int port,
+ NetLog* net_log);
+
+ virtual void SetUp();
+
+ protected:
+ const uint16 kNwPort;
+ CapturingNetLog net_log_;
+ scoped_ptr<SOCKS5ClientSocket> user_sock_;
+ AddressList address_list_;
+ // Filled in by BuildMockSocket() and owned by its return value
+ // (which |user_sock| is set to).
+ StreamSocket* tcp_sock_;
+ TestCompletionCallback callback_;
+ scoped_ptr<MockHostResolver> host_resolver_;
+ scoped_ptr<SocketDataProvider> data_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
+};
+
+SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
+ : kNwPort(base::HostToNet16(80)),
+ host_resolver_(new MockHostResolver) {
+}
+
+// Set up platform before every test case
+void SOCKS5ClientSocketTest::SetUp() {
+ PlatformTest::SetUp();
+
+ // Resolve the "localhost" AddressList used by the TCP connection to connect.
+ HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
+ TestCompletionCallback callback;
+ int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(),
+ NULL, BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ ASSERT_EQ(OK, rv);
+}
+
+scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
+ MockRead reads[],
+ size_t reads_count,
+ MockWrite writes[],
+ size_t writes_count,
+ const std::string& hostname,
+ int port,
+ NetLog* net_log) {
+ TestCompletionCallback callback;
+ data_.reset(new StaticSocketDataProvider(reads, reads_count,
+ writes, writes_count));
+ tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
+
+ int rv = tcp_sock_->Connect(callback.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(tcp_sock_->IsConnected());
+
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ // |connection| takes ownership of |tcp_sock_|, but keep a
+ // non-owning pointer to it.
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket(
+ connection.Pass(),
+ HostResolver::RequestInfo(HostPortPair(hostname, port))));
+}
+
+// Tests a complete SOCKS5 handshake and the disconnection.
+TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
+ const std::string payload_write = "random data";
+ const std::string payload_read = "moar random data";
+
+ const char kOkRequest[] = {
+ 0x05, // Version
+ 0x01, // Command (CONNECT)
+ 0x00, // Reserved.
+ 0x03, // Address type (DOMAINNAME).
+ 0x09, // Length of domain (9)
+ // Domain string:
+ 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
+ 0x00, 0x50, // 16-bit port (80)
+ };
+
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
+ MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)),
+ MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
+ MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
+ MockRead(ASYNC, payload_read.data(), payload_read.size()) };
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ "localhost", 80, &net_log_);
+
+ // At this state the TCP connection is completed but not the SOCKS handshake.
+ EXPECT_TRUE(tcp_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnected());
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(user_sock_->IsConnected());
+
+ CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
+ NetLog::TYPE_SOCKS5_CONNECT));
+
+ rv = callback_.WaitForResult();
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
+ NetLog::TYPE_SOCKS5_CONNECT));
+
+ scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
+ memcpy(buffer->data(), payload_write.data(), payload_write.size());
+ rv = user_sock_->Write(
+ buffer.get(), payload_write.size(), callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
+
+ buffer = new IOBuffer(payload_read.size());
+ rv =
+ user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
+ EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
+
+ user_sock_->Disconnect();
+ EXPECT_FALSE(tcp_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnected());
+}
+
+// Test that you can call Connect() again after having called Disconnect().
+TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
+ const std::string hostname = "my-host-name";
+ const char kSOCKS5DomainRequest[] = {
+ 0x05, // VER
+ 0x01, // CMD
+ 0x00, // RSV
+ 0x03, // ATYPE
+ };
+
+ std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
+ request.push_back(hostname.size());
+ request.append(hostname);
+ request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
+
+ for (int i = 0; i < 2; ++i) {
+ MockWrite data_writes[] = {
+ MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
+ MockWrite(SYNCHRONOUS, request.data(), request.size())
+ };
+ MockRead data_reads[] = {
+ MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
+ MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
+ };
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, NULL);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+
+ user_sock_->Disconnect();
+ EXPECT_FALSE(user_sock_->IsConnected());
+ }
+}
+
+// Test that we fail trying to connect to a hosname longer than 255 bytes.
+TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
+ // Create a string of length 256, where each character is 'x'.
+ std::string large_host_name;
+ std::fill_n(std::back_inserter(large_host_name), 256, 'x');
+
+ // Create a SOCKS socket, with mock transport socket.
+ MockWrite data_writes[] = {MockWrite()};
+ MockRead data_reads[] = {MockRead()};
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ large_host_name, 80, NULL);
+
+ // Try to connect -- should fail (without having read/written anything to
+ // the transport socket first) because the hostname is too long.
+ TestCompletionCallback callback;
+ int rv = user_sock_->Connect(callback.callback());
+ EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
+}
+
+TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
+ const std::string hostname = "www.google.com";
+
+ const char kOkRequest[] = {
+ 0x05, // Version
+ 0x01, // Command (CONNECT)
+ 0x00, // Reserved.
+ 0x03, // Address type (DOMAINNAME).
+ 0x0E, // Length of domain (14)
+ // Domain string:
+ 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
+ 0x00, 0x50, // 16-bit port (80)
+ };
+
+ // Test for partial greet request write
+ {
+ const char partial1[] = { 0x05, 0x01 };
+ const char partial2[] = { 0x00 };
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, arraysize(partial1)),
+ MockWrite(ASYNC, partial2, arraysize(partial2)),
+ MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
+ MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
+ NetLog::TYPE_SOCKS5_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ }
+
+ // Test for partial greet response read
+ {
+ const char partial1[] = { 0x05 };
+ const char partial2[] = { 0x00 };
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
+ MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, partial1, arraysize(partial1)),
+ MockRead(ASYNC, partial2, arraysize(partial2)),
+ MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ }
+
+ // Test for partial handshake request write.
+ {
+ const int kSplitPoint = 3; // Break handshake write into two parts.
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
+ MockWrite(ASYNC, kOkRequest, kSplitPoint),
+ MockWrite(ASYNC, kOkRequest + kSplitPoint,
+ arraysize(kOkRequest) - kSplitPoint)
+ };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
+ MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ }
+
+ // Test for partial handshake response read
+ {
+ const int kSplitPoint = 6; // Break the handshake read into two parts.
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
+ MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest))
+ };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
+ MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
+ MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
+ kSOCKS5OkResponseLength - kSplitPoint)
+ };
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hostname, 80, &net_log_);
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
+ NetLog::TYPE_SOCKS5_CONNECT));
+ }
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc
new file mode 100644
index 00000000000..1941fdbfd95
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket.cc
@@ -0,0 +1,432 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks_client_socket.h"
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/compiler_specific.h"
+#include "base/sys_byteorder.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/socket/client_socket_handle.h"
+
+namespace net {
+
+// Every SOCKS server requests a user-id from the client. It is optional
+// and we send an empty string.
+static const char kEmptyUserId[] = "";
+
+// For SOCKS4, the client sends 8 bytes plus the size of the user-id.
+static const unsigned int kWriteHeaderSize = 8;
+
+// For SOCKS4 the server sends 8 bytes for acknowledgement.
+static const unsigned int kReadHeaderSize = 8;
+
+// Server Response codes for SOCKS.
+static const uint8 kServerResponseOk = 0x5A;
+static const uint8 kServerResponseRejected = 0x5B;
+static const uint8 kServerResponseNotReachable = 0x5C;
+static const uint8 kServerResponseMismatchedUserId = 0x5D;
+
+static const uint8 kSOCKSVersion4 = 0x04;
+static const uint8 kSOCKSStreamRequest = 0x01;
+
+// A struct holding the essential details of the SOCKS4 Server Request.
+// The port in the header is stored in network byte order.
+struct SOCKS4ServerRequest {
+ uint8 version;
+ uint8 command;
+ uint16 nw_port;
+ uint8 ip[4];
+};
+COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
+ socks4_server_request_struct_wrong_size);
+
+// A struct holding details of the SOCKS4 Server Response.
+struct SOCKS4ServerResponse {
+ uint8 reserved_null;
+ uint8 code;
+ uint16 port;
+ uint8 ip[4];
+};
+COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
+ socks4_server_response_struct_wrong_size);
+
+SOCKSClientSocket::SOCKSClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info,
+ HostResolver* host_resolver)
+ : transport_(transport_socket.Pass()),
+ next_state_(STATE_NONE),
+ completed_handshake_(false),
+ bytes_sent_(0),
+ bytes_received_(0),
+ host_resolver_(host_resolver),
+ host_request_info_(req_info),
+ net_log_(transport_->socket()->NetLog()) {
+}
+
+SOCKSClientSocket::~SOCKSClientSocket() {
+ Disconnect();
+}
+
+int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
+ DCHECK(transport_.get());
+ DCHECK(transport_->socket());
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ // If already connected, then just return OK.
+ if (completed_handshake_)
+ return OK;
+
+ next_state_ = STATE_RESOLVE_HOST;
+
+ net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
+
+ int rv = DoLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ user_callback_ = callback;
+ } else {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
+ }
+ return rv;
+}
+
+void SOCKSClientSocket::Disconnect() {
+ completed_handshake_ = false;
+ host_resolver_.Cancel();
+ transport_->socket()->Disconnect();
+
+ // Reset other states to make sure they aren't mistakenly used later.
+ // These are the states initialized by Connect().
+ next_state_ = STATE_NONE;
+ user_callback_.Reset();
+}
+
+bool SOCKSClientSocket::IsConnected() const {
+ return completed_handshake_ && transport_->socket()->IsConnected();
+}
+
+bool SOCKSClientSocket::IsConnectedAndIdle() const {
+ return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
+}
+
+const BoundNetLog& SOCKSClientSocket::NetLog() const {
+ return net_log_;
+}
+
+void SOCKSClientSocket::SetSubresourceSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetSubresourceSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+void SOCKSClientSocket::SetOmniboxSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetOmniboxSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+bool SOCKSClientSocket::WasEverUsed() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->WasEverUsed();
+ }
+ NOTREACHED();
+ return false;
+}
+
+bool SOCKSClientSocket::UsingTCPFastOpen() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->UsingTCPFastOpen();
+ }
+ NOTREACHED();
+ return false;
+}
+
+bool SOCKSClientSocket::WasNpnNegotiated() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->WasNpnNegotiated();
+ }
+ NOTREACHED();
+ return false;
+}
+
+NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->GetNegotiatedProtocol();
+ }
+ NOTREACHED();
+ return kProtoUnknown;
+}
+
+bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->GetSSLInfo(ssl_info);
+ }
+ NOTREACHED();
+ return false;
+
+}
+
+// Read is called by the transport layer above to read. This can only be done
+// if the SOCKS handshake is complete.
+int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(completed_handshake_);
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ return transport_->socket()->Read(buf, buf_len, callback);
+}
+
+// Write is called by the transport layer. This can only be done if the
+// SOCKS handshake is complete.
+int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(completed_handshake_);
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DCHECK(user_callback_.is_null());
+
+ return transport_->socket()->Write(buf, buf_len, callback);
+}
+
+bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
+ return transport_->socket()->SetReceiveBufferSize(size);
+}
+
+bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
+ return transport_->socket()->SetSendBufferSize(size);
+}
+
+void SOCKSClientSocket::DoCallback(int result) {
+ DCHECK_NE(ERR_IO_PENDING, result);
+ DCHECK(!user_callback_.is_null());
+
+ // Since Run() may result in Read being called,
+ // clear user_callback_ up front.
+ CompletionCallback c = user_callback_;
+ user_callback_.Reset();
+ DVLOG(1) << "Finished setting up SOCKS handshake";
+ c.Run(result);
+}
+
+void SOCKSClientSocket::OnIOComplete(int result) {
+ DCHECK_NE(STATE_NONE, next_state_);
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
+ DoCallback(rv);
+ }
+}
+
+int SOCKSClientSocket::DoLoop(int last_io_result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+ int rv = last_io_result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_RESOLVE_HOST:
+ DCHECK_EQ(OK, rv);
+ rv = DoResolveHost();
+ break;
+ case STATE_RESOLVE_HOST_COMPLETE:
+ rv = DoResolveHostComplete(rv);
+ break;
+ case STATE_HANDSHAKE_WRITE:
+ DCHECK_EQ(OK, rv);
+ rv = DoHandshakeWrite();
+ break;
+ case STATE_HANDSHAKE_WRITE_COMPLETE:
+ rv = DoHandshakeWriteComplete(rv);
+ break;
+ case STATE_HANDSHAKE_READ:
+ DCHECK_EQ(OK, rv);
+ rv = DoHandshakeRead();
+ break;
+ case STATE_HANDSHAKE_READ_COMPLETE:
+ rv = DoHandshakeReadComplete(rv);
+ break;
+ default:
+ NOTREACHED() << "bad state";
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+ return rv;
+}
+
+int SOCKSClientSocket::DoResolveHost() {
+ next_state_ = STATE_RESOLVE_HOST_COMPLETE;
+ // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
+ // addresses for the target host.
+ host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
+ return host_resolver_.Resolve(
+ host_request_info_, &addresses_,
+ base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
+ net_log_);
+}
+
+int SOCKSClientSocket::DoResolveHostComplete(int result) {
+ if (result != OK) {
+ // Resolving the hostname failed; fail the request rather than automatically
+ // falling back to SOCKS4a (since it can be confusing to see invalid IP
+ // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
+ return result;
+ }
+
+ next_state_ = STATE_HANDSHAKE_WRITE;
+ return OK;
+}
+
+// Builds the buffer that is to be sent to the server.
+const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
+ SOCKS4ServerRequest request;
+ request.version = kSOCKSVersion4;
+ request.command = kSOCKSStreamRequest;
+ request.nw_port = base::HostToNet16(host_request_info_.port());
+
+ DCHECK(!addresses_.empty());
+ const IPEndPoint& endpoint = addresses_.front();
+
+ // We disabled IPv6 results when resolving the hostname, so none of the
+ // results in the list will be IPv6.
+ // TODO(eroman): we only ever use the first address in the list. It would be
+ // more robust to try all the IP addresses we have before
+ // failing the connect attempt.
+ CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
+ CHECK_LE(endpoint.address().size(), sizeof(request.ip));
+ memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
+
+ DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
+
+ std::string handshake_data(reinterpret_cast<char*>(&request),
+ sizeof(request));
+ handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
+
+ return handshake_data;
+}
+
+// Writes the SOCKS handshake data to the underlying socket connection.
+int SOCKSClientSocket::DoHandshakeWrite() {
+ next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
+
+ if (buffer_.empty()) {
+ buffer_ = BuildHandshakeWriteBuffer();
+ bytes_sent_ = 0;
+ }
+
+ int handshake_buf_len = buffer_.size() - bytes_sent_;
+ DCHECK_GT(handshake_buf_len, 0);
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
+ handshake_buf_len);
+ return transport_->socket()->Write(
+ handshake_buf_.get(),
+ handshake_buf_len,
+ base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
+}
+
+int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
+ if (result < 0)
+ return result;
+
+ // We ignore the case when result is 0, since the underlying Write
+ // may return spurious writes while waiting on the socket.
+
+ bytes_sent_ += result;
+ if (bytes_sent_ == buffer_.size()) {
+ next_state_ = STATE_HANDSHAKE_READ;
+ buffer_.clear();
+ } else if (bytes_sent_ < buffer_.size()) {
+ next_state_ = STATE_HANDSHAKE_WRITE;
+ } else {
+ return ERR_UNEXPECTED;
+ }
+
+ return OK;
+}
+
+int SOCKSClientSocket::DoHandshakeRead() {
+ next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
+
+ if (buffer_.empty()) {
+ bytes_received_ = 0;
+ }
+
+ int handshake_buf_len = kReadHeaderSize - bytes_received_;
+ handshake_buf_ = new IOBuffer(handshake_buf_len);
+ return transport_->socket()->Read(
+ handshake_buf_.get(),
+ handshake_buf_len,
+ base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
+}
+
+int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
+ if (result < 0)
+ return result;
+
+ // The underlying socket closed unexpectedly.
+ if (result == 0)
+ return ERR_CONNECTION_CLOSED;
+
+ if (bytes_received_ + result > kReadHeaderSize) {
+ // TODO(eroman): Describe failure in NetLog.
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ buffer_.append(handshake_buf_->data(), result);
+ bytes_received_ += result;
+ if (bytes_received_ < kReadHeaderSize) {
+ next_state_ = STATE_HANDSHAKE_READ;
+ return OK;
+ }
+
+ const SOCKS4ServerResponse* response =
+ reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
+
+ if (response->reserved_null != 0x00) {
+ LOG(ERROR) << "Unknown response from SOCKS server.";
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ switch (response->code) {
+ case kServerResponseOk:
+ completed_handshake_ = true;
+ return OK;
+ case kServerResponseRejected:
+ LOG(ERROR) << "SOCKS request rejected or failed";
+ return ERR_SOCKS_CONNECTION_FAILED;
+ case kServerResponseNotReachable:
+ LOG(ERROR) << "SOCKS request failed because client is not running "
+ << "identd (or not reachable from the server)";
+ return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
+ case kServerResponseMismatchedUserId:
+ LOG(ERROR) << "SOCKS request failed because client's identd could "
+ << "not confirm the user ID string in the request";
+ return ERR_SOCKS_CONNECTION_FAILED;
+ default:
+ LOG(ERROR) << "SOCKS server sent unknown response";
+ return ERR_SOCKS_CONNECTION_FAILED;
+ }
+
+ // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
+}
+
+int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetPeerAddress(address);
+}
+
+int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetLocalAddress(address);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h
new file mode 100644
index 00000000000..285c75ec295
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket.h
@@ -0,0 +1,134 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_H_
+#define NET_SOCKET_SOCKS_CLIENT_SOCKET_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/gtest_prod_util.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/dns/host_resolver.h"
+#include "net/dns/single_request_host_resolver.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class ClientSocketHandle;
+class BoundNetLog;
+
+// The SOCKS client socket implementation
+class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
+ public:
+ // |req_info| contains the hostname and port to which the socket above will
+ // communicate to via the socks layer. For testing the referrer is optional.
+ SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostResolver::RequestInfo& req_info,
+ HostResolver* host_resolver);
+
+ // On destruction Disconnect() is called.
+ virtual ~SOCKSClientSocket();
+
+ // StreamSocket implementation.
+
+ // Does the SOCKS handshake and completes the protocol.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+
+ private:
+ FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, CompleteHandshake);
+ FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, SOCKS4AFailedDNS);
+ FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, SOCKS4AIfDomainInIPv6);
+
+ enum State {
+ STATE_RESOLVE_HOST,
+ STATE_RESOLVE_HOST_COMPLETE,
+ STATE_HANDSHAKE_WRITE,
+ STATE_HANDSHAKE_WRITE_COMPLETE,
+ STATE_HANDSHAKE_READ,
+ STATE_HANDSHAKE_READ_COMPLETE,
+ STATE_NONE,
+ };
+
+ void DoCallback(int result);
+ void OnIOComplete(int result);
+
+ int DoLoop(int last_io_result);
+ int DoResolveHost();
+ int DoResolveHostComplete(int result);
+ int DoHandshakeRead();
+ int DoHandshakeReadComplete(int result);
+ int DoHandshakeWrite();
+ int DoHandshakeWriteComplete(int result);
+
+ const std::string BuildHandshakeWriteBuffer() const;
+
+ // Stores the underlying socket.
+ scoped_ptr<ClientSocketHandle> transport_;
+
+ State next_state_;
+
+ // Stores the callback to the layer above, called on completing Connect().
+ CompletionCallback user_callback_;
+
+ // This IOBuffer is used by the class to read and write
+ // SOCKS handshake data. The length contains the expected size to
+ // read or write.
+ scoped_refptr<IOBuffer> handshake_buf_;
+
+ // While writing, this buffer stores the complete write handshake data.
+ // While reading, it stores the handshake information received so far.
+ std::string buffer_;
+
+ // This becomes true when the SOCKS handshake has completed and the
+ // overlying connection is free to communicate.
+ bool completed_handshake_;
+
+ // These contain the bytes sent / received by the SOCKS handshake.
+ size_t bytes_sent_;
+ size_t bytes_received_;
+
+ // Used to resolve the hostname to which the SOCKS proxy will connect.
+ SingleRequestHostResolver host_resolver_;
+ AddressList addresses_;
+ HostResolver::RequestInfo host_request_info_;
+
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_H_
diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc
new file mode 100644
index 00000000000..e49eabaa84c
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket_pool.cc
@@ -0,0 +1,310 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks_client_socket_pool.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/time/time.h"
+#include "base/values.h"
+#include "net/base/net_errors.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/socks5_client_socket.h"
+#include "net/socket/socks_client_socket.h"
+#include "net/socket/transport_client_socket_pool.h"
+
+namespace net {
+
+SOCKSSocketParams::SOCKSSocketParams(
+ const scoped_refptr<TransportSocketParams>& proxy_server,
+ bool socks_v5,
+ const HostPortPair& host_port_pair,
+ RequestPriority priority)
+ : transport_params_(proxy_server),
+ destination_(host_port_pair),
+ socks_v5_(socks_v5) {
+ if (transport_params_.get())
+ ignore_limits_ = transport_params_->ignore_limits();
+ else
+ ignore_limits_ = false;
+ destination_.set_priority(priority);
+}
+
+SOCKSSocketParams::~SOCKSSocketParams() {}
+
+// SOCKSConnectJobs will time out after this many seconds. Note this is on
+// top of the timeout for the transport socket.
+static const int kSOCKSConnectJobTimeoutInSeconds = 30;
+
+SOCKSConnectJob::SOCKSConnectJob(
+ const std::string& group_name,
+ const scoped_refptr<SOCKSSocketParams>& socks_params,
+ const base::TimeDelta& timeout_duration,
+ TransportClientSocketPool* transport_pool,
+ HostResolver* host_resolver,
+ Delegate* delegate,
+ NetLog* net_log)
+ : ConnectJob(group_name, timeout_duration, delegate,
+ BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
+ socks_params_(socks_params),
+ transport_pool_(transport_pool),
+ resolver_(host_resolver),
+ callback_(base::Bind(&SOCKSConnectJob::OnIOComplete,
+ base::Unretained(this))) {
+}
+
+SOCKSConnectJob::~SOCKSConnectJob() {
+ // We don't worry about cancelling the tcp socket since the destructor in
+ // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
+ // it.
+}
+
+LoadState SOCKSConnectJob::GetLoadState() const {
+ switch (next_state_) {
+ case STATE_TRANSPORT_CONNECT:
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ return transport_socket_handle_->GetLoadState();
+ case STATE_SOCKS_CONNECT:
+ case STATE_SOCKS_CONNECT_COMPLETE:
+ return LOAD_STATE_CONNECTING;
+ default:
+ NOTREACHED();
+ return LOAD_STATE_IDLE;
+ }
+}
+
+void SOCKSConnectJob::OnIOComplete(int result) {
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING)
+ NotifyDelegateOfCompletion(rv); // Deletes |this|
+}
+
+int SOCKSConnectJob::DoLoop(int result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_TRANSPORT_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoTransportConnect();
+ break;
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ rv = DoTransportConnectComplete(rv);
+ break;
+ case STATE_SOCKS_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoSOCKSConnect();
+ break;
+ case STATE_SOCKS_CONNECT_COMPLETE:
+ rv = DoSOCKSConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED() << "bad state";
+ rv = ERR_FAILED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ return rv;
+}
+
+int SOCKSConnectJob::DoTransportConnect() {
+ next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
+ transport_socket_handle_.reset(new ClientSocketHandle());
+ return transport_socket_handle_->Init(
+ group_name(), socks_params_->transport_params(),
+ socks_params_->destination().priority(), callback_, transport_pool_,
+ net_log());
+}
+
+int SOCKSConnectJob::DoTransportConnectComplete(int result) {
+ if (result != OK)
+ return ERR_PROXY_CONNECTION_FAILED;
+
+ // Reset the timer to just the length of time allowed for SOCKS handshake
+ // so that a fast TCP connection plus a slow SOCKS failure doesn't take
+ // longer to timeout than it should.
+ ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
+ next_state_ = STATE_SOCKS_CONNECT;
+ return result;
+}
+
+int SOCKSConnectJob::DoSOCKSConnect() {
+ next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
+
+ // Add a SOCKS connection on top of the tcp socket.
+ if (socks_params_->is_socks_v5()) {
+ socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(),
+ socks_params_->destination()));
+ } else {
+ socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(),
+ socks_params_->destination(),
+ resolver_));
+ }
+ return socket_->Connect(
+ base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this)));
+}
+
+int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
+ if (result != OK) {
+ socket_->Disconnect();
+ return result;
+ }
+
+ SetSocket(socket_.Pass());
+ return result;
+}
+
+int SOCKSConnectJob::ConnectInternal() {
+ next_state_ = STATE_TRANSPORT_CONNECT;
+ return DoLoop(OK);
+}
+
+scoped_ptr<ConnectJob>
+SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const {
+ return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name,
+ request.params(),
+ ConnectionTimeout(),
+ transport_pool_,
+ host_resolver_,
+ delegate,
+ net_log_));
+}
+
+base::TimeDelta
+SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
+ return transport_pool_->ConnectionTimeout() +
+ base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
+}
+
+SOCKSClientSocketPool::SOCKSClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ TransportClientSocketPool* transport_pool,
+ NetLog* net_log)
+ : transport_pool_(transport_pool),
+ base_(max_sockets, max_sockets_per_group, histograms,
+ ClientSocketPool::unused_idle_socket_timeout(),
+ ClientSocketPool::used_idle_socket_timeout(),
+ new SOCKSConnectJobFactory(transport_pool,
+ host_resolver,
+ net_log)) {
+ // We should always have a |transport_pool_| except in unit tests.
+ if (transport_pool_)
+ transport_pool_->AddLayeredPool(this);
+}
+
+SOCKSClientSocketPool::~SOCKSClientSocketPool() {
+ // We should always have a |transport_pool_| except in unit tests.
+ if (transport_pool_)
+ transport_pool_->RemoveLayeredPool(this);
+}
+
+int SOCKSClientSocketPool::RequestSocket(
+ const std::string& group_name, const void* socket_params,
+ RequestPriority priority, ClientSocketHandle* handle,
+ const CompletionCallback& callback, const BoundNetLog& net_log) {
+ const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
+ static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
+
+ return base_.RequestSocket(group_name, *casted_socket_params, priority,
+ handle, callback, net_log);
+}
+
+void SOCKSClientSocketPool::RequestSockets(
+ const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ const scoped_refptr<SOCKSSocketParams>* casted_params =
+ static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
+
+ base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
+}
+
+void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) {
+ base_.CancelRequest(group_name, handle);
+}
+
+void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
+}
+
+void SOCKSClientSocketPool::FlushWithError(int error) {
+ base_.FlushWithError(error);
+}
+
+bool SOCKSClientSocketPool::IsStalled() const {
+ return base_.IsStalled() || transport_pool_->IsStalled();
+}
+
+void SOCKSClientSocketPool::CloseIdleSockets() {
+ base_.CloseIdleSockets();
+}
+
+int SOCKSClientSocketPool::IdleSocketCount() const {
+ return base_.idle_socket_count();
+}
+
+int SOCKSClientSocketPool::IdleSocketCountInGroup(
+ const std::string& group_name) const {
+ return base_.IdleSocketCountInGroup(group_name);
+}
+
+LoadState SOCKSClientSocketPool::GetLoadState(
+ const std::string& group_name, const ClientSocketHandle* handle) const {
+ return base_.GetLoadState(group_name, handle);
+}
+
+void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
+ base_.AddLayeredPool(layered_pool);
+}
+
+void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
+ base_.RemoveLayeredPool(layered_pool);
+}
+
+base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const {
+ base::DictionaryValue* dict = base_.GetInfoAsValue(name, type);
+ if (include_nested_pools) {
+ base::ListValue* list = new base::ListValue();
+ list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
+ "transport_socket_pool",
+ false));
+ dict->Set("nested_pools", list);
+ }
+ return dict;
+}
+
+base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
+ return base_.ConnectionTimeout();
+}
+
+ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
+ return base_.histograms();
+};
+
+bool SOCKSClientSocketPool::CloseOneIdleConnection() {
+ if (base_.CloseOneIdleSocket())
+ return true;
+ return base_.CloseOneIdleConnectionInLayeredPool();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h
new file mode 100644
index 00000000000..fe69a78df69
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket_pool.h
@@ -0,0 +1,211 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_
+#define NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/host_port_pair.h"
+#include "net/dns/host_resolver.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/client_socket_pool_histograms.h"
+
+namespace net {
+
+class ConnectJobFactory;
+class TransportClientSocketPool;
+class TransportSocketParams;
+
+class NET_EXPORT_PRIVATE SOCKSSocketParams
+ : public base::RefCounted<SOCKSSocketParams> {
+ public:
+ SOCKSSocketParams(const scoped_refptr<TransportSocketParams>& proxy_server,
+ bool socks_v5, const HostPortPair& host_port_pair,
+ RequestPriority priority);
+
+ const scoped_refptr<TransportSocketParams>& transport_params() const {
+ return transport_params_;
+ }
+ const HostResolver::RequestInfo& destination() const { return destination_; }
+ bool is_socks_v5() const { return socks_v5_; }
+ bool ignore_limits() const { return ignore_limits_; }
+
+ private:
+ friend class base::RefCounted<SOCKSSocketParams>;
+ ~SOCKSSocketParams();
+
+ // The transport (likely TCP) connection must point toward the proxy server.
+ const scoped_refptr<TransportSocketParams> transport_params_;
+ // This is the HTTP destination.
+ HostResolver::RequestInfo destination_;
+ const bool socks_v5_;
+ bool ignore_limits_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKSSocketParams);
+};
+
+// SOCKSConnectJob handles the handshake to a socks server after setting up
+// an underlying transport socket.
+class SOCKSConnectJob : public ConnectJob {
+ public:
+ SOCKSConnectJob(const std::string& group_name,
+ const scoped_refptr<SOCKSSocketParams>& params,
+ const base::TimeDelta& timeout_duration,
+ TransportClientSocketPool* transport_pool,
+ HostResolver* host_resolver,
+ Delegate* delegate,
+ NetLog* net_log);
+ virtual ~SOCKSConnectJob();
+
+ // ConnectJob methods.
+ virtual LoadState GetLoadState() const OVERRIDE;
+
+ private:
+ enum State {
+ STATE_TRANSPORT_CONNECT,
+ STATE_TRANSPORT_CONNECT_COMPLETE,
+ STATE_SOCKS_CONNECT,
+ STATE_SOCKS_CONNECT_COMPLETE,
+ STATE_NONE,
+ };
+
+ void OnIOComplete(int result);
+
+ // Runs the state transition loop.
+ int DoLoop(int result);
+
+ int DoTransportConnect();
+ int DoTransportConnectComplete(int result);
+ int DoSOCKSConnect();
+ int DoSOCKSConnectComplete(int result);
+
+ // Begins the transport connection and the SOCKS handshake. Returns OK on
+ // success and ERR_IO_PENDING if it cannot immediately service the request.
+ // Otherwise, it returns a net error code.
+ virtual int ConnectInternal() OVERRIDE;
+
+ scoped_refptr<SOCKSSocketParams> socks_params_;
+ TransportClientSocketPool* const transport_pool_;
+ HostResolver* const resolver_;
+
+ State next_state_;
+ CompletionCallback callback_;
+ scoped_ptr<ClientSocketHandle> transport_socket_handle_;
+ scoped_ptr<StreamSocket> socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJob);
+};
+
+class NET_EXPORT_PRIVATE SOCKSClientSocketPool
+ : public ClientSocketPool, public LayeredPool {
+ public:
+ SOCKSClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ TransportClientSocketPool* transport_pool,
+ NetLog* net_log);
+
+ virtual ~SOCKSClientSocketPool();
+
+ // ClientSocketPool implementation.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* connect_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE;
+
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
+
+ virtual void FlushWithError(int error) OVERRIDE;
+
+ virtual bool IsStalled() const OVERRIDE;
+
+ virtual void CloseIdleSockets() OVERRIDE;
+
+ virtual int IdleSocketCount() const OVERRIDE;
+
+ virtual int IdleSocketCountInGroup(
+ const std::string& group_name) const OVERRIDE;
+
+ virtual LoadState GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const OVERRIDE;
+
+ virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+
+ virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+
+ virtual base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const OVERRIDE;
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+
+ virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+
+ // LayeredPool implementation.
+ virtual bool CloseOneIdleConnection() OVERRIDE;
+
+ private:
+ typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase;
+
+ class SOCKSConnectJobFactory : public PoolBase::ConnectJobFactory {
+ public:
+ SOCKSConnectJobFactory(TransportClientSocketPool* transport_pool,
+ HostResolver* host_resolver,
+ NetLog* net_log)
+ : transport_pool_(transport_pool),
+ host_resolver_(host_resolver),
+ net_log_(net_log) {}
+
+ virtual ~SOCKSConnectJobFactory() {}
+
+ // ClientSocketPoolBase::ConnectJobFactory methods.
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const OVERRIDE;
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+
+ private:
+ TransportClientSocketPool* const transport_pool_;
+ HostResolver* const host_resolver_;
+ NetLog* net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJobFactory);
+ };
+
+ TransportClientSocketPool* const transport_pool_;
+ PoolBase base_;
+
+ DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool);
+};
+
+REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams);
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc
new file mode 100644
index 00000000000..77440d36a19
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc
@@ -0,0 +1,297 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks_client_socket_pool.h"
+
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/time/time.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/socket_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+const int kMaxSockets = 32;
+const int kMaxSocketsPerGroup = 6;
+
+// Make sure |handle|'s load times are set correctly. Only connect times should
+// be set.
+void TestLoadTimingInfo(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ // None of these tests use a NetLog.
+ EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+
+ ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+class SOCKSClientSocketPoolTest : public testing::Test {
+ protected:
+ class SOCKS5MockData {
+ public:
+ explicit SOCKS5MockData(IoMode mode) {
+ writes_.reset(new MockWrite[3]);
+ writes_[0] = MockWrite(mode, kSOCKS5GreetRequest,
+ kSOCKS5GreetRequestLength);
+ writes_[1] = MockWrite(mode, kSOCKS5OkRequest, kSOCKS5OkRequestLength);
+ writes_[2] = MockWrite(mode, 0);
+
+ reads_.reset(new MockRead[3]);
+ reads_[0] = MockRead(mode, kSOCKS5GreetResponse,
+ kSOCKS5GreetResponseLength);
+ reads_[1] = MockRead(mode, kSOCKS5OkResponse, kSOCKS5OkResponseLength);
+ reads_[2] = MockRead(mode, 0);
+
+ data_.reset(new StaticSocketDataProvider(reads_.get(), 3,
+ writes_.get(), 3));
+ }
+
+ SocketDataProvider* data_provider() { return data_.get(); }
+
+ private:
+ scoped_ptr<StaticSocketDataProvider> data_;
+ scoped_ptr<MockWrite[]> writes_;
+ scoped_ptr<MockRead[]> reads_;
+ };
+
+ SOCKSClientSocketPoolTest()
+ : ignored_transport_socket_params_(new TransportSocketParams(
+ HostPortPair("proxy", 80), MEDIUM, false, false,
+ OnHostResolutionCallback())),
+ transport_histograms_("MockTCP"),
+ transport_socket_pool_(
+ kMaxSockets, kMaxSocketsPerGroup,
+ &transport_histograms_,
+ &transport_client_socket_factory_),
+ ignored_socket_params_(new SOCKSSocketParams(
+ ignored_transport_socket_params_, true, HostPortPair("host", 80),
+ MEDIUM)),
+ socks_histograms_("SOCKSUnitTest"),
+ pool_(kMaxSockets, kMaxSocketsPerGroup,
+ &socks_histograms_,
+ NULL,
+ &transport_socket_pool_,
+ NULL) {
+ }
+
+ virtual ~SOCKSClientSocketPoolTest() {}
+
+ int StartRequest(const std::string& group_name, RequestPriority priority) {
+ return test_base_.StartRequestUsingPool(
+ &pool_, group_name, priority, ignored_socket_params_);
+ }
+
+ int GetOrderOfRequest(size_t index) const {
+ return test_base_.GetOrderOfRequest(index);
+ }
+
+ ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
+
+ scoped_refptr<TransportSocketParams> ignored_transport_socket_params_;
+ ClientSocketPoolHistograms transport_histograms_;
+ MockClientSocketFactory transport_client_socket_factory_;
+ MockTransportClientSocketPool transport_socket_pool_;
+
+ scoped_refptr<SOCKSSocketParams> ignored_socket_params_;
+ ClientSocketPoolHistograms socks_histograms_;
+ SOCKSClientSocketPool pool_;
+ ClientSocketPoolTest test_base_;
+};
+
+TEST_F(SOCKSClientSocketPoolTest, Simple) {
+ SOCKS5MockData data(SYNCHRONOUS);
+ data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
+
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+}
+
+TEST_F(SOCKSClientSocketPoolTest, Async) {
+ SOCKS5MockData data(ASYNC);
+ transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+}
+
+TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) {
+ StaticSocketDataProvider socket_data;
+ socket_data.set_connect_data(MockConnect(SYNCHRONOUS,
+ ERR_CONNECTION_REFUSED));
+ transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
+
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+}
+
+TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) {
+ StaticSocketDataProvider socket_data;
+ socket_data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_REFUSED));
+ transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+}
+
+TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) {
+ MockRead failed_read[] = {
+ MockRead(SYNCHRONOUS, 0),
+ };
+ StaticSocketDataProvider socket_data(
+ failed_read, arraysize(failed_read), NULL, 0);
+ socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
+
+ ClientSocketHandle handle;
+ EXPECT_EQ(0, transport_socket_pool_.release_count());
+ int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_EQ(1, transport_socket_pool_.release_count());
+}
+
+TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) {
+ MockRead failed_read[] = {
+ MockRead(ASYNC, 0),
+ };
+ StaticSocketDataProvider socket_data(
+ failed_read, arraysize(failed_read), NULL, 0);
+ socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(0, transport_socket_pool_.release_count());
+ int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
+ &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_EQ(1, transport_socket_pool_.release_count());
+}
+
+TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) {
+ SOCKS5MockData data(SYNCHRONOUS);
+ transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
+ // We need two connections because the pool base lets one cancelled
+ // connect job proceed for potential future use.
+ SOCKS5MockData data2(SYNCHRONOUS);
+ transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
+
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+ int rv = StartRequest("a", LOW);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ rv = StartRequest("a", LOW);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ pool_.CancelRequest("a", (*requests())[0]->handle());
+ pool_.CancelRequest("a", (*requests())[1]->handle());
+ // Requests in the connect phase don't actually get cancelled.
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+
+ // Now wait for the TCP sockets to connect.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+ EXPECT_EQ(2, pool_.IdleSocketCount());
+
+ (*requests())[0]->handle()->Reset();
+ (*requests())[1]->handle()->Reset();
+}
+
+TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) {
+ SOCKS5MockData data(ASYNC);
+ data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
+ // We need two connections because the pool base lets one cancelled
+ // connect job proceed for potential future use.
+ SOCKS5MockData data2(ASYNC);
+ data2.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
+
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+ EXPECT_EQ(0, transport_socket_pool_.release_count());
+ int rv = StartRequest("a", LOW);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ rv = StartRequest("a", LOW);
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ pool_.CancelRequest("a", (*requests())[0]->handle());
+ pool_.CancelRequest("a", (*requests())[1]->handle());
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+ // Requests in the connect phase don't actually get cancelled.
+ EXPECT_EQ(0, transport_socket_pool_.release_count());
+
+ // Now wait for the async data to reach the SOCKS connect jobs.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
+ EXPECT_EQ(0, transport_socket_pool_.cancel_count());
+ EXPECT_EQ(0, transport_socket_pool_.release_count());
+ EXPECT_EQ(2, pool_.IdleSocketCount());
+
+ (*requests())[0]->handle()->Reset();
+ (*requests())[1]->handle()->Reset();
+}
+
+// It would be nice to also test the timeouts in SOCKSClientSocketPool.
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc
new file mode 100644
index 00000000000..8c30838959d
--- /dev/null
+++ b/chromium/net/socket/socks_client_socket_unittest.cc
@@ -0,0 +1,415 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socks_client_socket.h"
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/test_completion_callback.h"
+#include "net/base/winsock_init.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/tcp_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+//-----------------------------------------------------------------------------
+
+namespace net {
+
+const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
+const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
+
+class SOCKSClientSocketTest : public PlatformTest {
+ public:
+ SOCKSClientSocketTest();
+ // Create a SOCKSClientSocket on top of a MockSocket.
+ scoped_ptr<SOCKSClientSocket> BuildMockSocket(
+ MockRead reads[], size_t reads_count,
+ MockWrite writes[], size_t writes_count,
+ HostResolver* host_resolver,
+ const std::string& hostname, int port,
+ NetLog* net_log);
+ virtual void SetUp();
+
+ protected:
+ scoped_ptr<SOCKSClientSocket> user_sock_;
+ AddressList address_list_;
+ // Filled in by BuildMockSocket() and owned by its return value
+ // (which |user_sock| is set to).
+ StreamSocket* tcp_sock_;
+ TestCompletionCallback callback_;
+ scoped_ptr<MockHostResolver> host_resolver_;
+ scoped_ptr<SocketDataProvider> data_;
+};
+
+SOCKSClientSocketTest::SOCKSClientSocketTest()
+ : host_resolver_(new MockHostResolver) {
+}
+
+// Set up platform before every test case
+void SOCKSClientSocketTest::SetUp() {
+ PlatformTest::SetUp();
+}
+
+scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
+ MockRead reads[],
+ size_t reads_count,
+ MockWrite writes[],
+ size_t writes_count,
+ HostResolver* host_resolver,
+ const std::string& hostname,
+ int port,
+ NetLog* net_log) {
+
+ TestCompletionCallback callback;
+ data_.reset(new StaticSocketDataProvider(reads, reads_count,
+ writes, writes_count));
+ tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
+
+ int rv = tcp_sock_->Connect(callback.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(tcp_sock_->IsConnected());
+
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ // |connection| takes ownership of |tcp_sock_|, but keep a
+ // non-owning pointer to it.
+ connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
+ return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
+ connection.Pass(),
+ HostResolver::RequestInfo(HostPortPair(hostname, port)),
+ host_resolver));
+}
+
+// Implementation of HostResolver that never completes its resolve request.
+// We use this in the test "DisconnectWhileHostResolveInProgress" to make
+// sure that the outstanding resolve request gets cancelled.
+class HangingHostResolverWithCancel : public HostResolver {
+ public:
+ HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
+
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) OVERRIDE {
+ DCHECK(addresses);
+ DCHECK_EQ(false, callback.is_null());
+ EXPECT_FALSE(HasOutstandingRequest());
+ outstanding_request_ = reinterpret_cast<RequestHandle>(1);
+ *out_req = outstanding_request_;
+ return ERR_IO_PENDING;
+ }
+
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) OVERRIDE {
+ NOTIMPLEMENTED();
+ return ERR_UNEXPECTED;
+ }
+
+ virtual void CancelRequest(RequestHandle req) OVERRIDE {
+ EXPECT_TRUE(HasOutstandingRequest());
+ EXPECT_EQ(outstanding_request_, req);
+ outstanding_request_ = NULL;
+ }
+
+ bool HasOutstandingRequest() {
+ return outstanding_request_ != NULL;
+ }
+
+ private:
+ RequestHandle outstanding_request_;
+
+ DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
+};
+
+// Tests a complete handshake and the disconnection.
+TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
+ const std::string payload_write = "random data";
+ const std::string payload_read = "moar random data";
+
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
+ MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
+ MockRead(ASYNC, payload_read.data(), payload_read.size()) };
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
+
+ // At this state the TCP connection is completed but not the SOCKS handshake.
+ EXPECT_TRUE(tcp_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnected());
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(
+ LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+ EXPECT_FALSE(user_sock_->IsConnected());
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+
+ scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
+ memcpy(buffer->data(), payload_write.data(), payload_write.size());
+ rv = user_sock_->Write(
+ buffer.get(), payload_write.size(), callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
+
+ buffer = new IOBuffer(payload_read.size());
+ rv =
+ user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
+ EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
+
+ user_sock_->Disconnect();
+ EXPECT_FALSE(tcp_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnected());
+}
+
+// List of responses from the socks server and the errors they should
+// throw up are tested here.
+TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
+ const struct {
+ const char fail_reply[8];
+ Error fail_code;
+ } tests[] = {
+ // Failure of the server response code
+ {
+ { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
+ ERR_SOCKS_CONNECTION_FAILED,
+ },
+ // Failure of the null byte
+ {
+ { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
+ ERR_SOCKS_CONNECTION_FAILED,
+ },
+ };
+
+ //---------------------------------------
+
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
+ MockWrite data_writes[] = {
+ MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
+ MockRead data_reads[] = {
+ MockRead(SYNCHRONOUS, tests[i].fail_reply,
+ arraysize(tests[i].fail_reply)) };
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(tests[i].fail_code, rv);
+ EXPECT_FALSE(user_sock_->IsConnected());
+ EXPECT_TRUE(tcp_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+ }
+}
+
+// Tests scenario when the server sends the handshake response in
+// more than one packet.
+TEST_F(SOCKSClientSocketTest, PartialServerReads) {
+ const char kSOCKSPartialReply1[] = { 0x00 };
+ const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
+
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
+ MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+}
+
+// Tests scenario when the client sends the handshake request in
+// more than one packet.
+TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
+ const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
+ const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
+
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
+ // simulate some empty writes
+ MockWrite(ASYNC, 0),
+ MockWrite(ASYNC, 0),
+ MockWrite(ASYNC, kSOCKSPartialRequest2,
+ arraysize(kSOCKSPartialRequest2)) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(user_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+}
+
+// Tests the case when the server sends a smaller sized handshake data
+// and closes the connection.
+TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
+ MockWrite data_writes[] = {
+ MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
+ MockRead data_reads[] = {
+ MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
+ // close connection unexpectedly
+ MockRead(SYNCHRONOUS, 0) };
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ host_resolver_.get(),
+ "localhost", 80,
+ &log);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
+ EXPECT_FALSE(user_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+}
+
+// Tries to connect to an unknown hostname. Should fail rather than
+// falling back to SOCKS4a.
+TEST_F(SOCKSClientSocketTest, FailedDNS) {
+ const char hostname[] = "unresolved.ipv4.address";
+
+ host_resolver_->rules()->AddSimulatedFailure(hostname);
+
+ CapturingNetLog log;
+
+ user_sock_ = BuildMockSocket(NULL, 0,
+ NULL, 0,
+ host_resolver_.get(),
+ hostname, 80,
+ &log);
+
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(
+ entries, 0, NetLog::TYPE_SOCKS_CONNECT));
+
+ rv = callback_.WaitForResult();
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
+ EXPECT_FALSE(user_sock_->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsEndEvent(
+ entries, -1, NetLog::TYPE_SOCKS_CONNECT));
+}
+
+// Calls Disconnect() while a host resolve is in progress. The outstanding host
+// resolve should be cancelled.
+TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
+ scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
+ new HangingHostResolverWithCancel());
+
+ // Doesn't matter what the socket data is, we will never use it -- garbage.
+ MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
+ MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
+
+ user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes),
+ hanging_resolver.get(),
+ "foo", 80,
+ NULL);
+
+ // Start connecting (will get stuck waiting for the host to resolve).
+ int rv = user_sock_->Connect(callback_.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_FALSE(user_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
+
+ // The host resolver should have received the resolve request.
+ EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
+
+ // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
+ user_sock_->Disconnect();
+
+ EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
+
+ EXPECT_FALSE(user_sock_->IsConnected());
+ EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc
new file mode 100644
index 00000000000..54f66a1f681
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_client_socket.h"
+
+#include "base/strings/string_util.h"
+
+namespace net {
+
+SSLClientSocket::SSLClientSocket()
+ : was_npn_negotiated_(false),
+ was_spdy_negotiated_(false),
+ protocol_negotiated_(kProtoUnknown),
+ channel_id_sent_(false) {
+}
+
+// static
+NextProto SSLClientSocket::NextProtoFromString(
+ const std::string& proto_string) {
+ if (proto_string == "http1.1" || proto_string == "http/1.1") {
+ return kProtoHTTP11;
+ } else if (proto_string == "spdy/1") {
+ return kProtoSPDY1;
+ } else if (proto_string == "spdy/2") {
+ return kProtoSPDY2;
+ } else if (proto_string == "spdy/3") {
+ return kProtoSPDY3;
+ } else if (proto_string == "spdy/3.1") {
+ return kProtoSPDY31;
+ } else if (proto_string == "spdy/4a2") {
+ return kProtoSPDY4a2;
+ } else if (proto_string == "HTTP-draft-04/2.0") {
+ return kProtoHTTP2Draft04;
+ } else if (proto_string == "quic/1+spdy/3") {
+ return kProtoQUIC1SPDY3;
+ } else {
+ return kProtoUnknown;
+ }
+}
+
+// static
+const char* SSLClientSocket::NextProtoToString(NextProto next_proto) {
+ switch (next_proto) {
+ case kProtoHTTP11:
+ return "http/1.1";
+ case kProtoSPDY1:
+ return "spdy/1";
+ case kProtoSPDY2:
+ return "spdy/2";
+ case kProtoSPDY3:
+ return "spdy/3";
+ case kProtoSPDY31:
+ return "spdy/3.1";
+ case kProtoSPDY4a2:
+ return "spdy/4a2";
+ case kProtoHTTP2Draft04:
+ return "HTTP-draft-04/2.0";
+ case kProtoQUIC1SPDY3:
+ return "quic/1+spdy/3";
+ case kProtoSPDY21:
+ case kProtoUnknown:
+ break;
+ }
+ return "unknown";
+}
+
+// static
+const char* SSLClientSocket::NextProtoStatusToString(
+ const SSLClientSocket::NextProtoStatus status) {
+ switch (status) {
+ case kNextProtoUnsupported:
+ return "unsupported";
+ case kNextProtoNegotiated:
+ return "negotiated";
+ case kNextProtoNoOverlap:
+ return "no-overlap";
+ }
+ return NULL;
+}
+
+// static
+std::string SSLClientSocket::ServerProtosToString(
+ const std::string& server_protos) {
+ const char* protos = server_protos.c_str();
+ size_t protos_len = server_protos.length();
+ std::vector<std::string> server_protos_with_commas;
+ for (size_t i = 0; i < protos_len; ) {
+ const size_t len = protos[i];
+ std::string proto_str(&protos[i + 1], len);
+ server_protos_with_commas.push_back(proto_str);
+ i += len + 1;
+ }
+ return JoinString(server_protos_with_commas, ',');
+}
+
+bool SSLClientSocket::WasNpnNegotiated() const {
+ return was_npn_negotiated_;
+}
+
+NextProto SSLClientSocket::GetNegotiatedProtocol() const {
+ return protocol_negotiated_;
+}
+
+bool SSLClientSocket::IgnoreCertError(int error, int load_flags) {
+ if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS)
+ return true;
+
+ if (error == ERR_CERT_COMMON_NAME_INVALID &&
+ (load_flags & LOAD_IGNORE_CERT_COMMON_NAME_INVALID))
+ return true;
+
+ if (error == ERR_CERT_DATE_INVALID &&
+ (load_flags & LOAD_IGNORE_CERT_DATE_INVALID))
+ return true;
+
+ if (error == ERR_CERT_AUTHORITY_INVALID &&
+ (load_flags & LOAD_IGNORE_CERT_AUTHORITY_INVALID))
+ return true;
+
+ return false;
+}
+
+bool SSLClientSocket::set_was_npn_negotiated(bool negotiated) {
+ return was_npn_negotiated_ = negotiated;
+}
+
+bool SSLClientSocket::was_spdy_negotiated() const {
+ return was_spdy_negotiated_;
+}
+
+bool SSLClientSocket::set_was_spdy_negotiated(bool negotiated) {
+ return was_spdy_negotiated_ = negotiated;
+}
+
+void SSLClientSocket::set_protocol_negotiated(NextProto protocol_negotiated) {
+ protocol_negotiated_ = protocol_negotiated;
+}
+
+bool SSLClientSocket::WasChannelIDSent() const {
+ return channel_id_sent_;
+}
+
+void SSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
+ channel_id_sent_ = channel_id_sent;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h
new file mode 100644
index 00000000000..41ee0873347
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket.h
@@ -0,0 +1,141 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_H_
+#define NET_SOCKET_SSL_CLIENT_SOCKET_H_
+
+#include <string>
+
+#include "net/base/completion_callback.h"
+#include "net/base/load_flags.h"
+#include "net/base/net_errors.h"
+#include "net/socket/ssl_socket.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class CertVerifier;
+class ServerBoundCertService;
+class SSLCertRequestInfo;
+class SSLInfo;
+class TransportSecurityState;
+
+// This struct groups together several fields which are used by various
+// classes related to SSLClientSocket.
+struct SSLClientSocketContext {
+ SSLClientSocketContext()
+ : cert_verifier(NULL),
+ server_bound_cert_service(NULL),
+ transport_security_state(NULL) {}
+
+ SSLClientSocketContext(CertVerifier* cert_verifier_arg,
+ ServerBoundCertService* server_bound_cert_service_arg,
+ TransportSecurityState* transport_security_state_arg,
+ const std::string& ssl_session_cache_shard_arg)
+ : cert_verifier(cert_verifier_arg),
+ server_bound_cert_service(server_bound_cert_service_arg),
+ transport_security_state(transport_security_state_arg),
+ ssl_session_cache_shard(ssl_session_cache_shard_arg) {}
+
+ CertVerifier* cert_verifier;
+ ServerBoundCertService* server_bound_cert_service;
+ TransportSecurityState* transport_security_state;
+ // ssl_session_cache_shard is an opaque string that identifies a shard of the
+ // SSL session cache. SSL sockets with the same ssl_session_cache_shard may
+ // resume each other's SSL sessions but we'll never sessions between shards.
+ const std::string ssl_session_cache_shard;
+};
+
+// A client socket that uses SSL as the transport layer.
+//
+// NOTE: The SSL handshake occurs within the Connect method after a TCP
+// connection is established. If a SSL error occurs during the handshake,
+// Connect will fail.
+//
+class NET_EXPORT SSLClientSocket : public SSLSocket {
+ public:
+ SSLClientSocket();
+
+ // Next Protocol Negotiation (NPN) allows a TLS client and server to come to
+ // an agreement about the application level protocol to speak over a
+ // connection.
+ enum NextProtoStatus {
+ // WARNING: These values are serialized to disk. Don't change them.
+
+ kNextProtoUnsupported = 0, // The server doesn't support NPN.
+ kNextProtoNegotiated = 1, // We agreed on a protocol.
+ kNextProtoNoOverlap = 2, // No protocols in common. We requested
+ // the first protocol in our list.
+ };
+
+ // StreamSocket:
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+
+ // Gets the SSL CertificateRequest info of the socket after Connect failed
+ // with ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
+ virtual void GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) = 0;
+
+ // Get the application level protocol that we negotiated with the server.
+ // *proto is set to the resulting protocol (n.b. that the string may have
+ // embedded NULs).
+ // kNextProtoUnsupported: *proto is cleared.
+ // kNextProtoNegotiated: *proto is set to the negotiated protocol.
+ // kNextProtoNoOverlap: *proto is set to the first protocol in the
+ // supported list.
+ // *server_protos is set to the server advertised protocols.
+ virtual NextProtoStatus GetNextProto(std::string* proto,
+ std::string* server_protos) = 0;
+
+ static NextProto NextProtoFromString(const std::string& proto_string);
+
+ static const char* NextProtoToString(NextProto next_proto);
+
+ static const char* NextProtoStatusToString(const NextProtoStatus status);
+
+ // Can be used with the second argument(|server_protos|) of |GetNextProto| to
+ // construct a comma separated string of server advertised protocols.
+ static std::string ServerProtosToString(const std::string& server_protos);
+
+ static bool IgnoreCertError(int error, int load_flags);
+
+ // ClearSessionCache clears the SSL session cache, used to resume SSL
+ // sessions.
+ static void ClearSessionCache();
+
+ virtual bool set_was_npn_negotiated(bool negotiated);
+
+ virtual bool was_spdy_negotiated() const;
+
+ virtual bool set_was_spdy_negotiated(bool negotiated);
+
+ virtual void set_protocol_negotiated(NextProto protocol_negotiated);
+
+ // Returns the ServerBoundCertService used by this socket, or NULL if
+ // server bound certificates are not supported.
+ virtual ServerBoundCertService* GetServerBoundCertService() const = 0;
+
+ // Returns true if a channel ID was sent on this connection.
+ // This may be useful for protocols, like SPDY, which allow the same
+ // connection to be shared between multiple domains, each of which need
+ // a channel ID.
+ virtual bool WasChannelIDSent() const;
+
+ virtual void set_channel_id_sent(bool channel_id_sent);
+
+ private:
+ // True if NPN was responded to, independent of selecting SPDY or HTTP.
+ bool was_npn_negotiated_;
+ // True if NPN successfully negotiated SPDY.
+ bool was_spdy_negotiated_;
+ // Protocol that we negotiated with the server.
+ NextProto protocol_negotiated_;
+ // True if a channel ID was sent.
+ bool channel_id_sent_;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_CLIENT_SOCKET_H_
diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc
new file mode 100644
index 00000000000..acc1b0dee2e
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_nss.cc
@@ -0,0 +1,3504 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file includes code SSLClientSocketNSS::DoVerifyCertComplete() derived
+// from AuthCertificateCallback() in
+// mozilla/security/manager/ssl/src/nsNSSCallbacks.cpp.
+
+/* ***** BEGIN LICENSE BLOCK *****
+ * Version: MPL 1.1/GPL 2.0/LGPL 2.1
+ *
+ * The contents of this file are subject to the Mozilla Public License Version
+ * 1.1 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.mozilla.org/MPL/
+ *
+ * Software distributed under the License is distributed on an "AS IS" basis,
+ * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
+ * for the specific language governing rights and limitations under the
+ * License.
+ *
+ * The Original Code is the Netscape security libraries.
+ *
+ * The Initial Developer of the Original Code is
+ * Netscape Communications Corporation.
+ * Portions created by the Initial Developer are Copyright (C) 2000
+ * the Initial Developer. All Rights Reserved.
+ *
+ * Contributor(s):
+ * Ian McGreer <mcgreer@netscape.com>
+ * Javier Delgadillo <javi@netscape.com>
+ * Kai Engert <kengert@redhat.com>
+ *
+ * Alternatively, the contents of this file may be used under the terms of
+ * either the GNU General Public License Version 2 or later (the "GPL"), or
+ * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
+ * in which case the provisions of the GPL or the LGPL are applicable instead
+ * of those above. If you wish to allow use of your version of this file only
+ * under the terms of either the GPL or the LGPL, and not to allow others to
+ * use your version of this file under the terms of the MPL, indicate your
+ * decision by deleting the provisions above and replace them with the notice
+ * and other provisions required by the GPL or the LGPL. If you do not delete
+ * the provisions above, a recipient may use your version of this file under
+ * the terms of any one of the MPL, the GPL or the LGPL.
+ *
+ * ***** END LICENSE BLOCK ***** */
+
+#include "net/socket/ssl_client_socket_nss.h"
+
+#include <certdb.h>
+#include <hasht.h>
+#include <keyhi.h>
+#include <nspr.h>
+#include <nss.h>
+#include <ocsp.h>
+#include <pk11pub.h>
+#include <secerr.h>
+#include <sechash.h>
+#include <ssl.h>
+#include <sslerr.h>
+#include <sslproto.h>
+
+#include <algorithm>
+#include <limits>
+#include <map>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/callback_helpers.h"
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "base/memory/singleton.h"
+#include "base/metrics/histogram.h"
+#include "base/single_thread_task_runner.h"
+#include "base/stl_util.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/strings/string_util.h"
+#include "base/strings/stringprintf.h"
+#include "base/thread_task_runner_handle.h"
+#include "base/threading/thread_restrictions.h"
+#include "base/values.h"
+#include "crypto/ec_private_key.h"
+#include "crypto/nss_util.h"
+#include "crypto/nss_util_internal.h"
+#include "crypto/rsa_private_key.h"
+#include "crypto/scoped_nss_types.h"
+#include "net/base/address_list.h"
+#include "net/base/connection_type_histograms.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/cert/asn1_util.h"
+#include "net/cert/cert_status_flags.h"
+#include "net/cert/cert_verifier.h"
+#include "net/cert/single_request_cert_verifier.h"
+#include "net/cert/x509_certificate_net_log_param.h"
+#include "net/cert/x509_util.h"
+#include "net/http/transport_security_state.h"
+#include "net/ocsp/nss_ocsp.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/nss_ssl_util.h"
+#include "net/socket/ssl_error_params.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_connection_status_flags.h"
+#include "net/ssl/ssl_info.h"
+
+#if defined(OS_WIN)
+#include <windows.h>
+#include <wincrypt.h>
+
+#include "base/win/windows_version.h"
+#elif defined(OS_MACOSX)
+#include <Security/SecBase.h>
+#include <Security/SecCertificate.h>
+#include <Security/SecIdentity.h>
+
+#include "base/mac/mac_logging.h"
+#include "base/synchronization/lock.h"
+#include "crypto/mac_security_services_lock.h"
+#elif defined(USE_NSS)
+#include <dlfcn.h>
+#endif
+
+namespace net {
+
+// State machines are easier to debug if you log state transitions.
+// Enable these if you want to see what's going on.
+#if 1
+#define EnterFunction(x)
+#define LeaveFunction(x)
+#define GotoState(s) next_handshake_state_ = s
+#else
+#define EnterFunction(x)\
+ VLOG(1) << (void *)this << " " << __FUNCTION__ << " enter " << x\
+ << "; next_handshake_state " << next_handshake_state_
+#define LeaveFunction(x)\
+ VLOG(1) << (void *)this << " " << __FUNCTION__ << " leave " << x\
+ << "; next_handshake_state " << next_handshake_state_
+#define GotoState(s)\
+ do {\
+ VLOG(1) << (void *)this << " " << __FUNCTION__ << " jump to state " << s;\
+ next_handshake_state_ = s;\
+ } while (0)
+#endif
+
+namespace {
+
+// SSL plaintext fragments are shorter than 16KB. Although the record layer
+// overhead is allowed to be 2K + 5 bytes, in practice the overhead is much
+// smaller than 1KB. So a 17KB buffer should be large enough to hold an
+// entire SSL record.
+const int kRecvBufferSize = 17 * 1024;
+const int kSendBufferSize = 17 * 1024;
+
+// Used by SSLClientSocketNSS::Core to indicate there is no read result
+// obtained by a previous operation waiting to be returned to the caller.
+// This constant can be any non-negative/non-zero value (eg: it does not
+// overlap with any value of the net::Error range, including net::OK).
+const int kNoPendingReadResult = 1;
+
+#if defined(OS_WIN)
+// CERT_OCSP_RESPONSE_PROP_ID is only implemented on Vista+, but it can be
+// set on Windows XP without error. There is some overhead from the server
+// sending the OCSP response if it supports the extension, for the subset of
+// XP clients who will request it but be unable to use it, but this is an
+// acceptable trade-off for simplicity of implementation.
+bool IsOCSPStaplingSupported() {
+ return true;
+}
+#elif defined(USE_NSS)
+typedef SECStatus
+(*CacheOCSPResponseFromSideChannelFunction)(
+ CERTCertDBHandle *handle, CERTCertificate *cert, PRTime time,
+ SECItem *encodedResponse, void *pwArg);
+
+// On Linux, we dynamically link against the system version of libnss3.so. In
+// order to continue working on systems without up-to-date versions of NSS we
+// lookup CERT_CacheOCSPResponseFromSideChannel with dlsym.
+
+// RuntimeLibNSSFunctionPointers is a singleton which caches the results of any
+// runtime symbol resolution that we need.
+class RuntimeLibNSSFunctionPointers {
+ public:
+ CacheOCSPResponseFromSideChannelFunction
+ GetCacheOCSPResponseFromSideChannelFunction() {
+ return cache_ocsp_response_from_side_channel_;
+ }
+
+ static RuntimeLibNSSFunctionPointers* GetInstance() {
+ return Singleton<RuntimeLibNSSFunctionPointers>::get();
+ }
+
+ private:
+ friend struct DefaultSingletonTraits<RuntimeLibNSSFunctionPointers>;
+
+ RuntimeLibNSSFunctionPointers() {
+ cache_ocsp_response_from_side_channel_ =
+ (CacheOCSPResponseFromSideChannelFunction)
+ dlsym(RTLD_DEFAULT, "CERT_CacheOCSPResponseFromSideChannel");
+ }
+
+ CacheOCSPResponseFromSideChannelFunction
+ cache_ocsp_response_from_side_channel_;
+};
+
+CacheOCSPResponseFromSideChannelFunction
+GetCacheOCSPResponseFromSideChannelFunction() {
+ return RuntimeLibNSSFunctionPointers::GetInstance()
+ ->GetCacheOCSPResponseFromSideChannelFunction();
+}
+
+bool IsOCSPStaplingSupported() {
+ return GetCacheOCSPResponseFromSideChannelFunction() != NULL;
+}
+#else
+// TODO(agl): Figure out if we can plumb the OCSP response into Mac's system
+// certificate validation functions.
+bool IsOCSPStaplingSupported() {
+ return false;
+}
+#endif
+
+class FreeCERTCertificate {
+ public:
+ inline void operator()(CERTCertificate* x) const {
+ CERT_DestroyCertificate(x);
+ }
+};
+typedef scoped_ptr_malloc<CERTCertificate, FreeCERTCertificate>
+ ScopedCERTCertificate;
+
+#if defined(OS_WIN)
+
+// This callback is intended to be used with CertFindChainInStore. In addition
+// to filtering by extended/enhanced key usage, we do not show expired
+// certificates and require digital signature usage in the key usage
+// extension.
+//
+// This matches our behavior on Mac OS X and that of NSS. It also matches the
+// default behavior of IE8. See http://support.microsoft.com/kb/890326 and
+// http://blogs.msdn.com/b/askie/archive/2009/06/09/my-expired-client-certificates-no-longer-display-when-connecting-to-my-web-server-using-ie8.aspx
+BOOL WINAPI ClientCertFindCallback(PCCERT_CONTEXT cert_context,
+ void* find_arg) {
+ VLOG(1) << "Calling ClientCertFindCallback from _nss";
+ // Verify the certificate's KU is good.
+ BYTE key_usage;
+ if (CertGetIntendedKeyUsage(X509_ASN_ENCODING, cert_context->pCertInfo,
+ &key_usage, 1)) {
+ if (!(key_usage & CERT_DIGITAL_SIGNATURE_KEY_USAGE))
+ return FALSE;
+ } else {
+ DWORD err = GetLastError();
+ // If |err| is non-zero, it's an actual error. Otherwise the extension
+ // just isn't present, and we treat it as if everything was allowed.
+ if (err) {
+ DLOG(ERROR) << "CertGetIntendedKeyUsage failed: " << err;
+ return FALSE;
+ }
+ }
+
+ // Verify the current time is within the certificate's validity period.
+ if (CertVerifyTimeValidity(NULL, cert_context->pCertInfo) != 0)
+ return FALSE;
+
+ // Verify private key metadata is associated with this certificate.
+ DWORD size = 0;
+ if (!CertGetCertificateContextProperty(
+ cert_context, CERT_KEY_PROV_INFO_PROP_ID, NULL, &size)) {
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+#endif
+
+void DestroyCertificates(CERTCertificate** certs, size_t len) {
+ for (size_t i = 0; i < len; i++)
+ CERT_DestroyCertificate(certs[i]);
+}
+
+// Helper functions to make it possible to log events from within the
+// SSLClientSocketNSS::Core.
+void AddLogEvent(const base::WeakPtr<BoundNetLog>& net_log,
+ NetLog::EventType event_type) {
+ if (!net_log)
+ return;
+ net_log->AddEvent(event_type);
+}
+
+// Helper function to make it possible to log events from within the
+// SSLClientSocketNSS::Core.
+void AddLogEventWithCallback(const base::WeakPtr<BoundNetLog>& net_log,
+ NetLog::EventType event_type,
+ const NetLog::ParametersCallback& callback) {
+ if (!net_log)
+ return;
+ net_log->AddEvent(event_type, callback);
+}
+
+// Helper function to make it easier to call BoundNetLog::AddByteTransferEvent
+// from within the SSLClientSocketNSS::Core.
+// AddByteTransferEvent expects to receive a const char*, which within the
+// Core is backed by an IOBuffer. If the "const char*" is bound via
+// base::Bind and posted to another thread, and the IOBuffer that backs that
+// pointer then goes out of scope on the origin thread, this would result in
+// an invalid read of a stale pointer.
+// Instead, provide a signature that accepts an IOBuffer*, so that a reference
+// to the owning IOBuffer can be bound to the Callback. This ensures that the
+// IOBuffer will stay alive long enough to cross threads if needed.
+void LogByteTransferEvent(
+ const base::WeakPtr<BoundNetLog>& net_log, NetLog::EventType event_type,
+ int len, IOBuffer* buffer) {
+ if (!net_log)
+ return;
+ net_log->AddByteTransferEvent(event_type, len, buffer->data());
+}
+
+// PeerCertificateChain is a helper object which extracts the certificate
+// chain, as given by the server, from an NSS socket and performs the needed
+// resource management. The first element of the chain is the leaf certificate
+// and the other elements are in the order given by the server.
+class PeerCertificateChain {
+ public:
+ PeerCertificateChain() {}
+ PeerCertificateChain(const PeerCertificateChain& other);
+ ~PeerCertificateChain();
+ PeerCertificateChain& operator=(const PeerCertificateChain& other);
+
+ // Resets the current chain, freeing any resources, and updates the current
+ // chain to be a copy of the chain stored in |nss_fd|.
+ // If |nss_fd| is NULL, then the current certificate chain will be freed.
+ void Reset(PRFileDesc* nss_fd);
+
+ // Returns the current certificate chain as a vector of DER-encoded
+ // base::StringPieces. The returned vector remains valid until Reset is
+ // called.
+ std::vector<base::StringPiece> AsStringPieceVector() const;
+
+ bool empty() const { return certs_.empty(); }
+ size_t size() const { return certs_.size(); }
+
+ CERTCertificate* operator[](size_t index) const {
+ DCHECK_LT(index, certs_.size());
+ return certs_[index];
+ }
+
+ private:
+ std::vector<CERTCertificate*> certs_;
+};
+
+PeerCertificateChain::PeerCertificateChain(
+ const PeerCertificateChain& other) {
+ *this = other;
+}
+
+PeerCertificateChain::~PeerCertificateChain() {
+ Reset(NULL);
+}
+
+PeerCertificateChain& PeerCertificateChain::operator=(
+ const PeerCertificateChain& other) {
+ if (this == &other)
+ return *this;
+
+ Reset(NULL);
+ certs_.reserve(other.certs_.size());
+ for (size_t i = 0; i < other.certs_.size(); ++i)
+ certs_.push_back(CERT_DupCertificate(other.certs_[i]));
+
+ return *this;
+}
+
+void PeerCertificateChain::Reset(PRFileDesc* nss_fd) {
+ for (size_t i = 0; i < certs_.size(); ++i)
+ CERT_DestroyCertificate(certs_[i]);
+ certs_.clear();
+
+ if (nss_fd == NULL)
+ return;
+
+ unsigned int num_certs = 0;
+ SECStatus rv = SSL_PeerCertificateChain(nss_fd, NULL, &num_certs, 0);
+ DCHECK_EQ(SECSuccess, rv);
+
+ // The handshake on |nss_fd| may not have completed.
+ if (num_certs == 0)
+ return;
+
+ certs_.resize(num_certs);
+ const unsigned int expected_num_certs = num_certs;
+ rv = SSL_PeerCertificateChain(nss_fd, vector_as_array(&certs_),
+ &num_certs, expected_num_certs);
+ DCHECK_EQ(SECSuccess, rv);
+ DCHECK_EQ(expected_num_certs, num_certs);
+}
+
+std::vector<base::StringPiece>
+PeerCertificateChain::AsStringPieceVector() const {
+ std::vector<base::StringPiece> v(certs_.size());
+ for (unsigned i = 0; i < certs_.size(); i++) {
+ v[i] = base::StringPiece(
+ reinterpret_cast<const char*>(certs_[i]->derCert.data),
+ certs_[i]->derCert.len);
+ }
+
+ return v;
+}
+
+// HandshakeState is a helper struct used to pass handshake state between
+// the NSS task runner and the network task runner.
+//
+// It contains members that may be read or written on the NSS task runner,
+// but which also need to be read from the network task runner. The NSS task
+// runner will notify the network task runner whenever this state changes, so
+// that the network task runner can safely make a copy, which avoids the need
+// for locking.
+struct HandshakeState {
+ HandshakeState() { Reset(); }
+
+ void Reset() {
+ next_proto_status = SSLClientSocket::kNextProtoUnsupported;
+ next_proto.clear();
+ server_protos.clear();
+ channel_id_sent = false;
+ server_cert_chain.Reset(NULL);
+ server_cert = NULL;
+ resumed_handshake = false;
+ ssl_connection_status = 0;
+ }
+
+ // Set to kNextProtoNegotiated if NPN was successfully negotiated, with the
+ // negotiated protocol stored in |next_proto|.
+ SSLClientSocket::NextProtoStatus next_proto_status;
+ std::string next_proto;
+ // If the server supports NPN, the protocols supported by the server.
+ std::string server_protos;
+
+ // True if a channel ID was sent.
+ bool channel_id_sent;
+
+ // List of DER-encoded X.509 DistinguishedName of certificate authorities
+ // allowed by the server.
+ std::vector<std::string> cert_authorities;
+
+ // Set when the handshake fully completes.
+ //
+ // The server certificate is first received from NSS as an NSS certificate
+ // chain (|server_cert_chain|) and then converted into a platform-specific
+ // X509Certificate object (|server_cert|). It's possible for some
+ // certificates to be successfully parsed by NSS, and not by the platform
+ // libraries (i.e.: when running within a sandbox, different parsing
+ // algorithms, etc), so it's not safe to assume that |server_cert| will
+ // always be non-NULL.
+ PeerCertificateChain server_cert_chain;
+ scoped_refptr<X509Certificate> server_cert;
+
+ // True if the current handshake was the result of TLS session resumption.
+ bool resumed_handshake;
+
+ // The negotiated security parameters (TLS version, cipher, extensions) of
+ // the SSL connection.
+ int ssl_connection_status;
+};
+
+// Client-side error mapping functions.
+
+// Map NSS error code to network error code.
+int MapNSSClientError(PRErrorCode err) {
+ switch (err) {
+ case SSL_ERROR_BAD_CERT_ALERT:
+ case SSL_ERROR_UNSUPPORTED_CERT_ALERT:
+ case SSL_ERROR_REVOKED_CERT_ALERT:
+ case SSL_ERROR_EXPIRED_CERT_ALERT:
+ case SSL_ERROR_CERTIFICATE_UNKNOWN_ALERT:
+ case SSL_ERROR_UNKNOWN_CA_ALERT:
+ case SSL_ERROR_ACCESS_DENIED_ALERT:
+ return ERR_BAD_SSL_CLIENT_AUTH_CERT;
+ default:
+ return MapNSSError(err);
+ }
+}
+
+// Map NSS error code from the first SSL handshake to network error code.
+int MapNSSClientHandshakeError(PRErrorCode err) {
+ switch (err) {
+ // If the server closed on us, it is a protocol error.
+ // Some TLS-intolerant servers do this when we request TLS.
+ case PR_END_OF_FILE_ERROR:
+ return ERR_SSL_PROTOCOL_ERROR;
+ default:
+ return MapNSSClientError(err);
+ }
+}
+
+} // namespace
+
+// SSLClientSocketNSS::Core provides a thread-safe, ref-counted core that is
+// able to marshal data between NSS functions and an underlying transport
+// socket.
+//
+// All public functions are meant to be called from the network task runner,
+// and any callbacks supplied will be invoked there as well, provided that
+// Detach() has not been called yet.
+//
+/////////////////////////////////////////////////////////////////////////////
+//
+// Threading within SSLClientSocketNSS and SSLClientSocketNSS::Core:
+//
+// Because NSS may block on either hardware or user input during operations
+// such as signing, creating certificates, or locating private keys, the Core
+// handles all of the interactions with the underlying NSS SSL socket, so
+// that these blocking calls can be executed on a dedicated task runner.
+//
+// Note that the network task runner and the NSS task runner may be executing
+// on the same thread. If that happens, then it's more performant to try to
+// complete as much work as possible synchronously, even if it might block,
+// rather than continually PostTask-ing to the same thread.
+//
+// Because NSS functions should only be called on the NSS task runner, while
+// I/O resources should only be accessed on the network task runner, most
+// public functions are implemented via three methods, each with different
+// task runner affinities.
+//
+// In the single-threaded mode (where the network and NSS task runners run on
+// the same thread), these are all attempted synchronously, while in the
+// multi-threaded mode, message passing is used.
+//
+// 1) NSS Task Runner: Execute NSS function (DoPayloadRead, DoPayloadWrite,
+// DoHandshake)
+// 2) NSS Task Runner: Prepare data to go from NSS to an IO function:
+// (BufferRecv, BufferSend)
+// 3) Network Task Runner: Perform IO on that data (DoBufferRecv,
+// DoBufferSend, DoGetDomainBoundCert, OnGetDomainBoundCertComplete)
+// 4) Both Task Runners: Callback for asynchronous completion or to marshal
+// data from the network task runner back to NSS (BufferRecvComplete,
+// BufferSendComplete, OnHandshakeIOComplete)
+//
+/////////////////////////////////////////////////////////////////////////////
+// Single-threaded example
+//
+// |--------------------------Network Task Runner--------------------------|
+// SSLClientSocketNSS Core (Transport Socket)
+// Read()
+// |-------------------------V
+// Read()
+// |
+// DoPayloadRead()
+// |
+// BufferRecv()
+// |
+// DoBufferRecv()
+// |-------------------------V
+// Read()
+// V-------------------------|
+// BufferRecvComplete()
+// |
+// PostOrRunCallback()
+// V-------------------------|
+// (Read Callback)
+//
+/////////////////////////////////////////////////////////////////////////////
+// Multi-threaded example:
+//
+// |--------------------Network Task Runner-------------|--NSS Task Runner--|
+// SSLClientSocketNSS Core Socket Core
+// Read()
+// |---------------------V
+// Read()
+// |-------------------------------V
+// Read()
+// |
+// DoPayloadRead()
+// |
+// BufferRecv
+// V-------------------------------|
+// DoBufferRecv
+// |----------------V
+// Read()
+// V----------------|
+// BufferRecvComplete()
+// |-------------------------------V
+// BufferRecvComplete()
+// |
+// PostOrRunCallback()
+// V-------------------------------|
+// PostOrRunCallback()
+// V---------------------|
+// (Read Callback)
+//
+/////////////////////////////////////////////////////////////////////////////
+class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
+ public:
+ // Creates a new Core.
+ //
+ // Any calls to NSS are executed on the |nss_task_runner|, while any calls
+ // that need to operate on the underlying transport, net log, or server
+ // bound certificate fetching will happen on the |network_task_runner|, so
+ // that their lifetimes match that of the owning SSLClientSocketNSS.
+ //
+ // The caller retains ownership of |transport|, |net_log|, and
+ // |server_bound_cert_service|, and they will not be accessed once Detach()
+ // has been called.
+ Core(base::SequencedTaskRunner* network_task_runner,
+ base::SequencedTaskRunner* nss_task_runner,
+ ClientSocketHandle* transport,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ BoundNetLog* net_log,
+ ServerBoundCertService* server_bound_cert_service);
+
+ // Called on the network task runner.
+ // Transfers ownership of |socket|, an NSS SSL socket, and |buffers|, the
+ // underlying memio implementation, to the Core. Returns true if the Core
+ // was successfully registered with the socket.
+ bool Init(PRFileDesc* socket, memio_Private* buffers);
+
+ // Called on the network task runner.
+ // Sets the predicted certificate chain that the peer will send, for use
+ // with the TLS CachedInfo extension. If called, it must not be called
+ // before Init() or after Connect().
+ void SetPredictedCertificates(
+ const std::vector<std::string>& predicted_certificates);
+
+ // Called on the network task runner.
+ //
+ // Attempts to perform an SSL handshake. If the handshake cannot be
+ // completed synchronously, returns ERR_IO_PENDING, invoking |callback| on
+ // the network task runner once the handshake has completed. Otherwise,
+ // returns OK on success or a network error code on failure.
+ int Connect(const CompletionCallback& callback);
+
+ // Called on the network task runner.
+ // Signals that the resources owned by the network task runner are going
+ // away. No further callbacks will be invoked on the network task runner.
+ // May be called at any time.
+ void Detach();
+
+ // Called on the network task runner.
+ // Returns the current state of the underlying SSL socket. May be called at
+ // any time.
+ const HandshakeState& state() const { return network_handshake_state_; }
+
+ // Called on the network task runner.
+ // Read() and Write() mirror the net::Socket functions of the same name.
+ // If ERR_IO_PENDING is returned, |callback| will be invoked on the network
+ // task runner at a later point, unless the caller calls Detach().
+ int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+
+ // Called on the network task runner.
+ bool IsConnected();
+ bool HasPendingAsyncOperation();
+ bool HasUnhandledReceivedData();
+
+ private:
+ friend class base::RefCountedThreadSafe<Core>;
+ ~Core();
+
+ enum State {
+ STATE_NONE,
+ STATE_HANDSHAKE,
+ STATE_GET_DOMAIN_BOUND_CERT_COMPLETE,
+ };
+
+ bool OnNSSTaskRunner() const;
+ bool OnNetworkTaskRunner() const;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Methods that are ONLY called on the NSS task runner:
+ ////////////////////////////////////////////////////////////////////////////
+
+ // Called by NSS during full handshakes to allow the application to
+ // verify the certificate. Instead of verifying the certificate in the midst
+ // of the handshake, SECSuccess is always returned and the peer's certificate
+ // is verified afterwards.
+ // This behaviour is an artifact of the original SSLClientSocketWin
+ // implementation, which could not verify the peer's certificate until after
+ // the handshake had completed, as well as bugs in NSS that prevent
+ // SSL_RestartHandshakeAfterCertReq from working.
+ static SECStatus OwnAuthCertHandler(void* arg,
+ PRFileDesc* socket,
+ PRBool checksig,
+ PRBool is_server);
+
+ // Callbacks called by NSS when the peer requests client certificate
+ // authentication.
+ // See the documentation in third_party/nss/ssl/ssl.h for the meanings of
+ // the arguments.
+#if defined(NSS_PLATFORM_CLIENT_AUTH)
+ // When NSS has been integrated with awareness of the underlying system
+ // cryptographic libraries, this callback allows the caller to supply a
+ // native platform certificate and key for use by NSS. At most, one of
+ // either (result_certs, result_private_key) or (result_nss_certificate,
+ // result_nss_private_key) should be set.
+ // |arg| contains a pointer to the current SSLClientSocketNSS::Core.
+ static SECStatus PlatformClientAuthHandler(
+ void* arg,
+ PRFileDesc* socket,
+ CERTDistNames* ca_names,
+ CERTCertList** result_certs,
+ void** result_private_key,
+ CERTCertificate** result_nss_certificate,
+ SECKEYPrivateKey** result_nss_private_key);
+#else
+ static SECStatus ClientAuthHandler(void* arg,
+ PRFileDesc* socket,
+ CERTDistNames* ca_names,
+ CERTCertificate** result_certificate,
+ SECKEYPrivateKey** result_private_key);
+#endif
+
+ // Called by NSS once the handshake has completed.
+ // |arg| contains a pointer to the current SSLClientSocketNSS::Core.
+ static void HandshakeCallback(PRFileDesc* socket, void* arg);
+
+ // Handles an NSS error generated while handshaking or performing IO.
+ // Returns a network error code mapped from the original NSS error.
+ int HandleNSSError(PRErrorCode error, bool handshake_error);
+
+ int DoHandshakeLoop(int last_io_result);
+ int DoReadLoop(int result);
+ int DoWriteLoop(int result);
+
+ int DoHandshake();
+ int DoGetDBCertComplete(int result);
+
+ int DoPayloadRead();
+ int DoPayloadWrite();
+
+ bool DoTransportIO();
+ int BufferRecv();
+ int BufferSend();
+
+ void OnRecvComplete(int result);
+ void OnSendComplete(int result);
+
+ void DoConnectCallback(int result);
+ void DoReadCallback(int result);
+ void DoWriteCallback(int result);
+
+ // Client channel ID handler.
+ static SECStatus ClientChannelIDHandler(
+ void* arg,
+ PRFileDesc* socket,
+ SECKEYPublicKey **out_public_key,
+ SECKEYPrivateKey **out_private_key);
+
+ // ImportChannelIDKeys is a helper function for turning a DER-encoded cert and
+ // key into a SECKEYPublicKey and SECKEYPrivateKey. Returns OK upon success
+ // and an error code otherwise.
+ // Requires |domain_bound_private_key_| and |domain_bound_cert_| to have been
+ // set by a call to ServerBoundCertService->GetDomainBoundCert. The caller
+ // takes ownership of the |*cert| and |*key|.
+ int ImportChannelIDKeys(SECKEYPublicKey** public_key, SECKEYPrivateKey** key);
+
+ // Updates the NSS and platform specific certificates.
+ void UpdateServerCert();
+ // Updates the nss_handshake_state_ with the negotiated security parameters.
+ void UpdateConnectionStatus();
+ // Record histograms for channel id support during full handshakes - resumed
+ // handshakes are ignored.
+ void RecordChannelIDSupport();
+ // UpdateNextProto gets any application-layer protocol that may have been
+ // negotiated by the TLS connection.
+ void UpdateNextProto();
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Methods that are ONLY called on the network task runner:
+ ////////////////////////////////////////////////////////////////////////////
+ int DoBufferRecv(IOBuffer* buffer, int len);
+ int DoBufferSend(IOBuffer* buffer, int len);
+ int DoGetDomainBoundCert(const std::string& host);
+
+ void OnGetDomainBoundCertComplete(int result);
+ void OnHandshakeStateUpdated(const HandshakeState& state);
+ void OnNSSBufferUpdated(int amount_in_read_buffer);
+ void DidNSSRead(int result);
+ void DidNSSWrite(int result);
+ void RecordChannelIDSupportOnNetworkTaskRunner(
+ bool negotiated_channel_id,
+ bool channel_id_enabled,
+ bool supports_ecc) const;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Methods that are called on both the network task runner and the NSS
+ // task runner.
+ ////////////////////////////////////////////////////////////////////////////
+ void OnHandshakeIOComplete(int result);
+ void BufferRecvComplete(IOBuffer* buffer, int result);
+ void BufferSendComplete(int result);
+
+ // PostOrRunCallback is a helper function to ensure that |callback| is
+ // invoked on the network task runner, but only if Detach() has not yet
+ // been called.
+ void PostOrRunCallback(const tracked_objects::Location& location,
+ const base::Closure& callback);
+
+ // Uses PostOrRunCallback and |weak_net_log_| to try and log a
+ // SSL_CLIENT_CERT_PROVIDED event, with the indicated count.
+ void AddCertProvidedEvent(int cert_count);
+
+ // Sets the handshake state |channel_id_sent| flag and logs the
+ // SSL_CHANNEL_ID_PROVIDED event.
+ void SetChannelIDProvided();
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Members that are ONLY accessed on the network task runner:
+ ////////////////////////////////////////////////////////////////////////////
+
+ // True if the owning SSLClientSocketNSS has called Detach(). No further
+ // callbacks will be invoked nor access to members owned by the network
+ // task runner.
+ bool detached_;
+
+ // The underlying transport to use for network IO.
+ ClientSocketHandle* transport_;
+ base::WeakPtrFactory<BoundNetLog> weak_net_log_factory_;
+
+ // The current handshake state. Mirrors |nss_handshake_state_|.
+ HandshakeState network_handshake_state_;
+
+ // The service for retrieving Channel ID keys. May be NULL.
+ ServerBoundCertService* server_bound_cert_service_;
+ ServerBoundCertService::RequestHandle domain_bound_cert_request_handle_;
+
+ // The information about NSS task runner.
+ int unhandled_buffer_size_;
+ bool nss_waiting_read_;
+ bool nss_waiting_write_;
+ bool nss_is_closed_;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Members that are ONLY accessed on the NSS task runner:
+ ////////////////////////////////////////////////////////////////////////////
+ HostPortPair host_and_port_;
+ SSLConfig ssl_config_;
+
+ // NSS SSL socket.
+ PRFileDesc* nss_fd_;
+
+ // Buffers for the network end of the SSL state machine
+ memio_Private* nss_bufs_;
+
+ // Used by DoPayloadRead() when attempting to fill the caller's buffer with
+ // as much data as possible, without blocking.
+ // If DoPayloadRead() encounters an error after having read some data, stores
+ // the results to return on the *next* call to DoPayloadRead(). A value of
+ // kNoPendingReadResult indicates there is no pending result, otherwise 0
+ // indicates EOF and < 0 indicates an error.
+ int pending_read_result_;
+ // Contains the previously observed NSS error. Only valid when
+ // pending_read_result_ != kNoPendingReadResult.
+ PRErrorCode pending_read_nss_error_;
+
+ // The certificate chain, in DER form, that is expected to be received from
+ // the server.
+ std::vector<std::string> predicted_certs_;
+
+ State next_handshake_state_;
+
+ // True if channel ID extension was negotiated.
+ bool channel_id_xtn_negotiated_;
+ // True if the handshake state machine was interrupted for channel ID.
+ bool channel_id_needed_;
+ // True if the handshake state machine was interrupted for client auth.
+ bool client_auth_cert_needed_;
+ // True if NSS has called HandshakeCallback.
+ bool handshake_callback_called_;
+
+ HandshakeState nss_handshake_state_;
+
+ bool transport_recv_busy_;
+ bool transport_recv_eof_;
+ bool transport_send_busy_;
+
+ // Used by Read function.
+ scoped_refptr<IOBuffer> user_read_buf_;
+ int user_read_buf_len_;
+
+ // Used by Write function.
+ scoped_refptr<IOBuffer> user_write_buf_;
+ int user_write_buf_len_;
+
+ CompletionCallback user_connect_callback_;
+ CompletionCallback user_read_callback_;
+ CompletionCallback user_write_callback_;
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Members that are accessed on both the network task runner and the NSS
+ // task runner.
+ ////////////////////////////////////////////////////////////////////////////
+ scoped_refptr<base::SequencedTaskRunner> network_task_runner_;
+ scoped_refptr<base::SequencedTaskRunner> nss_task_runner_;
+
+ // Dereferenced only on the network task runner, but bound to tasks destined
+ // for the network task runner from the NSS task runner.
+ base::WeakPtr<BoundNetLog> weak_net_log_;
+
+ // Written on the network task runner by the |server_bound_cert_service_|,
+ // prior to invoking OnHandshakeIOComplete.
+ // Read on the NSS task runner when once OnHandshakeIOComplete is invoked
+ // on the NSS task runner.
+ std::string domain_bound_private_key_;
+ std::string domain_bound_cert_;
+
+ DISALLOW_COPY_AND_ASSIGN(Core);
+};
+
+SSLClientSocketNSS::Core::Core(
+ base::SequencedTaskRunner* network_task_runner,
+ base::SequencedTaskRunner* nss_task_runner,
+ ClientSocketHandle* transport,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ BoundNetLog* net_log,
+ ServerBoundCertService* server_bound_cert_service)
+ : detached_(false),
+ transport_(transport),
+ weak_net_log_factory_(net_log),
+ server_bound_cert_service_(server_bound_cert_service),
+ unhandled_buffer_size_(0),
+ nss_waiting_read_(false),
+ nss_waiting_write_(false),
+ nss_is_closed_(false),
+ host_and_port_(host_and_port),
+ ssl_config_(ssl_config),
+ nss_fd_(NULL),
+ nss_bufs_(NULL),
+ pending_read_result_(kNoPendingReadResult),
+ pending_read_nss_error_(0),
+ next_handshake_state_(STATE_NONE),
+ channel_id_xtn_negotiated_(false),
+ channel_id_needed_(false),
+ client_auth_cert_needed_(false),
+ handshake_callback_called_(false),
+ transport_recv_busy_(false),
+ transport_recv_eof_(false),
+ transport_send_busy_(false),
+ user_read_buf_len_(0),
+ user_write_buf_len_(0),
+ network_task_runner_(network_task_runner),
+ nss_task_runner_(nss_task_runner),
+ weak_net_log_(weak_net_log_factory_.GetWeakPtr()) {
+}
+
+SSLClientSocketNSS::Core::~Core() {
+ // TODO(wtc): Send SSL close_notify alert.
+ if (nss_fd_ != NULL) {
+ PR_Close(nss_fd_);
+ nss_fd_ = NULL;
+ }
+}
+
+bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket,
+ memio_Private* buffers) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK(!nss_fd_);
+ DCHECK(!nss_bufs_);
+
+ nss_fd_ = socket;
+ nss_bufs_ = buffers;
+
+ SECStatus rv = SECSuccess;
+
+ if (!ssl_config_.next_protos.empty()) {
+ size_t wire_length = 0;
+ for (std::vector<std::string>::const_iterator
+ i = ssl_config_.next_protos.begin();
+ i != ssl_config_.next_protos.end(); ++i) {
+ if (i->size() > 255) {
+ LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i;
+ continue;
+ }
+ wire_length += i->size();
+ wire_length++;
+ }
+ scoped_ptr<uint8[]> wire_protos(new uint8[wire_length]);
+ uint8* dst = wire_protos.get();
+ for (std::vector<std::string>::const_iterator
+ i = ssl_config_.next_protos.begin();
+ i != ssl_config_.next_protos.end(); i++) {
+ if (i->size() > 255)
+ continue;
+ *dst++ = i->size();
+ memcpy(dst, i->data(), i->size());
+ dst += i->size();
+ }
+ DCHECK_EQ(dst, wire_protos.get() + wire_length);
+ rv = SSL_SetNextProtoNego(nss_fd_, wire_protos.get(), wire_length);
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(*weak_net_log_, "SSL_SetNextProtoCallback", "");
+ }
+
+ rv = SSL_AuthCertificateHook(
+ nss_fd_, SSLClientSocketNSS::Core::OwnAuthCertHandler, this);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(*weak_net_log_, "SSL_AuthCertificateHook", "");
+ return false;
+ }
+
+#if defined(NSS_PLATFORM_CLIENT_AUTH)
+ rv = SSL_GetPlatformClientAuthDataHook(
+ nss_fd_, SSLClientSocketNSS::Core::PlatformClientAuthHandler,
+ this);
+#else
+ rv = SSL_GetClientAuthDataHook(
+ nss_fd_, SSLClientSocketNSS::Core::ClientAuthHandler, this);
+#endif
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(*weak_net_log_, "SSL_GetClientAuthDataHook", "");
+ return false;
+ }
+
+ if (ssl_config_.channel_id_enabled) {
+ if (!server_bound_cert_service_) {
+ DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID.";
+ } else if (!crypto::ECPrivateKey::IsSupported()) {
+ DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID.";
+ } else if (!server_bound_cert_service_->IsSystemTimeValid()) {
+ DVLOG(1) << "System time is weird, not enabling channel ID.";
+ } else {
+ rv = SSL_SetClientChannelIDCallback(
+ nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this);
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(*weak_net_log_, "SSL_SetClientChannelIDCallback",
+ "");
+ }
+ }
+
+ rv = SSL_HandshakeCallback(
+ nss_fd_, SSLClientSocketNSS::Core::HandshakeCallback, this);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(*weak_net_log_, "SSL_HandshakeCallback", "");
+ return false;
+ }
+
+ return true;
+}
+
+void SSLClientSocketNSS::Core::SetPredictedCertificates(
+ const std::vector<std::string>& predicted_certs) {
+ if (predicted_certs.empty())
+ return;
+
+ if (!OnNSSTaskRunner()) {
+ DCHECK(!detached_);
+ nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&Core::SetPredictedCertificates, this, predicted_certs));
+ return;
+ }
+
+ DCHECK(nss_fd_);
+
+ predicted_certs_ = predicted_certs;
+
+ scoped_ptr<CERTCertificate*[]> certs(
+ new CERTCertificate*[predicted_certs.size()]);
+
+ for (size_t i = 0; i < predicted_certs.size(); i++) {
+ SECItem derCert;
+ derCert.data = const_cast<uint8*>(reinterpret_cast<const uint8*>(
+ predicted_certs[i].data()));
+ derCert.len = predicted_certs[i].size();
+ certs[i] = CERT_NewTempCertificate(
+ CERT_GetDefaultCertDB(), &derCert, NULL /* no nickname given */,
+ PR_FALSE /* not permanent */, PR_TRUE /* copy DER data */);
+ if (!certs[i]) {
+ DestroyCertificates(&certs[0], i);
+ NOTREACHED();
+ return;
+ }
+ }
+
+ SECStatus rv;
+#ifdef SSL_ENABLE_CACHED_INFO
+ rv = SSL_SetPredictedPeerCertificates(nss_fd_, certs.get(),
+ predicted_certs.size());
+ DCHECK_EQ(SECSuccess, rv);
+#else
+ rv = SECFailure; // Not implemented.
+#endif
+ DestroyCertificates(&certs[0], predicted_certs.size());
+
+ if (rv != SECSuccess) {
+ LOG(WARNING) << "SetPredictedCertificates failed: "
+ << host_and_port_.ToString();
+ }
+}
+
+int SSLClientSocketNSS::Core::Connect(const CompletionCallback& callback) {
+ if (!OnNSSTaskRunner()) {
+ DCHECK(!detached_);
+ bool posted = nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(IgnoreResult(&Core::Connect), this, callback));
+ return posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ DCHECK(user_read_callback_.is_null());
+ DCHECK(user_write_callback_.is_null());
+ DCHECK(user_connect_callback_.is_null());
+ DCHECK(!user_read_buf_.get());
+ DCHECK(!user_write_buf_.get());
+
+ next_handshake_state_ = STATE_HANDSHAKE;
+ int rv = DoHandshakeLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ user_connect_callback_ = callback;
+ } else if (rv > OK) {
+ rv = OK;
+ }
+ if (rv != ERR_IO_PENDING && !OnNetworkTaskRunner()) {
+ PostOrRunCallback(FROM_HERE, base::Bind(callback, rv));
+ return ERR_IO_PENDING;
+ }
+
+ return rv;
+}
+
+void SSLClientSocketNSS::Core::Detach() {
+ DCHECK(OnNetworkTaskRunner());
+
+ detached_ = true;
+ transport_ = NULL;
+ weak_net_log_factory_.InvalidateWeakPtrs();
+
+ network_handshake_state_.Reset();
+
+ domain_bound_cert_request_handle_.Cancel();
+}
+
+int SSLClientSocketNSS::Core::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (!OnNSSTaskRunner()) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK(!detached_);
+ DCHECK(transport_);
+ DCHECK(!nss_waiting_read_);
+
+ nss_waiting_read_ = true;
+ bool posted = nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(IgnoreResult(&Core::Read), this, make_scoped_refptr(buf),
+ buf_len, callback));
+ if (!posted) {
+ nss_is_closed_ = true;
+ nss_waiting_read_ = false;
+ }
+ return posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+ DCHECK(handshake_callback_called_);
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ DCHECK(user_read_callback_.is_null());
+ DCHECK(user_connect_callback_.is_null());
+ DCHECK(!user_read_buf_.get());
+ DCHECK(nss_bufs_);
+
+ user_read_buf_ = buf;
+ user_read_buf_len_ = buf_len;
+
+ int rv = DoReadLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ if (OnNetworkTaskRunner())
+ nss_waiting_read_ = true;
+ user_read_callback_ = callback;
+ } else {
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+
+ if (!OnNetworkTaskRunner()) {
+ PostOrRunCallback(FROM_HERE, base::Bind(&Core::DidNSSRead, this, rv));
+ PostOrRunCallback(FROM_HERE, base::Bind(callback, rv));
+ return ERR_IO_PENDING;
+ } else {
+ DCHECK(!nss_waiting_read_);
+ if (rv <= 0)
+ nss_is_closed_ = true;
+ }
+ }
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (!OnNSSTaskRunner()) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK(!detached_);
+ DCHECK(transport_);
+ DCHECK(!nss_waiting_write_);
+
+ nss_waiting_write_ = true;
+ bool posted = nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(IgnoreResult(&Core::Write), this, make_scoped_refptr(buf),
+ buf_len, callback));
+ if (!posted) {
+ nss_is_closed_ = true;
+ nss_waiting_write_ = false;
+ }
+ return posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+ DCHECK(handshake_callback_called_);
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ DCHECK(user_write_callback_.is_null());
+ DCHECK(user_connect_callback_.is_null());
+ DCHECK(!user_write_buf_.get());
+ DCHECK(nss_bufs_);
+
+ user_write_buf_ = buf;
+ user_write_buf_len_ = buf_len;
+
+ int rv = DoWriteLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ if (OnNetworkTaskRunner())
+ nss_waiting_write_ = true;
+ user_write_callback_ = callback;
+ } else {
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+
+ if (!OnNetworkTaskRunner()) {
+ PostOrRunCallback(FROM_HERE, base::Bind(&Core::DidNSSWrite, this, rv));
+ PostOrRunCallback(FROM_HERE, base::Bind(callback, rv));
+ return ERR_IO_PENDING;
+ } else {
+ DCHECK(!nss_waiting_write_);
+ if (rv < 0)
+ nss_is_closed_ = true;
+ }
+ }
+
+ return rv;
+}
+
+bool SSLClientSocketNSS::Core::IsConnected() {
+ DCHECK(OnNetworkTaskRunner());
+ return !nss_is_closed_;
+}
+
+bool SSLClientSocketNSS::Core::HasPendingAsyncOperation() {
+ DCHECK(OnNetworkTaskRunner());
+ return nss_waiting_read_ || nss_waiting_write_;
+}
+
+bool SSLClientSocketNSS::Core::HasUnhandledReceivedData() {
+ DCHECK(OnNetworkTaskRunner());
+ return unhandled_buffer_size_ != 0;
+}
+
+bool SSLClientSocketNSS::Core::OnNSSTaskRunner() const {
+ return nss_task_runner_->RunsTasksOnCurrentThread();
+}
+
+bool SSLClientSocketNSS::Core::OnNetworkTaskRunner() const {
+ return network_task_runner_->RunsTasksOnCurrentThread();
+}
+
+// static
+SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler(
+ void* arg,
+ PRFileDesc* socket,
+ PRBool checksig,
+ PRBool is_server) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ if (!core->handshake_callback_called_) {
+ // Only need to turn off False Start in the initial handshake. Also, it is
+ // unsafe to call SSL_OptionSet in a renegotiation because the "first
+ // handshake" lock isn't already held, which will result in an assertion
+ // failure in the ssl_Get1stHandshakeLock call in SSL_OptionSet.
+ PRBool negotiated_extension;
+ SECStatus rv = SSL_HandshakeNegotiatedExtension(socket,
+ ssl_app_layer_protocol_xtn,
+ &negotiated_extension);
+ if (rv != SECSuccess || !negotiated_extension) {
+ rv = SSL_HandshakeNegotiatedExtension(socket,
+ ssl_next_proto_nego_xtn,
+ &negotiated_extension);
+ }
+ if (rv != SECSuccess || !negotiated_extension) {
+ // If the server doesn't support NPN or ALPN, then we don't do False
+ // Start with it.
+ SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE);
+ }
+ }
+
+ // Tell NSS to not verify the certificate.
+ return SECSuccess;
+}
+
+#if defined(NSS_PLATFORM_CLIENT_AUTH)
+// static
+SECStatus SSLClientSocketNSS::Core::PlatformClientAuthHandler(
+ void* arg,
+ PRFileDesc* socket,
+ CERTDistNames* ca_names,
+ CERTCertList** result_certs,
+ void** result_private_key,
+ CERTCertificate** result_nss_certificate,
+ SECKEYPrivateKey** result_nss_private_key) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ DCHECK(core->OnNSSTaskRunner());
+
+ core->PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEvent, core->weak_net_log_,
+ NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED));
+
+ core->client_auth_cert_needed_ = !core->ssl_config_.send_client_cert;
+#if defined(OS_WIN)
+ if (core->ssl_config_.send_client_cert) {
+ if (core->ssl_config_.client_cert) {
+ PCCERT_CONTEXT cert_context =
+ core->ssl_config_.client_cert->os_cert_handle();
+
+ HCRYPTPROV_OR_NCRYPT_KEY_HANDLE crypt_prov = 0;
+ DWORD key_spec = 0;
+ BOOL must_free = FALSE;
+ DWORD flags = 0;
+ if (base::win::GetVersion() >= base::win::VERSION_VISTA)
+ flags |= CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG;
+
+ BOOL acquired_key = CryptAcquireCertificatePrivateKey(
+ cert_context, flags, NULL, &crypt_prov, &key_spec, &must_free);
+
+ if (acquired_key) {
+ // Should never get a cached handle back - ownership must always be
+ // transferred.
+ CHECK_EQ(must_free, TRUE);
+
+ SECItem der_cert;
+ der_cert.type = siDERCertBuffer;
+ der_cert.data = cert_context->pbCertEncoded;
+ der_cert.len = cert_context->cbCertEncoded;
+
+ // TODO(rsleevi): Error checking for NSS allocation errors.
+ CERTCertDBHandle* db_handle = CERT_GetDefaultCertDB();
+ CERTCertificate* user_cert = CERT_NewTempCertificate(
+ db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE);
+ if (!user_cert) {
+ // Importing the certificate can fail for reasons including a serial
+ // number collision. See crbug.com/97355.
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+ }
+ CERTCertList* cert_chain = CERT_NewCertList();
+ CERT_AddCertToListTail(cert_chain, user_cert);
+
+ // Add the intermediates.
+ X509Certificate::OSCertHandles intermediates =
+ core->ssl_config_.client_cert->GetIntermediateCertificates();
+ for (X509Certificate::OSCertHandles::const_iterator it =
+ intermediates.begin(); it != intermediates.end(); ++it) {
+ der_cert.data = (*it)->pbCertEncoded;
+ der_cert.len = (*it)->cbCertEncoded;
+
+ CERTCertificate* intermediate = CERT_NewTempCertificate(
+ db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE);
+ if (!intermediate) {
+ CERT_DestroyCertList(cert_chain);
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+ }
+ CERT_AddCertToListTail(cert_chain, intermediate);
+ }
+ PCERT_KEY_CONTEXT key_context = reinterpret_cast<PCERT_KEY_CONTEXT>(
+ PORT_ZAlloc(sizeof(CERT_KEY_CONTEXT)));
+ key_context->cbSize = sizeof(*key_context);
+ // NSS will free this context when no longer in use.
+ key_context->hCryptProv = crypt_prov;
+ key_context->dwKeySpec = key_spec;
+ *result_private_key = key_context;
+ *result_certs = cert_chain;
+
+ int cert_count = 1 + intermediates.size();
+ core->AddCertProvidedEvent(cert_count);
+ return SECSuccess;
+ }
+ LOG(WARNING) << "Client cert found without private key";
+ }
+
+ // Send no client certificate.
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+ }
+
+ core->nss_handshake_state_.cert_authorities.clear();
+
+ std::vector<CERT_NAME_BLOB> issuer_list(ca_names->nnames);
+ for (int i = 0; i < ca_names->nnames; ++i) {
+ issuer_list[i].cbData = ca_names->names[i].len;
+ issuer_list[i].pbData = ca_names->names[i].data;
+ core->nss_handshake_state_.cert_authorities.push_back(std::string(
+ reinterpret_cast<const char*>(ca_names->names[i].data),
+ static_cast<size_t>(ca_names->names[i].len)));
+ }
+
+ // Update the network task runner's view of the handshake state now that
+ // server certificate request has been recorded.
+ core->PostOrRunCallback(
+ FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core,
+ core->nss_handshake_state_));
+
+ // Tell NSS to suspend the client authentication. We will then abort the
+ // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
+ return SECWouldBlock;
+#elif defined(OS_MACOSX)
+ if (core->ssl_config_.send_client_cert) {
+ if (core->ssl_config_.client_cert.get()) {
+ OSStatus os_error = noErr;
+ SecIdentityRef identity = NULL;
+ SecKeyRef private_key = NULL;
+ X509Certificate::OSCertHandles chain;
+ {
+ base::AutoLock lock(crypto::GetMacSecurityServicesLock());
+ os_error = SecIdentityCreateWithCertificate(
+ NULL, core->ssl_config_.client_cert->os_cert_handle(), &identity);
+ }
+ if (os_error == noErr) {
+ os_error = SecIdentityCopyPrivateKey(identity, &private_key);
+ CFRelease(identity);
+ }
+
+ if (os_error == noErr) {
+ // TODO(rsleevi): Error checking for NSS allocation errors.
+ *result_certs = CERT_NewCertList();
+ *result_private_key = private_key;
+
+ chain.push_back(core->ssl_config_.client_cert->os_cert_handle());
+ const X509Certificate::OSCertHandles& intermediates =
+ core->ssl_config_.client_cert->GetIntermediateCertificates();
+ if (!intermediates.empty())
+ chain.insert(chain.end(), intermediates.begin(), intermediates.end());
+
+ for (size_t i = 0, chain_count = chain.size(); i < chain_count; ++i) {
+ CSSM_DATA cert_data;
+ SecCertificateRef cert_ref = chain[i];
+ os_error = SecCertificateGetData(cert_ref, &cert_data);
+ if (os_error != noErr)
+ break;
+
+ SECItem der_cert;
+ der_cert.type = siDERCertBuffer;
+ der_cert.data = cert_data.Data;
+ der_cert.len = cert_data.Length;
+ CERTCertificate* nss_cert = CERT_NewTempCertificate(
+ CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE);
+ if (!nss_cert) {
+ // In the event of an NSS error, make up an OS error and reuse
+ // the error handling below.
+ os_error = errSecCreateChainFailed;
+ break;
+ }
+ CERT_AddCertToListTail(*result_certs, nss_cert);
+ }
+ }
+
+ if (os_error == noErr) {
+ core->AddCertProvidedEvent(chain.size());
+ return SECSuccess;
+ }
+
+ OSSTATUS_LOG(WARNING, os_error)
+ << "Client cert found, but could not be used";
+ if (*result_certs) {
+ CERT_DestroyCertList(*result_certs);
+ *result_certs = NULL;
+ }
+ if (*result_private_key)
+ *result_private_key = NULL;
+ if (private_key)
+ CFRelease(private_key);
+ }
+
+ // Send no client certificate.
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+ }
+
+ core->nss_handshake_state_.cert_authorities.clear();
+
+ // Retrieve the cert issuers accepted by the server.
+ std::vector<CertPrincipal> valid_issuers;
+ int n = ca_names->nnames;
+ for (int i = 0; i < n; i++) {
+ core->nss_handshake_state_.cert_authorities.push_back(std::string(
+ reinterpret_cast<const char*>(ca_names->names[i].data),
+ static_cast<size_t>(ca_names->names[i].len)));
+ }
+
+ // Update the network task runner's view of the handshake state now that
+ // server certificate request has been recorded.
+ core->PostOrRunCallback(
+ FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core,
+ core->nss_handshake_state_));
+
+ // Tell NSS to suspend the client authentication. We will then abort the
+ // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
+ return SECWouldBlock;
+#else
+ return SECFailure;
+#endif
+}
+
+#elif defined(OS_IOS)
+
+SECStatus SSLClientSocketNSS::Core::ClientAuthHandler(
+ void* arg,
+ PRFileDesc* socket,
+ CERTDistNames* ca_names,
+ CERTCertificate** result_certificate,
+ SECKEYPrivateKey** result_private_key) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ DCHECK(core->OnNSSTaskRunner());
+
+ core->PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEvent, core->weak_net_log_,
+ NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED));
+
+ // TODO(droger): Support client auth on iOS. See http://crbug.com/145954).
+ LOG(WARNING) << "Client auth is not supported";
+
+ // Never send a certificate.
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+}
+
+#else // NSS_PLATFORM_CLIENT_AUTH
+
+// static
+// Based on Mozilla's NSS_GetClientAuthData.
+SECStatus SSLClientSocketNSS::Core::ClientAuthHandler(
+ void* arg,
+ PRFileDesc* socket,
+ CERTDistNames* ca_names,
+ CERTCertificate** result_certificate,
+ SECKEYPrivateKey** result_private_key) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ DCHECK(core->OnNSSTaskRunner());
+
+ core->PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEvent, core->weak_net_log_,
+ NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED));
+
+ // Regular client certificate requested.
+ core->client_auth_cert_needed_ = !core->ssl_config_.send_client_cert;
+ void* wincx = SSL_RevealPinArg(socket);
+
+ if (core->ssl_config_.send_client_cert) {
+ // Second pass: a client certificate should have been selected.
+ if (core->ssl_config_.client_cert.get()) {
+ CERTCertificate* cert =
+ CERT_DupCertificate(core->ssl_config_.client_cert->os_cert_handle());
+ SECKEYPrivateKey* privkey = PK11_FindKeyByAnyCert(cert, wincx);
+ if (privkey) {
+ // TODO(jsorianopastor): We should wait for server certificate
+ // verification before sending our credentials. See
+ // http://crbug.com/13934.
+ *result_certificate = cert;
+ *result_private_key = privkey;
+ // A cert_count of -1 means the number of certificates is unknown.
+ // NSS will construct the certificate chain.
+ core->AddCertProvidedEvent(-1);
+
+ return SECSuccess;
+ }
+ LOG(WARNING) << "Client cert found without private key";
+ }
+ // Send no client certificate.
+ core->AddCertProvidedEvent(0);
+ return SECFailure;
+ }
+
+ // First pass: client certificate is needed.
+ core->nss_handshake_state_.cert_authorities.clear();
+
+ // Retrieve the DER-encoded DistinguishedName of the cert issuers accepted by
+ // the server and save them in |cert_authorities|.
+ for (int i = 0; i < ca_names->nnames; i++) {
+ core->nss_handshake_state_.cert_authorities.push_back(std::string(
+ reinterpret_cast<const char*>(ca_names->names[i].data),
+ static_cast<size_t>(ca_names->names[i].len)));
+ }
+
+ // Update the network task runner's view of the handshake state now that
+ // server certificate request has been recorded.
+ core->PostOrRunCallback(
+ FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core,
+ core->nss_handshake_state_));
+
+ // Tell NSS to suspend the client authentication. We will then abort the
+ // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
+ return SECWouldBlock;
+}
+#endif // NSS_PLATFORM_CLIENT_AUTH
+
+// static
+void SSLClientSocketNSS::Core::HandshakeCallback(
+ PRFileDesc* socket,
+ void* arg) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ DCHECK(core->OnNSSTaskRunner());
+
+ core->handshake_callback_called_ = true;
+
+ HandshakeState* nss_state = &core->nss_handshake_state_;
+
+ PRBool last_handshake_resumed;
+ SECStatus rv = SSL_HandshakeResumedSession(socket, &last_handshake_resumed);
+ if (rv == SECSuccess && last_handshake_resumed) {
+ nss_state->resumed_handshake = true;
+ } else {
+ nss_state->resumed_handshake = false;
+ }
+
+ core->RecordChannelIDSupport();
+ core->UpdateServerCert();
+ core->UpdateConnectionStatus();
+ core->UpdateNextProto();
+
+ // Update the network task runners view of the handshake state whenever
+ // a handshake has completed.
+ core->PostOrRunCallback(
+ FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core,
+ *nss_state));
+}
+
+int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error,
+ bool handshake_error) {
+ DCHECK(OnNSSTaskRunner());
+
+ int net_error = handshake_error ? MapNSSClientHandshakeError(nss_error) :
+ MapNSSClientError(nss_error);
+
+#if defined(OS_WIN)
+ // On Windows, a handle to the HCRYPTPROV is cached in the X509Certificate
+ // os_cert_handle() as an optimization. However, if the certificate
+ // private key is stored on a smart card, and the smart card is removed,
+ // the cached HCRYPTPROV will not be able to obtain the HCRYPTKEY again,
+ // preventing client certificate authentication. Because the
+ // X509Certificate may outlive the individual SSLClientSocketNSS, due to
+ // caching in X509Certificate, this failure ends up preventing client
+ // certificate authentication with the same certificate for all future
+ // attempts, even after the smart card has been re-inserted. By setting
+ // the CERT_KEY_PROV_HANDLE_PROP_ID to NULL, the cached HCRYPTPROV will
+ // typically be freed. This allows a new HCRYPTPROV to be obtained from
+ // the certificate on the next attempt, which should succeed if the smart
+ // card has been re-inserted, or will typically prompt the user to
+ // re-insert the smart card if not.
+ if ((net_error == ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY ||
+ net_error == ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED) &&
+ ssl_config_.send_client_cert && ssl_config_.client_cert) {
+ CertSetCertificateContextProperty(
+ ssl_config_.client_cert->os_cert_handle(),
+ CERT_KEY_PROV_HANDLE_PROP_ID, 0, NULL);
+ }
+#endif
+
+ return net_error;
+}
+
+int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) {
+ DCHECK(OnNSSTaskRunner());
+
+ int rv = last_io_result;
+ do {
+ // Default to STATE_NONE for next state.
+ State state = next_handshake_state_;
+ GotoState(STATE_NONE);
+
+ switch (state) {
+ case STATE_HANDSHAKE:
+ rv = DoHandshake();
+ break;
+ case STATE_GET_DOMAIN_BOUND_CERT_COMPLETE:
+ rv = DoGetDBCertComplete(rv);
+ break;
+ case STATE_NONE:
+ default:
+ rv = ERR_UNEXPECTED;
+ LOG(DFATAL) << "unexpected state " << state;
+ break;
+ }
+
+ // Do the actual network I/O
+ bool network_moved = DoTransportIO();
+ if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) {
+ // In general we exit the loop if rv is ERR_IO_PENDING. In this
+ // special case we keep looping even if rv is ERR_IO_PENDING because
+ // the transport IO may allow DoHandshake to make progress.
+ DCHECK(rv == OK || rv == ERR_IO_PENDING);
+ rv = OK; // This causes us to stay in the loop.
+ }
+ } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoReadLoop(int result) {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK(handshake_callback_called_);
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+
+ if (result < 0)
+ return result;
+
+ if (!nss_bufs_) {
+ LOG(DFATAL) << "!nss_bufs_";
+ int rv = ERR_UNEXPECTED;
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, 0)));
+ return rv;
+ }
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadRead();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoWriteLoop(int result) {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK(handshake_callback_called_);
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+
+ if (result < 0)
+ return result;
+
+ if (!nss_bufs_) {
+ LOG(DFATAL) << "!nss_bufs_";
+ int rv = ERR_UNEXPECTED;
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, 0)));
+ return rv;
+ }
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadWrite();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+
+ LeaveFunction(rv);
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoHandshake() {
+ DCHECK(OnNSSTaskRunner());
+
+ int net_error = net::OK;
+ SECStatus rv = SSL_ForceHandshake(nss_fd_);
+
+ // Note: this function may be called multiple times during the handshake, so
+ // even though channel id and client auth are separate else cases, they can
+ // both be used during a single SSL handshake.
+ if (channel_id_needed_) {
+ GotoState(STATE_GET_DOMAIN_BOUND_CERT_COMPLETE);
+ net_error = ERR_IO_PENDING;
+ } else if (client_auth_cert_needed_) {
+ net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED;
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogSSLErrorCallback(net_error, 0)));
+
+ // If the handshake already succeeded (because the server requests but
+ // doesn't require a client cert), we need to invalidate the SSL session
+ // so that we won't try to resume the non-client-authenticated session in
+ // the next handshake. This will cause the server to ask for a client
+ // cert again.
+ if (rv == SECSuccess && SSL_InvalidateSession(nss_fd_) != SECSuccess)
+ LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError();
+ } else if (rv == SECSuccess) {
+ if (!handshake_callback_called_) {
+ // Workaround for https://bugzilla.mozilla.org/show_bug.cgi?id=562434 -
+ // SSL_ForceHandshake returned SECSuccess prematurely.
+ rv = SECFailure;
+ net_error = ERR_SSL_PROTOCOL_ERROR;
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogSSLErrorCallback(net_error, 0)));
+ } else {
+ #if defined(SSL_ENABLE_OCSP_STAPLING)
+ // TODO(agl): figure out how to plumb an OCSP response into the Mac
+ // system library and update IsOCSPStaplingSupported for Mac.
+ if (IsOCSPStaplingSupported()) {
+ const SECItemArray* ocsp_responses =
+ SSL_PeerStapledOCSPResponses(nss_fd_);
+ if (ocsp_responses->len) {
+ #if defined(OS_WIN)
+ if (nss_handshake_state_.server_cert) {
+ CRYPT_DATA_BLOB ocsp_response_blob;
+ ocsp_response_blob.cbData = ocsp_responses->items[0].len;
+ ocsp_response_blob.pbData = ocsp_responses->items[0].data;
+ BOOL ok = CertSetCertificateContextProperty(
+ nss_handshake_state_.server_cert->os_cert_handle(),
+ CERT_OCSP_RESPONSE_PROP_ID,
+ CERT_SET_PROPERTY_IGNORE_PERSIST_ERROR_FLAG,
+ &ocsp_response_blob);
+ if (!ok) {
+ VLOG(1) << "Failed to set OCSP response property: "
+ << GetLastError();
+ }
+ }
+ #elif defined(USE_NSS)
+ CacheOCSPResponseFromSideChannelFunction cache_ocsp_response =
+ GetCacheOCSPResponseFromSideChannelFunction();
+
+ cache_ocsp_response(
+ CERT_GetDefaultCertDB(),
+ nss_handshake_state_.server_cert_chain[0], PR_Now(),
+ &ocsp_responses->items[0], NULL);
+ #endif
+ }
+ }
+ #endif
+ }
+ // Done!
+ } else {
+ PRErrorCode prerr = PR_GetError();
+ net_error = HandleNSSError(prerr, true);
+
+ // Some network devices that inspect application-layer packets seem to
+ // inject TCP reset packets to break the connections when they see
+ // TLS 1.1 in ClientHello or ServerHello. See http://crbug.com/130293.
+ //
+ // Only allow ERR_CONNECTION_RESET to trigger a fallback from TLS 1.1 or
+ // 1.2. We don't lose much in this fallback because the explicit IV for CBC
+ // mode in TLS 1.1 is approximated by record splitting in TLS 1.0. The
+ // fallback will be more painful for TLS 1.2 when we have GCM support.
+ //
+ // ERR_CONNECTION_RESET is a common network error, so we don't want it
+ // to trigger a version fallback in general, especially the TLS 1.0 ->
+ // SSL 3.0 fallback, which would drop TLS extensions.
+ if (prerr == PR_CONNECT_RESET_ERROR &&
+ ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1) {
+ net_error = ERR_SSL_PROTOCOL_ERROR;
+ }
+
+ // If not done, stay in this state
+ if (net_error == ERR_IO_PENDING) {
+ GotoState(STATE_HANDSHAKE);
+ } else {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogSSLErrorCallback(net_error, prerr)));
+ }
+ }
+
+ return net_error;
+}
+
+int SSLClientSocketNSS::Core::DoGetDBCertComplete(int result) {
+ SECStatus rv;
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&BoundNetLog::EndEventWithNetErrorCode, weak_net_log_,
+ NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT, result));
+
+ channel_id_needed_ = false;
+
+ if (result != OK)
+ return result;
+
+ SECKEYPublicKey* public_key;
+ SECKEYPrivateKey* private_key;
+ int error = ImportChannelIDKeys(&public_key, &private_key);
+ if (error != OK)
+ return error;
+
+ rv = SSL_RestartHandshakeAfterChannelIDReq(nss_fd_, public_key, private_key);
+ if (rv != SECSuccess)
+ return MapNSSError(PORT_GetError());
+
+ SetChannelIDProvided();
+ GotoState(STATE_HANDSHAKE);
+ return OK;
+}
+
+int SSLClientSocketNSS::Core::DoPayloadRead() {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK(user_read_buf_.get());
+ DCHECK_GT(user_read_buf_len_, 0);
+
+ int rv;
+ // If a previous greedy read resulted in an error that was not consumed (eg:
+ // due to the caller having read some data successfully), then return that
+ // pending error now.
+ if (pending_read_result_ != kNoPendingReadResult) {
+ rv = pending_read_result_;
+ PRErrorCode prerr = pending_read_nss_error_;
+ pending_read_result_ = kNoPendingReadResult;
+ pending_read_nss_error_ = 0;
+
+ if (rv == 0) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&LogByteTransferEvent, weak_net_log_,
+ NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv,
+ scoped_refptr<IOBuffer>(user_read_buf_)));
+ } else {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, prerr)));
+ }
+ return rv;
+ }
+
+ // Perform a greedy read, attempting to read as much as the caller has
+ // requested. In the current NSS implementation, PR_Read will return
+ // exactly one SSL application data record's worth of data per invocation.
+ // The record size is dictated by the server, and may be noticeably smaller
+ // than the caller's buffer. This may be as little as a single byte, if the
+ // server is performing 1/n-1 record splitting.
+ //
+ // However, this greedy read may result in renegotiations/re-handshakes
+ // happening or may lead to some data being read, followed by an EOF (such as
+ // a TLS close-notify). If at least some data was read, then that result
+ // should be deferred until the next call to DoPayloadRead(). Otherwise, if no
+ // data was read, it's safe to return the error or EOF immediately.
+ int total_bytes_read = 0;
+ do {
+ rv = PR_Read(nss_fd_, user_read_buf_->data() + total_bytes_read,
+ user_read_buf_len_ - total_bytes_read);
+ if (rv > 0)
+ total_bytes_read += rv;
+ } while (total_bytes_read < user_read_buf_len_ && rv > 0);
+ int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_);
+ PostOrRunCallback(FROM_HERE, base::Bind(&Core::OnNSSBufferUpdated, this,
+ amount_in_read_buffer));
+
+ if (total_bytes_read == user_read_buf_len_) {
+ // The caller's entire request was satisfied without error. No further
+ // processing needed.
+ rv = total_bytes_read;
+ } else {
+ // Otherwise, an error occurred (rv <= 0). The error needs to be handled
+ // immediately, while the NSPR/NSS errors are still available in
+ // thread-local storage. However, the handled/remapped error code should
+ // only be returned if no application data was already read; if it was, the
+ // error code should be deferred until the next call of DoPayloadRead.
+ //
+ // If no data was read, |*next_result| will point to the return value of
+ // this function. If at least some data was read, |*next_result| will point
+ // to |pending_read_error_|, to be returned in a future call to
+ // DoPayloadRead() (e.g.: after the current data is handled).
+ int* next_result = &rv;
+ if (total_bytes_read > 0) {
+ pending_read_result_ = rv;
+ rv = total_bytes_read;
+ next_result = &pending_read_result_;
+ }
+
+ if (client_auth_cert_needed_) {
+ *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED;
+ pending_read_nss_error_ = 0;
+ } else if (*next_result < 0) {
+ // If *next_result == 0, then that indicates EOF, and no special error
+ // handling is needed.
+ pending_read_nss_error_ = PR_GetError();
+ *next_result = HandleNSSError(pending_read_nss_error_, false);
+ if (rv > 0 && *next_result == ERR_IO_PENDING) {
+ // If at least some data was read from PR_Read(), do not treat
+ // insufficient data as an error to return in the next call to
+ // DoPayloadRead() - instead, let the call fall through to check
+ // PR_Read() again. This is because DoTransportIO() may complete
+ // in between the next call to DoPayloadRead(), and thus it is
+ // important to check PR_Read() on subsequent invocations to see
+ // if a complete record may now be read.
+ pending_read_nss_error_ = 0;
+ pending_read_result_ = kNoPendingReadResult;
+ }
+ }
+ }
+
+ DCHECK_NE(ERR_IO_PENDING, pending_read_result_);
+
+ if (rv >= 0) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&LogByteTransferEvent, weak_net_log_,
+ NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv,
+ scoped_refptr<IOBuffer>(user_read_buf_)));
+ } else if (rv != ERR_IO_PENDING) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, pending_read_nss_error_)));
+ pending_read_nss_error_ = 0;
+ }
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoPayloadWrite() {
+ DCHECK(OnNSSTaskRunner());
+
+ DCHECK(user_write_buf_.get());
+
+ int old_amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_);
+ int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_);
+ int new_amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_);
+ // PR_Write could potentially consume the unhandled data in the memio read
+ // buffer if a renegotiation is in progress. If the buffer is consumed,
+ // notify the latest buffer size to NetworkRunner.
+ if (old_amount_in_read_buffer != new_amount_in_read_buffer) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::OnNSSBufferUpdated, this, new_amount_in_read_buffer));
+ }
+ if (rv >= 0) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&LogByteTransferEvent, weak_net_log_,
+ NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv,
+ scoped_refptr<IOBuffer>(user_write_buf_)));
+ return rv;
+ }
+ PRErrorCode prerr = PR_GetError();
+ if (prerr == PR_WOULD_BLOCK_ERROR)
+ return ERR_IO_PENDING;
+
+ rv = HandleNSSError(prerr, false);
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_WRITE_ERROR,
+ CreateNetLogSSLErrorCallback(rv, prerr)));
+ return rv;
+}
+
+// Do as much network I/O as possible between the buffer and the
+// transport socket. Return true if some I/O performed, false
+// otherwise (error or ERR_IO_PENDING).
+bool SSLClientSocketNSS::Core::DoTransportIO() {
+ DCHECK(OnNSSTaskRunner());
+
+ bool network_moved = false;
+ if (nss_bufs_ != NULL) {
+ int rv;
+ // Read and write as much data as we can. The loop is neccessary
+ // because Write() may return synchronously.
+ do {
+ rv = BufferSend();
+ if (rv != ERR_IO_PENDING && rv != 0)
+ network_moved = true;
+ } while (rv > 0);
+ if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING)
+ network_moved = true;
+ }
+ return network_moved;
+}
+
+int SSLClientSocketNSS::Core::BufferRecv() {
+ DCHECK(OnNSSTaskRunner());
+
+ if (transport_recv_busy_)
+ return ERR_IO_PENDING;
+
+ // If NSS is blocked on reading from |nss_bufs_|, because it is empty,
+ // determine how much data NSS wants to read. If NSS was not blocked,
+ // this will return 0.
+ int requested = memio_GetReadRequest(nss_bufs_);
+ if (requested == 0) {
+ // This is not a perfect match of error codes, as no operation is
+ // actually pending. However, returning 0 would be interpreted as a
+ // possible sign of EOF, which is also an inappropriate match.
+ return ERR_IO_PENDING;
+ }
+
+ char* buf;
+ int nb = memio_GetReadParams(nss_bufs_, &buf);
+ int rv;
+ if (!nb) {
+ // buffer too full to read into, so no I/O possible at moment
+ rv = ERR_IO_PENDING;
+ } else {
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(nb));
+ if (OnNetworkTaskRunner()) {
+ rv = DoBufferRecv(read_buffer.get(), nb);
+ } else {
+ bool posted = network_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(IgnoreResult(&Core::DoBufferRecv), this, read_buffer,
+ nb));
+ rv = posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ if (rv == ERR_IO_PENDING) {
+ transport_recv_busy_ = true;
+ } else {
+ if (rv > 0) {
+ memcpy(buf, read_buffer->data(), rv);
+ } else if (rv == 0) {
+ transport_recv_eof_ = true;
+ }
+ memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv));
+ }
+ }
+ return rv;
+}
+
+// Return 0 if nss_bufs_ was empty,
+// > 0 for bytes transferred immediately,
+// < 0 for error (or the non-error ERR_IO_PENDING).
+int SSLClientSocketNSS::Core::BufferSend() {
+ DCHECK(OnNSSTaskRunner());
+
+ if (transport_send_busy_)
+ return ERR_IO_PENDING;
+
+ const char* buf1;
+ const char* buf2;
+ unsigned int len1, len2;
+ memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2);
+ const unsigned int len = len1 + len2;
+
+ int rv = 0;
+ if (len) {
+ scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len));
+ memcpy(send_buffer->data(), buf1, len1);
+ memcpy(send_buffer->data() + len1, buf2, len2);
+
+ if (OnNetworkTaskRunner()) {
+ rv = DoBufferSend(send_buffer.get(), len);
+ } else {
+ bool posted = network_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(IgnoreResult(&Core::DoBufferSend), this, send_buffer,
+ len));
+ rv = posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ if (rv == ERR_IO_PENDING) {
+ transport_send_busy_ = true;
+ } else {
+ memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv));
+ }
+ }
+
+ return rv;
+}
+
+void SSLClientSocketNSS::Core::OnRecvComplete(int result) {
+ DCHECK(OnNSSTaskRunner());
+
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ // Network layer received some data, check if client requested to read
+ // decrypted data.
+ if (!user_read_buf_.get())
+ return;
+
+ int rv = DoReadLoop(result);
+ if (rv != ERR_IO_PENDING)
+ DoReadCallback(rv);
+}
+
+void SSLClientSocketNSS::Core::OnSendComplete(int result) {
+ DCHECK(OnNSSTaskRunner());
+
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ // OnSendComplete may need to call DoPayloadRead while the renegotiation
+ // handshake is in progress.
+ int rv_read = ERR_IO_PENDING;
+ int rv_write = ERR_IO_PENDING;
+ bool network_moved;
+ do {
+ if (user_read_buf_.get())
+ rv_read = DoPayloadRead();
+ if (user_write_buf_.get())
+ rv_write = DoPayloadWrite();
+ network_moved = DoTransportIO();
+ } while (rv_read == ERR_IO_PENDING && rv_write == ERR_IO_PENDING &&
+ (user_read_buf_.get() || user_write_buf_.get()) && network_moved);
+
+ // If the parent SSLClientSocketNSS is deleted during the processing of the
+ // Read callback and OnNSSTaskRunner() == OnNetworkTaskRunner(), then the Core
+ // will be detached (and possibly deleted). Guard against deletion by taking
+ // an extra reference, then check if the Core was detached before invoking the
+ // next callback.
+ scoped_refptr<Core> guard(this);
+ if (user_read_buf_.get() && rv_read != ERR_IO_PENDING)
+ DoReadCallback(rv_read);
+
+ if (OnNetworkTaskRunner() && detached_)
+ return;
+
+ if (user_write_buf_.get() && rv_write != ERR_IO_PENDING)
+ DoWriteCallback(rv_write);
+}
+
+// As part of Connect(), the SSLClientSocketNSS object performs an SSL
+// handshake. This requires network IO, which in turn calls
+// BufferRecvComplete() with a non-zero byte count. This byte count eventually
+// winds its way through the state machine and ends up being passed to the
+// callback. For Read() and Write(), that's what we want. But for Connect(),
+// the caller expects OK (i.e. 0) for success.
+void SSLClientSocketNSS::Core::DoConnectCallback(int rv) {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ DCHECK(!user_connect_callback_.is_null());
+
+ base::Closure c = base::Bind(
+ base::ResetAndReturn(&user_connect_callback_),
+ rv > OK ? OK : rv);
+ PostOrRunCallback(FROM_HERE, c);
+}
+
+void SSLClientSocketNSS::Core::DoReadCallback(int rv) {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ DCHECK(!user_read_callback_.is_null());
+
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_);
+ // This is used to curry the |amount_int_read_buffer| and |user_cb| back to
+ // the network task runner.
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::OnNSSBufferUpdated, this, amount_in_read_buffer));
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::DidNSSRead, this, rv));
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(base::ResetAndReturn(&user_read_callback_), rv));
+}
+
+void SSLClientSocketNSS::Core::DoWriteCallback(int rv) {
+ DCHECK(OnNSSTaskRunner());
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ DCHECK(!user_write_callback_.is_null());
+
+ // Since Run may result in Write being called, clear |user_write_callback_|
+ // up front.
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+ // Update buffer status because DoWriteLoop called DoTransportIO which may
+ // perform read operations.
+ int amount_in_read_buffer = memio_GetReadableBufferSize(nss_bufs_);
+ // This is used to curry the |amount_int_read_buffer| and |user_cb| back to
+ // the network task runner.
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::OnNSSBufferUpdated, this, amount_in_read_buffer));
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::DidNSSWrite, this, rv));
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(base::ResetAndReturn(&user_write_callback_), rv));
+}
+
+SECStatus SSLClientSocketNSS::Core::ClientChannelIDHandler(
+ void* arg,
+ PRFileDesc* socket,
+ SECKEYPublicKey **out_public_key,
+ SECKEYPrivateKey **out_private_key) {
+ Core* core = reinterpret_cast<Core*>(arg);
+ DCHECK(core->OnNSSTaskRunner());
+
+ core->PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEvent, core->weak_net_log_,
+ NetLog::TYPE_SSL_CHANNEL_ID_REQUESTED));
+
+ // We have negotiated the TLS channel ID extension.
+ core->channel_id_xtn_negotiated_ = true;
+ std::string host = core->host_and_port_.host();
+ int error = ERR_UNEXPECTED;
+ if (core->OnNetworkTaskRunner()) {
+ error = core->DoGetDomainBoundCert(host);
+ } else {
+ bool posted = core->network_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(
+ IgnoreResult(&Core::DoGetDomainBoundCert),
+ core, host));
+ error = posted ? ERR_IO_PENDING : ERR_ABORTED;
+ }
+
+ if (error == ERR_IO_PENDING) {
+ // Asynchronous case.
+ core->channel_id_needed_ = true;
+ return SECWouldBlock;
+ }
+
+ core->PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&BoundNetLog::EndEventWithNetErrorCode, core->weak_net_log_,
+ NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT, error));
+ SECStatus rv = SECSuccess;
+ if (error == OK) {
+ // Synchronous success.
+ int result = core->ImportChannelIDKeys(out_public_key, out_private_key);
+ if (result == OK)
+ core->SetChannelIDProvided();
+ else
+ rv = SECFailure;
+ } else {
+ rv = SECFailure;
+ }
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::ImportChannelIDKeys(SECKEYPublicKey** public_key,
+ SECKEYPrivateKey** key) {
+ // Set the certificate.
+ SECItem cert_item;
+ cert_item.data = (unsigned char*) domain_bound_cert_.data();
+ cert_item.len = domain_bound_cert_.size();
+ ScopedCERTCertificate cert(CERT_NewTempCertificate(CERT_GetDefaultCertDB(),
+ &cert_item,
+ NULL,
+ PR_FALSE,
+ PR_TRUE));
+ if (cert == NULL)
+ return MapNSSError(PORT_GetError());
+
+ // Set the private key.
+ if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo(
+ ServerBoundCertService::kEPKIPassword,
+ reinterpret_cast<const unsigned char*>(
+ domain_bound_private_key_.data()),
+ domain_bound_private_key_.size(),
+ &cert->subjectPublicKeyInfo,
+ false,
+ false,
+ key,
+ public_key)) {
+ int error = MapNSSError(PORT_GetError());
+ return error;
+ }
+
+ return OK;
+}
+
+void SSLClientSocketNSS::Core::UpdateServerCert() {
+ nss_handshake_state_.server_cert_chain.Reset(nss_fd_);
+ nss_handshake_state_.server_cert = X509Certificate::CreateFromDERCertChain(
+ nss_handshake_state_.server_cert_chain.AsStringPieceVector());
+ if (nss_handshake_state_.server_cert.get()) {
+ // Since this will be called asynchronously on another thread, it needs to
+ // own a reference to the certificate.
+ NetLog::ParametersCallback net_log_callback =
+ base::Bind(&NetLogX509CertificateCallback,
+ nss_handshake_state_.server_cert);
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_CERTIFICATES_RECEIVED,
+ net_log_callback));
+ }
+}
+
+void SSLClientSocketNSS::Core::UpdateConnectionStatus() {
+ SSLChannelInfo channel_info;
+ SECStatus ok = SSL_GetChannelInfo(nss_fd_,
+ &channel_info, sizeof(channel_info));
+ if (ok == SECSuccess &&
+ channel_info.length == sizeof(channel_info) &&
+ channel_info.cipherSuite) {
+ nss_handshake_state_.ssl_connection_status |=
+ (static_cast<int>(channel_info.cipherSuite) &
+ SSL_CONNECTION_CIPHERSUITE_MASK) <<
+ SSL_CONNECTION_CIPHERSUITE_SHIFT;
+
+ nss_handshake_state_.ssl_connection_status |=
+ (static_cast<int>(channel_info.compressionMethod) &
+ SSL_CONNECTION_COMPRESSION_MASK) <<
+ SSL_CONNECTION_COMPRESSION_SHIFT;
+
+ // NSS 3.14.x doesn't have a version macro for TLS 1.2 (because NSS didn't
+ // support it yet), so use 0x0303 directly.
+ int version = SSL_CONNECTION_VERSION_UNKNOWN;
+ if (channel_info.protocolVersion < SSL_LIBRARY_VERSION_3_0) {
+ // All versions less than SSL_LIBRARY_VERSION_3_0 are treated as SSL
+ // version 2.
+ version = SSL_CONNECTION_VERSION_SSL2;
+ } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) {
+ version = SSL_CONNECTION_VERSION_SSL3;
+ } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_1_TLS) {
+ version = SSL_CONNECTION_VERSION_TLS1;
+ } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_1) {
+ version = SSL_CONNECTION_VERSION_TLS1_1;
+ } else if (channel_info.protocolVersion == 0x0303) {
+ version = SSL_CONNECTION_VERSION_TLS1_2;
+ }
+ nss_handshake_state_.ssl_connection_status |=
+ (version & SSL_CONNECTION_VERSION_MASK) <<
+ SSL_CONNECTION_VERSION_SHIFT;
+ }
+
+ PRBool peer_supports_renego_ext;
+ ok = SSL_HandshakeNegotiatedExtension(nss_fd_, ssl_renegotiation_info_xtn,
+ &peer_supports_renego_ext);
+ if (ok == SECSuccess) {
+ if (!peer_supports_renego_ext) {
+ nss_handshake_state_.ssl_connection_status |=
+ SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION;
+ // Log an informational message if the server does not support secure
+ // renegotiation (RFC 5746).
+ VLOG(1) << "The server " << host_and_port_.ToString()
+ << " does not support the TLS renegotiation_info extension.";
+ }
+ UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported",
+ peer_supports_renego_ext, 2);
+
+ // We would like to eliminate fallback to SSLv3 for non-buggy servers
+ // because of security concerns. For example, Google offers forward
+ // secrecy with ECDHE but that requires TLS 1.0. An attacker can block
+ // TLSv1 connections and force us to downgrade to SSLv3 and remove forward
+ // secrecy.
+ //
+ // Yngve from Opera has suggested using the renegotiation extension as an
+ // indicator that SSLv3 fallback was mistaken:
+ // tools.ietf.org/html/draft-pettersen-tls-version-rollback-removal-00 .
+ //
+ // As a first step, measure how often clients perform version fallback
+ // while the server advertises support secure renegotiation.
+ if (ssl_config_.version_fallback &&
+ channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) {
+ UMA_HISTOGRAM_BOOLEAN("Net.SSLv3FallbackToRenegoPatchedServer",
+ peer_supports_renego_ext == PR_TRUE);
+ }
+ }
+
+ if (ssl_config_.version_fallback) {
+ nss_handshake_state_.ssl_connection_status |=
+ SSL_CONNECTION_VERSION_FALLBACK;
+ }
+}
+
+void SSLClientSocketNSS::Core::UpdateNextProto() {
+ uint8 buf[256];
+ SSLNextProtoState state;
+ unsigned buf_len;
+
+ SECStatus rv = SSL_GetNextProto(nss_fd_, &state, buf, &buf_len, sizeof(buf));
+ if (rv != SECSuccess)
+ return;
+
+ nss_handshake_state_.next_proto =
+ std::string(reinterpret_cast<char*>(buf), buf_len);
+ switch (state) {
+ case SSL_NEXT_PROTO_NEGOTIATED:
+ case SSL_NEXT_PROTO_SELECTED:
+ nss_handshake_state_.next_proto_status = kNextProtoNegotiated;
+ break;
+ case SSL_NEXT_PROTO_NO_OVERLAP:
+ nss_handshake_state_.next_proto_status = kNextProtoNoOverlap;
+ break;
+ case SSL_NEXT_PROTO_NO_SUPPORT:
+ nss_handshake_state_.next_proto_status = kNextProtoUnsupported;
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+}
+
+void SSLClientSocketNSS::Core::RecordChannelIDSupport() {
+ DCHECK(OnNSSTaskRunner());
+ if (nss_handshake_state_.resumed_handshake)
+ return;
+
+ // Copy the NSS task runner-only state to the network task runner and
+ // log histograms from there, since the histograms also need access to the
+ // network task runner state.
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&Core::RecordChannelIDSupportOnNetworkTaskRunner,
+ this,
+ channel_id_xtn_negotiated_,
+ ssl_config_.channel_id_enabled,
+ crypto::ECPrivateKey::IsSupported()));
+}
+
+void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNetworkTaskRunner(
+ bool negotiated_channel_id,
+ bool channel_id_enabled,
+ bool supports_ecc) const {
+ DCHECK(OnNetworkTaskRunner());
+
+ // Since this enum is used for a histogram, do not change or re-use values.
+ enum {
+ DISABLED = 0,
+ CLIENT_ONLY = 1,
+ CLIENT_AND_SERVER = 2,
+ CLIENT_NO_ECC = 3,
+ CLIENT_BAD_SYSTEM_TIME = 4,
+ CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5,
+ DOMAIN_BOUND_CERT_USAGE_MAX
+ } supported = DISABLED;
+ if (negotiated_channel_id) {
+ supported = CLIENT_AND_SERVER;
+ } else if (channel_id_enabled) {
+ if (!server_bound_cert_service_)
+ supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE;
+ else if (!supports_ecc)
+ supported = CLIENT_NO_ECC;
+ else if (!server_bound_cert_service_->IsSystemTimeValid())
+ supported = CLIENT_BAD_SYSTEM_TIME;
+ else
+ supported = CLIENT_ONLY;
+ }
+ UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported,
+ DOMAIN_BOUND_CERT_USAGE_MAX);
+}
+
+int SSLClientSocketNSS::Core::DoBufferRecv(IOBuffer* read_buffer, int len) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK_GT(len, 0);
+
+ if (detached_)
+ return ERR_ABORTED;
+
+ int rv = transport_->socket()->Read(
+ read_buffer, len,
+ base::Bind(&Core::BufferRecvComplete, base::Unretained(this),
+ scoped_refptr<IOBuffer>(read_buffer)));
+
+ if (!OnNSSTaskRunner() && rv != ERR_IO_PENDING) {
+ nss_task_runner_->PostTask(
+ FROM_HERE, base::Bind(&Core::BufferRecvComplete, this,
+ scoped_refptr<IOBuffer>(read_buffer), rv));
+ return rv;
+ }
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoBufferSend(IOBuffer* send_buffer, int len) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK_GT(len, 0);
+
+ if (detached_)
+ return ERR_ABORTED;
+
+ int rv = transport_->socket()->Write(
+ send_buffer, len,
+ base::Bind(&Core::BufferSendComplete,
+ base::Unretained(this)));
+
+ if (!OnNSSTaskRunner() && rv != ERR_IO_PENDING) {
+ nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&Core::BufferSendComplete, this, rv));
+ return rv;
+ }
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) {
+ DCHECK(OnNetworkTaskRunner());
+
+ if (detached_)
+ return ERR_FAILED;
+
+ weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT);
+
+ int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert(
+ host,
+ &domain_bound_private_key_,
+ &domain_bound_cert_,
+ base::Bind(&Core::OnGetDomainBoundCertComplete, base::Unretained(this)),
+ &domain_bound_cert_request_handle_);
+
+ if (rv != ERR_IO_PENDING && !OnNSSTaskRunner()) {
+ nss_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&Core::OnHandshakeIOComplete, this, rv));
+ return ERR_IO_PENDING;
+ }
+
+ return rv;
+}
+
+void SSLClientSocketNSS::Core::OnHandshakeStateUpdated(
+ const HandshakeState& state) {
+ DCHECK(OnNetworkTaskRunner());
+ network_handshake_state_ = state;
+}
+
+void SSLClientSocketNSS::Core::OnNSSBufferUpdated(int amount_in_read_buffer) {
+ DCHECK(OnNetworkTaskRunner());
+ unhandled_buffer_size_ = amount_in_read_buffer;
+}
+
+void SSLClientSocketNSS::Core::DidNSSRead(int result) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK(nss_waiting_read_);
+ nss_waiting_read_ = false;
+ if (result <= 0)
+ nss_is_closed_ = true;
+}
+
+void SSLClientSocketNSS::Core::DidNSSWrite(int result) {
+ DCHECK(OnNetworkTaskRunner());
+ DCHECK(nss_waiting_write_);
+ nss_waiting_write_ = false;
+ if (result < 0)
+ nss_is_closed_ = true;
+}
+
+void SSLClientSocketNSS::Core::BufferSendComplete(int result) {
+ if (!OnNSSTaskRunner()) {
+ if (detached_)
+ return;
+
+ nss_task_runner_->PostTask(
+ FROM_HERE, base::Bind(&Core::BufferSendComplete, this, result));
+ return;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+
+ memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result));
+ transport_send_busy_ = false;
+ OnSendComplete(result);
+}
+
+void SSLClientSocketNSS::Core::OnHandshakeIOComplete(int result) {
+ if (!OnNSSTaskRunner()) {
+ if (detached_)
+ return;
+
+ nss_task_runner_->PostTask(
+ FROM_HERE, base::Bind(&Core::OnHandshakeIOComplete, this, result));
+ return;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+
+ int rv = DoHandshakeLoop(result);
+ if (rv != ERR_IO_PENDING)
+ DoConnectCallback(rv);
+}
+
+void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) {
+ DVLOG(1) << __FUNCTION__ << " " << result;
+ DCHECK(OnNetworkTaskRunner());
+
+ OnHandshakeIOComplete(result);
+}
+
+void SSLClientSocketNSS::Core::BufferRecvComplete(
+ IOBuffer* read_buffer,
+ int result) {
+ DCHECK(read_buffer);
+
+ if (!OnNSSTaskRunner()) {
+ if (detached_)
+ return;
+
+ nss_task_runner_->PostTask(
+ FROM_HERE, base::Bind(&Core::BufferRecvComplete, this,
+ scoped_refptr<IOBuffer>(read_buffer), result));
+ return;
+ }
+
+ DCHECK(OnNSSTaskRunner());
+
+ if (result > 0) {
+ char* buf;
+ int nb = memio_GetReadParams(nss_bufs_, &buf);
+ CHECK_GE(nb, result);
+ memcpy(buf, read_buffer->data(), result);
+ } else if (result == 0) {
+ transport_recv_eof_ = true;
+ }
+
+ memio_PutReadResult(nss_bufs_, MapErrorToNSS(result));
+ transport_recv_busy_ = false;
+ OnRecvComplete(result);
+}
+
+void SSLClientSocketNSS::Core::PostOrRunCallback(
+ const tracked_objects::Location& location,
+ const base::Closure& task) {
+ if (!OnNetworkTaskRunner()) {
+ network_task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&Core::PostOrRunCallback, this, location, task));
+ return;
+ }
+
+ if (detached_ || task.is_null())
+ return;
+ task.Run();
+}
+
+void SSLClientSocketNSS::Core::AddCertProvidedEvent(int cert_count) {
+ PostOrRunCallback(
+ FROM_HERE,
+ base::Bind(&AddLogEventWithCallback, weak_net_log_,
+ NetLog::TYPE_SSL_CLIENT_CERT_PROVIDED,
+ NetLog::IntegerCallback("cert_count", cert_count)));
+}
+
+void SSLClientSocketNSS::Core::SetChannelIDProvided() {
+ PostOrRunCallback(
+ FROM_HERE, base::Bind(&AddLogEvent, weak_net_log_,
+ NetLog::TYPE_SSL_CHANNEL_ID_PROVIDED));
+ nss_handshake_state_.channel_id_sent = true;
+ // Update the network task runner's view of the handshake state now that
+ // channel id has been sent.
+ PostOrRunCallback(
+ FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, this,
+ nss_handshake_state_));
+}
+
+SSLClientSocketNSS::SSLClientSocketNSS(
+ base::SequencedTaskRunner* nss_task_runner,
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context)
+ : nss_task_runner_(nss_task_runner),
+ transport_(transport_socket.Pass()),
+ host_and_port_(host_and_port),
+ ssl_config_(ssl_config),
+ cert_verifier_(context.cert_verifier),
+ server_bound_cert_service_(context.server_bound_cert_service),
+ ssl_session_cache_shard_(context.ssl_session_cache_shard),
+ completed_handshake_(false),
+ next_handshake_state_(STATE_NONE),
+ nss_fd_(NULL),
+ net_log_(transport_->socket()->NetLog()),
+ transport_security_state_(context.transport_security_state),
+ valid_thread_id_(base::kInvalidThreadId) {
+ EnterFunction("");
+ InitCore();
+ LeaveFunction("");
+}
+
+SSLClientSocketNSS::~SSLClientSocketNSS() {
+ EnterFunction("");
+ Disconnect();
+ LeaveFunction("");
+}
+
+// static
+void SSLClientSocket::ClearSessionCache() {
+ // SSL_ClearSessionCache can't be called before NSS is initialized. Don't
+ // bother initializing NSS just to clear an empty SSL session cache.
+ if (!NSS_IsInitialized())
+ return;
+
+ SSL_ClearSessionCache();
+}
+
+bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) {
+ EnterFunction("");
+ ssl_info->Reset();
+ if (core_->state().server_cert_chain.empty() ||
+ !core_->state().server_cert_chain[0]) {
+ return false;
+ }
+
+ ssl_info->cert_status = server_cert_verify_result_.cert_status;
+ ssl_info->cert = server_cert_verify_result_.verified_cert;
+ ssl_info->connection_status =
+ core_->state().ssl_connection_status;
+ ssl_info->public_key_hashes = server_cert_verify_result_.public_key_hashes;
+ for (HashValueVector::const_iterator i = side_pinned_public_keys_.begin();
+ i != side_pinned_public_keys_.end(); ++i) {
+ ssl_info->public_key_hashes.push_back(*i);
+ }
+ ssl_info->is_issued_by_known_root =
+ server_cert_verify_result_.is_issued_by_known_root;
+ ssl_info->client_cert_sent =
+ ssl_config_.send_client_cert && ssl_config_.client_cert.get();
+ ssl_info->channel_id_sent = WasChannelIDSent();
+
+ PRUint16 cipher_suite = SSLConnectionStatusToCipherSuite(
+ core_->state().ssl_connection_status);
+ SSLCipherSuiteInfo cipher_info;
+ SECStatus ok = SSL_GetCipherSuiteInfo(cipher_suite,
+ &cipher_info, sizeof(cipher_info));
+ if (ok == SECSuccess) {
+ ssl_info->security_bits = cipher_info.effectiveKeyBits;
+ } else {
+ ssl_info->security_bits = -1;
+ LOG(DFATAL) << "SSL_GetCipherSuiteInfo returned " << PR_GetError()
+ << " for cipherSuite " << cipher_suite;
+ }
+
+ ssl_info->handshake_type = core_->state().resumed_handshake ?
+ SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL;
+
+ LeaveFunction("");
+ return true;
+}
+
+void SSLClientSocketNSS::GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) {
+ EnterFunction("");
+ // TODO(rch): switch SSLCertRequestInfo.host_and_port to a HostPortPair
+ cert_request_info->host_and_port = host_and_port_.ToString();
+ cert_request_info->cert_authorities = core_->state().cert_authorities;
+ LeaveFunction("");
+}
+
+int SSLClientSocketNSS::ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+
+ // SSL_ExportKeyingMaterial may block the current thread if |core_| is in
+ // the midst of a handshake.
+ SECStatus result = SSL_ExportKeyingMaterial(
+ nss_fd_, label.data(), label.size(), has_context,
+ reinterpret_cast<const unsigned char*>(context.data()),
+ context.length(), out, outlen);
+ if (result != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_ExportKeyingMaterial", "");
+ return MapNSSError(PORT_GetError());
+ }
+ return OK;
+}
+
+int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ unsigned char buf[64];
+ unsigned int len;
+ SECStatus result = SSL_GetChannelBinding(nss_fd_,
+ SSL_CHANNEL_BINDING_TLS_UNIQUE,
+ buf, &len, arraysize(buf));
+ if (result != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", "");
+ return MapNSSError(PORT_GetError());
+ }
+ out->assign(reinterpret_cast<char*>(buf), len);
+ return OK;
+}
+
+SSLClientSocket::NextProtoStatus
+SSLClientSocketNSS::GetNextProto(std::string* proto,
+ std::string* server_protos) {
+ *proto = core_->state().next_proto;
+ *server_protos = core_->state().server_protos;
+ return core_->state().next_proto_status;
+}
+
+int SSLClientSocketNSS::Connect(const CompletionCallback& callback) {
+ EnterFunction("");
+ DCHECK(transport_.get());
+ // It is an error to create an SSLClientSocket whose context has no
+ // TransportSecurityState.
+ DCHECK(transport_security_state_);
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ DCHECK(user_connect_callback_.is_null());
+ DCHECK(!callback.is_null());
+
+ EnsureThreadIdAssigned();
+
+ net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT);
+
+ int rv = Init();
+ if (rv != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ return rv;
+ }
+
+ rv = InitializeSSLOptions();
+ if (rv != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ return rv;
+ }
+
+ rv = InitializeSSLPeerName();
+ if (rv != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ return rv;
+ }
+
+ GotoState(STATE_HANDSHAKE);
+
+ rv = DoHandshakeLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ user_connect_callback_ = callback;
+ } else {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ }
+
+ LeaveFunction("");
+ return rv > OK ? OK : rv;
+}
+
+void SSLClientSocketNSS::Disconnect() {
+ EnterFunction("");
+
+ CHECK(CalledOnValidThread());
+
+ // Shut down anything that may call us back.
+ core_->Detach();
+ verifier_.reset();
+ transport_->socket()->Disconnect();
+
+ // Reset object state.
+ user_connect_callback_.Reset();
+ server_cert_verify_result_.Reset();
+ completed_handshake_ = false;
+ start_cert_verification_time_ = base::TimeTicks();
+ InitCore();
+
+ LeaveFunction("");
+}
+
+bool SSLClientSocketNSS::IsConnected() const {
+ EnterFunction("");
+ bool ret = completed_handshake_ &&
+ (core_->HasPendingAsyncOperation() ||
+ (core_->IsConnected() && core_->HasUnhandledReceivedData()) ||
+ transport_->socket()->IsConnected());
+ LeaveFunction("");
+ return ret;
+}
+
+bool SSLClientSocketNSS::IsConnectedAndIdle() const {
+ EnterFunction("");
+ bool ret = completed_handshake_ &&
+ !core_->HasPendingAsyncOperation() &&
+ !(core_->IsConnected() && core_->HasUnhandledReceivedData()) &&
+ transport_->socket()->IsConnectedAndIdle();
+ LeaveFunction("");
+ return ret;
+}
+
+int SSLClientSocketNSS::GetPeerAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetPeerAddress(address);
+}
+
+int SSLClientSocketNSS::GetLocalAddress(IPEndPoint* address) const {
+ return transport_->socket()->GetLocalAddress(address);
+}
+
+const BoundNetLog& SSLClientSocketNSS::NetLog() const {
+ return net_log_;
+}
+
+void SSLClientSocketNSS::SetSubresourceSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetSubresourceSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+void SSLClientSocketNSS::SetOmniboxSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetOmniboxSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+bool SSLClientSocketNSS::WasEverUsed() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->WasEverUsed();
+ }
+ NOTREACHED();
+ return false;
+}
+
+bool SSLClientSocketNSS::UsingTCPFastOpen() const {
+ if (transport_.get() && transport_->socket()) {
+ return transport_->socket()->UsingTCPFastOpen();
+ }
+ NOTREACHED();
+ return false;
+}
+
+int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(core_.get());
+ DCHECK(!callback.is_null());
+
+ EnterFunction(buf_len);
+ int rv = core_->Read(buf, buf_len, callback);
+ LeaveFunction(rv);
+
+ return rv;
+}
+
+int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(core_.get());
+ DCHECK(!callback.is_null());
+
+ EnterFunction(buf_len);
+ int rv = core_->Write(buf, buf_len, callback);
+ LeaveFunction(rv);
+
+ return rv;
+}
+
+bool SSLClientSocketNSS::SetReceiveBufferSize(int32 size) {
+ return transport_->socket()->SetReceiveBufferSize(size);
+}
+
+bool SSLClientSocketNSS::SetSendBufferSize(int32 size) {
+ return transport_->socket()->SetSendBufferSize(size);
+}
+
+int SSLClientSocketNSS::Init() {
+ EnterFunction("");
+ // Initialize the NSS SSL library in a threadsafe way. This also
+ // initializes the NSS base library.
+ EnsureNSSSSLInit();
+ if (!NSS_IsInitialized())
+ return ERR_UNEXPECTED;
+#if defined(USE_NSS) || defined(OS_IOS)
+ if (ssl_config_.cert_io_enabled) {
+ // We must call EnsureNSSHttpIOInit() here, on the IO thread, to get the IO
+ // loop by MessageLoopForIO::current().
+ // X509Certificate::Verify() runs on a worker thread of CertVerifier.
+ EnsureNSSHttpIOInit();
+ }
+#endif
+
+ LeaveFunction("");
+ return OK;
+}
+
+void SSLClientSocketNSS::InitCore() {
+ core_ = new Core(base::ThreadTaskRunnerHandle::Get().get(),
+ nss_task_runner_.get(),
+ transport_.get(),
+ host_and_port_,
+ ssl_config_,
+ &net_log_,
+ server_bound_cert_service_);
+}
+
+int SSLClientSocketNSS::InitializeSSLOptions() {
+ // Transport connected, now hook it up to nss
+ nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize);
+ if (nss_fd_ == NULL) {
+ return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code.
+ }
+
+ // Grab pointer to buffers
+ memio_Private* nss_bufs = memio_GetSecret(nss_fd_);
+
+ /* Create SSL state machine */
+ /* Push SSL onto our fake I/O socket */
+ nss_fd_ = SSL_ImportFD(NULL, nss_fd_);
+ if (nss_fd_ == NULL) {
+ LogFailedNSSFunction(net_log_, "SSL_ImportFD", "");
+ return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code.
+ }
+ // TODO(port): set more ssl options! Check errors!
+
+ int rv;
+
+ rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2");
+ return ERR_UNEXPECTED;
+ }
+
+ // Don't do V2 compatible hellos because they don't support TLS extensions.
+ rv = SSL_OptionSet(nss_fd_, SSL_V2_COMPATIBLE_HELLO, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_V2_COMPATIBLE_HELLO");
+ return ERR_UNEXPECTED;
+ }
+
+ SSLVersionRange version_range;
+ version_range.min = ssl_config_.version_min;
+ version_range.max = ssl_config_.version_max;
+ rv = SSL_VersionRangeSet(nss_fd_, &version_range);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", "");
+ return ERR_NO_SSL_VERSIONS_ENABLED;
+ }
+
+ for (std::vector<uint16>::const_iterator it =
+ ssl_config_.disabled_cipher_suites.begin();
+ it != ssl_config_.disabled_cipher_suites.end(); ++it) {
+ // This will fail if the specified cipher is not implemented by NSS, but
+ // the failure is harmless.
+ SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE);
+ }
+
+ // Support RFC 5077
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(
+ net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS");
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START,
+ ssl_config_.false_start_enabled);
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START");
+
+ // We allow servers to request renegotiation. Since we're a client,
+ // prohibiting this is rather a waste of time. Only servers are in a
+ // position to prevent renegotiation attacks.
+ // http://extendedsubset.com/?p=8
+
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_RENEGOTIATION,
+ SSL_RENEGOTIATE_TRANSITIONAL);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(
+ net_log_, "SSL_OptionSet", "SSL_ENABLE_RENEGOTIATION");
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_CBC_RANDOM_IV, PR_TRUE);
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_CBC_RANDOM_IV");
+
+// Added in NSS 3.15
+#ifdef SSL_ENABLE_OCSP_STAPLING
+ if (IsOCSPStaplingSupported()) {
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet",
+ "SSL_ENABLE_OCSP_STAPLING");
+ }
+ }
+#endif
+
+// Chromium patch to libssl
+#ifdef SSL_ENABLE_CACHED_INFO
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_CACHED_INFO,
+ ssl_config_.cached_info_enabled);
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_CACHED_INFO");
+#endif
+
+ rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT");
+ return ERR_UNEXPECTED;
+ }
+
+ if (!core_->Init(nss_fd_, nss_bufs))
+ return ERR_UNEXPECTED;
+
+ // Tell SSL the hostname we're trying to connect to.
+ SSL_SetURL(nss_fd_, host_and_port_.host().c_str());
+
+ // Tell SSL we're a client; needed if not letting NSPR do socket I/O
+ SSL_ResetHandshake(nss_fd_, PR_FALSE);
+
+ return OK;
+}
+
+int SSLClientSocketNSS::InitializeSSLPeerName() {
+ // Tell NSS who we're connected to
+ IPEndPoint peer_address;
+ int err = transport_->socket()->GetPeerAddress(&peer_address);
+ if (err != OK)
+ return err;
+
+ SockaddrStorage storage;
+ if (!peer_address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_UNEXPECTED;
+
+ PRNetAddr peername;
+ memset(&peername, 0, sizeof(peername));
+ DCHECK_LE(static_cast<size_t>(storage.addr_len), sizeof(peername));
+ size_t len = std::min(static_cast<size_t>(storage.addr_len),
+ sizeof(peername));
+ memcpy(&peername, storage.addr, len);
+
+ // Adjust the address family field for BSD, whose sockaddr
+ // structure has a one-byte length and one-byte address family
+ // field at the beginning. PRNetAddr has a two-byte address
+ // family field at the beginning.
+ peername.raw.family = storage.addr->sa_family;
+
+ memio_SetPeerName(nss_fd_, &peername);
+
+ // Set the peer ID for session reuse. This is necessary when we create an
+ // SSL tunnel through a proxy -- GetPeerName returns the proxy's address
+ // rather than the destination server's address in that case.
+ std::string peer_id = host_and_port_.ToString();
+ // If the ssl_session_cache_shard_ is non-empty, we append it to the peer id.
+ // This will cause session cache misses between sockets with different values
+ // of ssl_session_cache_shard_ and this is used to partition the session cache
+ // for incognito mode.
+ if (!ssl_session_cache_shard_.empty()) {
+ peer_id += "/" + ssl_session_cache_shard_;
+ }
+ SECStatus rv = SSL_SetSockPeerID(nss_fd_, const_cast<char*>(peer_id.c_str()));
+ if (rv != SECSuccess)
+ LogFailedNSSFunction(net_log_, "SSL_SetSockPeerID", peer_id.c_str());
+
+ return OK;
+}
+
+void SSLClientSocketNSS::DoConnectCallback(int rv) {
+ EnterFunction(rv);
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ DCHECK(!user_connect_callback_.is_null());
+
+ base::ResetAndReturn(&user_connect_callback_).Run(rv > OK ? OK : rv);
+ LeaveFunction("");
+}
+
+void SSLClientSocketNSS::OnHandshakeIOComplete(int result) {
+ EnterFunction(result);
+ int rv = DoHandshakeLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ DoConnectCallback(rv);
+ }
+ LeaveFunction("");
+}
+
+int SSLClientSocketNSS::DoHandshakeLoop(int last_io_result) {
+ EnterFunction(last_io_result);
+ int rv = last_io_result;
+ do {
+ // Default to STATE_NONE for next state.
+ // (This is a quirk carried over from the windows
+ // implementation. It makes reading the logs a bit harder.)
+ // State handlers can and often do call GotoState just
+ // to stay in the current state.
+ State state = next_handshake_state_;
+ GotoState(STATE_NONE);
+ switch (state) {
+ case STATE_HANDSHAKE:
+ rv = DoHandshake();
+ break;
+ case STATE_HANDSHAKE_COMPLETE:
+ rv = DoHandshakeComplete(rv);
+ break;
+ case STATE_VERIFY_CERT:
+ DCHECK(rv == OK);
+ rv = DoVerifyCert(rv);
+ break;
+ case STATE_VERIFY_CERT_COMPLETE:
+ rv = DoVerifyCertComplete(rv);
+ break;
+ case STATE_NONE:
+ default:
+ rv = ERR_UNEXPECTED;
+ LOG(DFATAL) << "unexpected state " << state;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
+ LeaveFunction("");
+ return rv;
+}
+
+int SSLClientSocketNSS::DoHandshake() {
+ EnterFunction("");
+ int rv = core_->Connect(
+ base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete,
+ base::Unretained(this)));
+ GotoState(STATE_HANDSHAKE_COMPLETE);
+
+ LeaveFunction(rv);
+ return rv;
+}
+
+int SSLClientSocketNSS::DoHandshakeComplete(int result) {
+ EnterFunction(result);
+
+ if (result == OK) {
+ // SSL handshake is completed. Let's verify the certificate.
+ GotoState(STATE_VERIFY_CERT);
+ // Done!
+ }
+ set_channel_id_sent(core_->state().channel_id_sent);
+
+ LeaveFunction(result);
+ return result;
+}
+
+
+int SSLClientSocketNSS::DoVerifyCert(int result) {
+ DCHECK(!core_->state().server_cert_chain.empty());
+ DCHECK(core_->state().server_cert_chain[0]);
+
+ GotoState(STATE_VERIFY_CERT_COMPLETE);
+
+ // If the certificate is expected to be bad we can use the expectation as
+ // the cert status.
+ base::StringPiece der_cert(
+ reinterpret_cast<char*>(
+ core_->state().server_cert_chain[0]->derCert.data),
+ core_->state().server_cert_chain[0]->derCert.len);
+ CertStatus cert_status;
+ if (ssl_config_.IsAllowedBadCert(der_cert, &cert_status)) {
+ DCHECK(start_cert_verification_time_.is_null());
+ VLOG(1) << "Received an expected bad cert with status: " << cert_status;
+ server_cert_verify_result_.Reset();
+ server_cert_verify_result_.cert_status = cert_status;
+ server_cert_verify_result_.verified_cert = core_->state().server_cert;
+ return OK;
+ }
+
+ // We may have failed to create X509Certificate object if we are
+ // running inside sandbox.
+ if (!core_->state().server_cert.get()) {
+ server_cert_verify_result_.Reset();
+ server_cert_verify_result_.cert_status = CERT_STATUS_INVALID;
+ return ERR_CERT_INVALID;
+ }
+
+ start_cert_verification_time_ = base::TimeTicks::Now();
+
+ int flags = 0;
+ if (ssl_config_.rev_checking_enabled)
+ flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED;
+ if (ssl_config_.verify_ev_cert)
+ flags |= CertVerifier::VERIFY_EV_CERT;
+ if (ssl_config_.cert_io_enabled)
+ flags |= CertVerifier::VERIFY_CERT_IO_ENABLED;
+ if (ssl_config_.rev_checking_required_local_anchors)
+ flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS;
+ verifier_.reset(new SingleRequestCertVerifier(cert_verifier_));
+ return verifier_->Verify(
+ core_->state().server_cert.get(),
+ host_and_port_.host(),
+ flags,
+ SSLConfigService::GetCRLSet().get(),
+ &server_cert_verify_result_,
+ base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete,
+ base::Unretained(this)),
+ net_log_);
+}
+
+// Derived from AuthCertificateCallback() in
+// mozilla/source/security/manager/ssl/src/nsNSSCallbacks.cpp.
+int SSLClientSocketNSS::DoVerifyCertComplete(int result) {
+ verifier_.reset();
+
+ if (!start_cert_verification_time_.is_null()) {
+ base::TimeDelta verify_time =
+ base::TimeTicks::Now() - start_cert_verification_time_;
+ if (result == OK)
+ UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTime", verify_time);
+ else
+ UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTimeError", verify_time);
+ }
+
+ // We used to remember the intermediate CA certs in the NSS database
+ // persistently. However, NSS opens a connection to the SQLite database
+ // during NSS initialization and doesn't close the connection until NSS
+ // shuts down. If the file system where the database resides is gone,
+ // the database connection goes bad. What's worse, the connection won't
+ // recover when the file system comes back. Until this NSS or SQLite bug
+ // is fixed, we need to avoid using the NSS database for non-essential
+ // purposes. See https://bugzilla.mozilla.org/show_bug.cgi?id=508081 and
+ // http://crbug.com/15630 for more info.
+
+ // TODO(hclam): Skip logging if server cert was expected to be bad because
+ // |server_cert_verify_result_| doesn't contain all the information about
+ // the cert.
+ if (result == OK)
+ LogConnectionTypeMetrics();
+
+ completed_handshake_ = true;
+
+#if defined(OFFICIAL_BUILD) && !defined(OS_ANDROID) && !defined(OS_IOS)
+ // Take care of any mandates for public key pinning.
+ //
+ // Pinning is only enabled for official builds to make sure that others don't
+ // end up with pins that cannot be easily updated.
+ //
+ // TODO(agl): We might have an issue here where a request for foo.example.com
+ // merges into a SPDY connection to www.example.com, and gets a different
+ // certificate.
+
+ // Perform pin validation if, and only if, all these conditions obtain:
+ //
+ // * a TransportSecurityState object is available;
+ // * the server's certificate chain is valid (or suffers from only a minor
+ // error);
+ // * the server's certificate chain chains up to a known root (i.e. not a
+ // user-installed trust anchor); and
+ // * the build is recent (very old builds should fail open so that users
+ // have some chance to recover).
+ //
+ const CertStatus cert_status = server_cert_verify_result_.cert_status;
+ if (transport_security_state_ &&
+ (result == OK ||
+ (IsCertificateError(result) && IsCertStatusMinorError(cert_status))) &&
+ server_cert_verify_result_.is_issued_by_known_root &&
+ TransportSecurityState::IsBuildTimely()) {
+ bool sni_available =
+ ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1 ||
+ ssl_config_.version_fallback;
+ const std::string& host = host_and_port_.host();
+
+ TransportSecurityState::DomainState domain_state;
+ if (transport_security_state_->GetDomainState(host, sni_available,
+ &domain_state) &&
+ domain_state.HasPublicKeyPins()) {
+ if (!domain_state.CheckPublicKeyPins(
+ server_cert_verify_result_.public_key_hashes)) {
+ result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN;
+ UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", false);
+ TransportSecurityState::ReportUMAOnPinFailure(host);
+ } else {
+ UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", true);
+ }
+ }
+ }
+#endif
+
+ // Exit DoHandshakeLoop and return the result to the caller to Connect.
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ return result;
+}
+
+void SSLClientSocketNSS::LogConnectionTypeMetrics() const {
+ UpdateConnectionTypeHistograms(CONNECTION_SSL);
+ int ssl_version = SSLConnectionStatusToVersion(
+ core_->state().ssl_connection_status);
+ switch (ssl_version) {
+ case SSL_CONNECTION_VERSION_SSL2:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2);
+ break;
+ case SSL_CONNECTION_VERSION_SSL3:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1_1:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1_2:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2);
+ break;
+ };
+}
+
+void SSLClientSocketNSS::EnsureThreadIdAssigned() const {
+ base::AutoLock auto_lock(lock_);
+ if (valid_thread_id_ != base::kInvalidThreadId)
+ return;
+ valid_thread_id_ = base::PlatformThread::CurrentId();
+}
+
+bool SSLClientSocketNSS::CalledOnValidThread() const {
+ EnsureThreadIdAssigned();
+ base::AutoLock auto_lock(lock_);
+ return valid_thread_id_ == base::PlatformThread::CurrentId();
+}
+
+ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const {
+ return server_bound_cert_service_;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h
new file mode 100644
index 00000000000..b41d28d74a8
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_nss.h
@@ -0,0 +1,196 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_
+#define NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_
+
+#include <certt.h>
+#include <keyt.h>
+#include <nspr.h>
+#include <nss.h>
+
+#include <string>
+#include <vector>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/synchronization/lock.h"
+#include "base/threading/platform_thread.h"
+#include "base/time/time.h"
+#include "base/timer/timer.h"
+#include "net/base/completion_callback.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/base/nss_memio.h"
+#include "net/cert/cert_verify_result.h"
+#include "net/cert/x509_certificate.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/ssl/server_bound_cert_service.h"
+#include "net/ssl/ssl_config_service.h"
+
+namespace base {
+class SequencedTaskRunner;
+}
+
+namespace net {
+
+class BoundNetLog;
+class CertVerifier;
+class ClientSocketHandle;
+class ServerBoundCertService;
+class SingleRequestCertVerifier;
+class TransportSecurityState;
+class X509Certificate;
+
+// An SSL client socket implemented with Mozilla NSS.
+class SSLClientSocketNSS : public SSLClientSocket {
+ public:
+ // Takes ownership of the |transport_socket|, which must already be connected.
+ // The hostname specified in |host_and_port| will be compared with the name(s)
+ // in the server's certificate during the SSL handshake. If SSL client
+ // authentication is requested, the host_and_port field of SSLCertRequestInfo
+ // will be populated with |host_and_port|. |ssl_config| specifies
+ // the SSL settings.
+ //
+ // Because calls to NSS may block, such as due to needing to access slow
+ // hardware or needing to synchronously unlock protected tokens, calls to
+ // NSS may optionally be run on a dedicated thread. If synchronous/blocking
+ // behaviour is desired, for performance or compatibility, the current task
+ // runner should be supplied instead.
+ SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner,
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context);
+ virtual ~SSLClientSocketNSS();
+
+ // SSLClientSocket implementation.
+ virtual void GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) OVERRIDE;
+ virtual NextProtoStatus GetNextProto(std::string* proto,
+ std::string* server_protos) OVERRIDE;
+
+ // SSLSocket implementation.
+ virtual int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) OVERRIDE;
+ virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+ virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+
+ private:
+ // Helper class to handle marshalling any NSS interaction to and from the
+ // NSS and network task runners. Not every call needs to happen on the Core
+ class Core;
+
+ enum State {
+ STATE_NONE,
+ STATE_HANDSHAKE,
+ STATE_HANDSHAKE_COMPLETE,
+ STATE_VERIFY_CERT,
+ STATE_VERIFY_CERT_COMPLETE,
+ };
+
+ int Init();
+ void InitCore();
+
+ // Initializes NSS SSL options. Returns a net error code.
+ int InitializeSSLOptions();
+
+ // Initializes the socket peer name in SSL. Returns a net error code.
+ int InitializeSSLPeerName();
+
+ void DoConnectCallback(int result);
+ void OnHandshakeIOComplete(int result);
+
+ int DoHandshakeLoop(int last_io_result);
+ int DoHandshake();
+ int DoHandshakeComplete(int result);
+ int DoVerifyCert(int result);
+ int DoVerifyCertComplete(int result);
+
+ void LogConnectionTypeMetrics() const;
+
+ // The following methods are for debugging bug 65948. Will remove this code
+ // after fixing bug 65948.
+ void EnsureThreadIdAssigned() const;
+ bool CalledOnValidThread() const;
+
+ // The task runner used to perform NSS operations.
+ scoped_refptr<base::SequencedTaskRunner> nss_task_runner_;
+ scoped_ptr<ClientSocketHandle> transport_;
+ HostPortPair host_and_port_;
+ SSLConfig ssl_config_;
+
+ scoped_refptr<Core> core_;
+
+ CompletionCallback user_connect_callback_;
+
+ CertVerifyResult server_cert_verify_result_;
+ HashValueVector side_pinned_public_keys_;
+
+ CertVerifier* const cert_verifier_;
+ scoped_ptr<SingleRequestCertVerifier> verifier_;
+
+ // The service for retrieving Channel ID keys. May be NULL.
+ ServerBoundCertService* server_bound_cert_service_;
+
+ // ssl_session_cache_shard_ is an opaque string that partitions the SSL
+ // session cache. i.e. sessions created with one value will not attempt to
+ // resume on the socket with a different value.
+ const std::string ssl_session_cache_shard_;
+
+ // True if the SSL handshake has been completed.
+ bool completed_handshake_;
+
+ State next_handshake_state_;
+
+ // The NSS SSL state machine. This is owned by |core_|.
+ // TODO(rsleevi): http://crbug.com/130616 - Remove this member once
+ // ExportKeyingMaterial is updated to be asynchronous.
+ PRFileDesc* nss_fd_;
+
+ BoundNetLog net_log_;
+
+ base::TimeTicks start_cert_verification_time_;
+
+ TransportSecurityState* transport_security_state_;
+
+ // The following two variables are added for debugging bug 65948. Will
+ // remove this code after fixing bug 65948.
+ // Added the following code Debugging in release mode.
+ mutable base::Lock lock_;
+ // This is mutable so that CalledOnValidThread can set it.
+ // It's guarded by |lock_|.
+ mutable base::PlatformThreadId valid_thread_id_;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_CLIENT_SOCKET_NSS_H_
diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc
new file mode 100644
index 00000000000..4591cec5b9d
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_openssl.cc
@@ -0,0 +1,1435 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// OpenSSL binding for SSLClientSocket. The class layout and general principle
+// of operation is derived from SSLClientSocketNSS.
+
+#include "net/socket/ssl_client_socket_openssl.h"
+
+#include <openssl/err.h>
+#include <openssl/opensslv.h>
+#include <openssl/ssl.h>
+
+#include "base/bind.h"
+#include "base/callback_helpers.h"
+#include "base/memory/singleton.h"
+#include "base/metrics/histogram.h"
+#include "base/synchronization/lock.h"
+#include "crypto/openssl_util.h"
+#include "net/base/net_errors.h"
+#include "net/cert/cert_verifier.h"
+#include "net/cert/single_request_cert_verifier.h"
+#include "net/cert/x509_certificate_net_log_param.h"
+#include "net/socket/ssl_error_params.h"
+#include "net/ssl/openssl_client_key_store.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_connection_status_flags.h"
+#include "net/ssl/ssl_info.h"
+
+namespace net {
+
+namespace {
+
+// Enable this to see logging for state machine state transitions.
+#if 0
+#define GotoState(s) do { DVLOG(2) << (void *)this << " " << __FUNCTION__ << \
+ " jump to state " << s; \
+ next_handshake_state_ = s; } while (0)
+#else
+#define GotoState(s) next_handshake_state_ = s
+#endif
+
+const int kSessionCacheTimeoutSeconds = 60 * 60;
+const size_t kSessionCacheMaxEntires = 1024;
+
+// This constant can be any non-negative/non-zero value (eg: it does not
+// overlap with any value of the net::Error range, including net::OK).
+const int kNoPendingReadResult = 1;
+
+// If a client doesn't have a list of protocols that it supports, but
+// the server supports NPN, choosing "http/1.1" is the best answer.
+const char kDefaultSupportedNPNProtocol[] = "http/1.1";
+
+#if OPENSSL_VERSION_NUMBER < 0x1000103fL
+// This method doesn't seem to have made it into the OpenSSL headers.
+unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { return cipher->id; }
+#endif
+
+// Used for encoding the |connection_status| field of an SSLInfo object.
+int EncodeSSLConnectionStatus(int cipher_suite,
+ int compression,
+ int version) {
+ return ((cipher_suite & SSL_CONNECTION_CIPHERSUITE_MASK) <<
+ SSL_CONNECTION_CIPHERSUITE_SHIFT) |
+ ((compression & SSL_CONNECTION_COMPRESSION_MASK) <<
+ SSL_CONNECTION_COMPRESSION_SHIFT) |
+ ((version & SSL_CONNECTION_VERSION_MASK) <<
+ SSL_CONNECTION_VERSION_SHIFT);
+}
+
+// Returns the net SSL version number (see ssl_connection_status_flags.h) for
+// this SSL connection.
+int GetNetSSLVersion(SSL* ssl) {
+ switch (SSL_version(ssl)) {
+ case SSL2_VERSION:
+ return SSL_CONNECTION_VERSION_SSL2;
+ case SSL3_VERSION:
+ return SSL_CONNECTION_VERSION_SSL3;
+ case TLS1_VERSION:
+ return SSL_CONNECTION_VERSION_TLS1;
+ case 0x0302:
+ return SSL_CONNECTION_VERSION_TLS1_1;
+ case 0x0303:
+ return SSL_CONNECTION_VERSION_TLS1_2;
+ default:
+ return SSL_CONNECTION_VERSION_UNKNOWN;
+ }
+}
+
+int MapOpenSSLErrorSSL() {
+ // Walk down the error stack to find the SSLerr generated reason.
+ unsigned long error_code;
+ do {
+ error_code = ERR_get_error();
+ if (error_code == 0)
+ return ERR_SSL_PROTOCOL_ERROR;
+ } while (ERR_GET_LIB(error_code) != ERR_LIB_SSL);
+
+ DVLOG(1) << "OpenSSL SSL error, reason: " << ERR_GET_REASON(error_code)
+ << ", name: " << ERR_error_string(error_code, NULL);
+ switch (ERR_GET_REASON(error_code)) {
+ case SSL_R_READ_TIMEOUT_EXPIRED:
+ return ERR_TIMED_OUT;
+ case SSL_R_BAD_RESPONSE_ARGUMENT:
+ return ERR_INVALID_ARGUMENT;
+ case SSL_R_UNKNOWN_CERTIFICATE_TYPE:
+ case SSL_R_UNKNOWN_CIPHER_TYPE:
+ case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE:
+ case SSL_R_UNKNOWN_PKEY_TYPE:
+ case SSL_R_UNKNOWN_REMOTE_ERROR_TYPE:
+ case SSL_R_UNKNOWN_SSL_VERSION:
+ return ERR_NOT_IMPLEMENTED;
+ case SSL_R_UNSUPPORTED_SSL_VERSION:
+ case SSL_R_NO_CIPHER_MATCH:
+ case SSL_R_NO_SHARED_CIPHER:
+ case SSL_R_TLSV1_ALERT_INSUFFICIENT_SECURITY:
+ case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION:
+ case SSL_R_UNSUPPORTED_PROTOCOL:
+ return ERR_SSL_VERSION_OR_CIPHER_MISMATCH;
+ case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE:
+ case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED:
+ case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN:
+ case SSL_R_TLSV1_ALERT_ACCESS_DENIED:
+ case SSL_R_TLSV1_ALERT_UNKNOWN_CA:
+ return ERR_BAD_SSL_CLIENT_AUTH_CERT;
+ case SSL_R_BAD_DECOMPRESSION:
+ case SSL_R_SSLV3_ALERT_DECOMPRESSION_FAILURE:
+ return ERR_SSL_DECOMPRESSION_FAILURE_ALERT;
+ case SSL_R_SSLV3_ALERT_BAD_RECORD_MAC:
+ return ERR_SSL_BAD_RECORD_MAC_ALERT;
+ case SSL_R_TLSV1_ALERT_DECRYPT_ERROR:
+ return ERR_SSL_DECRYPT_ERROR_ALERT;
+ case SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED:
+ return ERR_SSL_UNSAFE_NEGOTIATION;
+ case SSL_R_WRONG_NUMBER_OF_KEY_BITS:
+ return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY;
+ // SSL_R_UNKNOWN_PROTOCOL is reported if premature application data is
+ // received (see http://crbug.com/42538), and also if all the protocol
+ // versions supported by the server were disabled in this socket instance.
+ // Mapped to ERR_SSL_PROTOCOL_ERROR for compatibility with other SSL sockets
+ // in the former scenario.
+ case SSL_R_UNKNOWN_PROTOCOL:
+ case SSL_R_SSL_HANDSHAKE_FAILURE:
+ case SSL_R_DECRYPTION_FAILED:
+ case SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC:
+ case SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG:
+ case SSL_R_DIGEST_CHECK_FAILED:
+ case SSL_R_DUPLICATE_COMPRESSION_ID:
+ case SSL_R_ECGROUP_TOO_LARGE_FOR_CIPHER:
+ case SSL_R_ENCRYPTED_LENGTH_TOO_LONG:
+ case SSL_R_ERROR_IN_RECEIVED_CIPHER_LIST:
+ case SSL_R_EXCESSIVE_MESSAGE_SIZE:
+ case SSL_R_EXTRA_DATA_IN_MESSAGE:
+ case SSL_R_GOT_A_FIN_BEFORE_A_CCS:
+ case SSL_R_ILLEGAL_PADDING:
+ case SSL_R_INVALID_CHALLENGE_LENGTH:
+ case SSL_R_INVALID_COMMAND:
+ case SSL_R_INVALID_PURPOSE:
+ case SSL_R_INVALID_STATUS_RESPONSE:
+ case SSL_R_INVALID_TICKET_KEYS_LENGTH:
+ case SSL_R_KEY_ARG_TOO_LONG:
+ case SSL_R_READ_WRONG_PACKET_TYPE:
+ case SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE:
+ // TODO(joth): SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE may be returned from the
+ // server after receiving ClientHello if there's no common supported cipher.
+ // Ideally we'd map that specific case to ERR_SSL_VERSION_OR_CIPHER_MISMATCH
+ // to match the NSS implementation. See also http://goo.gl/oMtZW
+ case SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE:
+ case SSL_R_SSLV3_ALERT_NO_CERTIFICATE:
+ case SSL_R_SSLV3_ALERT_ILLEGAL_PARAMETER:
+ case SSL_R_TLSV1_ALERT_DECODE_ERROR:
+ case SSL_R_TLSV1_ALERT_DECRYPTION_FAILED:
+ case SSL_R_TLSV1_ALERT_EXPORT_RESTRICTION:
+ case SSL_R_TLSV1_ALERT_INTERNAL_ERROR:
+ case SSL_R_TLSV1_ALERT_NO_RENEGOTIATION:
+ case SSL_R_TLSV1_ALERT_RECORD_OVERFLOW:
+ case SSL_R_TLSV1_ALERT_USER_CANCELLED:
+ return ERR_SSL_PROTOCOL_ERROR;
+ default:
+ LOG(WARNING) << "Unmapped error reason: " << ERR_GET_REASON(error_code);
+ return ERR_FAILED;
+ }
+}
+
+// Converts an OpenSSL error code into a net error code, walking the OpenSSL
+// error stack if needed. Note that |tracer| is not currently used in the
+// implementation, but is passed in anyway as this ensures the caller will clear
+// any residual codes left on the error stack.
+int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer) {
+ switch (err) {
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ return ERR_IO_PENDING;
+ case SSL_ERROR_SYSCALL:
+ LOG(ERROR) << "OpenSSL SYSCALL error, earliest error code in "
+ "error queue: " << ERR_peek_error() << ", errno: "
+ << errno;
+ return ERR_SSL_PROTOCOL_ERROR;
+ case SSL_ERROR_SSL:
+ return MapOpenSSLErrorSSL();
+ default:
+ // TODO(joth): Implement full mapping.
+ LOG(WARNING) << "Unknown OpenSSL error " << err;
+ return ERR_SSL_PROTOCOL_ERROR;
+ }
+}
+
+// We do certificate verification after handshake, so we disable the default
+// by registering a no-op verify function.
+int NoOpVerifyCallback(X509_STORE_CTX*, void *) {
+ DVLOG(3) << "skipping cert verify";
+ return 1;
+}
+
+// OpenSSL manages a cache of SSL_SESSION, this class provides the application
+// side policy for that cache about session re-use: we retain one session per
+// unique HostPortPair, per shard.
+class SSLSessionCache {
+ public:
+ SSLSessionCache() {}
+
+ void OnSessionAdded(const HostPortPair& host_and_port,
+ const std::string& shard,
+ SSL_SESSION* session) {
+ // Declare the session cleaner-upper before the lock, so any call into
+ // OpenSSL to free the session will happen after the lock is released.
+ crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free;
+ base::AutoLock lock(lock_);
+
+ DCHECK_EQ(0U, session_map_.count(session));
+ const std::string cache_key = GetCacheKey(host_and_port, shard);
+
+ std::pair<HostPortMap::iterator, bool> res =
+ host_port_map_.insert(std::make_pair(cache_key, session));
+ if (!res.second) { // Already exists: replace old entry.
+ session_to_free.reset(res.first->second);
+ session_map_.erase(session_to_free.get());
+ res.first->second = session;
+ }
+ DVLOG(2) << "Adding session " << session << " => "
+ << cache_key << ", new entry = " << res.second;
+ DCHECK(host_port_map_[cache_key] == session);
+ session_map_[session] = res.first;
+ DCHECK_EQ(host_port_map_.size(), session_map_.size());
+ DCHECK_LE(host_port_map_.size(), kSessionCacheMaxEntires);
+ }
+
+ void OnSessionRemoved(SSL_SESSION* session) {
+ // Declare the session cleaner-upper before the lock, so any call into
+ // OpenSSL to free the session will happen after the lock is released.
+ crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free;
+ base::AutoLock lock(lock_);
+
+ SessionMap::iterator it = session_map_.find(session);
+ if (it == session_map_.end())
+ return;
+ DVLOG(2) << "Remove session " << session << " => " << it->second->first;
+ DCHECK(it->second->second == session);
+ host_port_map_.erase(it->second);
+ session_map_.erase(it);
+ session_to_free.reset(session);
+ DCHECK_EQ(host_port_map_.size(), session_map_.size());
+ }
+
+ // Looks up the host:port in the cache, and if a session is found it is added
+ // to |ssl|, returning true on success.
+ bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port,
+ const std::string& shard) {
+ base::AutoLock lock(lock_);
+ const std::string cache_key = GetCacheKey(host_and_port, shard);
+ HostPortMap::iterator it = host_port_map_.find(cache_key);
+ if (it == host_port_map_.end())
+ return false;
+ DVLOG(2) << "Lookup session: " << it->second << " => " << cache_key;
+ SSL_SESSION* session = it->second;
+ DCHECK(session);
+ DCHECK(session_map_[session] == it);
+ // Ideally we'd release |lock_| before calling into OpenSSL here, however
+ // that opens a small risk |session| will go out of scope before it is used.
+ // Alternatively we would take a temporary local refcount on |session|,
+ // except OpenSSL does not provide a public API for adding a ref (c.f.
+ // SSL_SESSION_free which decrements the ref).
+ return SSL_set_session(ssl, session) == 1;
+ }
+
+ // Flush removes all entries from the cache. This is called when a client
+ // certificate is added.
+ void Flush() {
+ for (HostPortMap::iterator i = host_port_map_.begin();
+ i != host_port_map_.end(); i++) {
+ SSL_SESSION_free(i->second);
+ }
+ host_port_map_.clear();
+ session_map_.clear();
+ }
+
+ private:
+ static std::string GetCacheKey(const HostPortPair& host_and_port,
+ const std::string& shard) {
+ return host_and_port.ToString() + "/" + shard;
+ }
+
+ // A pair of maps to allow bi-directional lookups between host:port and an
+ // associated session.
+ typedef std::map<std::string, SSL_SESSION*> HostPortMap;
+ typedef std::map<SSL_SESSION*, HostPortMap::iterator> SessionMap;
+ HostPortMap host_port_map_;
+ SessionMap session_map_;
+
+ // Protects access to both the above maps.
+ base::Lock lock_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLSessionCache);
+};
+
+class SSLContext {
+ public:
+ static SSLContext* GetInstance() { return Singleton<SSLContext>::get(); }
+ SSL_CTX* ssl_ctx() { return ssl_ctx_.get(); }
+ SSLSessionCache* session_cache() { return &session_cache_; }
+
+ SSLClientSocketOpenSSL* GetClientSocketFromSSL(SSL* ssl) {
+ DCHECK(ssl);
+ SSLClientSocketOpenSSL* socket = static_cast<SSLClientSocketOpenSSL*>(
+ SSL_get_ex_data(ssl, ssl_socket_data_index_));
+ DCHECK(socket);
+ return socket;
+ }
+
+ bool SetClientSocketForSSL(SSL* ssl, SSLClientSocketOpenSSL* socket) {
+ return SSL_set_ex_data(ssl, ssl_socket_data_index_, socket) != 0;
+ }
+
+ private:
+ friend struct DefaultSingletonTraits<SSLContext>;
+
+ SSLContext() {
+ crypto::EnsureOpenSSLInit();
+ ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0);
+ DCHECK_NE(ssl_socket_data_index_, -1);
+ ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method()));
+ SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), NoOpVerifyCallback, NULL);
+ SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_CLIENT);
+ SSL_CTX_sess_set_new_cb(ssl_ctx_.get(), NewSessionCallbackStatic);
+ SSL_CTX_sess_set_remove_cb(ssl_ctx_.get(), RemoveSessionCallbackStatic);
+ SSL_CTX_set_timeout(ssl_ctx_.get(), kSessionCacheTimeoutSeconds);
+ SSL_CTX_sess_set_cache_size(ssl_ctx_.get(), kSessionCacheMaxEntires);
+ SSL_CTX_set_client_cert_cb(ssl_ctx_.get(), ClientCertCallback);
+#if defined(OPENSSL_NPN_NEGOTIATED)
+ // TODO(kristianm): Only select this if ssl_config_.next_proto is not empty.
+ // It would be better if the callback were not a global setting,
+ // but that is an OpenSSL issue.
+ SSL_CTX_set_next_proto_select_cb(ssl_ctx_.get(), SelectNextProtoCallback,
+ NULL);
+#endif
+ }
+
+ static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) {
+ return GetInstance()->NewSessionCallback(ssl, session);
+ }
+
+ int NewSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ SSLClientSocketOpenSSL* socket = GetClientSocketFromSSL(ssl);
+ session_cache_.OnSessionAdded(socket->host_and_port(),
+ socket->ssl_session_cache_shard(),
+ session);
+ return 1; // 1 => We took ownership of |session|.
+ }
+
+ static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) {
+ return GetInstance()->RemoveSessionCallback(ctx, session);
+ }
+
+ void RemoveSessionCallback(SSL_CTX* ctx, SSL_SESSION* session) {
+ DCHECK(ctx == ssl_ctx());
+ session_cache_.OnSessionRemoved(session);
+ }
+
+ static int ClientCertCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey) {
+ SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl);
+ CHECK(socket);
+ return socket->ClientCertRequestCallback(ssl, x509, pkey);
+ }
+
+ static int SelectNextProtoCallback(SSL* ssl,
+ unsigned char** out, unsigned char* outlen,
+ const unsigned char* in,
+ unsigned int inlen, void* arg) {
+ SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl);
+ return socket->SelectNextProtoCallback(out, outlen, in, inlen);
+ }
+
+ // This is the index used with SSL_get_ex_data to retrieve the owner
+ // SSLClientSocketOpenSSL object from an SSL instance.
+ int ssl_socket_data_index_;
+
+ // session_cache_ must appear before |ssl_ctx_| because the destruction of
+ // |ssl_ctx_| may trigger callbacks into |session_cache_|. Therefore,
+ // |session_cache_| must be destructed after |ssl_ctx_|.
+ SSLSessionCache session_cache_;
+ crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_;
+};
+
+// Utility to construct the appropriate set & clear masks for use the OpenSSL
+// options and mode configuration functions. (SSL_set_options etc)
+struct SslSetClearMask {
+ SslSetClearMask() : set_mask(0), clear_mask(0) {}
+ void ConfigureFlag(long flag, bool state) {
+ (state ? set_mask : clear_mask) |= flag;
+ // Make sure we haven't got any intersection in the set & clear options.
+ DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state;
+ }
+ long set_mask;
+ long clear_mask;
+};
+
+} // namespace
+
+// static
+void SSLClientSocket::ClearSessionCache() {
+ SSLContext* context = SSLContext::GetInstance();
+ context->session_cache()->Flush();
+}
+
+SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context)
+ : transport_send_busy_(false),
+ transport_recv_busy_(false),
+ transport_recv_eof_(false),
+ weak_factory_(this),
+ pending_read_error_(kNoPendingReadResult),
+ completed_handshake_(false),
+ client_auth_cert_needed_(false),
+ cert_verifier_(context.cert_verifier),
+ ssl_(NULL),
+ transport_bio_(NULL),
+ transport_(transport_socket.Pass()),
+ host_and_port_(host_and_port),
+ ssl_config_(ssl_config),
+ ssl_session_cache_shard_(context.ssl_session_cache_shard),
+ trying_cached_session_(false),
+ next_handshake_state_(STATE_NONE),
+ npn_status_(kNextProtoUnsupported),
+ net_log_(transport_->socket()->NetLog()) {
+}
+
+SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
+ Disconnect();
+}
+
+bool SSLClientSocketOpenSSL::Init() {
+ DCHECK(!ssl_);
+ DCHECK(!transport_bio_);
+
+ SSLContext* context = SSLContext::GetInstance();
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+
+ ssl_ = SSL_new(context->ssl_ctx());
+ if (!ssl_ || !context->SetClientSocketForSSL(ssl_, this))
+ return false;
+
+ if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str()))
+ return false;
+
+ trying_cached_session_ =
+ context->session_cache()->SetSSLSession(ssl_, host_and_port_,
+ ssl_session_cache_shard_);
+
+ BIO* ssl_bio = NULL;
+ // 0 => use default buffer sizes.
+ if (!BIO_new_bio_pair(&ssl_bio, 0, &transport_bio_, 0))
+ return false;
+ DCHECK(ssl_bio);
+ DCHECK(transport_bio_);
+
+ SSL_set_bio(ssl_, ssl_bio, ssl_bio);
+
+ // OpenSSL defaults some options to on, others to off. To avoid ambiguity,
+ // set everything we care about to an absolute value.
+ SslSetClearMask options;
+ options.ConfigureFlag(SSL_OP_NO_SSLv2, true);
+ bool ssl3_enabled = (ssl_config_.version_min == SSL_PROTOCOL_VERSION_SSL3);
+ options.ConfigureFlag(SSL_OP_NO_SSLv3, !ssl3_enabled);
+ bool tls1_enabled = (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1 &&
+ ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1);
+ options.ConfigureFlag(SSL_OP_NO_TLSv1, !tls1_enabled);
+#if defined(SSL_OP_NO_TLSv1_1)
+ bool tls1_1_enabled =
+ (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_1 &&
+ ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1);
+ options.ConfigureFlag(SSL_OP_NO_TLSv1_1, !tls1_1_enabled);
+#endif
+#if defined(SSL_OP_NO_TLSv1_2)
+ bool tls1_2_enabled =
+ (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_2 &&
+ ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_2);
+ options.ConfigureFlag(SSL_OP_NO_TLSv1_2, !tls1_2_enabled);
+#endif
+
+#if defined(SSL_OP_NO_COMPRESSION)
+ options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true);
+#endif
+
+ // TODO(joth): Set this conditionally, see http://crbug.com/55410
+ options.ConfigureFlag(SSL_OP_LEGACY_SERVER_CONNECT, true);
+
+ SSL_set_options(ssl_, options.set_mask);
+ SSL_clear_options(ssl_, options.clear_mask);
+
+ // Same as above, this time for the SSL mode.
+ SslSetClearMask mode;
+
+#if defined(SSL_MODE_RELEASE_BUFFERS)
+ mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true);
+#endif
+
+#if defined(SSL_MODE_SMALL_BUFFERS)
+ mode.ConfigureFlag(SSL_MODE_SMALL_BUFFERS, true);
+#endif
+
+ SSL_set_mode(ssl_, mode.set_mask);
+ SSL_clear_mode(ssl_, mode.clear_mask);
+
+ // Removing ciphers by ID from OpenSSL is a bit involved as we must use the
+ // textual name with SSL_set_cipher_list because there is no public API to
+ // directly remove a cipher by ID.
+ STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_);
+ DCHECK(ciphers);
+ // See SSLConfig::disabled_cipher_suites for description of the suites
+ // disabled by default. Note that !SHA384 only removes HMAC-SHA384 cipher
+ // suites, not GCM cipher suites with SHA384 as the handshake hash.
+ std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA384:!aECDH");
+ // Walk through all the installed ciphers, seeing if any need to be
+ // appended to the cipher removal |command|.
+ for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) {
+ const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i);
+ const uint16 id = SSL_CIPHER_get_id(cipher);
+ // Remove any ciphers with a strength of less than 80 bits. Note the NSS
+ // implementation uses "effective" bits here but OpenSSL does not provide
+ // this detail. This only impacts Triple DES: reports 112 vs. 168 bits,
+ // both of which are greater than 80 anyway.
+ bool disable = SSL_CIPHER_get_bits(cipher, NULL) < 80;
+ if (!disable) {
+ disable = std::find(ssl_config_.disabled_cipher_suites.begin(),
+ ssl_config_.disabled_cipher_suites.end(), id) !=
+ ssl_config_.disabled_cipher_suites.end();
+ }
+ if (disable) {
+ const char* name = SSL_CIPHER_get_name(cipher);
+ DVLOG(3) << "Found cipher to remove: '" << name << "', ID: " << id
+ << " strength: " << SSL_CIPHER_get_bits(cipher, NULL);
+ command.append(":!");
+ command.append(name);
+ }
+ }
+ int rv = SSL_set_cipher_list(ssl_, command.c_str());
+ // If this fails (rv = 0) it means there are no ciphers enabled on this SSL.
+ // This will almost certainly result in the socket failing to complete the
+ // handshake at which point the appropriate error is bubbled up to the client.
+ LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') "
+ "returned " << rv;
+ return true;
+}
+
+int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl,
+ X509** x509,
+ EVP_PKEY** pkey) {
+ DVLOG(3) << "OpenSSL ClientCertRequestCallback called";
+ DCHECK(ssl == ssl_);
+ DCHECK(*x509 == NULL);
+ DCHECK(*pkey == NULL);
+
+ if (!ssl_config_.send_client_cert) {
+ // First pass: we know that a client certificate is needed, but we do not
+ // have one at hand.
+ client_auth_cert_needed_ = true;
+ STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl);
+ for (int i = 0; i < sk_X509_NAME_num(authorities); i++) {
+ X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i);
+ unsigned char* str = NULL;
+ int length = i2d_X509_NAME(ca_name, &str);
+ cert_authorities_.push_back(std::string(
+ reinterpret_cast<const char*>(str),
+ static_cast<size_t>(length)));
+ OPENSSL_free(str);
+ }
+
+ return -1; // Suspends handshake.
+ }
+
+ // Second pass: a client certificate should have been selected.
+ if (ssl_config_.client_cert.get()) {
+ // A note about ownership: FetchClientCertPrivateKey() increments
+ // the reference count of the EVP_PKEY. Ownership of this reference
+ // is passed directly to OpenSSL, which will release the reference
+ // using EVP_PKEY_free() when the SSL object is destroyed.
+ OpenSSLClientKeyStore::ScopedEVP_PKEY privkey;
+ if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey(
+ ssl_config_.client_cert.get(), &privkey)) {
+ // TODO(joth): (copied from NSS) We should wait for server certificate
+ // verification before sending our credentials. See http://crbug.com/13934
+ *x509 = X509Certificate::DupOSCertHandle(
+ ssl_config_.client_cert->os_cert_handle());
+ *pkey = privkey.release();
+ return 1;
+ }
+ LOG(WARNING) << "Client cert found without private key";
+ }
+
+ // Send no client certificate.
+ return 0;
+}
+
+// SSLClientSocket methods
+
+bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) {
+ ssl_info->Reset();
+ if (!server_cert_.get())
+ return false;
+
+ ssl_info->cert = server_cert_verify_result_.verified_cert;
+ ssl_info->cert_status = server_cert_verify_result_.cert_status;
+ ssl_info->is_issued_by_known_root =
+ server_cert_verify_result_.is_issued_by_known_root;
+ ssl_info->public_key_hashes =
+ server_cert_verify_result_.public_key_hashes;
+ ssl_info->client_cert_sent =
+ ssl_config_.send_client_cert && ssl_config_.client_cert.get();
+ ssl_info->channel_id_sent = WasChannelIDSent();
+
+ const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_);
+ CHECK(cipher);
+ ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL);
+ const COMP_METHOD* compression = SSL_get_current_compression(ssl_);
+
+ ssl_info->connection_status = EncodeSSLConnectionStatus(
+ SSL_CIPHER_get_id(cipher),
+ compression ? compression->type : 0,
+ GetNetSSLVersion(ssl_));
+
+ bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_);
+ if (!peer_supports_renego_ext)
+ ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION;
+ UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported",
+ implicit_cast<int>(peer_supports_renego_ext), 2);
+
+ if (ssl_config_.version_fallback)
+ ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK;
+
+ ssl_info->handshake_type = SSL_session_reused(ssl_) ?
+ SSLInfo::HANDSHAKE_RESUME : SSLInfo::HANDSHAKE_FULL;
+
+ DVLOG(3) << "Encoded connection status: cipher suite = "
+ << SSLConnectionStatusToCipherSuite(ssl_info->connection_status)
+ << " version = "
+ << SSLConnectionStatusToVersion(ssl_info->connection_status);
+ return true;
+}
+
+void SSLClientSocketOpenSSL::GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) {
+ cert_request_info->host_and_port = host_and_port_.ToString();
+ cert_request_info->cert_authorities = cert_authorities_;
+}
+
+int SSLClientSocketOpenSSL::ExportKeyingMaterial(
+ const base::StringPiece& label,
+ bool has_context, const base::StringPiece& context,
+ unsigned char* out, unsigned int outlen) {
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+
+ int rv = SSL_export_keying_material(
+ ssl_, out, outlen, const_cast<char*>(label.data()),
+ label.size(),
+ reinterpret_cast<unsigned char*>(const_cast<char*>(context.data())),
+ context.length(),
+ context.length() > 0);
+
+ if (rv != 1) {
+ int ssl_error = SSL_get_error(ssl_, rv);
+ LOG(ERROR) << "Failed to export keying material;"
+ << " returned " << rv
+ << ", SSL error code " << ssl_error;
+ return MapOpenSSLError(ssl_error, err_tracer);
+ }
+ return OK;
+}
+
+int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) {
+ return ERR_NOT_IMPLEMENTED;
+}
+
+SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto(
+ std::string* proto, std::string* server_protos) {
+ *proto = npn_proto_;
+ *server_protos = server_protos_;
+ return npn_status_;
+}
+
+ServerBoundCertService*
+SSLClientSocketOpenSSL::GetServerBoundCertService() const {
+ return NULL;
+}
+
+void SSLClientSocketOpenSSL::DoReadCallback(int rv) {
+ // Since Run may result in Read being called, clear |user_read_callback_|
+ // up front.
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ base::ResetAndReturn(&user_read_callback_).Run(rv);
+}
+
+void SSLClientSocketOpenSSL::DoWriteCallback(int rv) {
+ // Since Run may result in Write being called, clear |user_write_callback_|
+ // up front.
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+ base::ResetAndReturn(&user_write_callback_).Run(rv);
+}
+
+// StreamSocket implementation.
+int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) {
+ net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT);
+
+ // Set up new ssl object.
+ if (!Init()) {
+ int result = ERR_UNEXPECTED;
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, result);
+ return result;
+ }
+
+ // Set SSL to client mode. Handshake happens in the loop below.
+ SSL_set_connect_state(ssl_);
+
+ GotoState(STATE_HANDSHAKE);
+ int rv = DoHandshakeLoop(net::OK);
+ if (rv == ERR_IO_PENDING) {
+ user_connect_callback_ = callback;
+ } else {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ }
+
+ return rv > OK ? OK : rv;
+}
+
+void SSLClientSocketOpenSSL::Disconnect() {
+ if (ssl_) {
+ // Calling SSL_shutdown prevents the session from being marked as
+ // unresumable.
+ SSL_shutdown(ssl_);
+ SSL_free(ssl_);
+ ssl_ = NULL;
+ }
+ if (transport_bio_) {
+ BIO_free_all(transport_bio_);
+ transport_bio_ = NULL;
+ }
+
+ // Shut down anything that may call us back.
+ verifier_.reset();
+ transport_->socket()->Disconnect();
+
+ // Null all callbacks, delete all buffers.
+ transport_send_busy_ = false;
+ send_buffer_ = NULL;
+ transport_recv_busy_ = false;
+ transport_recv_eof_ = false;
+ recv_buffer_ = NULL;
+
+ user_connect_callback_.Reset();
+ user_read_callback_.Reset();
+ user_write_callback_.Reset();
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+
+ server_cert_verify_result_.Reset();
+ completed_handshake_ = false;
+
+ cert_authorities_.clear();
+ client_auth_cert_needed_ = false;
+}
+
+int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) {
+ int rv = last_io_result;
+ do {
+ // Default to STATE_NONE for next state.
+ // (This is a quirk carried over from the windows
+ // implementation. It makes reading the logs a bit harder.)
+ // State handlers can and often do call GotoState just
+ // to stay in the current state.
+ State state = next_handshake_state_;
+ GotoState(STATE_NONE);
+ switch (state) {
+ case STATE_HANDSHAKE:
+ rv = DoHandshake();
+ break;
+ case STATE_VERIFY_CERT:
+ DCHECK(rv == OK);
+ rv = DoVerifyCert(rv);
+ break;
+ case STATE_VERIFY_CERT_COMPLETE:
+ rv = DoVerifyCertComplete(rv);
+ break;
+ case STATE_NONE:
+ default:
+ rv = ERR_UNEXPECTED;
+ NOTREACHED() << "unexpected state" << state;
+ break;
+ }
+
+ bool network_moved = DoTransportIO();
+ if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) {
+ // In general we exit the loop if rv is ERR_IO_PENDING. In this
+ // special case we keep looping even if rv is ERR_IO_PENDING because
+ // the transport IO may allow DoHandshake to make progress.
+ rv = OK; // This causes us to stay in the loop.
+ }
+ } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
+ return rv;
+}
+
+int SSLClientSocketOpenSSL::DoHandshake() {
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+ int net_error = net::OK;
+ int rv = SSL_do_handshake(ssl_);
+
+ if (client_auth_cert_needed_) {
+ net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED;
+ // If the handshake already succeeded (because the server requests but
+ // doesn't require a client cert), we need to invalidate the SSL session
+ // so that we won't try to resume the non-client-authenticated session in
+ // the next handshake. This will cause the server to ask for a client
+ // cert again.
+ if (rv == 1) {
+ // Remove from session cache but don't clear this connection.
+ SSL_SESSION* session = SSL_get_session(ssl_);
+ if (session) {
+ int rv = SSL_CTX_remove_session(SSL_get_SSL_CTX(ssl_), session);
+ LOG_IF(WARNING, !rv) << "Couldn't invalidate SSL session: " << session;
+ }
+ }
+ } else if (rv == 1) {
+ if (trying_cached_session_ && logging::DEBUG_MODE) {
+ DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString()
+ << " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail");
+ }
+ // SSL handshake is completed. Let's verify the certificate.
+ const bool got_cert = !!UpdateServerCert();
+ DCHECK(got_cert);
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_CERTIFICATES_RECEIVED,
+ base::Bind(&NetLogX509CertificateCallback,
+ base::Unretained(server_cert_.get())));
+ GotoState(STATE_VERIFY_CERT);
+ } else {
+ int ssl_error = SSL_get_error(ssl_, rv);
+ net_error = MapOpenSSLError(ssl_error, err_tracer);
+
+ // If not done, stay in this state
+ if (net_error == ERR_IO_PENDING) {
+ GotoState(STATE_HANDSHAKE);
+ } else {
+ LOG(ERROR) << "handshake failed; returned " << rv
+ << ", SSL error code " << ssl_error
+ << ", net_error " << net_error;
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogSSLErrorCallback(net_error, ssl_error));
+ }
+ }
+ return net_error;
+}
+
+// SelectNextProtoCallback is called by OpenSSL during the handshake. If the
+// server supports NPN, selects a protocol from the list that the server
+// provides. According to third_party/openssl/openssl/ssl/ssl_lib.c, the
+// callback can assume that |in| is syntactically valid.
+int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out,
+ unsigned char* outlen,
+ const unsigned char* in,
+ unsigned int inlen) {
+#if defined(OPENSSL_NPN_NEGOTIATED)
+ if (ssl_config_.next_protos.empty()) {
+ *out = reinterpret_cast<uint8*>(
+ const_cast<char*>(kDefaultSupportedNPNProtocol));
+ *outlen = arraysize(kDefaultSupportedNPNProtocol) - 1;
+ npn_status_ = kNextProtoUnsupported;
+ return SSL_TLSEXT_ERR_OK;
+ }
+
+ // Assume there's no overlap between our protocols and the server's list.
+ npn_status_ = kNextProtoNoOverlap;
+
+ // For each protocol in server preference order, see if we support it.
+ for (unsigned int i = 0; i < inlen; i += in[i] + 1) {
+ for (std::vector<std::string>::const_iterator
+ j = ssl_config_.next_protos.begin();
+ j != ssl_config_.next_protos.end(); ++j) {
+ if (in[i] == j->size() &&
+ memcmp(&in[i + 1], j->data(), in[i]) == 0) {
+ // We found a match.
+ *out = const_cast<unsigned char*>(in) + i + 1;
+ *outlen = in[i];
+ npn_status_ = kNextProtoNegotiated;
+ break;
+ }
+ }
+ if (npn_status_ == kNextProtoNegotiated)
+ break;
+ }
+
+ // If we didn't find a protocol, we select the first one from our list.
+ if (npn_status_ == kNextProtoNoOverlap) {
+ *out = reinterpret_cast<uint8*>(const_cast<char*>(
+ ssl_config_.next_protos[0].data()));
+ *outlen = ssl_config_.next_protos[0].size();
+ }
+
+ npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen);
+ server_protos_.assign(reinterpret_cast<const char*>(in), inlen);
+ DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_;
+#endif
+ return SSL_TLSEXT_ERR_OK;
+}
+
+int SSLClientSocketOpenSSL::DoVerifyCert(int result) {
+ DCHECK(server_cert_.get());
+ GotoState(STATE_VERIFY_CERT_COMPLETE);
+
+ CertStatus cert_status;
+ if (ssl_config_.IsAllowedBadCert(server_cert_.get(), &cert_status)) {
+ VLOG(1) << "Received an expected bad cert with status: " << cert_status;
+ server_cert_verify_result_.Reset();
+ server_cert_verify_result_.cert_status = cert_status;
+ server_cert_verify_result_.verified_cert = server_cert_;
+ return OK;
+ }
+
+ int flags = 0;
+ if (ssl_config_.rev_checking_enabled)
+ flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED;
+ if (ssl_config_.verify_ev_cert)
+ flags |= CertVerifier::VERIFY_EV_CERT;
+ if (ssl_config_.cert_io_enabled)
+ flags |= CertVerifier::VERIFY_CERT_IO_ENABLED;
+ if (ssl_config_.rev_checking_required_local_anchors)
+ flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS;
+ verifier_.reset(new SingleRequestCertVerifier(cert_verifier_));
+ return verifier_->Verify(
+ server_cert_.get(),
+ host_and_port_.host(),
+ flags,
+ NULL /* no CRL set */,
+ &server_cert_verify_result_,
+ base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete,
+ base::Unretained(this)),
+ net_log_);
+}
+
+int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) {
+ verifier_.reset();
+
+ if (result == OK) {
+ // TODO(joth): Work out if we need to remember the intermediate CA certs
+ // when the server sends them to us, and do so here.
+ } else {
+ DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result)
+ << " (" << result << ")";
+ }
+
+ completed_handshake_ = true;
+ // Exit DoHandshakeLoop and return the result to the caller to Connect.
+ DCHECK_EQ(STATE_NONE, next_handshake_state_);
+ return result;
+}
+
+X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() {
+ if (server_cert_.get())
+ return server_cert_.get();
+
+ crypto::ScopedOpenSSL<X509, X509_free> cert(SSL_get_peer_certificate(ssl_));
+ if (!cert.get()) {
+ LOG(WARNING) << "SSL_get_peer_certificate returned NULL";
+ return NULL;
+ }
+
+ // Unlike SSL_get_peer_certificate, SSL_get_peer_cert_chain does not
+ // increment the reference so sk_X509_free does not need to be called.
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
+ X509Certificate::OSCertHandles intermediates;
+ if (chain) {
+ for (int i = 0; i < sk_X509_num(chain); ++i)
+ intermediates.push_back(sk_X509_value(chain, i));
+ }
+ server_cert_ = X509Certificate::CreateFromHandle(cert.get(), intermediates);
+ DCHECK(server_cert_.get());
+
+ return server_cert_.get();
+}
+
+bool SSLClientSocketOpenSSL::DoTransportIO() {
+ bool network_moved = false;
+ int rv;
+ // Read and write as much data as possible. The loop is necessary because
+ // Write() may return synchronously.
+ do {
+ rv = BufferSend();
+ if (rv != ERR_IO_PENDING && rv != 0)
+ network_moved = true;
+ } while (rv > 0);
+ if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING)
+ network_moved = true;
+ return network_moved;
+}
+
+int SSLClientSocketOpenSSL::BufferSend(void) {
+ if (transport_send_busy_)
+ return ERR_IO_PENDING;
+
+ if (!send_buffer_.get()) {
+ // Get a fresh send buffer out of the send BIO.
+ size_t max_read = BIO_ctrl_pending(transport_bio_);
+ if (!max_read)
+ return 0; // Nothing pending in the OpenSSL write BIO.
+ send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read);
+ int read_bytes = BIO_read(transport_bio_, send_buffer_->data(), max_read);
+ DCHECK_GT(read_bytes, 0);
+ CHECK_EQ(static_cast<int>(max_read), read_bytes);
+ }
+
+ int rv = transport_->socket()->Write(
+ send_buffer_.get(),
+ send_buffer_->BytesRemaining(),
+ base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete,
+ base::Unretained(this)));
+ if (rv == ERR_IO_PENDING) {
+ transport_send_busy_ = true;
+ } else {
+ TransportWriteComplete(rv);
+ }
+ return rv;
+}
+
+void SSLClientSocketOpenSSL::BufferSendComplete(int result) {
+ transport_send_busy_ = false;
+ TransportWriteComplete(result);
+ OnSendComplete(result);
+}
+
+void SSLClientSocketOpenSSL::TransportWriteComplete(int result) {
+ DCHECK(ERR_IO_PENDING != result);
+ if (result < 0) {
+ // Got a socket write error; close the BIO to indicate this upward.
+ DVLOG(1) << "TransportWriteComplete error " << result;
+ (void)BIO_shutdown_wr(transport_bio_);
+ BIO_set_mem_eof_return(transport_bio_, 0);
+ send_buffer_ = NULL;
+ } else {
+ DCHECK(send_buffer_.get());
+ send_buffer_->DidConsume(result);
+ DCHECK_GE(send_buffer_->BytesRemaining(), 0);
+ if (send_buffer_->BytesRemaining() <= 0)
+ send_buffer_ = NULL;
+ }
+}
+
+int SSLClientSocketOpenSSL::BufferRecv(void) {
+ if (transport_recv_busy_)
+ return ERR_IO_PENDING;
+
+ // Determine how much was requested from |transport_bio_| that was not
+ // actually available.
+ size_t requested = BIO_ctrl_get_read_request(transport_bio_);
+ if (requested == 0) {
+ // This is not a perfect match of error codes, as no operation is
+ // actually pending. However, returning 0 would be interpreted as
+ // a possible sign of EOF, which is also an inappropriate match.
+ return ERR_IO_PENDING;
+ }
+
+ // Known Issue: While only reading |requested| data is the more correct
+ // implementation, it has the downside of resulting in frequent reads:
+ // One read for the SSL record header (~5 bytes) and one read for the SSL
+ // record body. Rather than issuing these reads to the underlying socket
+ // (and constantly allocating new IOBuffers), a single Read() request to
+ // fill |transport_bio_| is issued. As long as an SSL client socket cannot
+ // be gracefully shutdown (via SSL close alerts) and re-used for non-SSL
+ // traffic, this over-subscribed Read()ing will not cause issues.
+ size_t max_write = BIO_ctrl_get_write_guarantee(transport_bio_);
+ if (!max_write)
+ return ERR_IO_PENDING;
+
+ recv_buffer_ = new IOBuffer(max_write);
+ int rv = transport_->socket()->Read(
+ recv_buffer_.get(),
+ max_write,
+ base::Bind(&SSLClientSocketOpenSSL::BufferRecvComplete,
+ base::Unretained(this)));
+ if (rv == ERR_IO_PENDING) {
+ transport_recv_busy_ = true;
+ } else {
+ TransportReadComplete(rv);
+ }
+ return rv;
+}
+
+void SSLClientSocketOpenSSL::BufferRecvComplete(int result) {
+ TransportReadComplete(result);
+ OnRecvComplete(result);
+}
+
+void SSLClientSocketOpenSSL::TransportReadComplete(int result) {
+ DCHECK(ERR_IO_PENDING != result);
+ if (result <= 0) {
+ DVLOG(1) << "TransportReadComplete result " << result;
+ // Received 0 (end of file) or an error. Either way, bubble it up to the
+ // SSL layer via the BIO. TODO(joth): consider stashing the error code, to
+ // relay up to the SSL socket client (i.e. via DoReadCallback).
+ if (result == 0)
+ transport_recv_eof_ = true;
+ BIO_set_mem_eof_return(transport_bio_, 0);
+ (void)BIO_shutdown_wr(transport_bio_);
+ } else {
+ DCHECK(recv_buffer_.get());
+ int ret = BIO_write(transport_bio_, recv_buffer_->data(), result);
+ // A write into a memory BIO should always succeed.
+ CHECK_EQ(result, ret);
+ }
+ recv_buffer_ = NULL;
+ transport_recv_busy_ = false;
+}
+
+void SSLClientSocketOpenSSL::DoConnectCallback(int rv) {
+ if (!user_connect_callback_.is_null()) {
+ CompletionCallback c = user_connect_callback_;
+ user_connect_callback_.Reset();
+ c.Run(rv > OK ? OK : rv);
+ }
+}
+
+void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) {
+ int rv = DoHandshakeLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ DoConnectCallback(rv);
+ }
+}
+
+void SSLClientSocketOpenSSL::OnSendComplete(int result) {
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ // In handshake phase.
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ // OnSendComplete may need to call DoPayloadRead while the renegotiation
+ // handshake is in progress.
+ int rv_read = ERR_IO_PENDING;
+ int rv_write = ERR_IO_PENDING;
+ bool network_moved;
+ do {
+ if (user_read_buf_.get())
+ rv_read = DoPayloadRead();
+ if (user_write_buf_.get())
+ rv_write = DoPayloadWrite();
+ network_moved = DoTransportIO();
+ } while (rv_read == ERR_IO_PENDING && rv_write == ERR_IO_PENDING &&
+ (user_read_buf_.get() || user_write_buf_.get()) && network_moved);
+
+ // Performing the Read callback may cause |this| to be deleted. If this
+ // happens, the Write callback should not be invoked. Guard against this by
+ // holding a WeakPtr to |this| and ensuring it's still valid.
+ base::WeakPtr<SSLClientSocketOpenSSL> guard(weak_factory_.GetWeakPtr());
+ if (user_read_buf_.get() && rv_read != ERR_IO_PENDING)
+ DoReadCallback(rv_read);
+
+ if (!guard.get())
+ return;
+
+ if (user_write_buf_.get() && rv_write != ERR_IO_PENDING)
+ DoWriteCallback(rv_write);
+}
+
+void SSLClientSocketOpenSSL::OnRecvComplete(int result) {
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ // In handshake phase.
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ // Network layer received some data, check if client requested to read
+ // decrypted data.
+ if (!user_read_buf_.get())
+ return;
+
+ int rv = DoReadLoop(result);
+ if (rv != ERR_IO_PENDING)
+ DoReadCallback(rv);
+}
+
+bool SSLClientSocketOpenSSL::IsConnected() const {
+ // If the handshake has not yet completed.
+ if (!completed_handshake_)
+ return false;
+ // If an asynchronous operation is still pending.
+ if (user_read_buf_.get() || user_write_buf_.get())
+ return true;
+
+ return transport_->socket()->IsConnected();
+}
+
+bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const {
+ // If the handshake has not yet completed.
+ if (!completed_handshake_)
+ return false;
+ // If an asynchronous operation is still pending.
+ if (user_read_buf_.get() || user_write_buf_.get())
+ return false;
+ // If there is data waiting to be sent, or data read from the network that
+ // has not yet been consumed.
+ if (BIO_ctrl_pending(transport_bio_) > 0 ||
+ BIO_ctrl_wpending(transport_bio_) > 0) {
+ return false;
+ }
+
+ return transport_->socket()->IsConnectedAndIdle();
+}
+
+int SSLClientSocketOpenSSL::GetPeerAddress(IPEndPoint* addressList) const {
+ return transport_->socket()->GetPeerAddress(addressList);
+}
+
+int SSLClientSocketOpenSSL::GetLocalAddress(IPEndPoint* addressList) const {
+ return transport_->socket()->GetLocalAddress(addressList);
+}
+
+const BoundNetLog& SSLClientSocketOpenSSL::NetLog() const {
+ return net_log_;
+}
+
+void SSLClientSocketOpenSSL::SetSubresourceSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetSubresourceSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+void SSLClientSocketOpenSSL::SetOmniboxSpeculation() {
+ if (transport_.get() && transport_->socket()) {
+ transport_->socket()->SetOmniboxSpeculation();
+ } else {
+ NOTREACHED();
+ }
+}
+
+bool SSLClientSocketOpenSSL::WasEverUsed() const {
+ if (transport_.get() && transport_->socket())
+ return transport_->socket()->WasEverUsed();
+
+ NOTREACHED();
+ return false;
+}
+
+bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const {
+ if (transport_.get() && transport_->socket())
+ return transport_->socket()->UsingTCPFastOpen();
+
+ NOTREACHED();
+ return false;
+}
+
+// Socket methods
+
+int SSLClientSocketOpenSSL::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ user_read_buf_ = buf;
+ user_read_buf_len_ = buf_len;
+
+ int rv = DoReadLoop(OK);
+
+ if (rv == ERR_IO_PENDING) {
+ user_read_callback_ = callback;
+ } else {
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ }
+
+ return rv;
+}
+
+int SSLClientSocketOpenSSL::DoReadLoop(int result) {
+ if (result < 0)
+ return result;
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadRead();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+
+ return rv;
+}
+
+int SSLClientSocketOpenSSL::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ user_write_buf_ = buf;
+ user_write_buf_len_ = buf_len;
+
+ int rv = DoWriteLoop(OK);
+
+ if (rv == ERR_IO_PENDING) {
+ user_write_callback_ = callback;
+ } else {
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+ }
+
+ return rv;
+}
+
+int SSLClientSocketOpenSSL::DoWriteLoop(int result) {
+ if (result < 0)
+ return result;
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadWrite();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+
+ return rv;
+}
+
+bool SSLClientSocketOpenSSL::SetReceiveBufferSize(int32 size) {
+ return transport_->socket()->SetReceiveBufferSize(size);
+}
+
+bool SSLClientSocketOpenSSL::SetSendBufferSize(int32 size) {
+ return transport_->socket()->SetSendBufferSize(size);
+}
+
+int SSLClientSocketOpenSSL::DoPayloadRead() {
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+
+ int rv;
+ if (pending_read_error_ != kNoPendingReadResult) {
+ rv = pending_read_error_;
+ pending_read_error_ = kNoPendingReadResult;
+ if (rv == 0) {
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
+ rv, user_read_buf_->data());
+ }
+ return rv;
+ }
+
+ int total_bytes_read = 0;
+ do {
+ rv = SSL_read(ssl_, user_read_buf_->data() + total_bytes_read,
+ user_read_buf_len_ - total_bytes_read);
+ if (rv > 0)
+ total_bytes_read += rv;
+ } while (total_bytes_read < user_read_buf_len_ && rv > 0);
+
+ if (total_bytes_read == user_read_buf_len_) {
+ rv = total_bytes_read;
+ } else {
+ // Otherwise, an error occurred (rv <= 0). The error needs to be handled
+ // immediately, while the OpenSSL errors are still available in
+ // thread-local storage. However, the handled/remapped error code should
+ // only be returned if no application data was already read; if it was, the
+ // error code should be deferred until the next call of DoPayloadRead.
+ //
+ // If no data was read, |*next_result| will point to the return value of
+ // this function. If at least some data was read, |*next_result| will point
+ // to |pending_read_error_|, to be returned in a future call to
+ // DoPayloadRead() (e.g.: after the current data is handled).
+ int *next_result = &rv;
+ if (total_bytes_read > 0) {
+ pending_read_error_ = rv;
+ rv = total_bytes_read;
+ next_result = &pending_read_error_;
+ }
+
+ if (client_auth_cert_needed_) {
+ *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED;
+ } else if (*next_result < 0) {
+ int err = SSL_get_error(ssl_, *next_result);
+ *next_result = MapOpenSSLError(err, err_tracer);
+ if (rv > 0 && *next_result == ERR_IO_PENDING) {
+ // If at least some data was read from SSL_read(), do not treat
+ // insufficient data as an error to return in the next call to
+ // DoPayloadRead() - instead, let the call fall through to check
+ // SSL_read() again. This is because DoTransportIO() may complete
+ // in between the next call to DoPayloadRead(), and thus it is
+ // important to check SSL_read() on subsequent invocations to see
+ // if a complete record may now be read.
+ *next_result = kNoPendingReadResult;
+ }
+ }
+ }
+
+ if (rv >= 0) {
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv,
+ user_read_buf_->data());
+ }
+ return rv;
+}
+
+int SSLClientSocketOpenSSL::DoPayloadWrite() {
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+ int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_);
+
+ if (rv >= 0) {
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv,
+ user_write_buf_->data());
+ return rv;
+ }
+
+ int err = SSL_get_error(ssl_, rv);
+ return MapOpenSSLError(err, err_tracer);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h
new file mode 100644
index 00000000000..f66d95cc69d
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_openssl.h
@@ -0,0 +1,203 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_
+#define NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_
+
+#include <string>
+
+#include "base/compiler_specific.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/io_buffer.h"
+#include "net/cert/cert_verify_result.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/ssl/ssl_config_service.h"
+
+// Avoid including misc OpenSSL headers, i.e.:
+// <openssl/bio.h>
+typedef struct bio_st BIO;
+// <openssl/evp.h>
+typedef struct evp_pkey_st EVP_PKEY;
+// <openssl/ssl.h>
+typedef struct ssl_st SSL;
+// <openssl/x509.h>
+typedef struct x509_st X509;
+
+namespace net {
+
+class CertVerifier;
+class SingleRequestCertVerifier;
+class SSLCertRequestInfo;
+class SSLInfo;
+
+// An SSL client socket implemented with OpenSSL.
+class SSLClientSocketOpenSSL : public SSLClientSocket {
+ public:
+ // Takes ownership of the transport_socket, which may already be connected.
+ // The given hostname will be compared with the name(s) in the server's
+ // certificate during the SSL handshake. ssl_config specifies the SSL
+ // settings.
+ SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context);
+ virtual ~SSLClientSocketOpenSSL();
+
+ const HostPortPair& host_and_port() const { return host_and_port_; }
+ const std::string& ssl_session_cache_shard() const {
+ return ssl_session_cache_shard_;
+ }
+
+ // Callback from the SSL layer that indicates the remote server is requesting
+ // a certificate for this client.
+ int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey);
+
+ // Callback from the SSL layer to check which NPN protocol we are supporting
+ int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen,
+ const unsigned char* in, unsigned int inlen);
+
+ // SSLClientSocket implementation.
+ virtual void GetSSLCertRequestInfo(
+ SSLCertRequestInfo* cert_request_info) OVERRIDE;
+ virtual NextProtoStatus GetNextProto(std::string* proto,
+ std::string* server_protos) OVERRIDE;
+ virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+
+ // SSLSocket implementation.
+ virtual int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) OVERRIDE;
+ virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ private:
+ bool Init();
+ void DoReadCallback(int result);
+ void DoWriteCallback(int result);
+
+ bool DoTransportIO();
+ int DoHandshake();
+ int DoVerifyCert(int result);
+ int DoVerifyCertComplete(int result);
+ void DoConnectCallback(int result);
+ X509Certificate* UpdateServerCert();
+
+ void OnHandshakeIOComplete(int result);
+ void OnSendComplete(int result);
+ void OnRecvComplete(int result);
+
+ int DoHandshakeLoop(int last_io_result);
+ int DoReadLoop(int result);
+ int DoWriteLoop(int result);
+ int DoPayloadRead();
+ int DoPayloadWrite();
+
+ int BufferSend();
+ int BufferRecv();
+ void BufferSendComplete(int result);
+ void BufferRecvComplete(int result);
+ void TransportWriteComplete(int result);
+ void TransportReadComplete(int result);
+
+ bool transport_send_busy_;
+ bool transport_recv_busy_;
+ bool transport_recv_eof_;
+
+ scoped_refptr<DrainableIOBuffer> send_buffer_;
+ scoped_refptr<IOBuffer> recv_buffer_;
+
+ CompletionCallback user_connect_callback_;
+ CompletionCallback user_read_callback_;
+ CompletionCallback user_write_callback_;
+
+ base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_;
+
+ // Used by Read function.
+ scoped_refptr<IOBuffer> user_read_buf_;
+ int user_read_buf_len_;
+
+ // Used by Write function.
+ scoped_refptr<IOBuffer> user_write_buf_;
+ int user_write_buf_len_;
+
+ // Used by DoPayloadRead() when attempting to fill the caller's buffer with
+ // as much data as possible without blocking.
+ // If DoPayloadRead() encounters an error after having read some data, stores
+ // the result to return on the *next* call to DoPayloadRead(). A value > 0
+ // indicates there is no pending result, otherwise 0 indicates EOF and < 0
+ // indicates an error.
+ int pending_read_error_;
+
+ // Set when handshake finishes.
+ scoped_refptr<X509Certificate> server_cert_;
+ CertVerifyResult server_cert_verify_result_;
+ bool completed_handshake_;
+
+ // Stores client authentication information between ClientAuthHandler and
+ // GetSSLCertRequestInfo calls.
+ bool client_auth_cert_needed_;
+ // List of DER-encoded X.509 DistinguishedName of certificate authorities
+ // allowed by the server.
+ std::vector<std::string> cert_authorities_;
+
+ CertVerifier* const cert_verifier_;
+ scoped_ptr<SingleRequestCertVerifier> verifier_;
+
+ // OpenSSL stuff
+ SSL* ssl_;
+ BIO* transport_bio_;
+
+ scoped_ptr<ClientSocketHandle> transport_;
+ const HostPortPair host_and_port_;
+ SSLConfig ssl_config_;
+ // ssl_session_cache_shard_ is an opaque string that partitions the SSL
+ // session cache. i.e. sessions created with one value will not attempt to
+ // resume on the socket with a different value.
+ const std::string ssl_session_cache_shard_;
+
+ // Used for session cache diagnostics.
+ bool trying_cached_session_;
+
+ enum State {
+ STATE_NONE,
+ STATE_HANDSHAKE,
+ STATE_VERIFY_CERT,
+ STATE_VERIFY_CERT_COMPLETE,
+ };
+ State next_handshake_state_;
+ NextProtoStatus npn_status_;
+ std::string npn_proto_;
+ std::string server_protos_;
+ BoundNetLog net_log_;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_CLIENT_SOCKET_OPENSSL_H_
diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
new file mode 100644
index 00000000000..04f899903ac
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
@@ -0,0 +1,279 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_client_socket.h"
+
+#include <errno.h>
+#include <string.h>
+
+#include <openssl/bio.h>
+#include <openssl/bn.h>
+#include <openssl/evp.h>
+#include <openssl/pem.h>
+#include <openssl/rsa.h>
+
+#include "base/file_util.h"
+#include "base/files/file_path.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_handle.h"
+#include "base/values.h"
+#include "crypto/openssl_util.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/test_completion_callback.h"
+#include "net/base/test_data_directory.h"
+#include "net/cert/mock_cert_verifier.h"
+#include "net/cert/test_root_certs.h"
+#include "net/dns/host_resolver.h"
+#include "net/http/transport_security_state.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/tcp_client_socket.h"
+#include "net/ssl/openssl_client_key_store.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_config_service.h"
+#include "net/test/cert_test_util.h"
+#include "net/test/spawned_test_server/spawned_test_server.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+namespace {
+
+typedef OpenSSLClientKeyStore::ScopedEVP_PKEY ScopedEVP_PKEY;
+
+// BIO_free is a macro, it can't be used as a template parameter.
+void BIO_free_func(BIO* bio) {
+ BIO_free(bio);
+}
+
+typedef crypto::ScopedOpenSSL<BIO, BIO_free_func> ScopedBIO;
+typedef crypto::ScopedOpenSSL<RSA, RSA_free> ScopedRSA;
+typedef crypto::ScopedOpenSSL<BIGNUM, BN_free> ScopedBIGNUM;
+
+const SSLConfig kDefaultSSLConfig;
+
+// Loads a PEM-encoded private key file into a scoped EVP_PKEY object.
+// |filepath| is the private key file path.
+// |*pkey| is reset to the new EVP_PKEY on success, untouched otherwise.
+// Returns true on success, false on failure.
+bool LoadPrivateKeyOpenSSL(
+ const base::FilePath& filepath,
+ OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) {
+ std::string data;
+ if (!file_util::ReadFileToString(filepath, &data)) {
+ LOG(ERROR) << "Could not read private key file: "
+ << filepath.value() << ": " << strerror(errno);
+ return false;
+ }
+ ScopedBIO bio(
+ BIO_new_mem_buf(
+ const_cast<char*>(reinterpret_cast<const char*>(data.data())),
+ static_cast<int>(data.size())));
+ if (!bio.get()) {
+ LOG(ERROR) << "Could not allocate BIO for buffer?";
+ return false;
+ }
+ EVP_PKEY* result = PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL);
+ if (result == NULL) {
+ LOG(ERROR) << "Could not decode private key file: "
+ << filepath.value();
+ return false;
+ }
+ pkey->reset(result);
+ return true;
+}
+
+class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
+ public:
+ SSLClientSocketOpenSSLClientAuthTest()
+ : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
+ cert_verifier_(new net::MockCertVerifier),
+ transport_security_state_(new net::TransportSecurityState) {
+ cert_verifier_->set_default_result(net::OK);
+ context_.cert_verifier = cert_verifier_.get();
+ context_.transport_security_state = transport_security_state_.get();
+ key_store_ = net::OpenSSLClientKeyStore::GetInstance();
+ }
+
+ virtual ~SSLClientSocketOpenSSLClientAuthTest() {
+ key_store_->Flush();
+ }
+
+ protected:
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config) {
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ connection->SetSocket(transport_socket.Pass());
+ return socket_factory_->CreateSSLClientSocket(connection.Pass(),
+ host_and_port,
+ ssl_config,
+ context_);
+ }
+
+ // Connect to a HTTPS test server.
+ bool ConnectToTestServer(SpawnedTestServer::SSLOptions& ssl_options) {
+ test_server_.reset(new SpawnedTestServer(SpawnedTestServer::TYPE_HTTPS,
+ ssl_options,
+ base::FilePath()));
+ if (!test_server_->Start()) {
+ LOG(ERROR) << "Could not start SpawnedTestServer";
+ return false;
+ }
+
+ if (!test_server_->GetAddressList(&addr_)) {
+ LOG(ERROR) << "Could not get SpawnedTestServer address list";
+ return false;
+ }
+
+ transport_.reset(new TCPClientSocket(
+ addr_, &log_, NetLog::Source()));
+ int rv = callback_.GetResult(
+ transport_->Connect(callback_.callback()));
+ if (rv != OK) {
+ LOG(ERROR) << "Could not connect to SpawnedTestServer";
+ return false;
+ }
+ return true;
+ }
+
+ // Record a certificate's private key to ensure it can be used
+ // by the OpenSSL-based SSLClientSocket implementation.
+ // |ssl_config| provides a client certificate.
+ // |private_key| must be an EVP_PKEY for the corresponding private key.
+ // Returns true on success, false on failure.
+ bool RecordPrivateKey(SSLConfig& ssl_config,
+ EVP_PKEY* private_key) {
+ return key_store_->RecordClientCertPrivateKey(
+ ssl_config.client_cert.get(), private_key);
+ }
+
+ // Create an SSLClientSocket object and use it to connect to a test
+ // server, then wait for connection results. This must be called after
+ // a succesful ConnectToTestServer() call.
+ // |ssl_config| the SSL configuration to use.
+ // |result| will retrieve the ::Connect() result value.
+ // Returns true on succes, false otherwise. Success means that the socket
+ // could be created and its Connect() was called, not that the connection
+ // itself was a success.
+ bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config,
+ int* result) {
+ sock_ = CreateSSLClientSocket(transport_.Pass(),
+ test_server_->host_port_pair(),
+ ssl_config);
+
+ if (sock_->IsConnected()) {
+ LOG(ERROR) << "SSL Socket prematurely connected";
+ return false;
+ }
+
+ *result = callback_.GetResult(sock_->Connect(callback_.callback()));
+ return true;
+ }
+
+
+ // Check that the client certificate was sent.
+ // Returns true on success.
+ bool CheckSSLClientSocketSentCert() {
+ SSLInfo ssl_info;
+ sock_->GetSSLInfo(&ssl_info);
+ return ssl_info.client_cert_sent;
+ }
+
+ ClientSocketFactory* socket_factory_;
+ scoped_ptr<MockCertVerifier> cert_verifier_;
+ scoped_ptr<TransportSecurityState> transport_security_state_;
+ SSLClientSocketContext context_;
+ OpenSSLClientKeyStore* key_store_;
+ scoped_ptr<SpawnedTestServer> test_server_;
+ AddressList addr_;
+ TestCompletionCallback callback_;
+ CapturingNetLog log_;
+ scoped_ptr<StreamSocket> transport_;
+ scoped_ptr<SSLClientSocket> sock_;
+};
+
+// Connect to a server requesting client authentication, do not send
+// any client certificates. It should refuse the connection.
+TEST_F(SSLClientSocketOpenSSLClientAuthTest, NoCert) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+
+ ASSERT_TRUE(ConnectToTestServer(ssl_options));
+
+ base::FilePath certs_dir = GetTestCertsDirectory();
+ SSLConfig ssl_config = kDefaultSSLConfig;
+
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+
+ EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
+ EXPECT_FALSE(sock_->IsConnected());
+}
+
+// Connect to a server requesting client authentication, and send it
+// an empty certificate. It should refuse the connection.
+TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendEmptyCert) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ ssl_options.client_authorities.push_back(
+ GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem"));
+
+ ASSERT_TRUE(ConnectToTestServer(ssl_options));
+
+ base::FilePath certs_dir = GetTestCertsDirectory();
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.send_client_cert = true;
+ ssl_config.client_cert = NULL;
+
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock_->IsConnected());
+}
+
+// Connect to a server requesting client authentication. Send it a
+// matching certificate. It should allow the connection.
+TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ ssl_options.client_authorities.push_back(
+ GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem"));
+
+ ASSERT_TRUE(ConnectToTestServer(ssl_options));
+
+ base::FilePath certs_dir = GetTestCertsDirectory();
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.send_client_cert = true;
+ ssl_config.client_cert = ImportCertFromFile(certs_dir, "client_1.pem");
+
+ // This is required to ensure that signing works with the client
+ // certificate's private key.
+ OpenSSLClientKeyStore::ScopedEVP_PKEY client_private_key;
+ ASSERT_TRUE(LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"),
+ &client_private_key));
+ EXPECT_TRUE(RecordPrivateKey(ssl_config, client_private_key.get()));
+
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock_->IsConnected());
+
+ EXPECT_TRUE(CheckSSLClientSocketSentCert());
+
+ sock_->Disconnect();
+ EXPECT_FALSE(sock_->IsConnected());
+}
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc
new file mode 100644
index 00000000000..d07c76ffb49
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_pool.cc
@@ -0,0 +1,664 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_client_socket_pool.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/metrics/field_trial.h"
+#include "base/metrics/histogram.h"
+#include "base/metrics/sparse_histogram.h"
+#include "base/values.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_errors.h"
+#include "net/http/http_proxy_client_socket.h"
+#include "net/http/http_proxy_client_socket_pool.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/socks_client_socket_pool.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/transport_client_socket_pool.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_connection_status_flags.h"
+#include "net/ssl/ssl_info.h"
+
+namespace net {
+
+SSLSocketParams::SSLSocketParams(
+ const scoped_refptr<TransportSocketParams>& transport_params,
+ const scoped_refptr<SOCKSSocketParams>& socks_params,
+ const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
+ ProxyServer::Scheme proxy,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ PrivacyMode privacy_mode,
+ int load_flags,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn)
+ : transport_params_(transport_params),
+ http_proxy_params_(http_proxy_params),
+ socks_params_(socks_params),
+ proxy_(proxy),
+ host_and_port_(host_and_port),
+ ssl_config_(ssl_config),
+ privacy_mode_(privacy_mode),
+ load_flags_(load_flags),
+ force_spdy_over_ssl_(force_spdy_over_ssl),
+ want_spdy_over_npn_(want_spdy_over_npn),
+ ignore_limits_(false) {
+ switch (proxy_) {
+ case ProxyServer::SCHEME_DIRECT:
+ DCHECK(transport_params_.get() != NULL);
+ DCHECK(http_proxy_params_.get() == NULL);
+ DCHECK(socks_params_.get() == NULL);
+ ignore_limits_ = transport_params_->ignore_limits();
+ break;
+ case ProxyServer::SCHEME_HTTP:
+ case ProxyServer::SCHEME_HTTPS:
+ DCHECK(transport_params_.get() == NULL);
+ DCHECK(http_proxy_params_.get() != NULL);
+ DCHECK(socks_params_.get() == NULL);
+ ignore_limits_ = http_proxy_params_->ignore_limits();
+ break;
+ case ProxyServer::SCHEME_SOCKS4:
+ case ProxyServer::SCHEME_SOCKS5:
+ DCHECK(transport_params_.get() == NULL);
+ DCHECK(http_proxy_params_.get() == NULL);
+ DCHECK(socks_params_.get() != NULL);
+ ignore_limits_ = socks_params_->ignore_limits();
+ break;
+ default:
+ LOG(DFATAL) << "unknown proxy type";
+ break;
+ }
+}
+
+SSLSocketParams::~SSLSocketParams() {}
+
+// Timeout for the SSL handshake portion of the connect.
+static const int kSSLHandshakeTimeoutInSeconds = 30;
+
+SSLConnectJob::SSLConnectJob(const std::string& group_name,
+ const scoped_refptr<SSLSocketParams>& params,
+ const base::TimeDelta& timeout_duration,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ const SSLClientSocketContext& context,
+ Delegate* delegate,
+ NetLog* net_log)
+ : ConnectJob(group_name,
+ timeout_duration,
+ delegate,
+ BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
+ params_(params),
+ transport_pool_(transport_pool),
+ socks_pool_(socks_pool),
+ http_proxy_pool_(http_proxy_pool),
+ client_socket_factory_(client_socket_factory),
+ host_resolver_(host_resolver),
+ context_(context.cert_verifier,
+ context.server_bound_cert_service,
+ context.transport_security_state,
+ (params->privacy_mode() == kPrivacyModeEnabled
+ ? "pm/" + context.ssl_session_cache_shard
+ : context.ssl_session_cache_shard)),
+ callback_(base::Bind(&SSLConnectJob::OnIOComplete,
+ base::Unretained(this))) {}
+
+SSLConnectJob::~SSLConnectJob() {}
+
+LoadState SSLConnectJob::GetLoadState() const {
+ switch (next_state_) {
+ case STATE_TUNNEL_CONNECT_COMPLETE:
+ if (transport_socket_handle_->socket())
+ return LOAD_STATE_ESTABLISHING_PROXY_TUNNEL;
+ // else, fall through.
+ case STATE_TRANSPORT_CONNECT:
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ case STATE_SOCKS_CONNECT:
+ case STATE_SOCKS_CONNECT_COMPLETE:
+ case STATE_TUNNEL_CONNECT:
+ return transport_socket_handle_->GetLoadState();
+ case STATE_SSL_CONNECT:
+ case STATE_SSL_CONNECT_COMPLETE:
+ return LOAD_STATE_SSL_HANDSHAKE;
+ default:
+ NOTREACHED();
+ return LOAD_STATE_IDLE;
+ }
+}
+
+void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle* handle) {
+ // Headers in |error_response_info_| indicate a proxy tunnel setup
+ // problem. See DoTunnelConnectComplete.
+ if (error_response_info_.headers.get()) {
+ handle->set_pending_http_proxy_connection(
+ transport_socket_handle_.release());
+ }
+ handle->set_ssl_error_response_info(error_response_info_);
+ if (!connect_timing_.ssl_start.is_null())
+ handle->set_is_ssl_error(true);
+}
+
+void SSLConnectJob::OnIOComplete(int result) {
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING)
+ NotifyDelegateOfCompletion(rv); // Deletes |this|.
+}
+
+int SSLConnectJob::DoLoop(int result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_TRANSPORT_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoTransportConnect();
+ break;
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ rv = DoTransportConnectComplete(rv);
+ break;
+ case STATE_SOCKS_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoSOCKSConnect();
+ break;
+ case STATE_SOCKS_CONNECT_COMPLETE:
+ rv = DoSOCKSConnectComplete(rv);
+ break;
+ case STATE_TUNNEL_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoTunnelConnect();
+ break;
+ case STATE_TUNNEL_CONNECT_COMPLETE:
+ rv = DoTunnelConnectComplete(rv);
+ break;
+ case STATE_SSL_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoSSLConnect();
+ break;
+ case STATE_SSL_CONNECT_COMPLETE:
+ rv = DoSSLConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED() << "bad state";
+ rv = ERR_FAILED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ return rv;
+}
+
+int SSLConnectJob::DoTransportConnect() {
+ DCHECK(transport_pool_);
+
+ next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
+ transport_socket_handle_.reset(new ClientSocketHandle());
+ scoped_refptr<TransportSocketParams> transport_params =
+ params_->transport_params();
+ return transport_socket_handle_->Init(
+ group_name(), transport_params,
+ transport_params->destination().priority(), callback_, transport_pool_,
+ net_log());
+}
+
+int SSLConnectJob::DoTransportConnectComplete(int result) {
+ if (result == OK)
+ next_state_ = STATE_SSL_CONNECT;
+
+ return result;
+}
+
+int SSLConnectJob::DoSOCKSConnect() {
+ DCHECK(socks_pool_);
+ next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
+ transport_socket_handle_.reset(new ClientSocketHandle());
+ scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params();
+ return transport_socket_handle_->Init(
+ group_name(), socks_params, socks_params->destination().priority(),
+ callback_, socks_pool_, net_log());
+}
+
+int SSLConnectJob::DoSOCKSConnectComplete(int result) {
+ if (result == OK)
+ next_state_ = STATE_SSL_CONNECT;
+
+ return result;
+}
+
+int SSLConnectJob::DoTunnelConnect() {
+ DCHECK(http_proxy_pool_);
+ next_state_ = STATE_TUNNEL_CONNECT_COMPLETE;
+
+ transport_socket_handle_.reset(new ClientSocketHandle());
+ scoped_refptr<HttpProxySocketParams> http_proxy_params =
+ params_->http_proxy_params();
+ return transport_socket_handle_->Init(
+ group_name(), http_proxy_params,
+ http_proxy_params->destination().priority(), callback_, http_proxy_pool_,
+ net_log());
+}
+
+int SSLConnectJob::DoTunnelConnectComplete(int result) {
+ // Extract the information needed to prompt for appropriate proxy
+ // authentication so that when ClientSocketPoolBaseHelper calls
+ // |GetAdditionalErrorState|, we can easily set the state.
+ if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
+ error_response_info_ = transport_socket_handle_->ssl_error_response_info();
+ } else if (result == ERR_PROXY_AUTH_REQUESTED ||
+ result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) {
+ StreamSocket* socket = transport_socket_handle_->socket();
+ HttpProxyClientSocket* tunnel_socket =
+ static_cast<HttpProxyClientSocket*>(socket);
+ error_response_info_ = *tunnel_socket->GetConnectResponseInfo();
+ }
+ if (result < 0)
+ return result;
+
+ next_state_ = STATE_SSL_CONNECT;
+ return result;
+}
+
+int SSLConnectJob::DoSSLConnect() {
+ next_state_ = STATE_SSL_CONNECT_COMPLETE;
+ // Reset the timeout to just the time allowed for the SSL handshake.
+ ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds));
+
+ // If the handle has a fresh socket, get its connect start and DNS times.
+ // This should always be the case.
+ const LoadTimingInfo::ConnectTiming& socket_connect_timing =
+ transport_socket_handle_->connect_timing();
+ if (!transport_socket_handle_->is_reused() &&
+ !socket_connect_timing.connect_start.is_null()) {
+ // Overwriting |connect_start| serves two purposes - it adjusts timing so
+ // |connect_start| doesn't include dns times, and it adjusts the time so
+ // as not to include time spent waiting for an idle socket.
+ connect_timing_.connect_start = socket_connect_timing.connect_start;
+ connect_timing_.dns_start = socket_connect_timing.dns_start;
+ connect_timing_.dns_end = socket_connect_timing.dns_end;
+ }
+
+ connect_timing_.ssl_start = base::TimeTicks::Now();
+
+ ssl_socket_ = client_socket_factory_->CreateSSLClientSocket(
+ transport_socket_handle_.Pass(),
+ params_->host_and_port(),
+ params_->ssl_config(),
+ context_);
+ return ssl_socket_->Connect(callback_);
+}
+
+int SSLConnectJob::DoSSLConnectComplete(int result) {
+ connect_timing_.ssl_end = base::TimeTicks::Now();
+
+ SSLClientSocket::NextProtoStatus status =
+ SSLClientSocket::kNextProtoUnsupported;
+ std::string proto;
+ std::string server_protos;
+ // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket
+ // that hasn't had SSL_ImportFD called on it. If we get a certificate error
+ // here, then we know that we called SSL_ImportFD.
+ if (result == OK || IsCertificateError(result))
+ status = ssl_socket_->GetNextProto(&proto, &server_protos);
+
+ // If we want spdy over npn, make sure it succeeded.
+ if (status == SSLClientSocket::kNextProtoNegotiated) {
+ ssl_socket_->set_was_npn_negotiated(true);
+ NextProto protocol_negotiated =
+ SSLClientSocket::NextProtoFromString(proto);
+ ssl_socket_->set_protocol_negotiated(protocol_negotiated);
+ // If we negotiated a SPDY version, it must have been present in
+ // SSLConfig::next_protos.
+ // TODO(mbelshe): Verify this.
+ if (protocol_negotiated >= kProtoSPDYMinimumVersion &&
+ protocol_negotiated <= kProtoSPDYMaximumVersion) {
+ ssl_socket_->set_was_spdy_negotiated(true);
+ }
+ }
+ if (params_->want_spdy_over_npn() && !ssl_socket_->was_spdy_negotiated())
+ return ERR_NPN_NEGOTIATION_FAILED;
+
+ // Spdy might be turned on by default, or it might be over npn.
+ bool using_spdy = params_->force_spdy_over_ssl() ||
+ params_->want_spdy_over_npn();
+
+ if (result == OK ||
+ ssl_socket_->IgnoreCertError(result, params_->load_flags())) {
+ DCHECK(!connect_timing_.ssl_start.is_null());
+ base::TimeDelta connect_duration =
+ connect_timing_.ssl_end - connect_timing_.ssl_start;
+ if (using_spdy) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency_2",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ }
+#if defined(SPDY_PROXY_AUTH_ORIGIN)
+ bool using_data_reduction_proxy = params_->host_and_port().Equals(
+ HostPortPair::FromURL(GURL(SPDY_PROXY_AUTH_ORIGIN)));
+ if (using_data_reduction_proxy) {
+ UMA_HISTOGRAM_CUSTOM_TIMES(
+ "Net.SSL_Connection_Latency_DataReductionProxy",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ }
+#endif
+
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_2",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+
+ SSLInfo ssl_info;
+ ssl_socket_->GetSSLInfo(&ssl_info);
+
+ UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSL_CipherSuite",
+ SSLConnectionStatusToCipherSuite(
+ ssl_info.connection_status));
+
+ if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Resume_Handshake",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ } else if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_FULL) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Full_Handshake",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ }
+
+ const std::string& host = params_->host_and_port().host();
+ bool is_google = host == "google.com" ||
+ (host.size() > 11 &&
+ host.rfind(".google.com") == host.size() - 11);
+ if (is_google) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google2",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google_"
+ "Resume_Handshake",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ } else if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_FULL) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google_"
+ "Full_Handshake",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(1),
+ 100);
+ }
+ }
+ }
+
+ if (result == OK || IsCertificateError(result)) {
+ SetSocket(ssl_socket_.PassAs<StreamSocket>());
+ } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
+ error_response_info_.cert_request_info = new SSLCertRequestInfo;
+ ssl_socket_->GetSSLCertRequestInfo(
+ error_response_info_.cert_request_info.get());
+ }
+
+ return result;
+}
+
+int SSLConnectJob::ConnectInternal() {
+ switch (params_->proxy()) {
+ case ProxyServer::SCHEME_DIRECT:
+ next_state_ = STATE_TRANSPORT_CONNECT;
+ break;
+ case ProxyServer::SCHEME_HTTP:
+ case ProxyServer::SCHEME_HTTPS:
+ next_state_ = STATE_TUNNEL_CONNECT;
+ break;
+ case ProxyServer::SCHEME_SOCKS4:
+ case ProxyServer::SCHEME_SOCKS5:
+ next_state_ = STATE_SOCKS_CONNECT;
+ break;
+ default:
+ NOTREACHED() << "unknown proxy type";
+ break;
+ }
+ return DoLoop(OK);
+}
+
+SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ const SSLClientSocketContext& context,
+ NetLog* net_log)
+ : transport_pool_(transport_pool),
+ socks_pool_(socks_pool),
+ http_proxy_pool_(http_proxy_pool),
+ client_socket_factory_(client_socket_factory),
+ host_resolver_(host_resolver),
+ context_(context),
+ net_log_(net_log) {
+ base::TimeDelta max_transport_timeout = base::TimeDelta();
+ base::TimeDelta pool_timeout;
+ if (transport_pool_)
+ max_transport_timeout = transport_pool_->ConnectionTimeout();
+ if (socks_pool_) {
+ pool_timeout = socks_pool_->ConnectionTimeout();
+ if (pool_timeout > max_transport_timeout)
+ max_transport_timeout = pool_timeout;
+ }
+ if (http_proxy_pool_) {
+ pool_timeout = http_proxy_pool_->ConnectionTimeout();
+ if (pool_timeout > max_transport_timeout)
+ max_transport_timeout = pool_timeout;
+ }
+ timeout_ = max_transport_timeout +
+ base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds);
+}
+
+SSLClientSocketPool::SSLClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ CertVerifier* cert_verifier,
+ ServerBoundCertService* server_bound_cert_service,
+ TransportSecurityState* transport_security_state,
+ const std::string& ssl_session_cache_shard,
+ ClientSocketFactory* client_socket_factory,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ SSLConfigService* ssl_config_service,
+ NetLog* net_log)
+ : transport_pool_(transport_pool),
+ socks_pool_(socks_pool),
+ http_proxy_pool_(http_proxy_pool),
+ base_(max_sockets, max_sockets_per_group, histograms,
+ ClientSocketPool::unused_idle_socket_timeout(),
+ ClientSocketPool::used_idle_socket_timeout(),
+ new SSLConnectJobFactory(transport_pool,
+ socks_pool,
+ http_proxy_pool,
+ client_socket_factory,
+ host_resolver,
+ SSLClientSocketContext(
+ cert_verifier,
+ server_bound_cert_service,
+ transport_security_state,
+ ssl_session_cache_shard),
+ net_log)),
+ ssl_config_service_(ssl_config_service) {
+ if (ssl_config_service_.get())
+ ssl_config_service_->AddObserver(this);
+ if (transport_pool_)
+ transport_pool_->AddLayeredPool(this);
+ if (socks_pool_)
+ socks_pool_->AddLayeredPool(this);
+ if (http_proxy_pool_)
+ http_proxy_pool_->AddLayeredPool(this);
+}
+
+SSLClientSocketPool::~SSLClientSocketPool() {
+ if (http_proxy_pool_)
+ http_proxy_pool_->RemoveLayeredPool(this);
+ if (socks_pool_)
+ socks_pool_->RemoveLayeredPool(this);
+ if (transport_pool_)
+ transport_pool_->RemoveLayeredPool(this);
+ if (ssl_config_service_.get())
+ ssl_config_service_->RemoveObserver(this);
+}
+
+scoped_ptr<ConnectJob>
+SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const {
+ return scoped_ptr<ConnectJob>(
+ new SSLConnectJob(group_name, request.params(), ConnectionTimeout(),
+ transport_pool_, socks_pool_, http_proxy_pool_,
+ client_socket_factory_, host_resolver_,
+ context_, delegate, net_log_));
+}
+
+base::TimeDelta
+SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() const {
+ return timeout_;
+}
+
+int SSLClientSocketPool::RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) {
+ const scoped_refptr<SSLSocketParams>* casted_socket_params =
+ static_cast<const scoped_refptr<SSLSocketParams>*>(socket_params);
+
+ return base_.RequestSocket(group_name, *casted_socket_params, priority,
+ handle, callback, net_log);
+}
+
+void SSLClientSocketPool::RequestSockets(
+ const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ const scoped_refptr<SSLSocketParams>* casted_params =
+ static_cast<const scoped_refptr<SSLSocketParams>*>(params);
+
+ base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
+}
+
+void SSLClientSocketPool::CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) {
+ base_.CancelRequest(group_name, handle);
+}
+
+void SSLClientSocketPool::ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
+}
+
+void SSLClientSocketPool::FlushWithError(int error) {
+ base_.FlushWithError(error);
+}
+
+bool SSLClientSocketPool::IsStalled() const {
+ return base_.IsStalled() ||
+ (transport_pool_ && transport_pool_->IsStalled()) ||
+ (socks_pool_ && socks_pool_->IsStalled()) ||
+ (http_proxy_pool_ && http_proxy_pool_->IsStalled());
+}
+
+void SSLClientSocketPool::CloseIdleSockets() {
+ base_.CloseIdleSockets();
+}
+
+int SSLClientSocketPool::IdleSocketCount() const {
+ return base_.idle_socket_count();
+}
+
+int SSLClientSocketPool::IdleSocketCountInGroup(
+ const std::string& group_name) const {
+ return base_.IdleSocketCountInGroup(group_name);
+}
+
+LoadState SSLClientSocketPool::GetLoadState(
+ const std::string& group_name, const ClientSocketHandle* handle) const {
+ return base_.GetLoadState(group_name, handle);
+}
+
+void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
+ base_.AddLayeredPool(layered_pool);
+}
+
+void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
+ base_.RemoveLayeredPool(layered_pool);
+}
+
+base::DictionaryValue* SSLClientSocketPool::GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const {
+ base::DictionaryValue* dict = base_.GetInfoAsValue(name, type);
+ if (include_nested_pools) {
+ base::ListValue* list = new base::ListValue();
+ if (transport_pool_) {
+ list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
+ "transport_socket_pool",
+ false));
+ }
+ if (socks_pool_) {
+ list->Append(socks_pool_->GetInfoAsValue("socks_pool",
+ "socks_pool",
+ true));
+ }
+ if (http_proxy_pool_) {
+ list->Append(http_proxy_pool_->GetInfoAsValue("http_proxy_pool",
+ "http_proxy_pool",
+ true));
+ }
+ dict->Set("nested_pools", list);
+ }
+ return dict;
+}
+
+base::TimeDelta SSLClientSocketPool::ConnectionTimeout() const {
+ return base_.ConnectionTimeout();
+}
+
+ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const {
+ return base_.histograms();
+}
+
+void SSLClientSocketPool::OnSSLConfigChanged() {
+ FlushWithError(ERR_NETWORK_CHANGED);
+}
+
+bool SSLClientSocketPool::CloseOneIdleConnection() {
+ if (base_.CloseOneIdleSocket())
+ return true;
+ return base_.CloseOneIdleConnectionInLayeredPool();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h
new file mode 100644
index 00000000000..431a1b7ceea
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_pool.h
@@ -0,0 +1,297 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
+#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
+
+#include <string>
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/privacy_mode.h"
+#include "net/dns/host_resolver.h"
+#include "net/http/http_response_info.h"
+#include "net/proxy/proxy_server.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/ssl/ssl_config_service.h"
+
+namespace net {
+
+class CertVerifier;
+class ClientSocketFactory;
+class ConnectJobFactory;
+class HostPortPair;
+class HttpProxyClientSocketPool;
+class HttpProxySocketParams;
+class SOCKSClientSocketPool;
+class SOCKSSocketParams;
+class SSLClientSocket;
+class TransportClientSocketPool;
+class TransportSecurityState;
+class TransportSocketParams;
+
+// SSLSocketParams only needs the socket params for the transport socket
+// that will be used (denoted by |proxy|).
+class NET_EXPORT_PRIVATE SSLSocketParams
+ : public base::RefCounted<SSLSocketParams> {
+ public:
+ SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params,
+ const scoped_refptr<SOCKSSocketParams>& socks_params,
+ const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
+ ProxyServer::Scheme proxy,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ PrivacyMode privacy_mode,
+ int load_flags,
+ bool force_spdy_over_ssl,
+ bool want_spdy_over_npn);
+
+ const scoped_refptr<TransportSocketParams>& transport_params() {
+ return transport_params_;
+ }
+ const scoped_refptr<HttpProxySocketParams>& http_proxy_params() {
+ return http_proxy_params_;
+ }
+ const scoped_refptr<SOCKSSocketParams>& socks_params() {
+ return socks_params_;
+ }
+ ProxyServer::Scheme proxy() const { return proxy_; }
+ const HostPortPair& host_and_port() const { return host_and_port_; }
+ const SSLConfig& ssl_config() const { return ssl_config_; }
+ PrivacyMode privacy_mode() const { return privacy_mode_; }
+ int load_flags() const { return load_flags_; }
+ bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; }
+ bool want_spdy_over_npn() const { return want_spdy_over_npn_; }
+ bool ignore_limits() const { return ignore_limits_; }
+
+ private:
+ friend class base::RefCounted<SSLSocketParams>;
+ ~SSLSocketParams();
+
+ const scoped_refptr<TransportSocketParams> transport_params_;
+ const scoped_refptr<HttpProxySocketParams> http_proxy_params_;
+ const scoped_refptr<SOCKSSocketParams> socks_params_;
+ const ProxyServer::Scheme proxy_;
+ const HostPortPair host_and_port_;
+ const SSLConfig ssl_config_;
+ const PrivacyMode privacy_mode_;
+ const int load_flags_;
+ const bool force_spdy_over_ssl_;
+ const bool want_spdy_over_npn_;
+ bool ignore_limits_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLSocketParams);
+};
+
+// SSLConnectJob handles the SSL handshake after setting up the underlying
+// connection as specified in the params.
+class SSLConnectJob : public ConnectJob {
+ public:
+ SSLConnectJob(
+ const std::string& group_name,
+ const scoped_refptr<SSLSocketParams>& params,
+ const base::TimeDelta& timeout_duration,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ const SSLClientSocketContext& context,
+ Delegate* delegate,
+ NetLog* net_log);
+ virtual ~SSLConnectJob();
+
+ // ConnectJob methods.
+ virtual LoadState GetLoadState() const OVERRIDE;
+
+ virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE;
+
+ private:
+ enum State {
+ STATE_TRANSPORT_CONNECT,
+ STATE_TRANSPORT_CONNECT_COMPLETE,
+ STATE_SOCKS_CONNECT,
+ STATE_SOCKS_CONNECT_COMPLETE,
+ STATE_TUNNEL_CONNECT,
+ STATE_TUNNEL_CONNECT_COMPLETE,
+ STATE_SSL_CONNECT,
+ STATE_SSL_CONNECT_COMPLETE,
+ STATE_NONE,
+ };
+
+ void OnIOComplete(int result);
+
+ // Runs the state transition loop.
+ int DoLoop(int result);
+
+ int DoTransportConnect();
+ int DoTransportConnectComplete(int result);
+ int DoSOCKSConnect();
+ int DoSOCKSConnectComplete(int result);
+ int DoTunnelConnect();
+ int DoTunnelConnectComplete(int result);
+ int DoSSLConnect();
+ int DoSSLConnectComplete(int result);
+
+ // Starts the SSL connection process. Returns OK on success and
+ // ERR_IO_PENDING if it cannot immediately service the request.
+ // Otherwise, it returns a net error code.
+ virtual int ConnectInternal() OVERRIDE;
+
+ scoped_refptr<SSLSocketParams> params_;
+ TransportClientSocketPool* const transport_pool_;
+ SOCKSClientSocketPool* const socks_pool_;
+ HttpProxyClientSocketPool* const http_proxy_pool_;
+ ClientSocketFactory* const client_socket_factory_;
+ HostResolver* const host_resolver_;
+
+ const SSLClientSocketContext context_;
+
+ State next_state_;
+ CompletionCallback callback_;
+ scoped_ptr<ClientSocketHandle> transport_socket_handle_;
+ scoped_ptr<SSLClientSocket> ssl_socket_;
+
+ HttpResponseInfo error_response_info_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
+};
+
+class NET_EXPORT_PRIVATE SSLClientSocketPool
+ : public ClientSocketPool,
+ public LayeredPool,
+ public SSLConfigService::Observer {
+ public:
+ // Only the pools that will be used are required. i.e. if you never
+ // try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
+ SSLClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ CertVerifier* cert_verifier,
+ ServerBoundCertService* server_bound_cert_service,
+ TransportSecurityState* transport_security_state,
+ const std::string& ssl_session_cache_shard,
+ ClientSocketFactory* client_socket_factory,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ SSLConfigService* ssl_config_service,
+ NetLog* net_log);
+
+ virtual ~SSLClientSocketPool();
+
+ // ClientSocketPool implementation.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* connect_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) OVERRIDE;
+
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE;
+
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
+
+ virtual void FlushWithError(int error) OVERRIDE;
+
+ virtual bool IsStalled() const OVERRIDE;
+
+ virtual void CloseIdleSockets() OVERRIDE;
+
+ virtual int IdleSocketCount() const OVERRIDE;
+
+ virtual int IdleSocketCountInGroup(
+ const std::string& group_name) const OVERRIDE;
+
+ virtual LoadState GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const OVERRIDE;
+
+ virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+
+ virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+
+ virtual base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const OVERRIDE;
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+
+ virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+
+ // LayeredPool implementation.
+ virtual bool CloseOneIdleConnection() OVERRIDE;
+
+ private:
+ typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
+
+ // SSLConfigService::Observer implementation.
+
+ // When the user changes the SSL config, we flush all idle sockets so they
+ // won't get re-used.
+ virtual void OnSSLConfigChanged() OVERRIDE;
+
+ class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
+ public:
+ SSLConnectJobFactory(
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ const SSLClientSocketContext& context,
+ NetLog* net_log);
+
+ virtual ~SSLConnectJobFactory() {}
+
+ // ClientSocketPoolBase::ConnectJobFactory methods.
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const OVERRIDE;
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+
+ private:
+ TransportClientSocketPool* const transport_pool_;
+ SOCKSClientSocketPool* const socks_pool_;
+ HttpProxyClientSocketPool* const http_proxy_pool_;
+ ClientSocketFactory* const client_socket_factory_;
+ HostResolver* const host_resolver_;
+ const SSLClientSocketContext context_;
+ base::TimeDelta timeout_;
+ NetLog* net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
+ };
+
+ TransportClientSocketPool* const transport_pool_;
+ SOCKSClientSocketPool* const socks_pool_;
+ HttpProxyClientSocketPool* const http_proxy_pool_;
+ PoolBase base_;
+ const scoped_refptr<SSLConfigService> ssl_config_service_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
+};
+
+REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams);
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
new file mode 100644
index 00000000000..280f6e7af1a
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
@@ -0,0 +1,857 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/http/http_proxy_client_socket_pool.h"
+
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/strings/string_util.h"
+#include "base/strings/utf_string_conversions.h"
+#include "base/time/time.h"
+#include "net/base/auth.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/cert/cert_verifier.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/http/http_auth_handler_factory.h"
+#include "net/http/http_network_session.h"
+#include "net/http/http_request_headers.h"
+#include "net/http/http_response_headers.h"
+#include "net/http/http_server_properties_impl.h"
+#include "net/proxy/proxy_service.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/next_proto.h"
+#include "net/socket/socket_test_util.h"
+#include "net/spdy/spdy_session.h"
+#include "net/spdy/spdy_session_pool.h"
+#include "net/spdy/spdy_test_util_common.h"
+#include "net/ssl/ssl_config_service_defaults.h"
+#include "net/test/test_certificate_data.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+const int kMaxSockets = 32;
+const int kMaxSocketsPerGroup = 6;
+
+// Make sure |handle|'s load times are set correctly. DNS and connect start
+// times comes from mock client sockets in these tests, so primarily serves to
+// check those times were copied, and ssl times / connect end are set correctly.
+void TestLoadTimingInfo(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+ // None of these tests use a NetLog.
+ EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasTimes(
+ load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_SSL_TIMES | CONNECT_TIMING_HAS_DNS_TIMES);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+// Just like TestLoadTimingInfo, except DNS times are expected to be null, for
+// tests over proxies that do DNS lookups themselves.
+void TestLoadTimingInfoNoDns(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ // None of these tests use a NetLog.
+ EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+
+ ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_SSL_TIMES);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+class SSLClientSocketPoolTest
+ : public testing::Test,
+ public ::testing::WithParamInterface<NextProto> {
+ protected:
+ SSLClientSocketPoolTest()
+ : proxy_service_(ProxyService::CreateDirect()),
+ ssl_config_service_(new SSLConfigServiceDefaults),
+ http_auth_handler_factory_(
+ HttpAuthHandlerFactory::CreateDefault(&host_resolver_)),
+ session_(CreateNetworkSession()),
+ direct_transport_socket_params_(
+ new TransportSocketParams(HostPortPair("host", 443),
+ MEDIUM,
+ false,
+ false,
+ OnHostResolutionCallback())),
+ transport_histograms_("MockTCP"),
+ transport_socket_pool_(kMaxSockets,
+ kMaxSocketsPerGroup,
+ &transport_histograms_,
+ &socket_factory_),
+ proxy_transport_socket_params_(
+ new TransportSocketParams(HostPortPair("proxy", 443),
+ MEDIUM,
+ false,
+ false,
+ OnHostResolutionCallback())),
+ socks_socket_params_(
+ new SOCKSSocketParams(proxy_transport_socket_params_,
+ true,
+ HostPortPair("sockshost", 443),
+ MEDIUM)),
+ socks_histograms_("MockSOCKS"),
+ socks_socket_pool_(kMaxSockets,
+ kMaxSocketsPerGroup,
+ &socks_histograms_,
+ &transport_socket_pool_),
+ http_proxy_socket_params_(
+ new HttpProxySocketParams(proxy_transport_socket_params_,
+ NULL,
+ GURL("http://host"),
+ std::string(),
+ HostPortPair("host", 80),
+ session_->http_auth_cache(),
+ session_->http_auth_handler_factory(),
+ session_->spdy_session_pool(),
+ true)),
+ http_proxy_histograms_("MockHttpProxy"),
+ http_proxy_socket_pool_(kMaxSockets,
+ kMaxSocketsPerGroup,
+ &http_proxy_histograms_,
+ &host_resolver_,
+ &transport_socket_pool_,
+ NULL,
+ NULL) {
+ scoped_refptr<SSLConfigService> ssl_config_service(
+ new SSLConfigServiceDefaults);
+ ssl_config_service->GetSSLConfig(&ssl_config_);
+ }
+
+ void CreatePool(bool transport_pool, bool http_proxy_pool, bool socks_pool) {
+ ssl_histograms_.reset(new ClientSocketPoolHistograms("SSLUnitTest"));
+ pool_.reset(new SSLClientSocketPool(
+ kMaxSockets,
+ kMaxSocketsPerGroup,
+ ssl_histograms_.get(),
+ NULL /* host_resolver */,
+ NULL /* cert_verifier */,
+ NULL /* server_bound_cert_service */,
+ NULL /* transport_security_state */,
+ std::string() /* ssl_session_cache_shard */,
+ &socket_factory_,
+ transport_pool ? &transport_socket_pool_ : NULL,
+ socks_pool ? &socks_socket_pool_ : NULL,
+ http_proxy_pool ? &http_proxy_socket_pool_ : NULL,
+ NULL,
+ NULL));
+ }
+
+ scoped_refptr<SSLSocketParams> SSLParams(ProxyServer::Scheme proxy,
+ bool want_spdy_over_npn) {
+ return make_scoped_refptr(new SSLSocketParams(
+ proxy == ProxyServer::SCHEME_DIRECT ? direct_transport_socket_params_
+ : NULL,
+ proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL,
+ proxy == ProxyServer::SCHEME_HTTP ? http_proxy_socket_params_ : NULL,
+ proxy,
+ HostPortPair("host", 443),
+ ssl_config_,
+ kPrivacyModeDisabled,
+ 0,
+ false,
+ want_spdy_over_npn));
+ }
+
+ void AddAuthToCache() {
+ const base::string16 kFoo(ASCIIToUTF16("foo"));
+ const base::string16 kBar(ASCIIToUTF16("bar"));
+ session_->http_auth_cache()->Add(GURL("http://proxy:443/"),
+ "MyRealm1",
+ HttpAuth::AUTH_SCHEME_BASIC,
+ "Basic realm=MyRealm1",
+ AuthCredentials(kFoo, kBar),
+ "/");
+ }
+
+ HttpNetworkSession* CreateNetworkSession() {
+ HttpNetworkSession::Params params;
+ params.host_resolver = &host_resolver_;
+ params.cert_verifier = cert_verifier_.get();
+ params.transport_security_state = transport_security_state_.get();
+ params.proxy_service = proxy_service_.get();
+ params.client_socket_factory = &socket_factory_;
+ params.ssl_config_service = ssl_config_service_.get();
+ params.http_auth_handler_factory = http_auth_handler_factory_.get();
+ params.http_server_properties =
+ http_server_properties_.GetWeakPtr();
+ params.enable_spdy_compression = false;
+ params.spdy_default_protocol = GetParam();
+ return new HttpNetworkSession(params);
+ }
+
+ void TestIPPoolingDisabled(SSLSocketDataProvider* ssl);
+
+ MockClientSocketFactory socket_factory_;
+ MockCachingHostResolver host_resolver_;
+ scoped_ptr<CertVerifier> cert_verifier_;
+ scoped_ptr<TransportSecurityState> transport_security_state_;
+ const scoped_ptr<ProxyService> proxy_service_;
+ const scoped_refptr<SSLConfigService> ssl_config_service_;
+ const scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory_;
+ HttpServerPropertiesImpl http_server_properties_;
+ const scoped_refptr<HttpNetworkSession> session_;
+
+ scoped_refptr<TransportSocketParams> direct_transport_socket_params_;
+ ClientSocketPoolHistograms transport_histograms_;
+ MockTransportClientSocketPool transport_socket_pool_;
+
+ scoped_refptr<TransportSocketParams> proxy_transport_socket_params_;
+
+ scoped_refptr<SOCKSSocketParams> socks_socket_params_;
+ ClientSocketPoolHistograms socks_histograms_;
+ MockSOCKSClientSocketPool socks_socket_pool_;
+
+ scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_;
+ ClientSocketPoolHistograms http_proxy_histograms_;
+ HttpProxyClientSocketPool http_proxy_socket_pool_;
+
+ SSLConfig ssl_config_;
+ scoped_ptr<ClientSocketPoolHistograms> ssl_histograms_;
+ scoped_ptr<SSLClientSocketPool> pool_;
+};
+
+INSTANTIATE_TEST_CASE_P(
+ NextProto,
+ SSLClientSocketPoolTest,
+ testing::Values(kProtoSPDY2, kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2,
+ kProtoHTTP2Draft04));
+
+TEST_P(SSLClientSocketPoolTest, TCPFail) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", params, MEDIUM, CompletionCallback(), pool_.get(),
+ BoundNetLog());
+ EXPECT_EQ(ERR_CONNECTION_FAILED, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, TCPFailAsync) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, BasicDirect) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectCertError) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, ERR_CERT_COMMON_NAME_INVALID);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectSSLError) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_TRUE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectWithNPN) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.SetNextProto(kProtoHTTP11);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+ SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket());
+ EXPECT_TRUE(ssl_socket->WasNpnNegotiated());
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectNoSPDY) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.SetNextProto(kProtoHTTP11);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ true);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_NPN_NEGOTIATION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_TRUE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectGotSPDY) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.SetNextProto(GetParam());
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ true);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+
+ SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket());
+ EXPECT_TRUE(ssl_socket->WasNpnNegotiated());
+ std::string proto;
+ std::string server_protos;
+ ssl_socket->GetNextProto(&proto, &server_protos);
+ EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto));
+}
+
+TEST_P(SSLClientSocketPoolTest, DirectGotBonusSPDY) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.SetNextProto(GetParam());
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_DIRECT,
+ true);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfo(handle);
+
+ SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket());
+ EXPECT_TRUE(ssl_socket->WasNpnNegotiated());
+ std::string proto;
+ std::string server_protos;
+ ssl_socket->GetNextProto(&proto, &server_protos);
+ EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto));
+}
+
+TEST_P(SSLClientSocketPoolTest, SOCKSFail) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_CONNECTION_FAILED, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, SOCKSFailAsync) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, SOCKSBasic) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ // SOCKS5 generally has no DNS times, but the mock SOCKS5 sockets used here
+ // don't go through the real logic, unlike in the HTTP proxy tests.
+ TestLoadTimingInfo(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) {
+ StaticSocketDataProvider data;
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_SOCKS5,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ // SOCKS5 generally has no DNS times, but the mock SOCKS5 sockets used here
+ // don't go through the real logic, unlike in the HTTP proxy tests.
+ TestLoadTimingInfo(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, HttpProxyFail) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, HttpProxyFailAsync) {
+ StaticSocketDataProvider data;
+ data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_FAILED));
+ socket_factory_.AddSocketDataProvider(&data);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+}
+
+TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) {
+ MockWrite writes[] = {
+ MockWrite(SYNCHRONOUS,
+ "CONNECT host:80 HTTP/1.1\r\n"
+ "Host: host\r\n"
+ "Proxy-Connection: keep-alive\r\n"
+ "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
+ };
+ MockRead reads[] = {
+ MockRead(SYNCHRONOUS, "HTTP/1.1 200 Connection Established\r\n\r\n"),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), writes,
+ arraysize(writes));
+ data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
+ socket_factory_.AddSocketDataProvider(&data);
+ AddAuthToCache();
+ SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoNoDns(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) {
+ MockWrite writes[] = {
+ MockWrite("CONNECT host:80 HTTP/1.1\r\n"
+ "Host: host\r\n"
+ "Proxy-Connection: keep-alive\r\n"
+ "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
+ };
+ MockRead reads[] = {
+ MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), writes,
+ arraysize(writes));
+ socket_factory_.AddSocketDataProvider(&data);
+ AddAuthToCache();
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoNoDns(handle);
+}
+
+TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) {
+ MockWrite writes[] = {
+ MockWrite("CONNECT host:80 HTTP/1.1\r\n"
+ "Host: host\r\n"
+ "Proxy-Connection: keep-alive\r\n\r\n"),
+ };
+ MockRead reads[] = {
+ MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
+ MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
+ MockRead("Content-Length: 10\r\n\r\n"),
+ MockRead("0123456789"),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), writes,
+ arraysize(writes));
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(false, true /* http proxy pool */, true /* socks pool */);
+ scoped_refptr<SSLSocketParams> params = SSLParams(ProxyServer::SCHEME_HTTP,
+ false);
+
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ int rv = handle.Init(
+ "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_FALSE(handle.is_ssl_error());
+ const HttpResponseInfo& tunnel_info = handle.ssl_error_response_info();
+ EXPECT_EQ(tunnel_info.headers->response_code(), 407);
+ scoped_ptr<ClientSocketHandle> tunnel_handle(
+ handle.release_pending_http_proxy_connection());
+ EXPECT_TRUE(tunnel_handle->socket());
+ EXPECT_FALSE(tunnel_handle->socket()->IsConnected());
+}
+
+TEST_P(SSLClientSocketPoolTest, IPPooling) {
+ const int kTestPort = 80;
+ struct TestHosts {
+ std::string name;
+ std::string iplist;
+ SpdySessionKey key;
+ AddressList addresses;
+ } test_hosts[] = {
+ { "www.webkit.org", "192.0.2.33,192.168.0.1,192.168.0.5" },
+ { "code.google.com", "192.168.0.2,192.168.0.3,192.168.0.5" },
+ { "js.webkit.org", "192.168.0.4,192.168.0.1,192.0.2.33" },
+ };
+
+ host_resolver_.set_synchronous_mode(true);
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) {
+ host_resolver_.rules()->AddIPLiteralRule(
+ test_hosts[i].name, test_hosts[i].iplist, std::string());
+
+ // This test requires that the HostResolver cache be populated. Normal
+ // code would have done this already, but we do it manually.
+ HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort));
+ host_resolver_.Resolve(info, &test_hosts[i].addresses, CompletionCallback(),
+ NULL, BoundNetLog());
+
+ // Setup a SpdySessionKey
+ test_hosts[i].key = SpdySessionKey(
+ HostPortPair(test_hosts[i].name, kTestPort), ProxyServer::Direct(),
+ kPrivacyModeDisabled);
+ }
+
+ MockRead reads[] = {
+ MockRead(ASYNC, ERR_IO_PENDING),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), NULL, 0);
+ socket_factory_.AddSocketDataProvider(&data);
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.cert = X509Certificate::CreateFromBytes(
+ reinterpret_cast<const char*>(webkit_der), sizeof(webkit_der));
+ ssl.SetNextProto(GetParam());
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ base::WeakPtr<SpdySession> spdy_session =
+ CreateSecureSpdySession(session_, test_hosts[0].key, BoundNetLog());
+
+ EXPECT_TRUE(
+ HasSpdySession(session_->spdy_session_pool(), test_hosts[0].key));
+ EXPECT_FALSE(
+ HasSpdySession(session_->spdy_session_pool(), test_hosts[1].key));
+ EXPECT_TRUE(
+ HasSpdySession(session_->spdy_session_pool(), test_hosts[2].key));
+
+ session_->spdy_session_pool()->CloseAllSessions();
+}
+
+void SSLClientSocketPoolTest::TestIPPoolingDisabled(
+ SSLSocketDataProvider* ssl) {
+ const int kTestPort = 80;
+ struct TestHosts {
+ std::string name;
+ std::string iplist;
+ SpdySessionKey key;
+ AddressList addresses;
+ } test_hosts[] = {
+ { "www.webkit.org", "192.0.2.33,192.168.0.1,192.168.0.5" },
+ { "js.webkit.com", "192.168.0.4,192.168.0.1,192.0.2.33" },
+ };
+
+ TestCompletionCallback callback;
+ int rv;
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) {
+ host_resolver_.rules()->AddIPLiteralRule(
+ test_hosts[i].name, test_hosts[i].iplist, std::string());
+
+ // This test requires that the HostResolver cache be populated. Normal
+ // code would have done this already, but we do it manually.
+ HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort));
+ rv = host_resolver_.Resolve(info, &test_hosts[i].addresses,
+ callback.callback(), NULL, BoundNetLog());
+ EXPECT_EQ(OK, callback.GetResult(rv));
+
+ // Setup a SpdySessionKey
+ test_hosts[i].key = SpdySessionKey(
+ HostPortPair(test_hosts[i].name, kTestPort), ProxyServer::Direct(),
+ kPrivacyModeDisabled);
+ }
+
+ MockRead reads[] = {
+ MockRead(ASYNC, ERR_IO_PENDING),
+ };
+ StaticSocketDataProvider data(reads, arraysize(reads), NULL, 0);
+ socket_factory_.AddSocketDataProvider(&data);
+ socket_factory_.AddSSLSocketDataProvider(ssl);
+
+ CreatePool(true /* tcp pool */, false, false);
+ base::WeakPtr<SpdySession> spdy_session =
+ CreateSecureSpdySession(session_, test_hosts[0].key, BoundNetLog());
+
+ EXPECT_TRUE(
+ HasSpdySession(session_->spdy_session_pool(), test_hosts[0].key));
+ EXPECT_FALSE(
+ HasSpdySession(session_->spdy_session_pool(), test_hosts[1].key));
+
+ session_->spdy_session_pool()->CloseAllSessions();
+}
+
+// Verifies that an SSL connection with client authentication disables SPDY IP
+// pooling.
+TEST_P(SSLClientSocketPoolTest, IPPoolingClientCert) {
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.cert = X509Certificate::CreateFromBytes(
+ reinterpret_cast<const char*>(webkit_der), sizeof(webkit_der));
+ ssl.client_cert_sent = true;
+ ssl.SetNextProto(GetParam());
+ TestIPPoolingDisabled(&ssl);
+}
+
+// Verifies that an SSL connection with channel ID disables SPDY IP pooling.
+TEST_P(SSLClientSocketPoolTest, IPPoolingChannelID) {
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.channel_id_sent = true;
+ ssl.SetNextProto(GetParam());
+ TestIPPoolingDisabled(&ssl);
+}
+
+// It would be nice to also test the timeouts in SSLClientSocketPool.
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc
new file mode 100644
index 00000000000..f791928580f
--- /dev/null
+++ b/chromium/net/socket/ssl_client_socket_unittest.cc
@@ -0,0 +1,1798 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_client_socket.h"
+
+#include "base/callback_helpers.h"
+#include "base/memory/ref_counted.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/test_completion_callback.h"
+#include "net/base/test_data_directory.h"
+#include "net/cert/mock_cert_verifier.h"
+#include "net/cert/test_root_certs.h"
+#include "net/dns/host_resolver.h"
+#include "net/http/transport_security_state.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/tcp_client_socket.h"
+#include "net/ssl/ssl_cert_request_info.h"
+#include "net/ssl/ssl_config_service.h"
+#include "net/test/cert_test_util.h"
+#include "net/test/spawned_test_server/spawned_test_server.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+//-----------------------------------------------------------------------------
+
+namespace net {
+
+namespace {
+
+const SSLConfig kDefaultSSLConfig;
+
+// WrappedStreamSocket is a base class that wraps an existing StreamSocket,
+// forwarding the Socket and StreamSocket interfaces to the underlying
+// transport.
+// This is to provide a common base class for subclasses to override specific
+// StreamSocket methods for testing, while still communicating with a 'real'
+// StreamSocket.
+class WrappedStreamSocket : public StreamSocket {
+ public:
+ explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport)
+ : transport_(transport.Pass()) {}
+ virtual ~WrappedStreamSocket() {}
+
+ // StreamSocket implementation:
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ return transport_->Connect(callback);
+ }
+ virtual void Disconnect() OVERRIDE { transport_->Disconnect(); }
+ virtual bool IsConnected() const OVERRIDE {
+ return transport_->IsConnected();
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ return transport_->IsConnectedAndIdle();
+ }
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ return transport_->GetPeerAddress(address);
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ return transport_->GetLocalAddress(address);
+ }
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return transport_->NetLog();
+ }
+ virtual void SetSubresourceSpeculation() OVERRIDE {
+ transport_->SetSubresourceSpeculation();
+ }
+ virtual void SetOmniboxSpeculation() OVERRIDE {
+ transport_->SetOmniboxSpeculation();
+ }
+ virtual bool WasEverUsed() const OVERRIDE {
+ return transport_->WasEverUsed();
+ }
+ virtual bool UsingTCPFastOpen() const OVERRIDE {
+ return transport_->UsingTCPFastOpen();
+ }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return transport_->WasNpnNegotiated();
+ }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return transport_->GetNegotiatedProtocol();
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return transport_->GetSSLInfo(ssl_info);
+ }
+
+ // Socket implementation:
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return transport_->Read(buf, buf_len, callback);
+ }
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return transport_->Write(buf, buf_len, callback);
+ }
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
+ return transport_->SetReceiveBufferSize(size);
+ }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE {
+ return transport_->SetSendBufferSize(size);
+ }
+
+ protected:
+ scoped_ptr<StreamSocket> transport_;
+};
+
+// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that
+// will ensure a certain amount of data is internally buffered before
+// satisfying a Read() request. It exists to mimic OS-level internal
+// buffering, but in a way to guarantee that X number of bytes will be
+// returned to callers of Read(), regardless of how quickly the OS receives
+// them from the TestServer.
+class ReadBufferingStreamSocket : public WrappedStreamSocket {
+ public:
+ explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport);
+ virtual ~ReadBufferingStreamSocket() {}
+
+ // Socket implementation:
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // Sets the internal buffer to |size|. This must not be greater than
+ // the largest value supplied to Read() - that is, it does not handle
+ // having "leftovers" at the end of Read().
+ // Each call to Read() will be prevented from completion until at least
+ // |size| data has been read.
+ // Set to 0 to turn off buffering, causing Read() to transparently
+ // read via the underlying transport.
+ void SetBufferSize(int size);
+
+ private:
+ enum State {
+ STATE_NONE,
+ STATE_READ,
+ STATE_READ_COMPLETE,
+ };
+
+ int DoLoop(int result);
+ int DoRead();
+ int DoReadComplete(int result);
+ void OnReadCompleted(int result);
+
+ State state_;
+ scoped_refptr<GrowableIOBuffer> read_buffer_;
+ int buffer_size_;
+
+ scoped_refptr<IOBuffer> user_read_buf_;
+ CompletionCallback user_read_callback_;
+};
+
+ReadBufferingStreamSocket::ReadBufferingStreamSocket(
+ scoped_ptr<StreamSocket> transport)
+ : WrappedStreamSocket(transport.Pass()),
+ read_buffer_(new GrowableIOBuffer()),
+ buffer_size_(0) {}
+
+void ReadBufferingStreamSocket::SetBufferSize(int size) {
+ DCHECK(!user_read_buf_.get());
+ buffer_size_ = size;
+ read_buffer_->SetCapacity(size);
+}
+
+int ReadBufferingStreamSocket::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (buffer_size_ == 0)
+ return transport_->Read(buf, buf_len, callback);
+
+ if (buf_len < buffer_size_)
+ return ERR_UNEXPECTED;
+
+ state_ = STATE_READ;
+ user_read_buf_ = buf;
+ int result = DoLoop(OK);
+ if (result == ERR_IO_PENDING)
+ user_read_callback_ = callback;
+ else
+ user_read_buf_ = NULL;
+ return result;
+}
+
+int ReadBufferingStreamSocket::DoLoop(int result) {
+ int rv = result;
+ do {
+ State current_state = state_;
+ state_ = STATE_NONE;
+ switch (current_state) {
+ case STATE_READ:
+ rv = DoRead();
+ break;
+ case STATE_READ_COMPLETE:
+ rv = DoReadComplete(rv);
+ break;
+ case STATE_NONE:
+ default:
+ NOTREACHED() << "Unexpected state: " << current_state;
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
+ return rv;
+}
+
+int ReadBufferingStreamSocket::DoRead() {
+ state_ = STATE_READ_COMPLETE;
+ int rv =
+ transport_->Read(read_buffer_.get(),
+ read_buffer_->RemainingCapacity(),
+ base::Bind(&ReadBufferingStreamSocket::OnReadCompleted,
+ base::Unretained(this)));
+ return rv;
+}
+
+int ReadBufferingStreamSocket::DoReadComplete(int result) {
+ state_ = STATE_NONE;
+ if (result <= 0)
+ return result;
+
+ read_buffer_->set_offset(read_buffer_->offset() + result);
+ if (read_buffer_->RemainingCapacity() > 0) {
+ state_ = STATE_READ;
+ return OK;
+ }
+
+ memcpy(user_read_buf_->data(),
+ read_buffer_->StartOfBuffer(),
+ read_buffer_->capacity());
+ read_buffer_->set_offset(0);
+ return read_buffer_->capacity();
+}
+
+void ReadBufferingStreamSocket::OnReadCompleted(int result) {
+ result = DoLoop(result);
+ if (result == ERR_IO_PENDING)
+ return;
+
+ user_read_buf_ = NULL;
+ base::ResetAndReturn(&user_read_callback_).Run(result);
+}
+
+// Simulates synchronously receiving an error during Read() or Write()
+class SynchronousErrorStreamSocket : public WrappedStreamSocket {
+ public:
+ explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport);
+ virtual ~SynchronousErrorStreamSocket() {}
+
+ // Socket implementation:
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // Sets the next Read() call and all future calls to return |error|.
+ // If there is already a pending asynchronous read, the configured error
+ // will not be returned until that asynchronous read has completed and Read()
+ // is called again.
+ void SetNextReadError(Error error) {
+ DCHECK_GE(0, error);
+ have_read_error_ = true;
+ pending_read_error_ = error;
+ }
+
+ // Sets the next Write() call and all future calls to return |error|.
+ // If there is already a pending asynchronous write, the configured error
+ // will not be returned until that asynchronous write has completed and
+ // Write() is called again.
+ void SetNextWriteError(Error error) {
+ DCHECK_GE(0, error);
+ have_write_error_ = true;
+ pending_write_error_ = error;
+ }
+
+ private:
+ bool have_read_error_;
+ int pending_read_error_;
+
+ bool have_write_error_;
+ int pending_write_error_;
+
+ DISALLOW_COPY_AND_ASSIGN(SynchronousErrorStreamSocket);
+};
+
+SynchronousErrorStreamSocket::SynchronousErrorStreamSocket(
+ scoped_ptr<StreamSocket> transport)
+ : WrappedStreamSocket(transport.Pass()),
+ have_read_error_(false),
+ pending_read_error_(OK),
+ have_write_error_(false),
+ pending_write_error_(OK) {}
+
+int SynchronousErrorStreamSocket::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (have_read_error_)
+ return pending_read_error_;
+ return transport_->Read(buf, buf_len, callback);
+}
+
+int SynchronousErrorStreamSocket::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ if (have_write_error_)
+ return pending_write_error_;
+ return transport_->Write(buf, buf_len, callback);
+}
+
+// FakeBlockingStreamSocket wraps an existing StreamSocket and simulates the
+// underlying transport needing to complete things asynchronously in a
+// deterministic manner (e.g.: independent of the TestServer and the OS's
+// semantics).
+class FakeBlockingStreamSocket : public WrappedStreamSocket {
+ public:
+ explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport);
+ virtual ~FakeBlockingStreamSocket() {}
+
+ // Socket implementation:
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return read_state_.RunWrappedFunction(buf, buf_len, callback);
+ }
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return write_state_.RunWrappedFunction(buf, buf_len, callback);
+ }
+
+ // Causes the next call to Read() to return ERR_IO_PENDING, not completing
+ // (invoking the callback) until UnblockRead() has been called and the
+ // underlying transport has completed.
+ void SetNextReadShouldBlock() { read_state_.SetShouldBlock(); }
+ void UnblockRead() { read_state_.Unblock(); }
+
+ // Causes the next call to Write() to return ERR_IO_PENDING, not completing
+ // (invoking the callback) until UnblockWrite() has been called and the
+ // underlying transport has completed.
+ void SetNextWriteShouldBlock() { write_state_.SetShouldBlock(); }
+ void UnblockWrite() { write_state_.Unblock(); }
+
+ private:
+ // Tracks the state for simulating a blocking Read/Write operation.
+ class BlockingState {
+ public:
+ // Wrapper for the underlying Socket function to call (ie: Read/Write).
+ typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)>
+ WrappedSocketFunction;
+
+ explicit BlockingState(const WrappedSocketFunction& function);
+ ~BlockingState() {}
+
+ // Sets the next call to RunWrappedFunction() to block, returning
+ // ERR_IO_PENDING and not invoking the user callback until Unblock() is
+ // called.
+ void SetShouldBlock();
+
+ // Unblocks the currently blocked pending function, invoking the user
+ // callback if the results are immediately available.
+ // Note: It's not valid to call this unless SetShouldBlock() has been
+ // called beforehand.
+ void Unblock();
+
+ // Performs the wrapped socket function on the underlying transport. If
+ // configured to block via SetShouldBlock(), then |user_callback| will not
+ // be invoked until Unblock() has been called.
+ int RunWrappedFunction(IOBuffer* buf,
+ int len,
+ const CompletionCallback& user_callback);
+
+ private:
+ // Handles completion from the underlying wrapped socket function.
+ void OnCompleted(int result);
+
+ WrappedSocketFunction wrapped_function_;
+ bool should_block_;
+ bool have_result_;
+ int pending_result_;
+ CompletionCallback user_callback_;
+ };
+
+ BlockingState read_state_;
+ BlockingState write_state_;
+
+ DISALLOW_COPY_AND_ASSIGN(FakeBlockingStreamSocket);
+};
+
+FakeBlockingStreamSocket::FakeBlockingStreamSocket(
+ scoped_ptr<StreamSocket> transport)
+ : WrappedStreamSocket(transport.Pass()),
+ read_state_(base::Bind(&Socket::Read,
+ base::Unretained(transport_.get()))),
+ write_state_(base::Bind(&Socket::Write,
+ base::Unretained(transport_.get()))) {}
+
+FakeBlockingStreamSocket::BlockingState::BlockingState(
+ const WrappedSocketFunction& function)
+ : wrapped_function_(function),
+ should_block_(false),
+ have_result_(false),
+ pending_result_(OK) {}
+
+void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() {
+ DCHECK(!should_block_);
+ should_block_ = true;
+}
+
+void FakeBlockingStreamSocket::BlockingState::Unblock() {
+ DCHECK(should_block_);
+ should_block_ = false;
+
+ // If the operation is still pending in the underlying transport, immediately
+ // return - OnCompleted() will handle invoking the callback once the transport
+ // has completed.
+ if (!have_result_)
+ return;
+
+ have_result_ = false;
+
+ base::ResetAndReturn(&user_callback_).Run(pending_result_);
+}
+
+int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction(
+ IOBuffer* buf,
+ int len,
+ const CompletionCallback& callback) {
+
+ // The callback to be called by the underlying transport. Either forward
+ // directly to the user's callback if not set to block, or intercept it with
+ // OnCompleted so that the user's callback is not invoked until Unblock() is
+ // called.
+ CompletionCallback transport_callback =
+ !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted,
+ base::Unretained(this));
+ int rv = wrapped_function_.Run(buf, len, transport_callback);
+ if (should_block_) {
+ user_callback_ = callback;
+ // May have completed synchronously.
+ have_result_ = (rv != ERR_IO_PENDING);
+ pending_result_ = rv;
+ return ERR_IO_PENDING;
+ }
+
+ return rv;
+}
+
+void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) {
+ if (should_block_) {
+ // Store the result so that the callback can be invoked once Unblock() is
+ // called.
+ have_result_ = true;
+ pending_result_ = result;
+ return;
+ }
+
+ // Otherwise, the Unblock() function was called before the underlying
+ // transport completed, so run the user's callback immediately.
+ base::ResetAndReturn(&user_callback_).Run(result);
+}
+
+// CompletionCallback that will delete the associated StreamSocket when
+// the callback is invoked.
+class DeleteSocketCallback : public TestCompletionCallbackBase {
+ public:
+ explicit DeleteSocketCallback(StreamSocket* socket)
+ : socket_(socket),
+ callback_(base::Bind(&DeleteSocketCallback::OnComplete,
+ base::Unretained(this))) {}
+ virtual ~DeleteSocketCallback() {}
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ if (socket_) {
+ delete socket_;
+ socket_ = NULL;
+ } else {
+ ADD_FAILURE() << "Deleting socket twice";
+ }
+ SetResult(result);
+ }
+
+ StreamSocket* socket_;
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback);
+};
+
+class SSLClientSocketTest : public PlatformTest {
+ public:
+ SSLClientSocketTest()
+ : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
+ cert_verifier_(new MockCertVerifier),
+ transport_security_state_(new TransportSecurityState) {
+ cert_verifier_->set_default_result(OK);
+ context_.cert_verifier = cert_verifier_.get();
+ context_.transport_security_state = transport_security_state_.get();
+ }
+
+ protected:
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<StreamSocket> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config) {
+ scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
+ connection->SetSocket(transport_socket.Pass());
+ return socket_factory_->CreateSSLClientSocket(
+ connection.Pass(), host_and_port, ssl_config, context_);
+ }
+
+ ClientSocketFactory* socket_factory_;
+ scoped_ptr<MockCertVerifier> cert_verifier_;
+ scoped_ptr<TransportSecurityState> transport_security_state_;
+ SSLClientSocketContext context_;
+};
+
+//-----------------------------------------------------------------------------
+
+// LogContainsSSLConnectEndEvent returns true if the given index in the given
+// log is an SSL connect end event. The NSS sockets will cork in an attempt to
+// merge the first application data record with the Finished message when false
+// starting. However, in order to avoid the server timing out the handshake,
+// they'll give up waiting for application data and send the Finished after a
+// timeout. This means that an SSL connect end event may appear as a socket
+// write.
+static bool LogContainsSSLConnectEndEvent(
+ const CapturingNetLog::CapturedEntryList& log,
+ int i) {
+ return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) ||
+ LogContainsEvent(
+ log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
+}
+;
+
+TEST_F(SSLClientSocketTest, Connect) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
+
+ sock->Disconnect();
+ EXPECT_FALSE(sock->IsConnected());
+}
+
+TEST_F(SSLClientSocketTest, ConnectExpired) {
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_EXPIRED);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_EQ(ERR_CERT_DATE_INVALID, rv);
+
+ // Rather than testing whether or not the underlying socket is connected,
+ // test that the handshake has finished. This is because it may be
+ // desirable to disconnect the socket before showing a user prompt, since
+ // the user may take indefinitely long to respond.
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
+}
+
+TEST_F(SSLClientSocketTest, ConnectMismatched) {
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID);
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv);
+
+ // Rather than testing whether or not the underlying socket is connected,
+ // test that the handshake has finished. This is because it may be
+ // desirable to disconnect the socket before showing a user prompt, since
+ // the user may take indefinitely long to respond.
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
+}
+
+// Attempt to connect to a page which requests a client certificate. It should
+// return an error code on connect.
+TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ log.GetEntries(&entries);
+ // Because we prematurely kill the handshake at CertificateRequest,
+ // the server may still send data (notably the ServerHelloDone)
+ // after the error is returned. As a result, the SSL_CONNECT may not
+ // be the last entry. See http://crbug.com/54445. We use
+ // ExpectLogContainsSomewhere instead of
+ // LogContainsSSLConnectEndEvent to avoid assuming, e.g., only one
+ // extra read instead of two. This occurs before the handshake ends,
+ // so the corking logic of LogContainsSSLConnectEndEvent isn't
+ // necessary.
+ //
+ // TODO(davidben): When SSL_RestartHandshakeAfterCertReq in NSS is
+ // fixed and we can respond to the first CertificateRequest
+ // without closing the socket, add a unit test for sending the
+ // certificate. This test may still be useful as we'll want to close
+ // the socket on a timeout if the user takes a long time to pick a
+ // cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832
+ ExpectLogContainsSomewhere(
+ entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END);
+ EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
+ EXPECT_FALSE(sock->IsConnected());
+}
+
+// Connect to a server requesting optional client authentication. Send it a
+// null certificate. It should allow the connection.
+//
+// TODO(davidben): Also test providing an actual certificate.
+TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.send_client_cert = true;
+ ssl_config.client_cert = NULL;
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ // Our test server accepts certificate-less connections.
+ // TODO(davidben): Add a test which requires them and verify the error.
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
+
+ // We responded to the server's certificate request with a Certificate
+ // message with no client certificate in it. ssl_info.client_cert_sent
+ // should be false in this case.
+ SSLInfo ssl_info;
+ sock->GetSSLInfo(&ssl_info);
+ EXPECT_FALSE(ssl_info.client_cert_sent);
+
+ sock->Disconnect();
+ EXPECT_FALSE(sock->IsConnected());
+}
+
+// TODO(wtc): Add unit tests for IsConnectedAndIdle:
+// - Server closes an SSL connection (with a close_notify alert message).
+// - Server closes the underlying TCP connection directly.
+// - Server sends data unexpectedly.
+
+TEST_F(SSLClientSocketTest, Read) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+
+ rv = sock->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ for (;;) {
+ rv = sock->Read(buf.get(), 4096, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_GE(rv, 0);
+ if (rv <= 0)
+ break;
+ }
+}
+
+// Tests that the SSLClientSocket properly handles when the underlying transport
+// synchronously returns an error code - such as if an intermediary terminates
+// the socket connection uncleanly.
+// This is a regression test for http://crbug.com/238536
+TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to avoid handshake non-determinism.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ static const int kRequestTextSize =
+ static_cast<int>(arraysize(request_text) - 1);
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
+ memcpy(request_buffer->data(), request_text, kRequestTextSize);
+
+ rv = callback.GetResult(
+ sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
+ EXPECT_EQ(kRequestTextSize, rv);
+
+ // Simulate an unclean/forcible shutdown.
+ raw_transport->SetNextReadError(ERR_CONNECTION_RESET);
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+
+ // Note: This test will hang if this bug has regressed. Simply checking that
+ // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate
+ // result when using a dedicated task runner for NSS.
+ rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback()));
+
+#if !defined(USE_OPENSSL)
+ // SSLClientSocketNSS records the error exactly
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+#else
+ // SSLClientSocketOpenSSL treats any errors as a simple EOF.
+ EXPECT_EQ(0, rv);
+#endif
+}
+
+// Tests that the SSLClientSocket properly handles when the underlying transport
+// asynchronously returns an error code while writing data - such as if an
+// intermediary terminates the socket connection uncleanly.
+// This is a regression test for http://crbug.com/249848
+TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
+ // is retained in order to configure additional errors.
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to avoid handshake non-determinism.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ static const int kRequestTextSize =
+ static_cast<int>(arraysize(request_text) - 1);
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
+ memcpy(request_buffer->data(), request_text, kRequestTextSize);
+
+ // Simulate an unclean/forcible shutdown on the underlying socket.
+ // However, simulate this error asynchronously.
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
+ raw_transport->SetNextWriteShouldBlock();
+
+ // This write should complete synchronously, because the TLS ciphertext
+ // can be created and placed into the outgoing buffers independent of the
+ // underlying transport.
+ rv = callback.GetResult(
+ sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
+ EXPECT_EQ(kRequestTextSize, rv);
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+
+ rv = sock->Read(buf.get(), 4096, callback.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ // Now unblock the outgoing request, having it fail with the connection
+ // being reset.
+ raw_transport->UnblockWrite();
+
+ // Note: This will cause an inifite loop if this bug has regressed. Simply
+ // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING
+ // is a legitimate result when using a dedicated task runner for NSS.
+ rv = callback.GetResult(rv);
+
+#if !defined(USE_OPENSSL)
+ // SSLClientSocketNSS records the error exactly
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+#else
+ // SSLClientSocketOpenSSL treats any errors as a simple EOF.
+ EXPECT_EQ(0, rv);
+#endif
+}
+
+// Test the full duplex mode, with Read and Write pending at the same time.
+// This test also serves as a regression test for http://crbug.com/29815.
+TEST_F(SSLClientSocketTest, Read_FullDuplex) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback; // Used for everything except Write.
+
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ // Issue a "hanging" Read first.
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ rv = sock->Read(buf.get(), 4096, callback.callback());
+ // We haven't written the request, so there should be no response yet.
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ // Write the request.
+ // The request is padded with a User-Agent header to a size that causes the
+ // memio circular buffer (4k bytes) in SSLClientSocketNSS to wrap around.
+ // This tests the fix for http://crbug.com/29815.
+ std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
+ for (int i = 0; i < 3770; ++i)
+ request_text.push_back('*');
+ request_text.append("\r\n\r\n");
+ scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text));
+
+ TestCompletionCallback callback2; // Used for Write only.
+ rv = sock->Write(
+ request_buffer.get(), request_text.size(), callback2.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback2.WaitForResult();
+ EXPECT_EQ(static_cast<int>(request_text.size()), rv);
+
+ // Now get the Read result.
+ rv = callback.WaitForResult();
+ EXPECT_GT(rv, 0);
+}
+
+// Attempts to Read() and Write() from an SSLClientSocketNSS in full duplex
+// mode when the underlying transport is blocked on sending data. When the
+// underlying transport completes due to an error, it should invoke both the
+// Read() and Write() callbacks. If the socket is deleted by the Read()
+// callback, the Write() callback should not be invoked.
+// Regression test for http://crbug.com/232633
+TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
+ // is retained in order to configure additional errors.
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
+
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to avoid handshake non-determinism.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ scoped_ptr<SSLClientSocket> sock =
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ ssl_config);
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
+ request_text.append(20 * 1024, '*');
+ request_text.append("\r\n\r\n");
+ scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer(
+ new StringIOBuffer(request_text), request_text.size()));
+
+ // Simulate errors being returned from the underlying Read() and Write() ...
+ raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET);
+ raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
+ // ... but have those errors returned asynchronously. Because the Write() will
+ // return first, this will trigger the error.
+ raw_transport->SetNextReadShouldBlock();
+ raw_transport->SetNextWriteShouldBlock();
+
+ // Enqueue a Read() before calling Write(), which should "hang" due to
+ // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return.
+ SSLClientSocket* raw_sock = sock.get();
+ DeleteSocketCallback read_callback(sock.release());
+ scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096));
+ rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback());
+
+ // Ensure things didn't complete synchronously, otherwise |sock| is invalid.
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ ASSERT_FALSE(read_callback.have_result());
+
+#if !defined(USE_OPENSSL)
+ // NSS follows a pattern where a call to PR_Write will only consume as
+ // much data as it can encode into application data records before the
+ // internal memio buffer is full, which should only fill if writing a large
+ // amount of data and the underlying transport is blocked. Once this happens,
+ // NSS will return (total size of all application data records it wrote) - 1,
+ // with the caller expected to resume with the remaining unsent data.
+ //
+ // This causes SSLClientSocketNSS::Write to return that it wrote some data
+ // before it will return ERR_IO_PENDING, so make an extra call to Write() to
+ // get the socket in the state needed for the test below.
+ //
+ // This is not needed for OpenSSL, because for OpenSSL,
+ // SSL_MODE_ENABLE_PARTIAL_WRITE is not specified - thus
+ // SSLClientSocketOpenSSL::Write() will not return until all of
+ // |request_buffer| has been written to the underlying BIO (although not
+ // necessarily the underlying transport).
+ rv = callback.GetResult(raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback()));
+ ASSERT_LT(0, rv);
+ request_buffer->DidConsume(rv);
+
+ // Guard to ensure that |request_buffer| was larger than all of the internal
+ // buffers (transport, memio, NSS) along the way - otherwise the next call
+ // to Write() will crash with an invalid buffer.
+ ASSERT_LT(0, request_buffer->BytesRemaining());
+#endif
+
+ // Attempt to write the remaining data. NSS will not be able to consume the
+ // application data because the internal buffers are full, while OpenSSL will
+ // return that its blocked because the underlying transport is blocked.
+ rv = raw_sock->Write(request_buffer.get(),
+ request_buffer->BytesRemaining(),
+ callback.callback());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ ASSERT_FALSE(callback.have_result());
+
+ // Now unblock Write(), which will invoke OnSendComplete and (eventually)
+ // call the Read() callback, deleting the socket and thus aborting calling
+ // the Write() callback.
+ raw_transport->UnblockWrite();
+
+ rv = read_callback.WaitForResult();
+
+#if !defined(USE_OPENSSL)
+ // NSS records the error exactly.
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+#else
+ // OpenSSL treats any errors as a simple EOF.
+ EXPECT_EQ(0, rv);
+#endif
+
+ // The Write callback should not have been called.
+ EXPECT_FALSE(callback.have_result());
+}
+
+TEST_F(SSLClientSocketTest, Read_SmallChunks) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+
+ rv = sock->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(1));
+ for (;;) {
+ rv = sock->Read(buf.get(), 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_GE(rv, 0);
+ if (rv <= 0)
+ break;
+ }
+}
+
+TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<ReadBufferingStreamSocket> transport(
+ new ReadBufferingStreamSocket(real_transport.Pass()));
+ ReadBufferingStreamSocket* raw_transport = transport.get();
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ ASSERT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
+ test_server.host_port_pair(),
+ kDefaultSSLConfig));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ ASSERT_EQ(OK, rv);
+ ASSERT_TRUE(sock->IsConnected());
+
+ const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+
+ rv = callback.GetResult(sock->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback()));
+ ASSERT_GT(rv, 0);
+ ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
+
+ // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of
+ // data (the max SSL record size) at a time. Ensure that at least 15K worth
+ // of SSL data is buffered first. The 15K of buffered data is made up of
+ // many smaller SSL records (the TestServer writes along 1350 byte
+ // plaintext boundaries), although there may also be a few records that are
+ // smaller or larger, due to timing and SSL False Start.
+ // 15K was chosen because 15K is smaller than the 17K (max) read issued by
+ // the SSLClientSocket implementation, and larger than the minimum amount
+ // of ciphertext necessary to contain the 8K of plaintext requested below.
+ raw_transport->SetBufferSize(15000);
+
+ scoped_refptr<IOBuffer> buffer(new IOBuffer(8192));
+ rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback()));
+ ASSERT_EQ(rv, 8192);
+}
+
+TEST_F(SSLClientSocketTest, Read_Interrupted) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+
+ rv = sock->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
+
+ // Do a partial read and then exit. This test should not crash!
+ scoped_refptr<IOBuffer> buf(new IOBuffer(512));
+ rv = sock->Read(buf.get(), 512, callback.callback());
+ EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_GT(rv, 0);
+}
+
+TEST_F(SSLClientSocketTest, Read_FullLogging) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ log.SetLogLevel(NetLog::LOG_ALL);
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+
+ rv = sock->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ size_t last_index = ExpectLogContainsSomewhereAfter(
+ entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ for (;;) {
+ rv = sock->Read(buf.get(), 4096, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_GE(rv, 0);
+ if (rv <= 0)
+ break;
+
+ log.GetEntries(&entries);
+ last_index =
+ ExpectLogContainsSomewhereAfter(entries,
+ last_index + 1,
+ NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
+ NetLog::PHASE_NONE);
+ }
+}
+
+// Regression test for http://crbug.com/42538
+TEST_F(SSLClientSocketTest, PrematureApplicationData) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ TestCompletionCallback callback;
+
+ static const unsigned char application_data[] = {
+ 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
+ 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
+ 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
+ 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
+ 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
+ 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
+ 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
+ 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
+ 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
+ 0x0a};
+
+ // All reads and writes complete synchronously (async=false).
+ MockRead data_reads[] = {
+ MockRead(SYNCHRONOUS,
+ reinterpret_cast<const char*>(application_data),
+ arraysize(application_data)),
+ MockRead(SYNCHRONOUS, OK), };
+
+ StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0);
+
+ scoped_ptr<StreamSocket> transport(
+ new MockTCPClientSocket(addr, NULL, &data));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
+}
+
+TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
+ // Rather than exhaustively disabling every RC4 ciphersuite defined at
+ // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml,
+ // only disabling those cipher suites that the test server actually
+ // implements.
+ const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA
+ };
+
+ SpawnedTestServer::SSLOptions ssl_options;
+ // Enable only RC4 on the test server.
+ ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4;
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ SSLConfig ssl_config;
+ for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i)
+ ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+
+ // NSS has special handling that maps a handshake_failure alert received
+ // immediately after a client_hello to be a mismatched cipher suite error,
+ // leading to ERR_SSL_VERSION_OR_CIPHER_MISMATCH. When using OpenSSL or
+ // Secure Transport (OS X), the handshake_failure is bubbled up without any
+ // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure
+ // indicates that no cipher suite was negotiated with the test server.
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH ||
+ rv == ERR_SSL_PROTOCOL_ERROR);
+ // The exact ordering differs between SSLClientSocketNSS (which issues an
+ // extra read) and SSLClientSocketMac (which does not). Just make sure the
+ // error appears somewhere in the log.
+ log.GetEntries(&entries);
+ ExpectLogContainsSomewhere(
+ entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE);
+
+ // We cannot test sock->IsConnected(), as the NSS implementation disconnects
+ // the socket when it encounters an error, whereas other implementations
+ // leave it connected.
+ // Because this an error that the test server is mutually aware of, as opposed
+ // to being an error such as a certificate name mismatch, which is
+ // client-only, the exact index of the SSL connect end depends on how
+ // quickly the test server closes the underlying socket. If the test server
+ // closes before the IO message loop pumps messages, there may be a 0-byte
+ // Read event in the NetLog due to TCPClientSocket picking up the EOF. As a
+ // result, the SSL connect end event will be the second-to-last entry,
+ // rather than the last entry.
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1) ||
+ LogContainsSSLConnectEndEvent(entries, -2));
+}
+
+// When creating an SSLClientSocket, it is allowed to pass in a
+// ClientSocketHandle that is not obtained from a client socket pool.
+// Here we verify that such a simple ClientSocketHandle, not associated with any
+// client socket pool, can be destroyed safely.
+TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle());
+ socket_handle->SetSocket(transport.Pass());
+
+ scoped_ptr<SSLClientSocket> sock(
+ socket_factory_->CreateSSLClientSocket(socket_handle.Pass(),
+ test_server.host_port_pair(),
+ kDefaultSSLConfig,
+ context_));
+
+ EXPECT_FALSE(sock->IsConnected());
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+}
+
+// Verifies that SSLClientSocket::ExportKeyingMaterial return a success
+// code and different keying label results in different keying material.
+TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ const int kKeyingMaterialSize = 32;
+ const char* kKeyingLabel1 = "client-socket-test-1";
+ const char* kKeyingContext = "";
+ unsigned char client_out1[kKeyingMaterialSize];
+ memset(client_out1, 0, sizeof(client_out1));
+ rv = sock->ExportKeyingMaterial(
+ kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1));
+ EXPECT_EQ(rv, OK);
+
+ const char* kKeyingLabel2 = "client-socket-test-2";
+ unsigned char client_out2[kKeyingMaterialSize];
+ memset(client_out2, 0, sizeof(client_out2));
+ rv = sock->ExportKeyingMaterial(
+ kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2));
+ EXPECT_EQ(rv, OK);
+ EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0);
+}
+
+// Verifies that SSLClientSocket::ClearSessionCache can be called without
+// explicit NSS initialization.
+TEST(SSLClientSocket, ClearSessionCache) {
+ SSLClientSocket::ClearSessionCache();
+}
+
+// This tests that SSLInfo contains a properly re-constructed certificate
+// chain. That, in turn, verifies that GetSSLInfo is giving us the chain as
+// verified, not the chain as served by the server. (They may be different.)
+//
+// CERT_CHAIN_WRONG_ROOT is redundant-server-chain.pem. It contains A
+// (end-entity) -> B -> C, and C is signed by D. redundant-validated-chain.pem
+// contains a chain of A -> B -> C2, where C2 is the same public key as C, but
+// a self-signed root. Such a situation can occur when a new root (C2) is
+// cross-certified by an old root (D) and has two different versions of its
+// floating around. Servers may supply C2 as an intermediate, but the
+// SSLClientSocket should return the chain that was verified, from
+// verify_result, instead.
+TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
+ // By default, cause the CertVerifier to treat all certificates as
+ // expired.
+ cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
+
+ // We will expect SSLInfo to ultimately contain this chain.
+ CertificateList certs =
+ CreateCertificateListFromFile(GetTestCertsDirectory(),
+ "redundant-validated-chain.pem",
+ X509Certificate::FORMAT_AUTO);
+ ASSERT_EQ(3U, certs.size());
+
+ X509Certificate::OSCertHandles temp_intermediates;
+ temp_intermediates.push_back(certs[1]->os_cert_handle());
+ temp_intermediates.push_back(certs[2]->os_cert_handle());
+
+ CertVerifyResult verify_result;
+ verify_result.verified_cert = X509Certificate::CreateFromHandle(
+ certs[0]->os_cert_handle(), temp_intermediates);
+
+ // Add a rule that maps the server cert (A) to the chain of A->B->C2
+ // rather than A->B->C.
+ cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK);
+
+ // Load and install the root for the validated chain.
+ scoped_refptr<X509Certificate> root_cert = ImportCertFromFile(
+ GetTestCertsDirectory(), "redundant-validated-chain-root.pem");
+ ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert);
+ ScopedTestRoot scoped_root(root_cert.get());
+
+ // Set up a test server with CERT_CHAIN_WRONG_ROOT.
+ SpawnedTestServer::SSLOptions ssl_options(
+ SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS,
+ ssl_options,
+ base::FilePath(FILE_PATH_LITERAL("net/data/ssl")));
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+ EXPECT_FALSE(sock->IsConnected());
+ rv = sock->Connect(callback.callback());
+
+ CapturingNetLog::CapturedEntryList entries;
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+ log.GetEntries(&entries);
+ EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
+
+ SSLInfo ssl_info;
+ sock->GetSSLInfo(&ssl_info);
+
+ // Verify that SSLInfo contains the corrected re-constructed chain A -> B
+ // -> C2.
+ const X509Certificate::OSCertHandles& intermediates =
+ ssl_info.cert->GetIntermediateCertificates();
+ ASSERT_EQ(2U, intermediates.size());
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(),
+ certs[0]->os_cert_handle()));
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0],
+ certs[1]->os_cert_handle()));
+ EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1],
+ certs[2]->os_cert_handle()));
+
+ sock->Disconnect();
+ EXPECT_FALSE(sock->IsConnected());
+}
+
+// Verifies the correctness of GetSSLCertRequestInfo.
+class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
+ protected:
+ // Creates a test server with the given SSLOptions, connects to it and returns
+ // the SSLCertRequestInfo reported by the socket.
+ scoped_refptr<SSLCertRequestInfo> GetCertRequest(
+ SpawnedTestServer::SSLOptions ssl_options) {
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ if (!test_server.Start())
+ return NULL;
+
+ AddressList addr;
+ if (!test_server.GetAddressList(&addr))
+ return NULL;
+
+ TestCompletionCallback callback;
+ CapturingNetLog log;
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, &log, NetLog::Source()));
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
+ EXPECT_FALSE(sock->IsConnected());
+
+ rv = sock->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
+ sock->GetSSLCertRequestInfo(request_info.get());
+ sock->Disconnect();
+ EXPECT_FALSE(sock->IsConnected());
+
+ return request_info;
+ }
+};
+
+TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
+ ASSERT_TRUE(request_info.get());
+ EXPECT_EQ(0u, request_info->cert_authorities.size());
+}
+
+TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) {
+ const base::FilePath::CharType kThawteFile[] =
+ FILE_PATH_LITERAL("thawte.single.pem");
+ const unsigned char kThawteDN[] = {
+ 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+ 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a,
+ 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e,
+ 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79,
+ 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03,
+ 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20,
+ 0x53, 0x47, 0x43, 0x20, 0x43, 0x41};
+ const size_t kThawteLen = sizeof(kThawteDN);
+
+ const base::FilePath::CharType kDiginotarFile[] =
+ FILE_PATH_LITERAL("diginotar_root_ca.pem");
+ const unsigned char kDiginotarDN[] = {
+ 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
+ 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a,
+ 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31,
+ 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69,
+ 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74,
+ 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48,
+ 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f,
+ 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e,
+ 0x6c};
+ const size_t kDiginotarLen = sizeof(kDiginotarDN);
+
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.request_client_certificate = true;
+ ssl_options.client_authorities.push_back(
+ GetTestClientCertsDirectory().Append(kThawteFile));
+ ssl_options.client_authorities.push_back(
+ GetTestClientCertsDirectory().Append(kDiginotarFile));
+ scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
+ ASSERT_TRUE(request_info.get());
+ ASSERT_EQ(2u, request_info->cert_authorities.size());
+ EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen),
+ request_info->cert_authorities[0]);
+ EXPECT_EQ(
+ std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen),
+ request_info->cert_authorities[1]);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_error_params.cc b/chromium/net/socket/ssl_error_params.cc
new file mode 100644
index 00000000000..37561f0de48
--- /dev/null
+++ b/chromium/net/socket/ssl_error_params.cc
@@ -0,0 +1,31 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_error_params.h"
+
+#include "base/bind.h"
+#include "base/values.h"
+
+namespace net {
+
+namespace {
+
+base::Value* NetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetInteger("net_error", net_error);
+ if (ssl_lib_error)
+ dict->SetInteger("ssl_lib_error", ssl_lib_error);
+ return dict;
+}
+
+} // namespace
+
+NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error) {
+ return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_error_params.h b/chromium/net/socket/ssl_error_params.h
new file mode 100644
index 00000000000..07a1c4d99d9
--- /dev/null
+++ b/chromium/net/socket/ssl_error_params.h
@@ -0,0 +1,18 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_ERROR_PARAMS_H_
+#define NET_SOCKET_SSL_ERROR_PARAMS_H_
+
+#include "net/base/net_log.h"
+
+namespace net {
+
+// Creates NetLog callback for when we receive an SSL error.
+NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error);
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_ERROR_PARAMS_H_
diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h
new file mode 100644
index 00000000000..8b607bf80cf
--- /dev/null
+++ b/chromium/net/socket/ssl_server_socket.h
@@ -0,0 +1,64 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_SERVER_SOCKET_H_
+#define NET_SOCKET_SSL_SERVER_SOCKET_H_
+
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+#include "net/socket/ssl_socket.h"
+#include "net/socket/stream_socket.h"
+
+namespace crypto {
+class RSAPrivateKey;
+} // namespace crypto
+
+namespace net {
+
+struct SSLConfig;
+class X509Certificate;
+
+class SSLServerSocket : public SSLSocket {
+ public:
+ virtual ~SSLServerSocket() {}
+
+ // Perform the SSL server handshake, and notify the supplied callback
+ // if the process completes asynchronously. If Disconnect is called before
+ // completion then the callback will be silently, as for other StreamSocket
+ // calls.
+ virtual int Handshake(const CompletionCallback& callback) = 0;
+};
+
+// Configures the underlying SSL library for the use of SSL server sockets.
+//
+// Due to the requirements of the underlying libraries, this should be called
+// early in process initialization, before any SSL socket, client or server,
+// has been used.
+//
+// Note: If a process does not use SSL server sockets, this call may be
+// omitted.
+NET_EXPORT void EnableSSLServerSockets();
+
+// Creates an SSL server socket over an already-connected transport socket.
+// The caller must provide the server certificate and private key to use.
+//
+// The returned SSLServerSocket takes ownership of |socket|. Stubbed versions
+// of CreateSSLServerSocket will delete |socket| and return NULL.
+// It takes a reference to |certificate|.
+// The |key| and |ssl_config| parameters are copied. |key| cannot be const
+// because the methods used to copy its contents are non-const.
+//
+// The caller starts the SSL server handshake by calling Handshake on the
+// returned socket.
+NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
+ X509Certificate* certificate,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config);
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_SERVER_SOCKET_H_
diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc
new file mode 100644
index 00000000000..7e5d70118ac
--- /dev/null
+++ b/chromium/net/socket/ssl_server_socket_nss.cc
@@ -0,0 +1,828 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/ssl_server_socket_nss.h"
+
+#if defined(OS_WIN)
+#include <winsock2.h>
+#endif
+
+#if defined(USE_SYSTEM_SSL)
+#include <dlfcn.h>
+#endif
+#if defined(OS_MACOSX)
+#include <Security/Security.h>
+#endif
+#include <certdb.h>
+#include <cryptohi.h>
+#include <hasht.h>
+#include <keyhi.h>
+#include <nspr.h>
+#include <nss.h>
+#include <pk11pub.h>
+#include <secerr.h>
+#include <sechash.h>
+#include <ssl.h>
+#include <sslerr.h>
+#include <sslproto.h>
+
+#include <limits>
+
+#include "base/lazy_instance.h"
+#include "base/memory/ref_counted.h"
+#include "crypto/rsa_private_key.h"
+#include "crypto/nss_util_internal.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/nss_ssl_util.h"
+#include "net/socket/ssl_error_params.h"
+
+// SSL plaintext fragments are shorter than 16KB. Although the record layer
+// overhead is allowed to be 2K + 5 bytes, in practice the overhead is much
+// smaller than 1KB. So a 17KB buffer should be large enough to hold an
+// entire SSL record.
+static const int kRecvBufferSize = 17 * 1024;
+static const int kSendBufferSize = 17 * 1024;
+
+#define GotoState(s) next_handshake_state_ = s
+
+namespace net {
+
+namespace {
+
+bool g_nss_server_sockets_init = false;
+
+class NSSSSLServerInitSingleton {
+ public:
+ NSSSSLServerInitSingleton() {
+ EnsureNSSSSLInit();
+
+ SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL);
+ g_nss_server_sockets_init = true;
+ }
+
+ ~NSSSSLServerInitSingleton() {
+ SSL_ShutdownServerSessionIDCache();
+ g_nss_server_sockets_init = false;
+ }
+};
+
+static base::LazyInstance<NSSSSLServerInitSingleton>
+ g_nss_ssl_server_init_singleton = LAZY_INSTANCE_INITIALIZER;
+
+} // namespace
+
+void EnableSSLServerSockets() {
+ g_nss_ssl_server_init_singleton.Get();
+}
+
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
+ X509Certificate* cert,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config) {
+ DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been"
+ << "called yet!";
+
+ return scoped_ptr<SSLServerSocket>(
+ new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config));
+}
+
+SSLServerSocketNSS::SSLServerSocketNSS(
+ scoped_ptr<StreamSocket> transport_socket,
+ scoped_refptr<X509Certificate> cert,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config)
+ : transport_send_busy_(false),
+ transport_recv_busy_(false),
+ user_read_buf_len_(0),
+ user_write_buf_len_(0),
+ nss_fd_(NULL),
+ nss_bufs_(NULL),
+ transport_socket_(transport_socket.Pass()),
+ ssl_config_(ssl_config),
+ cert_(cert),
+ next_handshake_state_(STATE_NONE),
+ completed_handshake_(false) {
+ ssl_config_.false_start_enabled = false;
+ ssl_config_.version_min = SSL_PROTOCOL_VERSION_SSL3;
+ ssl_config_.version_max = SSL_PROTOCOL_VERSION_TLS1_1;
+
+ // TODO(hclam): Need a better way to clone a key.
+ std::vector<uint8> key_bytes;
+ CHECK(key->ExportPrivateKey(&key_bytes));
+ key_.reset(crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_bytes));
+ CHECK(key_.get());
+}
+
+SSLServerSocketNSS::~SSLServerSocketNSS() {
+ if (nss_fd_ != NULL) {
+ PR_Close(nss_fd_);
+ nss_fd_ = NULL;
+ }
+}
+
+int SSLServerSocketNSS::Handshake(const CompletionCallback& callback) {
+ net_log_.BeginEvent(NetLog::TYPE_SSL_SERVER_HANDSHAKE);
+
+ int rv = Init();
+ if (rv != OK) {
+ LOG(ERROR) << "Failed to initialize NSS";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
+ return rv;
+ }
+
+ rv = InitializeSSLOptions();
+ if (rv != OK) {
+ LOG(ERROR) << "Failed to initialize SSL options";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
+ return rv;
+ }
+
+ // Set peer address. TODO(hclam): This should be in a separate method.
+ PRNetAddr peername;
+ memset(&peername, 0, sizeof(peername));
+ peername.raw.family = AF_INET;
+ memio_SetPeerName(nss_fd_, &peername);
+
+ GotoState(STATE_HANDSHAKE);
+ rv = DoHandshakeLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ user_handshake_callback_ = callback;
+ } else {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
+ }
+
+ return rv > OK ? OK : rv;
+}
+
+int SSLServerSocketNSS::ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ SECStatus result = SSL_ExportKeyingMaterial(
+ nss_fd_, label.data(), label.size(), has_context,
+ reinterpret_cast<const unsigned char*>(context.data()),
+ context.length(), out, outlen);
+ if (result != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_ExportKeyingMaterial", "");
+ return MapNSSError(PORT_GetError());
+ }
+ return OK;
+}
+
+int SSLServerSocketNSS::GetTLSUniqueChannelBinding(std::string* out) {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ unsigned char buf[64];
+ unsigned int len;
+ SECStatus result = SSL_GetChannelBinding(nss_fd_,
+ SSL_CHANNEL_BINDING_TLS_UNIQUE,
+ buf, &len, arraysize(buf));
+ if (result != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", "");
+ return MapNSSError(PORT_GetError());
+ }
+ out->assign(reinterpret_cast<char*>(buf), len);
+ return OK;
+}
+
+int SSLServerSocketNSS::Connect(const CompletionCallback& callback) {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(user_read_callback_.is_null());
+ DCHECK(user_handshake_callback_.is_null());
+ DCHECK(!user_read_buf_.get());
+ DCHECK(nss_bufs_);
+ DCHECK(!callback.is_null());
+
+ user_read_buf_ = buf;
+ user_read_buf_len_ = buf_len;
+
+ DCHECK(completed_handshake_);
+
+ int rv = DoReadLoop(OK);
+
+ if (rv == ERR_IO_PENDING) {
+ user_read_callback_ = callback;
+ } else {
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ }
+ return rv;
+}
+
+int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(user_write_callback_.is_null());
+ DCHECK(!user_write_buf_.get());
+ DCHECK(nss_bufs_);
+ DCHECK(!callback.is_null());
+
+ user_write_buf_ = buf;
+ user_write_buf_len_ = buf_len;
+
+ int rv = DoWriteLoop(OK);
+
+ if (rv == ERR_IO_PENDING) {
+ user_write_callback_ = callback;
+ } else {
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+ }
+ return rv;
+}
+
+bool SSLServerSocketNSS::SetReceiveBufferSize(int32 size) {
+ return transport_socket_->SetReceiveBufferSize(size);
+}
+
+bool SSLServerSocketNSS::SetSendBufferSize(int32 size) {
+ return transport_socket_->SetSendBufferSize(size);
+}
+
+bool SSLServerSocketNSS::IsConnected() const {
+ return completed_handshake_;
+}
+
+void SSLServerSocketNSS::Disconnect() {
+ transport_socket_->Disconnect();
+}
+
+bool SSLServerSocketNSS::IsConnectedAndIdle() const {
+ return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
+}
+
+int SSLServerSocketNSS::GetPeerAddress(IPEndPoint* address) const {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ return transport_socket_->GetPeerAddress(address);
+}
+
+int SSLServerSocketNSS::GetLocalAddress(IPEndPoint* address) const {
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ return transport_socket_->GetLocalAddress(address);
+}
+
+const BoundNetLog& SSLServerSocketNSS::NetLog() const {
+ return net_log_;
+}
+
+void SSLServerSocketNSS::SetSubresourceSpeculation() {
+ transport_socket_->SetSubresourceSpeculation();
+}
+
+void SSLServerSocketNSS::SetOmniboxSpeculation() {
+ transport_socket_->SetOmniboxSpeculation();
+}
+
+bool SSLServerSocketNSS::WasEverUsed() const {
+ return transport_socket_->WasEverUsed();
+}
+
+bool SSLServerSocketNSS::UsingTCPFastOpen() const {
+ return transport_socket_->UsingTCPFastOpen();
+}
+
+bool SSLServerSocketNSS::WasNpnNegotiated() const {
+ return false;
+}
+
+NextProto SSLServerSocketNSS::GetNegotiatedProtocol() const {
+ // NPN is not supported by this class.
+ return kProtoUnknown;
+}
+
+bool SSLServerSocketNSS::GetSSLInfo(SSLInfo* ssl_info) {
+ NOTIMPLEMENTED();
+ return false;
+}
+
+int SSLServerSocketNSS::InitializeSSLOptions() {
+ // Transport connected, now hook it up to nss
+ nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize);
+ if (nss_fd_ == NULL) {
+ return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR error code.
+ }
+
+ // Grab pointer to buffers
+ nss_bufs_ = memio_GetSecret(nss_fd_);
+
+ /* Create SSL state machine */
+ /* Push SSL onto our fake I/O socket */
+ nss_fd_ = SSL_ImportFD(NULL, nss_fd_);
+ if (nss_fd_ == NULL) {
+ LogFailedNSSFunction(net_log_, "SSL_ImportFD", "");
+ return ERR_OUT_OF_MEMORY; // TODO(port): map NSPR/NSS error code.
+ }
+ // TODO(port): set more ssl options! Check errors!
+
+ int rv;
+
+ rv = SSL_OptionSet(nss_fd_, SSL_SECURITY, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_SECURITY");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SSL2, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_SSL2");
+ return ERR_UNEXPECTED;
+ }
+
+ SSLVersionRange version_range;
+ version_range.min = ssl_config_.version_min;
+ version_range.max = ssl_config_.version_max;
+ rv = SSL_VersionRangeSet(nss_fd_, &version_range);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", "");
+ return ERR_NO_SSL_VERSIONS_ENABLED;
+ }
+
+ for (std::vector<uint16>::const_iterator it =
+ ssl_config_.disabled_cipher_suites.begin();
+ it != ssl_config_.disabled_cipher_suites.end(); ++it) {
+ // This will fail if the specified cipher is not implemented by NSS, but
+ // the failure is harmless.
+ SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE);
+ }
+
+ // Server socket doesn't need session tickets.
+ rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(
+ net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS");
+ }
+
+ // Doing this will force PR_Accept perform handshake as server.
+ rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_CLIENT, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_CLIENT");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_HANDSHAKE_AS_SERVER, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_HANDSHAKE_AS_SERVER");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUEST_CERTIFICATE");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_OptionSet(nss_fd_, SSL_REQUIRE_CERTIFICATE, PR_FALSE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_REQUIRE_CERTIFICATE");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_AuthCertificateHook(nss_fd_, OwnAuthCertHandler, this);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_AuthCertificateHook", "");
+ return ERR_UNEXPECTED;
+ }
+
+ rv = SSL_HandshakeCallback(nss_fd_, HandshakeCallback, this);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_HandshakeCallback", "");
+ return ERR_UNEXPECTED;
+ }
+
+ // Get a certificate of CERTCertificate structure.
+ std::string der_string;
+ if (!X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string))
+ return ERR_UNEXPECTED;
+
+ SECItem der_cert;
+ der_cert.data = reinterpret_cast<unsigned char*>(const_cast<char*>(
+ der_string.data()));
+ der_cert.len = der_string.length();
+ der_cert.type = siDERCertBuffer;
+
+ // Parse into a CERTCertificate structure.
+ CERTCertificate* cert = CERT_NewTempCertificate(
+ CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE);
+ if (!cert) {
+ LogFailedNSSFunction(net_log_, "CERT_NewTempCertificate", "");
+ return MapNSSError(PORT_GetError());
+ }
+
+ // Get a key of SECKEYPrivateKey* structure.
+ std::vector<uint8> key_vector;
+ if (!key_->ExportPrivateKey(&key_vector)) {
+ CERT_DestroyCertificate(cert);
+ return ERR_UNEXPECTED;
+ }
+
+ SECKEYPrivateKeyStr* private_key = NULL;
+ PK11SlotInfo* slot = crypto::GetPrivateNSSKeySlot();
+ if (!slot) {
+ CERT_DestroyCertificate(cert);
+ return ERR_UNEXPECTED;
+ }
+
+ SECItem der_private_key_info;
+ der_private_key_info.data =
+ const_cast<unsigned char*>(&key_vector.front());
+ der_private_key_info.len = key_vector.size();
+ // The server's RSA private key must be imported into NSS with the
+ // following key usage bits:
+ // - KU_KEY_ENCIPHERMENT, required for the RSA key exchange algorithm.
+ // - KU_DIGITAL_SIGNATURE, required for the DHE_RSA and ECDHE_RSA key
+ // exchange algorithms.
+ const unsigned int key_usage = KU_KEY_ENCIPHERMENT | KU_DIGITAL_SIGNATURE;
+ rv = PK11_ImportDERPrivateKeyInfoAndReturnKey(
+ slot, &der_private_key_info, NULL, NULL, PR_FALSE, PR_FALSE,
+ key_usage, &private_key, NULL);
+ PK11_FreeSlot(slot);
+ if (rv != SECSuccess) {
+ CERT_DestroyCertificate(cert);
+ return ERR_UNEXPECTED;
+ }
+
+ // Assign server certificate and private key.
+ SSLKEAType cert_kea = NSS_FindCertKEAType(cert);
+ rv = SSL_ConfigSecureServer(nss_fd_, cert, private_key, cert_kea);
+ CERT_DestroyCertificate(cert);
+ SECKEY_DestroyPrivateKey(private_key);
+
+ if (rv != SECSuccess) {
+ PRErrorCode prerr = PR_GetError();
+ LOG(ERROR) << "Failed to config SSL server: " << prerr;
+ LogFailedNSSFunction(net_log_, "SSL_ConfigureSecureServer", "");
+ return ERR_UNEXPECTED;
+ }
+
+ // Tell SSL we're a server; needed if not letting NSPR do socket I/O
+ rv = SSL_ResetHandshake(nss_fd_, PR_TRUE);
+ if (rv != SECSuccess) {
+ LogFailedNSSFunction(net_log_, "SSL_ResetHandshake", "");
+ return ERR_UNEXPECTED;
+ }
+
+ return OK;
+}
+
+void SSLServerSocketNSS::OnSendComplete(int result) {
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ // In handshake phase.
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ if (!completed_handshake_)
+ return;
+
+ if (user_write_buf_.get()) {
+ int rv = DoWriteLoop(result);
+ if (rv != ERR_IO_PENDING)
+ DoWriteCallback(rv);
+ } else {
+ // Ensure that any queued ciphertext is flushed.
+ DoTransportIO();
+ }
+}
+
+void SSLServerSocketNSS::OnRecvComplete(int result) {
+ if (next_handshake_state_ == STATE_HANDSHAKE) {
+ // In handshake phase.
+ OnHandshakeIOComplete(result);
+ return;
+ }
+
+ // Network layer received some data, check if client requested to read
+ // decrypted data.
+ if (!user_read_buf_.get() || !completed_handshake_)
+ return;
+
+ int rv = DoReadLoop(result);
+ if (rv != ERR_IO_PENDING)
+ DoReadCallback(rv);
+}
+
+void SSLServerSocketNSS::OnHandshakeIOComplete(int result) {
+ int rv = DoHandshakeLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_SERVER_HANDSHAKE, rv);
+ if (!user_handshake_callback_.is_null())
+ DoHandshakeCallback(rv);
+ }
+}
+
+// Return 0 for EOF,
+// > 0 for bytes transferred immediately,
+// < 0 for error (or the non-error ERR_IO_PENDING).
+int SSLServerSocketNSS::BufferSend(void) {
+ if (transport_send_busy_)
+ return ERR_IO_PENDING;
+
+ const char* buf1;
+ const char* buf2;
+ unsigned int len1, len2;
+ memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2);
+ const unsigned int len = len1 + len2;
+
+ int rv = 0;
+ if (len) {
+ scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len));
+ memcpy(send_buffer->data(), buf1, len1);
+ memcpy(send_buffer->data() + len1, buf2, len2);
+ rv = transport_socket_->Write(
+ send_buffer.get(),
+ len,
+ base::Bind(&SSLServerSocketNSS::BufferSendComplete,
+ base::Unretained(this)));
+ if (rv == ERR_IO_PENDING) {
+ transport_send_busy_ = true;
+ } else {
+ memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv));
+ }
+ }
+
+ return rv;
+}
+
+void SSLServerSocketNSS::BufferSendComplete(int result) {
+ memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result));
+ transport_send_busy_ = false;
+ OnSendComplete(result);
+}
+
+int SSLServerSocketNSS::BufferRecv(void) {
+ if (transport_recv_busy_) return ERR_IO_PENDING;
+
+ char* buf;
+ int nb = memio_GetReadParams(nss_bufs_, &buf);
+ int rv;
+ if (!nb) {
+ // buffer too full to read into, so no I/O possible at moment
+ rv = ERR_IO_PENDING;
+ } else {
+ recv_buffer_ = new IOBuffer(nb);
+ rv = transport_socket_->Read(
+ recv_buffer_.get(),
+ nb,
+ base::Bind(&SSLServerSocketNSS::BufferRecvComplete,
+ base::Unretained(this)));
+ if (rv == ERR_IO_PENDING) {
+ transport_recv_busy_ = true;
+ } else {
+ if (rv > 0)
+ memcpy(buf, recv_buffer_->data(), rv);
+ memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv));
+ recv_buffer_ = NULL;
+ }
+ }
+ return rv;
+}
+
+void SSLServerSocketNSS::BufferRecvComplete(int result) {
+ if (result > 0) {
+ char* buf;
+ memio_GetReadParams(nss_bufs_, &buf);
+ memcpy(buf, recv_buffer_->data(), result);
+ }
+ recv_buffer_ = NULL;
+ memio_PutReadResult(nss_bufs_, MapErrorToNSS(result));
+ transport_recv_busy_ = false;
+ OnRecvComplete(result);
+}
+
+// Do as much network I/O as possible between the buffer and the
+// transport socket. Return true if some I/O performed, false
+// otherwise (error or ERR_IO_PENDING).
+bool SSLServerSocketNSS::DoTransportIO() {
+ bool network_moved = false;
+ if (nss_bufs_ != NULL) {
+ int rv;
+ // Read and write as much data as we can. The loop is neccessary
+ // because Write() may return synchronously.
+ do {
+ rv = BufferSend();
+ if (rv > 0)
+ network_moved = true;
+ } while (rv > 0);
+ if (BufferRecv() >= 0)
+ network_moved = true;
+ }
+ return network_moved;
+}
+
+int SSLServerSocketNSS::DoPayloadRead() {
+ DCHECK(user_read_buf_.get());
+ DCHECK_GT(user_read_buf_len_, 0);
+ int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_);
+ if (rv >= 0)
+ return rv;
+ PRErrorCode prerr = PR_GetError();
+ if (prerr == PR_WOULD_BLOCK_ERROR) {
+ return ERR_IO_PENDING;
+ }
+ rv = MapNSSError(prerr);
+ net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, prerr));
+ return rv;
+}
+
+int SSLServerSocketNSS::DoPayloadWrite() {
+ DCHECK(user_write_buf_.get());
+ int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_);
+ if (rv >= 0)
+ return rv;
+ PRErrorCode prerr = PR_GetError();
+ if (prerr == PR_WOULD_BLOCK_ERROR) {
+ return ERR_IO_PENDING;
+ }
+ rv = MapNSSError(prerr);
+ net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR,
+ CreateNetLogSSLErrorCallback(rv, prerr));
+ return rv;
+}
+
+int SSLServerSocketNSS::DoHandshakeLoop(int last_io_result) {
+ int rv = last_io_result;
+ do {
+ // Default to STATE_NONE for next state.
+ // (This is a quirk carried over from the windows
+ // implementation. It makes reading the logs a bit harder.)
+ // State handlers can and often do call GotoState just
+ // to stay in the current state.
+ State state = next_handshake_state_;
+ GotoState(STATE_NONE);
+ switch (state) {
+ case STATE_HANDSHAKE:
+ rv = DoHandshake();
+ break;
+ case STATE_NONE:
+ default:
+ rv = ERR_UNEXPECTED;
+ LOG(DFATAL) << "unexpected state " << state;
+ break;
+ }
+
+ // Do the actual network I/O
+ bool network_moved = DoTransportIO();
+ if (network_moved && next_handshake_state_ == STATE_HANDSHAKE) {
+ // In general we exit the loop if rv is ERR_IO_PENDING. In this
+ // special case we keep looping even if rv is ERR_IO_PENDING because
+ // the transport IO may allow DoHandshake to make progress.
+ rv = OK; // This causes us to stay in the loop.
+ }
+ } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
+ return rv;
+}
+
+int SSLServerSocketNSS::DoReadLoop(int result) {
+ DCHECK(completed_handshake_);
+ DCHECK(next_handshake_state_ == STATE_NONE);
+
+ if (result < 0)
+ return result;
+
+ if (!nss_bufs_) {
+ LOG(DFATAL) << "!nss_bufs_";
+ int rv = ERR_UNEXPECTED;
+ net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogSSLErrorCallback(rv, 0));
+ return rv;
+ }
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadRead();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+ return rv;
+}
+
+int SSLServerSocketNSS::DoWriteLoop(int result) {
+ DCHECK(completed_handshake_);
+ DCHECK(next_handshake_state_ == STATE_NONE);
+
+ if (result < 0)
+ return result;
+
+ if (!nss_bufs_) {
+ LOG(DFATAL) << "!nss_bufs_";
+ int rv = ERR_UNEXPECTED;
+ net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR,
+ CreateNetLogSSLErrorCallback(rv, 0));
+ return rv;
+ }
+
+ bool network_moved;
+ int rv;
+ do {
+ rv = DoPayloadWrite();
+ network_moved = DoTransportIO();
+ } while (rv == ERR_IO_PENDING && network_moved);
+ return rv;
+}
+
+int SSLServerSocketNSS::DoHandshake() {
+ int net_error = OK;
+ SECStatus rv = SSL_ForceHandshake(nss_fd_);
+
+ if (rv == SECSuccess) {
+ completed_handshake_ = true;
+ } else {
+ PRErrorCode prerr = PR_GetError();
+ net_error = MapNSSError(prerr);
+
+ // If not done, stay in this state
+ if (net_error == ERR_IO_PENDING) {
+ GotoState(STATE_HANDSHAKE);
+ } else {
+ LOG(ERROR) << "handshake failed; NSS error code " << prerr
+ << ", net_error " << net_error;
+ net_log_.AddEvent(NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogSSLErrorCallback(net_error, prerr));
+ }
+ }
+ return net_error;
+}
+
+void SSLServerSocketNSS::DoHandshakeCallback(int rv) {
+ DCHECK_NE(rv, ERR_IO_PENDING);
+
+ CompletionCallback c = user_handshake_callback_;
+ user_handshake_callback_.Reset();
+ c.Run(rv > OK ? OK : rv);
+}
+
+void SSLServerSocketNSS::DoReadCallback(int rv) {
+ DCHECK(rv != ERR_IO_PENDING);
+ DCHECK(!user_read_callback_.is_null());
+
+ // Since Run may result in Read being called, clear |user_read_callback_|
+ // up front.
+ CompletionCallback c = user_read_callback_;
+ user_read_callback_.Reset();
+ user_read_buf_ = NULL;
+ user_read_buf_len_ = 0;
+ c.Run(rv);
+}
+
+void SSLServerSocketNSS::DoWriteCallback(int rv) {
+ DCHECK(rv != ERR_IO_PENDING);
+ DCHECK(!user_write_callback_.is_null());
+
+ // Since Run may result in Write being called, clear |user_write_callback_|
+ // up front.
+ CompletionCallback c = user_write_callback_;
+ user_write_callback_.Reset();
+ user_write_buf_ = NULL;
+ user_write_buf_len_ = 0;
+ c.Run(rv);
+}
+
+// static
+// NSS calls this if an incoming certificate needs to be verified.
+// Do nothing but return SECSuccess.
+// This is called only in full handshake mode.
+// Peer certificate is retrieved in HandshakeCallback() later, which is called
+// in full handshake mode or in resumption handshake mode.
+SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg,
+ PRFileDesc* socket,
+ PRBool checksig,
+ PRBool is_server) {
+ // TODO(hclam): Implement.
+ // Tell NSS to not verify the certificate.
+ return SECSuccess;
+}
+
+// static
+// NSS calls this when handshake is completed.
+// After the SSL handshake is finished we need to verify the certificate.
+void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket,
+ void* arg) {
+ // TODO(hclam): Implement.
+}
+
+int SSLServerSocketNSS::Init() {
+ // Initialize the NSS SSL library in a threadsafe way. This also
+ // initializes the NSS base library.
+ EnsureNSSSSLInit();
+ if (!NSS_IsInitialized())
+ return ERR_UNEXPECTED;
+
+ EnableSSLServerSockets();
+ return OK;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h
new file mode 100644
index 00000000000..8bbb0e338ac
--- /dev/null
+++ b/chromium/net/socket/ssl_server_socket_nss.h
@@ -0,0 +1,150 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_
+#define NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_
+
+#include <certt.h>
+#include <keyt.h>
+#include <nspr.h>
+#include <nss.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_log.h"
+#include "net/base/nss_memio.h"
+#include "net/socket/ssl_server_socket.h"
+#include "net/ssl/ssl_config_service.h"
+
+namespace net {
+
+class SSLServerSocketNSS : public SSLServerSocket {
+ public:
+ // See comments on CreateSSLServerSocket for details of how these
+ // parameters are used.
+ SSLServerSocketNSS(scoped_ptr<StreamSocket> socket,
+ scoped_refptr<X509Certificate> certificate,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config);
+ virtual ~SSLServerSocketNSS();
+
+ // SSLServerSocket interface.
+ virtual int Handshake(const CompletionCallback& callback) OVERRIDE;
+
+ // SSLSocket interface.
+ virtual int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) OVERRIDE;
+ virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+
+ // Socket interface (via StreamSocket).
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ private:
+ enum State {
+ STATE_NONE,
+ STATE_HANDSHAKE,
+ };
+
+ int InitializeSSLOptions();
+
+ void OnSendComplete(int result);
+ void OnRecvComplete(int result);
+ void OnHandshakeIOComplete(int result);
+
+ int BufferSend();
+ void BufferSendComplete(int result);
+ int BufferRecv();
+ void BufferRecvComplete(int result);
+ bool DoTransportIO();
+ int DoPayloadRead();
+ int DoPayloadWrite();
+
+ int DoHandshakeLoop(int last_io_result);
+ int DoReadLoop(int result);
+ int DoWriteLoop(int result);
+ int DoHandshake();
+ void DoHandshakeCallback(int result);
+ void DoReadCallback(int result);
+ void DoWriteCallback(int result);
+
+ static SECStatus OwnAuthCertHandler(void* arg,
+ PRFileDesc* socket,
+ PRBool checksig,
+ PRBool is_server);
+ static void HandshakeCallback(PRFileDesc* socket, void* arg);
+
+ virtual int Init();
+
+ // Members used to send and receive buffer.
+ bool transport_send_busy_;
+ bool transport_recv_busy_;
+
+ scoped_refptr<IOBuffer> recv_buffer_;
+
+ BoundNetLog net_log_;
+
+ CompletionCallback user_handshake_callback_;
+ CompletionCallback user_read_callback_;
+ CompletionCallback user_write_callback_;
+
+ // Used by Read function.
+ scoped_refptr<IOBuffer> user_read_buf_;
+ int user_read_buf_len_;
+
+ // Used by Write function.
+ scoped_refptr<IOBuffer> user_write_buf_;
+ int user_write_buf_len_;
+
+ // The NSS SSL state machine
+ PRFileDesc* nss_fd_;
+
+ // Buffers for the network end of the SSL state machine
+ memio_Private* nss_bufs_;
+
+ // StreamSocket for sending and receiving data.
+ scoped_ptr<StreamSocket> transport_socket_;
+
+ // Options for the SSL socket.
+ SSLConfig ssl_config_;
+
+ // Certificate for the server.
+ scoped_refptr<X509Certificate> cert_;
+
+ // Private key used by the server.
+ scoped_ptr<crypto::RSAPrivateKey> key_;
+
+ State next_handshake_state_;
+ bool completed_handshake_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_SERVER_SOCKET_NSS_H_
diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc
new file mode 100644
index 00000000000..c327f2caf10
--- /dev/null
+++ b/chromium/net/socket/ssl_server_socket_openssl.cc
@@ -0,0 +1,28 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/logging.h"
+#include "net/socket/ssl_server_socket.h"
+
+// TODO(bulach): Provide simple stubs for EnableSSLServerSockets and
+// CreateSSLServerSocket so that when building for OpenSSL rather than NSS,
+// so that the code using SSL server sockets can be compiled and disabled
+// programatically rather than requiring to be carved out from the compile.
+
+namespace net {
+
+void EnableSSLServerSockets() {
+ NOTIMPLEMENTED();
+}
+
+scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
+ scoped_ptr<StreamSocket> socket,
+ X509Certificate* certificate,
+ crypto::RSAPrivateKey* key,
+ const SSLConfig& ssl_config) {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLServerSocket>();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc
new file mode 100644
index 00000000000..64c85490b29
--- /dev/null
+++ b/chromium/net/socket/ssl_server_socket_unittest.cc
@@ -0,0 +1,588 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This test suite uses SSLClientSocket to test the implementation of
+// SSLServerSocket. In order to establish connections between the sockets
+// we need two additional classes:
+// 1. FakeSocket
+// Connects SSL socket to FakeDataChannel. This class is just a stub.
+//
+// 2. FakeDataChannel
+// Implements the actual exchange of data between two FakeSockets.
+//
+// Implementations of these two classes are included in this file.
+
+#include "net/socket/ssl_server_socket.h"
+
+#include <stdlib.h>
+
+#include <queue>
+
+#include "base/compiler_specific.h"
+#include "base/file_util.h"
+#include "base/files/file_path.h"
+#include "base/message_loop/message_loop.h"
+#include "base/path_service.h"
+#include "crypto/nss_util.h"
+#include "crypto/rsa_private_key.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/test_data_directory.h"
+#include "net/cert/cert_status_flags.h"
+#include "net/cert/mock_cert_verifier.h"
+#include "net/cert/x509_certificate.h"
+#include "net/http/transport_security_state.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "net/ssl/ssl_config_service.h"
+#include "net/ssl/ssl_info.h"
+#include "net/test/cert_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+namespace {
+
+class FakeDataChannel {
+ public:
+ FakeDataChannel()
+ : read_buf_len_(0),
+ weak_factory_(this),
+ closed_(false),
+ write_called_after_close_(false) {
+ }
+
+ int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+ if (closed_)
+ return 0;
+ if (data_.empty()) {
+ read_callback_ = callback;
+ read_buf_ = buf;
+ read_buf_len_ = buf_len;
+ return net::ERR_IO_PENDING;
+ }
+ return PropogateData(buf, buf_len);
+ }
+
+ int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+ if (closed_) {
+ if (write_called_after_close_)
+ return net::ERR_CONNECTION_RESET;
+ write_called_after_close_ = true;
+ write_callback_ = callback;
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
+ weak_factory_.GetWeakPtr()));
+ return net::ERR_IO_PENDING;
+ }
+ data_.push(new net::DrainableIOBuffer(buf, buf_len));
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
+ weak_factory_.GetWeakPtr()));
+ return buf_len;
+ }
+
+ // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
+ // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
+ // after the FakeDataChannel is closed, the first Write() call completes
+ // asynchronously, which is necessary to reproduce bug 127822.
+ void Close() {
+ closed_ = true;
+ }
+
+ private:
+ void DoReadCallback() {
+ if (read_callback_.is_null() || data_.empty())
+ return;
+
+ int copied = PropogateData(read_buf_, read_buf_len_);
+ CompletionCallback callback = read_callback_;
+ read_callback_.Reset();
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ callback.Run(copied);
+ }
+
+ void DoWriteCallback() {
+ if (write_callback_.is_null())
+ return;
+
+ CompletionCallback callback = write_callback_;
+ write_callback_.Reset();
+ callback.Run(net::ERR_CONNECTION_RESET);
+ }
+
+ int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
+ scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
+ int copied = std::min(buf->BytesRemaining(), read_buf_len);
+ memcpy(read_buf->data(), buf->data(), copied);
+ buf->DidConsume(copied);
+
+ if (!buf->BytesRemaining())
+ data_.pop();
+ return copied;
+ }
+
+ CompletionCallback read_callback_;
+ scoped_refptr<net::IOBuffer> read_buf_;
+ int read_buf_len_;
+
+ CompletionCallback write_callback_;
+
+ std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
+
+ base::WeakPtrFactory<FakeDataChannel> weak_factory_;
+
+ // True if Close() has been called.
+ bool closed_;
+
+ // Controls the completion of Write() after the FakeDataChannel is closed.
+ // After the FakeDataChannel is closed, the first Write() call completes
+ // asynchronously.
+ bool write_called_after_close_;
+
+ DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
+};
+
+class FakeSocket : public StreamSocket {
+ public:
+ FakeSocket(FakeDataChannel* incoming_channel,
+ FakeDataChannel* outgoing_channel)
+ : incoming_(incoming_channel),
+ outgoing_(outgoing_channel) {
+ }
+
+ virtual ~FakeSocket() {
+ }
+
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ // Read random number of bytes.
+ buf_len = rand() % buf_len + 1;
+ return incoming_->Read(buf, buf_len, callback);
+ }
+
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ // Write random number of bytes.
+ buf_len = rand() % buf_len + 1;
+ return outgoing_->Write(buf, buf_len, callback);
+ }
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
+ return true;
+ }
+
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE {
+ return true;
+ }
+
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ return net::OK;
+ }
+
+ virtual void Disconnect() OVERRIDE {
+ incoming_->Close();
+ outgoing_->Close();
+ }
+
+ virtual bool IsConnected() const OVERRIDE {
+ return true;
+ }
+
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ return true;
+ }
+
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ net::IPAddressNumber ip_address(net::kIPv4AddressSize);
+ *address = net::IPEndPoint(ip_address, 0 /*port*/);
+ return net::OK;
+ }
+
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ net::IPAddressNumber ip_address(4);
+ *address = net::IPEndPoint(ip_address, 0);
+ return net::OK;
+ }
+
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+
+ virtual bool WasEverUsed() const OVERRIDE {
+ return true;
+ }
+
+ virtual bool UsingTCPFastOpen() const OVERRIDE {
+ return false;
+ }
+
+
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return false;
+ }
+
+ private:
+ net::BoundNetLog net_log_;
+ FakeDataChannel* incoming_;
+ FakeDataChannel* outgoing_;
+
+ DISALLOW_COPY_AND_ASSIGN(FakeSocket);
+};
+
+} // namespace
+
+// Verify the correctness of the test helper classes first.
+TEST(FakeSocketTest, DataTransfer) {
+ // Establish channels between two sockets.
+ FakeDataChannel channel_1;
+ FakeDataChannel channel_2;
+ FakeSocket client(&channel_1, &channel_2);
+ FakeSocket server(&channel_2, &channel_1);
+
+ const char kTestData[] = "testing123";
+ const int kTestDataSize = strlen(kTestData);
+ const int kReadBufSize = 1024;
+ scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
+ scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
+
+ // Write then read.
+ int written =
+ server.Write(write_buf.get(), kTestDataSize, CompletionCallback());
+ EXPECT_GT(written, 0);
+ EXPECT_LE(written, kTestDataSize);
+
+ int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback());
+ EXPECT_GT(read, 0);
+ EXPECT_LE(read, written);
+ EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
+
+ // Read then write.
+ TestCompletionCallback callback;
+ EXPECT_EQ(net::ERR_IO_PENDING,
+ server.Read(read_buf.get(), kReadBufSize, callback.callback()));
+
+ written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
+ EXPECT_GT(written, 0);
+ EXPECT_LE(written, kTestDataSize);
+
+ read = callback.WaitForResult();
+ EXPECT_GT(read, 0);
+ EXPECT_LE(read, written);
+ EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
+}
+
+class SSLServerSocketTest : public PlatformTest {
+ public:
+ SSLServerSocketTest()
+ : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
+ cert_verifier_(new MockCertVerifier()),
+ transport_security_state_(new TransportSecurityState) {
+ cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID);
+ }
+
+ protected:
+ void Initialize() {
+ scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
+ client_connection->SetSocket(
+ scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
+ scoped_ptr<StreamSocket> server_socket(
+ new FakeSocket(&channel_2_, &channel_1_));
+
+ base::FilePath certs_dir(GetTestCertsDirectory());
+
+ base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
+ std::string cert_der;
+ ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
+
+ scoped_refptr<net::X509Certificate> cert =
+ X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
+
+ base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
+ std::string key_string;
+ ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string));
+ std::vector<uint8> key_vector(
+ reinterpret_cast<const uint8*>(key_string.data()),
+ reinterpret_cast<const uint8*>(key_string.data() +
+ key_string.length()));
+
+ scoped_ptr<crypto::RSAPrivateKey> private_key(
+ crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
+
+ net::SSLConfig ssl_config;
+ ssl_config.cached_info_enabled = false;
+ ssl_config.false_start_enabled = false;
+ ssl_config.channel_id_enabled = false;
+ ssl_config.version_min = SSL_PROTOCOL_VERSION_SSL3;
+ ssl_config.version_max = SSL_PROTOCOL_VERSION_TLS1_1;
+
+ // Certificate provided by the host doesn't need authority.
+ net::SSLConfig::CertAndStatus cert_and_status;
+ cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
+ cert_and_status.der_cert = cert_der;
+ ssl_config.allowed_bad_certs.push_back(cert_and_status);
+
+ net::HostPortPair host_and_pair("unittest", 0);
+ net::SSLClientSocketContext context;
+ context.cert_verifier = cert_verifier_.get();
+ context.transport_security_state = transport_security_state_.get();
+ client_socket_ =
+ socket_factory_->CreateSSLClientSocket(
+ client_connection.Pass(), host_and_pair, ssl_config, context);
+ server_socket_ = net::CreateSSLServerSocket(
+ server_socket.Pass(),
+ cert.get(), private_key.get(), net::SSLConfig());
+ }
+
+ FakeDataChannel channel_1_;
+ FakeDataChannel channel_2_;
+ scoped_ptr<net::SSLClientSocket> client_socket_;
+ scoped_ptr<net::SSLServerSocket> server_socket_;
+ net::ClientSocketFactory* socket_factory_;
+ scoped_ptr<net::MockCertVerifier> cert_verifier_;
+ scoped_ptr<net::TransportSecurityState> transport_security_state_;
+};
+
+// SSLServerSocket is only implemented using NSS.
+#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
+
+// This test only executes creation of client and server sockets. This is to
+// test that creation of sockets doesn't crash and have minimal code to run
+// under valgrind in order to help debugging memory problems.
+TEST_F(SSLServerSocketTest, Initialize) {
+ Initialize();
+}
+
+// This test executes Connect() on SSLClientSocket and Handshake() on
+// SSLServerSocket to make sure handshaking between the two sockets is
+// completed successfully.
+TEST_F(SSLServerSocketTest, Handshake) {
+ Initialize();
+
+ TestCompletionCallback connect_callback;
+ TestCompletionCallback handshake_callback;
+
+ int server_ret = server_socket_->Handshake(handshake_callback.callback());
+ EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+
+ int client_ret = client_socket_->Connect(connect_callback.callback());
+ EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+
+ if (client_ret == net::ERR_IO_PENDING) {
+ EXPECT_EQ(net::OK, connect_callback.WaitForResult());
+ }
+ if (server_ret == net::ERR_IO_PENDING) {
+ EXPECT_EQ(net::OK, handshake_callback.WaitForResult());
+ }
+
+ // Make sure the cert status is expected.
+ SSLInfo ssl_info;
+ client_socket_->GetSSLInfo(&ssl_info);
+ EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
+}
+
+TEST_F(SSLServerSocketTest, DataTransfer) {
+ Initialize();
+
+ TestCompletionCallback connect_callback;
+ TestCompletionCallback handshake_callback;
+
+ // Establish connection.
+ int client_ret = client_socket_->Connect(connect_callback.callback());
+ ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+
+ int server_ret = server_socket_->Handshake(handshake_callback.callback());
+ ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+
+ client_ret = connect_callback.GetResult(client_ret);
+ ASSERT_EQ(net::OK, client_ret);
+ server_ret = handshake_callback.GetResult(server_ret);
+ ASSERT_EQ(net::OK, server_ret);
+
+ const int kReadBufSize = 1024;
+ scoped_refptr<net::StringIOBuffer> write_buf =
+ new net::StringIOBuffer("testing123");
+ scoped_refptr<net::DrainableIOBuffer> read_buf =
+ new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize),
+ kReadBufSize);
+
+ // Write then read.
+ TestCompletionCallback write_callback;
+ TestCompletionCallback read_callback;
+ server_ret = server_socket_->Write(
+ write_buf.get(), write_buf->size(), write_callback.callback());
+ EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ client_ret = client_socket_->Read(
+ read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
+ EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+
+ server_ret = write_callback.GetResult(server_ret);
+ EXPECT_GT(server_ret, 0);
+ client_ret = read_callback.GetResult(client_ret);
+ ASSERT_GT(client_ret, 0);
+
+ read_buf->DidConsume(client_ret);
+ while (read_buf->BytesConsumed() < write_buf->size()) {
+ client_ret = client_socket_->Read(
+ read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
+ EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+ client_ret = read_callback.GetResult(client_ret);
+ ASSERT_GT(client_ret, 0);
+ read_buf->DidConsume(client_ret);
+ }
+ EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
+ read_buf->SetOffset(0);
+ EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
+
+ // Read then write.
+ write_buf = new net::StringIOBuffer("hello123");
+ server_ret = server_socket_->Read(
+ read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
+ EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ client_ret = client_socket_->Write(
+ write_buf.get(), write_buf->size(), write_callback.callback());
+ EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+
+ server_ret = read_callback.GetResult(server_ret);
+ ASSERT_GT(server_ret, 0);
+ client_ret = write_callback.GetResult(client_ret);
+ EXPECT_GT(client_ret, 0);
+
+ read_buf->DidConsume(server_ret);
+ while (read_buf->BytesConsumed() < write_buf->size()) {
+ server_ret = server_socket_->Read(
+ read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
+ EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ server_ret = read_callback.GetResult(server_ret);
+ ASSERT_GT(server_ret, 0);
+ read_buf->DidConsume(server_ret);
+ }
+ EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
+ read_buf->SetOffset(0);
+ EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
+}
+
+// A regression test for bug 127822 (http://crbug.com/127822).
+// If the server closes the connection after the handshake is finished,
+// the client's Write() call should not cause an infinite loop.
+// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
+TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
+ Initialize();
+
+ TestCompletionCallback connect_callback;
+ TestCompletionCallback handshake_callback;
+
+ // Establish connection.
+ int client_ret = client_socket_->Connect(connect_callback.callback());
+ ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+
+ int server_ret = server_socket_->Handshake(handshake_callback.callback());
+ ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+
+ client_ret = connect_callback.GetResult(client_ret);
+ ASSERT_EQ(net::OK, client_ret);
+ server_ret = handshake_callback.GetResult(server_ret);
+ ASSERT_EQ(net::OK, server_ret);
+
+ scoped_refptr<net::StringIOBuffer> write_buf =
+ new net::StringIOBuffer("testing123");
+
+ // The server closes the connection. The server needs to write some
+ // data first so that the client's Read() calls from the transport
+ // socket won't return ERR_IO_PENDING. This ensures that the client
+ // will call Read() on the transport socket again.
+ TestCompletionCallback write_callback;
+
+ server_ret = server_socket_->Write(
+ write_buf.get(), write_buf->size(), write_callback.callback());
+ EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+
+ server_ret = write_callback.GetResult(server_ret);
+ EXPECT_GT(server_ret, 0);
+
+ server_socket_->Disconnect();
+
+ // The client writes some data. This should not cause an infinite loop.
+ client_ret = client_socket_->Write(
+ write_buf.get(), write_buf->size(), write_callback.callback());
+ EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+
+ client_ret = write_callback.GetResult(client_ret);
+ EXPECT_GT(client_ret, 0);
+
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, base::MessageLoop::QuitClosure(),
+ base::TimeDelta::FromMilliseconds(10));
+ base::MessageLoop::current()->Run();
+}
+
+// This test executes ExportKeyingMaterial() on the client and server sockets,
+// after connecting them, and verifies that the results match.
+// This test will fail if False Start is enabled (see crbug.com/90208).
+TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
+ Initialize();
+
+ TestCompletionCallback connect_callback;
+ TestCompletionCallback handshake_callback;
+
+ int client_ret = client_socket_->Connect(connect_callback.callback());
+ ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+
+ int server_ret = server_socket_->Handshake(handshake_callback.callback());
+ ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+
+ if (client_ret == net::ERR_IO_PENDING) {
+ ASSERT_EQ(net::OK, connect_callback.WaitForResult());
+ }
+ if (server_ret == net::ERR_IO_PENDING) {
+ ASSERT_EQ(net::OK, handshake_callback.WaitForResult());
+ }
+
+ const int kKeyingMaterialSize = 32;
+ const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test";
+ const char* kKeyingContext = "";
+ unsigned char server_out[kKeyingMaterialSize];
+ int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
+ false, kKeyingContext,
+ server_out, sizeof(server_out));
+ ASSERT_EQ(net::OK, rv);
+
+ unsigned char client_out[kKeyingMaterialSize];
+ rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
+ false, kKeyingContext,
+ client_out, sizeof(client_out));
+ ASSERT_EQ(net::OK, rv);
+ EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
+
+ const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad";
+ unsigned char client_bad[kKeyingMaterialSize];
+ rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
+ false, kKeyingContext,
+ client_bad, sizeof(client_bad));
+ ASSERT_EQ(rv, net::OK);
+ EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
+}
+#endif
+
+} // namespace net
diff --git a/chromium/net/socket/ssl_socket.h b/chromium/net/socket/ssl_socket.h
new file mode 100644
index 00000000000..68d1e4a2bfe
--- /dev/null
+++ b/chromium/net/socket/ssl_socket.h
@@ -0,0 +1,37 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SSL_SOCKET_H_
+#define NET_SOCKET_SSL_SOCKET_H_
+
+#include "base/basictypes.h"
+#include "base/strings/string_piece.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+// SSLSocket interface defines method that are common between client
+// and server SSL sockets.
+class NET_EXPORT SSLSocket : public StreamSocket {
+public:
+ virtual ~SSLSocket() {}
+
+ // Exports data derived from the SSL master-secret (see RFC 5705).
+ // If |has_context| is false, uses the no-context construction from the
+ // RFC and |context| is ignored. The call will fail with an error if
+ // the socket is not connected or the SSL implementation does not
+ // support the operation.
+ virtual int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) = 0;
+
+ // Stores the the tls-unique channel binding (see RFC 5929) in |*out|.
+ virtual int GetTLSUniqueChannelBinding(std::string* out) = 0;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SSL_SOCKET_H_
diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc
new file mode 100644
index 00000000000..c85c671800d
--- /dev/null
+++ b/chromium/net/socket/stream_listen_socket.cc
@@ -0,0 +1,308 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/stream_listen_socket.h"
+
+#if defined(OS_WIN)
+// winsock2.h must be included first in order to ensure it is included before
+// windows.h.
+#include <winsock2.h>
+#elif defined(OS_POSIX)
+#include <arpa/inet.h>
+#include <errno.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include "net/base/net_errors.h"
+#endif
+
+#include "base/logging.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/sys_byteorder.h"
+#include "base/threading/platform_thread.h"
+#include "build/build_config.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+using std::string;
+
+#if defined(OS_WIN)
+typedef int socklen_t;
+#endif // defined(OS_WIN)
+
+namespace net {
+
+namespace {
+
+const int kReadBufSize = 4096;
+
+} // namespace
+
+#if defined(OS_WIN)
+const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET;
+const int StreamListenSocket::kSocketError = SOCKET_ERROR;
+#elif defined(OS_POSIX)
+const SocketDescriptor StreamListenSocket::kInvalidSocket = -1;
+const int StreamListenSocket::kSocketError = -1;
+#endif
+
+StreamListenSocket::StreamListenSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del)
+ : socket_delegate_(del),
+ socket_(s),
+ reads_paused_(false),
+ has_pending_reads_(false) {
+#if defined(OS_WIN)
+ socket_event_ = WSACreateEvent();
+ // TODO(ibrar): error handling in case of socket_event_ == WSA_INVALID_EVENT.
+ WatchSocket(NOT_WAITING);
+#elif defined(OS_POSIX)
+ wait_state_ = NOT_WAITING;
+#endif
+}
+
+StreamListenSocket::~StreamListenSocket() {
+#if defined(OS_WIN)
+ if (socket_event_) {
+ WSACloseEvent(socket_event_);
+ socket_event_ = WSA_INVALID_EVENT;
+ }
+#endif
+ CloseSocket(socket_);
+}
+
+void StreamListenSocket::Send(const char* bytes, int len,
+ bool append_linefeed) {
+ SendInternal(bytes, len);
+ if (append_linefeed)
+ SendInternal("\r\n", 2);
+}
+
+void StreamListenSocket::Send(const string& str, bool append_linefeed) {
+ Send(str.data(), static_cast<int>(str.length()), append_linefeed);
+}
+
+int StreamListenSocket::GetLocalAddress(IPEndPoint* address) {
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len)) {
+#if defined(OS_WIN)
+ int err = WSAGetLastError();
+#else
+ int err = errno;
+#endif
+ return MapSystemError(err);
+ }
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_FAILED;
+ return OK;
+}
+
+SocketDescriptor StreamListenSocket::AcceptSocket() {
+ SocketDescriptor conn = HANDLE_EINTR(accept(socket_, NULL, NULL));
+ if (conn == kInvalidSocket)
+ LOG(ERROR) << "Error accepting connection.";
+ else
+ SetNonBlocking(conn);
+ return conn;
+}
+
+void StreamListenSocket::SendInternal(const char* bytes, int len) {
+ char* send_buf = const_cast<char *>(bytes);
+ int len_left = len;
+ while (true) {
+ int sent = HANDLE_EINTR(send(socket_, send_buf, len_left, 0));
+ if (sent == len_left) { // A shortcut to avoid extraneous checks.
+ break;
+ }
+ if (sent == kSocketError) {
+#if defined(OS_WIN)
+ if (WSAGetLastError() != WSAEWOULDBLOCK) {
+ LOG(ERROR) << "send failed: WSAGetLastError()==" << WSAGetLastError();
+#elif defined(OS_POSIX)
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ LOG(ERROR) << "send failed: errno==" << errno;
+#endif
+ break;
+ }
+ // Otherwise we would block, and now we have to wait for a retry.
+ // Fall through to PlatformThread::YieldCurrentThread()
+ } else {
+ // sent != len_left according to the shortcut above.
+ // Shift the buffer start and send the remainder after a short while.
+ send_buf += sent;
+ len_left -= sent;
+ }
+ base::PlatformThread::YieldCurrentThread();
+ }
+}
+
+void StreamListenSocket::Listen() {
+ int backlog = 10; // TODO(erikkay): maybe don't allow any backlog?
+ if (listen(socket_, backlog) == -1) {
+ // TODO(erikkay): error handling.
+ LOG(ERROR) << "Could not listen on socket.";
+ return;
+ }
+#if defined(OS_POSIX)
+ WatchSocket(WAITING_ACCEPT);
+#endif
+}
+
+void StreamListenSocket::Read() {
+ char buf[kReadBufSize + 1]; // +1 for null termination.
+ int len;
+ do {
+ len = HANDLE_EINTR(recv(socket_, buf, kReadBufSize, 0));
+ if (len == kSocketError) {
+#if defined(OS_WIN)
+ int err = WSAGetLastError();
+ if (err == WSAEWOULDBLOCK) {
+#elif defined(OS_POSIX)
+ if (errno == EWOULDBLOCK || errno == EAGAIN) {
+#endif
+ break;
+ } else {
+ // TODO(ibrar): some error handling required here.
+ break;
+ }
+ } else if (len == 0) {
+ // In Windows, Close() is called by OnObjectSignaled. In POSIX, we need
+ // to call it here.
+#if defined(OS_POSIX)
+ Close();
+#endif
+ } else {
+ // TODO(ibrar): maybe change DidRead to take a length instead.
+ DCHECK_GT(len, 0);
+ DCHECK_LE(len, kReadBufSize);
+ buf[len] = 0; // Already create a buffer with +1 length.
+ socket_delegate_->DidRead(this, buf, len);
+ }
+ } while (len == kReadBufSize);
+}
+
+void StreamListenSocket::Close() {
+#if defined(OS_POSIX)
+ if (wait_state_ == NOT_WAITING)
+ return;
+ wait_state_ = NOT_WAITING;
+#endif
+ UnwatchSocket();
+ socket_delegate_->DidClose(this);
+}
+
+void StreamListenSocket::CloseSocket(SocketDescriptor s) {
+ if (s && s != kInvalidSocket) {
+ UnwatchSocket();
+#if defined(OS_WIN)
+ closesocket(s);
+#elif defined(OS_POSIX)
+ close(s);
+#endif
+ }
+}
+
+void StreamListenSocket::WatchSocket(WaitState state) {
+#if defined(OS_WIN)
+ WSAEventSelect(socket_, socket_event_, FD_ACCEPT | FD_CLOSE | FD_READ);
+ watcher_.StartWatching(socket_event_, this);
+#elif defined(OS_POSIX)
+ // Implicitly calls StartWatchingFileDescriptor().
+ base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_READ, &watcher_, this);
+ wait_state_ = state;
+#endif
+}
+
+void StreamListenSocket::UnwatchSocket() {
+#if defined(OS_WIN)
+ watcher_.StopWatching();
+#elif defined(OS_POSIX)
+ watcher_.StopWatchingFileDescriptor();
+#endif
+}
+
+// TODO(ibrar): We can add these functions into OS dependent files.
+#if defined(OS_WIN)
+// MessageLoop watcher callback.
+void StreamListenSocket::OnObjectSignaled(HANDLE object) {
+ WSANETWORKEVENTS ev;
+ if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) {
+ // TODO
+ return;
+ }
+
+ if (ev.lNetworkEvents & FD_CLOSE) {
+ Close();
+ // Close might have deleted this object. We should return immediately.
+ return;
+ }
+
+ // The object was reset by WSAEnumNetworkEvents. Watch for the next signal.
+ watcher_.StartWatching(object, this);
+
+ if (ev.lNetworkEvents == 0) {
+ // Occasionally the event is set even though there is no new data.
+ // The net seems to think that this is ignorable.
+ return;
+ }
+ if (ev.lNetworkEvents & FD_ACCEPT) {
+ Accept();
+ }
+ if (ev.lNetworkEvents & FD_READ) {
+ if (reads_paused_) {
+ has_pending_reads_ = true;
+ } else {
+ Read();
+ // Read() might call Close() internally and 'this' can be invalid here
+ return;
+ }
+ }
+}
+#elif defined(OS_POSIX)
+void StreamListenSocket::OnFileCanReadWithoutBlocking(int fd) {
+ switch (wait_state_) {
+ case WAITING_ACCEPT:
+ Accept();
+ break;
+ case WAITING_READ:
+ if (reads_paused_) {
+ has_pending_reads_ = true;
+ } else {
+ Read();
+ }
+ break;
+ default:
+ // Close() is called by Read() in the Linux case.
+ NOTREACHED();
+ break;
+ }
+}
+
+void StreamListenSocket::OnFileCanWriteWithoutBlocking(int fd) {
+ // MessagePumpLibevent callback, we don't listen for write events
+ // so we shouldn't ever reach here.
+ NOTREACHED();
+}
+
+#endif
+
+void StreamListenSocket::PauseReads() {
+ DCHECK(!reads_paused_);
+ reads_paused_ = true;
+}
+
+void StreamListenSocket::ResumeReads() {
+ DCHECK(reads_paused_);
+ reads_paused_ = false;
+ if (has_pending_reads_) {
+ has_pending_reads_ = false;
+ Read();
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h
new file mode 100644
index 00000000000..6f03eefaca2
--- /dev/null
+++ b/chromium/net/socket/stream_listen_socket.h
@@ -0,0 +1,155 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// Stream-based listen socket implementation that handles reading and writing
+// to the socket, but does not handle creating the socket nor connecting
+// sockets, which are handled by subclasses on creation and in Accept,
+// respectively.
+
+// StreamListenSocket handles IO asynchronously in the specified MessageLoop.
+// This class is NOT thread safe. It uses WSAEVENT handles to monitor activity
+// in a given MessageLoop. This means that callbacks will happen in that loop's
+// thread always and that all other methods (including constructor and
+// destructor) should also be called from the same thread.
+
+#ifndef NET_SOCKET_STREAM_LISTEN_SOCKET_H_
+#define NET_SOCKET_STREAM_LISTEN_SOCKET_H_
+
+#include "build/build_config.h"
+
+#if defined(OS_WIN)
+#include <winsock2.h>
+#endif
+#include <string>
+#if defined(OS_WIN)
+#include "base/win/object_watcher.h"
+#elif defined(OS_POSIX)
+#include "base/message_loop/message_loop.h"
+#endif
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "net/base/net_export.h"
+#include "net/socket/stream_listen_socket.h"
+
+#if defined(OS_POSIX)
+typedef int SocketDescriptor;
+#else
+typedef SOCKET SocketDescriptor;
+#endif
+
+namespace net {
+
+class IPEndPoint;
+
+class NET_EXPORT StreamListenSocket
+ : public base::RefCountedThreadSafe<StreamListenSocket>,
+#if defined(OS_WIN)
+ public base::win::ObjectWatcher::Delegate {
+#elif defined(OS_POSIX)
+ public base::MessageLoopForIO::Watcher {
+#endif
+
+ public:
+ // TODO(erikkay): this delegate should really be split into two parts
+ // to split up the listener from the connected socket. Perhaps this class
+ // should be split up similarly.
+ class Delegate {
+ public:
+ // |server| is the original listening Socket, connection is the new
+ // Socket that was created. Ownership of |connection| is transferred
+ // to the delegate with this call.
+ virtual void DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) = 0;
+ virtual void DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) = 0;
+ virtual void DidClose(StreamListenSocket* sock) = 0;
+
+ protected:
+ virtual ~Delegate() {}
+ };
+
+ // Send data to the socket.
+ void Send(const char* bytes, int len, bool append_linefeed = false);
+ void Send(const std::string& str, bool append_linefeed = false);
+
+ // Copies the local address to |address|. Returns a network error code.
+ int GetLocalAddress(IPEndPoint* address);
+
+ static const SocketDescriptor kInvalidSocket;
+ static const int kSocketError;
+
+ protected:
+ enum WaitState {
+ NOT_WAITING = 0,
+ WAITING_ACCEPT = 1,
+ WAITING_READ = 2
+ };
+
+ StreamListenSocket(SocketDescriptor s, Delegate* del);
+ virtual ~StreamListenSocket();
+
+ SocketDescriptor AcceptSocket();
+ virtual void Accept() = 0;
+
+ void Listen();
+ void Read();
+ void Close();
+ void CloseSocket(SocketDescriptor s);
+
+ // Pass any value in case of Windows, because in Windows
+ // we are not using state.
+ void WatchSocket(WaitState state);
+ void UnwatchSocket();
+
+ Delegate* const socket_delegate_;
+
+ private:
+ friend class base::RefCountedThreadSafe<StreamListenSocket>;
+ friend class TransportClientSocketTest;
+
+ void SendInternal(const char* bytes, int len);
+
+#if defined(OS_WIN)
+ // ObjectWatcher delegate.
+ virtual void OnObjectSignaled(HANDLE object);
+ base::win::ObjectWatcher watcher_;
+ HANDLE socket_event_;
+#elif defined(OS_POSIX)
+ // Called by MessagePumpLibevent when the socket is ready to do I/O.
+ virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
+ virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
+ WaitState wait_state_;
+ // The socket's libevent wrapper.
+ base::MessageLoopForIO::FileDescriptorWatcher watcher_;
+#endif
+
+ // NOTE: This is for unit test use only!
+ // Pause/Resume calling Read(). Note that ResumeReads() will also call
+ // Read() if there is anything to read.
+ void PauseReads();
+ void ResumeReads();
+
+ const SocketDescriptor socket_;
+ bool reads_paused_;
+ bool has_pending_reads_;
+
+ DISALLOW_COPY_AND_ASSIGN(StreamListenSocket);
+};
+
+// Abstract factory that must be subclassed for each subclass of
+// StreamListenSocket.
+class NET_EXPORT StreamListenSocketFactory {
+ public:
+ virtual ~StreamListenSocketFactory() {}
+
+ // Returns a new instance of StreamListenSocket or NULL if an error occurred.
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const = 0;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_STREAM_LISTEN_SOCKET_H_
diff --git a/chromium/net/socket/stream_socket.cc b/chromium/net/socket/stream_socket.cc
new file mode 100644
index 00000000000..fb194f64625
--- /dev/null
+++ b/chromium/net/socket/stream_socket.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/stream_socket.h"
+
+#include "base/metrics/field_trial.h"
+#include "base/metrics/histogram.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/values.h"
+
+namespace net {
+
+StreamSocket::UseHistory::UseHistory()
+ : was_ever_connected_(false),
+ was_used_to_convey_data_(false),
+ omnibox_speculation_(false),
+ subresource_speculation_(false) {
+}
+
+StreamSocket::UseHistory::~UseHistory() {
+ EmitPreconnectionHistograms();
+}
+
+void StreamSocket::UseHistory::Reset() {
+ EmitPreconnectionHistograms();
+ was_ever_connected_ = false;
+ was_used_to_convey_data_ = false;
+ // omnibox_speculation_ and subresource_speculation_ values
+ // are intentionally preserved.
+}
+
+void StreamSocket::UseHistory::set_was_ever_connected() {
+ DCHECK(!was_used_to_convey_data_);
+ was_ever_connected_ = true;
+}
+
+void StreamSocket::UseHistory::set_was_used_to_convey_data() {
+ DCHECK(was_ever_connected_);
+ was_used_to_convey_data_ = true;
+}
+
+
+void StreamSocket::UseHistory::set_subresource_speculation() {
+ DCHECK(was_ever_connected_);
+ // TODO(jar): We should transition to marking a socket (or stream) at
+ // construction time as being created for speculative reasons. This current
+ // approach of trying to track use of a socket to convey data can make
+ // mistakes when other sockets (such as ones sitting in the pool for a long
+ // time) are issued. Unused sockets can be left over when a when a set of
+ // connections to a host are made, and one is "unlucky" and takes so long to
+ // complete a connection, that another socket is used, and recycled before a
+ // second connection comes available. Similarly, re-try connections can leave
+ // an original (slow to connect socket) in the pool, and that can be issued
+ // to a speculative requester. In any cases such old sockets will fail when an
+ // attempt is made to used them!... and then it will look like a speculative
+ // socket was discarded without any user!?!?!
+ if (was_used_to_convey_data_)
+ return;
+ subresource_speculation_ = true;
+}
+
+void StreamSocket::UseHistory::set_omnibox_speculation() {
+ DCHECK(was_ever_connected_);
+ if (was_used_to_convey_data_)
+ return;
+ omnibox_speculation_ = true;
+}
+
+bool StreamSocket::UseHistory::was_used_to_convey_data() const {
+ DCHECK(!was_used_to_convey_data_ || was_ever_connected_);
+ return was_used_to_convey_data_;
+}
+
+void StreamSocket::UseHistory::EmitPreconnectionHistograms() const {
+ DCHECK(!subresource_speculation_ || !omnibox_speculation_);
+ // 0 ==> non-speculative, never connected.
+ // 1 ==> non-speculative never used (but connected).
+ // 2 ==> non-speculative and used.
+ // 3 ==> omnibox_speculative never connected.
+ // 4 ==> omnibox_speculative never used (but connected).
+ // 5 ==> omnibox_speculative and used.
+ // 6 ==> subresource_speculative never connected.
+ // 7 ==> subresource_speculative never used (but connected).
+ // 8 ==> subresource_speculative and used.
+ int result;
+ if (was_used_to_convey_data_)
+ result = 2;
+ else if (was_ever_connected_)
+ result = 1;
+ else
+ result = 0; // Never used, and not really connected.
+
+ if (omnibox_speculation_)
+ result += 3;
+ else if (subresource_speculation_)
+ result += 6;
+ UMA_HISTOGRAM_ENUMERATION("Net.PreconnectUtilization2", result, 9);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h
new file mode 100644
index 00000000000..38eec86dd92
--- /dev/null
+++ b/chromium/net/socket/stream_socket.h
@@ -0,0 +1,138 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_STREAM_SOCKET_H_
+#define NET_SOCKET_STREAM_SOCKET_H_
+
+#include "net/base/net_log.h"
+#include "net/socket/next_proto.h"
+#include "net/socket/socket.h"
+
+namespace net {
+
+class AddressList;
+class IPEndPoint;
+class SSLInfo;
+
+class NET_EXPORT_PRIVATE StreamSocket : public Socket {
+ public:
+ virtual ~StreamSocket() {}
+
+ // Called to establish a connection. Returns OK if the connection could be
+ // established synchronously. Otherwise, ERR_IO_PENDING is returned and the
+ // given callback will run asynchronously when the connection is established
+ // or when an error occurs. The result is some other error code if the
+ // connection could not be established.
+ //
+ // The socket's Read and Write methods may not be called until Connect
+ // succeeds.
+ //
+ // It is valid to call Connect on an already connected socket, in which case
+ // OK is simply returned.
+ //
+ // Connect may also be called again after a call to the Disconnect method.
+ //
+ virtual int Connect(const CompletionCallback& callback) = 0;
+
+ // Called to disconnect a socket. Does nothing if the socket is already
+ // disconnected. After calling Disconnect it is possible to call Connect
+ // again to establish a new connection.
+ //
+ // If IO (Connect, Read, or Write) is pending when the socket is
+ // disconnected, the pending IO is cancelled, and the completion callback
+ // will not be called.
+ virtual void Disconnect() = 0;
+
+ // Called to test if the connection is still alive. Returns false if a
+ // connection wasn't established or the connection is dead.
+ virtual bool IsConnected() const = 0;
+
+ // Called to test if the connection is still alive and idle. Returns false
+ // if a connection wasn't established, the connection is dead, or some data
+ // have been received.
+ virtual bool IsConnectedAndIdle() const = 0;
+
+ // Copies the peer address to |address| and returns a network error code.
+ // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not connected.
+ virtual int GetPeerAddress(IPEndPoint* address) const = 0;
+
+ // Copies the local address to |address| and returns a network error code.
+ // ERR_SOCKET_NOT_CONNECTED will be returned if the socket is not bound.
+ virtual int GetLocalAddress(IPEndPoint* address) const = 0;
+
+ // Gets the NetLog for this socket.
+ virtual const BoundNetLog& NetLog() const = 0;
+
+ // Set the annotation to indicate this socket was created for speculative
+ // reasons. This call is generally forwarded to a basic TCPClientSocket*,
+ // where a UseHistory can be updated.
+ virtual void SetSubresourceSpeculation() = 0;
+ virtual void SetOmniboxSpeculation() = 0;
+
+ // Returns true if the underlying transport socket ever had any reads or
+ // writes. StreamSockets layered on top of transport sockets should forward
+ // this call to the transport socket.
+ virtual bool WasEverUsed() const = 0;
+
+ // Returns true if the underlying transport socket is using TCP FastOpen.
+ // TCP FastOpen is an experiment with sending data in the TCP SYN packet.
+ virtual bool UsingTCPFastOpen() const = 0;
+
+ // Returns true if NPN was negotiated during the connection of this socket.
+ virtual bool WasNpnNegotiated() const = 0;
+
+ // Returns the protocol negotiated via NPN for this socket, or
+ // kProtoUnknown will be returned if NPN is not applicable.
+ virtual NextProto GetNegotiatedProtocol() const = 0;
+
+ // Gets the SSL connection information of the socket. Returns false if
+ // SSL was not used by this socket.
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) = 0;
+
+ protected:
+ // The following class is only used to gather statistics about the history of
+ // a socket. It is only instantiated and used in basic sockets, such as
+ // TCPClientSocket* instances. Other classes that are derived from
+ // StreamSocket should forward any potential settings to their underlying
+ // transport sockets.
+ class UseHistory {
+ public:
+ UseHistory();
+ ~UseHistory();
+
+ // Resets the state of UseHistory and emits histograms for the
+ // current state.
+ void Reset();
+
+ void set_was_ever_connected();
+ void set_was_used_to_convey_data();
+
+ // The next two setters only have any impact if the socket has not yet been
+ // used to transmit data. If called later, we assume that the socket was
+ // reused from the pool, and was NOT constructed to service a speculative
+ // request.
+ void set_subresource_speculation();
+ void set_omnibox_speculation();
+
+ bool was_used_to_convey_data() const;
+
+ private:
+ // Summarize the statistics for this socket.
+ void EmitPreconnectionHistograms() const;
+ // Indicate if this was ever connected.
+ bool was_ever_connected_;
+ // Indicate if this socket was ever used to transmit or receive data.
+ bool was_used_to_convey_data_;
+
+ // Indicate if this socket was first created for speculative use, and
+ // identify the motivation.
+ bool omnibox_speculation_;
+ bool subresource_speculation_;
+ DISALLOW_COPY_AND_ASSIGN(UseHistory);
+ };
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_STREAM_SOCKET_H_
diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc
new file mode 100644
index 00000000000..dbd21056f39
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_client_socket.h"
+
+#include "base/file_util.h"
+#include "base/files/file_path.h"
+
+namespace net {
+
+namespace {
+
+#if defined(OS_LINUX)
+
+// Checks to see if the system supports TCP FastOpen. Notably, it requires
+// kernel support. Additionally, this checks system configuration to ensure that
+// it's enabled.
+bool SystemSupportsTCPFastOpen() {
+ static const base::FilePath::CharType kTCPFastOpenProcFilePath[] =
+ "/proc/sys/net/ipv4/tcp_fastopen";
+ std::string system_enabled_tcp_fastopen;
+ if (!file_util::ReadFileToString(
+ base::FilePath(kTCPFastOpenProcFilePath),
+ &system_enabled_tcp_fastopen)) {
+ return false;
+ }
+
+ // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225
+ // TFO_CLIENT_ENABLE is the LSB
+ if (system_enabled_tcp_fastopen.empty() ||
+ (system_enabled_tcp_fastopen[0] & 0x1) == 0) {
+ return false;
+ }
+
+ return true;
+}
+
+#else
+
+bool SystemSupportsTCPFastOpen() {
+ return false;
+}
+
+#endif
+
+}
+
+static bool g_tcp_fastopen_enabled = false;
+
+void SetTCPFastOpenEnabled(bool value) {
+ g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen();
+}
+
+bool IsTCPFastOpenEnabled() {
+ return g_tcp_fastopen_enabled;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h
new file mode 100644
index 00000000000..8a2c0cd73f0
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket.h
@@ -0,0 +1,35 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_H_
+#define NET_SOCKET_TCP_CLIENT_SOCKET_H_
+
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+
+#if defined(OS_WIN)
+#include "net/socket/tcp_client_socket_win.h"
+#elif defined(OS_POSIX)
+#include "net/socket/tcp_client_socket_libevent.h"
+#endif
+
+namespace net {
+
+// A client socket that uses TCP as the transport layer.
+#if defined(OS_WIN)
+typedef TCPClientSocketWin TCPClientSocket;
+#elif defined(OS_POSIX)
+typedef TCPClientSocketLibevent TCPClientSocket;
+#endif
+
+// Enable/disable experimental TCP FastOpen option.
+// Not thread safe. Must be called during initialization/startup only.
+NET_EXPORT void SetTCPFastOpenEnabled(bool value);
+
+// Check if the TCP FastOpen option is enabled.
+bool IsTCPFastOpenEnabled();
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_CLIENT_SOCKET_H_
diff --git a/chromium/net/socket/tcp_client_socket_libevent.cc b/chromium/net/socket/tcp_client_socket_libevent.cc
new file mode 100644
index 00000000000..2f7e4b4b255
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket_libevent.cc
@@ -0,0 +1,830 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_client_socket.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <sys/socket.h>
+#include <netinet/tcp.h>
+#if defined(OS_POSIX)
+#include <netinet/in.h>
+#endif
+
+#include "base/logging.h"
+#include "base/message_loop/message_loop.h"
+#include "base/metrics/histogram.h"
+#include "base/metrics/stats_counters.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/strings/string_util.h"
+#include "net/base/connection_type_histograms.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/base/network_change_notifier.h"
+#include "net/socket/socket_net_log_params.h"
+
+// If we don't have a definition for TCPI_OPT_SYN_DATA, create one.
+#ifndef TCPI_OPT_SYN_DATA
+#define TCPI_OPT_SYN_DATA 32
+#endif
+
+namespace net {
+
+namespace {
+
+const int kInvalidSocket = -1;
+const int kTCPKeepAliveSeconds = 45;
+
+// SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets
+// will wait up to 200ms for more data to complete a packet before transmitting.
+// After calling this function, the kernel will not wait. See TCP_NODELAY in
+// `man 7 tcp`.
+bool SetTCPNoDelay(int fd, bool no_delay) {
+ int on = no_delay ? 1 : 0;
+ int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on,
+ sizeof(on));
+ return error == 0;
+}
+
+// SetTCPKeepAlive sets SO_KEEPALIVE.
+bool SetTCPKeepAlive(int fd, bool enable, int delay) {
+ int on = enable ? 1 : 0;
+ if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) {
+ PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd;
+ return false;
+ }
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ // Set seconds until first TCP keep alive.
+ if (setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &delay, sizeof(delay))) {
+ PLOG(ERROR) << "Failed to set TCP_KEEPIDLE on fd: " << fd;
+ return false;
+ }
+ // Set seconds between TCP keep alives.
+ if (setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &delay, sizeof(delay))) {
+ PLOG(ERROR) << "Failed to set TCP_KEEPINTVL on fd: " << fd;
+ return false;
+ }
+#endif
+ return true;
+}
+
+// Sets socket parameters. Returns the OS error code (or 0 on
+// success).
+int SetupSocket(int socket) {
+ if (SetNonBlocking(socket))
+ return errno;
+
+ // This mirrors the behaviour on Windows. See the comment in
+ // tcp_client_socket_win.cc after searching for "NODELAY".
+ SetTCPNoDelay(socket, true); // If SetTCPNoDelay fails, we don't care.
+ SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds);
+
+ return 0;
+}
+
+// Creates a new socket and sets default parameters for it. Returns
+// the OS error code (or 0 on success).
+int CreateSocket(int family, int* socket) {
+ *socket = ::socket(family, SOCK_STREAM, IPPROTO_TCP);
+ if (*socket == kInvalidSocket)
+ return errno;
+ int error = SetupSocket(*socket);
+ if (error) {
+ if (HANDLE_EINTR(close(*socket)) < 0)
+ PLOG(ERROR) << "close";
+ *socket = kInvalidSocket;
+ return error;
+ }
+ return 0;
+}
+
+int MapConnectError(int os_error) {
+ switch (os_error) {
+ case EACCES:
+ return ERR_NETWORK_ACCESS_DENIED;
+ case ETIMEDOUT:
+ return ERR_CONNECTION_TIMED_OUT;
+ default: {
+ int net_error = MapSystemError(os_error);
+ if (net_error == ERR_FAILED)
+ return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED.
+
+ // Give a more specific error when the user is offline.
+ if (net_error == ERR_ADDRESS_UNREACHABLE &&
+ NetworkChangeNotifier::IsOffline()) {
+ return ERR_INTERNET_DISCONNECTED;
+ }
+ return net_error;
+ }
+ }
+}
+
+} // namespace
+
+//-----------------------------------------------------------------------------
+
+TCPClientSocketLibevent::TCPClientSocketLibevent(
+ const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(kInvalidSocket),
+ bound_socket_(kInvalidSocket),
+ addresses_(addresses),
+ current_address_index_(-1),
+ read_watcher_(this),
+ write_watcher_(this),
+ next_connect_state_(CONNECT_STATE_NONE),
+ connect_os_error_(0),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
+ previously_disconnected_(false),
+ use_tcp_fastopen_(IsTCPFastOpenEnabled()),
+ tcp_fastopen_connected_(false),
+ fast_open_status_(FAST_OPEN_STATUS_UNKNOWN) {
+ net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
+ source.ToEventParametersCallback());
+}
+
+TCPClientSocketLibevent::~TCPClientSocketLibevent() {
+ Disconnect();
+ net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
+ if (tcp_fastopen_connected_) {
+ UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection",
+ fast_open_status_, FAST_OPEN_MAX_VALUE);
+ }
+}
+
+int TCPClientSocketLibevent::AdoptSocket(int socket) {
+ DCHECK_EQ(socket_, kInvalidSocket);
+
+ int error = SetupSocket(socket);
+ if (error)
+ return MapSystemError(error);
+
+ socket_ = socket;
+
+ // This is to make GetPeerAddress() work. It's up to the caller ensure
+ // that |address_| contains a reasonable address for this
+ // socket. (i.e. at least match IPv4 vs IPv6!).
+ current_address_index_ = 0;
+ use_history_.set_was_ever_connected();
+
+ return OK;
+}
+
+int TCPClientSocketLibevent::Bind(const IPEndPoint& address) {
+ if (current_address_index_ >= 0 || bind_address_.get()) {
+ // Cannot bind the socket if we are already bound connected or
+ // connecting.
+ return ERR_UNEXPECTED;
+ }
+
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+
+ // Create |bound_socket_| and try to bind it to |address|.
+ int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_);
+ if (error)
+ return MapSystemError(error);
+
+ if (HANDLE_EINTR(bind(bound_socket_, storage.addr, storage.addr_len))) {
+ error = errno;
+ if (HANDLE_EINTR(close(bound_socket_)) < 0)
+ PLOG(ERROR) << "close";
+ bound_socket_ = kInvalidSocket;
+ return MapSystemError(error);
+ }
+
+ bind_address_.reset(new IPEndPoint(address));
+
+ return 0;
+}
+
+int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+
+ // If already connected, then just return OK.
+ if (socket_ != kInvalidSocket)
+ return OK;
+
+ base::StatsCounter connects("tcp.connect");
+ connects.Increment();
+
+ DCHECK(!waiting_connect());
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
+ addresses_.CreateNetLogCallback());
+
+ // We will try to connect to each address in addresses_. Start with the
+ // first one in the list.
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ current_address_index_ = 0;
+
+ int rv = DoConnectLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ // Synchronous operation not supported.
+ DCHECK(!callback.is_null());
+ write_callback_ = callback;
+ } else {
+ LogConnectCompletion(rv);
+ }
+
+ return rv;
+}
+
+int TCPClientSocketLibevent::DoConnectLoop(int result) {
+ DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
+
+ int rv = result;
+ do {
+ ConnectState state = next_connect_state_;
+ next_connect_state_ = CONNECT_STATE_NONE;
+ switch (state) {
+ case CONNECT_STATE_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoConnect();
+ break;
+ case CONNECT_STATE_CONNECT_COMPLETE:
+ rv = DoConnectComplete(rv);
+ break;
+ default:
+ LOG(DFATAL) << "bad state";
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
+
+ return rv;
+}
+
+int TCPClientSocketLibevent::DoConnect() {
+ DCHECK_GE(current_address_index_, 0);
+ DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
+ DCHECK_EQ(0, connect_os_error_);
+
+ const IPEndPoint& endpoint = addresses_[current_address_index_];
+
+ if (previously_disconnected_) {
+ use_history_.Reset();
+ previously_disconnected_ = false;
+ }
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ CreateNetLogIPEndPointCallback(&endpoint));
+
+ next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
+
+ if (bound_socket_ != kInvalidSocket) {
+ DCHECK(bind_address_.get());
+ socket_ = bound_socket_;
+ bound_socket_ = kInvalidSocket;
+ } else {
+ // Create a non-blocking socket.
+ connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_);
+ if (connect_os_error_)
+ return MapSystemError(connect_os_error_);
+
+ if (bind_address_.get()) {
+ SockaddrStorage storage;
+ if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+ if (HANDLE_EINTR(bind(socket_, storage.addr, storage.addr_len)))
+ return MapSystemError(errno);
+ }
+ }
+
+ // Connect the socket.
+ if (!use_tcp_fastopen_) {
+ SockaddrStorage storage;
+ if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+
+ if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) {
+ // Connected without waiting!
+ return OK;
+ }
+ } else {
+ // With TCP FastOpen, we pretend that the socket is connected.
+ DCHECK(!tcp_fastopen_connected_);
+ return OK;
+ }
+
+ // Check if the connect() failed synchronously.
+ connect_os_error_ = errno;
+ if (connect_os_error_ != EINPROGRESS)
+ return MapConnectError(connect_os_error_);
+
+ // Otherwise the connect() is going to complete asynchronously, so watch
+ // for its completion.
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_WRITE,
+ &write_socket_watcher_, &write_watcher_)) {
+ connect_os_error_ = errno;
+ DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_;
+ return MapSystemError(connect_os_error_);
+ }
+
+ return ERR_IO_PENDING;
+}
+
+int TCPClientSocketLibevent::DoConnectComplete(int result) {
+ // Log the end of this attempt (and any OS error it threw).
+ int os_error = connect_os_error_;
+ connect_os_error_ = 0;
+ if (result != OK) {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ NetLog::IntegerCallback("os_error", os_error));
+ } else {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
+ }
+
+ if (result == OK) {
+ write_socket_watcher_.StopWatchingFileDescriptor();
+ use_history_.set_was_ever_connected();
+ return OK; // Done!
+ }
+
+ // Close whatever partially connected socket we currently have.
+ DoDisconnect();
+
+ // Try to fall back to the next address in the list.
+ if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ ++current_address_index_;
+ return OK;
+ }
+
+ // Otherwise there is nothing to fall back to, so give up.
+ return result;
+}
+
+void TCPClientSocketLibevent::Disconnect() {
+ DCHECK(CalledOnValidThread());
+
+ DoDisconnect();
+ current_address_index_ = -1;
+ bind_address_.reset();
+}
+
+void TCPClientSocketLibevent::DoDisconnect() {
+ if (socket_ == kInvalidSocket)
+ return;
+
+ bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ ok = write_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ if (HANDLE_EINTR(close(socket_)) < 0)
+ PLOG(ERROR) << "close";
+ socket_ = kInvalidSocket;
+ previously_disconnected_ = true;
+}
+
+bool TCPClientSocketLibevent::IsConnected() const {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ == kInvalidSocket || waiting_connect())
+ return false;
+
+ if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
+ // With TCP FastOpen, we pretend that the socket is connected.
+ // This allows GetPeerAddress() to return current_ai_ as the peer
+ // address. Since we don't fail over to the next address if
+ // sendto() fails, current_ai_ is the only possible peer address.
+ CHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
+ return true;
+ }
+
+ // Check if connection is alive.
+ char c;
+ int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK));
+ if (rv == 0)
+ return false;
+ if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+bool TCPClientSocketLibevent::IsConnectedAndIdle() const {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ == kInvalidSocket || waiting_connect())
+ return false;
+
+ // TODO(wtc): should we also handle the TCP FastOpen case here,
+ // as we do in IsConnected()?
+
+ // Check if connection is alive and we haven't received any data
+ // unexpectedly.
+ char c;
+ int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK));
+ if (rv >= 0)
+ return false;
+ if (errno != EAGAIN && errno != EWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+int TCPClientSocketLibevent::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_);
+ DCHECK(!waiting_connect());
+ DCHECK(read_callback_.is_null());
+ // Synchronous operation not supported
+ DCHECK(!callback.is_null());
+ DCHECK_GT(buf_len, 0);
+
+ int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len));
+ if (nread >= 0) {
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ read_bytes.Add(nread);
+ if (nread > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread,
+ buf->data());
+ RecordFastOpenStatus();
+ return nread;
+ }
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ int net_error = MapSystemError(errno);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, errno));
+ return net_error;
+ }
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_READ,
+ &read_socket_watcher_, &read_watcher_)) {
+ DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ read_buf_ = buf;
+ read_buf_len_ = buf_len;
+ read_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int TCPClientSocketLibevent::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_);
+ DCHECK(!waiting_connect());
+ DCHECK(write_callback_.is_null());
+ // Synchronous operation not supported
+ DCHECK(!callback.is_null());
+ DCHECK_GT(buf_len, 0);
+
+ int nwrite = InternalWrite(buf, buf_len);
+ if (nwrite >= 0) {
+ base::StatsCounter write_bytes("tcp.write_bytes");
+ write_bytes.Add(nwrite);
+ if (nwrite > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite,
+ buf->data());
+ return nwrite;
+ }
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ int net_error = MapSystemError(errno);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, errno));
+ return net_error;
+ }
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_WRITE,
+ &write_socket_watcher_, &write_watcher_)) {
+ DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ write_buf_ = buf;
+ write_buf_len_ = buf_len;
+ write_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) {
+ int nwrite;
+ if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
+ SockaddrStorage storage;
+ if (!addresses_[current_address_index_].ToSockAddr(storage.addr,
+ &storage.addr_len)) {
+ errno = EINVAL;
+ return -1;
+ }
+
+ int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN.
+#if defined(OS_LINUX)
+ // sendto() will fail with EPIPE when the system doesn't support TCP Fast
+ // Open. Theoretically that shouldn't happen since the caller should check
+ // for system support on startup, but users may dynamically disable TCP Fast
+ // Open via sysctl.
+ flags |= MSG_NOSIGNAL;
+#endif // defined(OS_LINUX)
+ nwrite = HANDLE_EINTR(sendto(socket_,
+ buf->data(),
+ buf_len,
+ flags,
+ storage.addr,
+ storage.addr_len));
+ tcp_fastopen_connected_ = true;
+
+ if (nwrite < 0) {
+ DCHECK_NE(EPIPE, errno);
+
+ // If errno == EINPROGRESS, that means the kernel didn't have a cookie
+ // and would block. The kernel is internally doing a connect() though.
+ // Remap EINPROGRESS to EAGAIN so we treat this the same as our other
+ // asynchronous cases. Note that the user buffer has not been copied to
+ // kernel space.
+ if (errno == EINPROGRESS) {
+ errno = EAGAIN;
+ fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN;
+ } else {
+ fast_open_status_ = FAST_OPEN_ERROR;
+ }
+ } else {
+ fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN;
+ }
+ } else {
+ nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len));
+ }
+ return nwrite;
+}
+
+bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) {
+ DCHECK(CalledOnValidThread());
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
+ reinterpret_cast<const char*>(&size),
+ sizeof(size));
+ DCHECK(!rv) << "Could not set socket receive buffer size: " << errno;
+ return rv == 0;
+}
+
+bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) {
+ DCHECK(CalledOnValidThread());
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
+ reinterpret_cast<const char*>(&size),
+ sizeof(size));
+ DCHECK(!rv) << "Could not set socket send buffer size: " << errno;
+ return rv == 0;
+}
+
+bool TCPClientSocketLibevent::SetKeepAlive(bool enable, int delay) {
+ int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_;
+ return SetTCPKeepAlive(socket, enable, delay);
+}
+
+bool TCPClientSocketLibevent::SetNoDelay(bool no_delay) {
+ int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_;
+ return SetTCPNoDelay(socket, no_delay);
+}
+
+void TCPClientSocketLibevent::ReadWatcher::OnFileCanReadWithoutBlocking(int) {
+ socket_->RecordFastOpenStatus();
+ if (!socket_->read_callback_.is_null())
+ socket_->DidCompleteRead();
+}
+
+void TCPClientSocketLibevent::WriteWatcher::OnFileCanWriteWithoutBlocking(int) {
+ if (socket_->waiting_connect()) {
+ socket_->DidCompleteConnect();
+ } else if (!socket_->write_callback_.is_null()) {
+ socket_->DidCompleteWrite();
+ }
+}
+
+void TCPClientSocketLibevent::LogConnectCompletion(int net_error) {
+ if (net_error == OK)
+ UpdateConnectionTypeHistograms(CONNECTION_ANY);
+
+ if (net_error != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error);
+ return;
+ }
+
+ SockaddrStorage storage;
+ int rv = getsockname(socket_, storage.addr, &storage.addr_len);
+ if (rv != 0) {
+ PLOG(ERROR) << "getsockname() [rv: " << rv << "] error: ";
+ NOTREACHED();
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv);
+ return;
+ }
+
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT,
+ CreateNetLogSourceAddressCallback(storage.addr,
+ storage.addr_len));
+}
+
+void TCPClientSocketLibevent::DoReadCallback(int rv) {
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ DCHECK(!read_callback_.is_null());
+
+ // since Run may result in Read being called, clear read_callback_ up front.
+ CompletionCallback c = read_callback_;
+ read_callback_.Reset();
+ c.Run(rv);
+}
+
+void TCPClientSocketLibevent::DoWriteCallback(int rv) {
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ DCHECK(!write_callback_.is_null());
+
+ // since Run may result in Write being called, clear write_callback_ up front.
+ CompletionCallback c = write_callback_;
+ write_callback_.Reset();
+ c.Run(rv);
+}
+
+void TCPClientSocketLibevent::DidCompleteConnect() {
+ DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
+
+ // Get the error that connect() completed with.
+ int os_error = 0;
+ socklen_t len = sizeof(os_error);
+ if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0)
+ os_error = errno;
+
+ // TODO(eroman): Is this check really necessary?
+ if (os_error == EINPROGRESS || os_error == EALREADY) {
+ NOTREACHED(); // This indicates a bug in libevent or our code.
+ return;
+ }
+
+ connect_os_error_ = os_error;
+ int rv = DoConnectLoop(MapConnectError(os_error));
+ if (rv != ERR_IO_PENDING) {
+ LogConnectCompletion(rv);
+ DoWriteCallback(rv);
+ }
+}
+
+void TCPClientSocketLibevent::DidCompleteRead() {
+ int bytes_transferred;
+ bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(),
+ read_buf_len_));
+
+ int result;
+ if (bytes_transferred >= 0) {
+ result = bytes_transferred;
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ read_bytes.Add(bytes_transferred);
+ if (bytes_transferred > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result,
+ read_buf_->data());
+ } else {
+ result = MapSystemError(errno);
+ if (result != ERR_IO_PENDING) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(result, errno));
+ }
+ }
+
+ if (result != ERR_IO_PENDING) {
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ DoReadCallback(result);
+ }
+}
+
+void TCPClientSocketLibevent::DidCompleteWrite() {
+ int bytes_transferred;
+ bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(),
+ write_buf_len_));
+
+ int result;
+ if (bytes_transferred >= 0) {
+ result = bytes_transferred;
+ base::StatsCounter write_bytes("tcp.write_bytes");
+ write_bytes.Add(bytes_transferred);
+ if (bytes_transferred > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result,
+ write_buf_->data());
+ } else {
+ result = MapSystemError(errno);
+ if (result != ERR_IO_PENDING) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
+ CreateNetLogSocketErrorCallback(result, errno));
+ }
+ }
+
+ if (result != ERR_IO_PENDING) {
+ write_buf_ = NULL;
+ write_buf_len_ = 0;
+ write_socket_watcher_.StopWatchingFileDescriptor();
+ DoWriteCallback(result);
+ }
+}
+
+int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ *address = addresses_[current_address_index_];
+ return OK;
+}
+
+int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (socket_ == kInvalidSocket) {
+ if (bind_address_.get()) {
+ *address = *bind_address_;
+ return OK;
+ }
+ return ERR_SOCKET_NOT_CONNECTED;
+ }
+
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len))
+ return MapSystemError(errno);
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_FAILED;
+
+ return OK;
+}
+
+void TCPClientSocketLibevent::RecordFastOpenStatus() {
+ if (use_tcp_fastopen_ &&
+ (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ||
+ fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) {
+ DCHECK_NE(FAST_OPEN_STATUS_UNKNOWN, fast_open_status_);
+ bool getsockopt_success(false);
+ bool server_acked_data(false);
+#if defined(TCP_INFO)
+ // Probe to see the if the socket used TCP Fast Open.
+ tcp_info info;
+ socklen_t info_len = sizeof(tcp_info);
+ getsockopt_success =
+ getsockopt(socket_, IPPROTO_TCP, TCP_INFO, &info, &info_len) == 0 &&
+ info_len == sizeof(tcp_info);
+ server_acked_data = getsockopt_success &&
+ (info.tcpi_options & TCPI_OPT_SYN_DATA);
+#endif
+ if (getsockopt_success) {
+ if (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN) {
+ fast_open_status_ = (server_acked_data ? FAST_OPEN_SYN_DATA_ACK :
+ FAST_OPEN_SYN_DATA_NACK);
+ } else {
+ fast_open_status_ = (server_acked_data ? FAST_OPEN_NO_SYN_DATA_ACK :
+ FAST_OPEN_NO_SYN_DATA_NACK);
+ }
+ } else {
+ fast_open_status_ = (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ?
+ FAST_OPEN_SYN_DATA_FAILED :
+ FAST_OPEN_NO_SYN_DATA_FAILED);
+ }
+ }
+}
+
+const BoundNetLog& TCPClientSocketLibevent::NetLog() const {
+ return net_log_;
+}
+
+void TCPClientSocketLibevent::SetSubresourceSpeculation() {
+ use_history_.set_subresource_speculation();
+}
+
+void TCPClientSocketLibevent::SetOmniboxSpeculation() {
+ use_history_.set_omnibox_speculation();
+}
+
+bool TCPClientSocketLibevent::WasEverUsed() const {
+ return use_history_.was_used_to_convey_data();
+}
+
+bool TCPClientSocketLibevent::UsingTCPFastOpen() const {
+ return use_tcp_fastopen_;
+}
+
+bool TCPClientSocketLibevent::WasNpnNegotiated() const {
+ return false;
+}
+
+NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const {
+ return kProtoUnknown;
+}
+
+bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) {
+ return false;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket_libevent.h b/chromium/net/socket/tcp_client_socket_libevent.h
new file mode 100644
index 00000000000..e5a0d8deab4
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket_libevent.h
@@ -0,0 +1,256 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
+#define NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_log.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class BoundNetLog;
+
+// A client socket that uses TCP as the transport layer.
+class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket,
+ public base::NonThreadSafe {
+ public:
+ // The IP address(es) and port number to connect to. The TCP socket will try
+ // each IP address in the list until it succeeds in establishing a
+ // connection.
+ TCPClientSocketLibevent(const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source);
+
+ virtual ~TCPClientSocketLibevent();
+
+ // AdoptSocket causes the given, connected socket to be adopted as a TCP
+ // socket. This object must not be connected. This object takes ownership of
+ // the given socket and then acts as if Connect() had been called. This
+ // function is used by TCPServerSocket() to adopt accepted connections
+ // and for testing.
+ int AdoptSocket(int socket);
+
+ // Binds the socket to a local IP address and port.
+ int Bind(const IPEndPoint& address);
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE;
+ virtual void Disconnect() OVERRIDE;
+ virtual bool IsConnected() const OVERRIDE;
+ virtual bool IsConnectedAndIdle() const OVERRIDE;
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual const BoundNetLog& NetLog() const OVERRIDE;
+ virtual void SetSubresourceSpeculation() OVERRIDE;
+ virtual void SetOmniboxSpeculation() OVERRIDE;
+ virtual bool WasEverUsed() const OVERRIDE;
+ virtual bool UsingTCPFastOpen() const OVERRIDE;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ // Multiple outstanding requests are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported
+ virtual int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) OVERRIDE;
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE;
+
+ virtual bool SetKeepAlive(bool enable, int delay);
+ virtual bool SetNoDelay(bool no_delay);
+
+ private:
+ // State machine for connecting the socket.
+ enum ConnectState {
+ CONNECT_STATE_CONNECT,
+ CONNECT_STATE_CONNECT_COMPLETE,
+ CONNECT_STATE_NONE,
+ };
+
+ // States that a fast open socket attempt can result in.
+ enum FastOpenStatus {
+ FAST_OPEN_STATUS_UNKNOWN,
+
+ // The initial fast open connect attempted returned synchronously,
+ // indicating that we had and sent a cookie along with the initial data.
+ FAST_OPEN_FAST_CONNECT_RETURN,
+
+ // The initial fast open connect attempted returned asynchronously,
+ // indicating that we did not have a cookie for the server.
+ FAST_OPEN_SLOW_CONNECT_RETURN,
+
+ // Some other error occurred on connection, so we couldn't tell if
+ // fast open would have worked.
+ FAST_OPEN_ERROR,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // had acked the data we sent.
+ FAST_OPEN_SYN_DATA_ACK,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // had nacked the data we sent.
+ FAST_OPEN_SYN_DATA_NACK,
+
+ // An attempt to do a fast open succeeded immediately
+ // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the
+ // socket was using fast open failed.
+ FAST_OPEN_SYN_DATA_FAILED,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and we later confirmed that the server had acked initial data. This
+ // should never happen (we didn't send data, so it shouldn't have
+ // been acked).
+ FAST_OPEN_NO_SYN_DATA_ACK,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and we later discovered that the server had nacked initial data. This
+ // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN.
+ FAST_OPEN_NO_SYN_DATA_NACK,
+
+ // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // and our later probe for ack/nack state failed.
+ FAST_OPEN_NO_SYN_DATA_FAILED,
+
+ FAST_OPEN_MAX_VALUE
+ };
+
+ class ReadWatcher : public base::MessageLoopForIO::Watcher {
+ public:
+ explicit ReadWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {}
+
+ // MessageLoopForIO::Watcher methods
+
+ virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE;
+
+ virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE {}
+
+ private:
+ TCPClientSocketLibevent* const socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(ReadWatcher);
+ };
+
+ class WriteWatcher : public base::MessageLoopForIO::Watcher {
+ public:
+ explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {}
+
+ // MessageLoopForIO::Watcher implementation.
+ virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {}
+ virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE;
+
+ private:
+ TCPClientSocketLibevent* const socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(WriteWatcher);
+ };
+
+ // State machine used by Connect().
+ int DoConnectLoop(int result);
+ int DoConnect();
+ int DoConnectComplete(int result);
+
+ // Helper used by Disconnect(), which disconnects minus the logging and
+ // resetting of current_address_index_.
+ void DoDisconnect();
+
+ void DoReadCallback(int rv);
+ void DoWriteCallback(int rv);
+ void DidCompleteRead();
+ void DidCompleteWrite();
+ void DidCompleteConnect();
+
+ // Returns true if a Connect() is in progress.
+ bool waiting_connect() const {
+ return next_connect_state_ != CONNECT_STATE_NONE;
+ }
+
+ // Helper to add a TCP_CONNECT (end) event to the NetLog.
+ void LogConnectCompletion(int net_error);
+
+ // Internal function to write to a socket.
+ int InternalWrite(IOBuffer* buf, int buf_len);
+
+ // Called when the socket is known to be in a connected state.
+ void RecordFastOpenStatus();
+
+ int socket_;
+
+ // Local IP address and port we are bound to. Set to NULL if Bind()
+ // was't called (in that cases OS chooses address/port).
+ scoped_ptr<IPEndPoint> bind_address_;
+
+ // Stores bound socket between Bind() and Connect() calls.
+ int bound_socket_;
+
+ // The list of addresses we should try in order to establish a connection.
+ AddressList addresses_;
+
+ // Where we are in above list. Set to -1 if uninitialized.
+ int current_address_index_;
+
+ // The socket's libevent wrappers
+ base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_;
+ base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_;
+
+ // The corresponding watchers for reads and writes.
+ ReadWatcher read_watcher_;
+ WriteWatcher write_watcher_;
+
+ // The buffer used by OnSocketReady to retry Read requests
+ scoped_refptr<IOBuffer> read_buf_;
+ int read_buf_len_;
+
+ // The buffer used by OnSocketReady to retry Write requests
+ scoped_refptr<IOBuffer> write_buf_;
+ int write_buf_len_;
+
+ // External callback; called when read is complete.
+ CompletionCallback read_callback_;
+
+ // External callback; called when write is complete.
+ CompletionCallback write_callback_;
+
+ // The next state for the Connect() state machine.
+ ConnectState next_connect_state_;
+
+ // The OS error that CONNECT_STATE_CONNECT last completed with.
+ int connect_os_error_;
+
+ BoundNetLog net_log_;
+
+ // This socket was previously disconnected and has not been re-connected.
+ bool previously_disconnected_;
+
+ // Record of connectivity and transmissions, for use in speculative connection
+ // histograms.
+ UseHistory use_history_;
+
+ // Enables experimental TCP FastOpen option.
+ const bool use_tcp_fastopen_;
+
+ // True when TCP FastOpen is in use and we have done the connect.
+ bool tcp_fastopen_connected_;
+
+ enum FastOpenStatus fast_open_status_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/tcp_client_socket_unittest.cc b/chromium/net/socket/tcp_client_socket_unittest.cc
new file mode 100644
index 00000000000..ce0c53559f8
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket_unittest.cc
@@ -0,0 +1,113 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file contains some tests for TCPClientSocket.
+// transport_client_socket_unittest.cc contans some other tests that
+// are common for TCP and other types of sockets.
+
+#include "net/socket/tcp_client_socket.h"
+
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/tcp_server_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+// Try binding a socket to loopback interface and verify that we can
+// still connect to a server on the same interface.
+TEST(TCPClientSocketTest, BindLoopbackToLoopback) {
+ IPAddressNumber lo_address;
+ ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address));
+
+ TCPServerSocket server(NULL, NetLog::Source());
+ ASSERT_EQ(OK, server.Listen(IPEndPoint(lo_address, 0), 1));
+ IPEndPoint server_address;
+ ASSERT_EQ(OK, server.GetLocalAddress(&server_address));
+
+ TCPClientSocket socket(AddressList(server_address), NULL, NetLog::Source());
+
+ EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0)));
+
+ IPEndPoint local_address_result;
+ EXPECT_EQ(OK, socket.GetLocalAddress(&local_address_result));
+ EXPECT_EQ(lo_address, local_address_result.address());
+
+ TestCompletionCallback connect_callback;
+ EXPECT_EQ(ERR_IO_PENDING, socket.Connect(connect_callback.callback()));
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+ int result = server.Accept(&accepted_socket, accept_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+
+ EXPECT_TRUE(socket.IsConnected());
+ socket.Disconnect();
+ EXPECT_FALSE(socket.IsConnected());
+ EXPECT_EQ(ERR_SOCKET_NOT_CONNECTED,
+ socket.GetLocalAddress(&local_address_result));
+}
+
+// Try to bind socket to the loopback interface and connect to an
+// external address, verify that connection fails.
+TEST(TCPClientSocketTest, BindLoopbackToExternal) {
+ IPAddressNumber external_ip;
+ ASSERT_TRUE(ParseIPLiteralToNumber("72.14.213.105", &external_ip));
+ TCPClientSocket socket(AddressList::CreateFromIPAddress(external_ip, 80),
+ NULL, NetLog::Source());
+
+ IPAddressNumber lo_address;
+ ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address));
+ EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0)));
+
+ TestCompletionCallback connect_callback;
+ int result = socket.Connect(connect_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = connect_callback.WaitForResult();
+
+ // We may get different errors here on different system, but
+ // connect() is not expected to succeed.
+ EXPECT_NE(OK, result);
+}
+
+// Bind a socket to the IPv4 loopback interface and try to connect to
+// the IPv6 loopback interface, verify that connection fails.
+TEST(TCPClientSocketTest, BindLoopbackToIPv6) {
+ IPAddressNumber ipv6_lo_ip;
+ ASSERT_TRUE(ParseIPLiteralToNumber("::1", &ipv6_lo_ip));
+ TCPServerSocket server(NULL, NetLog::Source());
+ int listen_result = server.Listen(IPEndPoint(ipv6_lo_ip, 0), 1);
+ if (listen_result != OK) {
+ LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is disabled."
+ " Skipping the test";
+ return;
+ }
+
+ IPEndPoint server_address;
+ ASSERT_EQ(OK, server.GetLocalAddress(&server_address));
+ TCPClientSocket socket(AddressList(server_address), NULL, NetLog::Source());
+
+ IPAddressNumber ipv4_lo_ip;
+ ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &ipv4_lo_ip));
+ EXPECT_EQ(OK, socket.Bind(IPEndPoint(ipv4_lo_ip, 0)));
+
+ TestCompletionCallback connect_callback;
+ int result = socket.Connect(connect_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = connect_callback.WaitForResult();
+
+ EXPECT_NE(OK, result);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket_win.cc b/chromium/net/socket/tcp_client_socket_win.cc
new file mode 100644
index 00000000000..9b0a5b50bf1
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket_win.cc
@@ -0,0 +1,1045 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_client_socket_win.h"
+
+#include <mstcpip.h>
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/metrics/stats_counters.h"
+#include "base/strings/string_util.h"
+#include "base/win/object_watcher.h"
+#include "base/win/windows_version.h"
+#include "net/base/connection_type_histograms.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/base/network_change_notifier.h"
+#include "net/base/winsock_init.h"
+#include "net/base/winsock_util.h"
+#include "net/socket/socket_net_log_params.h"
+
+namespace net {
+
+namespace {
+
+const int kTCPKeepAliveSeconds = 45;
+bool g_disable_overlapped_reads = false;
+
+bool SetSocketReceiveBufferSize(SOCKET socket, int32 size) {
+ int rv = setsockopt(socket, SOL_SOCKET, SO_RCVBUF,
+ reinterpret_cast<const char*>(&size), sizeof(size));
+ DCHECK(!rv) << "Could not set socket receive buffer size: " << GetLastError();
+ return rv == 0;
+}
+
+bool SetSocketSendBufferSize(SOCKET socket, int32 size) {
+ int rv = setsockopt(socket, SOL_SOCKET, SO_SNDBUF,
+ reinterpret_cast<const char*>(&size), sizeof(size));
+ DCHECK(!rv) << "Could not set socket send buffer size: " << GetLastError();
+ return rv == 0;
+}
+
+// Disable Nagle.
+// The Nagle implementation on windows is governed by RFC 896. The idea
+// behind Nagle is to reduce small packets on the network. When Nagle is
+// enabled, if a partial packet has been sent, the TCP stack will disallow
+// further *partial* packets until an ACK has been received from the other
+// side. Good applications should always strive to send as much data as
+// possible and avoid partial-packet sends. However, in most real world
+// applications, there are edge cases where this does not happen, and two
+// partial packets may be sent back to back. For a browser, it is NEVER
+// a benefit to delay for an RTT before the second packet is sent.
+//
+// As a practical example in Chromium today, consider the case of a small
+// POST. I have verified this:
+// Client writes 649 bytes of header (partial packet #1)
+// Client writes 50 bytes of POST data (partial packet #2)
+// In the above example, with Nagle, a RTT delay is inserted between these
+// two sends due to nagle. RTTs can easily be 100ms or more. The best
+// fix is to make sure that for POSTing data, we write as much data as
+// possible and minimize partial packets. We will fix that. But disabling
+// Nagle also ensure we don't run into this delay in other edge cases.
+// See also:
+// http://technet.microsoft.com/en-us/library/bb726981.aspx
+bool DisableNagle(SOCKET socket, bool disable) {
+ BOOL val = disable ? TRUE : FALSE;
+ int rv = setsockopt(socket, IPPROTO_TCP, TCP_NODELAY,
+ reinterpret_cast<const char*>(&val),
+ sizeof(val));
+ DCHECK(!rv) << "Could not disable nagle";
+ return rv == 0;
+}
+
+// Enable TCP Keep-Alive to prevent NAT routers from timing out TCP
+// connections. See http://crbug.com/27400 for details.
+bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) {
+ int delay = delay_secs * 1000;
+ struct tcp_keepalive keepalive_vals = {
+ enable ? 1 : 0, // TCP keep-alive on.
+ delay, // Delay seconds before sending first TCP keep-alive packet.
+ delay, // Delay seconds between sending TCP keep-alive packets.
+ };
+ DWORD bytes_returned = 0xABAB;
+ int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals,
+ sizeof(keepalive_vals), NULL, 0,
+ &bytes_returned, NULL, NULL);
+ DCHECK(!rv) << "Could not enable TCP Keep-Alive for socket: " << socket
+ << " [error: " << WSAGetLastError() << "].";
+
+ // Disregard any failure in disabling nagle or enabling TCP Keep-Alive.
+ return rv == 0;
+}
+
+// Sets socket parameters. Returns the OS error code (or 0 on
+// success).
+int SetupSocket(SOCKET socket) {
+ // Increase the socket buffer sizes from the default sizes for WinXP. In
+ // performance testing, there is substantial benefit by increasing from 8KB
+ // to 64KB.
+ // See also:
+ // http://support.microsoft.com/kb/823764/EN-US
+ // On Vista, if we manually set these sizes, Vista turns off its receive
+ // window auto-tuning feature.
+ // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx
+ // Since Vista's auto-tune is better than any static value we can could set,
+ // only change these on pre-vista machines.
+ if (base::win::GetVersion() < base::win::VERSION_VISTA) {
+ const int32 kSocketBufferSize = 64 * 1024;
+ SetSocketReceiveBufferSize(socket, kSocketBufferSize);
+ SetSocketSendBufferSize(socket, kSocketBufferSize);
+ }
+
+ DisableNagle(socket, true);
+ SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds);
+ return 0;
+}
+
+// Creates a new socket and sets default parameters for it. Returns
+// the OS error code (or 0 on success).
+int CreateSocket(int family, SOCKET* socket) {
+ *socket = CreatePlatformSocket(family, SOCK_STREAM, IPPROTO_TCP);
+ if (*socket == INVALID_SOCKET) {
+ int os_error = WSAGetLastError();
+ LOG(ERROR) << "CreatePlatformSocket failed: " << os_error;
+ return os_error;
+ }
+ int error = SetupSocket(*socket);
+ if (error) {
+ if (closesocket(*socket) < 0)
+ PLOG(ERROR) << "closesocket";
+ *socket = INVALID_SOCKET;
+ return error;
+ }
+ return 0;
+}
+
+int MapConnectError(int os_error) {
+ switch (os_error) {
+ // connect fails with WSAEACCES when Windows Firewall blocks the
+ // connection.
+ case WSAEACCES:
+ return ERR_NETWORK_ACCESS_DENIED;
+ case WSAETIMEDOUT:
+ return ERR_CONNECTION_TIMED_OUT;
+ default: {
+ int net_error = MapSystemError(os_error);
+ if (net_error == ERR_FAILED)
+ return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED.
+
+ // Give a more specific error when the user is offline.
+ if (net_error == ERR_ADDRESS_UNREACHABLE &&
+ NetworkChangeNotifier::IsOffline()) {
+ return ERR_INTERNET_DISCONNECTED;
+ }
+
+ return net_error;
+ }
+ }
+}
+
+} // namespace
+
+//-----------------------------------------------------------------------------
+
+// This class encapsulates all the state that has to be preserved as long as
+// there is a network IO operation in progress. If the owner TCPClientSocketWin
+// is destroyed while an operation is in progress, the Core is detached and it
+// lives until the operation completes and the OS doesn't reference any resource
+// declared on this class anymore.
+class TCPClientSocketWin::Core : public base::RefCounted<Core> {
+ public:
+ explicit Core(TCPClientSocketWin* socket);
+
+ // Start watching for the end of a read or write operation.
+ void WatchForRead();
+ void WatchForWrite();
+
+ // The TCPClientSocketWin is going away.
+ void Detach() { socket_ = NULL; }
+
+ // Throttle the read size based on our current slow start state.
+ // Returns the throttled read size.
+ int ThrottleReadSize(int size) {
+ if (slow_start_throttle_ < kMaxSlowStartThrottle) {
+ size = std::min(size, slow_start_throttle_);
+ slow_start_throttle_ *= 2;
+ }
+ return size;
+ }
+
+ // The separate OVERLAPPED variables for asynchronous operation.
+ // |read_overlapped_| is used for both Connect() and Read().
+ // |write_overlapped_| is only used for Write();
+ OVERLAPPED read_overlapped_;
+ OVERLAPPED write_overlapped_;
+
+ // The buffers used in Read() and Write().
+ scoped_refptr<IOBuffer> read_iobuffer_;
+ scoped_refptr<IOBuffer> write_iobuffer_;
+ int read_buffer_length_;
+ int write_buffer_length_;
+
+ // Remember the state of g_disable_overlapped_reads for the duration of the
+ // socket based on what it was when the socket was created.
+ bool disable_overlapped_reads_;
+ bool non_blocking_reads_initialized_;
+
+ private:
+ friend class base::RefCounted<Core>;
+
+ class ReadDelegate : public base::win::ObjectWatcher::Delegate {
+ public:
+ explicit ReadDelegate(Core* core) : core_(core) {}
+ virtual ~ReadDelegate() {}
+
+ // base::ObjectWatcher::Delegate methods:
+ virtual void OnObjectSignaled(HANDLE object);
+
+ private:
+ Core* const core_;
+ };
+
+ class WriteDelegate : public base::win::ObjectWatcher::Delegate {
+ public:
+ explicit WriteDelegate(Core* core) : core_(core) {}
+ virtual ~WriteDelegate() {}
+
+ // base::ObjectWatcher::Delegate methods:
+ virtual void OnObjectSignaled(HANDLE object);
+
+ private:
+ Core* const core_;
+ };
+
+ ~Core();
+
+ // The socket that created this object.
+ TCPClientSocketWin* socket_;
+
+ // |reader_| handles the signals from |read_watcher_|.
+ ReadDelegate reader_;
+ // |writer_| handles the signals from |write_watcher_|.
+ WriteDelegate writer_;
+
+ // |read_watcher_| watches for events from Connect() and Read().
+ base::win::ObjectWatcher read_watcher_;
+ // |write_watcher_| watches for events from Write();
+ base::win::ObjectWatcher write_watcher_;
+
+ // When doing reads from the socket, we try to mirror TCP's slow start.
+ // We do this because otherwise the async IO subsystem artifically delays
+ // returning data to the application.
+ static const int kInitialSlowStartThrottle = 1 * 1024;
+ static const int kMaxSlowStartThrottle = 32 * kInitialSlowStartThrottle;
+ int slow_start_throttle_;
+
+ DISALLOW_COPY_AND_ASSIGN(Core);
+};
+
+TCPClientSocketWin::Core::Core(
+ TCPClientSocketWin* socket)
+ : read_buffer_length_(0),
+ write_buffer_length_(0),
+ disable_overlapped_reads_(g_disable_overlapped_reads),
+ non_blocking_reads_initialized_(false),
+ socket_(socket),
+ reader_(this),
+ writer_(this),
+ slow_start_throttle_(kInitialSlowStartThrottle) {
+ memset(&read_overlapped_, 0, sizeof(read_overlapped_));
+ memset(&write_overlapped_, 0, sizeof(write_overlapped_));
+
+ read_overlapped_.hEvent = WSACreateEvent();
+ write_overlapped_.hEvent = WSACreateEvent();
+}
+
+TCPClientSocketWin::Core::~Core() {
+ // Make sure the message loop is not watching this object anymore.
+ read_watcher_.StopWatching();
+ write_watcher_.StopWatching();
+
+ WSACloseEvent(read_overlapped_.hEvent);
+ memset(&read_overlapped_, 0xaf, sizeof(read_overlapped_));
+ WSACloseEvent(write_overlapped_.hEvent);
+ memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
+}
+
+void TCPClientSocketWin::Core::WatchForRead() {
+ // We grab an extra reference because there is an IO operation in progress.
+ // Balanced in ReadDelegate::OnObjectSignaled().
+ AddRef();
+ read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_);
+}
+
+void TCPClientSocketWin::Core::WatchForWrite() {
+ // We grab an extra reference because there is an IO operation in progress.
+ // Balanced in WriteDelegate::OnObjectSignaled().
+ AddRef();
+ write_watcher_.StartWatching(write_overlapped_.hEvent, &writer_);
+}
+
+void TCPClientSocketWin::Core::ReadDelegate::OnObjectSignaled(
+ HANDLE object) {
+ DCHECK_EQ(object, core_->read_overlapped_.hEvent);
+ if (core_->socket_) {
+ if (core_->socket_->waiting_connect()) {
+ core_->socket_->DidCompleteConnect();
+ } else if (core_->disable_overlapped_reads_) {
+ core_->socket_->DidSignalRead();
+ } else {
+ core_->socket_->DidCompleteRead();
+ }
+ }
+
+ core_->Release();
+}
+
+void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled(
+ HANDLE object) {
+ DCHECK_EQ(object, core_->write_overlapped_.hEvent);
+ if (core_->socket_)
+ core_->socket_->DidCompleteWrite();
+
+ core_->Release();
+}
+
+//-----------------------------------------------------------------------------
+
+TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(INVALID_SOCKET),
+ bound_socket_(INVALID_SOCKET),
+ addresses_(addresses),
+ current_address_index_(-1),
+ waiting_read_(false),
+ waiting_write_(false),
+ next_connect_state_(CONNECT_STATE_NONE),
+ connect_os_error_(0),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
+ previously_disconnected_(false) {
+ net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
+ source.ToEventParametersCallback());
+ EnsureWinsockInit();
+}
+
+TCPClientSocketWin::~TCPClientSocketWin() {
+ Disconnect();
+ net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
+}
+
+int TCPClientSocketWin::AdoptSocket(SOCKET socket) {
+ DCHECK_EQ(socket_, INVALID_SOCKET);
+
+ int error = SetupSocket(socket);
+ if (error)
+ return MapSystemError(error);
+
+ socket_ = socket;
+ SetNonBlocking(socket_);
+
+ core_ = new Core(this);
+ current_address_index_ = 0;
+ use_history_.set_was_ever_connected();
+
+ return OK;
+}
+
+int TCPClientSocketWin::Bind(const IPEndPoint& address) {
+ if (current_address_index_ >= 0 || bind_address_.get()) {
+ // Cannot bind the socket if we are already connected or connecting.
+ return ERR_UNEXPECTED;
+ }
+
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+
+ // Create |bound_socket_| and try to bind it to |address|.
+ int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_);
+ if (error)
+ return MapSystemError(error);
+
+ if (bind(bound_socket_, storage.addr, storage.addr_len)) {
+ error = errno;
+ if (closesocket(bound_socket_) < 0)
+ PLOG(ERROR) << "closesocket";
+ bound_socket_ = INVALID_SOCKET;
+ return MapSystemError(error);
+ }
+
+ bind_address_.reset(new IPEndPoint(address));
+
+ return 0;
+}
+
+
+int TCPClientSocketWin::Connect(const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+
+ // If already connected, then just return OK.
+ if (socket_ != INVALID_SOCKET)
+ return OK;
+
+ base::StatsCounter connects("tcp.connect");
+ connects.Increment();
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT,
+ addresses_.CreateNetLogCallback());
+
+ // We will try to connect to each address in addresses_. Start with the
+ // first one in the list.
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ current_address_index_ = 0;
+
+ int rv = DoConnectLoop(OK);
+ if (rv == ERR_IO_PENDING) {
+ // Synchronous operation not supported.
+ DCHECK(!callback.is_null());
+ // TODO(ajwong): Is setting read_callback_ the right thing to do here??
+ read_callback_ = callback;
+ } else {
+ LogConnectCompletion(rv);
+ }
+
+ return rv;
+}
+
+int TCPClientSocketWin::DoConnectLoop(int result) {
+ DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
+
+ int rv = result;
+ do {
+ ConnectState state = next_connect_state_;
+ next_connect_state_ = CONNECT_STATE_NONE;
+ switch (state) {
+ case CONNECT_STATE_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoConnect();
+ break;
+ case CONNECT_STATE_CONNECT_COMPLETE:
+ rv = DoConnectComplete(rv);
+ break;
+ default:
+ LOG(DFATAL) << "bad state " << state;
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
+
+ return rv;
+}
+
+int TCPClientSocketWin::DoConnect() {
+ DCHECK_GE(current_address_index_, 0);
+ DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
+ DCHECK_EQ(0, connect_os_error_);
+
+ const IPEndPoint& endpoint = addresses_[current_address_index_];
+
+ if (previously_disconnected_) {
+ use_history_.Reset();
+ previously_disconnected_ = false;
+ }
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ CreateNetLogIPEndPointCallback(&endpoint));
+
+ next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
+
+ if (bound_socket_ != INVALID_SOCKET) {
+ DCHECK(bind_address_.get());
+ socket_ = bound_socket_;
+ bound_socket_ = INVALID_SOCKET;
+ } else {
+ connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_);
+ if (connect_os_error_ != 0)
+ return MapSystemError(connect_os_error_);
+
+ if (bind_address_.get()) {
+ SockaddrStorage storage;
+ if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+ if (bind(socket_, storage.addr, storage.addr_len))
+ return MapSystemError(errno);
+ }
+ }
+
+ DCHECK(!core_);
+ core_ = new Core(this);
+ // WSAEventSelect sets the socket to non-blocking mode as a side effect.
+ // Our connect() and recv() calls require that the socket be non-blocking.
+ WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT);
+
+ SockaddrStorage storage;
+ if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_INVALID_ARGUMENT;
+ if (!connect(socket_, storage.addr, storage.addr_len)) {
+ // Connected without waiting!
+ //
+ // The MSDN page for connect says:
+ // With a nonblocking socket, the connection attempt cannot be completed
+ // immediately. In this case, connect will return SOCKET_ERROR, and
+ // WSAGetLastError will return WSAEWOULDBLOCK.
+ // which implies that for a nonblocking socket, connect never returns 0.
+ // It's not documented whether the event object will be signaled or not
+ // if connect does return 0. So the code below is essentially dead code
+ // and we don't know if it's correct.
+ NOTREACHED();
+
+ if (ResetEventIfSignaled(core_->read_overlapped_.hEvent))
+ return OK;
+ } else {
+ int os_error = WSAGetLastError();
+ if (os_error != WSAEWOULDBLOCK) {
+ LOG(ERROR) << "connect failed: " << os_error;
+ connect_os_error_ = os_error;
+ return MapConnectError(os_error);
+ }
+ }
+
+ core_->WatchForRead();
+ return ERR_IO_PENDING;
+}
+
+int TCPClientSocketWin::DoConnectComplete(int result) {
+ // Log the end of this attempt (and any OS error it threw).
+ int os_error = connect_os_error_;
+ connect_os_error_ = 0;
+ if (result != OK) {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ NetLog::IntegerCallback("os_error", os_error));
+ } else {
+ net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
+ }
+
+ if (result == OK) {
+ use_history_.set_was_ever_connected();
+ return OK; // Done!
+ }
+
+ // Close whatever partially connected socket we currently have.
+ DoDisconnect();
+
+ // Try to fall back to the next address in the list.
+ if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
+ next_connect_state_ = CONNECT_STATE_CONNECT;
+ ++current_address_index_;
+ return OK;
+ }
+
+ // Otherwise there is nothing to fall back to, so give up.
+ return result;
+}
+
+void TCPClientSocketWin::Disconnect() {
+ DCHECK(CalledOnValidThread());
+
+ DoDisconnect();
+ current_address_index_ = -1;
+ bind_address_.reset();
+}
+
+void TCPClientSocketWin::DoDisconnect() {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ == INVALID_SOCKET)
+ return;
+
+ // Note: don't use CancelIo to cancel pending IO because it doesn't work
+ // when there is a Winsock layered service provider.
+
+ // In most socket implementations, closing a socket results in a graceful
+ // connection shutdown, but in Winsock we have to call shutdown explicitly.
+ // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure"
+ // at http://msdn.microsoft.com/en-us/library/ms738547.aspx
+ shutdown(socket_, SD_SEND);
+
+ // This cancels any pending IO.
+ closesocket(socket_);
+ socket_ = INVALID_SOCKET;
+
+ if (waiting_connect()) {
+ // We closed the socket, so this notification will never come.
+ // From MSDN' WSAEventSelect documentation:
+ // "Closing a socket with closesocket also cancels the association and
+ // selection of network events specified in WSAEventSelect for the socket".
+ core_->Release();
+ }
+
+ waiting_read_ = false;
+ waiting_write_ = false;
+
+ core_->Detach();
+ core_ = NULL;
+
+ previously_disconnected_ = true;
+}
+
+bool TCPClientSocketWin::IsConnected() const {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ == INVALID_SOCKET || waiting_connect())
+ return false;
+
+ if (waiting_read_)
+ return true;
+
+ // Check if connection is alive.
+ char c;
+ int rv = recv(socket_, &c, 1, MSG_PEEK);
+ if (rv == 0)
+ return false;
+ if (rv == SOCKET_ERROR && WSAGetLastError() != WSAEWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+bool TCPClientSocketWin::IsConnectedAndIdle() const {
+ DCHECK(CalledOnValidThread());
+
+ if (socket_ == INVALID_SOCKET || waiting_connect())
+ return false;
+
+ if (waiting_read_)
+ return true;
+
+ // Check if connection is alive and we haven't received any data
+ // unexpectedly.
+ char c;
+ int rv = recv(socket_, &c, 1, MSG_PEEK);
+ if (rv >= 0)
+ return false;
+ if (WSAGetLastError() != WSAEWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (!IsConnected())
+ return ERR_SOCKET_NOT_CONNECTED;
+ *address = addresses_[current_address_index_];
+ return OK;
+}
+
+int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+ if (socket_ == INVALID_SOCKET) {
+ if (bind_address_.get()) {
+ *address = *bind_address_;
+ return OK;
+ }
+ return ERR_SOCKET_NOT_CONNECTED;
+ }
+
+ struct sockaddr_storage addr_storage;
+ socklen_t addr_len = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ if (getsockname(socket_, addr, &addr_len))
+ return MapSystemError(WSAGetLastError());
+ if (!address->FromSockAddr(addr, addr_len))
+ return ERR_FAILED;
+ return OK;
+}
+
+void TCPClientSocketWin::SetSubresourceSpeculation() {
+ use_history_.set_subresource_speculation();
+}
+
+void TCPClientSocketWin::SetOmniboxSpeculation() {
+ use_history_.set_omnibox_speculation();
+}
+
+bool TCPClientSocketWin::WasEverUsed() const {
+ return use_history_.was_used_to_convey_data();
+}
+
+bool TCPClientSocketWin::UsingTCPFastOpen() const {
+ // Not supported on windows.
+ return false;
+}
+
+bool TCPClientSocketWin::WasNpnNegotiated() const {
+ return false;
+}
+
+NextProto TCPClientSocketWin::GetNegotiatedProtocol() const {
+ return kProtoUnknown;
+}
+
+bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) {
+ return false;
+}
+
+int TCPClientSocketWin::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, INVALID_SOCKET);
+ DCHECK(!waiting_read_);
+ DCHECK(read_callback_.is_null());
+ DCHECK(!core_->read_iobuffer_);
+
+ return DoRead(buf, buf_len, callback);
+}
+
+int TCPClientSocketWin::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_NE(socket_, INVALID_SOCKET);
+ DCHECK(!waiting_write_);
+ DCHECK(write_callback_.is_null());
+ DCHECK_GT(buf_len, 0);
+ DCHECK(!core_->write_iobuffer_);
+
+ base::StatsCounter writes("tcp.writes");
+ writes.Increment();
+
+ WSABUF write_buffer;
+ write_buffer.len = buf_len;
+ write_buffer.buf = buf->data();
+
+ // TODO(wtc): Remove the assertion after enough testing.
+ AssertEventNotSignaled(core_->write_overlapped_.hEvent);
+ DWORD num;
+ int rv = WSASend(socket_, &write_buffer, 1, &num, 0,
+ &core_->write_overlapped_, NULL);
+ if (rv == 0) {
+ if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
+ rv = static_cast<int>(num);
+ if (rv > buf_len || rv < 0) {
+ // It seems that some winsock interceptors report that more was written
+ // than was available. Treat this as an error. http://crbug.com/27870
+ LOG(ERROR) << "Detected broken LSP: Asked to write " << buf_len
+ << " bytes, but " << rv << " bytes reported.";
+ return ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES;
+ }
+ base::StatsCounter write_bytes("tcp.write_bytes");
+ write_bytes.Add(rv);
+ if (rv > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv,
+ buf->data());
+ return rv;
+ }
+ } else {
+ int os_error = WSAGetLastError();
+ if (os_error != WSA_IO_PENDING) {
+ int net_error = MapSystemError(os_error);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, os_error));
+ return net_error;
+ }
+ }
+ waiting_write_ = true;
+ write_callback_ = callback;
+ core_->write_iobuffer_ = buf;
+ core_->write_buffer_length_ = buf_len;
+ core_->WatchForWrite();
+ return ERR_IO_PENDING;
+}
+
+bool TCPClientSocketWin::SetReceiveBufferSize(int32 size) {
+ DCHECK(CalledOnValidThread());
+ return SetSocketReceiveBufferSize(socket_, size);
+}
+
+bool TCPClientSocketWin::SetSendBufferSize(int32 size) {
+ DCHECK(CalledOnValidThread());
+ return SetSocketSendBufferSize(socket_, size);
+}
+
+bool TCPClientSocketWin::SetKeepAlive(bool enable, int delay) {
+ return SetTCPKeepAlive(socket_, enable, delay);
+}
+
+bool TCPClientSocketWin::SetNoDelay(bool no_delay) {
+ return DisableNagle(socket_, no_delay);
+}
+
+void TCPClientSocketWin::DisableOverlappedReads() {
+ g_disable_overlapped_reads = true;
+}
+
+void TCPClientSocketWin::LogConnectCompletion(int net_error) {
+ if (net_error == OK)
+ UpdateConnectionTypeHistograms(CONNECTION_ANY);
+
+ if (net_error != OK) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error);
+ return;
+ }
+
+ struct sockaddr_storage source_address;
+ socklen_t addrlen = sizeof(source_address);
+ int rv = getsockname(
+ socket_, reinterpret_cast<struct sockaddr*>(&source_address), &addrlen);
+ if (rv != 0) {
+ LOG(ERROR) << "getsockname() [rv: " << rv
+ << "] error: " << WSAGetLastError();
+ NOTREACHED();
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv);
+ return;
+ }
+
+ net_log_.EndEvent(
+ NetLog::TYPE_TCP_CONNECT,
+ CreateNetLogSourceAddressCallback(
+ reinterpret_cast<const struct sockaddr*>(&source_address),
+ sizeof(source_address)));
+}
+
+int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ if (core_->disable_overlapped_reads_) {
+ if (!core_->non_blocking_reads_initialized_) {
+ WSAEventSelect(socket_, core_->read_overlapped_.hEvent,
+ FD_READ | FD_CLOSE);
+ core_->non_blocking_reads_initialized_ = true;
+ }
+ int rv = recv(socket_, buf->data(), buf_len, 0);
+ if (rv == SOCKET_ERROR) {
+ int os_error = WSAGetLastError();
+ if (os_error != WSAEWOULDBLOCK) {
+ int net_error = MapSystemError(os_error);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, os_error));
+ return net_error;
+ }
+ } else {
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ if (rv > 0) {
+ use_history_.set_was_used_to_convey_data();
+ read_bytes.Add(rv);
+ }
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv,
+ buf->data());
+ return rv;
+ }
+ } else {
+ buf_len = core_->ThrottleReadSize(buf_len);
+
+ WSABUF read_buffer;
+ read_buffer.len = buf_len;
+ read_buffer.buf = buf->data();
+
+ // TODO(wtc): Remove the assertion after enough testing.
+ AssertEventNotSignaled(core_->read_overlapped_.hEvent);
+ DWORD num;
+ DWORD flags = 0;
+ int rv = WSARecv(socket_, &read_buffer, 1, &num, &flags,
+ &core_->read_overlapped_, NULL);
+ if (rv == 0) {
+ if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) {
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ if (num > 0) {
+ use_history_.set_was_used_to_convey_data();
+ read_bytes.Add(num);
+ }
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num,
+ buf->data());
+ return static_cast<int>(num);
+ }
+ } else {
+ int os_error = WSAGetLastError();
+ if (os_error != WSA_IO_PENDING) {
+ int net_error = MapSystemError(os_error);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(net_error, os_error));
+ return net_error;
+ }
+ }
+ }
+
+ waiting_read_ = true;
+ read_callback_ = callback;
+ core_->read_iobuffer_ = buf;
+ core_->read_buffer_length_ = buf_len;
+ core_->WatchForRead();
+ return ERR_IO_PENDING;
+}
+
+void TCPClientSocketWin::DoReadCallback(int rv) {
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ DCHECK(!read_callback_.is_null());
+
+ // Since Run may result in Read being called, clear read_callback_ up front.
+ CompletionCallback c = read_callback_;
+ read_callback_.Reset();
+ c.Run(rv);
+}
+
+void TCPClientSocketWin::DoWriteCallback(int rv) {
+ DCHECK_NE(rv, ERR_IO_PENDING);
+ DCHECK(!write_callback_.is_null());
+
+ // since Run may result in Write being called, clear write_callback_ up front.
+ CompletionCallback c = write_callback_;
+ write_callback_.Reset();
+ c.Run(rv);
+}
+
+void TCPClientSocketWin::DidCompleteConnect() {
+ DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
+ int result;
+
+ WSANETWORKEVENTS events;
+ int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent,
+ &events);
+ int os_error = 0;
+ if (rv == SOCKET_ERROR) {
+ NOTREACHED();
+ os_error = WSAGetLastError();
+ result = MapSystemError(os_error);
+ } else if (events.lNetworkEvents & FD_CONNECT) {
+ os_error = events.iErrorCode[FD_CONNECT_BIT];
+ result = MapConnectError(os_error);
+ } else {
+ NOTREACHED();
+ result = ERR_UNEXPECTED;
+ }
+
+ connect_os_error_ = os_error;
+ rv = DoConnectLoop(result);
+ if (rv != ERR_IO_PENDING) {
+ LogConnectCompletion(rv);
+ DoReadCallback(rv);
+ }
+}
+
+void TCPClientSocketWin::DidCompleteRead() {
+ DCHECK(waiting_read_);
+ DWORD num_bytes, flags;
+ BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_,
+ &num_bytes, FALSE, &flags);
+ waiting_read_ = false;
+ int rv;
+ if (ok) {
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ read_bytes.Add(num_bytes);
+ if (num_bytes > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED,
+ num_bytes, core_->read_iobuffer_->data());
+ rv = static_cast<int>(num_bytes);
+ } else {
+ int os_error = WSAGetLastError();
+ rv = MapSystemError(os_error);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(rv, os_error));
+ }
+ WSAResetEvent(core_->read_overlapped_.hEvent);
+ core_->read_iobuffer_ = NULL;
+ core_->read_buffer_length_ = 0;
+ DoReadCallback(rv);
+}
+
+void TCPClientSocketWin::DidCompleteWrite() {
+ DCHECK(waiting_write_);
+
+ DWORD num_bytes, flags;
+ BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
+ &num_bytes, FALSE, &flags);
+ WSAResetEvent(core_->write_overlapped_.hEvent);
+ waiting_write_ = false;
+ int rv;
+ if (!ok) {
+ int os_error = WSAGetLastError();
+ rv = MapSystemError(os_error);
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
+ CreateNetLogSocketErrorCallback(rv, os_error));
+ } else {
+ rv = static_cast<int>(num_bytes);
+ if (rv > core_->write_buffer_length_ || rv < 0) {
+ // It seems that some winsock interceptors report that more was written
+ // than was available. Treat this as an error. http://crbug.com/27870
+ LOG(ERROR) << "Detected broken LSP: Asked to write "
+ << core_->write_buffer_length_ << " bytes, but " << rv
+ << " bytes reported.";
+ rv = ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES;
+ } else {
+ base::StatsCounter write_bytes("tcp.write_bytes");
+ write_bytes.Add(num_bytes);
+ if (num_bytes > 0)
+ use_history_.set_was_used_to_convey_data();
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes,
+ core_->write_iobuffer_->data());
+ }
+ }
+ core_->write_iobuffer_ = NULL;
+ DoWriteCallback(rv);
+}
+
+void TCPClientSocketWin::DidSignalRead() {
+ DCHECK(waiting_read_);
+ int os_error = 0;
+ WSANETWORKEVENTS network_events;
+ int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent,
+ &network_events);
+ if (rv == SOCKET_ERROR) {
+ os_error = WSAGetLastError();
+ rv = MapSystemError(os_error);
+ } else if (network_events.lNetworkEvents) {
+ DCHECK_EQ(network_events.lNetworkEvents & ~(FD_READ | FD_CLOSE), 0);
+ // If network_events.lNetworkEvents is FD_CLOSE and
+ // network_events.iErrorCode[FD_CLOSE_BIT] is 0, it is a graceful
+ // connection closure. It is tempting to directly set rv to 0 in
+ // this case, but the MSDN pages for WSAEventSelect and
+ // WSAAsyncSelect recommend we still call DoRead():
+ // FD_CLOSE should only be posted after all data is read from a
+ // socket, but an application should check for remaining data upon
+ // receipt of FD_CLOSE to avoid any possibility of losing data.
+ //
+ // If network_events.iErrorCode[FD_READ_BIT] or
+ // network_events.iErrorCode[FD_CLOSE_BIT] is nonzero, still call
+ // DoRead() because recv() reports a more accurate error code
+ // (WSAECONNRESET vs. WSAECONNABORTED) when the connection was
+ // reset.
+ rv = DoRead(core_->read_iobuffer_, core_->read_buffer_length_,
+ read_callback_);
+ if (rv == ERR_IO_PENDING)
+ return;
+ } else {
+ // This may happen because Read() may succeed synchronously and
+ // consume all the received data without resetting the event object.
+ core_->WatchForRead();
+ return;
+ }
+ waiting_read_ = false;
+ core_->read_iobuffer_ = NULL;
+ core_->read_buffer_length_ = 0;
+ DoReadCallback(rv);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_client_socket_win.h b/chromium/net/socket/tcp_client_socket_win.h
new file mode 100644
index 00000000000..26c8b9feff2
--- /dev/null
+++ b/chromium/net/socket/tcp_client_socket_win.h
@@ -0,0 +1,162 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
+#define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
+
+#include <winsock2.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/base/address_list.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_log.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class BoundNetLog;
+
+class NET_EXPORT TCPClientSocketWin : public StreamSocket,
+ NON_EXPORTED_BASE(base::NonThreadSafe) {
+ public:
+ // The IP address(es) and port number to connect to. The TCP socket will try
+ // each IP address in the list until it succeeds in establishing a
+ // connection.
+ TCPClientSocketWin(const AddressList& addresses,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source);
+
+ virtual ~TCPClientSocketWin();
+
+ // AdoptSocket causes the given, connected socket to be adopted as a TCP
+ // socket. This object must not be connected. This object takes ownership of
+ // the given socket and then acts as if Connect() had been called. This
+ // function is used by TCPServerSocket() to adopt accepted connections
+ // and for testing.
+ int AdoptSocket(SOCKET socket);
+
+ // Binds the socket to a local IP address and port.
+ int Bind(const IPEndPoint& address);
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback);
+ virtual void Disconnect();
+ virtual bool IsConnected() const;
+ virtual bool IsConnectedAndIdle() const;
+ virtual int GetPeerAddress(IPEndPoint* address) const;
+ virtual int GetLocalAddress(IPEndPoint* address) const;
+ virtual const BoundNetLog& NetLog() const { return net_log_; }
+ virtual void SetSubresourceSpeculation();
+ virtual void SetOmniboxSpeculation();
+ virtual bool WasEverUsed() const;
+ virtual bool UsingTCPFastOpen() const;
+ virtual bool WasNpnNegotiated() const OVERRIDE;
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+
+ // Socket implementation.
+ // Multiple outstanding requests are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback);
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback);
+
+ virtual bool SetReceiveBufferSize(int32 size);
+ virtual bool SetSendBufferSize(int32 size);
+
+ virtual bool SetKeepAlive(bool enable, int delay);
+ virtual bool SetNoDelay(bool no_delay);
+
+ // Perform reads in non-blocking mode instead of overlapped mode.
+ // Used for experiments.
+ static void DisableOverlappedReads();
+
+ private:
+ // State machine for connecting the socket.
+ enum ConnectState {
+ CONNECT_STATE_CONNECT,
+ CONNECT_STATE_CONNECT_COMPLETE,
+ CONNECT_STATE_NONE,
+ };
+
+ class Core;
+
+ // State machine used by Connect().
+ int DoConnectLoop(int result);
+ int DoConnect();
+ int DoConnectComplete(int result);
+
+ // Helper used by Disconnect(), which disconnects minus the logging and
+ // resetting of current_address_index_.
+ void DoDisconnect();
+
+ // Returns true if a Connect() is in progress.
+ bool waiting_connect() const {
+ return next_connect_state_ != CONNECT_STATE_NONE;
+ }
+
+ // Called after Connect() has completed with |net_error|.
+ void LogConnectCompletion(int net_error);
+
+ int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ void DoReadCallback(int rv);
+ void DoWriteCallback(int rv);
+ void DidCompleteConnect();
+ void DidCompleteRead();
+ void DidCompleteWrite();
+ void DidSignalRead();
+
+ SOCKET socket_;
+
+ // Local IP address and port we are bound to. Set to NULL if Bind()
+ // was't called (in that cases OS chooses address/port).
+ scoped_ptr<IPEndPoint> bind_address_;
+
+ // Stores bound socket between Bind() and Connect() calls.
+ SOCKET bound_socket_;
+
+ // The list of addresses we should try in order to establish a connection.
+ AddressList addresses_;
+
+ // Where we are in above list. Set to -1 if uninitialized.
+ int current_address_index_;
+
+ // The various states that the socket could be in.
+ bool waiting_read_;
+ bool waiting_write_;
+
+ // The core of the socket that can live longer than the socket itself. We pass
+ // resources to the Windows async IO functions and we have to make sure that
+ // they are not destroyed while the OS still references them.
+ scoped_refptr<Core> core_;
+
+ // External callback; called when connect or read is complete.
+ CompletionCallback read_callback_;
+
+ // External callback; called when write is complete.
+ CompletionCallback write_callback_;
+
+ // The next state for the Connect() state machine.
+ ConnectState next_connect_state_;
+
+ // The OS error that CONNECT_STATE_CONNECT last completed with.
+ int connect_os_error_;
+
+ BoundNetLog net_log_;
+
+ // This socket was previously disconnected and has not been re-connected.
+ bool previously_disconnected_;
+
+ // Record of connectivity and transmissions, for use in speculative connection
+ // histograms.
+ UseHistory use_history_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_
diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc
new file mode 100644
index 00000000000..aab2e45d0e9
--- /dev/null
+++ b/chromium/net/socket/tcp_listen_socket.cc
@@ -0,0 +1,128 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_listen_socket.h"
+
+#if defined(OS_WIN)
+// winsock2.h must be included first in order to ensure it is included before
+// windows.h.
+#include <winsock2.h>
+#elif defined(OS_POSIX)
+#include <arpa/inet.h>
+#include <errno.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include "net/base/net_errors.h"
+#endif
+
+#include "base/logging.h"
+#include "base/sys_byteorder.h"
+#include "base/threading/platform_thread.h"
+#include "build/build_config.h"
+#include "net/base/net_util.h"
+#include "net/base/winsock_init.h"
+
+using std::string;
+
+namespace net {
+
+// static
+scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen(
+ const string& ip, int port, StreamListenSocket::Delegate* del) {
+ SocketDescriptor s = CreateAndBind(ip, port);
+ if (s == kInvalidSocket)
+ return NULL;
+ scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del));
+ sock->Listen();
+ return sock;
+}
+
+TCPListenSocket::TCPListenSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del)
+ : StreamListenSocket(s, del) {
+}
+
+TCPListenSocket::~TCPListenSocket() {}
+
+SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) {
+#if defined(OS_WIN)
+ EnsureWinsockInit();
+#endif
+
+ SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (s != kInvalidSocket) {
+#if defined(OS_POSIX)
+ // Allow rapid reuse.
+ static const int kOn = 1;
+ setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &kOn, sizeof(kOn));
+#endif
+ sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = inet_addr(ip.c_str());
+ addr.sin_port = base::HostToNet16(port);
+ if (bind(s, reinterpret_cast<sockaddr*>(&addr), sizeof(addr))) {
+#if defined(OS_WIN)
+ closesocket(s);
+#elif defined(OS_POSIX)
+ close(s);
+#endif
+ LOG(ERROR) << "Could not bind socket to " << ip << ":" << port;
+ s = kInvalidSocket;
+ }
+ }
+ return s;
+}
+
+SocketDescriptor TCPListenSocket::CreateAndBindAnyPort(const string& ip,
+ int* port) {
+ SocketDescriptor s = CreateAndBind(ip, 0);
+ if (s == kInvalidSocket)
+ return kInvalidSocket;
+ sockaddr_in addr;
+ socklen_t addr_size = sizeof(addr);
+ bool failed = getsockname(s, reinterpret_cast<struct sockaddr*>(&addr),
+ &addr_size) != 0;
+ if (addr_size != sizeof(addr))
+ failed = true;
+ if (failed) {
+ LOG(ERROR) << "Could not determine bound port, getsockname() failed";
+#if defined(OS_WIN)
+ closesocket(s);
+#elif defined(OS_POSIX)
+ close(s);
+#endif
+ return kInvalidSocket;
+ }
+ *port = base::NetToHost16(addr.sin_port);
+ return s;
+}
+
+void TCPListenSocket::Accept() {
+ SocketDescriptor conn = AcceptSocket();
+ if (conn == kInvalidSocket)
+ return;
+ scoped_refptr<TCPListenSocket> sock(
+ new TCPListenSocket(conn, socket_delegate_));
+ // It's up to the delegate to AddRef if it wants to keep it around.
+#if defined(OS_POSIX)
+ sock->WatchSocket(WAITING_READ);
+#endif
+ socket_delegate_->DidAccept(this, sock.get());
+}
+
+TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port)
+ : ip_(ip),
+ port_(port) {
+}
+
+TCPListenSocketFactory::~TCPListenSocketFactory() {}
+
+scoped_refptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return TCPListenSocket::CreateAndListen(ip_, port_, delegate);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h
new file mode 100644
index 00000000000..dbc5347e945
--- /dev/null
+++ b/chromium/net/socket/tcp_listen_socket.h
@@ -0,0 +1,64 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_LISTEN_SOCKET_H_
+#define NET_SOCKET_TCP_LISTEN_SOCKET_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "net/base/net_export.h"
+#include "net/socket/stream_listen_socket.h"
+
+namespace net {
+
+// Implements a TCP socket. Note that this is ref counted.
+class NET_EXPORT TCPListenSocket : public StreamListenSocket {
+ public:
+ // Listen on port for the specified IP address. Use 127.0.0.1 to only
+ // accept local connections.
+ static scoped_refptr<TCPListenSocket> CreateAndListen(
+ const std::string& ip, int port, StreamListenSocket::Delegate* del);
+
+ // Get raw TCP socket descriptor bound to ip:port.
+ static SocketDescriptor CreateAndBind(const std::string& ip, int port);
+
+ // Get raw TCP socket descriptor bound to ip and return port it is bound to.
+ static SocketDescriptor CreateAndBindAnyPort(const std::string& ip,
+ int* port);
+
+ protected:
+ friend class scoped_refptr<TCPListenSocket>;
+
+ TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del);
+ virtual ~TCPListenSocket();
+
+ // Implements StreamListenSocket::Accept.
+ virtual void Accept() OVERRIDE;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(TCPListenSocket);
+};
+
+// Factory that can be used to instantiate TCPListenSocket.
+class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory {
+ public:
+ TCPListenSocketFactory(const std::string& ip, int port);
+ virtual ~TCPListenSocketFactory();
+
+ // StreamListenSocketFactory overrides.
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const OVERRIDE;
+
+ private:
+ const std::string ip_;
+ const int port_;
+
+ DISALLOW_COPY_AND_ASSIGN(TCPListenSocketFactory);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_LISTEN_SOCKET_H_
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc
new file mode 100644
index 00000000000..d13b784cbdc
--- /dev/null
+++ b/chromium/net/socket/tcp_listen_socket_unittest.cc
@@ -0,0 +1,291 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_listen_socket_unittest.h"
+
+#include <fcntl.h>
+#include <sys/types.h>
+
+#include "base/bind.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/sys_byteorder.h"
+#include "net/base/net_util.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+const int TCPListenSocketTester::kTestPort = 9999;
+
+static const int kReadBufSize = 1024;
+static const char kHelloWorld[] = "HELLO, WORLD";
+static const int kMaxQueueSize = 20;
+static const char kLoopback[] = "127.0.0.1";
+static const int kDefaultTimeoutMs = 5000;
+
+TCPListenSocketTester::TCPListenSocketTester()
+ : loop_(NULL), server_(NULL), connection_(NULL), cv_(&lock_) {}
+
+void TCPListenSocketTester::SetUp() {
+ base::Thread::Options options;
+ options.message_loop_type = base::MessageLoop::TYPE_IO;
+ thread_.reset(new base::Thread("socketio_test"));
+ thread_->StartWithOptions(options);
+ loop_ = reinterpret_cast<base::MessageLoopForIO*>(thread_->message_loop());
+
+ loop_->PostTask(FROM_HERE, base::Bind(
+ &TCPListenSocketTester::Listen, this));
+
+ // verify Listen succeeded
+ NextAction();
+ ASSERT_FALSE(server_.get() == NULL);
+ ASSERT_EQ(ACTION_LISTEN, last_action_.type());
+
+ // verify the connect/accept and setup test_socket_
+ test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_);
+ struct sockaddr_in client;
+ client.sin_family = AF_INET;
+ client.sin_addr.s_addr = inet_addr(kLoopback);
+ client.sin_port = base::HostToNet16(kTestPort);
+ int ret = HANDLE_EINTR(
+ connect(test_socket_, reinterpret_cast<sockaddr*>(&client),
+ sizeof(client)));
+#if defined(OS_POSIX)
+ // The connect() call may be interrupted by a signal. When connect()
+ // is retried on EINTR, it fails with EISCONN.
+ if (ret == StreamListenSocket::kSocketError)
+ ASSERT_EQ(EISCONN, errno);
+#else
+ // Don't have signals.
+ ASSERT_NE(StreamListenSocket::kSocketError, ret);
+#endif
+
+ NextAction();
+ ASSERT_EQ(ACTION_ACCEPT, last_action_.type());
+}
+
+void TCPListenSocketTester::TearDown() {
+#if defined(OS_WIN)
+ ASSERT_EQ(0, closesocket(test_socket_));
+#elif defined(OS_POSIX)
+ ASSERT_EQ(0, HANDLE_EINTR(close(test_socket_)));
+#endif
+ NextAction();
+ ASSERT_EQ(ACTION_CLOSE, last_action_.type());
+
+ loop_->PostTask(FROM_HERE, base::Bind(
+ &TCPListenSocketTester::Shutdown, this));
+ NextAction();
+ ASSERT_EQ(ACTION_SHUTDOWN, last_action_.type());
+
+ thread_.reset();
+ loop_ = NULL;
+}
+
+void TCPListenSocketTester::ReportAction(
+ const TCPListenSocketTestAction& action) {
+ base::AutoLock locked(lock_);
+ queue_.push_back(action);
+ cv_.Broadcast();
+}
+
+void TCPListenSocketTester::NextAction() {
+ base::AutoLock locked(lock_);
+ while (queue_.empty())
+ cv_.Wait();
+ last_action_ = queue_.front();
+ queue_.pop_front();
+}
+
+int TCPListenSocketTester::ClearTestSocket() {
+ char buf[kReadBufSize];
+ int len_ret = 0;
+ do {
+ int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0));
+ if (len == StreamListenSocket::kSocketError || len == 0) {
+ break;
+ } else {
+ len_ret += len;
+ }
+ } while (true);
+ return len_ret;
+}
+
+void TCPListenSocketTester::Shutdown() {
+ connection_->Release();
+ connection_ = NULL;
+ server_->Release();
+ server_ = NULL;
+ ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN));
+}
+
+void TCPListenSocketTester::Listen() {
+ server_ = DoListen();
+ ASSERT_TRUE(server_.get());
+ server_->AddRef();
+ ReportAction(TCPListenSocketTestAction(ACTION_LISTEN));
+}
+
+void TCPListenSocketTester::SendFromTester() {
+ connection_->Send(kHelloWorld);
+ ReportAction(TCPListenSocketTestAction(ACTION_SEND));
+}
+
+void TCPListenSocketTester::TestClientSend() {
+ ASSERT_TRUE(Send(test_socket_, kHelloWorld));
+ NextAction();
+ ASSERT_EQ(ACTION_READ, last_action_.type());
+ ASSERT_EQ(last_action_.data(), kHelloWorld);
+}
+
+void TCPListenSocketTester::TestClientSendLong() {
+ size_t hello_len = strlen(kHelloWorld);
+ std::string long_string;
+ size_t long_len = 0;
+ for (int i = 0; i < 200; i++) {
+ long_string += kHelloWorld;
+ long_len += hello_len;
+ }
+ ASSERT_TRUE(Send(test_socket_, long_string));
+ size_t read_len = 0;
+ while (read_len < long_len) {
+ NextAction();
+ ASSERT_EQ(ACTION_READ, last_action_.type());
+ std::string last_data = last_action_.data();
+ size_t len = last_data.length();
+ if (long_string.compare(read_len, len, last_data)) {
+ ASSERT_EQ(long_string.compare(read_len, len, last_data), 0);
+ }
+ read_len += last_data.length();
+ }
+ ASSERT_EQ(read_len, long_len);
+}
+
+void TCPListenSocketTester::TestServerSend() {
+ loop_->PostTask(FROM_HERE, base::Bind(
+ &TCPListenSocketTester::SendFromTester, this));
+ NextAction();
+ ASSERT_EQ(ACTION_SEND, last_action_.type());
+ const int buf_len = 200;
+ char buf[buf_len+1];
+ unsigned recv_len = 0;
+ while (recv_len < strlen(kHelloWorld)) {
+ int r = HANDLE_EINTR(recv(test_socket_,
+ buf + recv_len, buf_len - recv_len, 0));
+ ASSERT_GE(r, 0);
+ recv_len += static_cast<unsigned>(r);
+ if (!r)
+ break;
+ }
+ buf[recv_len] = 0;
+ ASSERT_STREQ(kHelloWorld, buf);
+}
+
+void TCPListenSocketTester::TestServerSendMultiple() {
+ // Send enough data to exceed the socket receive window. 20kb is probably a
+ // safe bet.
+ int send_count = (1024*20) / (sizeof(kHelloWorld)-1);
+
+ // Send multiple writes. Since no reading is occurring the data should be
+ // buffered in TCPListenSocket.
+ for (int i = 0; i < send_count; ++i) {
+ loop_->PostTask(FROM_HERE, base::Bind(
+ &TCPListenSocketTester::SendFromTester, this));
+ NextAction();
+ ASSERT_EQ(ACTION_SEND, last_action_.type());
+ }
+
+ // Make multiple reads. All of the data should eventually be returned.
+ char buf[sizeof(kHelloWorld)];
+ const int buf_len = sizeof(kHelloWorld);
+ for (int i = 0; i < send_count; ++i) {
+ unsigned recv_len = 0;
+ while (recv_len < buf_len-1) {
+ int r = HANDLE_EINTR(recv(test_socket_,
+ buf + recv_len, buf_len - 1 - recv_len, 0));
+ ASSERT_GE(r, 0);
+ recv_len += static_cast<unsigned>(r);
+ if (!r)
+ break;
+ }
+ buf[recv_len] = 0;
+ ASSERT_STREQ(kHelloWorld, buf);
+ }
+}
+
+bool TCPListenSocketTester::Send(SocketDescriptor sock,
+ const std::string& str) {
+ int len = static_cast<int>(str.length());
+ int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0));
+ if (send_len == StreamListenSocket::kSocketError) {
+ LOG(ERROR) << "send failed: " << errno;
+ return false;
+ } else if (send_len != len) {
+ return false;
+ }
+ return true;
+}
+
+void TCPListenSocketTester::DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) {
+ connection_ = connection;
+ connection_->AddRef();
+ ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT));
+}
+
+void TCPListenSocketTester::DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) {
+ std::string str(data, len);
+ ReportAction(TCPListenSocketTestAction(ACTION_READ, str));
+}
+
+void TCPListenSocketTester::DidClose(StreamListenSocket* sock) {
+ ReportAction(TCPListenSocketTestAction(ACTION_CLOSE));
+}
+
+TCPListenSocketTester::~TCPListenSocketTester() {}
+
+scoped_refptr<TCPListenSocket> TCPListenSocketTester::DoListen() {
+ return TCPListenSocket::CreateAndListen(kLoopback, kTestPort, this);
+}
+
+class TCPListenSocketTest: public PlatformTest {
+ public:
+ TCPListenSocketTest() {
+ tester_ = NULL;
+ }
+
+ virtual void SetUp() {
+ PlatformTest::SetUp();
+ tester_ = new TCPListenSocketTester();
+ tester_->SetUp();
+ }
+
+ virtual void TearDown() {
+ PlatformTest::TearDown();
+ tester_->TearDown();
+ tester_ = NULL;
+ }
+
+ scoped_refptr<TCPListenSocketTester> tester_;
+};
+
+TEST_F(TCPListenSocketTest, ClientSend) {
+ tester_->TestClientSend();
+}
+
+TEST_F(TCPListenSocketTest, ClientSendLong) {
+ tester_->TestClientSendLong();
+}
+
+TEST_F(TCPListenSocketTest, ServerSend) {
+ tester_->TestServerSend();
+}
+
+TEST_F(TCPListenSocketTest, ServerSendMultiple) {
+ tester_->TestServerSendMultiple();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h
new file mode 100644
index 00000000000..048a0186705
--- /dev/null
+++ b/chromium/net/socket/tcp_listen_socket_unittest.h
@@ -0,0 +1,122 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_BASE_LISTEN_SOCKET_UNITTEST_H_
+#define NET_BASE_LISTEN_SOCKET_UNITTEST_H_
+
+#include "build/build_config.h"
+
+#if defined(OS_WIN)
+#include <winsock2.h>
+#elif defined(OS_POSIX)
+#include <arpa/inet.h>
+#include <errno.h>
+#include <sys/socket.h>
+#endif
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/strings/string_util.h"
+#include "base/synchronization/condition_variable.h"
+#include "base/synchronization/lock.h"
+#include "base/threading/thread.h"
+#include "net/base/net_util.h"
+#include "net/base/winsock_init.h"
+#include "net/socket/tcp_listen_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+enum ActionType {
+ ACTION_NONE = 0,
+ ACTION_LISTEN = 1,
+ ACTION_ACCEPT = 2,
+ ACTION_READ = 3,
+ ACTION_SEND = 4,
+ ACTION_CLOSE = 5,
+ ACTION_SHUTDOWN = 6
+};
+
+class TCPListenSocketTestAction {
+ public:
+ TCPListenSocketTestAction() : action_(ACTION_NONE) {}
+ explicit TCPListenSocketTestAction(ActionType action) : action_(action) {}
+ TCPListenSocketTestAction(ActionType action, std::string data)
+ : action_(action),
+ data_(data) {}
+
+ const std::string data() const { return data_; }
+ ActionType type() const { return action_; }
+
+ private:
+ ActionType action_;
+ std::string data_;
+};
+
+
+// This had to be split out into a separate class because I couldn't
+// make the testing::Test class refcounted.
+class TCPListenSocketTester :
+ public StreamListenSocket::Delegate,
+ public base::RefCountedThreadSafe<TCPListenSocketTester> {
+
+ public:
+ TCPListenSocketTester();
+
+ void SetUp();
+ void TearDown();
+
+ void ReportAction(const TCPListenSocketTestAction& action);
+ void NextAction();
+
+ // read all pending data from the test socket
+ int ClearTestSocket();
+ // Release the connection and server sockets
+ void Shutdown();
+ void Listen();
+ void SendFromTester();
+ // verify the send/read from client to server
+ void TestClientSend();
+ // verify send/read of a longer string
+ void TestClientSendLong();
+ // verify a send/read from server to client
+ void TestServerSend();
+ // verify multiple sends and reads from server to client.
+ void TestServerSendMultiple();
+
+ virtual bool Send(SocketDescriptor sock, const std::string& str);
+
+ // StreamListenSocket::Delegate:
+ virtual void DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) OVERRIDE;
+ virtual void DidRead(StreamListenSocket* connection, const char* data,
+ int len) OVERRIDE;
+ virtual void DidClose(StreamListenSocket* sock) OVERRIDE;
+
+ scoped_ptr<base::Thread> thread_;
+ base::MessageLoopForIO* loop_;
+ scoped_refptr<TCPListenSocket> server_;
+ StreamListenSocket* connection_;
+ TCPListenSocketTestAction last_action_;
+
+ SocketDescriptor test_socket_;
+ static const int kTestPort;
+
+ base::Lock lock_; // protects |queue_| and wraps |cv_|
+ base::ConditionVariable cv_;
+ std::deque<TCPListenSocketTestAction> queue_;
+
+ protected:
+ friend class base::RefCountedThreadSafe<TCPListenSocketTester>;
+
+ virtual ~TCPListenSocketTester();
+
+ virtual scoped_refptr<TCPListenSocket> DoListen();
+};
+
+} // namespace net
+
+#endif // NET_BASE_LISTEN_SOCKET_UNITTEST_H_
diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h
new file mode 100644
index 00000000000..4970a150e8d
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket.h
@@ -0,0 +1,26 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_
+#define NET_SOCKET_TCP_SERVER_SOCKET_H_
+
+#include "build/build_config.h"
+
+#if defined(OS_WIN)
+#include "net/socket/tcp_server_socket_win.h"
+#elif defined(OS_POSIX)
+#include "net/socket/tcp_server_socket_libevent.h"
+#endif
+
+namespace net {
+
+#if defined(OS_WIN)
+typedef TCPServerSocketWin TCPServerSocket;
+#elif defined(OS_POSIX)
+typedef TCPServerSocketLibevent TCPServerSocket;
+#endif
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SERVER_SOCKET_H_
diff --git a/chromium/net/socket/tcp_server_socket_libevent.cc b/chromium/net/socket/tcp_server_socket_libevent.cc
new file mode 100644
index 00000000000..38dda962f46
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket_libevent.cc
@@ -0,0 +1,223 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_server_socket_libevent.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <sys/socket.h>
+
+#include "build/build_config.h"
+
+#if defined(OS_POSIX)
+#include <netinet/in.h>
+#endif
+
+#include "base/posix/eintr_wrapper.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/socket/socket_net_log_params.h"
+#include "net/socket/tcp_client_socket.h"
+
+namespace net {
+
+namespace {
+
+const int kInvalidSocket = -1;
+
+} // namespace
+
+TCPServerSocketLibevent::TCPServerSocketLibevent(
+ net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(kInvalidSocket),
+ accept_socket_(NULL),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+ net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
+ source.ToEventParametersCallback());
+}
+
+TCPServerSocketLibevent::~TCPServerSocketLibevent() {
+ if (socket_ != kInvalidSocket)
+ Close();
+ net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
+}
+
+int TCPServerSocketLibevent::Listen(const IPEndPoint& address, int backlog) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_GT(backlog, 0);
+ DCHECK_EQ(socket_, kInvalidSocket);
+
+ socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP);
+ if (socket_ < 0) {
+ PLOG(ERROR) << "socket() returned an error";
+ return MapSystemError(errno);
+ }
+
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(errno);
+ Close();
+ return result;
+ }
+
+ int result = SetSocketOptions();
+ if (result != OK) {
+ Close();
+ return result;
+ }
+
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len)) {
+ Close();
+ return ERR_ADDRESS_INVALID;
+ }
+
+ result = bind(socket_, storage.addr, storage.addr_len);
+ if (result < 0) {
+ PLOG(ERROR) << "bind() returned an error";
+ result = MapSystemError(errno);
+ Close();
+ return result;
+ }
+
+ result = listen(socket_, backlog);
+ if (result < 0) {
+ PLOG(ERROR) << "listen() returned an error";
+ result = MapSystemError(errno);
+ Close();
+ return result;
+ }
+
+ return OK;
+}
+
+int TCPServerSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len) < 0)
+ return MapSystemError(errno);
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_FAILED;
+
+ return OK;
+}
+
+int TCPServerSocketLibevent::Accept(
+ scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(socket);
+ DCHECK(!callback.is_null());
+ DCHECK(accept_callback_.is_null());
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
+
+ int result = AcceptInternal(socket);
+
+ if (result == ERR_IO_PENDING) {
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_, true, base::MessageLoopForIO::WATCH_READ,
+ &accept_socket_watcher_, this)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on read";
+ return MapSystemError(errno);
+ }
+
+ accept_socket_ = socket;
+ accept_callback_ = callback;
+ }
+
+ return result;
+}
+
+int TCPServerSocketLibevent::SetSocketOptions() {
+ // SO_REUSEADDR is useful for server sockets to bind to a recently unbound
+ // port. When a socket is closed, the end point changes its state to TIME_WAIT
+ // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer
+ // acknowledges its closure. For server sockets, it is usually safe to
+ // bind to a TIME_WAIT end point immediately, which is a widely adopted
+ // behavior.
+ //
+ // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to
+ // an end point that is already bound by another socket. To do that one must
+ // set SO_REUSEPORT instead. This option is not provided on Linux prior
+ // to 3.9.
+ //
+ // SO_REUSEPORT is provided in MacOS X and iOS.
+ int true_value = 1;
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &true_value,
+ sizeof(true_value));
+ if (rv < 0)
+ return MapSystemError(errno);
+ return OK;
+}
+
+int TCPServerSocketLibevent::AcceptInternal(
+ scoped_ptr<StreamSocket>* socket) {
+ SockaddrStorage storage;
+ int new_socket = HANDLE_EINTR(accept(socket_,
+ storage.addr,
+ &storage.addr_len));
+ if (new_socket < 0) {
+ int net_error = MapSystemError(errno);
+ if (net_error != ERR_IO_PENDING)
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
+ return net_error;
+ }
+
+ IPEndPoint address;
+ if (!address.FromSockAddr(storage.addr, storage.addr_len)) {
+ NOTREACHED();
+ if (HANDLE_EINTR(close(new_socket)) < 0)
+ PLOG(ERROR) << "close";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED);
+ return ERR_FAILED;
+ }
+ scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket(
+ AddressList(address),
+ net_log_.net_log(), net_log_.source()));
+ int adopt_result = tcp_socket->AdoptSocket(new_socket);
+ if (adopt_result != OK) {
+ if (HANDLE_EINTR(close(new_socket)) < 0)
+ PLOG(ERROR) << "close";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
+ return adopt_result;
+ }
+ socket->reset(tcp_socket.release());
+ net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
+ CreateNetLogIPEndPointCallback(&address));
+ return OK;
+}
+
+void TCPServerSocketLibevent::Close() {
+ if (socket_ != kInvalidSocket) {
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ if (HANDLE_EINTR(close(socket_)) < 0)
+ PLOG(ERROR) << "close";
+ socket_ = kInvalidSocket;
+ }
+}
+
+void TCPServerSocketLibevent::OnFileCanReadWithoutBlocking(int fd) {
+ DCHECK(CalledOnValidThread());
+
+ int result = AcceptInternal(accept_socket_);
+ if (result != ERR_IO_PENDING) {
+ accept_socket_ = NULL;
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ CompletionCallback callback = accept_callback_;
+ accept_callback_.Reset();
+ callback.Run(result);
+ }
+}
+
+void TCPServerSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) {
+ NOTREACHED();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_libevent.h b/chromium/net/socket/tcp_server_socket_libevent.h
new file mode 100644
index 00000000000..fe69472a653
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket_libevent.h
@@ -0,0 +1,55 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
+#define NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
+
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_log.h"
+#include "net/socket/server_socket.h"
+
+namespace net {
+
+class IPEndPoint;
+
+class NET_EXPORT_PRIVATE TCPServerSocketLibevent :
+ public ServerSocket,
+ public base::NonThreadSafe,
+ public base::MessageLoopForIO::Watcher {
+ public:
+ TCPServerSocketLibevent(net::NetLog* net_log,
+ const net::NetLog::Source& source);
+ virtual ~TCPServerSocketLibevent();
+
+ // net::ServerSocket implementation.
+ virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // MessageLoopForIO::Watcher implementation.
+ virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
+ virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
+
+ private:
+ int SetSocketOptions();
+ int AcceptInternal(scoped_ptr<StreamSocket>* socket);
+ void Close();
+
+ int socket_;
+
+ base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
+
+ scoped_ptr<StreamSocket>* accept_socket_;
+ CompletionCallback accept_callback_;
+
+ BoundNetLog net_log_;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc
new file mode 100644
index 00000000000..fd81e550d08
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket_unittest.cc
@@ -0,0 +1,251 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_server_socket.h"
+
+#include <string>
+#include <vector>
+
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/tcp_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+namespace {
+const int kListenBacklog = 5;
+
+class TCPServerSocketTest : public PlatformTest {
+ protected:
+ TCPServerSocketTest()
+ : socket_(NULL, NetLog::Source()) {
+ }
+
+ void SetUpIPv4() {
+ IPEndPoint address;
+ ParseAddress("127.0.0.1", 0, &address);
+ ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog));
+ ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
+ }
+
+ void SetUpIPv6(bool* success) {
+ *success = false;
+ IPEndPoint address;
+ ParseAddress("::1", 0, &address);
+ if (socket_.Listen(address, kListenBacklog) != 0) {
+ LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
+ "disabled. Skipping the test";
+ return;
+ }
+ ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
+ *success = true;
+ }
+
+ void ParseAddress(std::string ip_str, int port, IPEndPoint* address) {
+ IPAddressNumber ip_number;
+ bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
+ if (!rv)
+ return;
+ *address = IPEndPoint(ip_number, port);
+ }
+
+ static IPEndPoint GetPeerAddress(StreamSocket* socket) {
+ IPEndPoint address;
+ EXPECT_EQ(OK, socket->GetPeerAddress(&address));
+ return address;
+ }
+
+ AddressList local_address_list() const {
+ return AddressList(local_address_);
+ }
+
+ TCPServerSocket socket_;
+ IPEndPoint local_address_;
+};
+
+TEST_F(TCPServerSocketTest, Accept) {
+ ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+ int result = socket_.Accept(&accepted_socket, accept_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ ASSERT_TRUE(accepted_socket.get() != NULL);
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
+ local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+}
+
+// Test Accept() callback.
+TEST_F(TCPServerSocketTest, AcceptAsync) {
+ ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+
+ ASSERT_EQ(ERR_IO_PENDING,
+ socket_.Accept(&accepted_socket, accept_callback.callback()));
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+
+ EXPECT_TRUE(accepted_socket != NULL);
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
+ local_address_.address());
+}
+
+// Accept two connections simultaneously.
+TEST_F(TCPServerSocketTest, Accept2Connections) {
+ ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+
+ ASSERT_EQ(ERR_IO_PENDING,
+ socket_.Accept(&accepted_socket, accept_callback.callback()));
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback connect_callback2;
+ TCPClientSocket connecting_socket2(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket2.Connect(connect_callback2.callback());
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+
+ TestCompletionCallback accept_callback2;
+ scoped_ptr<StreamSocket> accepted_socket2;
+ int result = socket_.Accept(&accepted_socket2, accept_callback2.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback2.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+
+ EXPECT_TRUE(accepted_socket != NULL);
+ EXPECT_TRUE(accepted_socket2 != NULL);
+ EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
+
+ EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
+ local_address_.address());
+ EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
+ local_address_.address());
+}
+
+TEST_F(TCPServerSocketTest, AcceptIPv6) {
+ bool initialized = false;
+ ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
+ if (!initialized)
+ return;
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+ int result = socket_.Accept(&accepted_socket, accept_callback.callback());
+ if (result == ERR_IO_PENDING)
+ result = accept_callback.WaitForResult();
+ ASSERT_EQ(OK, result);
+
+ ASSERT_TRUE(accepted_socket.get() != NULL);
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
+ local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+}
+
+TEST_F(TCPServerSocketTest, AcceptIO) {
+ ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
+
+ TestCompletionCallback connect_callback;
+ TCPClientSocket connecting_socket(local_address_list(),
+ NULL, NetLog::Source());
+ connecting_socket.Connect(connect_callback.callback());
+
+ TestCompletionCallback accept_callback;
+ scoped_ptr<StreamSocket> accepted_socket;
+ int result = socket_.Accept(&accepted_socket, accept_callback.callback());
+ ASSERT_EQ(OK, accept_callback.GetResult(result));
+
+ ASSERT_TRUE(accepted_socket.get() != NULL);
+
+ // Both sockets should be on the loopback network interface.
+ EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
+ local_address_.address());
+
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
+
+ const std::string message("test message");
+ std::vector<char> buffer(message.size());
+
+ size_t bytes_written = 0;
+ while (bytes_written < message.size()) {
+ scoped_refptr<net::IOBufferWithSize> write_buffer(
+ new net::IOBufferWithSize(message.size() - bytes_written));
+ memmove(write_buffer->data(), message.data(), message.size());
+
+ TestCompletionCallback write_callback;
+ int write_result = accepted_socket->Write(
+ write_buffer.get(), write_buffer->size(), write_callback.callback());
+ write_result = write_callback.GetResult(write_result);
+ ASSERT_TRUE(write_result >= 0);
+ ASSERT_TRUE(bytes_written + write_result <= message.size());
+ bytes_written += write_result;
+ }
+
+ size_t bytes_read = 0;
+ while (bytes_read < message.size()) {
+ scoped_refptr<net::IOBufferWithSize> read_buffer(
+ new net::IOBufferWithSize(message.size() - bytes_read));
+ TestCompletionCallback read_callback;
+ int read_result = connecting_socket.Read(
+ read_buffer.get(), read_buffer->size(), read_callback.callback());
+ read_result = read_callback.GetResult(read_result);
+ ASSERT_TRUE(read_result >= 0);
+ ASSERT_TRUE(bytes_read + read_result <= message.size());
+ memmove(&buffer[bytes_read], read_buffer->data(), read_result);
+ bytes_read += read_result;
+ }
+
+ std::string received_message(buffer.begin(), buffer.end());
+ ASSERT_EQ(message, received_message);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_win.cc b/chromium/net/socket/tcp_server_socket_win.cc
new file mode 100644
index 00000000000..0ac77be5e81
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket_win.cc
@@ -0,0 +1,217 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_server_socket_win.h"
+
+#include <mstcpip.h>
+
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/winsock_init.h"
+#include "net/base/winsock_util.h"
+#include "net/socket/socket_net_log_params.h"
+#include "net/socket/tcp_client_socket.h"
+
+namespace net {
+
+TCPServerSocketWin::TCPServerSocketWin(net::NetLog* net_log,
+ const net::NetLog::Source& source)
+ : socket_(INVALID_SOCKET),
+ socket_event_(WSA_INVALID_EVENT),
+ accept_socket_(NULL),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+ net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
+ source.ToEventParametersCallback());
+ EnsureWinsockInit();
+}
+
+TCPServerSocketWin::~TCPServerSocketWin() {
+ Close();
+ net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
+}
+
+int TCPServerSocketWin::Listen(const IPEndPoint& address, int backlog) {
+ DCHECK(CalledOnValidThread());
+ DCHECK_GT(backlog, 0);
+ DCHECK_EQ(socket_, INVALID_SOCKET);
+ DCHECK_EQ(socket_event_, WSA_INVALID_EVENT);
+
+ socket_event_ = WSACreateEvent();
+ if (socket_event_ == WSA_INVALID_EVENT) {
+ PLOG(ERROR) << "WSACreateEvent()";
+ return ERR_FAILED;
+ }
+
+ socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP);
+ if (socket_ == INVALID_SOCKET) {
+ PLOG(ERROR) << "socket() returned an error";
+ return MapSystemError(WSAGetLastError());
+ }
+
+ if (SetNonBlocking(socket_)) {
+ int result = MapSystemError(WSAGetLastError());
+ Close();
+ return result;
+ }
+
+ int result = SetSocketOptions();
+ if (result != OK) {
+ Close();
+ return result;
+ }
+
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len)) {
+ Close();
+ return ERR_ADDRESS_INVALID;
+ }
+
+ result = bind(socket_, storage.addr, storage.addr_len);
+ if (result < 0) {
+ PLOG(ERROR) << "bind() returned an error";
+ result = MapSystemError(WSAGetLastError());
+ Close();
+ return result;
+ }
+
+ result = listen(socket_, backlog);
+ if (result < 0) {
+ PLOG(ERROR) << "listen() returned an error";
+ result = MapSystemError(WSAGetLastError());
+ Close();
+ return result;
+ }
+
+ return OK;
+}
+
+int TCPServerSocketWin::GetLocalAddress(IPEndPoint* address) const {
+ DCHECK(CalledOnValidThread());
+ DCHECK(address);
+
+ SockaddrStorage storage;
+ if (getsockname(socket_, storage.addr, &storage.addr_len))
+ return MapSystemError(WSAGetLastError());
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_FAILED;
+
+ return OK;
+}
+
+int TCPServerSocketWin::Accept(
+ scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(socket);
+ DCHECK(!callback.is_null());
+ DCHECK(accept_callback_.is_null());
+
+ net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
+
+ int result = AcceptInternal(socket);
+
+ if (result == ERR_IO_PENDING) {
+ // Start watching
+ WSAEventSelect(socket_, socket_event_, FD_ACCEPT);
+ accept_watcher_.StartWatching(socket_event_, this);
+
+ accept_socket_ = socket;
+ accept_callback_ = callback;
+ }
+
+ return result;
+}
+
+int TCPServerSocketWin::SetSocketOptions() {
+ // On Windows, a bound end point can be hijacked by another process by
+ // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE
+ // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the
+ // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another
+ // socket to forcibly bind to the end point until the end point is unbound.
+ // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE.
+ // MSDN: http://goo.gl/M6fjQ.
+ //
+ // Unlike on *nix, on Windows a TCP server socket can always bind to an end
+ // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not
+ // needed here.
+ //
+ // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end
+ // point in TIME_WAIT status. It does not have this effect for a TCP server
+ // socket.
+
+ BOOL true_value = 1;
+ int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE,
+ reinterpret_cast<const char*>(&true_value),
+ sizeof(true_value));
+ if (rv < 0)
+ return MapSystemError(errno);
+ return OK;
+}
+
+int TCPServerSocketWin::AcceptInternal(scoped_ptr<StreamSocket>* socket) {
+ SockaddrStorage storage;
+ int new_socket = accept(socket_, storage.addr, &storage.addr_len);
+ if (new_socket < 0) {
+ int net_error = MapSystemError(WSAGetLastError());
+ if (net_error != ERR_IO_PENDING)
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
+ return net_error;
+ }
+
+ IPEndPoint address;
+ if (!address.FromSockAddr(storage.addr, storage.addr_len)) {
+ NOTREACHED();
+ if (closesocket(new_socket) < 0)
+ PLOG(ERROR) << "closesocket";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED);
+ return ERR_FAILED;
+ }
+ scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket(
+ AddressList(address),
+ net_log_.net_log(), net_log_.source()));
+ int adopt_result = tcp_socket->AdoptSocket(new_socket);
+ if (adopt_result != OK) {
+ if (closesocket(new_socket) < 0)
+ PLOG(ERROR) << "closesocket";
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
+ return adopt_result;
+ }
+ socket->reset(tcp_socket.release());
+ net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
+ CreateNetLogIPEndPointCallback(&address));
+ return OK;
+}
+
+void TCPServerSocketWin::Close() {
+ if (socket_ != INVALID_SOCKET) {
+ if (closesocket(socket_) < 0)
+ PLOG(ERROR) << "closesocket";
+ socket_ = INVALID_SOCKET;
+ }
+
+ if (socket_event_) {
+ WSACloseEvent(socket_event_);
+ socket_event_ = WSA_INVALID_EVENT;
+ }
+}
+
+void TCPServerSocketWin::OnObjectSignaled(HANDLE object) {
+ WSANETWORKEVENTS ev;
+ if (WSAEnumNetworkEvents(socket_, socket_event_, &ev) == SOCKET_ERROR) {
+ PLOG(ERROR) << "WSAEnumNetworkEvents()";
+ return;
+ }
+
+ if (ev.lNetworkEvents & FD_ACCEPT) {
+ int result = AcceptInternal(accept_socket_);
+ if (result != ERR_IO_PENDING) {
+ accept_socket_ = NULL;
+ CompletionCallback callback = accept_callback_;
+ accept_callback_.Reset();
+ callback.Run(result);
+ }
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/socket/tcp_server_socket_win.h b/chromium/net/socket/tcp_server_socket_win.h
new file mode 100644
index 00000000000..5a1d378ad9b
--- /dev/null
+++ b/chromium/net/socket/tcp_server_socket_win.h
@@ -0,0 +1,58 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
+#define NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
+
+#include <winsock2.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/win/object_watcher.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_log.h"
+#include "net/socket/server_socket.h"
+
+namespace net {
+
+class IPEndPoint;
+
+class NET_EXPORT_PRIVATE TCPServerSocketWin
+ : public ServerSocket,
+ NON_EXPORTED_BASE(public base::NonThreadSafe),
+ public base::win::ObjectWatcher::Delegate {
+ public:
+ TCPServerSocketWin(net::NetLog* net_log,
+ const net::NetLog::Source& source);
+ ~TCPServerSocketWin();
+
+ // net::ServerSocket implementation.
+ virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE;
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ virtual int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) OVERRIDE;
+
+ // base::ObjectWatcher::Delegate implementation.
+ virtual void OnObjectSignaled(HANDLE object);
+
+ private:
+ int SetSocketOptions();
+ int AcceptInternal(scoped_ptr<StreamSocket>* socket);
+ void Close();
+
+ SOCKET socket_;
+ HANDLE socket_event_;
+
+ base::win::ObjectWatcher accept_watcher_;
+
+ scoped_ptr<StreamSocket>* accept_socket_;
+ CompletionCallback accept_callback_;
+
+ BoundNetLog net_log_;
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_
diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc
new file mode 100644
index 00000000000..6d0afac59fb
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_pool.cc
@@ -0,0 +1,477 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/transport_client_socket_pool.h"
+
+#include <algorithm>
+
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "base/message_loop/message_loop.h"
+#include "base/metrics/histogram.h"
+#include "base/strings/string_util.h"
+#include "base/time/time.h"
+#include "base/values.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/socket_net_log_params.h"
+#include "net/socket/tcp_client_socket.h"
+
+using base::TimeDelta;
+
+namespace net {
+
+// TODO(willchan): Base this off RTT instead of statically setting it. Note we
+// choose a timeout that is different from the backup connect job timer so they
+// don't synchronize.
+const int TransportConnectJob::kIPv6FallbackTimerInMs = 300;
+
+namespace {
+
+// Returns true iff all addresses in |list| are in the IPv6 family.
+bool AddressListOnlyContainsIPv6(const AddressList& list) {
+ DCHECK(!list.empty());
+ for (AddressList::const_iterator iter = list.begin(); iter != list.end();
+ ++iter) {
+ if (iter->GetFamily() != ADDRESS_FAMILY_IPV6)
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+TransportSocketParams::TransportSocketParams(
+ const HostPortPair& host_port_pair,
+ RequestPriority priority,
+ bool disable_resolver_cache,
+ bool ignore_limits,
+ const OnHostResolutionCallback& host_resolution_callback)
+ : destination_(host_port_pair),
+ ignore_limits_(ignore_limits),
+ host_resolution_callback_(host_resolution_callback) {
+ Initialize(priority, disable_resolver_cache);
+}
+
+TransportSocketParams::~TransportSocketParams() {}
+
+void TransportSocketParams::Initialize(RequestPriority priority,
+ bool disable_resolver_cache) {
+ destination_.set_priority(priority);
+ if (disable_resolver_cache)
+ destination_.set_allow_cached_response(false);
+}
+
+// TransportConnectJobs will time out after this many seconds. Note this is
+// the total time, including both host resolution and TCP connect() times.
+//
+// TODO(eroman): The use of this constant needs to be re-evaluated. The time
+// needed for TCPClientSocketXXX::Connect() can be arbitrarily long, since
+// the address list may contain many alternatives, and most of those may
+// timeout. Even worse, the per-connect timeout threshold varies greatly
+// between systems (anywhere from 20 seconds to 190 seconds).
+// See comment #12 at http://crbug.com/23364 for specifics.
+static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes.
+
+TransportConnectJob::TransportConnectJob(
+ const std::string& group_name,
+ const scoped_refptr<TransportSocketParams>& params,
+ base::TimeDelta timeout_duration,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ Delegate* delegate,
+ NetLog* net_log)
+ : ConnectJob(group_name, timeout_duration, delegate,
+ BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
+ params_(params),
+ client_socket_factory_(client_socket_factory),
+ resolver_(host_resolver),
+ next_state_(STATE_NONE) {
+}
+
+TransportConnectJob::~TransportConnectJob() {
+ // We don't worry about cancelling the host resolution and TCP connect, since
+ // ~SingleRequestHostResolver and ~StreamSocket will take care of it.
+}
+
+LoadState TransportConnectJob::GetLoadState() const {
+ switch (next_state_) {
+ case STATE_RESOLVE_HOST:
+ case STATE_RESOLVE_HOST_COMPLETE:
+ return LOAD_STATE_RESOLVING_HOST;
+ case STATE_TRANSPORT_CONNECT:
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ return LOAD_STATE_CONNECTING;
+ default:
+ NOTREACHED();
+ return LOAD_STATE_IDLE;
+ }
+}
+
+// static
+void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) {
+ for (AddressList::iterator i = list->begin(); i != list->end(); ++i) {
+ if (i->GetFamily() == ADDRESS_FAMILY_IPV4) {
+ std::rotate(list->begin(), i, list->end());
+ break;
+ }
+ }
+}
+
+void TransportConnectJob::OnIOComplete(int result) {
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING)
+ NotifyDelegateOfCompletion(rv); // Deletes |this|
+}
+
+int TransportConnectJob::DoLoop(int result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_RESOLVE_HOST:
+ DCHECK_EQ(OK, rv);
+ rv = DoResolveHost();
+ break;
+ case STATE_RESOLVE_HOST_COMPLETE:
+ rv = DoResolveHostComplete(rv);
+ break;
+ case STATE_TRANSPORT_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoTransportConnect();
+ break;
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ rv = DoTransportConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED();
+ rv = ERR_FAILED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ return rv;
+}
+
+int TransportConnectJob::DoResolveHost() {
+ next_state_ = STATE_RESOLVE_HOST_COMPLETE;
+ connect_timing_.dns_start = base::TimeTicks::Now();
+
+ return resolver_.Resolve(
+ params_->destination(), &addresses_,
+ base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)),
+ net_log());
+}
+
+int TransportConnectJob::DoResolveHostComplete(int result) {
+ connect_timing_.dns_end = base::TimeTicks::Now();
+ // Overwrite connection start time, since for connections that do not go
+ // through proxies, |connect_start| should not include dns lookup time.
+ connect_timing_.connect_start = connect_timing_.dns_end;
+
+ if (result == OK) {
+ // Invoke callback, and abort if it fails.
+ if (!params_->host_resolution_callback().is_null())
+ result = params_->host_resolution_callback().Run(addresses_, net_log());
+
+ if (result == OK)
+ next_state_ = STATE_TRANSPORT_CONNECT;
+ }
+ return result;
+}
+
+int TransportConnectJob::DoTransportConnect() {
+ next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
+ transport_socket_ = client_socket_factory_->CreateTransportClientSocket(
+ addresses_, net_log().net_log(), net_log().source());
+ int rv = transport_socket_->Connect(
+ base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)));
+ if (rv == ERR_IO_PENDING &&
+ addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV6 &&
+ !AddressListOnlyContainsIPv6(addresses_)) {
+ fallback_timer_.Start(FROM_HERE,
+ base::TimeDelta::FromMilliseconds(kIPv6FallbackTimerInMs),
+ this, &TransportConnectJob::DoIPv6FallbackTransportConnect);
+ }
+ return rv;
+}
+
+int TransportConnectJob::DoTransportConnectComplete(int result) {
+ if (result == OK) {
+ bool is_ipv4 = addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV4;
+ DCHECK(!connect_timing_.connect_start.is_null());
+ DCHECK(!connect_timing_.dns_start.is_null());
+ base::TimeTicks now = base::TimeTicks::Now();
+ base::TimeDelta total_duration = now - connect_timing_.dns_start;
+ UMA_HISTOGRAM_CUSTOM_TIMES(
+ "Net.DNS_Resolution_And_TCP_Connection_Latency2",
+ total_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ base::TimeDelta connect_duration = now - connect_timing_.connect_start;
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ if (is_ipv4) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ } else {
+ if (AddressListOnlyContainsIPv6(addresses_)) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ } else {
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ }
+ }
+ SetSocket(transport_socket_.Pass());
+ fallback_timer_.Stop();
+ } else {
+ // Be a bit paranoid and kill off the fallback members to prevent reuse.
+ fallback_transport_socket_.reset();
+ fallback_addresses_.reset();
+ }
+
+ return result;
+}
+
+void TransportConnectJob::DoIPv6FallbackTransportConnect() {
+ // The timer should only fire while we're waiting for the main connect to
+ // succeed.
+ if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) {
+ NOTREACHED();
+ return;
+ }
+
+ DCHECK(!fallback_transport_socket_.get());
+ DCHECK(!fallback_addresses_.get());
+
+ fallback_addresses_.reset(new AddressList(addresses_));
+ MakeAddressListStartWithIPv4(fallback_addresses_.get());
+ fallback_transport_socket_ =
+ client_socket_factory_->CreateTransportClientSocket(
+ *fallback_addresses_, net_log().net_log(), net_log().source());
+ fallback_connect_start_time_ = base::TimeTicks::Now();
+ int rv = fallback_transport_socket_->Connect(
+ base::Bind(
+ &TransportConnectJob::DoIPv6FallbackTransportConnectComplete,
+ base::Unretained(this)));
+ if (rv != ERR_IO_PENDING)
+ DoIPv6FallbackTransportConnectComplete(rv);
+}
+
+void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
+ // This should only happen when we're waiting for the main connect to succeed.
+ if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) {
+ NOTREACHED();
+ return;
+ }
+
+ DCHECK_NE(ERR_IO_PENDING, result);
+ DCHECK(fallback_transport_socket_.get());
+ DCHECK(fallback_addresses_.get());
+
+ if (result == OK) {
+ DCHECK(!fallback_connect_start_time_.is_null());
+ DCHECK(!connect_timing_.dns_start.is_null());
+ base::TimeTicks now = base::TimeTicks::Now();
+ base::TimeDelta total_duration = now - connect_timing_.dns_start;
+ UMA_HISTOGRAM_CUSTOM_TIMES(
+ "Net.DNS_Resolution_And_TCP_Connection_Latency2",
+ total_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ base::TimeDelta connect_duration = now - fallback_connect_start_time_;
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ SetSocket(fallback_transport_socket_.Pass());
+ next_state_ = STATE_NONE;
+ transport_socket_.reset();
+ } else {
+ // Be a bit paranoid and kill off the fallback members to prevent reuse.
+ fallback_transport_socket_.reset();
+ fallback_addresses_.reset();
+ }
+ NotifyDelegateOfCompletion(result); // Deletes |this|
+}
+
+int TransportConnectJob::ConnectInternal() {
+ next_state_ = STATE_RESOLVE_HOST;
+ return DoLoop(OK);
+}
+
+scoped_ptr<ConnectJob>
+ TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const {
+ return scoped_ptr<ConnectJob>(
+ new TransportConnectJob(group_name,
+ request.params(),
+ ConnectionTimeout(),
+ client_socket_factory_,
+ host_resolver_,
+ delegate,
+ net_log_));
+}
+
+base::TimeDelta
+ TransportClientSocketPool::TransportConnectJobFactory::ConnectionTimeout()
+ const {
+ return base::TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds);
+}
+
+TransportClientSocketPool::TransportClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ ClientSocketFactory* client_socket_factory,
+ NetLog* net_log)
+ : base_(max_sockets, max_sockets_per_group, histograms,
+ ClientSocketPool::unused_idle_socket_timeout(),
+ ClientSocketPool::used_idle_socket_timeout(),
+ new TransportConnectJobFactory(client_socket_factory,
+ host_resolver, net_log)) {
+ base_.EnableConnectBackupJobs();
+}
+
+TransportClientSocketPool::~TransportClientSocketPool() {}
+
+int TransportClientSocketPool::RequestSocket(
+ const std::string& group_name,
+ const void* params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) {
+ const scoped_refptr<TransportSocketParams>* casted_params =
+ static_cast<const scoped_refptr<TransportSocketParams>*>(params);
+
+ if (net_log.IsLoggingAllEvents()) {
+ // TODO(eroman): Split out the host and port parameters.
+ net_log.AddEvent(
+ NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKET,
+ CreateNetLogHostPortPairCallback(
+ &casted_params->get()->destination().host_port_pair()));
+ }
+
+ return base_.RequestSocket(group_name, *casted_params, priority, handle,
+ callback, net_log);
+}
+
+void TransportClientSocketPool::RequestSockets(
+ const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ const scoped_refptr<TransportSocketParams>* casted_params =
+ static_cast<const scoped_refptr<TransportSocketParams>*>(params);
+
+ if (net_log.IsLoggingAllEvents()) {
+ // TODO(eroman): Split out the host and port parameters.
+ net_log.AddEvent(
+ NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKETS,
+ CreateNetLogHostPortPairCallback(
+ &casted_params->get()->destination().host_port_pair()));
+ }
+
+ base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
+}
+
+void TransportClientSocketPool::CancelRequest(
+ const std::string& group_name,
+ ClientSocketHandle* handle) {
+ base_.CancelRequest(group_name, handle);
+}
+
+void TransportClientSocketPool::ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ base_.ReleaseSocket(group_name, socket.Pass(), id);
+}
+
+void TransportClientSocketPool::FlushWithError(int error) {
+ base_.FlushWithError(error);
+}
+
+bool TransportClientSocketPool::IsStalled() const {
+ return base_.IsStalled();
+}
+
+void TransportClientSocketPool::CloseIdleSockets() {
+ base_.CloseIdleSockets();
+}
+
+int TransportClientSocketPool::IdleSocketCount() const {
+ return base_.idle_socket_count();
+}
+
+int TransportClientSocketPool::IdleSocketCountInGroup(
+ const std::string& group_name) const {
+ return base_.IdleSocketCountInGroup(group_name);
+}
+
+LoadState TransportClientSocketPool::GetLoadState(
+ const std::string& group_name, const ClientSocketHandle* handle) const {
+ return base_.GetLoadState(group_name, handle);
+}
+
+void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
+ base_.AddLayeredPool(layered_pool);
+}
+
+void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
+ base_.RemoveLayeredPool(layered_pool);
+}
+
+base::DictionaryValue* TransportClientSocketPool::GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const {
+ return base_.GetInfoAsValue(name, type);
+}
+
+base::TimeDelta TransportClientSocketPool::ConnectionTimeout() const {
+ return base_.ConnectionTimeout();
+}
+
+ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const {
+ return base_.histograms();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h
new file mode 100644
index 00000000000..f07dc1f5675
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_pool.h
@@ -0,0 +1,221 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
+#define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "base/timer/timer.h"
+#include "net/base/host_port_pair.h"
+#include "net/dns/host_resolver.h"
+#include "net/dns/single_request_host_resolver.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/client_socket_pool_histograms.h"
+
+namespace net {
+
+class ClientSocketFactory;
+
+typedef base::Callback<int(const AddressList&, const BoundNetLog& net_log)>
+OnHostResolutionCallback;
+
+class NET_EXPORT_PRIVATE TransportSocketParams
+ : public base::RefCounted<TransportSocketParams> {
+ public:
+ // |host_resolution_callback| will be invoked after the the hostname is
+ // resolved. If |host_resolution_callback| does not return OK, then the
+ // connection will be aborted with that value.
+ TransportSocketParams(
+ const HostPortPair& host_port_pair,
+ RequestPriority priority,
+ bool disable_resolver_cache,
+ bool ignore_limits,
+ const OnHostResolutionCallback& host_resolution_callback);
+
+ const HostResolver::RequestInfo& destination() const { return destination_; }
+ bool ignore_limits() const { return ignore_limits_; }
+ const OnHostResolutionCallback& host_resolution_callback() const {
+ return host_resolution_callback_;
+ }
+
+ private:
+ friend class base::RefCounted<TransportSocketParams>;
+ ~TransportSocketParams();
+
+ void Initialize(RequestPriority priority, bool disable_resolver_cache);
+
+ HostResolver::RequestInfo destination_;
+ bool ignore_limits_;
+ const OnHostResolutionCallback host_resolution_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(TransportSocketParams);
+};
+
+// TransportConnectJob handles the host resolution necessary for socket creation
+// and the transport (likely TCP) connect. TransportConnectJob also has fallback
+// logic for IPv6 connect() timeouts (which may happen due to networks / routers
+// with broken IPv6 support). Those timeouts take 20s, so rather than make the
+// user wait 20s for the timeout to fire, we use a fallback timer
+// (kIPv6FallbackTimerInMs) and start a connect() to a IPv4 address if the timer
+// fires. Then we race the IPv4 connect() against the IPv6 connect() (which has
+// a headstart) and return the one that completes first to the socket pool.
+class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
+ public:
+ TransportConnectJob(const std::string& group_name,
+ const scoped_refptr<TransportSocketParams>& params,
+ base::TimeDelta timeout_duration,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ Delegate* delegate,
+ NetLog* net_log);
+ virtual ~TransportConnectJob();
+
+ // ConnectJob methods.
+ virtual LoadState GetLoadState() const OVERRIDE;
+
+ // Rolls |addrlist| forward until the first IPv4 address, if any.
+ // WARNING: this method should only be used to implement the prefer-IPv4 hack.
+ static void MakeAddressListStartWithIPv4(AddressList* addrlist);
+
+ static const int kIPv6FallbackTimerInMs;
+
+ private:
+ enum State {
+ STATE_RESOLVE_HOST,
+ STATE_RESOLVE_HOST_COMPLETE,
+ STATE_TRANSPORT_CONNECT,
+ STATE_TRANSPORT_CONNECT_COMPLETE,
+ STATE_NONE,
+ };
+
+ void OnIOComplete(int result);
+
+ // Runs the state transition loop.
+ int DoLoop(int result);
+
+ int DoResolveHost();
+ int DoResolveHostComplete(int result);
+ int DoTransportConnect();
+ int DoTransportConnectComplete(int result);
+
+ // Not part of the state machine.
+ void DoIPv6FallbackTransportConnect();
+ void DoIPv6FallbackTransportConnectComplete(int result);
+
+ // Begins the host resolution and the TCP connect. Returns OK on success
+ // and ERR_IO_PENDING if it cannot immediately service the request.
+ // Otherwise, it returns a net error code.
+ virtual int ConnectInternal() OVERRIDE;
+
+ scoped_refptr<TransportSocketParams> params_;
+ ClientSocketFactory* const client_socket_factory_;
+ SingleRequestHostResolver resolver_;
+ AddressList addresses_;
+ State next_state_;
+
+ scoped_ptr<StreamSocket> transport_socket_;
+
+ scoped_ptr<StreamSocket> fallback_transport_socket_;
+ scoped_ptr<AddressList> fallback_addresses_;
+ base::TimeTicks fallback_connect_start_time_;
+ base::OneShotTimer<TransportConnectJob> fallback_timer_;
+
+ DISALLOW_COPY_AND_ASSIGN(TransportConnectJob);
+};
+
+class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
+ public:
+ TransportClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ ClientSocketFactory* client_socket_factory,
+ NetLog* net_log);
+
+ virtual ~TransportClientSocketPool();
+
+ // ClientSocketPool implementation.
+ virtual int RequestSocket(const std::string& group_name,
+ const void* resolve_info,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) OVERRIDE;
+ virtual void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) OVERRIDE;
+ virtual void FlushWithError(int error) OVERRIDE;
+ virtual bool IsStalled() const OVERRIDE;
+ virtual void CloseIdleSockets() OVERRIDE;
+ virtual int IdleSocketCount() const OVERRIDE;
+ virtual int IdleSocketCountInGroup(
+ const std::string& group_name) const OVERRIDE;
+ virtual LoadState GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const OVERRIDE;
+ virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+ virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
+ virtual base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const OVERRIDE;
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+
+ private:
+ typedef ClientSocketPoolBase<TransportSocketParams> PoolBase;
+
+ class TransportConnectJobFactory
+ : public PoolBase::ConnectJobFactory {
+ public:
+ TransportConnectJobFactory(ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ NetLog* net_log)
+ : client_socket_factory_(client_socket_factory),
+ host_resolver_(host_resolver),
+ net_log_(net_log) {}
+
+ virtual ~TransportConnectJobFactory() {}
+
+ // ClientSocketPoolBase::ConnectJobFactory methods.
+
+ virtual scoped_ptr<ConnectJob> NewConnectJob(
+ const std::string& group_name,
+ const PoolBase::Request& request,
+ ConnectJob::Delegate* delegate) const OVERRIDE;
+
+ virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+
+ private:
+ ClientSocketFactory* const client_socket_factory_;
+ HostResolver* const host_resolver_;
+ NetLog* net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory);
+ };
+
+ PoolBase base_;
+
+ DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool);
+};
+
+REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool,
+ TransportSocketParams);
+
+} // namespace net
+
+#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc
new file mode 100644
index 00000000000..c607a38b78a
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc
@@ -0,0 +1,1355 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/transport_client_socket_pool.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/platform_thread.h"
+#include "net/base/capturing_net_log.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+using internal::ClientSocketPoolBaseHelper;
+
+namespace {
+
+const int kMaxSockets = 32;
+const int kMaxSocketsPerGroup = 6;
+const net::RequestPriority kDefaultPriority = LOW;
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// reused socket.
+void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ // Only pass true in as |is_reused|, as in general, HttpStream types should
+ // have stricter concepts of reuse than socket pools.
+ EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
+
+ EXPECT_TRUE(load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner
+// of a connection where |is_reused| is false may consider the connection
+// reused.
+void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
+ EXPECT_FALSE(handle.is_reused());
+
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_DNS_TIMES);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+
+ TestLoadTimingInfoConnectedReused(handle);
+}
+
+void SetIPv4Address(IPEndPoint* address) {
+ IPAddressNumber number;
+ CHECK(ParseIPLiteralToNumber("1.1.1.1", &number));
+ *address = IPEndPoint(number, 80);
+}
+
+void SetIPv6Address(IPEndPoint* address) {
+ IPAddressNumber number;
+ CHECK(ParseIPLiteralToNumber("1:abcd::3:4:ff", &number));
+ *address = IPEndPoint(number, 80);
+}
+
+class MockClientSocket : public StreamSocket {
+ public:
+ MockClientSocket(const AddressList& addrlist, net::NetLog* net_log)
+ : connected_(false),
+ addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+ }
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ connected_ = true;
+ return OK;
+ }
+ virtual void Disconnect() OVERRIDE {
+ connected_ = false;
+ }
+ virtual bool IsConnected() const OVERRIDE {
+ return connected_;
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ return connected_;
+ }
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ if (!connected_)
+ return ERR_SOCKET_NOT_CONNECTED;
+ if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
+ SetIPv4Address(address);
+ else
+ SetIPv6Address(address);
+ return OK;
+ }
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+ virtual bool WasEverUsed() const OVERRIDE { return false; }
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return false;
+ }
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; }
+
+ private:
+ bool connected_;
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+};
+
+class MockFailingClientSocket : public StreamSocket {
+ public:
+ MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
+ : addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+ }
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ return ERR_CONNECTION_FAILED;
+ }
+
+ virtual void Disconnect() OVERRIDE {}
+
+ virtual bool IsConnected() const OVERRIDE {
+ return false;
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ return false;
+ }
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+ virtual bool WasEverUsed() const OVERRIDE { return false; }
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return false;
+ }
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; }
+
+ private:
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+};
+
+class MockPendingClientSocket : public StreamSocket {
+ public:
+ // |should_connect| indicates whether the socket should successfully complete
+ // or fail.
+ // |should_stall| indicates that this socket should never connect.
+ // |delay_ms| is the delay, in milliseconds, before simulating a connect.
+ MockPendingClientSocket(
+ const AddressList& addrlist,
+ bool should_connect,
+ bool should_stall,
+ base::TimeDelta delay,
+ net::NetLog* net_log)
+ : weak_factory_(this),
+ should_connect_(should_connect),
+ should_stall_(should_stall),
+ delay_(delay),
+ is_connected_(false),
+ addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
+ }
+
+ // StreamSocket implementation.
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&MockPendingClientSocket::DoCallback,
+ weak_factory_.GetWeakPtr(), callback),
+ delay_);
+ return ERR_IO_PENDING;
+ }
+
+ virtual void Disconnect() OVERRIDE {}
+
+ virtual bool IsConnected() const OVERRIDE {
+ return is_connected_;
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ return is_connected_;
+ }
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ return ERR_UNEXPECTED;
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ if (!is_connected_)
+ return ERR_SOCKET_NOT_CONNECTED;
+ if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
+ SetIPv4Address(address);
+ else
+ SetIPv6Address(address);
+ return OK;
+ }
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ virtual void SetSubresourceSpeculation() OVERRIDE {}
+ virtual void SetOmniboxSpeculation() OVERRIDE {}
+ virtual bool WasEverUsed() const OVERRIDE { return false; }
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ return false;
+ }
+
+ // Socket implementation.
+ virtual int Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+
+ virtual int Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) OVERRIDE {
+ return ERR_FAILED;
+ }
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { return true; }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { return true; }
+
+ private:
+ void DoCallback(const CompletionCallback& callback) {
+ if (should_stall_)
+ return;
+
+ if (should_connect_) {
+ is_connected_ = true;
+ callback.Run(OK);
+ } else {
+ is_connected_ = false;
+ callback.Run(ERR_CONNECTION_FAILED);
+ }
+ }
+
+ base::WeakPtrFactory<MockPendingClientSocket> weak_factory_;
+ bool should_connect_;
+ bool should_stall_;
+ base::TimeDelta delay_;
+ bool is_connected_;
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+};
+
+class MockClientSocketFactory : public ClientSocketFactory {
+ public:
+ enum ClientSocketType {
+ MOCK_CLIENT_SOCKET,
+ MOCK_FAILING_CLIENT_SOCKET,
+ MOCK_PENDING_CLIENT_SOCKET,
+ MOCK_PENDING_FAILING_CLIENT_SOCKET,
+ // A delayed socket will pause before connecting through the message loop.
+ MOCK_DELAYED_CLIENT_SOCKET,
+ // A stalled socket that never connects at all.
+ MOCK_STALLED_CLIENT_SOCKET,
+ };
+
+ explicit MockClientSocketFactory(NetLog* net_log)
+ : net_log_(net_log), allocation_count_(0),
+ client_socket_type_(MOCK_CLIENT_SOCKET), client_socket_types_(NULL),
+ client_socket_index_(0), client_socket_index_max_(0),
+ delay_(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) OVERRIDE {
+ NOTREACHED();
+ return scoped_ptr<DatagramClientSocket>();
+ }
+
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* /* net_log */,
+ const NetLog::Source& /* source */) OVERRIDE {
+ allocation_count_++;
+
+ ClientSocketType type = client_socket_type_;
+ if (client_socket_types_ &&
+ client_socket_index_ < client_socket_index_max_) {
+ type = client_socket_types_[client_socket_index_++];
+ }
+
+ switch (type) {
+ case MOCK_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
+ case MOCK_FAILING_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockFailingClientSocket(addresses, net_log_));
+ case MOCK_PENDING_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, base::TimeDelta(), net_log_));
+ case MOCK_PENDING_FAILING_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, false, false, base::TimeDelta(), net_log_));
+ case MOCK_DELAYED_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, false, delay_, net_log_));
+ case MOCK_STALLED_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockPendingClientSocket(
+ addresses, true, true, base::TimeDelta(), net_log_));
+ default:
+ NOTREACHED();
+ return scoped_ptr<StreamSocket>(
+ new MockClientSocket(addresses, net_log_));
+ }
+ }
+
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+ }
+
+ virtual void ClearSSLSessionCache() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+
+ int allocation_count() const { return allocation_count_; }
+
+ // Set the default ClientSocketType.
+ void set_client_socket_type(ClientSocketType type) {
+ client_socket_type_ = type;
+ }
+
+ // Set a list of ClientSocketTypes to be used.
+ void set_client_socket_types(ClientSocketType* type_list, int num_types) {
+ DCHECK_GT(num_types, 0);
+ client_socket_types_ = type_list;
+ client_socket_index_ = 0;
+ client_socket_index_max_ = num_types;
+ }
+
+ void set_delay(base::TimeDelta delay) { delay_ = delay; }
+
+ private:
+ NetLog* net_log_;
+ int allocation_count_;
+ ClientSocketType client_socket_type_;
+ ClientSocketType* client_socket_types_;
+ int client_socket_index_;
+ int client_socket_index_max_;
+ base::TimeDelta delay_;
+};
+
+class TransportClientSocketPoolTest : public testing::Test {
+ protected:
+ TransportClientSocketPoolTest()
+ : connect_backup_jobs_enabled_(
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)),
+ params_(
+ new TransportSocketParams(HostPortPair("www.google.com", 80),
+ kDefaultPriority, false, false,
+ OnHostResolutionCallback())),
+ low_params_(
+ new TransportSocketParams(HostPortPair("www.google.com", 80),
+ LOW, false, false,
+ OnHostResolutionCallback())),
+ histograms_(new ClientSocketPoolHistograms("TCPUnitTest")),
+ host_resolver_(new MockHostResolver),
+ client_socket_factory_(&net_log_),
+ pool_(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL) {
+ }
+
+ virtual ~TransportClientSocketPoolTest() {
+ internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(
+ connect_backup_jobs_enabled_);
+ }
+
+ int StartRequest(const std::string& group_name, RequestPriority priority) {
+ scoped_refptr<TransportSocketParams> params(new TransportSocketParams(
+ HostPortPair("www.google.com", 80), MEDIUM, false, false,
+ OnHostResolutionCallback()));
+ return test_base_.StartRequestUsingPool(
+ &pool_, group_name, priority, params);
+ }
+
+ int GetOrderOfRequest(size_t index) {
+ return test_base_.GetOrderOfRequest(index);
+ }
+
+ bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) {
+ return test_base_.ReleaseOneConnection(keep_alive);
+ }
+
+ void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) {
+ test_base_.ReleaseAllConnections(keep_alive);
+ }
+
+ ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
+ size_t completion_count() const { return test_base_.completion_count(); }
+
+ bool connect_backup_jobs_enabled_;
+ CapturingNetLog net_log_;
+ scoped_refptr<TransportSocketParams> params_;
+ scoped_refptr<TransportSocketParams> low_params_;
+ scoped_ptr<ClientSocketPoolHistograms> histograms_;
+ scoped_ptr<MockHostResolver> host_resolver_;
+ MockClientSocketFactory client_socket_factory_;
+ TransportClientSocketPool pool_;
+ ClientSocketPoolTest test_base_;
+};
+
+TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) {
+ IPAddressNumber ip_number;
+ ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.1", &ip_number));
+ IPEndPoint addrlist_v4_1(ip_number, 80);
+ ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.2", &ip_number));
+ IPEndPoint addrlist_v4_2(ip_number, 80);
+ ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::64", &ip_number));
+ IPEndPoint addrlist_v6_1(ip_number, 80);
+ ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::66", &ip_number));
+ IPEndPoint addrlist_v6_2(ip_number, 80);
+
+ AddressList addrlist;
+
+ // Test 1: IPv4 only. Expect no change.
+ addrlist.clear();
+ addrlist.push_back(addrlist_v4_1);
+ addrlist.push_back(addrlist_v4_2);
+ TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist);
+ ASSERT_EQ(2u, addrlist.size());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily());
+
+ // Test 2: IPv6 only. Expect no change.
+ addrlist.clear();
+ addrlist.push_back(addrlist_v6_1);
+ addrlist.push_back(addrlist_v6_2);
+ TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist);
+ ASSERT_EQ(2u, addrlist.size());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[0].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[1].GetFamily());
+
+ // Test 3: IPv4 then IPv6. Expect no change.
+ addrlist.clear();
+ addrlist.push_back(addrlist_v4_1);
+ addrlist.push_back(addrlist_v4_2);
+ addrlist.push_back(addrlist_v6_1);
+ addrlist.push_back(addrlist_v6_2);
+ TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist);
+ ASSERT_EQ(4u, addrlist.size());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[2].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily());
+
+ // Test 4: IPv6, IPv4, IPv6, IPv4. Expect first IPv6 moved to the end.
+ addrlist.clear();
+ addrlist.push_back(addrlist_v6_1);
+ addrlist.push_back(addrlist_v4_1);
+ addrlist.push_back(addrlist_v6_2);
+ addrlist.push_back(addrlist_v4_2);
+ TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist);
+ ASSERT_EQ(4u, addrlist.size());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[1].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[2].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily());
+
+ // Test 5: IPv6, IPv6, IPv4, IPv4. Expect first two IPv6's moved to the end.
+ addrlist.clear();
+ addrlist.push_back(addrlist_v6_1);
+ addrlist.push_back(addrlist_v6_2);
+ addrlist.push_back(addrlist_v4_1);
+ addrlist.push_back(addrlist_v4_2);
+ TransportConnectJob::MakeAddressListStartWithIPv4(&addrlist);
+ ASSERT_EQ(4u, addrlist.size());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[0].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, addrlist[1].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[2].GetFamily());
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, addrlist[3].GetFamily());
+}
+
+TEST_F(TransportClientSocketPoolTest, Basic) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoConnectedNotReused(handle);
+}
+
+TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) {
+ host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name");
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ HostPortPair host_port_pair("unresolvable.host.name", 80);
+ scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
+ host_port_pair, kDefaultPriority, false, false,
+ OnHostResolutionCallback()));
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", dest, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult());
+}
+
+TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) {
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET);
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+
+ // Make the host resolutions complete synchronously this time.
+ host_resolver_->set_synchronous_mode(true);
+ EXPECT_EQ(ERR_CONNECTION_FAILED,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+}
+
+TEST_F(TransportClientSocketPoolTest, PendingRequests) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, (*requests())[0]->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Rest of them finish synchronously, until we reach the per-group limit.
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+
+ // The rest are pending since we've used all active sockets.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(kMaxSocketsPerGroup, client_socket_factory_.allocation_count());
+
+ // One initial asynchronous request and then 10 pending requests.
+ EXPECT_EQ(11U, completion_count());
+
+ // First part of requests, all with the same priority, finishes in FIFO order.
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+
+ // Make sure that rest of the requests complete in the order of priority.
+ EXPECT_EQ(7, GetOrderOfRequest(7));
+ EXPECT_EQ(14, GetOrderOfRequest(8));
+ EXPECT_EQ(15, GetOrderOfRequest(9));
+ EXPECT_EQ(10, GetOrderOfRequest(10));
+ EXPECT_EQ(13, GetOrderOfRequest(11));
+ EXPECT_EQ(8, GetOrderOfRequest(12));
+ EXPECT_EQ(16, GetOrderOfRequest(13));
+ EXPECT_EQ(11, GetOrderOfRequest(14));
+ EXPECT_EQ(12, GetOrderOfRequest(15));
+ EXPECT_EQ(9, GetOrderOfRequest(16));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(17));
+}
+
+TEST_F(TransportClientSocketPoolTest, PendingRequests_NoKeepAlive) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, (*requests())[0]->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Rest of them finish synchronously, until we reach the per-group limit.
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+
+ // The rest are pending since we've used all active sockets.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ // The pending requests should finish successfully.
+ EXPECT_EQ(OK, (*requests())[6]->WaitForResult());
+ EXPECT_EQ(OK, (*requests())[7]->WaitForResult());
+ EXPECT_EQ(OK, (*requests())[8]->WaitForResult());
+ EXPECT_EQ(OK, (*requests())[9]->WaitForResult());
+ EXPECT_EQ(OK, (*requests())[10]->WaitForResult());
+
+ EXPECT_EQ(static_cast<int>(requests()->size()),
+ client_socket_factory_.allocation_count());
+
+ // First asynchronous request, and then last 5 pending requests.
+ EXPECT_EQ(6U, completion_count());
+}
+
+// This test will start up a RequestSocket() and then immediately Cancel() it.
+// The pending host resolution will eventually complete, and destroy the
+// ClientSocketPool which will crash if the group was not cleared properly.
+TEST_F(TransportClientSocketPoolTest, CancelRequestClearGroup) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+ handle.Reset();
+}
+
+TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) {
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a", params_, kDefaultPriority, callback2.callback(),
+ &pool_, BoundNetLog()));
+
+ handle.Reset();
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ handle2.Reset();
+}
+
+TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) {
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, kDefaultPriority, callback.callback(),
+ &pool_, BoundNetLog()));
+
+ handle.Reset();
+
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a", params_, kDefaultPriority, callback2.callback(),
+ &pool_, BoundNetLog()));
+
+ host_resolver_->set_synchronous_mode(true);
+ // At this point, handle has two ConnectingSockets out for it. Due to the
+ // setting the mock resolver into synchronous mode, the host resolution for
+ // both will return in the same loop of the MessageLoop. The client socket
+ // is a pending socket, so the Connect() will asynchronously complete on the
+ // next loop of the MessageLoop. That means that the first
+ // ConnectingSocket will enter OnIOComplete, and then the second one will.
+ // If the first one is not cancelled, it will advance the load state, and
+ // then the second one will crash.
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_FALSE(callback.have_result());
+
+ handle.Reset();
+}
+
+TEST_F(TransportClientSocketPoolTest, CancelRequest) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, (*requests())[0]->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+
+ // Reached per-group limit, queue up requests.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", MEDIUM));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", HIGHEST));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOW));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", LOWEST));
+
+ // Cancel a request.
+ size_t index_to_cancel = kMaxSocketsPerGroup + 2;
+ EXPECT_FALSE((*requests())[index_to_cancel]->handle()->is_initialized());
+ (*requests())[index_to_cancel]->handle()->Reset();
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(kMaxSocketsPerGroup,
+ client_socket_factory_.allocation_count());
+ EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+ EXPECT_EQ(14, GetOrderOfRequest(7));
+ EXPECT_EQ(7, GetOrderOfRequest(8));
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound,
+ GetOrderOfRequest(9)); // Canceled request.
+ EXPECT_EQ(9, GetOrderOfRequest(10));
+ EXPECT_EQ(10, GetOrderOfRequest(11));
+ EXPECT_EQ(11, GetOrderOfRequest(12));
+ EXPECT_EQ(8, GetOrderOfRequest(13));
+ EXPECT_EQ(12, GetOrderOfRequest(14));
+ EXPECT_EQ(13, GetOrderOfRequest(15));
+ EXPECT_EQ(15, GetOrderOfRequest(16));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(17));
+}
+
+class RequestSocketCallback : public TestCompletionCallbackBase {
+ public:
+ RequestSocketCallback(ClientSocketHandle* handle,
+ TransportClientSocketPool* pool)
+ : handle_(handle),
+ pool_(pool),
+ within_callback_(false),
+ callback_(base::Bind(&RequestSocketCallback::OnComplete,
+ base::Unretained(this))) {
+ }
+
+ virtual ~RequestSocketCallback() {}
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ SetResult(result);
+ ASSERT_EQ(OK, result);
+
+ if (!within_callback_) {
+ // Don't allow reuse of the socket. Disconnect it and then release it and
+ // run through the MessageLoop once to get it completely released.
+ handle_->socket()->Disconnect();
+ handle_->Reset();
+ {
+ base::MessageLoop::ScopedNestableTaskAllower allow(
+ base::MessageLoop::current());
+ base::MessageLoop::current()->RunUntilIdle();
+ }
+ within_callback_ = true;
+ scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
+ HostPortPair("www.google.com", 80), LOWEST, false, false,
+ OnHostResolutionCallback()));
+ int rv = handle_->Init("a", dest, LOWEST, callback(), pool_,
+ BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ }
+ }
+
+ ClientSocketHandle* const handle_;
+ TransportClientSocketPool* const pool_;
+ bool within_callback_;
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(RequestSocketCallback);
+};
+
+TEST_F(TransportClientSocketPoolTest, RequestTwice) {
+ ClientSocketHandle handle;
+ RequestSocketCallback callback(&handle, &pool_);
+ scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
+ HostPortPair("www.google.com", 80), LOWEST, false, false,
+ OnHostResolutionCallback()));
+ int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_,
+ BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ // The callback is going to request "www.google.com". We want it to complete
+ // synchronously this time.
+ host_resolver_->set_synchronous_mode(true);
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ handle.Reset();
+}
+
+// Make sure that pending requests get serviced after active requests get
+// cancelled.
+TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) {
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+
+ // Queue up all the requests
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ // Now, kMaxSocketsPerGroup requests should be active. Let's cancel them.
+ ASSERT_LE(kMaxSocketsPerGroup, static_cast<int>(requests()->size()));
+ for (int i = 0; i < kMaxSocketsPerGroup; i++)
+ (*requests())[i]->handle()->Reset();
+
+ // Let's wait for the rest to complete now.
+ for (size_t i = kMaxSocketsPerGroup; i < requests()->size(); ++i) {
+ EXPECT_EQ(OK, (*requests())[i]->WaitForResult());
+ (*requests())[i]->handle()->Reset();
+ }
+
+ EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count());
+}
+
+// Make sure that pending requests get serviced after active requests fail.
+TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) {
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET);
+
+ const int kNumRequests = 2 * kMaxSocketsPerGroup + 1;
+ ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang.
+
+ // Queue up all the requests
+ for (int i = 0; i < kNumRequests; i++)
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ for (int i = 0; i < kNumRequests; i++)
+ EXPECT_EQ(ERR_CONNECTION_FAILED, (*requests())[i]->WaitForResult());
+}
+
+TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoConnectedNotReused(handle);
+
+ handle.Reset();
+ // Need to run all pending to release the socket back to the pool.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Now we should have 1 idle socket.
+ EXPECT_EQ(1, pool_.IdleSocketCount());
+
+ rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+ TestLoadTimingInfoConnectedReused(handle);
+}
+
+TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+
+ handle.Reset();
+
+ // Need to run all pending to release the socket back to the pool.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Now we should have 1 idle socket.
+ EXPECT_EQ(1, pool_.IdleSocketCount());
+
+ // After an IP address change, we should have 0 idle sockets.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+}
+
+TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) {
+ // Case 1 tests the first socket stalling, and the backup connecting.
+ MockClientSocketFactory::ClientSocketType case1_types[] = {
+ // The first socket will not connect.
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ // The second socket will connect more quickly.
+ MockClientSocketFactory::MOCK_CLIENT_SOCKET
+ };
+
+ // Case 2 tests the first socket being slow, so that we start the
+ // second connect, but the second connect stalls, and we still
+ // complete the first.
+ MockClientSocketFactory::ClientSocketType case2_types[] = {
+ // The first socket will connect, although delayed.
+ MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ // The second socket will not connect.
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
+ };
+
+ MockClientSocketFactory::ClientSocketType* cases[2] = {
+ case1_types,
+ case2_types
+ };
+
+ for (size_t index = 0; index < arraysize(cases); ++index) {
+ client_socket_factory_.set_client_socket_types(cases[index], 2);
+
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ // Create the first socket, set the timer.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Wait for the backup socket timer to fire.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs + 50));
+
+ // Let the appropriate socket connect.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+
+ // One socket is stalled, the other is active.
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+ handle.Reset();
+
+ // Close all pending connect jobs and existing sockets.
+ pool_.FlushWithError(ERR_NETWORK_CHANGED);
+ }
+}
+
+// Test the case where a socket took long enough to start the creation
+// of the backup socket, but then we cancelled the request after that.
+TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) {
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET);
+
+ enum { CANCEL_BEFORE_WAIT, CANCEL_AFTER_WAIT };
+
+ for (int index = CANCEL_BEFORE_WAIT; index < CANCEL_AFTER_WAIT; ++index) {
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("c", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ // Create the first socket, set the timer.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ if (index == CANCEL_AFTER_WAIT) {
+ // Wait for the backup socket timer to fire.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs));
+ }
+
+ // Let the appropriate socket connect.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ handle.Reset();
+
+ EXPECT_FALSE(callback.have_result());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ // One socket is stalled, the other is active.
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+ }
+}
+
+// Test the case where a socket took long enough to start the creation
+// of the backup socket and never completes, and then the backup
+// connection fails.
+TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) {
+ MockClientSocketFactory::ClientSocketType case_types[] = {
+ // The first socket will not connect.
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ // The second socket will fail immediately.
+ MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
+ };
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ // Create the first socket, set the timer.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Wait for the backup socket timer to fire.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs));
+
+ // Let the second connect be synchronous. Otherwise, the emulated
+ // host resolution takes an extra trip through the message loop.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Let the appropriate socket connect.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+ handle.Reset();
+
+ // Reset for the next case.
+ host_resolver_->set_synchronous_mode(false);
+}
+
+// Test the case where a socket took long enough to start the creation
+// of the backup socket and eventually completes, but the backup socket
+// fails.
+TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) {
+ MockClientSocketFactory::ClientSocketType case_types[] = {
+ // The first socket will connect, although delayed.
+ MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ // The second socket will not connect.
+ MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
+ };
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+ client_socket_factory_.set_delay(base::TimeDelta::FromSeconds(5));
+
+ EXPECT_EQ(0, pool_.IdleSocketCount());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ // Create the first socket, set the timer.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ // Wait for the backup socket timer to fire.
+ base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs));
+
+ // Let the second connect be synchronous. Otherwise, the emulated
+ // host resolution takes an extra trip through the message loop.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Let the appropriate socket connect.
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+ handle.Reset();
+
+ // Reset for the next case.
+ host_resolver_->set_synchronous_mode(false);
+}
+
+// Test the case of the IPv6 address stalling, and falling back to the IPv4
+// socket which finishes first.
+TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
+ };
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+
+ // Resolve an AddressList with a IPv6 address first and then a IPv4 address.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv4AddressSize, endpoint.address().size());
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+}
+
+// Test the case of the IPv6 address being slow, thus falling back to trying to
+// connect to the IPv4 address, but having the connect to the IPv6 address
+// finish first.
+TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
+ };
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+ client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds(
+ TransportConnectJob::kIPv6FallbackTimerInMs + 50));
+
+ // Resolve an AddressList with a IPv6 address first and then a IPv4 address.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv6AddressSize, endpoint.address().size());
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+}
+
+TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+
+ // Resolve an AddressList with only IPv6 addresses.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv6AddressSize, endpoint.address().size());
+ EXPECT_EQ(1, client_socket_factory_.allocation_count());
+}
+
+TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_client_socket_type(
+ MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+
+ // Resolve an AddressList with only IPv4 addresses.
+ host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv4AddressSize, endpoint.address().size());
+ EXPECT_EQ(1, client_socket_factory_.allocation_count());
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc
new file mode 100644
index 00000000000..5c5a303b82f
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_unittest.cc
@@ -0,0 +1,449 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/tcp_client_socket.h"
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_log_unittest.h"
+#include "net/base/test_completion_callback.h"
+#include "net/base/winsock_init.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/tcp_listen_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+namespace net {
+
+namespace {
+
+const char kServerReply[] = "HTTP/1.1 404 Not Found";
+
+enum ClientSocketTestTypes {
+ TCP,
+ SCTP
+};
+
+} // namespace
+
+class TransportClientSocketTest
+ : public StreamListenSocket::Delegate,
+ public ::testing::TestWithParam<ClientSocketTestTypes> {
+ public:
+ TransportClientSocketTest()
+ : listen_port_(0),
+ socket_factory_(ClientSocketFactory::GetDefaultFactory()),
+ close_server_socket_on_next_send_(false) {
+ }
+
+ virtual ~TransportClientSocketTest() {
+ }
+
+ // Implement StreamListenSocket::Delegate methods
+ virtual void DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) OVERRIDE {
+ connected_sock_ = reinterpret_cast<TCPListenSocket*>(connection);
+ }
+ virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE {
+ // TODO(dkegel): this might not be long enough to tickle some bugs.
+ connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1,
+ false /* Don't append line feed */);
+ if (close_server_socket_on_next_send_)
+ CloseServerSocket();
+ }
+ virtual void DidClose(StreamListenSocket* sock) OVERRIDE {}
+
+ // Testcase hooks
+ virtual void SetUp();
+
+ void CloseServerSocket() {
+ // delete the connected_sock_, which will close it.
+ connected_sock_ = NULL;
+ }
+
+ void PauseServerReads() {
+ connected_sock_->PauseReads();
+ }
+
+ void ResumeServerReads() {
+ connected_sock_->ResumeReads();
+ }
+
+ int DrainClientSocket(IOBuffer* buf,
+ uint32 buf_len,
+ uint32 bytes_to_read,
+ TestCompletionCallback* callback);
+
+ void SendClientRequest();
+
+ void set_close_server_socket_on_next_send(bool close) {
+ close_server_socket_on_next_send_ = close;
+ }
+
+ protected:
+ int listen_port_;
+ CapturingNetLog net_log_;
+ ClientSocketFactory* const socket_factory_;
+ scoped_ptr<StreamSocket> sock_;
+
+ private:
+ scoped_refptr<TCPListenSocket> listen_sock_;
+ scoped_refptr<TCPListenSocket> connected_sock_;
+ bool close_server_socket_on_next_send_;
+};
+
+void TransportClientSocketTest::SetUp() {
+ ::testing::TestWithParam<ClientSocketTestTypes>::SetUp();
+
+ // Find a free port to listen on
+ scoped_refptr<TCPListenSocket> sock;
+ int port;
+ // Range of ports to listen on. Shouldn't need to try many.
+ const int kMinPort = 10100;
+ const int kMaxPort = 10200;
+#if defined(OS_WIN)
+ EnsureWinsockInit();
+#endif
+ for (port = kMinPort; port < kMaxPort; port++) {
+ sock = TCPListenSocket::CreateAndListen("127.0.0.1", port, this);
+ if (sock.get())
+ break;
+ }
+ ASSERT_TRUE(sock.get() != NULL);
+ listen_sock_ = sock;
+ listen_port_ = port;
+
+ AddressList addr;
+ // MockHostResolver resolves everything to 127.0.0.1.
+ scoped_ptr<HostResolver> resolver(new MockHostResolver());
+ HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_));
+ TestCompletionCallback callback;
+ int rv = resolver->Resolve(info, &addr, callback.callback(), NULL,
+ BoundNetLog());
+ CHECK_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ CHECK_EQ(rv, OK);
+ sock_ =
+ socket_factory_->CreateTransportClientSocket(addr,
+ &net_log_,
+ NetLog::Source());
+}
+
+int TransportClientSocketTest::DrainClientSocket(
+ IOBuffer* buf, uint32 buf_len,
+ uint32 bytes_to_read, TestCompletionCallback* callback) {
+ int rv = OK;
+ uint32 bytes_read = 0;
+
+ while (bytes_read < bytes_to_read) {
+ rv = sock_->Read(buf, buf_len, callback->callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback->WaitForResult();
+
+ EXPECT_GE(rv, 0);
+ bytes_read += rv;
+ }
+
+ return static_cast<int>(bytes_read);
+}
+
+void TransportClientSocketTest::SendClientRequest() {
+ const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
+ scoped_refptr<IOBuffer> request_buffer(
+ new IOBuffer(arraysize(request_text) - 1));
+ TestCompletionCallback callback;
+ int rv;
+
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
+ rv = sock_->Write(
+ request_buffer.get(), arraysize(request_text) - 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, static_cast<int>(arraysize(request_text) - 1));
+}
+
+// TODO(leighton): Add SCTP to this list when it is ready.
+INSTANTIATE_TEST_CASE_P(StreamSocket,
+ TransportClientSocketTest,
+ ::testing::Values(TCP));
+
+TEST_P(TransportClientSocketTest, Connect) {
+ TestCompletionCallback callback;
+ EXPECT_FALSE(sock_->IsConnected());
+
+ int rv = sock_->Connect(callback.callback());
+
+ net::CapturingNetLog::CapturedEntryList net_log_entries;
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(net::LogContainsBeginEvent(
+ net_log_entries, 0, net::NetLog::TYPE_SOCKET_ALIVE));
+ EXPECT_TRUE(net::LogContainsBeginEvent(
+ net_log_entries, 1, net::NetLog::TYPE_TCP_CONNECT));
+ if (rv != OK) {
+ ASSERT_EQ(rv, ERR_IO_PENDING);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+
+ EXPECT_TRUE(sock_->IsConnected());
+ net_log_.GetEntries(&net_log_entries);
+ EXPECT_TRUE(net::LogContainsEndEvent(
+ net_log_entries, -1, net::NetLog::TYPE_TCP_CONNECT));
+
+ sock_->Disconnect();
+ EXPECT_FALSE(sock_->IsConnected());
+}
+
+TEST_P(TransportClientSocketTest, IsConnected) {
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ TestCompletionCallback callback;
+ uint32 bytes_read;
+
+ EXPECT_FALSE(sock_->IsConnected());
+ EXPECT_FALSE(sock_->IsConnectedAndIdle());
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(rv, ERR_IO_PENDING);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+ EXPECT_TRUE(sock_->IsConnected());
+ EXPECT_TRUE(sock_->IsConnectedAndIdle());
+
+ // Send the request and wait for the server to respond.
+ SendClientRequest();
+
+ // Drain a single byte so we know we've received some data.
+ bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
+ ASSERT_EQ(bytes_read, 1u);
+
+ // Socket should be considered connected, but not idle, due to
+ // pending data.
+ EXPECT_TRUE(sock_->IsConnected());
+ EXPECT_FALSE(sock_->IsConnectedAndIdle());
+
+ bytes_read = DrainClientSocket(
+ buf.get(), 4096, arraysize(kServerReply) - 2, &callback);
+ ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2);
+
+ // After draining the data, the socket should be back to connected
+ // and idle.
+ EXPECT_TRUE(sock_->IsConnected());
+ EXPECT_TRUE(sock_->IsConnectedAndIdle());
+
+ // This time close the server socket immediately after the server response.
+ set_close_server_socket_on_next_send(true);
+ SendClientRequest();
+
+ bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
+ ASSERT_EQ(bytes_read, 1u);
+
+ // As above because of data.
+ EXPECT_TRUE(sock_->IsConnected());
+ EXPECT_FALSE(sock_->IsConnectedAndIdle());
+
+ bytes_read = DrainClientSocket(
+ buf.get(), 4096, arraysize(kServerReply) - 2, &callback);
+ ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2);
+
+ // Once the data is drained, the socket should now be seen as not
+ // connected.
+ if (sock_->IsConnected()) {
+ // In the unlikely event that the server's connection closure is not
+ // processed in time, wait for the connection to be closed.
+ rv = sock_->Read(buf.get(), 4096, callback.callback());
+ EXPECT_EQ(0, callback.GetResult(rv));
+ EXPECT_FALSE(sock_->IsConnected());
+ }
+ EXPECT_FALSE(sock_->IsConnectedAndIdle());
+}
+
+TEST_P(TransportClientSocketTest, Read) {
+ TestCompletionCallback callback;
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(rv, ERR_IO_PENDING);
+
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+ SendClientRequest();
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ uint32 bytes_read = DrainClientSocket(
+ buf.get(), 4096, arraysize(kServerReply) - 1, &callback);
+ ASSERT_EQ(bytes_read, arraysize(kServerReply) - 1);
+
+ // All data has been read now. Read once more to force an ERR_IO_PENDING, and
+ // then close the server socket, and note the close.
+
+ rv = sock_->Read(buf.get(), 4096, callback.callback());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ CloseServerSocket();
+ EXPECT_EQ(0, callback.WaitForResult());
+}
+
+TEST_P(TransportClientSocketTest, Read_SmallChunks) {
+ TestCompletionCallback callback;
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(rv, ERR_IO_PENDING);
+
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+ SendClientRequest();
+
+ scoped_refptr<IOBuffer> buf(new IOBuffer(1));
+ uint32 bytes_read = 0;
+ while (bytes_read < arraysize(kServerReply) - 1) {
+ rv = sock_->Read(buf.get(), 1, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ ASSERT_EQ(1, rv);
+ bytes_read += rv;
+ }
+
+ // All data has been read now. Read once more to force an ERR_IO_PENDING, and
+ // then close the server socket, and note the close.
+
+ rv = sock_->Read(buf.get(), 1, callback.callback());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+ CloseServerSocket();
+ EXPECT_EQ(0, callback.WaitForResult());
+}
+
+TEST_P(TransportClientSocketTest, Read_Interrupted) {
+ TestCompletionCallback callback;
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+ SendClientRequest();
+
+ // Do a partial read and then exit. This test should not crash!
+ scoped_refptr<IOBuffer> buf(new IOBuffer(16));
+ rv = sock_->Read(buf.get(), 16, callback.callback());
+ EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+
+ EXPECT_NE(0, rv);
+}
+
+TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) {
+ TestCompletionCallback callback;
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(rv, ERR_IO_PENDING);
+
+ rv = callback.WaitForResult();
+ EXPECT_EQ(rv, OK);
+ }
+
+ // Read first. There's no data, so it should return ERR_IO_PENDING.
+ const int kBufLen = 4096;
+ scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen));
+ rv = sock_->Read(buf.get(), kBufLen, callback.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ PauseServerReads();
+ const int kWriteBufLen = 64 * 1024;
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen));
+ char* request_data = request_buffer->data();
+ memset(request_data, 'A', kWriteBufLen);
+ TestCompletionCallback write_callback;
+
+ while (true) {
+ rv = sock_->Write(
+ request_buffer.get(), kWriteBufLen, write_callback.callback());
+ ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING) {
+ ResumeServerReads();
+ rv = write_callback.WaitForResult();
+ break;
+ }
+ }
+
+ // At this point, both read and write have returned ERR_IO_PENDING, and the
+ // write callback has executed. We wait for the read callback to run now to
+ // make sure that the socket can handle full duplex communications.
+
+ rv = callback.WaitForResult();
+ EXPECT_GE(rv, 0);
+}
+
+TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) {
+ TestCompletionCallback callback;
+ int rv = sock_->Connect(callback.callback());
+ if (rv != OK) {
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ }
+
+ PauseServerReads();
+ const int kWriteBufLen = 64 * 1024;
+ scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen));
+ char* request_data = request_buffer->data();
+ memset(request_data, 'A', kWriteBufLen);
+ TestCompletionCallback write_callback;
+
+ while (true) {
+ rv = sock_->Write(
+ request_buffer.get(), kWriteBufLen, write_callback.callback());
+ ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+
+ if (rv == ERR_IO_PENDING)
+ break;
+ }
+
+ // Now we have the Write() blocked on ERR_IO_PENDING. It's time to force the
+ // Read() to block on ERR_IO_PENDING too.
+
+ const int kBufLen = 4096;
+ scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen));
+ while (true) {
+ rv = sock_->Read(buf.get(), kBufLen, callback.callback());
+ ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
+ if (rv == ERR_IO_PENDING)
+ break;
+ }
+
+ // At this point, both read and write have returned ERR_IO_PENDING. Now we
+ // run the write and read callbacks to make sure they can handle full duplex
+ // communications.
+
+ ResumeServerReads();
+ rv = write_callback.WaitForResult();
+ EXPECT_GE(rv, 0);
+
+ // It's possible the read is blocked because it's already read all the data.
+ // Close the server socket, so there will at least be a 0-byte read.
+ CloseServerSocket();
+
+ rv = callback.WaitForResult();
+ EXPECT_GE(rv, 0);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc
new file mode 100644
index 00000000000..5b6b2498245
--- /dev/null
+++ b/chromium/net/socket/unix_domain_socket_posix.cc
@@ -0,0 +1,193 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_socket_posix.h"
+
+#include <cstring>
+#include <string>
+
+#include <errno.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/threading/platform_thread.h"
+#include "build/build_config.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+namespace {
+
+bool NoAuthenticationCallback(uid_t, gid_t) {
+ return true;
+}
+
+bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ struct ucred user_cred;
+ socklen_t len = sizeof(user_cred);
+ if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
+ return false;
+ *user_id = user_cred.uid;
+ *group_id = user_cred.gid;
+#else
+ if (getpeereid(socket, user_id, group_id) == -1)
+ return false;
+#endif
+ return true;
+}
+
+} // namespace
+
+// static
+UnixDomainSocket::AuthCallback NoAuthentication() {
+ return base::Bind(NoAuthenticationCallback);
+}
+
+// static
+UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace) {
+ SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
+ if (s == kInvalidSocket && !fallback_path.empty())
+ s = CreateAndBind(fallback_path, use_abstract_namespace);
+ if (s == kInvalidSocket)
+ return NULL;
+ UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback);
+ sock->Listen();
+ return sock;
+}
+
+// static
+scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return CreateAndListenInternal(path, "", del, auth_callback, false);
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// static
+scoped_refptr<UnixDomainSocket>
+UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return make_scoped_refptr(
+ CreateAndListenInternal(path, fallback_path, del, auth_callback, true));
+}
+#endif
+
+UnixDomainSocket::UnixDomainSocket(
+ SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback)
+ : StreamListenSocket(s, del),
+ auth_callback_(auth_callback) {}
+
+UnixDomainSocket::~UnixDomainSocket() {}
+
+// static
+SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
+ bool use_abstract_namespace) {
+ sockaddr_un addr;
+ static const size_t kPathMax = sizeof(addr.sun_path);
+ if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
+ return kInvalidSocket;
+ const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0);
+ if (s == kInvalidSocket)
+ return kInvalidSocket;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ socklen_t addr_len;
+ if (use_abstract_namespace) {
+ // Convert the path given into abstract socket name. It must start with
+ // the '\0' character, so we are adding it. |addr_len| must specify the
+ // length of the structure exactly, as potentially the socket name may
+ // have '\0' characters embedded (although we don't support this).
+ // Note that addr.sun_path is already zero initialized.
+ memcpy(addr.sun_path + 1, path.c_str(), path.size());
+ addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
+ } else {
+ memcpy(addr.sun_path, path.c_str(), path.size());
+ addr_len = sizeof(sockaddr_un);
+ }
+ if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
+ LOG(ERROR) << "Could not bind unix domain socket to " << path;
+ if (use_abstract_namespace)
+ LOG(ERROR) << " (with abstract namespace enabled)";
+ if (HANDLE_EINTR(close(s)) < 0)
+ LOG(ERROR) << "close() error";
+ return kInvalidSocket;
+ }
+ return s;
+}
+
+void UnixDomainSocket::Accept() {
+ SocketDescriptor conn = StreamListenSocket::AcceptSocket();
+ if (conn == kInvalidSocket)
+ return;
+ uid_t user_id;
+ gid_t group_id;
+ if (!GetPeerIds(conn, &user_id, &group_id) ||
+ !auth_callback_.Run(user_id, group_id)) {
+ if (HANDLE_EINTR(close(conn)) < 0)
+ LOG(ERROR) << "close() error";
+ return;
+ }
+ scoped_refptr<UnixDomainSocket> sock(
+ new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
+ // It's up to the delegate to AddRef if it wants to keep it around.
+ sock->WatchSocket(WAITING_READ);
+ socket_delegate_->DidAccept(this, sock.get());
+}
+
+UnixDomainSocketFactory::UnixDomainSocketFactory(
+ const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback)
+ : path_(path),
+ auth_callback_(auth_callback) {}
+
+UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
+
+scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainSocket::CreateAndListen(
+ path_, delegate, auth_callback_);
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+
+UnixDomainSocketWithAbstractNamespaceFactory::
+UnixDomainSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const std::string& fallback_path,
+ const UnixDomainSocket::AuthCallback& auth_callback)
+ : UnixDomainSocketFactory(path, auth_callback),
+ fallback_path_(fallback_path) {}
+
+UnixDomainSocketWithAbstractNamespaceFactory::
+~UnixDomainSocketWithAbstractNamespaceFactory() {}
+
+scoped_refptr<StreamListenSocket>
+UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ path_, fallback_path_, delegate, auth_callback_);
+}
+
+#endif
+
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h
new file mode 100644
index 00000000000..2ef06803d24
--- /dev/null
+++ b/chromium/net/socket/unix_domain_socket_posix.h
@@ -0,0 +1,126 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
+#define NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback_forward.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+#include "net/socket/stream_listen_socket.h"
+
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+// Feature only supported on Linux currently. This lets the Unix Domain Socket
+// not be backed by the file system.
+#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
+#endif
+
+namespace net {
+
+// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
+class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
+ public:
+ // Callback that returns whether the already connected client, identified by
+ // its process |user_id| and |group_id|, is allowed to keep the connection
+ // open. Note that the socket is closed immediately in case the callback
+ // returns false.
+ typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback;
+
+ // Returns an authentication callback that always grants access for
+ // convenience in case you don't want to use authentication.
+ static AuthCallback NoAuthentication();
+
+ // Note that the returned UnixDomainSocket instance does not take ownership of
+ // |del|.
+ static scoped_refptr<UnixDomainSocket> CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+ // Same as above except that the created socket uses the abstract namespace
+ // which is a Linux-only feature. If |fallback_path| is not empty,
+ // make the second attempt with the provided fallback name.
+ static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+#endif
+
+ private:
+ UnixDomainSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+ virtual ~UnixDomainSocket();
+
+ static UnixDomainSocket* CreateAndListenInternal(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace);
+
+ static SocketDescriptor CreateAndBind(const std::string& path,
+ bool use_abstract_namespace);
+
+ // StreamListenSocket:
+ virtual void Accept() OVERRIDE;
+
+ AuthCallback auth_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket);
+};
+
+// Factory that can be used to instantiate UnixDomainSocket.
+class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
+ public:
+ // Note that this class does not take ownership of the provided delegate.
+ UnixDomainSocketFactory(const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback);
+ virtual ~UnixDomainSocketFactory();
+
+ // StreamListenSocketFactory:
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const OVERRIDE;
+
+ protected:
+ const std::string path_;
+ const UnixDomainSocket::AuthCallback auth_callback_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory);
+};
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// Use this factory to instantiate UnixDomainSocket using the abstract
+// namespace feature (only supported on Linux).
+class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory
+ : public UnixDomainSocketFactory {
+ public:
+ UnixDomainSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const std::string& fallback_path,
+ const UnixDomainSocket::AuthCallback& auth_callback);
+ virtual ~UnixDomainSocketWithAbstractNamespaceFactory();
+
+ // UnixDomainSocketFactory:
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const OVERRIDE;
+
+ private:
+ std::string fallback_path_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory);
+};
+#endif
+
+} // namespace net
+
+#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc
new file mode 100644
index 00000000000..5abe03b4ae3
--- /dev/null
+++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc
@@ -0,0 +1,338 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <queue>
+#include <string>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/file_util.h"
+#include "base/files/file_path.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/synchronization/condition_variable.h"
+#include "base/synchronization/lock.h"
+#include "base/threading/platform_thread.h"
+#include "base/threading/thread.h"
+#include "net/socket/unix_domain_socket_posix.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using std::queue;
+using std::string;
+
+namespace net {
+namespace {
+
+const char kSocketFilename[] = "unix_domain_socket_for_testing";
+const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2";
+const char kInvalidSocketPath[] = "/invalid/path";
+const char kMsg[] = "hello";
+
+enum EventType {
+ EVENT_ACCEPT,
+ EVENT_AUTH_DENIED,
+ EVENT_AUTH_GRANTED,
+ EVENT_CLOSE,
+ EVENT_LISTEN,
+ EVENT_READ,
+};
+
+string MakeSocketPath(const string& socket_file_name) {
+ base::FilePath temp_dir;
+ file_util::GetTempDir(&temp_dir);
+ return temp_dir.Append(socket_file_name).value();
+}
+
+string MakeSocketPath() {
+ return MakeSocketPath(kSocketFilename);
+}
+
+class EventManager : public base::RefCounted<EventManager> {
+ public:
+ EventManager() : condition_(&mutex_) {}
+
+ bool HasPendingEvent() {
+ base::AutoLock lock(mutex_);
+ return !events_.empty();
+ }
+
+ void Notify(EventType event) {
+ base::AutoLock lock(mutex_);
+ events_.push(event);
+ condition_.Broadcast();
+ }
+
+ EventType WaitForEvent() {
+ base::AutoLock lock(mutex_);
+ while (events_.empty())
+ condition_.Wait();
+ EventType event = events_.front();
+ events_.pop();
+ return event;
+ }
+
+ private:
+ friend class base::RefCounted<EventManager>;
+ virtual ~EventManager() {}
+
+ queue<EventType> events_;
+ base::Lock mutex_;
+ base::ConditionVariable condition_;
+};
+
+class TestListenSocketDelegate : public StreamListenSocket::Delegate {
+ public:
+ explicit TestListenSocketDelegate(
+ const scoped_refptr<EventManager>& event_manager)
+ : event_manager_(event_manager) {}
+
+ virtual void DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) OVERRIDE {
+ LOG(ERROR) << __PRETTY_FUNCTION__;
+ connection_ = connection;
+ Notify(EVENT_ACCEPT);
+ }
+
+ virtual void DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) OVERRIDE {
+ {
+ base::AutoLock lock(mutex_);
+ DCHECK(len);
+ data_.assign(data, len - 1);
+ }
+ Notify(EVENT_READ);
+ }
+
+ virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
+ Notify(EVENT_CLOSE);
+ }
+
+ void OnListenCompleted() {
+ Notify(EVENT_LISTEN);
+ }
+
+ string ReceivedData() {
+ base::AutoLock lock(mutex_);
+ return data_;
+ }
+
+ private:
+ void Notify(EventType event) {
+ event_manager_->Notify(event);
+ }
+
+ const scoped_refptr<EventManager> event_manager_;
+ scoped_refptr<StreamListenSocket> connection_;
+ base::Lock mutex_;
+ string data_;
+};
+
+bool UserCanConnectCallback(
+ bool allow_user, const scoped_refptr<EventManager>& event_manager,
+ uid_t, gid_t) {
+ event_manager->Notify(
+ allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
+ return allow_user;
+}
+
+class UnixDomainSocketTestHelper : public testing::Test {
+ public:
+ void CreateAndListen() {
+ socket_ = UnixDomainSocket::CreateAndListen(
+ file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
+ socket_delegate_->OnListenCompleted();
+ }
+
+ protected:
+ UnixDomainSocketTestHelper(const string& path, bool allow_user)
+ : file_path_(path),
+ allow_user_(allow_user) {}
+
+ virtual void SetUp() OVERRIDE {
+ event_manager_ = new EventManager();
+ socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
+ DeleteSocketFile();
+ }
+
+ virtual void TearDown() OVERRIDE {
+ DeleteSocketFile();
+ socket_ = NULL;
+ socket_delegate_.reset();
+ event_manager_ = NULL;
+ }
+
+ UnixDomainSocket::AuthCallback MakeAuthCallback() {
+ return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
+ }
+
+ void DeleteSocketFile() {
+ ASSERT_FALSE(file_path_.empty());
+ base::DeleteFile(file_path_, false /* not recursive */);
+ }
+
+ SocketDescriptor CreateClientSocket() {
+ const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0);
+ if (sock < 0) {
+ LOG(ERROR) << "socket() error";
+ return StreamListenSocket::kInvalidSocket;
+ }
+ sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ socklen_t addr_len;
+ strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
+ addr_len = sizeof(sockaddr_un);
+ if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
+ LOG(ERROR) << "connect() error";
+ return StreamListenSocket::kInvalidSocket;
+ }
+ return sock;
+ }
+
+ scoped_ptr<base::Thread> CreateAndRunServerThread() {
+ base::Thread::Options options;
+ options.message_loop_type = base::MessageLoop::TYPE_IO;
+ scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
+ thread->StartWithOptions(options);
+ thread->message_loop()->PostTask(
+ FROM_HERE,
+ base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
+ base::Unretained(this)));
+ return thread.Pass();
+ }
+
+ const base::FilePath file_path_;
+ const bool allow_user_;
+ scoped_refptr<EventManager> event_manager_;
+ scoped_ptr<TestListenSocketDelegate> socket_delegate_;
+ scoped_refptr<UnixDomainSocket> socket_;
+};
+
+class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTest()
+ : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
+};
+
+class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTestWithInvalidPath()
+ : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
+};
+
+class UnixDomainSocketTestWithForbiddenUser
+ : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTestWithForbiddenUser()
+ : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
+};
+
+TEST_F(UnixDomainSocketTest, CreateAndListen) {
+ CreateAndListen();
+ EXPECT_FALSE(socket_.get() == NULL);
+}
+
+TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
+ CreateAndListen();
+ EXPECT_TRUE(socket_.get() == NULL);
+}
+
+#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
+// Test with an invalid path to make sure that the socket is not backed by a
+// file.
+TEST_F(UnixDomainSocketTestWithInvalidPath,
+ CreateAndListenWithAbstractNamespace) {
+ socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
+ EXPECT_FALSE(socket_.get() == NULL);
+}
+
+TEST_F(UnixDomainSocketTest, TestFallbackName) {
+ scoped_refptr<UnixDomainSocket> existing_socket =
+ UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
+ EXPECT_FALSE(existing_socket.get() == NULL);
+ // First, try to bind socket with the same name with no fallback name.
+ socket_ =
+ UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
+ EXPECT_TRUE(socket_.get() == NULL);
+ // Now with a fallback name.
+ socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ file_path_.value(),
+ MakeSocketPath(kFallbackSocketName),
+ socket_delegate_.get(),
+ MakeAuthCallback());
+ EXPECT_FALSE(socket_.get() == NULL);
+ existing_socket = NULL;
+}
+#endif
+
+TEST_F(UnixDomainSocketTest, TestWithClient) {
+ const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
+ EventType event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_LISTEN, event);
+
+ // Create the client socket.
+ const SocketDescriptor sock = CreateClientSocket();
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_AUTH_GRANTED, event);
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_ACCEPT, event);
+
+ // Send a message from the client to the server.
+ ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
+ ASSERT_NE(-1, ret);
+ ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_READ, event);
+ ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
+
+ // Close the client socket.
+ ret = HANDLE_EINTR(close(sock));
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_CLOSE, event);
+}
+
+TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
+ const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
+ EventType event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_LISTEN, event);
+ const SocketDescriptor sock = CreateClientSocket();
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_AUTH_DENIED, event);
+
+ // Wait until the file descriptor is closed by the server.
+ struct pollfd poll_fd;
+ poll_fd.fd = sock;
+ poll_fd.events = POLLIN;
+ poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
+
+ // Send() must fail.
+ ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
+ ASSERT_EQ(-1, ret);
+ ASSERT_EQ(EPIPE, errno);
+ ASSERT_FALSE(event_manager_->HasPendingEvent());
+}
+
+} // namespace
+} // namespace net