diff --git a/bridge_integration_test.go b/bridge_integration_test.go index f8ef57e..f046463 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1690,16 +1690,10 @@ func createMockMCPSrv(t *testing.T) http.Handler { return server.NewStreamableHTTPServer(s) } -func openaiCfg(url, key string) aibridge.OpenAIConfig { - return aibridge.OpenAIConfig{ - BaseURL: url, - Key: key, - } +func openaiCfg(url, key string) *aibridge.OpenAIConfig { + return aibridge.NewProviderConfig(url, key, "") } -func anthropicCfg(url, key string) aibridge.AnthropicConfig { - return aibridge.AnthropicConfig{ - BaseURL: url, - Key: key, - } +func anthropicCfg(url, key string) *aibridge.AnthropicConfig { + return aibridge.NewProviderConfig(url, key, "") } diff --git a/config.go b/config.go index 8dc6f1d..79eb9bd 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,67 @@ package aibridge +import "go.uber.org/atomic" + type ProviderConfig struct { - BaseURL, Key string + baseURL, key atomic.String + upstreamLoggingDir atomic.String + enableUpstreamLogging atomic.Bool +} + +// NewProviderConfig creates a new ProviderConfig with the given values. +func NewProviderConfig(baseURL, key, upstreamLoggingDir string) *ProviderConfig { + cfg := &ProviderConfig{} + cfg.baseURL.Store(baseURL) + cfg.key.Store(key) + cfg.upstreamLoggingDir.Store(upstreamLoggingDir) + return cfg +} + +// BaseURL returns the base URL for the provider. +func (c *ProviderConfig) BaseURL() string { + return c.baseURL.Load() +} + +// SetBaseURL sets the base URL for the provider. +func (c *ProviderConfig) SetBaseURL(baseURL string) { + c.baseURL.Store(baseURL) +} + +// Key returns the API key for the provider. +func (c *ProviderConfig) Key() string { + return c.key.Load() +} + +// SetKey sets the API key for the provider. +func (c *ProviderConfig) SetKey(key string) { + c.key.Store(key) +} + +// UpstreamLoggingDir returns the base directory for upstream logging. +// If empty, the OS's tempdir will be used. +// Logs are written to $UpstreamLoggingDir/$provider/$model/$interceptionID.{req,res}.log +func (c *ProviderConfig) UpstreamLoggingDir() string { + return c.upstreamLoggingDir.Load() +} + +// SetUpstreamLoggingDir sets the base directory for upstream logging. +func (c *ProviderConfig) SetUpstreamLoggingDir(dir string) { + c.upstreamLoggingDir.Store(dir) +} + +// SetEnableUpstreamLogging enables or disables upstream logging at runtime. +func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) { + c.enableUpstreamLogging.Store(enabled) +} + +// IsUpstreamLoggingEnabled returns whether upstream logging is currently enabled. +func (c *ProviderConfig) IsUpstreamLoggingEnabled() bool { + return c.enableUpstreamLogging.Load() } type ( - OpenAIConfig ProviderConfig - AnthropicConfig ProviderConfig + OpenAIConfig = ProviderConfig + AnthropicConfig = ProviderConfig ) type AWSBedrockConfig struct { @@ -19,7 +74,7 @@ type AWSBedrockConfig struct { } type Config struct { - OpenAI ProviderConfig - Anthropic ProviderConfig + OpenAI OpenAIConfig + Anthropic AnthropicConfig Bedrock AWSBedrockConfig } diff --git a/go.mod b/go.mod index 47fd45d..f3befae 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 + go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b @@ -20,25 +21,28 @@ require ( // AI-related libs. require ( - github.com/anthropics/anthropic-sdk-go v1.13.0 - github.com/aws/aws-sdk-go-v2/config v1.27.27 - github.com/aws/aws-sdk-go-v2/credentials v1.17.27 + github.com/anthropics/anthropic-sdk-go v1.12.0 github.com/openai/openai-go/v2 v2.7.0 ) require ( - github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.17 + github.com/aws/aws-sdk-go-v2/credentials v1.18.21 +) + +require ( + github.com/aws/aws-sdk-go-v2 v1.39.6 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect - github.com/aws/smithy-go v1.20.3 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect + github.com/aws/smithy-go v1.23.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index d0b79c8..1f287b2 100644 --- a/go.sum +++ b/go.sum @@ -7,36 +7,36 @@ cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0Fv cloud.google.com/go/logging v1.8.1/go.mod h1:TJjR+SimHwuC8MZ9cjByQulAMgni+RkXeI3wwctHJEI= cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI= cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc= -github.com/anthropics/anthropic-sdk-go v1.13.0 h1:Bhbe8sRoDPtipttg8bQYrMCKe2b79+q6rFW1vOKEUKI= -github.com/anthropics/anthropic-sdk-go v1.13.0/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= -github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= -github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/anthropics/anthropic-sdk-go v1.12.0 h1:xPqlGnq7rWrTiHazIvCiumA0u7mGQnwDQtvA1M82h9U= +github.com/anthropics/anthropic-sdk-go v1.12.0/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= +github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= +github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM= -github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= -github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= -github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= -github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y= +github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= +github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= +github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= @@ -138,6 +138,8 @@ go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiM go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 5049e54..fbb7cbe 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -24,7 +24,7 @@ type AnthropicMessagesInterceptionBase struct { id uuid.UUID req *MessageNewParamsWrapper - cfg AnthropicConfig + cfg *AnthropicConfig bedrockCfg *AWSBedrockConfig logger slog.Logger @@ -82,7 +82,6 @@ func (i *AnthropicMessagesInterceptionBase) injectTools() { // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. i.req.ToolChoice = anthropic.ToolChoiceUnionParam{ OfAny: &anthropic.ToolChoiceAnyParam{ - Type: "auto", DisableParallelToolUse: anthropic.Bool(true), }, } @@ -97,8 +96,14 @@ func (i *AnthropicMessagesInterceptionBase) isSmallFastModel() bool { } func (i *AnthropicMessagesInterceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { - opts = append(opts, option.WithAPIKey(i.cfg.Key)) - opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) + opts = append(opts, option.WithAPIKey(i.cfg.Key())) + opts = append(opts, option.WithBaseURL(i.cfg.BaseURL())) + + if i.cfg.IsUpstreamLoggingEnabled() { + if middleware := createLoggingMiddleware(i.logger, i.cfg, ProviderAnthropic, i.id.String(), i.Model()); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) + } + } if i.bedrockCfg != nil { ctx, cancel := context.WithTimeout(ctx, time.Second*30) diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index a1f71e6..2627395 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -22,7 +22,7 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index ef8aabd..7875565 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -25,7 +25,7 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 20db323..06a4352 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -19,19 +19,13 @@ type OpenAIChatInterceptionBase struct { id uuid.UUID req *ChatCompletionNewParamsWrapper - baseURL, key string - logger slog.Logger + cfg *OpenAIConfig + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier } -func (i *OpenAIChatInterceptionBase) newCompletionsService(baseURL, key string) openai.ChatCompletionService { - opts := []option.RequestOption{option.WithAPIKey(key), option.WithBaseURL(baseURL)} - - return openai.NewChatCompletionService(opts...) -} - func (i *OpenAIChatInterceptionBase) ID() uuid.UUID { return i.id } @@ -125,3 +119,17 @@ func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, o _, _ = w.Write(out) } } + +func (i *OpenAIChatInterceptionBase) newChatCompletionService() openai.ChatCompletionService { + var opts []option.RequestOption + opts = append(opts, option.WithAPIKey(i.cfg.Key())) + opts = append(opts, option.WithBaseURL(i.cfg.BaseURL())) + + if i.cfg.IsUpstreamLoggingEnabled() { + if middleware := createLoggingMiddleware(i.logger, i.cfg, ProviderOpenAI, i.id.String(), i.Model()); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) + } + } + + return openai.NewChatCompletionService(opts...) +} diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 757c933..3c6b859 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -22,12 +22,11 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *OpenAIConfig) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -45,7 +44,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - svc := i.newCompletionsService(i.baseURL, i.key) + svc := i.newChatCompletionService() logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index ccabb35..4d61ec1 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -25,12 +25,11 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *OpenAIConfig) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -69,7 +68,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - svc := i.newCompletionsService(i.baseURL, i.key) + svc := i.newChatCompletionService() logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/provider_anthropic.go b/provider_anthropic.go index 7e9c99f..8a7d62b 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -18,7 +18,7 @@ var _ Provider = &AnthropicProvider{} // AnthropicProvider allows for interactions with the Anthropic API. type AnthropicProvider struct { - cfg AnthropicConfig + cfg *AnthropicConfig bedrockCfg *AWSBedrockConfig } @@ -28,12 +28,16 @@ const ( routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages ) -func NewAnthropicProvider(cfg AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.anthropic.com/" +func NewAnthropicProvider(cfg *AnthropicConfig, bedrockCfg *AWSBedrockConfig) *AnthropicProvider { + if cfg == nil { + panic("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("ANTHROPIC_API_KEY") + + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.anthropic.com/") + } + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("ANTHROPIC_API_KEY")) } return &AnthropicProvider{ @@ -84,7 +88,7 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req } func (p *AnthropicProvider) BaseURL() string { - return p.cfg.BaseURL + return p.cfg.BaseURL() } func (p *AnthropicProvider) AuthHeader() string { @@ -96,7 +100,7 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), p.cfg.Key) + headers.Set(p.AuthHeader(), p.cfg.Key()) } func getAnthropicErrorResponse(err error) *AnthropicErrorResponse { diff --git a/provider_openai.go b/provider_openai.go index 0fc31a6..bde4262 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -14,7 +14,7 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + cfg *OpenAIConfig } const ( @@ -23,18 +23,21 @@ const ( routeChatCompletions = "/openai/v1/chat/completions" // https://platform.openai.com/docs/api-reference/chat ) -func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.openai.com/v1/" +func NewOpenAIProvider(cfg *OpenAIConfig) *OpenAIProvider { + if cfg == nil { + panic("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("OPENAI_API_KEY") + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.openai.com/v1/") + } + + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("OPENAI_API_KEY")) } return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + cfg: cfg, } } @@ -74,9 +77,9 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIStreamingChatInterception(id, &req, p.cfg), nil } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIBlockingChatInterception(id, &req, p.cfg), nil } } @@ -84,7 +87,7 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } func (p *OpenAIProvider) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL() } func (p *OpenAIProvider) AuthHeader() string { @@ -96,5 +99,5 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), "Bearer "+p.key) + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key()) } diff --git a/request_logger.go b/request_logger.go new file mode 100644 index 0000000..0a71c4a --- /dev/null +++ b/request_logger.go @@ -0,0 +1,121 @@ +package aibridge + +import ( + "context" + "fmt" + "log" + "net/http" + "net/http/httputil" + "os" + "path/filepath" + "strings" + + "cdr.dev/slog" +) + +// SanitizeModelName makes a model name safe for use as a directory name. +// Replaces filesystem-unsafe characters with underscores. +func SanitizeModelName(model string) string { + repl := "_" + replacer := strings.NewReplacer( + "/", repl, + "\\", repl, + ":", repl, + "*", repl, + "?", repl, + "\"", repl, + "<", repl, + ">", repl, + "|", repl, + ) + return replacer.Replace(model) +} + +// logUpstreamRequest logs an HTTP request with the given ID and model name. +// The prefix format is: [req] [id] [model] +func logUpstreamRequest(logger *log.Logger, id, model string, req *http.Request) { + if logger == nil { + return + } + + if reqDump, err := httputil.DumpRequest(req, true); err == nil { + logger.Printf("[req] [%s] [%s] %s", id, model, reqDump) + } +} + +// logUpstreamResponse logs an HTTP response with the given ID and model name. +// The prefix format is: [res] [id] [model] +func logUpstreamResponse(logger *log.Logger, id, model string, resp *http.Response) { + if logger == nil { + return + } + + if respDump, err := httputil.DumpResponse(resp, true); err == nil { + logger.Printf("[res] [%s] [%s] %s", id, model, respDump) + } +} + +// logUpstreamError logs an error that occurred during request/response processing. +// The prefix format is: [res] [id] [model] Error: +func logUpstreamError(logger *log.Logger, id, model string, err error) { + if logger == nil { + return + } + + logger.Printf("[res] [%s] [%s] Error: %v", id, model, err) +} + +// createLoggingMiddleware creates a middleware function that logs requests and responses. +// Logs are written to $baseDir/$provider/$model/$id.req.log and $baseDir/$provider/$model/$id.res.log +// where baseDir is from cfg.UpstreamLoggingDir or os.TempDir() if not specified. +// Returns nil if logging setup fails, logging errors via the provided logger. +func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { + ctx := context.Background() + safeModel := SanitizeModelName(model) + + baseDir := cfg.UpstreamLoggingDir() + if baseDir == "" { + baseDir = os.TempDir() + } + + logDir := filepath.Join(baseDir, provider, safeModel) + + // Create the directory structure if it doesn't exist + if err := os.MkdirAll(logDir, 0755); err != nil { + logger.Warn(ctx, "failed to create log directory", slog.Error(err), slog.F("dir", logDir)) + return nil + } + + reqLogPath := filepath.Join(logDir, fmt.Sprintf("%s.req.log", id)) + resLogPath := filepath.Join(logDir, fmt.Sprintf("%s.res.log", id)) + + reqLogFile, err := os.OpenFile(reqLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + logger.Warn(ctx, "failed to open request log file", slog.Error(err), slog.F("path", reqLogPath)) + return nil + } + + resLogFile, err := os.OpenFile(resLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + reqLogFile.Close() + logger.Warn(ctx, "failed to open response log file", slog.Error(err), slog.F("path", resLogPath)) + return nil + } + + reqLogger := log.New(reqLogFile, "", log.LstdFlags) + resLogger := log.New(resLogFile, "", log.LstdFlags) + + return func(req *http.Request, next func(*http.Request) (*http.Response, error)) (*http.Response, error) { + logUpstreamRequest(reqLogger, id, model, req) + + resp, err := next(req) + if err != nil { + logUpstreamError(resLogger, id, model, err) + return resp, err + } + + logUpstreamResponse(resLogger, id, model, resp) + + return resp, err + } +} diff --git a/request_logger_test.go b/request_logger_test.go new file mode 100644 index 0000000..eda3a31 --- /dev/null +++ b/request_logger_test.go @@ -0,0 +1,171 @@ +package aibridge_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/stretchr/testify/require" + "golang.org/x/tools/txtar" +) + +func TestRequestLogging(t *testing.T) { + t.Parallel() + + testCases := []struct { + provider string + fixture []byte + route string + createProvider func(*aibridge.AnthropicConfig) aibridge.Provider + }{ + { + provider: aibridge.ProviderAnthropic, + fixture: antSimple, + route: "/anthropic/v1/messages", + createProvider: func(cfg *aibridge.AnthropicConfig) aibridge.Provider { + return aibridge.NewAnthropicProvider(cfg, nil) + }, + }, + { + provider: aibridge.ProviderOpenAI, + fixture: oaiSimple, + route: "/openai/v1/chat/completions", + createProvider: func(cfg *aibridge.AnthropicConfig) aibridge.Provider { + return aibridge.NewOpenAIProvider(cfg) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.provider, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Use a temp dir for this test + tmpDir := t.TempDir() + + // Parse fixture + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + // Create mock server + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + t.Cleanup(srv.Close) + + cfg := aibridge.NewProviderConfig(srv.URL, apiKey, tmpDir) + cfg.SetEnableUpstreamLogging(true) + + provider := tc.createProvider(cfg) + client := &mockRecorderClient{} + mcpProxy := mcp.NewServerProxyManager(nil) + + bridge, err := aibridge.NewRequestBridge(context.Background(), []aibridge.Provider{provider}, client, mcpProxy, nil, logger) + require.NoError(t, err) + t.Cleanup(func() { + _ = bridge.Shutdown(context.Background()) + }) + + // Make a request + req, err := http.NewRequestWithContext(t.Context(), "POST", tc.route, strings.NewReader(string(files[fixtureRequest]))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(aibridge.AsActor(req.Context(), userID, nil)) + rec := httptest.NewRecorder() + bridge.ServeHTTP(rec, req) + require.Equal(t, 200, rec.Code) + + // Check that log files were created + // Parse the request to get the model name + var reqData map[string]any + require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqData)) + model := reqData["model"].(string) + + logDir := filepath.Join(tmpDir, tc.provider, model) + entries, err := os.ReadDir(logDir) + require.NoError(t, err, "log directory should exist") + require.NotEmpty(t, entries, "log directory should contain files") + + // Should have at least one .req.log and one .res.log file + var hasReq, hasRes bool + for _, entry := range entries { + name := entry.Name() + if strings.HasSuffix(name, ".req.log") { + hasReq = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "request log should have content") + require.Contains(t, string(content), "POST") + } else if strings.HasSuffix(name, ".res.log") { + hasRes = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "response log should have content") + require.Contains(t, string(content), "200") + } + } + require.True(t, hasReq, "should have request log file") + require.True(t, hasRes, "should have response log file") + }) + } +} + +func TestSanitizeModelName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple model", + input: "gpt-4o", + expected: "gpt-4o", + }, + { + name: "model with slash", + input: "gpt-4o/mini", + expected: "gpt-4o_mini", + }, + { + name: "model with colon", + input: "o1:2024-12-17", + expected: "o1_2024-12-17", + }, + { + name: "model with backslash", + input: "model\\name", + expected: "model_name", + }, + { + name: "model with multiple special chars", + input: "model:name/version?", + expected: "model_name_version_", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := aibridge.SanitizeModelName(tt.input) + require.Equal(t, tt.expected, result) + }) + } +}